atomr_streams/
kill_switch.rs1use std::sync::Arc;
8
9use futures::stream::StreamExt;
10use parking_lot::Mutex;
11use tokio::sync::Notify;
12
13use crate::source::Source;
14
15#[derive(Clone)]
16pub struct KillSwitch {
17 inner: Arc<KillSwitchInner>,
18}
19
20struct KillSwitchInner {
21 notify: Notify,
22 state: Mutex<KillState>,
23}
24
25#[derive(Default, Clone)]
26struct KillState {
27 killed: bool,
28 error: Option<String>,
29}
30
31impl Default for KillSwitch {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl KillSwitch {
38 pub fn new() -> Self {
39 Self {
40 inner: Arc::new(KillSwitchInner {
41 notify: Notify::new(),
42 state: Mutex::new(KillState::default()),
43 }),
44 }
45 }
46
47 pub fn shutdown(&self) {
49 let mut s = self.inner.state.lock();
50 s.killed = true;
51 drop(s);
52 self.inner.notify.notify_waiters();
53 }
54
55 pub fn abort(&self, err: impl Into<String>) {
57 let mut s = self.inner.state.lock();
58 s.killed = true;
59 s.error = Some(err.into());
60 drop(s);
61 self.inner.notify.notify_waiters();
62 }
63
64 pub fn is_shut_down(&self) -> bool {
65 self.inner.state.lock().killed
66 }
67
68 pub fn error(&self) -> Option<String> {
69 self.inner.state.lock().error.clone()
70 }
71
72 pub fn flow<T: Send + 'static>(&self, source: Source<T>) -> Source<T> {
74 let inner = Arc::clone(&self.inner);
75 let s = futures::stream::unfold((source.into_boxed(), inner), |(mut s, inner)| async move {
76 if inner.state.lock().killed {
77 return None;
78 }
79 let next = {
80 let notified = inner.notify.notified();
81 tokio::pin!(notified);
82 tokio::select! {
83 biased;
84 _ = &mut notified => None,
85 item = s.next() => item,
86 }
87 };
88 next.map(|v| (v, (s, inner)))
89 })
90 .boxed();
91 Source { inner: s }
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::sink::Sink;
99 use std::time::Duration;
100
101 #[tokio::test]
102 async fn kill_switch_completes_long_running_source() {
103 let ks = KillSwitch::new();
104 let src = Source::tick(Duration::from_millis(1), Duration::from_millis(1), 1_u32);
105 let gated = ks.flow(src);
106 let handle = tokio::spawn(async move { Sink::collect(gated).await });
107 tokio::time::sleep(Duration::from_millis(10)).await;
108 ks.shutdown();
109 let out = handle.await.unwrap();
110 assert!(out.len() < 10_000, "stream should complete after shutdown");
111 }
112
113 #[tokio::test]
114 async fn abort_latches_error_message() {
115 let ks = KillSwitch::new();
116 ks.abort("boom");
117 assert_eq!(ks.error().as_deref(), Some("boom"));
118 assert!(ks.is_shut_down());
119 }
120}