1use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
3use std::task::{Context, Poll};
4use std::{any::Any, fmt, future::Future, panic, pin::Pin, thread, time::Duration};
5
6use crossbeam_channel::{Receiver, Sender, TrySendError, bounded, unbounded};
7
8#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub struct BlockingError;
11
12impl std::error::Error for BlockingError {}
13
14impl fmt::Display for BlockingError {
15 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16 "All threads are busy".fmt(f)
17 }
18}
19
20#[derive(Debug)]
21pub struct BlockingResult<T> {
22 rx: Option<oneshot::AsyncReceiver<Result<T, Box<dyn Any + Send>>>>,
23}
24
25type BoxedDispatchable = Box<dyn Dispatchable + Send>;
26
27pub(crate) trait Dispatchable: Send + 'static {
28 fn run(self: Box<Self>);
29}
30
31impl<F> Dispatchable for F
32where
33 F: FnOnce() + Send + 'static,
34{
35 fn run(self: Box<Self>) {
36 (*self)();
37 }
38}
39
40struct CounterGuard(Arc<AtomicUsize>);
41
42impl Drop for CounterGuard {
43 fn drop(&mut self) {
44 self.0.fetch_sub(1, Ordering::AcqRel);
45 }
46}
47
48fn worker(
49 receiver: Receiver<BoxedDispatchable>,
50 counter: Arc<AtomicUsize>,
51 timeout: Duration,
52) -> impl FnOnce() {
53 move || {
54 counter.fetch_add(1, Ordering::AcqRel);
55 let _guard = CounterGuard(counter);
56 while let Ok(f) = receiver.recv_timeout(timeout) {
57 f.run();
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
75pub struct ThreadPool {
76 name: String,
77 sender: Sender<BoxedDispatchable>,
78 receiver: Receiver<BoxedDispatchable>,
79 counter: Arc<AtomicUsize>,
80 thread_limit: usize,
81 recv_timeout: Duration,
82}
83
84impl ThreadPool {
85 pub fn new(
88 name: &str,
89 thread_limit: usize,
90 recv_timeout: Duration,
91 bound: bool,
92 ) -> Self {
93 let (sender, receiver) = if bound { bounded(0) } else { unbounded() };
94 Self {
95 sender,
96 receiver,
97 thread_limit,
98 recv_timeout,
99 name: format!("{name}:pool-wrk"),
100 counter: Arc::new(AtomicUsize::new(0)),
101 }
102 }
103
104 #[allow(clippy::missing_panics_doc)]
105 pub fn execute<F, R>(&self, f: F) -> BlockingResult<R>
115 where
116 F: FnOnce() -> R + Send + 'static,
117 R: Send + 'static,
118 {
119 let (tx, rx) = oneshot::async_channel();
120 let f = Box::new(move || {
121 if !tx.is_closed() {
123 let result = panic::catch_unwind(panic::AssertUnwindSafe(f));
124 let _ = tx.send(result);
125 }
126 });
127
128 match self.sender.try_send(f) {
129 Ok(()) => BlockingResult { rx: Some(rx) },
130 Err(e) => match e {
131 TrySendError::Full(f) => {
132 let cnt = self.counter.load(Ordering::Acquire);
133 if cnt >= self.thread_limit {
134 BlockingResult { rx: None }
135 } else {
136 thread::Builder::new()
137 .name(format!("{}:{}", self.name, cnt))
138 .spawn(worker(
139 self.receiver.clone(),
140 self.counter.clone(),
141 self.recv_timeout,
142 ))
143 .expect("Cannot construct new thread");
144 self.sender.send(f).expect("the channel should not be full");
145 BlockingResult { rx: Some(rx) }
146 }
147 }
148 TrySendError::Disconnected(_) => {
149 unreachable!("receiver should not all disconnected")
150 }
151 },
152 }
153 }
154}
155
156impl<R> Future for BlockingResult<R> {
157 type Output = Result<R, BlockingError>;
158
159 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
160 let this = self.get_mut();
161
162 if this.rx.is_none() {
163 return Poll::Ready(Err(BlockingError));
164 }
165
166 if let Some(mut rx) = this.rx.take() {
167 match Pin::new(&mut rx).poll(cx) {
168 Poll::Pending => {
169 this.rx = Some(rx);
170 Poll::Pending
171 }
172 Poll::Ready(result) => Poll::Ready(
173 result
174 .map_err(|_| BlockingError)
175 .and_then(|res| res.map_err(|_| BlockingError)),
176 ),
177 }
178 } else {
179 unreachable!()
180 }
181 }
182}