1use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
3use std::task::{Context, Poll};
4use std::{any::Any, fmt, future::Future, panic, pin::Pin, thread, time::Duration};
5
6use crossbeam_channel::{Receiver, Select, Sender, TrySendError, bounded, unbounded};
7
8#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub struct BlockingError;
11
12impl std::error::Error for BlockingError {}
13
14impl fmt::Display for BlockingError {
15 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16 "All threads are busy".fmt(f)
17 }
18}
19
20#[derive(Debug)]
21pub struct BlockingResult<T> {
22 rx: oneshot::AsyncReceiver<Result<T, Box<dyn Any + Send>>>,
23}
24
25type BoxedDispatchable = Box<dyn Dispatchable + Send>;
26
27pub(crate) trait Dispatchable: Send + 'static {
28 fn run(self: Box<Self>);
29}
30
31impl<F> Dispatchable for F
32where
33 F: FnOnce() + Send + 'static,
34{
35 fn run(self: Box<Self>) {
36 (*self)();
37 }
38}
39
40struct CounterGuard(Arc<AtomicUsize>);
41
42impl Drop for CounterGuard {
43 fn drop(&mut self) {
44 self.0.fetch_sub(1, Ordering::AcqRel);
45 }
46}
47
48fn worker(
49 receiver_high_prio: Receiver<BoxedDispatchable>,
50 receiver_low_prio: Receiver<BoxedDispatchable>,
51 counter: Arc<AtomicUsize>,
52 timeout: Duration,
53) -> impl FnOnce() {
54 move || {
55 counter.fetch_add(1, Ordering::AcqRel);
56 let _guard = CounterGuard(counter);
57 let mut sel = Select::new_biased();
58 sel.recv(&receiver_high_prio);
59 sel.recv(&receiver_low_prio);
60 while let Ok(op) = sel.select_timeout(timeout) {
61 match op {
62 op if op.index() == 0 => {
63 if let Ok(f) = op.recv(&receiver_high_prio) {
64 f.run();
65 }
66 }
67 op if op.index() == 1 => {
68 if let Ok(f) = op.recv(&receiver_low_prio) {
69 f.run();
70 }
71 }
72 _ => unreachable!(),
73 }
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
91pub struct ThreadPool {
92 name: String,
93 sender_low_prio: Sender<BoxedDispatchable>,
94 receiver_low_prio: Receiver<BoxedDispatchable>,
95 sender_high_prio: Sender<BoxedDispatchable>,
96 receiver_high_prio: Receiver<BoxedDispatchable>,
97 counter: Arc<AtomicUsize>,
98 thread_limit: usize,
99 recv_timeout: Duration,
100}
101
102impl ThreadPool {
103 pub fn new(name: &str, thread_limit: usize, recv_timeout: Duration) -> Self {
106 let (sender_low_prio, receiver_low_prio) = bounded(0);
107 let (sender_high_prio, receiver_high_prio) = unbounded();
108 Self {
109 sender_low_prio,
110 receiver_low_prio,
111 sender_high_prio,
112 receiver_high_prio,
113 thread_limit,
114 recv_timeout,
115 name: format!("{name}:pool-wrk"),
116 counter: Arc::new(AtomicUsize::new(0)),
117 }
118 }
119
120 #[allow(clippy::missing_panics_doc)]
121 pub fn execute<F, R>(&self, f: F) -> BlockingResult<R>
127 where
128 F: FnOnce() -> R + Send + 'static,
129 R: Send + 'static,
130 {
131 let (tx, rx) = oneshot::async_channel();
132 let f = Box::new(move || {
133 if !tx.is_closed() {
135 let result = panic::catch_unwind(panic::AssertUnwindSafe(f));
136 let _ = tx.send(result);
137 }
138 });
139
140 match self.sender_low_prio.try_send(f) {
141 Ok(()) => BlockingResult { rx },
142 Err(e) => match e {
143 TrySendError::Full(f) => {
144 let cnt = self.counter.load(Ordering::Acquire);
145 if cnt >= self.thread_limit {
146 self.sender_high_prio
147 .send(f)
148 .expect("the channel should not be full");
149 BlockingResult { rx }
150 } else {
151 thread::Builder::new()
152 .name(format!("{}:{}", self.name, cnt))
153 .spawn(worker(
154 self.receiver_high_prio.clone(),
155 self.receiver_low_prio.clone(),
156 self.counter.clone(),
157 self.recv_timeout,
158 ))
159 .expect("Cannot construct new thread");
160 self.sender_low_prio
161 .send(f)
162 .expect("the channel should not be full");
163 BlockingResult { rx }
164 }
165 }
166 TrySendError::Disconnected(_) => {
167 unreachable!("receiver should not all disconnected")
168 }
169 },
170 }
171 }
172}
173
174impl<R> Future for BlockingResult<R> {
175 type Output = Result<R, BlockingError>;
176
177 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
178 let this = self.get_mut();
179
180 match Pin::new(&mut this.rx).poll(cx) {
181 Poll::Pending => Poll::Pending,
182 Poll::Ready(result) => Poll::Ready(
183 result
184 .map_err(|_| BlockingError)
185 .and_then(|res| res.map_err(|_| BlockingError)),
186 ),
187 }
188 }
189}