tokio_graceful/
trigger.rs

1//! A trigger is a way to wake up a task from another task.
2//!
3//! This is useful for implementing graceful shutdowns, among other things.
4//! The way it works is a Sender and Receiver both have access to shared data,
5//! being a WakerList and a boolean indicating whether the trigger has been triggered.
6//!
7//! The Sender can trigger the Receiver by setting the boolean to true and waking up all the wakers.
8//! The Receiver can add itself to the waker list (when being polled) and check whether the boolean
9//! has been set to true.
10//!
11//! Using Arc, Mutex and Atomic* this is all done in a safe manner.
12//! The trick is further to use Slab to store the wakers, as it allows
13//! us to very efficiently keep track of the wakers and remove them when they are no longer needed.
14//!
15//! To make this work, in a cancel safe manner, we need to make sure
16//! we remove the waker from the waker list when the Receiver is dropped.
17
18use std::{
19    future::Future,
20    pin::Pin,
21    task::{Context, Poll, Waker},
22};
23
24use pin_project_lite::pin_project;
25use slab::Slab;
26
27use crate::sync::{Arc, AtomicBool, Mutex, Ordering};
28
29type WakerList = Arc<Mutex<Slab<Option<Waker>>>>;
30type TriggerState = Arc<AtomicBool>;
31
32/// A subscriber is the active state of a Receiver,
33/// and is there only when the Receiver did not yet detect a trigger.
34#[derive(Debug, Clone)]
35struct Subscriber {
36    wakers: WakerList,
37    state: TriggerState,
38}
39
40/// The state of a [`Subscriber] returned by `Subscriber::state`,
41/// which is used to determine whether the Subscriber has been triggered
42/// or has instead stored the callee's `Waker` for being able to wake it up
43/// when the trigger is triggered.
44#[derive(Debug)]
45enum SubscriberState {
46    Waiting(usize),
47    Triggered,
48}
49
50impl Subscriber {
51    /// Returns the state of the Subscriber,
52    /// which is used as a main driver in the Receiver's `Future::poll` implementation.
53    ///
54    /// If the Subscriber has been triggered, it returns `SubscriberState::Triggered`.
55    /// If the Subscriber has not yet been triggered, it returns `SubscriberState::Waiting`
56    /// with the key of the waker in the waker list.
57    ///
58    /// If the key is `Some`, it means the waker is already in the waker list,
59    /// and we can update the waker with the new waker. Otherwise we insert the waker
60    /// into the waker list as a new waker. Either way, we return the key of the waker.
61    pub fn state(&self, cx: &mut Context, key: Option<usize>) -> SubscriberState {
62        if self.state.load(Ordering::SeqCst) {
63            return SubscriberState::Triggered;
64        }
65
66        let mut wakers = self.wakers.lock().unwrap();
67
68        // check again after locking the wakers
69        // if we didn't miss this for some reason...
70        // (without this, we could miss a trigger, and never wake up...)
71        // (this was a bug detected by loom)
72        if self.state.load(Ordering::SeqCst) {
73            return SubscriberState::Triggered;
74        }
75
76        let waker = Some(cx.waker().clone());
77
78        SubscriberState::Waiting(if let Some(key) = key {
79            tracing::trace!("trigger::Subscriber: updating waker for key: {}", key);
80            *wakers.get_mut(key).unwrap() = waker;
81            key
82        } else {
83            let key = wakers.insert(waker);
84            tracing::trace!("trigger::Subscriber: insert waker for key: {}", key);
85            key
86        })
87    }
88}
89
90/// The state of a [`Receiver`], which is either open or closed.
91/// The closed state is mostly for simplification and optimization reasons.
92///
93/// When the Receiver is open, it contains a [`Subscriber`],
94/// which is used to determine whether the Receiver has been triggered.
95#[derive(Debug)]
96enum ReceiverState {
97    Open { sub: Subscriber, key: Option<usize> },
98    Closed,
99    Pending,
100}
101
102impl Clone for ReceiverState {
103    /// Clone either nothing or the [`Subscriber`].
104    /// Very important however to not clone its key as
105    /// that is linked to a polled future of the original Receiver,
106    /// and not the cloned one.
107    fn clone(&self) -> Self {
108        match self {
109            ReceiverState::Open { sub, .. } => ReceiverState::Open {
110                sub: sub.clone(),
111                key: None,
112            },
113            ReceiverState::Closed => ReceiverState::Closed,
114            ReceiverState::Pending => ReceiverState::Pending,
115        }
116    }
117}
118
119impl Drop for ReceiverState {
120    /// When the Receiver is dropped, we need to remove the waker from the waker list.
121    /// As to ensure the Receiver is cancel safe.
122    fn drop(&mut self) {
123        if let ReceiverState::Open { sub, key } = self {
124            if let Some(key) = key.take() {
125                let mut wakers = sub.wakers.lock().unwrap();
126                tracing::trace!(
127                    "trigger::ReceiverState::Drop: remove waker for key: {}",
128                    key
129                );
130                wakers.remove(key);
131            }
132        }
133    }
134}
135
136pin_project! {
137    #[derive(Debug, Clone)]
138    pub struct Receiver {
139        state: ReceiverState,
140    }
141}
142
143impl Receiver {
144    fn new(wakers: WakerList, state: TriggerState) -> Self {
145        Self {
146            state: ReceiverState::Open {
147                sub: Subscriber { wakers, state },
148                key: None,
149            },
150        }
151    }
152
153    /// Create a always-closed [`Receiver`].
154    pub(crate) fn closed() -> Self {
155        Self {
156            state: ReceiverState::Closed,
157        }
158    }
159
160    /// Create a always-pending [`Receiver`].
161    pub(crate) fn pending() -> Self {
162        Self {
163            state: ReceiverState::Pending,
164        }
165    }
166}
167
168impl Future for Receiver {
169    type Output = ();
170
171    /// Polls the Receiver, which is either open or closed.
172    ///
173    /// When the Receiver is open, it uses the [`Subscriber`] to determine
174    /// whether the Receiver has been triggered.
175    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
176        let this = self.project();
177        match this.state {
178            ReceiverState::Open { sub, key } => {
179                let state = sub.state(cx, *key);
180                match state {
181                    SubscriberState::Waiting(new_key) => {
182                        *key = Some(new_key);
183                        std::task::Poll::Pending
184                    }
185                    SubscriberState::Triggered => {
186                        *this.state = ReceiverState::Closed;
187                        std::task::Poll::Ready(())
188                    }
189                }
190            }
191            ReceiverState::Closed => std::task::Poll::Ready(()),
192            ReceiverState::Pending => std::task::Poll::Pending,
193        }
194    }
195}
196
197#[derive(Debug, Clone)]
198pub struct Sender {
199    state: TriggerState,
200    wakers: WakerList,
201}
202
203impl Sender {
204    fn new(wakers: WakerList, state: TriggerState) -> Self {
205        Self { wakers, state }
206    }
207
208    /// Triggers the Receiver, with a short circuit if the trigger has already been triggered.
209    pub fn trigger(&self) {
210        if self.state.swap(true, Ordering::SeqCst) {
211            return;
212        }
213
214        let mut wakers = self.wakers.lock().unwrap();
215        for (key, waker) in wakers.iter_mut() {
216            match waker.take() {
217                Some(waker) => {
218                    tracing::trace!("trigger::Sender: wake up waker with key: {}", key);
219                    waker.wake();
220                }
221                None => {
222                    tracing::trace!(
223                        "trigger::Sender: nop: waker already triggered with key: {}",
224                        key
225                    );
226                }
227            }
228        }
229    }
230}
231
232pub fn trigger() -> (Sender, Receiver) {
233    let wakers = Arc::new(Mutex::new(Slab::new()));
234    let state = Arc::new(AtomicBool::new(false));
235
236    let sender = Sender::new(wakers.clone(), state.clone());
237    let receiver = Receiver::new(wakers, state);
238
239    (sender, receiver)
240}
241
242#[cfg(all(test, not(loom)))]
243mod tests {
244    use super::*;
245
246    #[tokio::test]
247    async fn test_sender_trigger() {
248        let (sender, receiver) = trigger();
249
250        let th = tokio::spawn(async move {
251            sender.trigger();
252        });
253
254        receiver.await;
255
256        th.await.unwrap();
257    }
258
259    #[tokio::test]
260    async fn test_sender_never_trigger() {
261        let (_, receiver) = trigger();
262        tokio::time::timeout(std::time::Duration::from_millis(100), receiver)
263            .await
264            .unwrap_err();
265    }
266}
267
268#[cfg(all(test, loom))]
269mod loom_tests {
270    use super::*;
271
272    use loom::{future::block_on, thread};
273
274    #[test]
275    fn test_loom_sender_trigger() {
276        loom::model(|| {
277            let (sender, receiver) = trigger();
278
279            let th = thread::spawn(move || {
280                sender.trigger();
281            });
282
283            block_on(async move {
284                receiver.await;
285            });
286
287            th.join().unwrap();
288        });
289    }
290}