flo_task/
spawn_scope.rs

1//! RAII guard used to notify child tasks that the parent has been dropped.
2
3use std::future::Future;
4use tokio::sync::watch::{channel, Receiver, Sender};
5
6#[derive(Debug)]
7pub struct SpawnScope {
8    tx: Option<Sender<()>>,
9    rx: Receiver<()>,
10}
11
12impl SpawnScope {
13    pub fn new() -> Self {
14        let (tx, rx) = channel(());
15        Self { tx: Some(tx), rx }
16    }
17
18    pub fn handle(&self) -> SpawnScopeHandle {
19        let rx = self.rx.clone();
20        SpawnScopeHandle(rx)
21    }
22
23    pub fn spawn<F>(&self, future: F) 
24    where F: Future<Output = ()> + Send + 'static
25    {
26        let mut handle = self.handle();
27        tokio::spawn(async move {
28            tokio::select! {
29                _ = handle.left() => {},
30                _ = future => {},
31            }
32        });
33    }
34
35    pub fn close(&mut self) {
36        self.tx.take();
37    }
38}
39
40#[derive(Debug)]
41pub struct SpawnScopeHandle(Receiver<()>);
42
43impl Clone for SpawnScopeHandle {
44    fn clone(&self) -> Self {
45        let rx = self.0.clone();
46        SpawnScopeHandle(rx)
47    }
48}
49
50impl SpawnScopeHandle {
51    pub async fn left(&mut self) {
52        while let Some(_) = self.0.recv().await {}
53    }
54
55
56    pub fn spawn<F>(&self, future: F) 
57    where F: Future<Output = ()> + Send + 'static
58    {
59        let mut handle = self.clone();
60        tokio::spawn(async move {
61            tokio::select! {
62                _ = handle.left() => {},
63                _ = future => {},
64            }
65        });
66    }
67}
68
69#[tokio::test]
70async fn test_drop() {
71    use std::future::Future;
72    use std::time::Duration;
73    use tokio::time::delay_for;
74    let scope = SpawnScope::new();
75
76    fn get_task(mut scope: SpawnScopeHandle) -> impl Future<Output = i32> {
77        async move {
78            let mut n = 0;
79            loop {
80                tokio::select! {
81                  _ = scope.left() => {
82                    return n
83                  }
84                  _ = delay_for(Duration::from_millis(50)) => {
85                    n = n + 1
86                  }
87                }
88            }
89        }
90    }
91
92    let t1 = tokio::spawn(get_task(scope.handle()));
93    let t2 = tokio::spawn(get_task(scope.handle()));
94    let t3 = tokio::spawn(get_task(scope.handle()));
95
96    delay_for(Duration::from_millis(100)).await;
97    drop(scope);
98
99    let (v1, v2, v3) = tokio::try_join!(t1, t2, t3).unwrap();
100    assert!(v1 > 0);
101    assert!(v2 > 0);
102    assert!(v3 > 0);
103}
104
105#[tokio::test]
106async fn test_spawn() {
107    use tokio::sync::oneshot::*;
108    
109    let (tx, rx) = channel();
110
111    struct Guard(Option<Sender<()>>);
112    impl Drop for Guard {
113        fn drop(&mut self) {
114            self.0.take().unwrap().send(()).ok();
115        }
116    }
117
118    let g = Guard(tx.into());
119    let scope = SpawnScope::new();
120
121    scope.spawn(async move {
122        futures::future::pending::<()>().await;
123        drop(g)
124    });
125
126    drop(scope);
127
128    rx.await.unwrap();
129}