1use std::sync::{
2 Arc, Mutex,
3 atomic::{AtomicU8, Ordering},
4};
5
6use crate::stream::{BoxStream, Flow};
7use crate::{StreamError, StreamResult};
8
9#[derive(Clone, Debug)]
14pub struct UniqueKillSwitch {
15 state: Arc<KillSwitchState>,
16}
17
18impl UniqueKillSwitch {
19 pub fn shutdown(&self) {
24 self.state.shutdown();
25 }
26
27 pub fn abort(&self, error: StreamError) {
32 self.state.abort(error);
33 }
34}
35
36#[derive(Clone, Debug)]
41pub struct SharedKillSwitch {
42 name: Arc<str>,
43 state: Arc<KillSwitchState>,
44}
45
46impl SharedKillSwitch {
47 fn new(name: impl Into<Arc<str>>) -> Self {
48 Self {
49 name: name.into(),
50 state: Arc::new(KillSwitchState::default()),
51 }
52 }
53
54 #[must_use]
56 pub fn name(&self) -> &str {
57 &self.name
58 }
59
60 pub fn shutdown(&self) {
62 self.state.shutdown();
63 }
64
65 pub fn abort(&self, error: StreamError) {
67 self.state.abort(error);
68 }
69
70 #[must_use]
73 pub fn flow<T: Send + 'static>(&self) -> Flow<T, T, SharedKillSwitch> {
74 let state = Arc::clone(&self.state);
75 let switch = self.clone();
76 Flow::from_parts(
77 move |input| Box::new(KillSwitchStream::new(input, Arc::clone(&state))),
78 move || Ok(switch.clone()),
79 )
80 }
81}
82
83pub struct KillSwitches;
85
86impl KillSwitches {
87 #[must_use]
89 pub fn single<T: Send + 'static>() -> Flow<T, T, UniqueKillSwitch> {
90 Flow::from_materialized_factory(move || {
91 let state = Arc::new(KillSwitchState::default());
92 let switch = UniqueKillSwitch {
93 state: Arc::clone(&state),
94 };
95 let transform = Arc::new(move |input| {
96 Box::new(KillSwitchStream::new(input, Arc::clone(&state))) as BoxStream<T>
97 });
98 (transform, switch)
99 })
100 }
101
102 #[must_use]
105 pub fn shared(name: impl Into<Arc<str>>) -> SharedKillSwitch {
106 SharedKillSwitch::new(name)
107 }
108}
109
110struct KillSwitchStream<T> {
111 input: BoxStream<T>,
112 state: Arc<KillSwitchState>,
113 terminated: bool,
114}
115
116const KILL_SWITCH_OPEN: u8 = 0;
117const KILL_SWITCH_SHUTDOWN: u8 = 1;
118const KILL_SWITCH_ABORTED: u8 = 2;
119
120impl<T> KillSwitchStream<T> {
121 fn new(input: BoxStream<T>, state: Arc<KillSwitchState>) -> Self {
122 Self {
123 input,
124 state,
125 terminated: false,
126 }
127 }
128}
129
130impl<T> Iterator for KillSwitchStream<T> {
131 type Item = StreamResult<T>;
132
133 fn next(&mut self) -> Option<Self::Item> {
134 if self.terminated {
135 return None;
136 }
137
138 match self.state.current() {
139 KillSwitchStatus::Open => {}
140 KillSwitchStatus::Shutdown => {
141 self.terminated = true;
142 return None;
143 }
144 KillSwitchStatus::Aborted(error) => {
145 self.terminated = true;
146 return Some(Err(error));
147 }
148 }
149
150 let next = self.input.next();
151 if next.is_none() {
152 self.terminated = true;
153 }
154 next
155 }
156}
157
158#[derive(Clone, Debug, Default)]
159enum KillSwitchStatus {
160 #[default]
161 Open,
162 Shutdown,
163 Aborted(StreamError),
164}
165
166#[derive(Debug, Default)]
167struct KillSwitchState {
168 gate: AtomicU8,
169 status: Mutex<KillSwitchStatus>,
170}
171
172impl KillSwitchState {
173 fn shutdown(&self) {
174 if self
175 .gate
176 .compare_exchange(
177 KILL_SWITCH_OPEN,
178 KILL_SWITCH_SHUTDOWN,
179 Ordering::AcqRel,
180 Ordering::Acquire,
181 )
182 .is_err()
183 {
184 return;
185 }
186 let mut status = self.status.lock().expect("kill switch poisoned");
187 *status = KillSwitchStatus::Shutdown;
188 }
189
190 fn abort(&self, error: StreamError) {
191 if self
192 .gate
193 .compare_exchange(
194 KILL_SWITCH_OPEN,
195 KILL_SWITCH_ABORTED,
196 Ordering::AcqRel,
197 Ordering::Acquire,
198 )
199 .is_err()
200 {
201 return;
202 }
203 let mut status = self.status.lock().expect("kill switch poisoned");
204 *status = KillSwitchStatus::Aborted(error);
205 }
206
207 fn current(&self) -> KillSwitchStatus {
208 match self.gate.load(Ordering::Acquire) {
209 KILL_SWITCH_OPEN => KillSwitchStatus::Open,
210 KILL_SWITCH_SHUTDOWN => KillSwitchStatus::Shutdown,
211 KILL_SWITCH_ABORTED => self.status.lock().expect("kill switch poisoned").clone(),
212 gate => panic!("unexpected kill switch gate state {gate}"),
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use crate::testkit::{TestSink, TestSource};
221 use crate::{Keep, Materializer, Source};
222 use std::{
223 sync::{Arc, Barrier},
224 thread,
225 };
226
227 #[test]
228 fn unique_kill_switch_shutdown_completes_and_cancels() {
229 let materializer = Materializer::new();
230 let ((source, switch), sink) = TestSource::probe::<i32>()
231 .via_mat(KillSwitches::single(), Keep::both)
232 .to_mat(TestSink::probe(), Keep::both)
233 .run_with_materializer(&materializer)
234 .expect("graph materializes");
235
236 sink.request(1);
237 assert_eq!(source.expect_request(), 1);
238 source.send_next(1);
239 sink.assert_next(1);
240
241 switch.shutdown();
242 switch.shutdown();
243 sink.request(1);
244 sink.expect_complete();
245 source.expect_cancellation();
246 }
247
248 #[test]
249 fn unique_kill_switch_abort_is_idempotent_after_shutdown() {
250 let materializer = Materializer::new();
251 let ((source, switch), sink) = TestSource::probe::<i32>()
252 .via_mat(KillSwitches::single(), Keep::both)
253 .to_mat(TestSink::probe(), Keep::both)
254 .run_with_materializer(&materializer)
255 .expect("graph materializes");
256
257 switch.shutdown();
258 switch.abort(StreamError::Failed("late abort".to_owned()));
259 sink.request(1);
260 sink.expect_complete();
261 source.expect_cancellation();
262 }
263
264 #[test]
265 fn unique_kill_switch_pre_materialization_shutdown_completes_immediately() {
266 let flow = KillSwitches::single::<i32>().map_materialized_value(|switch| {
267 switch.shutdown();
268 switch
269 });
270 let sink = Source::from_iter(1..=3)
271 .via_mat(flow, Keep::right)
272 .run_with(TestSink::probe())
273 .expect("test sink materializes");
274
275 sink.request(1);
276 sink.expect_complete();
277 }
278
279 #[test]
280 fn shared_kill_switch_fans_out_to_many_streams() {
281 let switch = KillSwitches::shared("shared-switch");
282 let materializer = Materializer::new();
283 let make_stream = || TestSource::probe::<i32>().via_mat(switch.flow(), Keep::both);
284
285 let ((source_a, shared_a), sink_a) = make_stream()
286 .to_mat(TestSink::probe(), Keep::both)
287 .run_with_materializer(&materializer)
288 .expect("first stream materializes");
289 let ((source_b, shared_b), sink_b) = make_stream()
290 .to_mat(TestSink::probe(), Keep::both)
291 .run_with_materializer(&materializer)
292 .expect("second stream materializes");
293
294 assert_eq!(shared_a.name(), "shared-switch");
295 assert_eq!(shared_b.name(), "shared-switch");
296
297 sink_a.request(1);
298 sink_b.request(1);
299 assert_eq!(source_a.expect_request(), 1);
300 assert_eq!(source_b.expect_request(), 1);
301
302 source_a.send_next(1);
303 source_b.send_next(2);
304 sink_a.assert_next(1);
305 sink_b.assert_next(2);
306
307 switch.abort(StreamError::Failed("shared abort".to_owned()));
308 switch.shutdown();
309 sink_a.request(1);
310 sink_b.request(1);
311 assert_eq!(
312 sink_a.expect_error(),
313 StreamError::Failed("shared abort".to_owned())
314 );
315 assert_eq!(
316 sink_b.expect_error(),
317 StreamError::Failed("shared abort".to_owned())
318 );
319 source_a.expect_cancellation();
320 source_b.expect_cancellation();
321 }
322
323 #[test]
324 fn shared_kill_switch_is_thread_safe() {
325 let switch = Arc::new(KillSwitches::shared("thread-safe"));
326 let clone = Arc::clone(&switch);
327
328 let handle = thread::spawn(move || {
329 clone.shutdown();
330 });
331 switch.shutdown();
332 handle.join().expect("kill switch thread joins");
333 }
334
335 #[test]
336 fn unique_kill_switch_materializations_stay_thread_local_and_independent() {
337 let flow = KillSwitches::single::<usize>();
338 let materializer = Arc::new(Materializer::new());
339 let barrier = Arc::new(Barrier::new(9));
340
341 let handles = (0..8)
342 .map(|idx| {
343 let flow = flow.clone();
344 let materializer = Arc::clone(&materializer);
345 let barrier = Arc::clone(&barrier);
346 thread::spawn(move || {
347 barrier.wait();
348 Source::repeat(idx)
349 .via_mat(flow, Keep::right)
350 .to_mat(TestSink::probe(), Keep::both)
351 .run_with_materializer(materializer.as_ref())
352 .expect("kill switch flow materializes")
353 })
354 })
355 .collect::<Vec<_>>();
356
357 barrier.wait();
358
359 let mut streams = handles
360 .into_iter()
361 .map(|handle| handle.join().expect("materialization thread joins"))
362 .collect::<Vec<_>>();
363
364 for (idx, (_switch, sink)) in streams.iter_mut().enumerate() {
365 sink.request(1);
366 sink.assert_next(idx);
367 }
368
369 streams[3].0.shutdown();
370
371 for (idx, (_switch, sink)) in streams.iter_mut().enumerate() {
372 sink.request(1);
373 if idx == 3 {
374 sink.expect_complete();
375 } else {
376 sink.assert_next(idx);
377 }
378 }
379 }
380}