Skip to main content

ntex_rt/
pool.rs

1//! A thread pool to perform blocking operations in other threads.
2use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
3use std::task::{Context, Poll};
4use std::{any::Any, fmt, future::Future, panic, pin::Pin, thread, time::Duration};
5
6use crossbeam_channel::{Receiver, Sender, TrySendError, bounded};
7
8/// An error that may be emitted when all worker threads are busy.
9#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub struct BlockingError;
11
12impl std::error::Error for BlockingError {}
13
14impl fmt::Display for BlockingError {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        "All threads are busy".fmt(f)
17    }
18}
19
20#[derive(Debug)]
21pub struct BlockingResult<T> {
22    rx: Option<oneshot::Receiver<Result<T, Box<dyn Any + Send>>>>,
23}
24
25type BoxedDispatchable = Box<dyn Dispatchable + Send>;
26
27/// A trait for dispatching a closure
28pub(crate) trait Dispatchable: Send + 'static {
29    /// Run the dispatchable
30    fn run(self: Box<Self>);
31}
32
33impl<F> Dispatchable for F
34where
35    F: FnOnce() + Send + 'static,
36{
37    fn run(self: Box<Self>) {
38        (*self)();
39    }
40}
41
42struct CounterGuard(Arc<AtomicUsize>);
43
44impl Drop for CounterGuard {
45    fn drop(&mut self) {
46        self.0.fetch_sub(1, Ordering::AcqRel);
47    }
48}
49
50fn worker(
51    receiver: Receiver<BoxedDispatchable>,
52    counter: Arc<AtomicUsize>,
53    timeout: Duration,
54) -> impl FnOnce() {
55    move || {
56        counter.fetch_add(1, Ordering::AcqRel);
57        let _guard = CounterGuard(counter);
58        while let Ok(f) = receiver.recv_timeout(timeout) {
59            f.run();
60        }
61    }
62}
63
64/// A thread pool to perform blocking operations in other threads.
65#[derive(Debug, Clone)]
66pub(crate) struct ThreadPool {
67    name: String,
68    sender: Sender<BoxedDispatchable>,
69    receiver: Receiver<BoxedDispatchable>,
70    counter: Arc<AtomicUsize>,
71    thread_limit: usize,
72    recv_timeout: Duration,
73}
74
75impl ThreadPool {
76    /// Create [`ThreadPool`] with thread number limit and channel receive
77    /// timeout.
78    pub(crate) fn new(name: &str, thread_limit: usize, recv_timeout: Duration) -> Self {
79        let (sender, receiver) = bounded(0);
80        Self {
81            sender,
82            receiver,
83            thread_limit,
84            recv_timeout,
85            name: format!("{name}:pool-wrk"),
86            counter: Arc::new(AtomicUsize::new(0)),
87        }
88    }
89
90    /// Send a dispatchable, usually a closure, to another thread. Usually the
91    /// user should not use it. When all threads are busy and thread number
92    /// limit has been reached, it will return an error with the original
93    /// dispatchable.
94    pub(crate) fn dispatch<F, R>(&self, f: F) -> BlockingResult<R>
95    where
96        F: FnOnce() -> R + Send + 'static,
97        R: Send + 'static,
98    {
99        let (tx, rx) = oneshot::channel();
100        let f = Box::new(move || {
101            let result = panic::catch_unwind(panic::AssertUnwindSafe(f));
102            let _ = tx.send(result);
103        });
104
105        match self.sender.try_send(f) {
106            Ok(()) => BlockingResult { rx: Some(rx) },
107            Err(e) => match e {
108                TrySendError::Full(f) => {
109                    let cnt = self.counter.load(Ordering::Acquire);
110                    if cnt >= self.thread_limit {
111                        BlockingResult { rx: None }
112                    } else {
113                        thread::Builder::new()
114                            .name(format!("{}:{}", self.name, cnt))
115                            .spawn(worker(
116                                self.receiver.clone(),
117                                self.counter.clone(),
118                                self.recv_timeout,
119                            ))
120                            .expect("Cannot construct new thread");
121                        self.sender.send(f).expect("the channel should not be full");
122                        BlockingResult { rx: Some(rx) }
123                    }
124                }
125                TrySendError::Disconnected(_) => {
126                    unreachable!("receiver should not all disconnected")
127                }
128            },
129        }
130    }
131}
132
133impl<R> Future for BlockingResult<R> {
134    type Output = Result<R, BlockingError>;
135
136    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137        let this = self.get_mut();
138
139        if this.rx.is_none() {
140            return Poll::Ready(Err(BlockingError));
141        }
142
143        if let Some(mut rx) = this.rx.take() {
144            match Pin::new(&mut rx).poll(cx) {
145                Poll::Pending => {
146                    this.rx = Some(rx);
147                    Poll::Pending
148                }
149                Poll::Ready(result) => Poll::Ready(
150                    result
151                        .map_err(|_| BlockingError)
152                        .and_then(|res| res.map_err(|_| BlockingError)),
153                ),
154            }
155        } else {
156            unreachable!()
157        }
158    }
159}