use std::sync::{
Arc, Mutex,
atomic::{AtomicU8, Ordering},
};
use crate::stream::{BoxStream, Flow};
use crate::{StreamError, StreamResult};
#[derive(Clone, Debug)]
pub struct UniqueKillSwitch {
state: Arc<KillSwitchState>,
}
impl UniqueKillSwitch {
pub fn shutdown(&self) {
self.state.shutdown();
}
pub fn abort(&self, error: StreamError) {
self.state.abort(error);
}
}
#[derive(Clone, Debug)]
pub struct SharedKillSwitch {
name: Arc<str>,
state: Arc<KillSwitchState>,
}
impl SharedKillSwitch {
fn new(name: impl Into<Arc<str>>) -> Self {
Self {
name: name.into(),
state: Arc::new(KillSwitchState::default()),
}
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
pub fn shutdown(&self) {
self.state.shutdown();
}
pub fn abort(&self, error: StreamError) {
self.state.abort(error);
}
#[must_use]
pub fn flow<T: Send + 'static>(&self) -> Flow<T, T, SharedKillSwitch> {
let state = Arc::clone(&self.state);
let switch = self.clone();
Flow::from_parts(
move |input| Box::new(KillSwitchStream::new(input, Arc::clone(&state))),
move || Ok(switch.clone()),
)
}
}
pub struct KillSwitches;
impl KillSwitches {
#[must_use]
pub fn single<T: Send + 'static>() -> Flow<T, T, UniqueKillSwitch> {
Flow::from_materialized_factory(move || {
let state = Arc::new(KillSwitchState::default());
let switch = UniqueKillSwitch {
state: Arc::clone(&state),
};
let transform = Arc::new(move |input| {
Box::new(KillSwitchStream::new(input, Arc::clone(&state))) as BoxStream<T>
});
(transform, switch)
})
}
#[must_use]
pub fn shared(name: impl Into<Arc<str>>) -> SharedKillSwitch {
SharedKillSwitch::new(name)
}
}
struct KillSwitchStream<T> {
input: BoxStream<T>,
state: Arc<KillSwitchState>,
terminated: bool,
}
const KILL_SWITCH_OPEN: u8 = 0;
const KILL_SWITCH_SHUTDOWN: u8 = 1;
const KILL_SWITCH_ABORTED: u8 = 2;
impl<T> KillSwitchStream<T> {
fn new(input: BoxStream<T>, state: Arc<KillSwitchState>) -> Self {
Self {
input,
state,
terminated: false,
}
}
}
impl<T> Iterator for KillSwitchStream<T> {
type Item = StreamResult<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.terminated {
return None;
}
match self.state.current() {
KillSwitchStatus::Open => {}
KillSwitchStatus::Shutdown => {
self.terminated = true;
return None;
}
KillSwitchStatus::Aborted(error) => {
self.terminated = true;
return Some(Err(error));
}
}
let next = self.input.next();
if next.is_none() {
self.terminated = true;
}
next
}
}
#[derive(Clone, Debug, Default)]
enum KillSwitchStatus {
#[default]
Open,
Shutdown,
Aborted(StreamError),
}
#[derive(Debug, Default)]
struct KillSwitchState {
gate: AtomicU8,
status: Mutex<KillSwitchStatus>,
}
impl KillSwitchState {
fn shutdown(&self) {
if self
.gate
.compare_exchange(
KILL_SWITCH_OPEN,
KILL_SWITCH_SHUTDOWN,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_err()
{
return;
}
let mut status = self.status.lock().expect("kill switch poisoned");
*status = KillSwitchStatus::Shutdown;
}
fn abort(&self, error: StreamError) {
if self
.gate
.compare_exchange(
KILL_SWITCH_OPEN,
KILL_SWITCH_ABORTED,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_err()
{
return;
}
let mut status = self.status.lock().expect("kill switch poisoned");
*status = KillSwitchStatus::Aborted(error);
}
fn current(&self) -> KillSwitchStatus {
match self.gate.load(Ordering::Acquire) {
KILL_SWITCH_OPEN => KillSwitchStatus::Open,
KILL_SWITCH_SHUTDOWN => KillSwitchStatus::Shutdown,
KILL_SWITCH_ABORTED => self.status.lock().expect("kill switch poisoned").clone(),
gate => panic!("unexpected kill switch gate state {gate}"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testkit::{TestSink, TestSource};
use crate::{Keep, Materializer, Source};
use std::{
sync::{Arc, Barrier},
thread,
};
#[test]
fn unique_kill_switch_shutdown_completes_and_cancels() {
let materializer = Materializer::new();
let ((source, switch), sink) = TestSource::probe::<i32>()
.via_mat(KillSwitches::single(), Keep::both)
.to_mat(TestSink::probe(), Keep::both)
.run_with_materializer(&materializer)
.expect("graph materializes");
sink.request(1);
assert_eq!(source.expect_request(), 1);
source.send_next(1);
sink.assert_next(1);
switch.shutdown();
switch.shutdown();
sink.request(1);
sink.expect_complete();
source.expect_cancellation();
}
#[test]
fn unique_kill_switch_abort_is_idempotent_after_shutdown() {
let materializer = Materializer::new();
let ((source, switch), sink) = TestSource::probe::<i32>()
.via_mat(KillSwitches::single(), Keep::both)
.to_mat(TestSink::probe(), Keep::both)
.run_with_materializer(&materializer)
.expect("graph materializes");
switch.shutdown();
switch.abort(StreamError::Failed("late abort".to_owned()));
sink.request(1);
sink.expect_complete();
source.expect_cancellation();
}
#[test]
fn unique_kill_switch_pre_materialization_shutdown_completes_immediately() {
let flow = KillSwitches::single::<i32>().map_materialized_value(|switch| {
switch.shutdown();
switch
});
let sink = Source::from_iter(1..=3)
.via_mat(flow, Keep::right)
.run_with(TestSink::probe())
.expect("test sink materializes");
sink.request(1);
sink.expect_complete();
}
#[test]
fn shared_kill_switch_fans_out_to_many_streams() {
let switch = KillSwitches::shared("shared-switch");
let materializer = Materializer::new();
let make_stream = || TestSource::probe::<i32>().via_mat(switch.flow(), Keep::both);
let ((source_a, shared_a), sink_a) = make_stream()
.to_mat(TestSink::probe(), Keep::both)
.run_with_materializer(&materializer)
.expect("first stream materializes");
let ((source_b, shared_b), sink_b) = make_stream()
.to_mat(TestSink::probe(), Keep::both)
.run_with_materializer(&materializer)
.expect("second stream materializes");
assert_eq!(shared_a.name(), "shared-switch");
assert_eq!(shared_b.name(), "shared-switch");
sink_a.request(1);
sink_b.request(1);
assert_eq!(source_a.expect_request(), 1);
assert_eq!(source_b.expect_request(), 1);
source_a.send_next(1);
source_b.send_next(2);
sink_a.assert_next(1);
sink_b.assert_next(2);
switch.abort(StreamError::Failed("shared abort".to_owned()));
switch.shutdown();
sink_a.request(1);
sink_b.request(1);
assert_eq!(
sink_a.expect_error(),
StreamError::Failed("shared abort".to_owned())
);
assert_eq!(
sink_b.expect_error(),
StreamError::Failed("shared abort".to_owned())
);
source_a.expect_cancellation();
source_b.expect_cancellation();
}
#[test]
fn shared_kill_switch_is_thread_safe() {
let switch = Arc::new(KillSwitches::shared("thread-safe"));
let clone = Arc::clone(&switch);
let handle = thread::spawn(move || {
clone.shutdown();
});
switch.shutdown();
handle.join().expect("kill switch thread joins");
}
#[test]
fn unique_kill_switch_materializations_stay_thread_local_and_independent() {
let flow = KillSwitches::single::<usize>();
let materializer = Arc::new(Materializer::new());
let barrier = Arc::new(Barrier::new(9));
let handles = (0..8)
.map(|idx| {
let flow = flow.clone();
let materializer = Arc::clone(&materializer);
let barrier = Arc::clone(&barrier);
thread::spawn(move || {
barrier.wait();
Source::repeat(idx)
.via_mat(flow, Keep::right)
.to_mat(TestSink::probe(), Keep::both)
.run_with_materializer(materializer.as_ref())
.expect("kill switch flow materializes")
})
})
.collect::<Vec<_>>();
barrier.wait();
let mut streams = handles
.into_iter()
.map(|handle| handle.join().expect("materialization thread joins"))
.collect::<Vec<_>>();
for (idx, (_switch, sink)) in streams.iter_mut().enumerate() {
sink.request(1);
sink.assert_next(idx);
}
streams[3].0.shutdown();
for (idx, (_switch, sink)) in streams.iter_mut().enumerate() {
sink.request(1);
if idx == 3 {
sink.expect_complete();
} else {
sink.assert_next(idx);
}
}
}
}