Skip to main content

ntex_rt/
pool.rs

1//! A thread pool for blocking operations.
2use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
3use std::task::{Context, Poll};
4use std::{any::Any, fmt, future::Future, panic, pin::Pin, thread, time::Duration};
5
6use crossbeam_channel::{Receiver, Select, Sender, TrySendError, bounded, unbounded};
7
8/// An error that may be emitted when all worker threads are busy.
9#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub struct BlockingError;
11
12impl std::error::Error for BlockingError {}
13
14impl fmt::Display for BlockingError {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        "All threads are busy".fmt(f)
17    }
18}
19
20#[derive(Debug)]
21pub struct BlockingResult<T> {
22    rx: oneshot::AsyncReceiver<Result<T, Box<dyn Any + Send>>>,
23}
24
25type BoxedDispatchable = Box<dyn Dispatchable + Send>;
26
27pub(crate) trait Dispatchable: Send + 'static {
28    fn run(self: Box<Self>);
29}
30
31impl<F> Dispatchable for F
32where
33    F: FnOnce() + Send + 'static,
34{
35    fn run(self: Box<Self>) {
36        (*self)();
37    }
38}
39
40struct CounterGuard(Arc<AtomicUsize>);
41
42impl Drop for CounterGuard {
43    fn drop(&mut self) {
44        self.0.fetch_sub(1, Ordering::AcqRel);
45    }
46}
47
48fn worker(
49    receiver_high_prio: Receiver<BoxedDispatchable>,
50    receiver_low_prio: Receiver<BoxedDispatchable>,
51    counter: Arc<AtomicUsize>,
52    timeout: Duration,
53) -> impl FnOnce() {
54    move || {
55        counter.fetch_add(1, Ordering::AcqRel);
56        let _guard = CounterGuard(counter);
57        let mut sel = Select::new_biased();
58        sel.recv(&receiver_high_prio);
59        sel.recv(&receiver_low_prio);
60        while let Ok(op) = sel.select_timeout(timeout) {
61            match op {
62                op if op.index() == 0 => {
63                    if let Ok(f) = op.recv(&receiver_high_prio) {
64                        f.run();
65                    }
66                }
67                op if op.index() == 1 => {
68                    if let Ok(f) = op.recv(&receiver_low_prio) {
69                        f.run();
70                    }
71                }
72                _ => unreachable!(),
73            }
74        }
75    }
76}
77
78/// A thread pool for executing blocking operations.
79///
80/// The pool can be configured as either bounded or unbounded, which
81/// determines how tasks are handled when all worker threads are busy.
82///
83/// - In a **bounded** pool, submitting a task will fail if the number of
84///   concurrent operations has reached the thread limit.
85/// - In an **unbounded** pool, tasks are queued and will wait until a
86///   worker thread becomes available.
87///
88/// The number of worker threads scales dynamically with load, but will
89/// never exceed the `thread_limit` parameter.
90#[derive(Debug, Clone)]
91pub struct ThreadPool {
92    name: String,
93    sender_low_prio: Sender<BoxedDispatchable>,
94    receiver_low_prio: Receiver<BoxedDispatchable>,
95    sender_high_prio: Sender<BoxedDispatchable>,
96    receiver_high_prio: Receiver<BoxedDispatchable>,
97    counter: Arc<AtomicUsize>,
98    thread_limit: usize,
99    recv_timeout: Duration,
100}
101
102impl ThreadPool {
103    /// Creates a [`ThreadPool`] with a maximum number of worker threads
104    /// and a timeout for receiving tasks from the task channel.
105    pub fn new(name: &str, thread_limit: usize, recv_timeout: Duration) -> Self {
106        let (sender_low_prio, receiver_low_prio) = bounded(0);
107        let (sender_high_prio, receiver_high_prio) = unbounded();
108        Self {
109            sender_low_prio,
110            receiver_low_prio,
111            sender_high_prio,
112            receiver_high_prio,
113            thread_limit,
114            recv_timeout,
115            name: format!("{name}:pool-wrk"),
116            counter: Arc::new(AtomicUsize::new(0)),
117        }
118    }
119
120    #[allow(clippy::missing_panics_doc)]
121    /// Submits a task (closure) to the thread pool.
122    ///
123    /// The task will be executed by an available worker thread.
124    /// If no threads are available and the pool has reached its maximum size,
125    /// the work will be queued until a worker thread becomes available.
126    pub fn execute<F, R>(&self, f: F) -> BlockingResult<R>
127    where
128        F: FnOnce() -> R + Send + 'static,
129        R: Send + 'static,
130    {
131        let (tx, rx) = oneshot::async_channel();
132        let f = Box::new(move || {
133            // do not execute operation if receiver is dropped
134            if !tx.is_closed() {
135                let result = panic::catch_unwind(panic::AssertUnwindSafe(f));
136                let _ = tx.send(result);
137            }
138        });
139
140        match self.sender_low_prio.try_send(f) {
141            Ok(()) => BlockingResult { rx },
142            Err(e) => match e {
143                TrySendError::Full(f) => {
144                    let cnt = self.counter.load(Ordering::Acquire);
145                    if cnt >= self.thread_limit {
146                        self.sender_high_prio
147                            .send(f)
148                            .expect("the channel should not be full");
149                        BlockingResult { rx }
150                    } else {
151                        thread::Builder::new()
152                            .name(format!("{}:{}", self.name, cnt))
153                            .spawn(worker(
154                                self.receiver_high_prio.clone(),
155                                self.receiver_low_prio.clone(),
156                                self.counter.clone(),
157                                self.recv_timeout,
158                            ))
159                            .expect("Cannot construct new thread");
160                        self.sender_low_prio
161                            .send(f)
162                            .expect("the channel should not be full");
163                        BlockingResult { rx }
164                    }
165                }
166                TrySendError::Disconnected(_) => {
167                    unreachable!("receiver should not all disconnected")
168                }
169            },
170        }
171    }
172}
173
174impl<R> Future for BlockingResult<R> {
175    type Output = Result<R, BlockingError>;
176
177    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
178        let this = self.get_mut();
179
180        match Pin::new(&mut this.rx).poll(cx) {
181            Poll::Pending => Poll::Pending,
182            Poll::Ready(result) => Poll::Ready(
183                result
184                    .map_err(|_| BlockingError)
185                    .and_then(|res| res.map_err(|_| BlockingError)),
186            ),
187        }
188    }
189}