poolparty/
lib.rs

1use std::io;
2
3use async_std::channel::{unbounded, Receiver, Sender};
4
5use futures::{
6    executor::ThreadPool,
7    future::{Future, FutureExt},
8    pin_mut, select,
9};
10
11const INTERNAL_CHANNEL: &str = "Control channel closed, this should never happen.";
12
13/// Added functionality for the `futures::executor::ThreadPool` futures executor.
14///
15/// Futures will be spawned to and executed by the internal and exchangeable `ThreadPool` instance, but in such a way that *all* spawned futures are asked to stop on user request or in case any of them returns an error.
16///
17/// A notable difference to `futures:executor::ThreadPool` is that the user spawns futures of type `Output<Result(),T>` here instead of type `Output<()>`.
18///
19/// Caveats: If you do not call `observe().await` once all desired futures are spawned or if you spawn additional futures after the first `observe().await` the stopping mechanism won't work. In other words, instances cannot be "reused" after they were being observed for the first time.
20/// For now no measures are in place to prevent a user from doing this (maybe in a future version).
21///
22/// Also note that spawned tasks *can not* be cancelled instantly. They will stop executing the next time they yield to the executor.
23pub struct StoppableThreadPool<PoolError>
24where
25    PoolError: Send + Sync + 'static,
26{
27    pool: ThreadPool,
28    control_sender: Sender<Result<(), PoolError>>,
29    control_receiver: Receiver<Result<(), PoolError>>,
30    stop_senders: Vec<Sender<()>>,
31}
32
33impl<PoolError> StoppableThreadPool<PoolError>
34where
35    PoolError: Send + Sync + 'static,
36{
37    /// Create a new `StoppableThreadPool` instance using a default futures `ThreadPool` executor instance.
38    pub fn new() -> Result<StoppableThreadPool<PoolError>, io::Error> {
39        Ok(StoppableThreadPool::new_with_pool(ThreadPool::new()?))
40    }
41
42    /// Create a new `StoppableThreadPool` instance using a user supplied futures `ThreadPool` executor instance.
43    pub fn new_with_pool(pool: ThreadPool) -> StoppableThreadPool<PoolError> {
44        let (control_sender, control_receiver) = unbounded::<Result<(), PoolError>>();
45        StoppableThreadPool::<PoolError> {
46            pool,
47            control_sender,
48            control_receiver,
49            stop_senders: Vec::new(),
50        }
51    }
52
53    /// Change the underlying futures `ThreadPool` executor instance.
54    pub fn with_pool(&mut self, pool: ThreadPool) -> &mut Self {
55        self.pool = pool;
56        self
57    }
58
59    /// Start executing a future right away.
60    pub fn spawn<Fut>(&mut self, future: Fut) -> &mut Self
61    where
62        Fut: Future<Output = Result<(), PoolError>> + Send + 'static,
63    {
64        let (tx, rx) = unbounded::<()>();
65        self.stop_senders.push(tx);
66        let control = self.control_sender.clone();
67        self.pool.spawn_ok(async move {
68            let future = future.fuse();
69            let stopped = rx.recv().fuse();
70            pin_mut!(future, stopped);
71            let _ = select! {
72                output = future => control.send(output).await,
73                _ = stopped => control.send(Ok(())).await
74            };
75        });
76        self
77    }
78
79    /// Ensure that all spawned tasks are canceled on individual task error or any ` stop()` request issued by the user.
80    /// Call this function once all tasks are spawned.
81    /// A task that fails before a call to `observe()` is being awaited will still trigger a stop as soon as you actually start awaiting here.
82    pub async fn observe(&self) -> Result<(), PoolError> {
83        let mut completed: usize = 0;
84        while let Ok(output) = self.control_receiver.recv().await {
85            completed += 1;
86            if output.is_err() {
87                for tx in self.stop_senders.iter() {
88                    if tx.send(()).await.is_err() {
89                        eprintln!("Task already finished")
90                    }
91                }
92                return output;
93            }
94            if completed == self.stop_senders.len() {
95                break;
96            }
97        }
98        Ok(())
99    }
100
101    /// Stop the execution of all spawned tasks.
102    pub async fn stop(&self, why: PoolError) {
103        self.control_sender
104            .send(Err(why))
105            .await
106            .expect(INTERNAL_CHANNEL)
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use futures::{executor::block_on, executor::ThreadPool, join};
113
114    use crate::StoppableThreadPool;
115
116    async fn ok() -> Result<(), String> {
117        Ok(())
118    }
119
120    async fn forever() -> Result<(), String> {
121        loop {}
122    }
123
124    async fn fail(msg: String) -> Result<(), String> {
125        Err(msg)
126    }
127
128    #[test]
129    fn observe_ok() {
130        let mut pool = StoppableThreadPool::new().unwrap();
131        for _ in 0..1000 {
132            pool.spawn(ok());
133        }
134
135        block_on(async { assert_eq!(pool.observe().await.unwrap(), (),) });
136    }
137
138    #[test]
139    fn observe_err() {
140        let mut pool = StoppableThreadPool::new().unwrap();
141        let err = "fail_function_called".to_string();
142        pool.spawn(fail(err.clone()));
143        pool.spawn(forever());
144
145        block_on(async { assert_eq!(pool.observe().await.unwrap_err(), err) });
146    }
147
148    #[test]
149    fn user_stopped() {
150        let mut pool = StoppableThreadPool::new().unwrap();
151        pool.spawn(forever()).spawn(forever());
152        let stop_reason = "stopped by user".to_string();
153
154        block_on(async {
155            join!(
156                async { assert_eq!(pool.observe().await.unwrap_err(), stop_reason.clone()) },
157                pool.stop(stop_reason.clone())
158            )
159        });
160    }
161
162    #[test]
163    fn change_pool() {
164        let mut pool = StoppableThreadPool::new().unwrap();
165        pool.spawn(forever());
166        pool.with_pool(ThreadPool::new().unwrap());
167        pool.spawn(fail("fail function called".to_string()));
168
169        block_on(async {
170            assert_eq!(
171                pool.observe().await.unwrap_err(),
172                "fail function called".to_string(),
173            )
174        })
175    }
176}