sett 0.3.0

Rust port of sett (data compression, encryption and transfer tool).
Documentation
//!  Low-level I/O utilities.

use std::{
    io::{self, Write as _},
    mem, thread,
};

use anyhow::Context as _;
use bytes::{BufMut as _, BytesMut};

fn to_io_error<E: std::fmt::Debug>(e: E) -> io::Error {
    io::Error::new(io::ErrorKind::Other, format!("{e:?}"))
}

pub(crate) enum Source {
    Channel(tokio::sync::mpsc::Receiver<BytesMut>),
    New(usize),
}

impl Source {
    fn get(&mut self) -> anyhow::Result<BytesMut> {
        match self {
            Source::Channel(receiver) => {
                let mut buffer = receiver.blocking_recv().context("channel closed")?;
                buffer.clear();
                Ok(buffer)
            }
            Source::New(size) => Ok(BytesMut::with_capacity(*size)),
        }
    }
}

pub(crate) struct ChannelWriter {
    source: Source,
    sink: tokio::sync::mpsc::Sender<BytesMut>,
    buffer: BytesMut,
}

impl ChannelWriter {
    pub(crate) fn new(
        mut source: Source,
        sink: tokio::sync::mpsc::Sender<BytesMut>,
    ) -> anyhow::Result<Self> {
        let buffer = source.get()?;
        Ok(Self {
            source,
            sink,
            buffer,
        })
    }
}

impl io::Write for ChannelWriter {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        // NOTE: `BufMut::remaining_mut` doesn't consider the capacity of the buffer, so we need to
        // calculate the remaining space manually.
        let remaining_capacity = self.buffer.capacity() - self.buffer.len();
        let bytes_to_copy = std::cmp::min(buf.len(), remaining_capacity);
        self.buffer.put_slice(&buf[..bytes_to_copy]);
        if remaining_capacity - bytes_to_copy == 0 {
            self.sink
                .blocking_send(mem::replace(
                    &mut self.buffer,
                    self.source.get().map_err(to_io_error)?,
                ))
                .context("channel closed while sending (write)")
                .map_err(to_io_error)?;
        }
        Ok(bytes_to_copy)
    }

    fn flush(&mut self) -> io::Result<()> {
        if !self.buffer.is_empty() {
            self.sink
                .blocking_send(mem::replace(
                    &mut self.buffer,
                    self.source.get().map_err(to_io_error)?,
                ))
                .context("channel closed while sending (flush)")
                .map_err(to_io_error)?;
        }
        Ok(())
    }
}

pub(crate) trait Message {
    fn from_bytes(bytes: BytesMut) -> Self;
}

/// Splits the inner reader/writer.
///
/// Passes reads/writes to the inner type and sends a copy of read/written bytes to a channel.
/// Receives empty buffers from a source channel to avoid repetitive allocations.
pub(crate) struct Tee<'a, T, M> {
    inner: T,
    source: &'a mut tokio::sync::mpsc::Receiver<BytesMut>,
    sink: &'a tokio::sync::mpsc::Sender<M>,
    buffer: BytesMut,
}

impl<'a, R, M: Message> Tee<'a, R, M> {
    pub(crate) fn new(
        reader: R,
        source: &'a mut tokio::sync::mpsc::Receiver<BytesMut>,
        sink: &'a tokio::sync::mpsc::Sender<M>,
    ) -> anyhow::Result<Self> {
        let buffer = Self::get_new_buffer(source)?;
        Ok(Self {
            inner: reader,
            source,
            sink,
            buffer,
        })
    }

    #[inline]
    fn get_new_buffer(
        source: &mut tokio::sync::mpsc::Receiver<BytesMut>,
    ) -> anyhow::Result<BytesMut> {
        let mut buffer = source.blocking_recv().context("buffer exchange failed")?;
        buffer.clear();
        Ok(buffer)
    }

    pub(crate) fn flush_channel(&mut self) -> io::Result<()> {
        if !self.buffer.is_empty() {
            self.sink
                .blocking_send(M::from_bytes(mem::replace(
                    &mut self.buffer,
                    Self::get_new_buffer(self.source).map_err(to_io_error)?,
                )))
                .map_err(to_io_error)?;
        }
        Ok(())
    }
}

macro_rules! send_to_channel {
    ($self:expr, $n_bytes:expr, $buffer:expr) => {{
        let mut index = 0;
        while index < $n_bytes {
            let remaining_capacity = $self.buffer.capacity() - $self.buffer.len();
            let bytes_to_copy = std::cmp::min($n_bytes - index, remaining_capacity);
            $self
                .buffer
                .put_slice(&$buffer[index..index + bytes_to_copy]);
            if $self.buffer.len() == $self.buffer.capacity() {
                $self
                    .sink
                    .blocking_send(M::from_bytes(mem::replace(
                        &mut $self.buffer,
                        Self::get_new_buffer($self.source).map_err(to_io_error)?,
                    )))
                    .map_err(to_io_error)?;
            }
            index += bytes_to_copy;
        }
    }};
}

impl<'a, R: io::Read, M: Message> io::Read for Tee<'a, R, M> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        let n = self.inner.read(buf)?;
        if n == 0 {
            self.sink
                .blocking_send(M::from_bytes(mem::replace(
                    &mut self.buffer,
                    BytesMut::new(),
                )))
                .map_err(to_io_error)?;
        } else {
            send_to_channel!(self, n, buf);
        }
        Ok(n)
    }
}

impl<'a, W: io::Write, M: Message> io::Write for Tee<'a, W, M> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let n = self.inner.write(buf)?;
        send_to_channel!(self, n, buf);
        Ok(n)
    }

    fn flush(&mut self) -> io::Result<()> {
        self.flush_channel()?;
        self.inner.flush()
    }
}

enum ParallelWriterMessage {
    Payload(BytesMut),
    Flush,
    Finalize,
}

pub(super) struct FgWriter {
    sender: std::sync::mpsc::Sender<ParallelWriterMessage>,
    receiver: std::sync::mpsc::Receiver<io::Result<BytesMut>>,
    buffer: BytesMut,
}

impl FgWriter {
    fn new(
        sender: std::sync::mpsc::Sender<ParallelWriterMessage>,
        receiver: std::sync::mpsc::Receiver<io::Result<BytesMut>>,
    ) -> io::Result<Self> {
        let buffer = receiver.recv().map_err(to_io_error)??;
        Ok(Self {
            sender,
            receiver,
            buffer,
        })
    }

    fn exchange_buffer(&mut self) -> io::Result<()> {
        let buffer = std::mem::replace(
            &mut self.buffer,
            self.receiver.recv().map_err(to_io_error)??,
        );
        self.buffer.clear();
        self.sender
            .send(ParallelWriterMessage::Payload(buffer))
            .map_err(to_io_error)
    }

    fn finalize(&mut self) -> io::Result<()> {
        self.exchange_buffer()?;
        self.sender
            .send(ParallelWriterMessage::Finalize)
            .map_err(to_io_error)
    }
}

impl io::Write for FgWriter {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let remaining_capacity = self.buffer.capacity() - self.buffer.len();
        let bytes_to_copy = std::cmp::min(buf.len(), remaining_capacity);
        self.buffer.put_slice(&buf[..bytes_to_copy]);
        if remaining_capacity - bytes_to_copy == 0 {
            self.exchange_buffer()?;
        }
        Ok(bytes_to_copy)
    }

    fn flush(&mut self) -> io::Result<()> {
        self.exchange_buffer()?;
        self.sender
            .send(ParallelWriterMessage::Flush)
            .map_err(to_io_error)
    }
}

struct BgWriter {
    sender: std::sync::mpsc::Sender<io::Result<BytesMut>>,
    receiver: std::sync::mpsc::Receiver<ParallelWriterMessage>,
}

impl BgWriter {
    fn new(
        sender: std::sync::mpsc::Sender<io::Result<BytesMut>>,
        receiver: std::sync::mpsc::Receiver<ParallelWriterMessage>,
    ) -> Self {
        Self { sender, receiver }
    }

    fn listen<W: io::Write>(&mut self, writer: &mut W) -> io::Result<()> {
        loop {
            match self.receiver.recv().map_err(to_io_error)? {
                ParallelWriterMessage::Payload(buffer) => {
                    if buffer.is_empty() {
                        self.sender.send(Ok(buffer)).map_err(to_io_error)?;
                        continue;
                    }
                    if let Err(e) = writer.write_all(&buffer) {
                        self.sender
                            .send(Err(to_io_error("error occurred while writing")))
                            .map_err(to_io_error)?;
                        return Err(e);
                    }
                    self.sender.send(Ok(buffer)).map_err(to_io_error)?;
                }
                ParallelWriterMessage::Flush => {
                    if let Err(e) = writer.flush() {
                        self.sender
                            .send(Err(to_io_error("error occurred while flushing")))
                            .map_err(to_io_error)?;
                        return Err(e);
                    }
                }
                ParallelWriterMessage::Finalize => break,
            }
        }
        Ok(())
    }
}

pub(super) fn write_parallel<W, F, O, E>(writer: &mut W, f: F) -> Result<O, E>
where
    W: io::Write + Send,
    E: From<io::Error>,
    F: FnOnce(&mut FgWriter) -> Result<O, E>,
{
    const BUFFER_SIZE: usize = 1 << 22;
    const QUEUE_SIZE: usize = 3;
    let (sender, receiver) = std::sync::mpsc::channel();
    let (sender_back, receiver_back) = std::sync::mpsc::channel();
    for _ in 0..QUEUE_SIZE {
        sender_back
            .send(Ok(BytesMut::with_capacity(BUFFER_SIZE)))
            .map_err(to_io_error)?;
    }
    let mut bg_writer = BgWriter::new(sender_back, receiver);
    let mut fg_writer = FgWriter::new(sender, receiver_back)?;
    thread::scope(move |s| {
        let handle = s.spawn(move || bg_writer.listen(writer));
        let output = f(&mut fg_writer);
        fg_writer.flush()?;
        fg_writer.finalize()?;
        handle.join().map_err(to_io_error)??;
        output
    })
}

#[cfg(test)]
mod tests {
    #[test]
    fn write_to_buffer() {
        use std::io::Write as _;
        let mut output = Vec::new();
        let text = "We want a shrubbery!".as_bytes();
        for b in text {
            super::write_parallel(&mut output, |w| -> Result<(), std::io::Error> {
                assert_eq!(w.write(&[*b]).unwrap(), 1);
                Ok(())
            })
            .unwrap();
        }
        assert_eq!(&output, text);
    }
}