1use std::sync::{
2 Arc, Mutex,
3 atomic::{AtomicU8, Ordering},
4};
5
6use crate::stream::{BoxStream, Flow};
7use crate::{StreamError, StreamResult};
8
9#[derive(Clone, Debug)]
10pub struct UniqueKillSwitch {
11 state: Arc<KillSwitchState>,
12}
13
14impl UniqueKillSwitch {
15 pub fn shutdown(&self) {
20 self.state.shutdown();
21 }
22
23 pub fn abort(&self, error: StreamError) {
28 self.state.abort(error);
29 }
30}
31
32#[derive(Clone, Debug)]
33pub struct SharedKillSwitch {
34 name: Arc<str>,
35 state: Arc<KillSwitchState>,
36}
37
38impl SharedKillSwitch {
39 fn new(name: impl Into<Arc<str>>) -> Self {
40 Self {
41 name: name.into(),
42 state: Arc::new(KillSwitchState::default()),
43 }
44 }
45
46 #[must_use]
47 pub fn name(&self) -> &str {
48 &self.name
49 }
50
51 pub fn shutdown(&self) {
53 self.state.shutdown();
54 }
55
56 pub fn abort(&self, error: StreamError) {
58 self.state.abort(error);
59 }
60
61 #[must_use]
62 pub fn flow<T: Send + 'static>(&self) -> Flow<T, T, SharedKillSwitch> {
63 let state = Arc::clone(&self.state);
64 let switch = self.clone();
65 Flow::from_parts(
66 move |input| Box::new(KillSwitchStream::new(input, Arc::clone(&state))),
67 move || Ok(switch.clone()),
68 )
69 }
70}
71
72pub struct KillSwitches;
73
74impl KillSwitches {
75 #[must_use]
76 pub fn single<T: Send + 'static>() -> Flow<T, T, UniqueKillSwitch> {
77 Flow::from_materialized_factory(move || {
78 let state = Arc::new(KillSwitchState::default());
79 let switch = UniqueKillSwitch {
80 state: Arc::clone(&state),
81 };
82 let transform = Arc::new(move |input| {
83 Box::new(KillSwitchStream::new(input, Arc::clone(&state))) as BoxStream<T>
84 });
85 (transform, switch)
86 })
87 }
88
89 #[must_use]
90 pub fn shared(name: impl Into<Arc<str>>) -> SharedKillSwitch {
91 SharedKillSwitch::new(name)
92 }
93}
94
95struct KillSwitchStream<T> {
96 input: BoxStream<T>,
97 state: Arc<KillSwitchState>,
98 terminated: bool,
99}
100
101const KILL_SWITCH_OPEN: u8 = 0;
102const KILL_SWITCH_SHUTDOWN: u8 = 1;
103const KILL_SWITCH_ABORTED: u8 = 2;
104
105impl<T> KillSwitchStream<T> {
106 fn new(input: BoxStream<T>, state: Arc<KillSwitchState>) -> Self {
107 Self {
108 input,
109 state,
110 terminated: false,
111 }
112 }
113}
114
115impl<T> Iterator for KillSwitchStream<T> {
116 type Item = StreamResult<T>;
117
118 fn next(&mut self) -> Option<Self::Item> {
119 if self.terminated {
120 return None;
121 }
122
123 match self.state.current() {
124 KillSwitchStatus::Open => {}
125 KillSwitchStatus::Shutdown => {
126 self.terminated = true;
127 return None;
128 }
129 KillSwitchStatus::Aborted(error) => {
130 self.terminated = true;
131 return Some(Err(error));
132 }
133 }
134
135 let next = self.input.next();
136 if next.is_none() {
137 self.terminated = true;
138 }
139 next
140 }
141}
142
143#[derive(Clone, Debug, Default)]
144enum KillSwitchStatus {
145 #[default]
146 Open,
147 Shutdown,
148 Aborted(StreamError),
149}
150
151#[derive(Debug, Default)]
152struct KillSwitchState {
153 gate: AtomicU8,
154 status: Mutex<KillSwitchStatus>,
155}
156
157impl KillSwitchState {
158 fn shutdown(&self) {
159 if self
160 .gate
161 .compare_exchange(
162 KILL_SWITCH_OPEN,
163 KILL_SWITCH_SHUTDOWN,
164 Ordering::AcqRel,
165 Ordering::Acquire,
166 )
167 .is_err()
168 {
169 return;
170 }
171 let mut status = self.status.lock().expect("kill switch poisoned");
172 *status = KillSwitchStatus::Shutdown;
173 }
174
175 fn abort(&self, error: StreamError) {
176 if self
177 .gate
178 .compare_exchange(
179 KILL_SWITCH_OPEN,
180 KILL_SWITCH_ABORTED,
181 Ordering::AcqRel,
182 Ordering::Acquire,
183 )
184 .is_err()
185 {
186 return;
187 }
188 let mut status = self.status.lock().expect("kill switch poisoned");
189 *status = KillSwitchStatus::Aborted(error);
190 }
191
192 fn current(&self) -> KillSwitchStatus {
193 match self.gate.load(Ordering::Acquire) {
194 KILL_SWITCH_OPEN => KillSwitchStatus::Open,
195 KILL_SWITCH_SHUTDOWN => KillSwitchStatus::Shutdown,
196 KILL_SWITCH_ABORTED => self.status.lock().expect("kill switch poisoned").clone(),
197 gate => panic!("unexpected kill switch gate state {gate}"),
198 }
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::testkit::{TestSink, TestSource};
206 use crate::{Keep, Materializer, Source};
207 use std::{
208 sync::{Arc, Barrier},
209 thread,
210 };
211
212 #[test]
213 fn unique_kill_switch_shutdown_completes_and_cancels() {
214 let materializer = Materializer::new();
215 let ((source, switch), sink) = TestSource::probe::<i32>()
216 .via_mat(KillSwitches::single(), Keep::both)
217 .to_mat(TestSink::probe(), Keep::both)
218 .run_with_materializer(&materializer)
219 .expect("graph materializes");
220
221 sink.request(1);
222 assert_eq!(source.expect_request(), 1);
223 source.send_next(1);
224 sink.assert_next(1);
225
226 switch.shutdown();
227 switch.shutdown();
228 sink.request(1);
229 sink.expect_complete();
230 source.expect_cancellation();
231 }
232
233 #[test]
234 fn unique_kill_switch_abort_is_idempotent_after_shutdown() {
235 let materializer = Materializer::new();
236 let ((source, switch), sink) = TestSource::probe::<i32>()
237 .via_mat(KillSwitches::single(), Keep::both)
238 .to_mat(TestSink::probe(), Keep::both)
239 .run_with_materializer(&materializer)
240 .expect("graph materializes");
241
242 switch.shutdown();
243 switch.abort(StreamError::Failed("late abort".to_owned()));
244 sink.request(1);
245 sink.expect_complete();
246 source.expect_cancellation();
247 }
248
249 #[test]
250 fn unique_kill_switch_pre_materialization_shutdown_completes_immediately() {
251 let flow = KillSwitches::single::<i32>().map_materialized_value(|switch| {
252 switch.shutdown();
253 switch
254 });
255 let sink = Source::from_iter(1..=3)
256 .via_mat(flow, Keep::right)
257 .run_with(TestSink::probe())
258 .expect("test sink materializes");
259
260 sink.request(1);
261 sink.expect_complete();
262 }
263
264 #[test]
265 fn shared_kill_switch_fans_out_to_many_streams() {
266 let switch = KillSwitches::shared("shared-switch");
267 let materializer = Materializer::new();
268 let make_stream = || TestSource::probe::<i32>().via_mat(switch.flow(), Keep::both);
269
270 let ((source_a, shared_a), sink_a) = make_stream()
271 .to_mat(TestSink::probe(), Keep::both)
272 .run_with_materializer(&materializer)
273 .expect("first stream materializes");
274 let ((source_b, shared_b), sink_b) = make_stream()
275 .to_mat(TestSink::probe(), Keep::both)
276 .run_with_materializer(&materializer)
277 .expect("second stream materializes");
278
279 assert_eq!(shared_a.name(), "shared-switch");
280 assert_eq!(shared_b.name(), "shared-switch");
281
282 sink_a.request(1);
283 sink_b.request(1);
284 assert_eq!(source_a.expect_request(), 1);
285 assert_eq!(source_b.expect_request(), 1);
286
287 source_a.send_next(1);
288 source_b.send_next(2);
289 sink_a.assert_next(1);
290 sink_b.assert_next(2);
291
292 switch.abort(StreamError::Failed("shared abort".to_owned()));
293 switch.shutdown();
294 sink_a.request(1);
295 sink_b.request(1);
296 assert_eq!(
297 sink_a.expect_error(),
298 StreamError::Failed("shared abort".to_owned())
299 );
300 assert_eq!(
301 sink_b.expect_error(),
302 StreamError::Failed("shared abort".to_owned())
303 );
304 source_a.expect_cancellation();
305 source_b.expect_cancellation();
306 }
307
308 #[test]
309 fn shared_kill_switch_is_thread_safe() {
310 let switch = Arc::new(KillSwitches::shared("thread-safe"));
311 let clone = Arc::clone(&switch);
312
313 let handle = thread::spawn(move || {
314 clone.shutdown();
315 });
316 switch.shutdown();
317 handle.join().expect("kill switch thread joins");
318 }
319
320 #[test]
321 fn unique_kill_switch_materializations_stay_thread_local_and_independent() {
322 let flow = KillSwitches::single::<usize>();
323 let materializer = Arc::new(Materializer::new());
324 let barrier = Arc::new(Barrier::new(9));
325
326 let handles = (0..8)
327 .map(|idx| {
328 let flow = flow.clone();
329 let materializer = Arc::clone(&materializer);
330 let barrier = Arc::clone(&barrier);
331 thread::spawn(move || {
332 barrier.wait();
333 Source::repeat(idx)
334 .via_mat(flow, Keep::right)
335 .to_mat(TestSink::probe(), Keep::both)
336 .run_with_materializer(materializer.as_ref())
337 .expect("kill switch flow materializes")
338 })
339 })
340 .collect::<Vec<_>>();
341
342 barrier.wait();
343
344 let mut streams = handles
345 .into_iter()
346 .map(|handle| handle.join().expect("materialization thread joins"))
347 .collect::<Vec<_>>();
348
349 for (idx, (_switch, sink)) in streams.iter_mut().enumerate() {
350 sink.request(1);
351 sink.assert_next(idx);
352 }
353
354 streams[3].0.shutdown();
355
356 for (idx, (_switch, sink)) in streams.iter_mut().enumerate() {
357 sink.request(1);
358 if idx == 3 {
359 sink.expect_complete();
360 } else {
361 sink.assert_next(idx);
362 }
363 }
364 }
365}