1use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
3use std::task::{Context, Poll};
4use std::{any::Any, fmt, future::Future, panic, pin::Pin, thread, time::Duration};
5
6use crossbeam_channel::{Receiver, Sender, TrySendError, bounded};
7
8#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub struct BlockingError;
11
12impl std::error::Error for BlockingError {}
13
14impl fmt::Display for BlockingError {
15 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16 "All threads are busy".fmt(f)
17 }
18}
19
20pub struct BlockingResult<T> {
21 rx: Option<oneshot::Receiver<Result<T, Box<dyn Any + Send>>>>,
22}
23
24type BoxedDispatchable = Box<dyn Dispatchable + Send>;
25
26pub(crate) trait Dispatchable: Send + 'static {
30 fn run(self: Box<Self>);
32}
33
34impl<F> Dispatchable for F
35where
36 F: FnOnce() + Send + 'static,
37{
38 fn run(self: Box<Self>) {
39 (*self)()
40 }
41}
42
43struct CounterGuard(Arc<AtomicUsize>);
44
45impl Drop for CounterGuard {
46 fn drop(&mut self) {
47 self.0.fetch_sub(1, Ordering::AcqRel);
48 }
49}
50
51fn worker(
52 receiver: Receiver<BoxedDispatchable>,
53 counter: Arc<AtomicUsize>,
54 timeout: Duration,
55) -> impl FnOnce() {
56 move || {
57 counter.fetch_add(1, Ordering::AcqRel);
58 let _guard = CounterGuard(counter);
59 while let Ok(f) = receiver.recv_timeout(timeout) {
60 f.run();
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub(crate) struct ThreadPool {
68 name: String,
69 sender: Sender<BoxedDispatchable>,
70 receiver: Receiver<BoxedDispatchable>,
71 counter: Arc<AtomicUsize>,
72 thread_limit: usize,
73 recv_timeout: Duration,
74}
75
76impl ThreadPool {
77 pub(crate) fn new(name: &str, thread_limit: usize, recv_timeout: Duration) -> Self {
80 let (sender, receiver) = bounded(0);
81 Self {
82 sender,
83 receiver,
84 thread_limit,
85 recv_timeout,
86 name: format!("{}:pool-wrk", name),
87 counter: Arc::new(AtomicUsize::new(0)),
88 }
89 }
90
91 pub(crate) fn dispatch<F, R>(&self, f: F) -> BlockingResult<R>
96 where
97 F: FnOnce() -> R + Send + 'static,
98 R: Send + 'static,
99 {
100 let (tx, rx) = oneshot::channel();
101 let f = Box::new(move || {
102 let result = panic::catch_unwind(panic::AssertUnwindSafe(f));
103 let _ = tx.send(result);
104 });
105
106 match self.sender.try_send(f) {
107 Ok(_) => BlockingResult { rx: Some(rx) },
108 Err(e) => match e {
109 TrySendError::Full(f) => {
110 let cnt = self.counter.load(Ordering::Acquire);
111 if cnt >= self.thread_limit {
112 BlockingResult { rx: None }
113 } else {
114 thread::Builder::new()
115 .name(format!("{}:{}", self.name, cnt))
116 .spawn(worker(
117 self.receiver.clone(),
118 self.counter.clone(),
119 self.recv_timeout,
120 ))
121 .expect("Cannot construct new thread");
122 self.sender.send(f).expect("the channel should not be full");
123 BlockingResult { rx: Some(rx) }
124 }
125 }
126 TrySendError::Disconnected(_) => {
127 unreachable!("receiver should not all disconnected")
128 }
129 },
130 }
131 }
132}
133
134impl<R> Future for BlockingResult<R> {
135 type Output = Result<R, BlockingError>;
136
137 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138 let this = self.get_mut();
139
140 if this.rx.is_none() {
141 return Poll::Ready(Err(BlockingError));
142 }
143
144 if let Some(mut rx) = this.rx.take() {
145 match Pin::new(&mut rx).poll(cx) {
146 Poll::Pending => {
147 this.rx = Some(rx);
148 Poll::Pending
149 }
150 Poll::Ready(result) => Poll::Ready(
151 result
152 .map_err(|_| BlockingError)
153 .and_then(|res| res.map_err(|_| BlockingError)),
154 ),
155 }
156 } else {
157 unreachable!()
158 }
159 }
160}