1use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
3use std::task::{Context, Poll};
4use std::{any::Any, fmt, future::Future, panic, pin::Pin, thread, time::Duration};
5
6use crossbeam_channel::{Receiver, Sender, TrySendError, bounded};
7
8#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub struct BlockingError;
11
12impl std::error::Error for BlockingError {}
13
14impl fmt::Display for BlockingError {
15 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16 "All threads are busy".fmt(f)
17 }
18}
19
20pub struct BlockingResult<T> {
21 rx: Option<oneshot::Receiver<Result<T, Box<dyn Any + Send>>>>,
22}
23
24type BoxedDispatchable = Box<dyn Dispatchable + Send>;
25
26pub(crate) trait Dispatchable: Send + 'static {
30 fn run(self: Box<Self>);
32}
33
34impl<F> Dispatchable for F
35where
36 F: FnOnce() + Send + 'static,
37{
38 fn run(self: Box<Self>) {
39 (*self)()
40 }
41}
42
43struct CounterGuard(Arc<AtomicUsize>);
44
45impl Drop for CounterGuard {
46 fn drop(&mut self) {
47 self.0.fetch_sub(1, Ordering::AcqRel);
48 }
49}
50
51fn worker(
52 receiver: Receiver<BoxedDispatchable>,
53 counter: Arc<AtomicUsize>,
54 timeout: Duration,
55) -> impl FnOnce() {
56 move || {
57 counter.fetch_add(1, Ordering::AcqRel);
58 let _guard = CounterGuard(counter);
59 while let Ok(f) = receiver.recv_timeout(timeout) {
60 f.run();
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub(crate) struct ThreadPool {
68 sender: Sender<BoxedDispatchable>,
69 receiver: Receiver<BoxedDispatchable>,
70 counter: Arc<AtomicUsize>,
71 thread_limit: usize,
72 recv_timeout: Duration,
73}
74
75impl ThreadPool {
76 pub(crate) fn new(thread_limit: usize, recv_timeout: Duration) -> Self {
79 let (sender, receiver) = bounded(0);
80 Self {
81 sender,
82 receiver,
83 thread_limit,
84 recv_timeout,
85 counter: Arc::new(AtomicUsize::new(0)),
86 }
87 }
88
89 pub(crate) fn dispatch<F, R>(&self, f: F) -> BlockingResult<R>
94 where
95 F: FnOnce() -> R + Send + 'static,
96 R: Send + 'static,
97 {
98 let (tx, rx) = oneshot::channel();
99 let f = Box::new(move || {
100 let result = panic::catch_unwind(panic::AssertUnwindSafe(f));
101 let _ = tx.send(result);
102 });
103
104 match self.sender.try_send(f) {
105 Ok(_) => BlockingResult { rx: Some(rx) },
106 Err(e) => match e {
107 TrySendError::Full(f) => {
108 let cnt = self.counter.load(Ordering::Acquire);
109 if cnt >= self.thread_limit {
110 BlockingResult { rx: None }
111 } else {
112 thread::Builder::new()
113 .name(format!("pool-wrk:{}", cnt))
114 .spawn(worker(
115 self.receiver.clone(),
116 self.counter.clone(),
117 self.recv_timeout,
118 ))
119 .expect("Cannot construct new thread");
120 self.sender.send(f).expect("the channel should not be full");
121 BlockingResult { rx: Some(rx) }
122 }
123 }
124 TrySendError::Disconnected(_) => {
125 unreachable!("receiver should not all disconnected")
126 }
127 },
128 }
129 }
130}
131
132impl<R> Future for BlockingResult<R> {
133 type Output = Result<R, BlockingError>;
134
135 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136 let this = self.get_mut();
137
138 if this.rx.is_none() {
139 return Poll::Ready(Err(BlockingError));
140 }
141
142 if let Some(mut rx) = this.rx.take() {
143 match Pin::new(&mut rx).poll(cx) {
144 Poll::Pending => {
145 this.rx = Some(rx);
146 Poll::Pending
147 }
148 Poll::Ready(result) => Poll::Ready(
149 result
150 .map_err(|_| BlockingError)
151 .and_then(|res| res.map_err(|_| BlockingError)),
152 ),
153 }
154 } else {
155 unreachable!()
156 }
157 }
158}