Skip to main content

ntex_rt/
pool.rs

1//! A thread pool to perform blocking operations in other threads.
2use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
3use std::task::{Context, Poll};
4use std::{any::Any, fmt, future::Future, panic, pin::Pin, thread, time::Duration};
5
6use crossbeam_channel::{Receiver, Sender, TrySendError, bounded};
7
8/// An error that may be emitted when all worker threads are busy.
9#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub struct BlockingError;
11
12impl std::error::Error for BlockingError {}
13
14impl fmt::Display for BlockingError {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        "All threads are busy".fmt(f)
17    }
18}
19
20pub struct BlockingResult<T> {
21    rx: Option<oneshot::Receiver<Result<T, Box<dyn Any + Send>>>>,
22}
23
24type BoxedDispatchable = Box<dyn Dispatchable + Send>;
25
26/// A trait for dispatching a closure. It's implemented for all `FnOnce() + Send
27/// + 'static` but may also be implemented for any other types that are `Send`
28///   and `'static`.
29pub(crate) trait Dispatchable: Send + 'static {
30    /// Run the dispatchable
31    fn run(self: Box<Self>);
32}
33
34impl<F> Dispatchable for F
35where
36    F: FnOnce() + Send + 'static,
37{
38    fn run(self: Box<Self>) {
39        (*self)()
40    }
41}
42
43struct CounterGuard(Arc<AtomicUsize>);
44
45impl Drop for CounterGuard {
46    fn drop(&mut self) {
47        self.0.fetch_sub(1, Ordering::AcqRel);
48    }
49}
50
51fn worker(
52    receiver: Receiver<BoxedDispatchable>,
53    counter: Arc<AtomicUsize>,
54    timeout: Duration,
55) -> impl FnOnce() {
56    move || {
57        counter.fetch_add(1, Ordering::AcqRel);
58        let _guard = CounterGuard(counter);
59        while let Ok(f) = receiver.recv_timeout(timeout) {
60            f.run();
61        }
62    }
63}
64
65/// A thread pool to perform blocking operations in other threads.
66#[derive(Debug, Clone)]
67pub(crate) struct ThreadPool {
68    name: String,
69    sender: Sender<BoxedDispatchable>,
70    receiver: Receiver<BoxedDispatchable>,
71    counter: Arc<AtomicUsize>,
72    thread_limit: usize,
73    recv_timeout: Duration,
74}
75
76impl ThreadPool {
77    /// Create [`ThreadPool`] with thread number limit and channel receive
78    /// timeout.
79    pub(crate) fn new(name: &str, thread_limit: usize, recv_timeout: Duration) -> Self {
80        let (sender, receiver) = bounded(0);
81        Self {
82            sender,
83            receiver,
84            thread_limit,
85            recv_timeout,
86            name: format!("{}:pool-wrk", name),
87            counter: Arc::new(AtomicUsize::new(0)),
88        }
89    }
90
91    /// Send a dispatchable, usually a closure, to another thread. Usually the
92    /// user should not use it. When all threads are busy and thread number
93    /// limit has been reached, it will return an error with the original
94    /// dispatchable.
95    pub(crate) fn dispatch<F, R>(&self, f: F) -> BlockingResult<R>
96    where
97        F: FnOnce() -> R + Send + 'static,
98        R: Send + 'static,
99    {
100        let (tx, rx) = oneshot::channel();
101        let f = Box::new(move || {
102            let result = panic::catch_unwind(panic::AssertUnwindSafe(f));
103            let _ = tx.send(result);
104        });
105
106        match self.sender.try_send(f) {
107            Ok(_) => BlockingResult { rx: Some(rx) },
108            Err(e) => match e {
109                TrySendError::Full(f) => {
110                    let cnt = self.counter.load(Ordering::Acquire);
111                    if cnt >= self.thread_limit {
112                        BlockingResult { rx: None }
113                    } else {
114                        thread::Builder::new()
115                            .name(format!("{}:{}", self.name, cnt))
116                            .spawn(worker(
117                                self.receiver.clone(),
118                                self.counter.clone(),
119                                self.recv_timeout,
120                            ))
121                            .expect("Cannot construct new thread");
122                        self.sender.send(f).expect("the channel should not be full");
123                        BlockingResult { rx: Some(rx) }
124                    }
125                }
126                TrySendError::Disconnected(_) => {
127                    unreachable!("receiver should not all disconnected")
128                }
129            },
130        }
131    }
132}
133
134impl<R> Future for BlockingResult<R> {
135    type Output = Result<R, BlockingError>;
136
137    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138        let this = self.get_mut();
139
140        if this.rx.is_none() {
141            return Poll::Ready(Err(BlockingError));
142        }
143
144        if let Some(mut rx) = this.rx.take() {
145            match Pin::new(&mut rx).poll(cx) {
146                Poll::Pending => {
147                    this.rx = Some(rx);
148                    Poll::Pending
149                }
150                Poll::Ready(result) => Poll::Ready(
151                    result
152                        .map_err(|_| BlockingError)
153                        .and_then(|res| res.map_err(|_| BlockingError)),
154                ),
155            }
156        } else {
157            unreachable!()
158        }
159    }
160}