1use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
3use std::task::{Context, Poll};
4use std::{any::Any, fmt, future::Future, panic, pin::Pin, thread, time::Duration};
5
6use crossbeam_channel::{Receiver, Sender, TrySendError, bounded};
7
8#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub struct BlockingError;
11
12impl std::error::Error for BlockingError {}
13
14impl fmt::Display for BlockingError {
15 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16 "All threads are busy".fmt(f)
17 }
18}
19
20#[derive(Debug)]
21pub struct BlockingResult<T> {
22 rx: Option<oneshot::Receiver<Result<T, Box<dyn Any + Send>>>>,
23}
24
25type BoxedDispatchable = Box<dyn Dispatchable + Send>;
26
27pub(crate) trait Dispatchable: Send + 'static {
29 fn run(self: Box<Self>);
31}
32
33impl<F> Dispatchable for F
34where
35 F: FnOnce() + Send + 'static,
36{
37 fn run(self: Box<Self>) {
38 (*self)();
39 }
40}
41
42struct CounterGuard(Arc<AtomicUsize>);
43
44impl Drop for CounterGuard {
45 fn drop(&mut self) {
46 self.0.fetch_sub(1, Ordering::AcqRel);
47 }
48}
49
50fn worker(
51 receiver: Receiver<BoxedDispatchable>,
52 counter: Arc<AtomicUsize>,
53 timeout: Duration,
54) -> impl FnOnce() {
55 move || {
56 counter.fetch_add(1, Ordering::AcqRel);
57 let _guard = CounterGuard(counter);
58 while let Ok(f) = receiver.recv_timeout(timeout) {
59 f.run();
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub(crate) struct ThreadPool {
67 name: String,
68 sender: Sender<BoxedDispatchable>,
69 receiver: Receiver<BoxedDispatchable>,
70 counter: Arc<AtomicUsize>,
71 thread_limit: usize,
72 recv_timeout: Duration,
73}
74
75impl ThreadPool {
76 pub(crate) fn new(name: &str, thread_limit: usize, recv_timeout: Duration) -> Self {
79 let (sender, receiver) = bounded(0);
80 Self {
81 sender,
82 receiver,
83 thread_limit,
84 recv_timeout,
85 name: format!("{name}:pool-wrk"),
86 counter: Arc::new(AtomicUsize::new(0)),
87 }
88 }
89
90 pub(crate) fn dispatch<F, R>(&self, f: F) -> BlockingResult<R>
95 where
96 F: FnOnce() -> R + Send + 'static,
97 R: Send + 'static,
98 {
99 let (tx, rx) = oneshot::channel();
100 let f = Box::new(move || {
101 let result = panic::catch_unwind(panic::AssertUnwindSafe(f));
102 let _ = tx.send(result);
103 });
104
105 match self.sender.try_send(f) {
106 Ok(()) => BlockingResult { rx: Some(rx) },
107 Err(e) => match e {
108 TrySendError::Full(f) => {
109 let cnt = self.counter.load(Ordering::Acquire);
110 if cnt >= self.thread_limit {
111 BlockingResult { rx: None }
112 } else {
113 thread::Builder::new()
114 .name(format!("{}:{}", self.name, cnt))
115 .spawn(worker(
116 self.receiver.clone(),
117 self.counter.clone(),
118 self.recv_timeout,
119 ))
120 .expect("Cannot construct new thread");
121 self.sender.send(f).expect("the channel should not be full");
122 BlockingResult { rx: Some(rx) }
123 }
124 }
125 TrySendError::Disconnected(_) => {
126 unreachable!("receiver should not all disconnected")
127 }
128 },
129 }
130 }
131}
132
133impl<R> Future for BlockingResult<R> {
134 type Output = Result<R, BlockingError>;
135
136 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137 let this = self.get_mut();
138
139 if this.rx.is_none() {
140 return Poll::Ready(Err(BlockingError));
141 }
142
143 if let Some(mut rx) = this.rx.take() {
144 match Pin::new(&mut rx).poll(cx) {
145 Poll::Pending => {
146 this.rx = Some(rx);
147 Poll::Pending
148 }
149 Poll::Ready(result) => Poll::Ready(
150 result
151 .map_err(|_| BlockingError)
152 .and_then(|res| res.map_err(|_| BlockingError)),
153 ),
154 }
155 } else {
156 unreachable!()
157 }
158 }
159}