Skip to main content

atomr_streams/
kill_switch.rs

1//! KillSwitch — external shutdown for a running stream.
2//!
3//! `KillSwitch::shutdown()` completes every attached source; `abort(err)`
4//! makes attached sources fail (modelled as early completion plus the
5//! caller inspecting the latched error).
6
7use std::sync::Arc;
8
9use futures::stream::StreamExt;
10use parking_lot::Mutex;
11use tokio::sync::Notify;
12
13use crate::source::Source;
14
15#[derive(Clone)]
16pub struct KillSwitch {
17    inner: Arc<KillSwitchInner>,
18}
19
20struct KillSwitchInner {
21    notify: Notify,
22    state: Mutex<KillState>,
23}
24
25#[derive(Default, Clone)]
26struct KillState {
27    killed: bool,
28    error: Option<String>,
29}
30
31impl Default for KillSwitch {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl KillSwitch {
38    pub fn new() -> Self {
39        Self {
40            inner: Arc::new(KillSwitchInner {
41                notify: Notify::new(),
42                state: Mutex::new(KillState::default()),
43            }),
44        }
45    }
46
47    /// Gracefully complete any sources attached via [`Self::flow`].
48    pub fn shutdown(&self) {
49        let mut s = self.inner.state.lock();
50        s.killed = true;
51        drop(s);
52        self.inner.notify.notify_waiters();
53    }
54
55    /// Abort attached sources with the given error message.
56    pub fn abort(&self, err: impl Into<String>) {
57        let mut s = self.inner.state.lock();
58        s.killed = true;
59        s.error = Some(err.into());
60        drop(s);
61        self.inner.notify.notify_waiters();
62    }
63
64    pub fn is_shut_down(&self) -> bool {
65        self.inner.state.lock().killed
66    }
67
68    pub fn error(&self) -> Option<String> {
69        self.inner.state.lock().error.clone()
70    }
71
72    /// Wrap a source so it completes when this switch fires.
73    pub fn flow<T: Send + 'static>(&self, source: Source<T>) -> Source<T> {
74        let inner = Arc::clone(&self.inner);
75        let s = futures::stream::unfold((source.into_boxed(), inner), |(mut s, inner)| async move {
76            if inner.state.lock().killed {
77                return None;
78            }
79            let next = {
80                let notified = inner.notify.notified();
81                tokio::pin!(notified);
82                tokio::select! {
83                    biased;
84                    _ = &mut notified => None,
85                    item = s.next() => item,
86                }
87            };
88            next.map(|v| (v, (s, inner)))
89        })
90        .boxed();
91        Source { inner: s }
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use crate::sink::Sink;
99    use std::time::Duration;
100
101    #[tokio::test]
102    async fn kill_switch_completes_long_running_source() {
103        let ks = KillSwitch::new();
104        let src = Source::tick(Duration::from_millis(1), Duration::from_millis(1), 1_u32);
105        let gated = ks.flow(src);
106        let handle = tokio::spawn(async move { Sink::collect(gated).await });
107        tokio::time::sleep(Duration::from_millis(10)).await;
108        ks.shutdown();
109        let out = handle.await.unwrap();
110        assert!(out.len() < 10_000, "stream should complete after shutdown");
111    }
112
113    #[tokio::test]
114    async fn abort_latches_error_message() {
115        let ks = KillSwitch::new();
116        ks.abort("boom");
117        assert_eq!(ks.error().as_deref(), Some("boom"));
118        assert!(ks.is_shut_down());
119    }
120}