Skip to main content

datum/dynamic/
kill_switch.rs

1use std::sync::{
2    Arc, Mutex,
3    atomic::{AtomicU8, Ordering},
4};
5
6use crate::stream::{BoxStream, Flow};
7use crate::{StreamError, StreamResult};
8
9/// An exclusive shutdown/abort handle for one materialized stream.
10///
11/// Obtained as the materialized value of [`KillSwitches::single`]. The terminal signal is observed
12/// on the next downstream pull (Datum's runtime is pull-based).
13#[derive(Clone, Debug)]
14pub struct UniqueKillSwitch {
15    state: Arc<KillSwitchState>,
16}
17
18impl UniqueKillSwitch {
19    /// Completes downstream and cancels upstream on the next pull observed by this flow.
20    ///
21    /// In Datum's pull-based runtime the terminal signal is observed only when downstream issues
22    /// demand, so tests that expect immediate completion should request one element first.
23    pub fn shutdown(&self) {
24        self.state.shutdown();
25    }
26
27    /// Fails the stream on the next pull observed by this flow.
28    ///
29    /// In Datum's pull-based runtime the failure is observed only when downstream issues demand,
30    /// so tests that expect an immediate error should request one element first.
31    pub fn abort(&self, error: StreamError) {
32        self.state.abort(error);
33    }
34}
35
36/// A named, cloneable kill switch shared across any number of streams.
37///
38/// Created with [`KillSwitches::shared`]; wire it into a stream with [`SharedKillSwitch::flow`].
39/// A single `shutdown()`/`abort()` propagates to every attached flow on each flow's next pull.
40#[derive(Clone, Debug)]
41pub struct SharedKillSwitch {
42    name: Arc<str>,
43    state: Arc<KillSwitchState>,
44}
45
46impl SharedKillSwitch {
47    fn new(name: impl Into<Arc<str>>) -> Self {
48        Self {
49            name: name.into(),
50            state: Arc::new(KillSwitchState::default()),
51        }
52    }
53
54    /// The name given to this switch at construction.
55    #[must_use]
56    pub fn name(&self) -> &str {
57        &self.name
58    }
59
60    /// Completes downstream and cancels upstream on the next pull observed by attached flows.
61    pub fn shutdown(&self) {
62        self.state.shutdown();
63    }
64
65    /// Fails attached flows on the next pull observed by each flow.
66    pub fn abort(&self, error: StreamError) {
67        self.state.abort(error);
68    }
69
70    /// Wire this shared switch into a stream; the returned flow's materialized value is a clone of
71    /// this switch. Every flow built from the same `SharedKillSwitch` shares one on/off state.
72    #[must_use]
73    pub fn flow<T: Send + 'static>(&self) -> Flow<T, T, SharedKillSwitch> {
74        let state = Arc::clone(&self.state);
75        let switch = self.clone();
76        Flow::from_parts(
77            move |input| Box::new(KillSwitchStream::new(input, Arc::clone(&state))),
78            move || Ok(switch.clone()),
79        )
80    }
81}
82
83/// Factory for the [`UniqueKillSwitch`] and [`SharedKillSwitch`] stream controls.
84pub struct KillSwitches;
85
86impl KillSwitches {
87    /// A `Flow` whose materialized value is a [`UniqueKillSwitch`] controlling that one stream.
88    #[must_use]
89    pub fn single<T: Send + 'static>() -> Flow<T, T, UniqueKillSwitch> {
90        Flow::from_materialized_factory(move || {
91            let state = Arc::new(KillSwitchState::default());
92            let switch = UniqueKillSwitch {
93                state: Arc::clone(&state),
94            };
95            let transform = Arc::new(move |input| {
96                Box::new(KillSwitchStream::new(input, Arc::clone(&state))) as BoxStream<T>
97            });
98            (transform, switch)
99        })
100    }
101
102    /// A reusable [`SharedKillSwitch`] that can be wired into many streams via
103    /// [`SharedKillSwitch::flow`].
104    #[must_use]
105    pub fn shared(name: impl Into<Arc<str>>) -> SharedKillSwitch {
106        SharedKillSwitch::new(name)
107    }
108}
109
110struct KillSwitchStream<T> {
111    input: BoxStream<T>,
112    state: Arc<KillSwitchState>,
113    terminated: bool,
114}
115
116const KILL_SWITCH_OPEN: u8 = 0;
117const KILL_SWITCH_SHUTDOWN: u8 = 1;
118const KILL_SWITCH_ABORTED: u8 = 2;
119
120impl<T> KillSwitchStream<T> {
121    fn new(input: BoxStream<T>, state: Arc<KillSwitchState>) -> Self {
122        Self {
123            input,
124            state,
125            terminated: false,
126        }
127    }
128}
129
130impl<T> Iterator for KillSwitchStream<T> {
131    type Item = StreamResult<T>;
132
133    fn next(&mut self) -> Option<Self::Item> {
134        if self.terminated {
135            return None;
136        }
137
138        match self.state.current() {
139            KillSwitchStatus::Open => {}
140            KillSwitchStatus::Shutdown => {
141                self.terminated = true;
142                return None;
143            }
144            KillSwitchStatus::Aborted(error) => {
145                self.terminated = true;
146                return Some(Err(error));
147            }
148        }
149
150        let next = self.input.next();
151        if next.is_none() {
152            self.terminated = true;
153        }
154        next
155    }
156}
157
158#[derive(Clone, Debug, Default)]
159enum KillSwitchStatus {
160    #[default]
161    Open,
162    Shutdown,
163    Aborted(StreamError),
164}
165
166#[derive(Debug, Default)]
167struct KillSwitchState {
168    gate: AtomicU8,
169    status: Mutex<KillSwitchStatus>,
170}
171
172impl KillSwitchState {
173    fn shutdown(&self) {
174        if self
175            .gate
176            .compare_exchange(
177                KILL_SWITCH_OPEN,
178                KILL_SWITCH_SHUTDOWN,
179                Ordering::AcqRel,
180                Ordering::Acquire,
181            )
182            .is_err()
183        {
184            return;
185        }
186        let mut status = self.status.lock().expect("kill switch poisoned");
187        *status = KillSwitchStatus::Shutdown;
188    }
189
190    fn abort(&self, error: StreamError) {
191        if self
192            .gate
193            .compare_exchange(
194                KILL_SWITCH_OPEN,
195                KILL_SWITCH_ABORTED,
196                Ordering::AcqRel,
197                Ordering::Acquire,
198            )
199            .is_err()
200        {
201            return;
202        }
203        let mut status = self.status.lock().expect("kill switch poisoned");
204        *status = KillSwitchStatus::Aborted(error);
205    }
206
207    fn current(&self) -> KillSwitchStatus {
208        match self.gate.load(Ordering::Acquire) {
209            KILL_SWITCH_OPEN => KillSwitchStatus::Open,
210            KILL_SWITCH_SHUTDOWN => KillSwitchStatus::Shutdown,
211            KILL_SWITCH_ABORTED => self.status.lock().expect("kill switch poisoned").clone(),
212            gate => panic!("unexpected kill switch gate state {gate}"),
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::testkit::{TestSink, TestSource};
221    use crate::{Keep, Materializer, Source};
222    use std::{
223        sync::{Arc, Barrier},
224        thread,
225    };
226
227    #[test]
228    fn unique_kill_switch_shutdown_completes_and_cancels() {
229        let materializer = Materializer::new();
230        let ((source, switch), sink) = TestSource::probe::<i32>()
231            .via_mat(KillSwitches::single(), Keep::both)
232            .to_mat(TestSink::probe(), Keep::both)
233            .run_with_materializer(&materializer)
234            .expect("graph materializes");
235
236        sink.request(1);
237        assert_eq!(source.expect_request(), 1);
238        source.send_next(1);
239        sink.assert_next(1);
240
241        switch.shutdown();
242        switch.shutdown();
243        sink.request(1);
244        sink.expect_complete();
245        source.expect_cancellation();
246    }
247
248    #[test]
249    fn unique_kill_switch_abort_is_idempotent_after_shutdown() {
250        let materializer = Materializer::new();
251        let ((source, switch), sink) = TestSource::probe::<i32>()
252            .via_mat(KillSwitches::single(), Keep::both)
253            .to_mat(TestSink::probe(), Keep::both)
254            .run_with_materializer(&materializer)
255            .expect("graph materializes");
256
257        switch.shutdown();
258        switch.abort(StreamError::Failed("late abort".to_owned()));
259        sink.request(1);
260        sink.expect_complete();
261        source.expect_cancellation();
262    }
263
264    #[test]
265    fn unique_kill_switch_pre_materialization_shutdown_completes_immediately() {
266        let flow = KillSwitches::single::<i32>().map_materialized_value(|switch| {
267            switch.shutdown();
268            switch
269        });
270        let sink = Source::from_iter(1..=3)
271            .via_mat(flow, Keep::right)
272            .run_with(TestSink::probe())
273            .expect("test sink materializes");
274
275        sink.request(1);
276        sink.expect_complete();
277    }
278
279    #[test]
280    fn shared_kill_switch_fans_out_to_many_streams() {
281        let switch = KillSwitches::shared("shared-switch");
282        let materializer = Materializer::new();
283        let make_stream = || TestSource::probe::<i32>().via_mat(switch.flow(), Keep::both);
284
285        let ((source_a, shared_a), sink_a) = make_stream()
286            .to_mat(TestSink::probe(), Keep::both)
287            .run_with_materializer(&materializer)
288            .expect("first stream materializes");
289        let ((source_b, shared_b), sink_b) = make_stream()
290            .to_mat(TestSink::probe(), Keep::both)
291            .run_with_materializer(&materializer)
292            .expect("second stream materializes");
293
294        assert_eq!(shared_a.name(), "shared-switch");
295        assert_eq!(shared_b.name(), "shared-switch");
296
297        sink_a.request(1);
298        sink_b.request(1);
299        assert_eq!(source_a.expect_request(), 1);
300        assert_eq!(source_b.expect_request(), 1);
301
302        source_a.send_next(1);
303        source_b.send_next(2);
304        sink_a.assert_next(1);
305        sink_b.assert_next(2);
306
307        switch.abort(StreamError::Failed("shared abort".to_owned()));
308        switch.shutdown();
309        sink_a.request(1);
310        sink_b.request(1);
311        assert_eq!(
312            sink_a.expect_error(),
313            StreamError::Failed("shared abort".to_owned())
314        );
315        assert_eq!(
316            sink_b.expect_error(),
317            StreamError::Failed("shared abort".to_owned())
318        );
319        source_a.expect_cancellation();
320        source_b.expect_cancellation();
321    }
322
323    #[test]
324    fn shared_kill_switch_is_thread_safe() {
325        let switch = Arc::new(KillSwitches::shared("thread-safe"));
326        let clone = Arc::clone(&switch);
327
328        let handle = thread::spawn(move || {
329            clone.shutdown();
330        });
331        switch.shutdown();
332        handle.join().expect("kill switch thread joins");
333    }
334
335    #[test]
336    fn unique_kill_switch_materializations_stay_thread_local_and_independent() {
337        let flow = KillSwitches::single::<usize>();
338        let materializer = Arc::new(Materializer::new());
339        let barrier = Arc::new(Barrier::new(9));
340
341        let handles = (0..8)
342            .map(|idx| {
343                let flow = flow.clone();
344                let materializer = Arc::clone(&materializer);
345                let barrier = Arc::clone(&barrier);
346                thread::spawn(move || {
347                    barrier.wait();
348                    Source::repeat(idx)
349                        .via_mat(flow, Keep::right)
350                        .to_mat(TestSink::probe(), Keep::both)
351                        .run_with_materializer(materializer.as_ref())
352                        .expect("kill switch flow materializes")
353                })
354            })
355            .collect::<Vec<_>>();
356
357        barrier.wait();
358
359        let mut streams = handles
360            .into_iter()
361            .map(|handle| handle.join().expect("materialization thread joins"))
362            .collect::<Vec<_>>();
363
364        for (idx, (_switch, sink)) in streams.iter_mut().enumerate() {
365            sink.request(1);
366            sink.assert_next(idx);
367        }
368
369        streams[3].0.shutdown();
370
371        for (idx, (_switch, sink)) in streams.iter_mut().enumerate() {
372            sink.request(1);
373            if idx == 3 {
374                sink.expect_complete();
375            } else {
376                sink.assert_next(idx);
377            }
378        }
379    }
380}