tokio-process-tools 0.9.0

Correctness-focused async subprocess orchestration for Tokio: bounded output, multi-consumer streams, output detection, guaranteed cleanup and graceful termination.
Documentation
use super::state::{BestEffortLiveQueue, IndexedEvent, Shared, SubscriberId};
use crate::output_stream::Subscription;
use crate::output_stream::event::StreamEvent;
use crate::output_stream::policy::{BestEffortDelivery, Delivery, NoReplay, Replay};
use std::collections::VecDeque;
use std::future::Future;
use std::marker::PhantomData;
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::sync::broadcast::error::RecvError;
use tokio::sync::mpsc;

#[derive(Debug)]
pub(super) struct FastSubscription {
    pub(super) receiver: broadcast::Receiver<StreamEvent>,
    pub(super) emit_terminal_event: Option<StreamEvent>,
}

impl FastSubscription {
    pub(super) async fn recv(&mut self) -> Option<StreamEvent> {
        if let Some(event) = self.emit_terminal_event.take() {
            return Some(event);
        }

        match self.receiver.recv().await {
            Ok(event) => Some(event),
            Err(RecvError::Closed) => None,
            Err(RecvError::Lagged(lagged)) => {
                tracing::warn!(lagged, "Broadcast subscriber is lagging behind");
                Some(StreamEvent::Gap)
            }
        }
    }
}

#[derive(Debug)]
pub(super) enum LiveReceiver {
    Reliable(mpsc::Receiver<IndexedEvent>),
    BestEffort(Arc<BestEffortLiveQueue>),
    Closed,
}

impl LiveReceiver {
    async fn recv(&mut self) -> Option<IndexedEvent> {
        match self {
            Self::Reliable(receiver) => receiver.recv().await,
            Self::BestEffort(queue) => queue.recv().await,
            Self::Closed => None,
        }
    }
}

#[derive(Debug)]
pub(super) struct SharedSubscription<D = BestEffortDelivery, R = NoReplay>
where
    D: Delivery,
    R: Replay,
{
    pub(super) shared: Arc<Shared>,
    pub(super) id: Option<SubscriberId>,
    pub(super) replay: VecDeque<IndexedEvent>,
    pub(super) live_start_seq: u64,
    pub(super) live_receiver: LiveReceiver,
    pub(super) _marker: PhantomData<fn() -> (D, R)>,
    pub(super) done: bool,
}

impl<D, R> Drop for SharedSubscription<D, R>
where
    D: Delivery,
    R: Replay,
{
    fn drop(&mut self) {
        if !self.done
            && let Some(id) = self.id.take()
        {
            let mut state = self.shared.state.lock().expect("broadcast state poisoned");
            state.remove_subscriber(id);
        }
    }
}

impl<D, R> SharedSubscription<D, R>
where
    D: Delivery,
    R: Replay,
{
    pub(super) async fn recv(&mut self) -> Option<StreamEvent> {
        if let Some(event) = self.replay.pop_front() {
            if matches!(event.event, StreamEvent::Eof | StreamEvent::ReadError(_)) {
                self.detach();
            }
            return Some(event.event);
        }

        loop {
            let event = self.live_receiver.recv().await?;
            if event.seq < self.live_start_seq {
                continue;
            }
            if matches!(event.event, StreamEvent::Eof | StreamEvent::ReadError(_)) {
                self.detach();
            }
            return Some(event.event);
        }
    }

    fn detach(&mut self) {
        if let Some(id) = self.id.take() {
            let mut state = self.shared.state.lock().expect("broadcast state poisoned");
            state.remove_subscriber(id);
        }
        self.done = true;
    }
}

#[derive(Debug)]
pub(super) enum BroadcastSubscription<D = BestEffortDelivery, R = NoReplay>
where
    D: Delivery,
    R: Replay,
{
    Fast(FastSubscription),
    Shared(SharedSubscription<D, R>),
}

impl<D, R> BroadcastSubscription<D, R>
where
    D: Delivery,
    R: Replay,
{
    pub(super) async fn recv(&mut self) -> Option<StreamEvent> {
        match self {
            BroadcastSubscription::Fast(subscription) => subscription.recv().await,
            BroadcastSubscription::Shared(subscription) => subscription.recv().await,
        }
    }
}

impl<D, R> Subscription for BroadcastSubscription<D, R>
where
    D: Delivery,
    R: Replay,
{
    #[allow(
        clippy::manual_async_fn,
        reason = "the trait method must expose a Send future for tokio::spawn"
    )]
    fn next_event(&mut self) -> impl Future<Output = Option<StreamEvent>> + Send + '_ {
        async move { self.recv().await }
    }
}

#[cfg(test)]
mod tests {
    use super::super::state::{SubscriberSender, append_event};
    use super::*;
    use crate::StreamReadError;
    use crate::{NumBytesExt, ReliableDelivery, ReplayEnabled, ReplayRetention, StreamConfig};
    use assertr::prelude::*;
    use std::io;

    fn best_effort_options(
        retention: ReplayRetention,
    ) -> StreamConfig<BestEffortDelivery, ReplayEnabled> {
        let builder = StreamConfig::builder().best_effort_delivery();
        match retention {
            ReplayRetention::LastChunks(chunks) => builder.replay_last_chunks(chunks),
            ReplayRetention::LastBytes(bytes) => builder.replay_last_bytes(bytes),
            ReplayRetention::All => builder.replay_all(),
        }
        .read_chunk_size(3.bytes())
        .max_buffered_chunks(1)
        .build()
    }

    fn reliable_no_replay_options() -> StreamConfig<ReliableDelivery, NoReplay> {
        StreamConfig::builder()
            .reliable_for_active_subscribers()
            .no_replay()
            .read_chunk_size(1.bytes())
            .max_buffered_chunks(1)
            .build()
    }

    fn reliable_replay_options(
        retention: ReplayRetention,
    ) -> StreamConfig<ReliableDelivery, ReplayEnabled> {
        let builder = StreamConfig::builder().reliable_for_active_subscribers();
        match retention {
            ReplayRetention::LastChunks(chunks) => builder.replay_last_chunks(chunks),
            ReplayRetention::LastBytes(bytes) => builder.replay_last_bytes(bytes),
            ReplayRetention::All => builder.replay_all(),
        }
        .read_chunk_size(1.bytes())
        .max_buffered_chunks(4)
        .build()
    }

    fn subscribe<D, R>(
        shared: &Arc<Shared>,
        options: StreamConfig<D, R>,
    ) -> SharedSubscription<D, R>
    where
        D: Delivery,
        R: Replay,
    {
        let (sender, live_receiver) = match options.delivery_guarantee() {
            crate::DeliveryGuarantee::ReliableForActiveSubscribers => {
                let (sender, receiver) = mpsc::channel(options.max_buffered_chunks);
                (
                    SubscriberSender::Reliable(sender),
                    LiveReceiver::Reliable(receiver),
                )
            }
            crate::DeliveryGuarantee::BestEffort => {
                let queue = Arc::new(BestEffortLiveQueue::new(options.max_buffered_chunks));
                (
                    SubscriberSender::BestEffort(Arc::clone(&queue)),
                    LiveReceiver::BestEffort(queue),
                )
            }
        };

        let mut state = shared.state.lock().expect("broadcast state poisoned");
        let (replay, live_start_seq) = state.replay_snapshot(options);
        let id = if state.closed || state.terminal.is_some() {
            None
        } else {
            Some(state.add_subscriber(sender))
        };
        drop(state);

        SharedSubscription {
            shared: Arc::clone(shared),
            id,
            replay,
            live_start_seq,
            live_receiver,
            _marker: PhantomData,
            done: false,
        }
    }

    async fn assert_next_chunk<D, R>(
        subscription: &mut SharedSubscription<D, R>,
        expected: &'static [u8],
    ) where
        D: Delivery,
        R: Replay,
    {
        match subscription.recv().await {
            Some(StreamEvent::Chunk(chunk)) => {
                assert_that!(chunk.as_ref()).is_equal_to(expected);
            }
            other => {
                assert_that!(&other).fail(format_args!("expected chunk, got {other:?}"));
            }
        }
    }

    #[tokio::test]
    async fn slow_best_effort_subscriber_observes_gap_then_newer_tail() {
        let options = best_effort_options(ReplayRetention::LastChunks(1));
        let shared = Arc::new(Shared::new());
        let mut subscription = subscribe(&shared, options);

        append_event(&shared, options, StreamEvent::chunk(b"old")).await;
        append_event(&shared, options, StreamEvent::chunk(b"new")).await;
        append_event(&shared, options, StreamEvent::Eof).await;

        assert_that!(subscription.recv().await)
            .is_some()
            .is_equal_to(StreamEvent::Gap);
        assert_that!(subscription.recv().await)
            .is_some()
            .is_equal_to(StreamEvent::Eof);
    }

    #[tokio::test]
    async fn eof_is_replayed_to_late_subscribers_before_seal() {
        let options = reliable_replay_options(ReplayRetention::All);
        let shared = Arc::new(Shared::new());

        append_event(&shared, options, StreamEvent::chunk(b"tail")).await;
        append_event(&shared, options, StreamEvent::Eof).await;

        let mut subscription = subscribe(&shared, options);
        assert_next_chunk(&mut subscription, b"tail").await;
        assert_that!(subscription.recv().await)
            .is_some()
            .is_equal_to(StreamEvent::Eof);
    }

    #[tokio::test]
    async fn no_replay_late_subscriber_observes_terminal_read_error() {
        let options = reliable_no_replay_options();
        let shared = Arc::new(Shared::new());

        append_event(&shared, options, StreamEvent::chunk(b"booting\n")).await;
        append_event(
            &shared,
            options,
            StreamEvent::ReadError(StreamReadError::new(
                "custom",
                io::Error::from(io::ErrorKind::BrokenPipe),
            )),
        )
        .await;

        let mut subscription = subscribe(&shared, options);
        match subscription.recv().await {
            Some(StreamEvent::ReadError(err)) => {
                assert_that!(err.stream_name()).is_equal_to("custom");
                assert_that!(err.kind()).is_equal_to(io::ErrorKind::BrokenPipe);
            }
            other => {
                assert_that!(&other).fail(format_args!("expected read error, got {other:?}"));
            }
        }
    }

    #[tokio::test]
    async fn replay_late_subscriber_observes_retained_output_then_read_error() {
        let options = reliable_replay_options(ReplayRetention::All);
        let shared = Arc::new(Shared::new());

        append_event(&shared, options, StreamEvent::chunk(b"booting\npartial")).await;
        append_event(
            &shared,
            options,
            StreamEvent::ReadError(StreamReadError::new(
                "custom",
                io::Error::from(io::ErrorKind::BrokenPipe),
            )),
        )
        .await;

        let mut subscription = subscribe(&shared, options);
        assert_next_chunk(&mut subscription, b"booting\npartial").await;
        match subscription.recv().await {
            Some(StreamEvent::ReadError(err)) => {
                assert_that!(err.stream_name()).is_equal_to("custom");
                assert_that!(err.kind()).is_equal_to(io::ErrorKind::BrokenPipe);
            }
            other => {
                assert_that!(&other).fail(format_args!("expected read error, got {other:?}"));
            }
        }
    }

    #[tokio::test]
    async fn active_subscription_does_not_duplicate_live_handoff() {
        let options = reliable_replay_options(ReplayRetention::All);
        let shared = Arc::new(Shared::new());

        append_event(&shared, options, StreamEvent::chunk(b"replay")).await;
        let mut subscription = subscribe(&shared, options);
        append_event(&shared, options, StreamEvent::chunk(b"live")).await;
        append_event(&shared, options, StreamEvent::Eof).await;

        assert_next_chunk(&mut subscription, b"replay").await;
        assert_next_chunk(&mut subscription, b"live").await;
        assert_that!(subscription.recv().await)
            .is_some()
            .is_equal_to(StreamEvent::Eof);
    }
}