use std::ops::ControlFlow;
use futures::future::Future;
use tokio::sync::mpsc;
#[cfg(feature = "ctrl_c")]
use tokio::sync::mpsc::error::SendError;
use crate::{InterruptSignal, InterruptibleFutureControl, InterruptibleFutureResult};
pub trait InterruptibleFutureExt<'rx, B, T> {
fn interruptible_control(
self,
interrupt_rx: &'rx mut mpsc::Receiver<InterruptSignal>,
) -> InterruptibleFutureControl<'rx, B, T, Self>
where
Self: Sized + Future<Output = ControlFlow<B, T>>,
B: From<(T, InterruptSignal)>;
fn interruptible_result(
self,
interrupt_rx: &'rx mut mpsc::Receiver<InterruptSignal>,
) -> InterruptibleFutureResult<'rx, T, B, Self>
where
Self: Sized + Future<Output = Result<T, B>>;
#[cfg(feature = "ctrl_c")]
fn interruptible_control_ctrl_c(self) -> InterruptibleFutureControl<'rx, B, T, Self>
where
Self: Sized + Future<Output = ControlFlow<B, T>>,
B: From<(T, InterruptSignal)>;
#[cfg(feature = "ctrl_c")]
fn interruptible_result_ctrl_c(self) -> InterruptibleFutureResult<'rx, T, B, Self>
where
Self: Sized + Future<Output = Result<T, B>>;
}
impl<'rx, B, T, Fut> InterruptibleFutureExt<'rx, B, T> for Fut
where
Fut: Future,
{
fn interruptible_control(
self,
interrupt_rx: &'rx mut mpsc::Receiver<InterruptSignal>,
) -> InterruptibleFutureControl<'rx, B, T, Self>
where
Self: Sized + Future<Output = ControlFlow<B, T>>,
B: From<(T, InterruptSignal)>,
{
InterruptibleFutureControl::new(self, interrupt_rx.into())
}
fn interruptible_result(
self,
interrupt_rx: &'rx mut mpsc::Receiver<InterruptSignal>,
) -> InterruptibleFutureResult<'rx, T, B, Self>
where
Self: Sized + Future<Output = Result<T, B>>,
{
InterruptibleFutureResult::new(self, interrupt_rx.into())
}
#[cfg(feature = "ctrl_c")]
#[cfg_attr(coverage_nightly, coverage(off))]
fn interruptible_control_ctrl_c(self) -> InterruptibleFutureControl<'rx, B, T, Self>
where
Self: Sized + Future<Output = ControlFlow<B, T>>,
B: From<(T, InterruptSignal)>,
{
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
tokio::task::spawn(
#[cfg_attr(coverage_nightly, coverage(off))]
async move {
tokio::signal::ctrl_c()
.await
.expect("Failed to initialize signal handler for SIGINT");
let (Ok(()) | Err(SendError(InterruptSignal))) =
interrupt_tx.send(InterruptSignal).await;
},
);
InterruptibleFutureControl::new(self, interrupt_rx.into())
}
#[cfg(feature = "ctrl_c")]
#[cfg_attr(coverage_nightly, coverage(off))]
fn interruptible_result_ctrl_c(self) -> InterruptibleFutureResult<'rx, T, B, Self>
where
Self: Sized + Future<Output = Result<T, B>>,
{
let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
tokio::task::spawn(
#[cfg_attr(coverage_nightly, coverage(off))]
async move {
tokio::signal::ctrl_c()
.await
.expect("Failed to initialize signal handler for SIGINT");
let (Ok(()) | Err(SendError(InterruptSignal))) =
interrupt_tx.send(InterruptSignal).await;
},
);
InterruptibleFutureResult::new(self, interrupt_rx.into())
}
}
#[cfg(test)]
mod tests {
use std::ops::ControlFlow;
use futures::FutureExt;
use tokio::{
join,
sync::{
mpsc::{self, error::SendError},
oneshot,
},
};
use super::InterruptibleFutureExt;
use crate::InterruptSignal;
#[tokio::test]
async fn interrupt_overrides_control_future_continue_unit_value() {
let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let (ready_tx, ready_rx) = oneshot::channel::<()>();
let interruptible_control = async {
let () = ready_rx.await.expect("Expected to be notified to start.");
ControlFlow::Continue(())
}
.boxed()
.interruptible_control(&mut interrupt_rx);
let interrupter = async move {
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
ready_tx
.send(())
.expect("Expected to notify sleep to start.");
};
let (control_flow, ()) = join!(interruptible_control, interrupter);
assert_eq!(ControlFlow::Break(InterruptSignal), control_flow);
}
#[tokio::test]
async fn interrupt_overrides_control_future_continue_value() {
let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let (ready_tx, ready_rx) = oneshot::channel::<()>();
let interruptible_control = async {
let () = ready_rx.await.expect("Expected to be notified to start.");
ControlFlow::Continue(FutEnd {
value: 1,
interrupted: false,
})
}
.boxed()
.interruptible_control(&mut interrupt_rx);
let interrupter = async move {
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
ready_tx
.send(())
.expect("Expected to notify sleep to start.");
};
let (control_flow, ()) = join!(interruptible_control, interrupter);
assert_eq!(
ControlFlow::Break(FutEnd {
value: 1,
interrupted: true,
}),
control_flow
);
}
#[tokio::test]
async fn interrupt_does_not_override_control_future_break_value() {
let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let (ready_tx, ready_rx) = oneshot::channel::<()>();
let interruptible_control = async {
let () = ready_rx.await.expect("Expected to be notified to start.");
ControlFlow::Break(FutEnd {
value: 1,
interrupted: false,
})
}
.boxed()
.interruptible_control(&mut interrupt_rx);
let interrupter = async move {
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
ready_tx
.send(())
.expect("Expected to notify sleep to start.");
};
let (control_flow, ()) = join!(interruptible_control, interrupter);
assert_eq!(
ControlFlow::Break(FutEnd {
value: 1,
interrupted: false,
}),
control_flow
);
}
#[tokio::test]
async fn interrupt_after_control_future_completes_does_not_override_value() {
let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let interruptible_control = async { ControlFlow::<InterruptSignal, ()>::Continue(()) }
.boxed()
.interruptible_control(&mut interrupt_rx);
let interrupter = async move {
let (Ok(()) | Err(SendError(InterruptSignal))) =
interrupt_tx.send(InterruptSignal).await;
};
let (control_flow, ()) = join!(interruptible_control, interrupter);
assert_eq!(ControlFlow::Continue(()), control_flow);
}
#[tokio::test]
async fn interrupt_overrides_result_future_return_value() {
let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let (ready_tx, ready_rx) = oneshot::channel::<()>();
let interruptible_result = async {
let () = ready_rx.await.expect("Expected to be notified to start.");
Ok(())
}
.boxed()
.interruptible_result(&mut interrupt_rx);
let interrupter = async move {
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
ready_tx
.send(())
.expect("Expected to notify sleep to start.");
};
let (result_flow, ()) = join!(interruptible_result, interrupter);
assert_eq!(Err(InterruptSignal), result_flow);
}
#[tokio::test]
async fn interrupt_does_not_override_result_future_err_value() {
let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let (ready_tx, ready_rx) = oneshot::channel::<()>();
let interruptible_result = async {
let () = ready_rx.await.expect("Expected to be notified to start.");
Err(FutEnd {
value: 1,
interrupted: false,
})
}
.boxed()
.interruptible_result(&mut interrupt_rx);
let interrupter = async move {
interrupt_tx
.send(InterruptSignal)
.await
.expect("Expected to send `InterruptSignal`.");
ready_tx
.send(())
.expect("Expected to notify sleep to start.");
};
let (result_flow, ()) = join!(interruptible_result, interrupter);
assert_eq!(
Err(FutEnd {
value: 1,
interrupted: false,
}),
result_flow
);
}
#[tokio::test]
async fn interrupt_after_result_future_completes_does_not_override_value() {
let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
let interruptible_result = async { Result::<(), InterruptSignal>::Ok(()) }
.boxed()
.interruptible_result(&mut interrupt_rx);
let interrupter = async move {
let (Ok(()) | Err(SendError(InterruptSignal))) =
interrupt_tx.send(InterruptSignal).await;
};
let (result_flow, ()) = join!(interruptible_result, interrupter);
assert_eq!(Ok(()), result_flow);
}
#[derive(Debug, PartialEq, Eq)]
struct FutEnd {
value: usize,
interrupted: bool,
}
impl From<(FutEnd, InterruptSignal)> for FutEnd {
fn from((mut fut_end, InterruptSignal): (FutEnd, InterruptSignal)) -> Self {
fut_end.interrupted = true;
fut_end
}
}
#[test]
fn debug() {
let fut_end = FutEnd {
value: 1,
interrupted: false,
};
assert_eq!(
"FutEnd { value: 1, interrupted: false }",
format!("{fut_end:?}")
);
}
}