interruptible 0.2.4

Stops a future producer or stream from producing values when interrupted.
Documentation
use std::ops::ControlFlow;

use futures::future::Future;
use tokio::sync::mpsc;

#[cfg(feature = "ctrl_c")]
use tokio::sync::mpsc::error::SendError;

use crate::{InterruptSignal, InterruptibleFutureControl, InterruptibleFutureResult};

/// Provides the `.interruptible_control()` and `.interruptible_result()`
/// methods for `Future`s to return [`ControlFlow::Break`] or [`Result::Err`]
/// when an interrupt signal is received.
pub trait InterruptibleFutureExt<'rx, B, T> {
    /// Overrides this `Future`'s control flow when an interrupt signal is
    /// received.
    ///
    /// # Parameters
    ///
    /// * `interrupt_rx`: Channel receiver of the interrupt signal.
    fn interruptible_control(
        self,
        interrupt_rx: &'rx mut mpsc::Receiver<InterruptSignal>,
    ) -> InterruptibleFutureControl<'rx, B, T, Self>
    where
        Self: Sized + Future<Output = ControlFlow<B, T>>,
        B: From<(T, InterruptSignal)>;

    /// Overrides this `Future`'s result when an interrupt signal is received.
    ///
    /// # Parameters
    ///
    /// * `interrupt_rx`: Channel receiver of the interrupt signal.
    fn interruptible_result(
        self,
        interrupt_rx: &'rx mut mpsc::Receiver<InterruptSignal>,
    ) -> InterruptibleFutureResult<'rx, T, B, Self>
    where
        Self: Sized + Future<Output = Result<T, B>>;

    /// Overrides this `Future`'s control flow when an interrupt signal is
    /// received through `Ctrl C`.
    ///
    /// # Parameters
    ///
    /// * `interrupt_rx`: Channel receiver of the interrupt signal.
    #[cfg(feature = "ctrl_c")]
    fn interruptible_control_ctrl_c(self) -> InterruptibleFutureControl<'rx, B, T, Self>
    where
        Self: Sized + Future<Output = ControlFlow<B, T>>,
        B: From<(T, InterruptSignal)>;

    /// Overrides this `Future`'s result when an interrupt signal is received
    /// through `Ctrl C`.
    ///
    /// # Parameters
    ///
    /// * `interrupt_rx`: Channel receiver of the interrupt signal.
    #[cfg(feature = "ctrl_c")]
    fn interruptible_result_ctrl_c(self) -> InterruptibleFutureResult<'rx, T, B, Self>
    where
        Self: Sized + Future<Output = Result<T, B>>;
}

impl<'rx, B, T, Fut> InterruptibleFutureExt<'rx, B, T> for Fut
where
    Fut: Future,
{
    fn interruptible_control(
        self,
        interrupt_rx: &'rx mut mpsc::Receiver<InterruptSignal>,
    ) -> InterruptibleFutureControl<'rx, B, T, Self>
    where
        Self: Sized + Future<Output = ControlFlow<B, T>>,
        B: From<(T, InterruptSignal)>,
    {
        InterruptibleFutureControl::new(self, interrupt_rx.into())
    }

    fn interruptible_result(
        self,
        interrupt_rx: &'rx mut mpsc::Receiver<InterruptSignal>,
    ) -> InterruptibleFutureResult<'rx, T, B, Self>
    where
        Self: Sized + Future<Output = Result<T, B>>,
    {
        InterruptibleFutureResult::new(self, interrupt_rx.into())
    }

    #[cfg(feature = "ctrl_c")]
    #[cfg_attr(coverage_nightly, coverage(off))]
    fn interruptible_control_ctrl_c(self) -> InterruptibleFutureControl<'rx, B, T, Self>
    where
        Self: Sized + Future<Output = ControlFlow<B, T>>,
        B: From<(T, InterruptSignal)>,
    {
        let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);

        tokio::task::spawn(
            #[cfg_attr(coverage_nightly, coverage(off))]
            async move {
                tokio::signal::ctrl_c()
                    .await
                    .expect("Failed to initialize signal handler for SIGINT");

                let (Ok(()) | Err(SendError(InterruptSignal))) =
                    interrupt_tx.send(InterruptSignal).await;
            },
        );

        InterruptibleFutureControl::new(self, interrupt_rx.into())
    }

    #[cfg(feature = "ctrl_c")]
    #[cfg_attr(coverage_nightly, coverage(off))]
    fn interruptible_result_ctrl_c(self) -> InterruptibleFutureResult<'rx, T, B, Self>
    where
        Self: Sized + Future<Output = Result<T, B>>,
    {
        let (interrupt_tx, interrupt_rx) = mpsc::channel::<InterruptSignal>(16);

        tokio::task::spawn(
            #[cfg_attr(coverage_nightly, coverage(off))]
            async move {
                tokio::signal::ctrl_c()
                    .await
                    .expect("Failed to initialize signal handler for SIGINT");

                let (Ok(()) | Err(SendError(InterruptSignal))) =
                    interrupt_tx.send(InterruptSignal).await;
            },
        );

        InterruptibleFutureResult::new(self, interrupt_rx.into())
    }
}

#[cfg(test)]
mod tests {
    use std::ops::ControlFlow;

    use futures::FutureExt;
    use tokio::{
        join,
        sync::{
            mpsc::{self, error::SendError},
            oneshot,
        },
    };

    use super::InterruptibleFutureExt;
    use crate::InterruptSignal;

    #[tokio::test]
    async fn interrupt_overrides_control_future_continue_unit_value() {
        let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
        let (ready_tx, ready_rx) = oneshot::channel::<()>();

        let interruptible_control = async {
            let () = ready_rx.await.expect("Expected to be notified to start.");
            ControlFlow::Continue(())
        }
        .boxed()
        .interruptible_control(&mut interrupt_rx);

        let interrupter = async move {
            interrupt_tx
                .send(InterruptSignal)
                .await
                .expect("Expected to send `InterruptSignal`.");
            ready_tx
                .send(())
                .expect("Expected to notify sleep to start.");
        };

        let (control_flow, ()) = join!(interruptible_control, interrupter);

        assert_eq!(ControlFlow::Break(InterruptSignal), control_flow);
    }

    #[tokio::test]
    async fn interrupt_overrides_control_future_continue_value() {
        let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
        let (ready_tx, ready_rx) = oneshot::channel::<()>();

        let interruptible_control = async {
            let () = ready_rx.await.expect("Expected to be notified to start.");
            ControlFlow::Continue(FutEnd {
                value: 1,
                interrupted: false,
            })
        }
        .boxed()
        .interruptible_control(&mut interrupt_rx);

        let interrupter = async move {
            interrupt_tx
                .send(InterruptSignal)
                .await
                .expect("Expected to send `InterruptSignal`.");
            ready_tx
                .send(())
                .expect("Expected to notify sleep to start.");
        };

        let (control_flow, ()) = join!(interruptible_control, interrupter);

        assert_eq!(
            ControlFlow::Break(FutEnd {
                value: 1,
                interrupted: true,
            }),
            control_flow
        );
    }

    #[tokio::test]
    async fn interrupt_does_not_override_control_future_break_value() {
        let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
        let (ready_tx, ready_rx) = oneshot::channel::<()>();

        let interruptible_control = async {
            let () = ready_rx.await.expect("Expected to be notified to start.");
            ControlFlow::Break(FutEnd {
                value: 1,
                interrupted: false,
            })
        }
        .boxed()
        .interruptible_control(&mut interrupt_rx);

        let interrupter = async move {
            interrupt_tx
                .send(InterruptSignal)
                .await
                .expect("Expected to send `InterruptSignal`.");
            ready_tx
                .send(())
                .expect("Expected to notify sleep to start.");
        };

        let (control_flow, ()) = join!(interruptible_control, interrupter);

        assert_eq!(
            ControlFlow::Break(FutEnd {
                value: 1,
                interrupted: false,
            }),
            control_flow
        );
    }

    #[tokio::test]
    async fn interrupt_after_control_future_completes_does_not_override_value() {
        let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);

        let interruptible_control = async { ControlFlow::<InterruptSignal, ()>::Continue(()) }
            .boxed()
            .interruptible_control(&mut interrupt_rx);

        let interrupter = async move {
            let (Ok(()) | Err(SendError(InterruptSignal))) =
                interrupt_tx.send(InterruptSignal).await;
        };

        let (control_flow, ()) = join!(interruptible_control, interrupter);

        assert_eq!(ControlFlow::Continue(()), control_flow);
    }

    #[tokio::test]
    async fn interrupt_overrides_result_future_return_value() {
        let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
        let (ready_tx, ready_rx) = oneshot::channel::<()>();

        let interruptible_result = async {
            let () = ready_rx.await.expect("Expected to be notified to start.");
            Ok(())
        }
        .boxed()
        .interruptible_result(&mut interrupt_rx);

        let interrupter = async move {
            interrupt_tx
                .send(InterruptSignal)
                .await
                .expect("Expected to send `InterruptSignal`.");
            ready_tx
                .send(())
                .expect("Expected to notify sleep to start.");
        };

        let (result_flow, ()) = join!(interruptible_result, interrupter);

        assert_eq!(Err(InterruptSignal), result_flow);
    }

    #[tokio::test]
    async fn interrupt_does_not_override_result_future_err_value() {
        let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);
        let (ready_tx, ready_rx) = oneshot::channel::<()>();

        let interruptible_result = async {
            let () = ready_rx.await.expect("Expected to be notified to start.");
            Err(FutEnd {
                value: 1,
                interrupted: false,
            })
        }
        .boxed()
        .interruptible_result(&mut interrupt_rx);

        let interrupter = async move {
            interrupt_tx
                .send(InterruptSignal)
                .await
                .expect("Expected to send `InterruptSignal`.");
            ready_tx
                .send(())
                .expect("Expected to notify sleep to start.");
        };

        let (result_flow, ()) = join!(interruptible_result, interrupter);

        assert_eq!(
            Err(FutEnd {
                value: 1,
                interrupted: false,
            }),
            result_flow
        );
    }

    #[tokio::test]
    async fn interrupt_after_result_future_completes_does_not_override_value() {
        let (interrupt_tx, mut interrupt_rx) = mpsc::channel::<InterruptSignal>(16);

        let interruptible_result = async { Result::<(), InterruptSignal>::Ok(()) }
            .boxed()
            .interruptible_result(&mut interrupt_rx);

        let interrupter = async move {
            let (Ok(()) | Err(SendError(InterruptSignal))) =
                interrupt_tx.send(InterruptSignal).await;
        };

        let (result_flow, ()) = join!(interruptible_result, interrupter);

        assert_eq!(Ok(()), result_flow);
    }

    #[derive(Debug, PartialEq, Eq)]
    struct FutEnd {
        value: usize,
        interrupted: bool,
    }
    impl From<(FutEnd, InterruptSignal)> for FutEnd {
        fn from((mut fut_end, InterruptSignal): (FutEnd, InterruptSignal)) -> Self {
            fut_end.interrupted = true;
            fut_end
        }
    }

    #[test]
    fn debug() {
        let fut_end = FutEnd {
            value: 1,
            interrupted: false,
        };
        assert_eq!(
            "FutEnd { value: 1, interrupted: false }",
            format!("{fut_end:?}")
        );
    }
}