sett 0.4.0

Rust port of sett (data compression, encryption and transfer tool).
Documentation
//!  Low-level I/O utilities.

use std::{
    io::{self, Write as _},
    mem, thread,
};

use bytes::{BufMut as _, BytesMut};

pub(crate) enum Source {
    Channel(tokio::sync::mpsc::Receiver<BytesMut>),
    New(usize),
}

/// Error namespace
pub mod error {
    pub(super) fn to_io_error<E: std::fmt::Debug>(e: E) -> std::io::Error {
        std::io::Error::other(format!("{e:?}"))
    }

    /// Error occurring when a channel closed unexpectedly.
    #[derive(Debug)]
    pub enum ChannelClosedError {
        /// While writing
        Write,
        /// While flushing
        Flush,
        /// No context
        Unknown,
    }

    impl std::fmt::Display for ChannelClosedError {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            match self {
                Self::Write => write!(f, "channel closed while sending (write)"),
                Self::Flush => write!(f, "channel closed while sending (flush)"),
                Self::Unknown => write!(f, "channel closed"),
            }
        }
    }
    impl std::error::Error for ChannelClosedError {}

    impl From<ChannelClosedError> for std::io::Error {
        fn from(value: ChannelClosedError) -> Self {
            Self::other(format!("{value}"))
        }
    }

    /// Error occurring during a buffer exchange.
    #[derive(Debug)]
    pub struct BufferExchangeError;

    impl std::fmt::Display for BufferExchangeError {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            write!(f, "buffer exchange failed")
        }
    }
    impl std::error::Error for BufferExchangeError {}

    impl From<BufferExchangeError> for std::io::Error {
        fn from(value: BufferExchangeError) -> Self {
            Self::other(format!("{value}"))
        }
    }
}

impl Source {
    fn get(&mut self) -> Result<BytesMut, error::ChannelClosedError> {
        match self {
            Source::Channel(receiver) => {
                let mut buffer = receiver
                    .blocking_recv()
                    .ok_or(error::ChannelClosedError::Unknown)?;
                buffer.clear();
                Ok(buffer)
            }
            Source::New(size) => Ok(BytesMut::with_capacity(*size)),
        }
    }
}

pub(crate) struct ChannelWriter {
    source: Source,
    sink: tokio::sync::mpsc::Sender<BytesMut>,
    buffer: BytesMut,
}

impl ChannelWriter {
    pub(crate) fn new(
        mut source: Source,
        sink: tokio::sync::mpsc::Sender<BytesMut>,
    ) -> Result<Self, error::ChannelClosedError> {
        let buffer = source.get()?;
        Ok(Self {
            source,
            sink,
            buffer,
        })
    }
}

impl io::Write for ChannelWriter {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        // NOTE: `BufMut::remaining_mut` doesn't consider the capacity of the buffer, so we need to
        // calculate the remaining space manually.
        let remaining_capacity = self.buffer.capacity() - self.buffer.len();
        let bytes_to_copy = std::cmp::min(buf.len(), remaining_capacity);
        self.buffer.put_slice(&buf[..bytes_to_copy]);
        if remaining_capacity - bytes_to_copy == 0 {
            self.sink
                .blocking_send(mem::replace(
                    &mut self.buffer,
                    self.source
                        .get()
                        .map_err(|_| error::ChannelClosedError::Write)?,
                ))
                .map_err(|_| error::ChannelClosedError::Write)?;
        }
        Ok(bytes_to_copy)
    }

    fn flush(&mut self) -> io::Result<()> {
        if !self.buffer.is_empty() {
            self.sink
                .blocking_send(mem::replace(
                    &mut self.buffer,
                    self.source
                        .get()
                        .map_err(|_| error::ChannelClosedError::Flush)?,
                ))
                .map_err(|_| error::ChannelClosedError::Flush)?;
        }
        Ok(())
    }
}

pub(crate) trait Message {
    fn from_bytes(bytes: BytesMut) -> Self;
}

/// Splits the inner reader/writer.
///
/// Passes reads/writes to the inner type and sends a copy of read/written bytes to a channel.
/// Receives empty buffers from a source channel to avoid repetitive allocations.
pub(crate) struct Tee<'a, T, M> {
    inner: T,
    source: &'a mut tokio::sync::mpsc::Receiver<BytesMut>,
    sink: &'a tokio::sync::mpsc::Sender<M>,
    buffer: BytesMut,
}

impl<'a, R, M: Message> Tee<'a, R, M> {
    pub(crate) fn new(
        reader: R,
        source: &'a mut tokio::sync::mpsc::Receiver<BytesMut>,
        sink: &'a tokio::sync::mpsc::Sender<M>,
    ) -> Result<Self, error::BufferExchangeError> {
        let buffer = Self::get_new_buffer(source)?;
        Ok(Self {
            inner: reader,
            source,
            sink,
            buffer,
        })
    }

    #[inline]
    fn get_new_buffer(
        source: &mut tokio::sync::mpsc::Receiver<BytesMut>,
    ) -> Result<BytesMut, error::BufferExchangeError> {
        let mut buffer = source.blocking_recv().ok_or(error::BufferExchangeError)?;
        buffer.clear();
        Ok(buffer)
    }

    pub(crate) fn flush_channel(&mut self) -> io::Result<()> {
        if !self.buffer.is_empty() {
            self.sink
                .blocking_send(M::from_bytes(mem::replace(
                    &mut self.buffer,
                    Self::get_new_buffer(self.source)?,
                )))
                .map_err(error::to_io_error)?;
        }
        Ok(())
    }
}

fn send_to_channel<'a, R, M>(
    tee: &mut Tee<'a, R, M>,
    n_bytes: usize,
    buffer: &[u8],
) -> io::Result<()>
where
    M: Message,
{
    let mut index = 0;
    while index < n_bytes {
        let remaining_capacity = tee.buffer.capacity() - tee.buffer.len();
        let bytes_to_copy = std::cmp::min(n_bytes - index, remaining_capacity);
        tee.buffer.put_slice(&buffer[index..index + bytes_to_copy]);
        if tee.buffer.len() == tee.buffer.capacity() {
            tee.sink
                .blocking_send(M::from_bytes(mem::replace(
                    &mut tee.buffer,
                    Tee::<'a, R, M>::get_new_buffer(tee.source)?,
                )))
                .map_err(error::to_io_error)?;
        }
        index += bytes_to_copy;
    }
    Ok(())
}

impl<R: io::Read, M: Message> io::Read for Tee<'_, R, M> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        let n = self.inner.read(buf)?;
        if n == 0 {
            self.sink
                .blocking_send(M::from_bytes(mem::replace(
                    &mut self.buffer,
                    BytesMut::new(),
                )))
                .map_err(error::to_io_error)?;
        } else {
            send_to_channel(self, n, buf)?;
        }
        Ok(n)
    }
}

impl<W: io::Write, M: Message> io::Write for Tee<'_, W, M> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let n = self.inner.write(buf)?;
        send_to_channel(self, n, buf)?;
        Ok(n)
    }

    fn flush(&mut self) -> io::Result<()> {
        self.flush_channel()?;
        self.inner.flush()
    }
}

enum ParallelWriterMessage {
    Payload(BytesMut),
    Flush,
    Finalize,
}

pub(super) struct FgWriter {
    sender: std::sync::mpsc::Sender<ParallelWriterMessage>,
    receiver: std::sync::mpsc::Receiver<io::Result<BytesMut>>,
    buffer: BytesMut,
}

impl FgWriter {
    fn new(
        sender: std::sync::mpsc::Sender<ParallelWriterMessage>,
        receiver: std::sync::mpsc::Receiver<io::Result<BytesMut>>,
    ) -> io::Result<Self> {
        let buffer = receiver.recv().map_err(error::to_io_error)??;
        Ok(Self {
            sender,
            receiver,
            buffer,
        })
    }

    fn exchange_buffer(&mut self) -> io::Result<()> {
        let buffer = std::mem::replace(
            &mut self.buffer,
            self.receiver.recv().map_err(error::to_io_error)??,
        );
        self.buffer.clear();
        self.sender
            .send(ParallelWriterMessage::Payload(buffer))
            .map_err(error::to_io_error)
    }

    fn finalize(&mut self) -> io::Result<()> {
        self.exchange_buffer()?;
        self.sender
            .send(ParallelWriterMessage::Finalize)
            .map_err(error::to_io_error)
    }
}

impl io::Write for FgWriter {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let remaining_capacity = self.buffer.capacity() - self.buffer.len();
        let bytes_to_copy = std::cmp::min(buf.len(), remaining_capacity);
        self.buffer.put_slice(&buf[..bytes_to_copy]);
        if remaining_capacity - bytes_to_copy == 0 {
            self.exchange_buffer()?;
        }
        Ok(bytes_to_copy)
    }

    fn flush(&mut self) -> io::Result<()> {
        self.exchange_buffer()?;
        self.sender
            .send(ParallelWriterMessage::Flush)
            .map_err(error::to_io_error)
    }
}

struct BgWriter {
    sender: std::sync::mpsc::Sender<io::Result<BytesMut>>,
    receiver: std::sync::mpsc::Receiver<ParallelWriterMessage>,
}

impl BgWriter {
    fn new(
        sender: std::sync::mpsc::Sender<io::Result<BytesMut>>,
        receiver: std::sync::mpsc::Receiver<ParallelWriterMessage>,
    ) -> Self {
        Self { sender, receiver }
    }

    fn listen<W: io::Write>(&mut self, writer: &mut W) -> io::Result<()> {
        loop {
            let msg = self.receiver.recv().map_err(error::to_io_error)?;
            match msg {
                ParallelWriterMessage::Payload(buffer) => {
                    if buffer.is_empty() {
                        self.sender.send(Ok(buffer)).map_err(error::to_io_error)?;
                        continue;
                    }
                    if let Err(e) = writer.write_all(&buffer) {
                        self.sender
                            .send(Err(error::to_io_error("error occurred while writing")))
                            .map_err(error::to_io_error)?;
                        return Err(e);
                    }
                    self.sender.send(Ok(buffer)).map_err(error::to_io_error)?;
                }
                ParallelWriterMessage::Flush => {
                    if let Err(e) = writer.flush() {
                        self.sender
                            .send(Err(error::to_io_error("error occurred while flushing")))
                            .map_err(error::to_io_error)?;
                        return Err(e);
                    }
                }
                ParallelWriterMessage::Finalize => break,
            }
        }
        Ok(())
    }
}

pub(super) fn write_parallel<W, F, O, E>(writer: &mut W, f: F) -> Result<O, E>
where
    W: io::Write + Send,
    E: From<io::Error>,
    F: FnOnce(&mut FgWriter) -> Result<O, E>,
{
    const BUFFER_SIZE: usize = 1 << 22;
    const QUEUE_SIZE: usize = 3;
    let (sender, receiver) = std::sync::mpsc::channel();
    let (sender_back, receiver_back) = std::sync::mpsc::channel();
    for _ in 0..QUEUE_SIZE {
        sender_back
            .send(Ok(BytesMut::with_capacity(BUFFER_SIZE)))
            .map_err(error::to_io_error)?;
    }
    let mut bg_writer = BgWriter::new(sender_back, receiver);
    let mut fg_writer = FgWriter::new(sender, receiver_back)?;
    thread::scope(move |s| {
        let handle = s.spawn(move || bg_writer.listen(writer));
        let output = f(&mut fg_writer);
        fg_writer.flush()?;
        fg_writer.finalize()?;
        handle.join().map_err(error::to_io_error)??;
        output
    })
}

#[cfg(test)]
mod tests {
    #[test]
    fn write_to_buffer() {
        use std::io::Write as _;
        let mut output = Vec::new();
        let text = "We want a shrubbery!".as_bytes();
        for b in text {
            super::write_parallel(&mut output, |w| -> Result<(), std::io::Error> {
                assert_eq!(w.write(&[*b]).unwrap(), 1);
                Ok(())
            })
            .unwrap();
        }
        assert_eq!(&output, text);
    }
}