simple_shutdown/
group.rs

1use core::{
2    future::Future,
3    ops::Deref,
4    pin::Pin,
5    sync::atomic::Ordering,
6    task::{Poll, Waker},
7};
8
9#[cfg(feature = "alloc")]
10use alloc::sync::Arc;
11
12use pin_project_lite::pin_project;
13
14use crate::{intrusive::Node, spawn::Spawn, task::Task, State};
15
16/// A group of potentially related tasks. Tasks spawned by this struct can be waited on or signaled to shut down.
17pub struct TaskGroup<S: Spawn, C: Deref<Target = State>> {
18    spawner: S,
19    state: C,
20}
21
22#[cfg(feature = "alloc")]
23impl<S: Spawn> TaskGroup<S, Arc<State>> {
24    /// Create a new task group using the provided spawner
25    pub fn new(spawner: S) -> Self {
26        TaskGroup {
27            spawner,
28            state: Arc::new(State::new()),
29        }
30    }
31}
32
33impl<S: Spawn> TaskGroup<S, &'static State> {
34    /// Create a new task group using the provided spawner and state
35    pub fn with_static(spawner: S, state: &'static State) -> Self {
36        TaskGroup { spawner, state }
37    }
38}
39
40impl<S: Spawn, C: 'static + Deref<Target = State> + Clone + Send> TaskGroup<S, C> {
41    /// Signal a shutdown to all tasks in this group and wait for shutdown to finish.
42    pub async fn shutdown(&self) {
43        critical_section::with(|cs| {
44            self.state.shutdown_signaled.store(true, Ordering::SeqCst);
45
46            let list = self.state.shutdown_wakers.borrow(cs).borrow_mut();
47
48            let mut node = list.peek_front();
49            while let Some(inner_node) = node {
50                if let Some(ref waker) = inner_node.data {
51                    waker.clone().wake();
52                }
53
54                node = inner_node.next();
55            }
56        });
57
58        self.done().await;
59    }
60
61    /// Wait for all tasks in this group to finish without explicitly sending a shutdown signal.
62    pub async fn done(&self) {
63        DoneFuture {
64            state: self.state.clone(),
65        }
66        .await
67    }
68
69    /// Spawn a task as part of this task group
70    pub fn spawn(&self, future: impl Future<Output = ()> + Send + 'static) {
71        let task = Task::new(self.state.clone());
72        self.spawner.spawn(async {
73            future.await;
74            core::mem::drop(task);
75        });
76    }
77
78    pub fn spawn_with_shutdown<F>(&self, f: impl FnOnce(ShutdownSignal<C>) -> F)
79    where
80        F: Future<Output = ()> + Send + 'static,
81    {
82        let signal = ShutdownSignal {
83            state: self.state.clone(),
84            node: Node::new(None),
85        };
86        let future = f(signal);
87        self.spawn(future);
88    }
89}
90
91struct DoneFuture<C> {
92    state: C,
93}
94
95impl<C: Deref<Target = State>> Future for DoneFuture<C> {
96    type Output = ();
97
98    fn poll(
99        self: core::pin::Pin<&mut Self>,
100        cx: &mut core::task::Context<'_>,
101    ) -> core::task::Poll<Self::Output> {
102        self.state.done_waker.register(cx.waker());
103
104        if self.state.running_tasks.load(Ordering::SeqCst) == 0 {
105            Poll::Ready(())
106        } else {
107            Poll::Pending
108        }
109    }
110}
111
112pin_project! {
113    /// Future which completes once the associated task group has signaled a shutdown.
114    pub struct ShutdownSignal<C: Deref<Target = State>> {
115        state: C,
116        #[pin]
117        node: Node<Option<Waker>>,
118    }
119
120    impl<C: Deref<Target = State>> PinnedDrop for ShutdownSignal<C> {
121        fn drop(this: Pin<&mut Self>) {
122            let this = this.project();
123
124            critical_section::with(|cs| {
125                let mut list = this.state.shutdown_wakers.borrow(cs).borrow_mut();
126                if this.node.is_init() {
127                    unsafe {this.node.remove(&mut list) };
128                }
129            });
130        }
131    }
132}
133
134impl<C: Deref<Target = State>> Future for ShutdownSignal<C> {
135    type Output = ();
136
137    fn poll(
138        self: core::pin::Pin<&mut Self>,
139        cx: &mut core::task::Context<'_>,
140    ) -> Poll<Self::Output> {
141        let mut this = self.project();
142        unsafe {
143            critical_section::with(|cs| {
144                if this.state.shutdown_signaled.load(Ordering::SeqCst) {
145                    return Poll::Ready(());
146                }
147                let node = Pin::as_mut(&mut this.node).get_unchecked_mut();
148                node.data = Some(cx.waker().clone());
149                if !node.is_init() {
150                    this.state
151                        .shutdown_wakers
152                        .borrow(cs)
153                        .borrow_mut()
154                        .push_front(this.node);
155                }
156                return Poll::Pending;
157            })
158        }
159    }
160}