proc-heim 0.1.5

Library for running and managing short-lived and long-lived processes using asynchronous API
Documentation
use bytes::BytesMut;
use tokio::{
    io::AsyncReadExt as _,
    net::unix::pipe::Receiver,
    select,
    sync::{
        broadcast::{self},
        mpsc::{self},
        oneshot,
    },
};

use super::BufferCapacity;

enum MessageReaderCommand {
    Subscribe {
        responder: oneshot::Sender<broadcast::Receiver<Vec<u8>>>,
    },
    Abort,
}

pub struct MessageReader {
    pipe_reader: Receiver,
    subscription_receiver: mpsc::Receiver<MessageReaderCommand>,
    message_broadcaster: broadcast::Sender<Vec<u8>>,
    message_receiver: Option<broadcast::Receiver<Vec<u8>>>,
    abort: bool,
}

impl MessageReader {
    pub fn spawn(pipe_reader: Receiver, capacity: BufferCapacity) -> MessageReaderHandle {
        let (mut reader, sender) = Self::create(pipe_reader, capacity);
        tokio::spawn(async move { reader.run().await });
        MessageReaderHandle::new(sender)
    }

    fn create(
        pipe_reader: Receiver,
        capacity: BufferCapacity,
    ) -> (Self, mpsc::Sender<MessageReaderCommand>) {
        let (sender, subscription_receiver) = mpsc::channel(32);
        let (message_broadcaster, message_receiver) = broadcast::channel(capacity.inner);
        let reader = MessageReader {
            pipe_reader,
            subscription_receiver,
            message_broadcaster,
            message_receiver: Some(message_receiver),
            abort: false,
        };
        (reader, sender)
    }

    async fn run(&mut self) {
        loop {
            select! {
                Some(msg) = self.subscription_receiver.recv() => {
                    self.handle_message(msg).await;
                },
                _ = self.pipe_reader.readable() => {
                    self.read_message().await;
                }
            }
            if self.abort {
                break;
            }
        }
    }

    async fn handle_message(&mut self, msg: MessageReaderCommand) {
        match msg {
            MessageReaderCommand::Subscribe { responder } => {
                let receiver = if self.message_receiver.is_some() {
                    self.message_receiver.take().unwrap()
                } else {
                    self.message_broadcaster.subscribe()
                };
                let _ = responder.send(receiver);
            }
            MessageReaderCommand::Abort => {
                self.abort = true;
            }
        }
    }

    async fn read_message(&mut self) {
        let mut buf = BytesMut::with_capacity(4096);
        if self.pipe_reader.read_buf(&mut buf).await.is_ok() {
            buf.split_inclusive(|byte| *byte == b'\n')
                .map(|msg| msg.strip_suffix(b"\n").unwrap_or(msg))
                .for_each(|msg| {
                    let _ = self.message_broadcaster.send(msg.into());
                });
        }
    }
}

#[derive(Debug)]
pub struct MessageReaderHandle {
    sender: mpsc::Sender<MessageReaderCommand>,
}

impl MessageReaderHandle {
    fn new(sender: mpsc::Sender<MessageReaderCommand>) -> Self {
        Self { sender }
    }

    pub async fn subscribe(
        &self,
    ) -> Result<broadcast::Receiver<Vec<u8>>, oneshot::error::RecvError> {
        let (responder, receiver) = oneshot::channel();
        let _ = self
            .sender
            .send(MessageReaderCommand::Subscribe { responder })
            .await;
        receiver.await
    }

    pub async fn abort(&self) {
        let _ = self.sender.send(MessageReaderCommand::Abort).await;
    }
}

#[cfg(test)]
mod tests {
    use std::time::Duration;
    use test_utils::TestPipe;
    use tokio::{io::AsyncWriteExt, net::unix::pipe};

    use super::*;

    #[tokio::test]
    async fn should_read_messages_from_pipe() {
        let (mut sender, receiver) = pipe::pipe().unwrap();
        let reader = MessageReader::spawn(receiver, 8.try_into().unwrap());

        let writer_handle = tokio::spawn(async move {
            sender.write_all(b"Message 1\n").await.unwrap();
            sender.write_all(b"Message 2\n").await.unwrap();
        });

        let mut receiver = reader.subscribe().await.unwrap();

        let msg = receiver.recv().await.unwrap();
        assert_eq!(b"Message 1", &msg[..]);

        let msg = receiver.recv().await.unwrap();
        assert_eq!(b"Message 2", &msg[..]);

        writer_handle.await.unwrap();
        assert!(receiver.try_recv().is_err());
    }

    #[tokio::test]
    async fn should_read_messages_from_named_pipe() {
        let pipe = TestPipe::new();
        let reader = MessageReader::spawn(pipe.reader(), 8.try_into().unwrap());

        let mut writer = pipe.writer();
        let writer_handle = tokio::spawn(async move {
            writer.write_all(b"Message 1\n").await.unwrap();
            writer.write_all(b"Message 2\n").await.unwrap();
        });

        let mut receiver = reader.subscribe().await.unwrap();

        let msg = receiver.recv().await.unwrap();
        assert_eq!(b"Message 1", &msg[..]);

        let msg = receiver.recv().await.unwrap();
        assert_eq!(b"Message 2", &msg[..]);

        writer_handle.await.unwrap();
        assert!(receiver.try_recv().is_err());
    }

    #[tokio::test]
    async fn should_subscribe_multiple_times() {
        let pipe = TestPipe::new();
        let reader = MessageReader::spawn(pipe.reader(), 8.try_into().unwrap());

        let mut writer = pipe.writer();
        let writer_handle = tokio::spawn(async move {
            writer.write_all(b"Message 1\n").await.unwrap();
            tokio::time::sleep(Duration::from_secs(1)).await;
            writer.write_all(b"Message 2\n").await.unwrap();
        });

        let mut receiver = reader.subscribe().await.unwrap();
        let msg = receiver.recv().await.unwrap();
        assert_eq!(b"Message 1", &msg[..]);

        let mut receiver2 = reader.subscribe().await.unwrap();
        let msg = receiver2.recv().await.unwrap();
        assert_eq!(b"Message 2", &msg[..]);

        let msg = receiver.recv().await.unwrap();
        assert_eq!(b"Message 2", &msg[..]);

        writer_handle.await.unwrap();
    }

    #[tokio::test]
    async fn should_abort_reader_process() {
        let pipe = TestPipe::new();
        let reader = MessageReader::spawn(pipe.reader(), 8.try_into().unwrap());
        reader.abort().await;
        assert!(reader.subscribe().await.is_err());
    }
}