Skip to main content

datum/dynamic/
kill_switch.rs

1use std::sync::{
2    Arc, Mutex,
3    atomic::{AtomicU8, Ordering},
4};
5
6use crate::stream::{BoxStream, Flow};
7use crate::{StreamError, StreamResult};
8
9#[derive(Clone, Debug)]
10pub struct UniqueKillSwitch {
11    state: Arc<KillSwitchState>,
12}
13
14impl UniqueKillSwitch {
15    /// Completes downstream and cancels upstream on the next pull observed by this flow.
16    ///
17    /// In Datum's pull-based runtime the terminal signal is observed only when downstream issues
18    /// demand, so tests that expect immediate completion should request one element first.
19    pub fn shutdown(&self) {
20        self.state.shutdown();
21    }
22
23    /// Fails the stream on the next pull observed by this flow.
24    ///
25    /// In Datum's pull-based runtime the failure is observed only when downstream issues demand,
26    /// so tests that expect an immediate error should request one element first.
27    pub fn abort(&self, error: StreamError) {
28        self.state.abort(error);
29    }
30}
31
32#[derive(Clone, Debug)]
33pub struct SharedKillSwitch {
34    name: Arc<str>,
35    state: Arc<KillSwitchState>,
36}
37
38impl SharedKillSwitch {
39    fn new(name: impl Into<Arc<str>>) -> Self {
40        Self {
41            name: name.into(),
42            state: Arc::new(KillSwitchState::default()),
43        }
44    }
45
46    #[must_use]
47    pub fn name(&self) -> &str {
48        &self.name
49    }
50
51    /// Completes downstream and cancels upstream on the next pull observed by attached flows.
52    pub fn shutdown(&self) {
53        self.state.shutdown();
54    }
55
56    /// Fails attached flows on the next pull observed by each flow.
57    pub fn abort(&self, error: StreamError) {
58        self.state.abort(error);
59    }
60
61    #[must_use]
62    pub fn flow<T: Send + 'static>(&self) -> Flow<T, T, SharedKillSwitch> {
63        let state = Arc::clone(&self.state);
64        let switch = self.clone();
65        Flow::from_parts(
66            move |input| Box::new(KillSwitchStream::new(input, Arc::clone(&state))),
67            move || Ok(switch.clone()),
68        )
69    }
70}
71
72pub struct KillSwitches;
73
74impl KillSwitches {
75    #[must_use]
76    pub fn single<T: Send + 'static>() -> Flow<T, T, UniqueKillSwitch> {
77        Flow::from_materialized_factory(move || {
78            let state = Arc::new(KillSwitchState::default());
79            let switch = UniqueKillSwitch {
80                state: Arc::clone(&state),
81            };
82            let transform = Arc::new(move |input| {
83                Box::new(KillSwitchStream::new(input, Arc::clone(&state))) as BoxStream<T>
84            });
85            (transform, switch)
86        })
87    }
88
89    #[must_use]
90    pub fn shared(name: impl Into<Arc<str>>) -> SharedKillSwitch {
91        SharedKillSwitch::new(name)
92    }
93}
94
95struct KillSwitchStream<T> {
96    input: BoxStream<T>,
97    state: Arc<KillSwitchState>,
98    terminated: bool,
99}
100
101const KILL_SWITCH_OPEN: u8 = 0;
102const KILL_SWITCH_SHUTDOWN: u8 = 1;
103const KILL_SWITCH_ABORTED: u8 = 2;
104
105impl<T> KillSwitchStream<T> {
106    fn new(input: BoxStream<T>, state: Arc<KillSwitchState>) -> Self {
107        Self {
108            input,
109            state,
110            terminated: false,
111        }
112    }
113}
114
115impl<T> Iterator for KillSwitchStream<T> {
116    type Item = StreamResult<T>;
117
118    fn next(&mut self) -> Option<Self::Item> {
119        if self.terminated {
120            return None;
121        }
122
123        match self.state.current() {
124            KillSwitchStatus::Open => {}
125            KillSwitchStatus::Shutdown => {
126                self.terminated = true;
127                return None;
128            }
129            KillSwitchStatus::Aborted(error) => {
130                self.terminated = true;
131                return Some(Err(error));
132            }
133        }
134
135        let next = self.input.next();
136        if next.is_none() {
137            self.terminated = true;
138        }
139        next
140    }
141}
142
143#[derive(Clone, Debug, Default)]
144enum KillSwitchStatus {
145    #[default]
146    Open,
147    Shutdown,
148    Aborted(StreamError),
149}
150
151#[derive(Debug, Default)]
152struct KillSwitchState {
153    gate: AtomicU8,
154    status: Mutex<KillSwitchStatus>,
155}
156
157impl KillSwitchState {
158    fn shutdown(&self) {
159        if self
160            .gate
161            .compare_exchange(
162                KILL_SWITCH_OPEN,
163                KILL_SWITCH_SHUTDOWN,
164                Ordering::AcqRel,
165                Ordering::Acquire,
166            )
167            .is_err()
168        {
169            return;
170        }
171        let mut status = self.status.lock().expect("kill switch poisoned");
172        *status = KillSwitchStatus::Shutdown;
173    }
174
175    fn abort(&self, error: StreamError) {
176        if self
177            .gate
178            .compare_exchange(
179                KILL_SWITCH_OPEN,
180                KILL_SWITCH_ABORTED,
181                Ordering::AcqRel,
182                Ordering::Acquire,
183            )
184            .is_err()
185        {
186            return;
187        }
188        let mut status = self.status.lock().expect("kill switch poisoned");
189        *status = KillSwitchStatus::Aborted(error);
190    }
191
192    fn current(&self) -> KillSwitchStatus {
193        match self.gate.load(Ordering::Acquire) {
194            KILL_SWITCH_OPEN => KillSwitchStatus::Open,
195            KILL_SWITCH_SHUTDOWN => KillSwitchStatus::Shutdown,
196            KILL_SWITCH_ABORTED => self.status.lock().expect("kill switch poisoned").clone(),
197            gate => panic!("unexpected kill switch gate state {gate}"),
198        }
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::testkit::{TestSink, TestSource};
206    use crate::{Keep, Materializer, Source};
207    use std::{
208        sync::{Arc, Barrier},
209        thread,
210    };
211
212    #[test]
213    fn unique_kill_switch_shutdown_completes_and_cancels() {
214        let materializer = Materializer::new();
215        let ((source, switch), sink) = TestSource::probe::<i32>()
216            .via_mat(KillSwitches::single(), Keep::both)
217            .to_mat(TestSink::probe(), Keep::both)
218            .run_with_materializer(&materializer)
219            .expect("graph materializes");
220
221        sink.request(1);
222        assert_eq!(source.expect_request(), 1);
223        source.send_next(1);
224        sink.assert_next(1);
225
226        switch.shutdown();
227        switch.shutdown();
228        sink.request(1);
229        sink.expect_complete();
230        source.expect_cancellation();
231    }
232
233    #[test]
234    fn unique_kill_switch_abort_is_idempotent_after_shutdown() {
235        let materializer = Materializer::new();
236        let ((source, switch), sink) = TestSource::probe::<i32>()
237            .via_mat(KillSwitches::single(), Keep::both)
238            .to_mat(TestSink::probe(), Keep::both)
239            .run_with_materializer(&materializer)
240            .expect("graph materializes");
241
242        switch.shutdown();
243        switch.abort(StreamError::Failed("late abort".to_owned()));
244        sink.request(1);
245        sink.expect_complete();
246        source.expect_cancellation();
247    }
248
249    #[test]
250    fn unique_kill_switch_pre_materialization_shutdown_completes_immediately() {
251        let flow = KillSwitches::single::<i32>().map_materialized_value(|switch| {
252            switch.shutdown();
253            switch
254        });
255        let sink = Source::from_iter(1..=3)
256            .via_mat(flow, Keep::right)
257            .run_with(TestSink::probe())
258            .expect("test sink materializes");
259
260        sink.request(1);
261        sink.expect_complete();
262    }
263
264    #[test]
265    fn shared_kill_switch_fans_out_to_many_streams() {
266        let switch = KillSwitches::shared("shared-switch");
267        let materializer = Materializer::new();
268        let make_stream = || TestSource::probe::<i32>().via_mat(switch.flow(), Keep::both);
269
270        let ((source_a, shared_a), sink_a) = make_stream()
271            .to_mat(TestSink::probe(), Keep::both)
272            .run_with_materializer(&materializer)
273            .expect("first stream materializes");
274        let ((source_b, shared_b), sink_b) = make_stream()
275            .to_mat(TestSink::probe(), Keep::both)
276            .run_with_materializer(&materializer)
277            .expect("second stream materializes");
278
279        assert_eq!(shared_a.name(), "shared-switch");
280        assert_eq!(shared_b.name(), "shared-switch");
281
282        sink_a.request(1);
283        sink_b.request(1);
284        assert_eq!(source_a.expect_request(), 1);
285        assert_eq!(source_b.expect_request(), 1);
286
287        source_a.send_next(1);
288        source_b.send_next(2);
289        sink_a.assert_next(1);
290        sink_b.assert_next(2);
291
292        switch.abort(StreamError::Failed("shared abort".to_owned()));
293        switch.shutdown();
294        sink_a.request(1);
295        sink_b.request(1);
296        assert_eq!(
297            sink_a.expect_error(),
298            StreamError::Failed("shared abort".to_owned())
299        );
300        assert_eq!(
301            sink_b.expect_error(),
302            StreamError::Failed("shared abort".to_owned())
303        );
304        source_a.expect_cancellation();
305        source_b.expect_cancellation();
306    }
307
308    #[test]
309    fn shared_kill_switch_is_thread_safe() {
310        let switch = Arc::new(KillSwitches::shared("thread-safe"));
311        let clone = Arc::clone(&switch);
312
313        let handle = thread::spawn(move || {
314            clone.shutdown();
315        });
316        switch.shutdown();
317        handle.join().expect("kill switch thread joins");
318    }
319
320    #[test]
321    fn unique_kill_switch_materializations_stay_thread_local_and_independent() {
322        let flow = KillSwitches::single::<usize>();
323        let materializer = Arc::new(Materializer::new());
324        let barrier = Arc::new(Barrier::new(9));
325
326        let handles = (0..8)
327            .map(|idx| {
328                let flow = flow.clone();
329                let materializer = Arc::clone(&materializer);
330                let barrier = Arc::clone(&barrier);
331                thread::spawn(move || {
332                    barrier.wait();
333                    Source::repeat(idx)
334                        .via_mat(flow, Keep::right)
335                        .to_mat(TestSink::probe(), Keep::both)
336                        .run_with_materializer(materializer.as_ref())
337                        .expect("kill switch flow materializes")
338                })
339            })
340            .collect::<Vec<_>>();
341
342        barrier.wait();
343
344        let mut streams = handles
345            .into_iter()
346            .map(|handle| handle.join().expect("materialization thread joins"))
347            .collect::<Vec<_>>();
348
349        for (idx, (_switch, sink)) in streams.iter_mut().enumerate() {
350            sink.request(1);
351            sink.assert_next(idx);
352        }
353
354        streams[3].0.shutdown();
355
356        for (idx, (_switch, sink)) in streams.iter_mut().enumerate() {
357            sink.request(1);
358            if idx == 3 {
359                sink.expect_complete();
360            } else {
361                sink.assert_next(idx);
362            }
363        }
364    }
365}