use futures::stream::Stream;
use own::OwnedOrMutRef;
use tokio::sync::mpsc;
#[cfg(feature = "ctrl_c")]
use tokio::sync::mpsc::error::SendError;
use crate::{
InterruptSignal, InterruptStrategy, Interruptibility, InterruptibilityState,
InterruptibleStream,
};
pub trait InterruptibleStreamExt {
fn interruptible(
self,
interrupt_rx: OwnedOrMutRef<'_, mpsc::Receiver<InterruptSignal>>,
) -> InterruptibleStream<'_, 'static, Self>
where
Self: Sized;
fn interruptible_with<'rx, 'intx>(
self,
interruptibility_state: InterruptibilityState<'rx, 'intx>,
) -> InterruptibleStream<'rx, 'intx, Self>
where
Self: Sized + 'rx;
#[cfg(feature = "ctrl_c")]
fn interruptible_ctrl_c(self) -> InterruptibleStream<'static, 'static, Self>
where
Self: Sized;
}
impl<S> InterruptibleStreamExt for S
where
S: Stream,
{
fn interruptible(
self,
interrupt_rx: OwnedOrMutRef<'_, mpsc::Receiver<InterruptSignal>>,
) -> InterruptibleStream<'_, 'static, Self>
where
Self: Sized,
{
InterruptibleStream::new(
self,
Interruptibility::Interruptible {
interrupt_rx,
interrupt_strategy: InterruptStrategy::FinishCurrent,
}
.into(),
)
}
fn interruptible_with<'rx, 'intx>(
self,
interruptibility_state: InterruptibilityState<'rx, 'intx>,
) -> InterruptibleStream<'rx, 'intx, Self>
where
Self: Sized + 'rx,
{
InterruptibleStream::new(self, interruptibility_state)
}
#[cfg(feature = "ctrl_c")]
#[cfg_attr(coverage_nightly, coverage(off))]
fn interruptible_ctrl_c(self) -> InterruptibleStream<'static, 'static, Self>
where
Self: Sized,
{
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
tokio::task::spawn(
#[cfg_attr(coverage_nightly, coverage(off))]
async move {
tokio::signal::ctrl_c()
.await
.expect("Failed to initialize signal handler for `SIGINT`.");
let (Ok(()) | Err(SendError(InterruptSignal))) =
interrupt_tx.send(InterruptSignal).await;
},
);
InterruptibleStream::new(
self,
Interruptibility::Interruptible {
interrupt_rx: interrupt_rx.into(),
interrupt_strategy: InterruptStrategy::FinishCurrent,
}
.into(),
)
}
}
#[cfg(test)]
mod tests {
use futures::{stream, StreamExt};
use tokio::{
sync::{
mpsc::{self, error::SendError},
oneshot,
},
task::yield_now,
};
use super::InterruptibleStreamExt;
use crate::{InterruptSignal, Interruptibility, PollOutcome};
#[tokio::test]
async fn interrupt_during_future_overrides_stream_return_value() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let (interrupt_ready_tx, interrupt_ready_rx) = oneshot::channel::<()>();
let (interrupted_tx, interrupted_rx) = oneshot::channel::<()>();
let mut interruptible_stream = stream::unfold(
(0u32, Some((interrupt_ready_tx, interrupted_rx))),
move |(n, channel_tx_rx)| async move {
if let Some((interrupt_ready_tx, interrupted_rx)) = channel_tx_rx {
interrupt_ready_tx
.send(())
.expect("Expected to send to interrupt ready channel.");
let () = interrupted_rx
.await
.expect("Expected to be notified to return value.");
}
Some((n, (n + 1, None)))
},
)
.interruptible(interrupt_rx.into());
let interrupt_task = async {
interrupt_ready_rx
.await
.expect("Expected `interrupt_ready_rx`. to receive message.");
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
interrupted_tx
.send(())
.expect("Expected to notify future to return value.");
};
let (poll_outcome, ()) = tokio::join!(interruptible_stream.next(), interrupt_task);
assert_eq!(Some(PollOutcome::Interrupted(Some(0u32))), poll_outcome);
assert_eq!(None, interruptible_stream.next().await);
}
#[tokio::test]
async fn interrupt_with_finish_current_before_start_returns_interrupt_before_poll() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let (ready_tx, ready_rx) = oneshot::channel::<()>();
let mut interruptible_stream = stream::unfold(
(0u32, Some(ready_rx)),
#[cfg_attr(coverage_nightly, coverage(off))]
move |(n, ready_rx)| {
#[cfg_attr(coverage_nightly, coverage(off))]
async move {
if let Some(ready_rx) = ready_rx {
let () = ready_rx
.await
.expect("Expected to be notified to return value.");
}
Some((n, (n + 1, None)))
}
},
)
.interruptible_with(Interruptibility::finish_current(interrupt_rx.into()).into());
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
ready_tx
.send(())
.expect("Expected to notify future to return value.");
assert_eq!(
Some(PollOutcome::Interrupted(None)),
interruptible_stream.next().await
);
assert_eq!(None, interruptible_stream.next().await);
}
#[tokio::test]
async fn interrupt_with_finish_current_during_future_overrides_stream_return_value() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let (interrupt_ready_tx, interrupt_ready_rx) = oneshot::channel::<()>();
let (interrupted_tx, interrupted_rx) = oneshot::channel::<()>();
let mut interruptible_stream = stream::unfold(
(0u32, Some((interrupt_ready_tx, interrupted_rx))),
move |(n, channel_tx_rx)| async move {
if let Some((interrupt_ready_tx, interrupted_rx)) = channel_tx_rx {
interrupt_ready_tx
.send(())
.expect("Expected to send to interrupt ready channel.");
let () = interrupted_rx
.await
.expect("Expected to be notified to return value.");
}
Some((n, (n + 1, None)))
},
)
.interruptible_with(Interruptibility::finish_current(interrupt_rx.into()).into());
let interrupt_task = async {
interrupt_ready_rx
.await
.expect("Expected `interrupt_ready_rx`. to receive message.");
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
interrupted_tx
.send(())
.expect("Expected to notify future to return value.");
};
let (poll_outcome, ()) = tokio::join!(interruptible_stream.next(), interrupt_task);
assert_eq!(Some(PollOutcome::Interrupted(Some(0u32))), poll_outcome);
}
#[tokio::test]
async fn interrupt_with_poll_next_n_before_start_returns_n_items() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let (interrupted_tx, interrupted_rx) = oneshot::channel::<()>();
let mut interruptible_stream = stream::unfold(
(0u32, Some(interrupted_rx)),
move |(n, interrupted_rx)| async move {
if let Some(interrupted_rx) = interrupted_rx {
let () = interrupted_rx
.await
.expect("Expected to be notified to return value.");
}
if n < 3 {
Some((n, (n + 1, None)))
} else {
None
}
},
)
.interruptible_with(Interruptibility::poll_next_n(interrupt_rx.into(), 2).into());
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
interrupted_tx
.send(())
.expect("Expected to notify future to return value.");
assert_eq!(
Some(PollOutcome::NoInterrupt(0u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::NoInterrupt(1u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::Interrupted(None)),
interruptible_stream.next().await
);
assert_eq!(None, interruptible_stream.next().await);
}
#[tokio::test]
async fn interrupt_with_poll_next_n_returns_n_items_variant_interrupt_before_0() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let mut interruptible_stream = stream::unfold(0u32, move |n| async move {
if n < 3 { Some((n, n + 1)) } else { None }
})
.interruptible_with(Interruptibility::poll_next_n(interrupt_rx.into(), 0).into());
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
assert_eq!(
Some(PollOutcome::Interrupted(None)),
interruptible_stream.next().await
);
assert_eq!(None, interruptible_stream.next().await);
}
#[tokio::test]
async fn interrupt_with_poll_next_n_returns_n_items_variant_interrupt_before_1() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let mut interruptible_stream = stream::unfold(0u32, move |n| async move {
if n < 3 { Some((n, n + 1)) } else { None }
})
.interruptible_with(Interruptibility::poll_next_n(interrupt_rx.into(), 1).into());
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
assert_eq!(
Some(PollOutcome::NoInterrupt(0u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::Interrupted(None)),
interruptible_stream.next().await
);
assert_eq!(None, interruptible_stream.next().await);
}
#[tokio::test]
async fn interrupt_with_poll_next_n_returns_n_items_variant_interrupt_before_2() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let mut interruptible_stream = stream::unfold(0u32, move |n| async move {
if n < 10 { Some((n, n + 1)) } else { None }
})
.interruptible_with(Interruptibility::poll_next_n(interrupt_rx.into(), 2).into());
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
assert_eq!(
Some(PollOutcome::NoInterrupt(0u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::NoInterrupt(1u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::Interrupted(None)),
interruptible_stream.next().await
);
assert_eq!(None, interruptible_stream.next().await);
}
#[tokio::test]
async fn interrupt_with_poll_next_n_returns_n_items_variant_interrupt_before_2b() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let mut interruptible_stream = stream::unfold(0u32, move |n| async move {
if n < 2 {
yield_now().await;
yield_now().await;
yield_now().await;
yield_now().await;
}
if n < 10 { Some((n, n + 1)) } else { None }
})
.interruptible_with(Interruptibility::poll_next_n(interrupt_rx.into(), 2).into());
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
assert_eq!(
Some(PollOutcome::NoInterrupt(0u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::NoInterrupt(1u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::Interrupted(None)),
interruptible_stream.next().await
);
assert_eq!(None, interruptible_stream.next().await);
}
#[tokio::test]
async fn interrupt_with_poll_next_n_returns_n_items_variant_interrupt_before_3() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let mut interruptible_stream = stream::unfold(0u32, move |n| async move {
if n < 10 { Some((n, n + 1)) } else { None }
})
.interruptible_with(Interruptibility::poll_next_n(interrupt_rx.into(), 3).into());
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
assert_eq!(
Some(PollOutcome::NoInterrupt(0u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::NoInterrupt(1u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::NoInterrupt(2u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::Interrupted(None)),
interruptible_stream.next().await
);
assert_eq!(None, interruptible_stream.next().await);
}
#[tokio::test]
async fn interrupt_with_poll_next_n_returns_n_items_variant_interrupt_before_6() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let mut interruptible_stream = stream::unfold(0u32, move |n| async move {
if n < 10 { Some((n, n + 1)) } else { None }
})
.interruptible_with(Interruptibility::poll_next_n(interrupt_rx.into(), 6).into());
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
assert_eq!(
Some(PollOutcome::NoInterrupt(0u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::NoInterrupt(1u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::NoInterrupt(2u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::NoInterrupt(3u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::NoInterrupt(4u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::NoInterrupt(5u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::Interrupted(None)),
interruptible_stream.next().await
);
assert_eq!(None, interruptible_stream.next().await);
}
#[tokio::test]
async fn interrupt_with_poll_next_n_returns_n_items_variant_interrupt_between_0() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let (interrupt_ready_tx, mut interrupt_ready_rx) = mpsc::channel::<()>(2);
let (interrupted_tx, interrupted_rx) = oneshot::channel::<()>();
let mut interruptible_stream = stream::unfold(
(0u32, Some((interrupt_ready_tx, interrupted_rx))),
move |(n, channel_tx_rx)| async move {
if let Some((interrupt_ready_tx, interrupted_rx)) = channel_tx_rx {
interrupt_ready_tx
.send(())
.await
.expect("Expected to send to interrupt ready channel.");
let () = interrupted_rx
.await
.expect("Expected to be notified to return value.");
}
if n < 3 {
Some((n, (n + 1, None)))
} else {
None
}
},
)
.interruptible_with(Interruptibility::poll_next_n(interrupt_rx.into(), 0).into());
let interrupt_task = async {
interrupt_ready_rx
.recv()
.await
.expect("Expected `interrupt_ready_rx`. to receive message.");
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
interrupted_tx
.send(())
.expect("Expected to notify future to return value.");
};
let ((poll_outcome_first, poll_outcome_second), ()) = tokio::join!(
async {
(
interruptible_stream.next().await,
interruptible_stream.next().await,
)
},
interrupt_task
);
assert_eq!(
Some(PollOutcome::Interrupted(Some(0u32))),
poll_outcome_first
);
assert_eq!(None, poll_outcome_second);
}
#[tokio::test]
async fn interrupt_with_poll_next_n_returns_n_items_variant_interrupt_between_1() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let (interrupt_ready_tx, mut interrupt_ready_rx) = mpsc::channel::<()>(2);
let (interrupted_tx, interrupted_rx) = oneshot::channel::<()>();
let mut interruptible_stream = stream::unfold(
(0u32, Some((interrupt_ready_tx, interrupted_rx))),
move |(n, channel_tx_rx)| async move {
if n == 0 {
return Some((n, (n + 1, channel_tx_rx)));
}
if let Some((interrupt_ready_tx, interrupted_rx)) = channel_tx_rx {
interrupt_ready_tx
.send(())
.await
.expect("Expected to send to interrupt ready channel.");
let () = interrupted_rx
.await
.expect("Expected to be notified to return value.");
yield_now().await;
}
if n < 3 {
Some((n, (n + 1, None)))
} else {
None
}
},
)
.interruptible_with(Interruptibility::poll_next_n(interrupt_rx.into(), 1).into());
let interrupt_task = async {
interrupt_ready_rx
.recv()
.await
.expect("Expected `interrupt_ready_rx`. to receive message.");
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
interrupted_tx
.send(())
.expect("Expected to notify future to return value.");
};
let ((poll_outcome_first, poll_outcome_second), ()) = tokio::join!(
async {
(
interruptible_stream.next().await,
interruptible_stream.next().await,
)
},
interrupt_task
);
assert_eq!(Some(PollOutcome::NoInterrupt(0u32)), poll_outcome_first);
assert_eq!(Some(PollOutcome::NoInterrupt(1u32)), poll_outcome_second);
assert_eq!(
Some(PollOutcome::Interrupted(None)),
interruptible_stream.next().await
);
assert_eq!(None, interruptible_stream.next().await);
}
#[tokio::test]
async fn interrupt_with_poll_next_n_returns_n_items_variant_interrupt_between_2() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let (interrupt_ready_tx, mut interrupt_ready_rx) = mpsc::channel::<()>(2);
let (interrupted_tx, interrupted_rx) = oneshot::channel::<()>();
let mut interruptible_stream = stream::unfold(
(0u32, Some((interrupt_ready_tx, interrupted_rx))),
move |(n, channel_tx_rx)| async move {
if n <= 1 {
return Some((n, (n + 1, channel_tx_rx)));
}
if let Some((interrupt_ready_tx, interrupted_rx)) = channel_tx_rx {
interrupt_ready_tx
.send(())
.await
.expect("Expected to send to interrupt ready channel.");
let () = interrupted_rx
.await
.expect("Expected to be notified to return value.");
yield_now().await;
yield_now().await;
}
if n < 10 {
Some((n, (n + 1, None)))
} else {
None
}
},
)
.interruptible_with(Interruptibility::poll_next_n(interrupt_rx.into(), 2).into());
let interrupt_task = async {
interrupt_ready_rx
.recv()
.await
.expect("Expected `interrupt_ready_rx`. to receive message.");
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
interrupted_tx
.send(())
.expect("Expected to notify future to return value.");
};
let ((poll_outcome_first, poll_outcome_second, poll_outcome_third), ()) = tokio::join!(
async {
(
interruptible_stream.next().await,
interruptible_stream.next().await,
interruptible_stream.next().await,
)
},
interrupt_task
);
assert_eq!(Some(PollOutcome::NoInterrupt(0u32)), poll_outcome_first);
assert_eq!(Some(PollOutcome::NoInterrupt(1u32)), poll_outcome_second);
assert_eq!(Some(PollOutcome::NoInterrupt(2u32)), poll_outcome_third);
assert_eq!(
Some(PollOutcome::NoInterrupt(3u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::Interrupted(None)),
interruptible_stream.next().await
);
assert_eq!(None, interruptible_stream.next().await);
}
#[tokio::test]
async fn interrupt_with_poll_next_n_returns_no_interrupt_when_not_interrupted() {
let (_interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let mut interruptible_stream = stream::unfold(0u32, move |n| async move {
if n < 3 { Some((n, n + 1)) } else { None }
})
.interruptible_with(Interruptibility::poll_next_n(interrupt_rx.into(), 1).into());
assert_eq!(
Some(PollOutcome::NoInterrupt(0u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::NoInterrupt(1u32)),
interruptible_stream.next().await
);
assert_eq!(
Some(PollOutcome::NoInterrupt(2u32)),
interruptible_stream.next().await
);
assert_eq!(None, interruptible_stream.next().await);
}
#[tokio::test]
async fn interrupt_after_stream_completes_does_not_override_value() {
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let mut interruptible_stream = stream::unfold(0u32, move |n| async move {
if n < 3 { Some((n, n + 1)) } else { None }
})
.interruptible(interrupt_rx.into());
let poll_outcome = interruptible_stream.next().await;
let (Ok(()) | Err(SendError(InterruptSignal))) = interrupt_tx.send(InterruptSignal).await;
assert_eq!(Some(PollOutcome::NoInterrupt(0)), poll_outcome);
}
}