Skip to main content

ntex_rt/
pool.rs

1//! A thread pool for blocking operations.
2use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
3use std::task::{Context, Poll};
4use std::{any::Any, fmt, future::Future, panic, pin::Pin, thread, time::Duration};
5
6use crossbeam_channel::{Receiver, Sender, TrySendError, bounded, unbounded};
7
8/// An error that may be emitted when all worker threads are busy.
9#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub struct BlockingError;
11
12impl std::error::Error for BlockingError {}
13
14impl fmt::Display for BlockingError {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        "All threads are busy".fmt(f)
17    }
18}
19
20#[derive(Debug)]
21pub struct BlockingResult<T> {
22    rx: Option<oneshot::AsyncReceiver<Result<T, Box<dyn Any + Send>>>>,
23}
24
25type BoxedDispatchable = Box<dyn Dispatchable + Send>;
26
27pub(crate) trait Dispatchable: Send + 'static {
28    fn run(self: Box<Self>);
29}
30
31impl<F> Dispatchable for F
32where
33    F: FnOnce() + Send + 'static,
34{
35    fn run(self: Box<Self>) {
36        (*self)();
37    }
38}
39
40struct CounterGuard(Arc<AtomicUsize>);
41
42impl Drop for CounterGuard {
43    fn drop(&mut self) {
44        self.0.fetch_sub(1, Ordering::AcqRel);
45    }
46}
47
48fn worker(
49    receiver: Receiver<BoxedDispatchable>,
50    counter: Arc<AtomicUsize>,
51    timeout: Duration,
52) -> impl FnOnce() {
53    move || {
54        counter.fetch_add(1, Ordering::AcqRel);
55        let _guard = CounterGuard(counter);
56        while let Ok(f) = receiver.recv_timeout(timeout) {
57            f.run();
58        }
59    }
60}
61
62/// A thread pool for executing blocking operations.
63///
64/// The pool can be configured as either bounded or unbounded, which
65/// determines how tasks are handled when all worker threads are busy.
66///
67/// - In a **bounded** pool, submitting a task will fail if the number of
68///   concurrent operations has reached the thread limit.
69/// - In an **unbounded** pool, tasks are queued and will wait until a
70///   worker thread becomes available.
71///
72/// The number of worker threads scales dynamically with load, but will
73/// never exceed the `thread_limit` parameter.
74#[derive(Debug, Clone)]
75pub struct ThreadPool {
76    name: String,
77    sender: Sender<BoxedDispatchable>,
78    receiver: Receiver<BoxedDispatchable>,
79    counter: Arc<AtomicUsize>,
80    thread_limit: usize,
81    recv_timeout: Duration,
82}
83
84impl ThreadPool {
85    /// Creates a [`ThreadPool`] with a maximum number of worker threads
86    /// and a timeout for receiving tasks from the task channel.
87    pub fn new(
88        name: &str,
89        thread_limit: usize,
90        recv_timeout: Duration,
91        bound: bool,
92    ) -> Self {
93        let (sender, receiver) = if bound { bounded(0) } else { unbounded() };
94        Self {
95            sender,
96            receiver,
97            thread_limit,
98            recv_timeout,
99            name: format!("{name}:pool-wrk"),
100            counter: Arc::new(AtomicUsize::new(0)),
101        }
102    }
103
104    #[allow(clippy::missing_panics_doc)]
105    /// Submits a task (closure) to the thread pool.
106    ///
107    /// The task will be executed by an available worker thread.
108    /// If no threads are available and the pool has reached its maximum size,
109    /// the behavior depends on the `boundedness` configuration:
110    ///
111    /// - For a bounded pool, the function returns an error.
112    /// - For an unbounded pool, the task is queued and executed when a worker
113    ///   becomes available.
114    pub fn execute<F, R>(&self, f: F) -> BlockingResult<R>
115    where
116        F: FnOnce() -> R + Send + 'static,
117        R: Send + 'static,
118    {
119        let (tx, rx) = oneshot::async_channel();
120        let f = Box::new(move || {
121            // do not execute operation if recevier is dropped
122            if !tx.is_closed() {
123                let result = panic::catch_unwind(panic::AssertUnwindSafe(f));
124                let _ = tx.send(result);
125            }
126        });
127
128        match self.sender.try_send(f) {
129            Ok(()) => BlockingResult { rx: Some(rx) },
130            Err(e) => match e {
131                TrySendError::Full(f) => {
132                    let cnt = self.counter.load(Ordering::Acquire);
133                    if cnt >= self.thread_limit {
134                        BlockingResult { rx: None }
135                    } else {
136                        thread::Builder::new()
137                            .name(format!("{}:{}", self.name, cnt))
138                            .spawn(worker(
139                                self.receiver.clone(),
140                                self.counter.clone(),
141                                self.recv_timeout,
142                            ))
143                            .expect("Cannot construct new thread");
144                        self.sender.send(f).expect("the channel should not be full");
145                        BlockingResult { rx: Some(rx) }
146                    }
147                }
148                TrySendError::Disconnected(_) => {
149                    unreachable!("receiver should not all disconnected")
150                }
151            },
152        }
153    }
154}
155
156impl<R> Future for BlockingResult<R> {
157    type Output = Result<R, BlockingError>;
158
159    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
160        let this = self.get_mut();
161
162        if this.rx.is_none() {
163            return Poll::Ready(Err(BlockingError));
164        }
165
166        if let Some(mut rx) = this.rx.take() {
167            match Pin::new(&mut rx).poll(cx) {
168                Poll::Pending => {
169                    this.rx = Some(rx);
170                    Poll::Pending
171                }
172                Poll::Ready(result) => Poll::Ready(
173                    result
174                        .map_err(|_| BlockingError)
175                        .and_then(|res| res.map_err(|_| BlockingError)),
176                ),
177            }
178        } else {
179            unreachable!()
180        }
181    }
182}