simple_shutdown/
lib.rs

1#![no_std]
2
3#[cfg(feature = "alloc")]
4extern crate alloc;
5
6use core::{
7    cell::RefCell,
8    sync::atomic::{AtomicBool, AtomicUsize},
9    task::Waker,
10};
11
12use critical_section::Mutex;
13use futures_util::task::AtomicWaker;
14use intrusive::List;
15
16mod group;
17mod intrusive;
18mod spawn;
19mod task;
20
21pub use group::{ShutdownSignal, TaskGroup};
22pub use spawn::Spawn;
23
24pub struct State {
25    running_tasks: AtomicUsize,
26    done_waker: AtomicWaker,
27    shutdown_wakers: Mutex<RefCell<List<Option<Waker>>>>,
28    shutdown_signaled: AtomicBool,
29}
30
31impl State {
32    pub const fn new() -> Self {
33        State {
34            running_tasks: AtomicUsize::new(0),
35            done_waker: AtomicWaker::new(),
36            shutdown_wakers: Mutex::new(RefCell::new(List::new())),
37            shutdown_signaled: AtomicBool::new(false),
38        }
39    }
40}
41
42#[cfg(test)]
43mod tests {
44    use crate::*;
45
46    #[test]
47    fn spawns_tasks() {
48        static STATE: State = State::new();
49
50        let runtime = tokio::runtime::Runtime::new().unwrap();
51        let group = TaskGroup::with_static(&runtime, &STATE);
52        let (tx, rx) = tokio::sync::oneshot::channel();
53
54        runtime.block_on(async move {
55            group.spawn(async move {
56                if let Err(_) = tx.send(()) {
57                    panic!("the receiver dropped");
58                }
59            });
60
61            tokio::select! {
62                result = rx => match result {
63                    Ok(_) => {}
64                    Err(_) => panic!("the sender did not spawn"),
65                },
66                _ = tokio::time::sleep(core::time::Duration::from_secs(10)) => panic!()
67            }
68        });
69    }
70
71    #[test]
72    fn done_waits() {
73        static STATE: State = State::new();
74
75        let runtime = tokio::runtime::Runtime::new().unwrap();
76        let group = TaskGroup::with_static(&runtime, &STATE);
77
78        runtime.block_on(async move {
79            group.spawn(async move {
80                loop {
81                    tokio::time::sleep(core::time::Duration::from_millis(100)).await;
82                }
83            });
84
85            tokio::select! {
86                _ = group.done() => panic!(),
87                _ = tokio::time::sleep(core::time::Duration::from_secs(2)) => {}
88            }
89        });
90    }
91
92    #[test]
93    fn done_exits() {
94        static STATE: State = State::new();
95
96        let runtime = tokio::runtime::Runtime::new().unwrap();
97        let group = TaskGroup::with_static(&runtime, &STATE);
98
99        runtime.block_on(async move {
100            for _ in 0..5 {
101                group.spawn(async move {
102                    tokio::time::sleep(core::time::Duration::from_millis(100)).await;
103                });
104            }
105
106            tokio::select! {
107                _ = group.done() => {},
108                _ = tokio::time::sleep(core::time::Duration::from_secs(2)) => panic!()
109            }
110        });
111    }
112
113    #[test]
114    fn shutdown_signals() {
115        static STATE: State = State::new();
116
117        let runtime = tokio::runtime::Runtime::new().unwrap();
118        let group = TaskGroup::with_static(&runtime, &STATE);
119        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<()>();
120
121        runtime.block_on(async move {
122            for _ in 0..5 {
123                let tx = tx.clone();
124                group.spawn_with_shutdown(|shutdown| async move {
125                    tokio::select! {
126                        _ = shutdown => {},
127                        _ = tokio::time::sleep(core::time::Duration::from_secs(5)) => {let _ = tx.send(());},
128                    }
129                    core::mem::drop(tx);});}
130            core::mem::drop(tx);
131            tokio::time::sleep(core::time::Duration::from_secs(1)).await;
132            group.shutdown().await;
133            if let Some(_) = rx.recv().await {
134                panic!();
135            }
136        });
137    }
138}