ntex_rt/
pool.rs

1//! A thread pool to perform blocking operations in other threads.
2use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
3use std::task::{Context, Poll};
4use std::{any::Any, fmt, future::Future, panic, pin::Pin, thread, time::Duration};
5
6use crossbeam_channel::{Receiver, Sender, TrySendError, bounded};
7
8/// An error that may be emitted when all worker threads are busy.
9#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub struct BlockingError;
11
12impl std::error::Error for BlockingError {}
13
14impl fmt::Display for BlockingError {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        "All threads are busy".fmt(f)
17    }
18}
19
20pub struct BlockingResult<T> {
21    rx: Option<oneshot::Receiver<Result<T, Box<dyn Any + Send>>>>,
22}
23
24type BoxedDispatchable = Box<dyn Dispatchable + Send>;
25
26/// A trait for dispatching a closure. It's implemented for all `FnOnce() + Send
27/// + 'static` but may also be implemented for any other types that are `Send`
28///   and `'static`.
29pub(crate) trait Dispatchable: Send + 'static {
30    /// Run the dispatchable
31    fn run(self: Box<Self>);
32}
33
34impl<F> Dispatchable for F
35where
36    F: FnOnce() + Send + 'static,
37{
38    fn run(self: Box<Self>) {
39        (*self)()
40    }
41}
42
43struct CounterGuard(Arc<AtomicUsize>);
44
45impl Drop for CounterGuard {
46    fn drop(&mut self) {
47        self.0.fetch_sub(1, Ordering::AcqRel);
48    }
49}
50
51fn worker(
52    receiver: Receiver<BoxedDispatchable>,
53    counter: Arc<AtomicUsize>,
54    timeout: Duration,
55) -> impl FnOnce() {
56    move || {
57        counter.fetch_add(1, Ordering::AcqRel);
58        let _guard = CounterGuard(counter);
59        while let Ok(f) = receiver.recv_timeout(timeout) {
60            f.run();
61        }
62    }
63}
64
65/// A thread pool to perform blocking operations in other threads.
66#[derive(Debug, Clone)]
67pub(crate) struct ThreadPool {
68    sender: Sender<BoxedDispatchable>,
69    receiver: Receiver<BoxedDispatchable>,
70    counter: Arc<AtomicUsize>,
71    thread_limit: usize,
72    recv_timeout: Duration,
73}
74
75impl ThreadPool {
76    /// Create [`ThreadPool`] with thread number limit and channel receive
77    /// timeout.
78    pub(crate) fn new(thread_limit: usize, recv_timeout: Duration) -> Self {
79        let (sender, receiver) = bounded(0);
80        Self {
81            sender,
82            receiver,
83            thread_limit,
84            recv_timeout,
85            counter: Arc::new(AtomicUsize::new(0)),
86        }
87    }
88
89    /// Send a dispatchable, usually a closure, to another thread. Usually the
90    /// user should not use it. When all threads are busy and thread number
91    /// limit has been reached, it will return an error with the original
92    /// dispatchable.
93    pub(crate) fn dispatch<F, R>(&self, f: F) -> BlockingResult<R>
94    where
95        F: FnOnce() -> R + Send + 'static,
96        R: Send + 'static,
97    {
98        let (tx, rx) = oneshot::channel();
99        let f = Box::new(move || {
100            let result = panic::catch_unwind(panic::AssertUnwindSafe(f));
101            let _ = tx.send(result);
102        });
103
104        match self.sender.try_send(f) {
105            Ok(_) => BlockingResult { rx: Some(rx) },
106            Err(e) => match e {
107                TrySendError::Full(f) => {
108                    let cnt = self.counter.load(Ordering::Acquire);
109                    if cnt >= self.thread_limit {
110                        BlockingResult { rx: None }
111                    } else {
112                        thread::Builder::new()
113                            .name(format!("pool-wrk:{}", cnt))
114                            .spawn(worker(
115                                self.receiver.clone(),
116                                self.counter.clone(),
117                                self.recv_timeout,
118                            ))
119                            .expect("Cannot construct new thread");
120                        self.sender.send(f).expect("the channel should not be full");
121                        BlockingResult { rx: Some(rx) }
122                    }
123                }
124                TrySendError::Disconnected(_) => {
125                    unreachable!("receiver should not all disconnected")
126                }
127            },
128        }
129    }
130}
131
132impl<R> Future for BlockingResult<R> {
133    type Output = Result<R, BlockingError>;
134
135    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136        let this = self.get_mut();
137
138        if this.rx.is_none() {
139            return Poll::Ready(Err(BlockingError));
140        }
141
142        if let Some(mut rx) = this.rx.take() {
143            match Pin::new(&mut rx).poll(cx) {
144                Poll::Pending => {
145                    this.rx = Some(rx);
146                    Poll::Pending
147                }
148                Poll::Ready(result) => Poll::Ready(
149                    result
150                        .map_err(|_| BlockingError)
151                        .and_then(|res| res.map_err(|_| BlockingError)),
152                ),
153            }
154        } else {
155            unreachable!()
156        }
157    }
158}