1use std::io;
2
3use async_std::channel::{unbounded, Receiver, Sender};
4
5use futures::{
6 executor::ThreadPool,
7 future::{Future, FutureExt},
8 pin_mut, select,
9};
10
11const INTERNAL_CHANNEL: &str = "Control channel closed, this should never happen.";
12
13pub struct StoppableThreadPool<PoolError>
24where
25 PoolError: Send + Sync + 'static,
26{
27 pool: ThreadPool,
28 control_sender: Sender<Result<(), PoolError>>,
29 control_receiver: Receiver<Result<(), PoolError>>,
30 stop_senders: Vec<Sender<()>>,
31}
32
33impl<PoolError> StoppableThreadPool<PoolError>
34where
35 PoolError: Send + Sync + 'static,
36{
37 pub fn new() -> Result<StoppableThreadPool<PoolError>, io::Error> {
39 Ok(StoppableThreadPool::new_with_pool(ThreadPool::new()?))
40 }
41
42 pub fn new_with_pool(pool: ThreadPool) -> StoppableThreadPool<PoolError> {
44 let (control_sender, control_receiver) = unbounded::<Result<(), PoolError>>();
45 StoppableThreadPool::<PoolError> {
46 pool,
47 control_sender,
48 control_receiver,
49 stop_senders: Vec::new(),
50 }
51 }
52
53 pub fn with_pool(&mut self, pool: ThreadPool) -> &mut Self {
55 self.pool = pool;
56 self
57 }
58
59 pub fn spawn<Fut>(&mut self, future: Fut) -> &mut Self
61 where
62 Fut: Future<Output = Result<(), PoolError>> + Send + 'static,
63 {
64 let (tx, rx) = unbounded::<()>();
65 self.stop_senders.push(tx);
66 let control = self.control_sender.clone();
67 self.pool.spawn_ok(async move {
68 let future = future.fuse();
69 let stopped = rx.recv().fuse();
70 pin_mut!(future, stopped);
71 let _ = select! {
72 output = future => control.send(output).await,
73 _ = stopped => control.send(Ok(())).await
74 };
75 });
76 self
77 }
78
79 pub async fn observe(&self) -> Result<(), PoolError> {
83 let mut completed: usize = 0;
84 while let Ok(output) = self.control_receiver.recv().await {
85 completed += 1;
86 if output.is_err() {
87 for tx in self.stop_senders.iter() {
88 if tx.send(()).await.is_err() {
89 eprintln!("Task already finished")
90 }
91 }
92 return output;
93 }
94 if completed == self.stop_senders.len() {
95 break;
96 }
97 }
98 Ok(())
99 }
100
101 pub async fn stop(&self, why: PoolError) {
103 self.control_sender
104 .send(Err(why))
105 .await
106 .expect(INTERNAL_CHANNEL)
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use futures::{executor::block_on, executor::ThreadPool, join};
113
114 use crate::StoppableThreadPool;
115
116 async fn ok() -> Result<(), String> {
117 Ok(())
118 }
119
120 async fn forever() -> Result<(), String> {
121 loop {}
122 }
123
124 async fn fail(msg: String) -> Result<(), String> {
125 Err(msg)
126 }
127
128 #[test]
129 fn observe_ok() {
130 let mut pool = StoppableThreadPool::new().unwrap();
131 for _ in 0..1000 {
132 pool.spawn(ok());
133 }
134
135 block_on(async { assert_eq!(pool.observe().await.unwrap(), (),) });
136 }
137
138 #[test]
139 fn observe_err() {
140 let mut pool = StoppableThreadPool::new().unwrap();
141 let err = "fail_function_called".to_string();
142 pool.spawn(fail(err.clone()));
143 pool.spawn(forever());
144
145 block_on(async { assert_eq!(pool.observe().await.unwrap_err(), err) });
146 }
147
148 #[test]
149 fn user_stopped() {
150 let mut pool = StoppableThreadPool::new().unwrap();
151 pool.spawn(forever()).spawn(forever());
152 let stop_reason = "stopped by user".to_string();
153
154 block_on(async {
155 join!(
156 async { assert_eq!(pool.observe().await.unwrap_err(), stop_reason.clone()) },
157 pool.stop(stop_reason.clone())
158 )
159 });
160 }
161
162 #[test]
163 fn change_pool() {
164 let mut pool = StoppableThreadPool::new().unwrap();
165 pool.spawn(forever());
166 pool.with_pool(ThreadPool::new().unwrap());
167 pool.spawn(fail("fail function called".to_string()));
168
169 block_on(async {
170 assert_eq!(
171 pool.observe().await.unwrap_err(),
172 "fail function called".to_string(),
173 )
174 })
175 }
176}