1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
use std::io;

use async_std::sync::{
    channel,
    Receiver,
    Sender,
};

use futures::{
    executor::ThreadPool,
    future::{Future,FutureExt},
    pin_mut,
    select,
};

/// Added functionality for the `futures::executor::ThreadPool` futures executor.
/// 
/// Futures will be spawned to and executed by the internal and exchangeable `ThreadPool` instance, but in such a way that *all* spawned futures are asked to stop on user request or in case any of them returns an error.
///
/// A notable difference to `futures:executor::ThreadPool` is that the user spawns futures of type `Output<Result(),T>` here instead of type `Output<()>`.
///
/// Caveats: If you do not call `observe().await` once all desired futures are spawned or if you spawn additional futures after the first `observe().await` the stopping mechanism won't work. In other words, instances cannot be "reused" after they were being observed for the first time.
/// For now no measures are in place to prevent a user from doing this (maybe in a future version).
/// 
/// Also note that spawned tasks *can not* be cancelled instantly. They will stop executing the next time they yield to the executor.
pub struct StoppableThreadPool<PoolError>
    where
        PoolError: Send + Sync + 'static,
    {
    pool: ThreadPool,
    control_sender: Sender<Result<(),PoolError>>,
    control_receiver: Receiver<Result<(),PoolError>>,
    stop_senders: Vec<Sender<()>>,
    chan_cap: usize,
}

impl<PoolError> StoppableThreadPool<PoolError> 
    where
        PoolError: Send + Sync + 'static,
    {
    /// Create a new `StoppableThreadPool` instance using a default futures `ThreadPool` executor instance.
    pub fn new(chan_cap: usize) -> Result<StoppableThreadPool<PoolError>,io::Error> {
        Ok(StoppableThreadPool::new_with_pool(
            ThreadPool::new()?,
            chan_cap,
        ))
    }

    /// Create a new `StoppableThreadPool` instance using a user supplied futures `ThreadPool` executor instance.
    pub fn new_with_pool(pool: ThreadPool, chan_cap: usize) -> StoppableThreadPool<PoolError> {
        let (control_sender, control_receiver) = channel::<Result<(),PoolError>>(chan_cap);
        StoppableThreadPool::<PoolError> {
            pool,
            control_sender,
            control_receiver,
            stop_senders: Vec::new(),
            chan_cap,
        }
    }

    /// Change the underlying futures `ThreadPool` executor instance. 
    pub fn with_pool(&mut self, pool: ThreadPool) -> &mut Self {
        self.pool = pool;
        self
    }

    /// Start executing a future right away.
    pub fn spawn<Fut>(&mut self, future: Fut) -> &mut Self
    where
        Fut: Future<Output = Result<(),PoolError>> + Send + 'static,
    {
        let (tx, rx) = channel::<()>(self.chan_cap);
        self.stop_senders.push(tx);
        let control = self.control_sender.clone();
        self.pool.spawn_ok(async move {
            let future = future.fuse();
            let stopped = rx.recv().fuse();
            pin_mut!(future, stopped);
            select! {
                output = future => control.send(output).await,
                _ = stopped => control.send(Ok(())).await
            };
        });
        self
    }

    /// Ensure that all spawned tasks are canceled on individual task error or any ` stop()` request issued by the user.
    /// Call this function once all tasks are spawned.
    /// A task that fails before a call to `observe()` is being awaited will still trigger a stop as soon as you actually start awaiting here.
    pub async fn observe(&self) -> Result<(),PoolError> {
        let mut completed: usize = 0;
        while let Ok(output) = self.control_receiver.recv().await {
            completed += 1;
            if output.is_err() {
                for tx in self.stop_senders.iter() {
                    tx.send(()).await
                }
                return output
            }
            if completed == self.stop_senders.len() {
                break
            }
        }
        Ok(())
    }

    /// Stop the execution of all spawned tasks.
    pub async fn stop(&self, why: PoolError) {
        self.control_sender.send(Err(why)).await
    }
}

#[cfg(test)]
mod tests {
    use futures::{
        join,
        executor::block_on,
        executor::ThreadPool,
    };

    use crate::StoppableThreadPool;

    async fn ok() -> Result<(),String> {
        Ok(())
    }

    async fn forever() -> Result<(),String> {
        loop {}
    }

    async fn fail(msg: String) -> Result<(),String> {
        Err(msg)
    }

    const CHAN_CAP: usize = 1;

    #[test]
    fn observe_ok() {
        let mut pool = StoppableThreadPool::new(CHAN_CAP).unwrap();
        for _ in 0..1000 {
            pool.spawn(ok());
        }

        block_on(async {
            assert_eq!(
                pool.observe().await.unwrap(),
                (),
            )
        });
    }

    #[test]
    fn observe_err() {
        let mut pool = StoppableThreadPool::new(CHAN_CAP).unwrap();
        let err = "fail_function_called".to_string();
        pool.spawn(fail(err.clone()));
        pool.spawn(forever());

        block_on(async {
            assert_eq!(
                pool.observe().await.unwrap_err(),
                err
            )
        });
    }

    #[test]
    fn user_stopped() {
        let mut pool = StoppableThreadPool::new(CHAN_CAP).unwrap();
        pool
            .spawn(forever())
            .spawn(forever());
        let stop_reason = "stopped by user".to_string();

        block_on(async {
            join!(
                async { 
                    assert_eq!(
                        pool.observe().await.unwrap_err(),
                        stop_reason.clone()
                    )
                },
                pool.stop(stop_reason.clone())
            )
        });
    }

    #[test]
    fn change_pool() {
        let mut pool = StoppableThreadPool::new(CHAN_CAP).unwrap();
        pool.spawn(forever());
        pool.with_pool(ThreadPool::new().unwrap());
        pool.spawn(fail("fail function called".to_string()));

        block_on(async {
            assert_eq!(
                pool.observe().await.unwrap_err(),
                "fail function called".to_string(),
            )
        })
    }
}