1use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
3use std::task::{Context, Poll};
4use std::{any::Any, fmt, future::Future, panic, pin::Pin, thread, time::Duration};
5
6use crossbeam_channel::{Receiver, Select, Sender, TrySendError, bounded, unbounded};
7
8#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub struct BlockingError;
11
12impl std::error::Error for BlockingError {}
13
14impl fmt::Display for BlockingError {
15 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16 "All threads are busy".fmt(f)
17 }
18}
19
20#[derive(Debug)]
21pub struct BlockingResult<T> {
22 rx: oneshot::AsyncReceiver<Result<T, Box<dyn Any + Send>>>,
23}
24
25type BoxedDispatchable = Box<dyn Dispatchable + Send>;
26
27pub(crate) trait Dispatchable: Send + 'static {
28 fn run(self: Box<Self>);
29}
30
31impl<F> Dispatchable for F
32where
33 F: FnOnce() + Send + 'static,
34{
35 fn run(self: Box<Self>) {
36 (*self)();
37 }
38}
39
40struct CounterGuard(Arc<AtomicUsize>);
41
42impl Drop for CounterGuard {
43 fn drop(&mut self) {
44 self.0.fetch_sub(1, Ordering::AcqRel);
45 }
46}
47
48fn worker(
49 receiver_high_prio: Receiver<BoxedDispatchable>,
50 receiver_low_prio: Receiver<BoxedDispatchable>,
51 counter: Arc<AtomicUsize>,
52 timeout: Duration,
53) -> impl FnOnce() {
54 move || {
55 counter.fetch_add(1, Ordering::AcqRel);
56 let _guard = CounterGuard(counter);
57 let mut sel = Select::new_biased();
58 sel.recv(&receiver_high_prio);
59 sel.recv(&receiver_low_prio);
60 while let Ok(op) = sel.select_timeout(timeout) {
61 match op {
62 op if op.index() == 0 => {
63 if let Ok(f) = op.recv(&receiver_high_prio) {
64 f.run();
65 }
66 }
67 op if op.index() == 1 => {
68 if let Ok(f) = op.recv(&receiver_low_prio) {
69 f.run();
70 }
71 }
72 _ => unreachable!(),
73 }
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
91pub struct ThreadPool {
92 name: String,
93 sender_low_prio: Sender<BoxedDispatchable>,
94 receiver_low_prio: Receiver<BoxedDispatchable>,
95 sender_high_prio: Sender<BoxedDispatchable>,
96 receiver_high_prio: Receiver<BoxedDispatchable>,
97 counter: Arc<AtomicUsize>,
98 thread_limit: usize,
99 recv_timeout: Duration,
100}
101
102impl ThreadPool {
103 pub fn new(name: &str, thread_limit: usize, recv_timeout: Duration) -> Self {
106 let (sender_low_prio, receiver_low_prio) = bounded(0);
107 let (sender_high_prio, receiver_high_prio) = unbounded();
108 Self {
109 sender_low_prio,
110 receiver_low_prio,
111 sender_high_prio,
112 receiver_high_prio,
113 thread_limit,
114 recv_timeout,
115 name: format!("{name}:pool-wrk"),
116 counter: Arc::new(AtomicUsize::new(0)),
117 }
118 }
119
120 pub(crate) fn execute_inplace<F, R>(f: F) -> BlockingResult<R>
121 where
122 F: FnOnce() -> R + Send + 'static,
123 R: Send + 'static,
124 {
125 let (tx, rx) = oneshot::async_channel();
126 let result = panic::catch_unwind(panic::AssertUnwindSafe(f));
127 let _ = tx.send(result);
128 BlockingResult { rx }
129 }
130
131 #[allow(clippy::missing_panics_doc)]
132 pub fn execute<F, R>(&self, f: F) -> BlockingResult<R>
138 where
139 F: FnOnce() -> R + Send + 'static,
140 R: Send + 'static,
141 {
142 let (tx, rx) = oneshot::async_channel();
143 let f = Box::new(move || {
144 if !tx.is_closed() {
146 let result = panic::catch_unwind(panic::AssertUnwindSafe(f));
147 let _ = tx.send(result);
148 }
149 });
150
151 match self.sender_low_prio.try_send(f) {
152 Ok(()) => BlockingResult { rx },
153 Err(e) => match e {
154 TrySendError::Full(f) => {
155 let cnt = self.counter.load(Ordering::Acquire);
156 if cnt >= self.thread_limit {
157 self.sender_high_prio
158 .send(f)
159 .expect("the channel should not be full");
160 BlockingResult { rx }
161 } else {
162 thread::Builder::new()
163 .name(format!("{}:{}", self.name, cnt))
164 .spawn(worker(
165 self.receiver_high_prio.clone(),
166 self.receiver_low_prio.clone(),
167 self.counter.clone(),
168 self.recv_timeout,
169 ))
170 .expect("Cannot construct new thread");
171 self.sender_low_prio
172 .send(f)
173 .expect("the channel should not be full");
174 BlockingResult { rx }
175 }
176 }
177 TrySendError::Disconnected(_) => {
178 unreachable!("receiver should not all disconnected")
179 }
180 },
181 }
182 }
183}
184
185impl<R> Future for BlockingResult<R> {
186 type Output = Result<R, BlockingError>;
187
188 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189 let this = self.get_mut();
190
191 match Pin::new(&mut this.rx).poll(cx) {
192 Poll::Pending => Poll::Pending,
193 Poll::Ready(result) => Poll::Ready(
194 result
195 .map_err(|_| BlockingError)
196 .and_then(|res| res.map_err(|_| BlockingError)),
197 ),
198 }
199 }
200}