use std::{
io::{self, Write as _},
mem, thread,
};
use bytes::{BufMut as _, BytesMut};
pub(crate) enum Source {
Channel(tokio::sync::mpsc::Receiver<BytesMut>),
New(usize),
}
pub mod error {
pub(super) fn to_io_error<E: std::fmt::Debug>(e: E) -> std::io::Error {
std::io::Error::other(format!("{e:?}"))
}
#[derive(Debug)]
pub enum ChannelClosedError {
Write,
Flush,
Unknown,
}
impl std::fmt::Display for ChannelClosedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Write => write!(f, "channel closed while sending (write)"),
Self::Flush => write!(f, "channel closed while sending (flush)"),
Self::Unknown => write!(f, "channel closed"),
}
}
}
impl std::error::Error for ChannelClosedError {}
impl From<ChannelClosedError> for std::io::Error {
fn from(value: ChannelClosedError) -> Self {
Self::other(format!("{value}"))
}
}
#[derive(Debug)]
pub struct BufferExchangeError;
impl std::fmt::Display for BufferExchangeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "buffer exchange failed")
}
}
impl std::error::Error for BufferExchangeError {}
impl From<BufferExchangeError> for std::io::Error {
fn from(value: BufferExchangeError) -> Self {
Self::other(format!("{value}"))
}
}
}
impl Source {
fn get(&mut self) -> Result<BytesMut, error::ChannelClosedError> {
match self {
Source::Channel(receiver) => {
let mut buffer = receiver
.blocking_recv()
.ok_or(error::ChannelClosedError::Unknown)?;
buffer.clear();
Ok(buffer)
}
Source::New(size) => Ok(BytesMut::with_capacity(*size)),
}
}
}
pub(crate) struct ChannelWriter {
source: Source,
sink: tokio::sync::mpsc::Sender<BytesMut>,
buffer: BytesMut,
}
impl ChannelWriter {
pub(crate) fn new(
mut source: Source,
sink: tokio::sync::mpsc::Sender<BytesMut>,
) -> Result<Self, error::ChannelClosedError> {
let buffer = source.get()?;
Ok(Self {
source,
sink,
buffer,
})
}
}
impl io::Write for ChannelWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let remaining_capacity = self.buffer.capacity() - self.buffer.len();
let bytes_to_copy = std::cmp::min(buf.len(), remaining_capacity);
self.buffer.put_slice(&buf[..bytes_to_copy]);
if remaining_capacity - bytes_to_copy == 0 {
self.sink
.blocking_send(mem::replace(
&mut self.buffer,
self.source
.get()
.map_err(|_| error::ChannelClosedError::Write)?,
))
.map_err(|_| error::ChannelClosedError::Write)?;
}
Ok(bytes_to_copy)
}
fn flush(&mut self) -> io::Result<()> {
if !self.buffer.is_empty() {
self.sink
.blocking_send(mem::replace(
&mut self.buffer,
self.source
.get()
.map_err(|_| error::ChannelClosedError::Flush)?,
))
.map_err(|_| error::ChannelClosedError::Flush)?;
}
Ok(())
}
}
pub(crate) trait Message {
fn from_bytes(bytes: BytesMut) -> Self;
}
pub(crate) struct Tee<'a, T, M> {
inner: T,
source: &'a mut tokio::sync::mpsc::Receiver<BytesMut>,
sink: &'a tokio::sync::mpsc::Sender<M>,
buffer: BytesMut,
}
impl<'a, R, M: Message> Tee<'a, R, M> {
pub(crate) fn new(
reader: R,
source: &'a mut tokio::sync::mpsc::Receiver<BytesMut>,
sink: &'a tokio::sync::mpsc::Sender<M>,
) -> Result<Self, error::BufferExchangeError> {
let buffer = Self::get_new_buffer(source)?;
Ok(Self {
inner: reader,
source,
sink,
buffer,
})
}
#[inline]
fn get_new_buffer(
source: &mut tokio::sync::mpsc::Receiver<BytesMut>,
) -> Result<BytesMut, error::BufferExchangeError> {
let mut buffer = source.blocking_recv().ok_or(error::BufferExchangeError)?;
buffer.clear();
Ok(buffer)
}
pub(crate) fn flush_channel(&mut self) -> io::Result<()> {
if !self.buffer.is_empty() {
self.sink
.blocking_send(M::from_bytes(mem::replace(
&mut self.buffer,
Self::get_new_buffer(self.source)?,
)))
.map_err(error::to_io_error)?;
}
Ok(())
}
}
fn send_to_channel<'a, R, M>(
tee: &mut Tee<'a, R, M>,
n_bytes: usize,
buffer: &[u8],
) -> io::Result<()>
where
M: Message,
{
let mut index = 0;
while index < n_bytes {
let remaining_capacity = tee.buffer.capacity() - tee.buffer.len();
let bytes_to_copy = std::cmp::min(n_bytes - index, remaining_capacity);
tee.buffer.put_slice(&buffer[index..index + bytes_to_copy]);
if tee.buffer.len() == tee.buffer.capacity() {
tee.sink
.blocking_send(M::from_bytes(mem::replace(
&mut tee.buffer,
Tee::<'a, R, M>::get_new_buffer(tee.source)?,
)))
.map_err(error::to_io_error)?;
}
index += bytes_to_copy;
}
Ok(())
}
impl<R: io::Read, M: Message> io::Read for Tee<'_, R, M> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = self.inner.read(buf)?;
if n == 0 {
self.sink
.blocking_send(M::from_bytes(mem::replace(
&mut self.buffer,
BytesMut::new(),
)))
.map_err(error::to_io_error)?;
} else {
send_to_channel(self, n, buf)?;
}
Ok(n)
}
}
impl<W: io::Write, M: Message> io::Write for Tee<'_, W, M> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let n = self.inner.write(buf)?;
send_to_channel(self, n, buf)?;
Ok(n)
}
fn flush(&mut self) -> io::Result<()> {
self.flush_channel()?;
self.inner.flush()
}
}
enum ParallelWriterMessage {
Payload(BytesMut),
Flush,
Finalize,
}
pub(super) struct FgWriter {
sender: std::sync::mpsc::Sender<ParallelWriterMessage>,
receiver: std::sync::mpsc::Receiver<io::Result<BytesMut>>,
buffer: BytesMut,
}
impl FgWriter {
fn new(
sender: std::sync::mpsc::Sender<ParallelWriterMessage>,
receiver: std::sync::mpsc::Receiver<io::Result<BytesMut>>,
) -> io::Result<Self> {
let buffer = receiver.recv().map_err(error::to_io_error)??;
Ok(Self {
sender,
receiver,
buffer,
})
}
fn exchange_buffer(&mut self) -> io::Result<()> {
let buffer = std::mem::replace(
&mut self.buffer,
self.receiver.recv().map_err(error::to_io_error)??,
);
self.buffer.clear();
self.sender
.send(ParallelWriterMessage::Payload(buffer))
.map_err(error::to_io_error)
}
fn finalize(&mut self) -> io::Result<()> {
self.exchange_buffer()?;
self.sender
.send(ParallelWriterMessage::Finalize)
.map_err(error::to_io_error)
}
}
impl io::Write for FgWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let remaining_capacity = self.buffer.capacity() - self.buffer.len();
let bytes_to_copy = std::cmp::min(buf.len(), remaining_capacity);
self.buffer.put_slice(&buf[..bytes_to_copy]);
if remaining_capacity - bytes_to_copy == 0 {
self.exchange_buffer()?;
}
Ok(bytes_to_copy)
}
fn flush(&mut self) -> io::Result<()> {
self.exchange_buffer()?;
self.sender
.send(ParallelWriterMessage::Flush)
.map_err(error::to_io_error)
}
}
struct BgWriter {
sender: std::sync::mpsc::Sender<io::Result<BytesMut>>,
receiver: std::sync::mpsc::Receiver<ParallelWriterMessage>,
}
impl BgWriter {
fn new(
sender: std::sync::mpsc::Sender<io::Result<BytesMut>>,
receiver: std::sync::mpsc::Receiver<ParallelWriterMessage>,
) -> Self {
Self { sender, receiver }
}
fn listen<W: io::Write>(&mut self, writer: &mut W) -> io::Result<()> {
loop {
let msg = self.receiver.recv().map_err(error::to_io_error)?;
match msg {
ParallelWriterMessage::Payload(buffer) => {
if buffer.is_empty() {
self.sender.send(Ok(buffer)).map_err(error::to_io_error)?;
continue;
}
if let Err(e) = writer.write_all(&buffer) {
self.sender
.send(Err(error::to_io_error("error occurred while writing")))
.map_err(error::to_io_error)?;
return Err(e);
}
self.sender.send(Ok(buffer)).map_err(error::to_io_error)?;
}
ParallelWriterMessage::Flush => {
if let Err(e) = writer.flush() {
self.sender
.send(Err(error::to_io_error("error occurred while flushing")))
.map_err(error::to_io_error)?;
return Err(e);
}
}
ParallelWriterMessage::Finalize => break,
}
}
Ok(())
}
}
pub(super) fn write_parallel<W, F, O, E>(writer: &mut W, f: F) -> Result<O, E>
where
W: io::Write + Send,
E: From<io::Error>,
F: FnOnce(&mut FgWriter) -> Result<O, E>,
{
const BUFFER_SIZE: usize = 1 << 22;
const QUEUE_SIZE: usize = 3;
let (sender, receiver) = std::sync::mpsc::channel();
let (sender_back, receiver_back) = std::sync::mpsc::channel();
for _ in 0..QUEUE_SIZE {
sender_back
.send(Ok(BytesMut::with_capacity(BUFFER_SIZE)))
.map_err(error::to_io_error)?;
}
let mut bg_writer = BgWriter::new(sender_back, receiver);
let mut fg_writer = FgWriter::new(sender, receiver_back)?;
thread::scope(move |s| {
let handle = s.spawn(move || bg_writer.listen(writer));
let output = f(&mut fg_writer);
fg_writer.flush()?;
fg_writer.finalize()?;
handle.join().map_err(error::to_io_error)??;
output
})
}
#[cfg(test)]
mod tests {
#[test]
fn write_to_buffer() {
use std::io::Write as _;
let mut output = Vec::new();
let text = "We want a shrubbery!".as_bytes();
for b in text {
super::write_parallel(&mut output, |w| -> Result<(), std::io::Error> {
assert_eq!(w.write(&[*b]).unwrap(), 1);
Ok(())
})
.unwrap();
}
assert_eq!(&output, text);
}
}