use std::{
io::{self, Write as _},
mem, thread,
};
use anyhow::Context as _;
use bytes::{BufMut as _, BytesMut};
fn to_io_error<E: std::fmt::Debug>(e: E) -> io::Error {
io::Error::new(io::ErrorKind::Other, format!("{e:?}"))
}
pub(crate) enum Source {
Channel(tokio::sync::mpsc::Receiver<BytesMut>),
New(usize),
}
impl Source {
fn get(&mut self) -> anyhow::Result<BytesMut> {
match self {
Source::Channel(receiver) => {
let mut buffer = receiver.blocking_recv().context("channel closed")?;
buffer.clear();
Ok(buffer)
}
Source::New(size) => Ok(BytesMut::with_capacity(*size)),
}
}
}
pub(crate) struct ChannelWriter {
source: Source,
sink: tokio::sync::mpsc::Sender<BytesMut>,
buffer: BytesMut,
}
impl ChannelWriter {
pub(crate) fn new(
mut source: Source,
sink: tokio::sync::mpsc::Sender<BytesMut>,
) -> anyhow::Result<Self> {
let buffer = source.get()?;
Ok(Self {
source,
sink,
buffer,
})
}
}
impl io::Write for ChannelWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let remaining_capacity = self.buffer.capacity() - self.buffer.len();
let bytes_to_copy = std::cmp::min(buf.len(), remaining_capacity);
self.buffer.put_slice(&buf[..bytes_to_copy]);
if remaining_capacity - bytes_to_copy == 0 {
self.sink
.blocking_send(mem::replace(
&mut self.buffer,
self.source.get().map_err(to_io_error)?,
))
.context("channel closed while sending (write)")
.map_err(to_io_error)?;
}
Ok(bytes_to_copy)
}
fn flush(&mut self) -> io::Result<()> {
if !self.buffer.is_empty() {
self.sink
.blocking_send(mem::replace(
&mut self.buffer,
self.source.get().map_err(to_io_error)?,
))
.context("channel closed while sending (flush)")
.map_err(to_io_error)?;
}
Ok(())
}
}
pub(crate) trait Message {
fn from_bytes(bytes: BytesMut) -> Self;
}
pub(crate) struct Tee<'a, T, M> {
inner: T,
source: &'a mut tokio::sync::mpsc::Receiver<BytesMut>,
sink: &'a tokio::sync::mpsc::Sender<M>,
buffer: BytesMut,
}
impl<'a, R, M: Message> Tee<'a, R, M> {
pub(crate) fn new(
reader: R,
source: &'a mut tokio::sync::mpsc::Receiver<BytesMut>,
sink: &'a tokio::sync::mpsc::Sender<M>,
) -> anyhow::Result<Self> {
let buffer = Self::get_new_buffer(source)?;
Ok(Self {
inner: reader,
source,
sink,
buffer,
})
}
#[inline]
fn get_new_buffer(
source: &mut tokio::sync::mpsc::Receiver<BytesMut>,
) -> anyhow::Result<BytesMut> {
let mut buffer = source.blocking_recv().context("buffer exchange failed")?;
buffer.clear();
Ok(buffer)
}
pub(crate) fn flush_channel(&mut self) -> io::Result<()> {
if !self.buffer.is_empty() {
self.sink
.blocking_send(M::from_bytes(mem::replace(
&mut self.buffer,
Self::get_new_buffer(self.source).map_err(to_io_error)?,
)))
.map_err(to_io_error)?;
}
Ok(())
}
}
macro_rules! send_to_channel {
($self:expr, $n_bytes:expr, $buffer:expr) => {{
let mut index = 0;
while index < $n_bytes {
let remaining_capacity = $self.buffer.capacity() - $self.buffer.len();
let bytes_to_copy = std::cmp::min($n_bytes - index, remaining_capacity);
$self
.buffer
.put_slice(&$buffer[index..index + bytes_to_copy]);
if $self.buffer.len() == $self.buffer.capacity() {
$self
.sink
.blocking_send(M::from_bytes(mem::replace(
&mut $self.buffer,
Self::get_new_buffer($self.source).map_err(to_io_error)?,
)))
.map_err(to_io_error)?;
}
index += bytes_to_copy;
}
}};
}
impl<'a, R: io::Read, M: Message> io::Read for Tee<'a, R, M> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = self.inner.read(buf)?;
if n == 0 {
self.sink
.blocking_send(M::from_bytes(mem::replace(
&mut self.buffer,
BytesMut::new(),
)))
.map_err(to_io_error)?;
} else {
send_to_channel!(self, n, buf);
}
Ok(n)
}
}
impl<'a, W: io::Write, M: Message> io::Write for Tee<'a, W, M> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let n = self.inner.write(buf)?;
send_to_channel!(self, n, buf);
Ok(n)
}
fn flush(&mut self) -> io::Result<()> {
self.flush_channel()?;
self.inner.flush()
}
}
enum ParallelWriterMessage {
Payload(BytesMut),
Flush,
Finalize,
}
pub(super) struct FgWriter {
sender: std::sync::mpsc::Sender<ParallelWriterMessage>,
receiver: std::sync::mpsc::Receiver<io::Result<BytesMut>>,
buffer: BytesMut,
}
impl FgWriter {
fn new(
sender: std::sync::mpsc::Sender<ParallelWriterMessage>,
receiver: std::sync::mpsc::Receiver<io::Result<BytesMut>>,
) -> io::Result<Self> {
let buffer = receiver.recv().map_err(to_io_error)??;
Ok(Self {
sender,
receiver,
buffer,
})
}
fn exchange_buffer(&mut self) -> io::Result<()> {
let buffer = std::mem::replace(
&mut self.buffer,
self.receiver.recv().map_err(to_io_error)??,
);
self.buffer.clear();
self.sender
.send(ParallelWriterMessage::Payload(buffer))
.map_err(to_io_error)
}
fn finalize(&mut self) -> io::Result<()> {
self.exchange_buffer()?;
self.sender
.send(ParallelWriterMessage::Finalize)
.map_err(to_io_error)
}
}
impl io::Write for FgWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let remaining_capacity = self.buffer.capacity() - self.buffer.len();
let bytes_to_copy = std::cmp::min(buf.len(), remaining_capacity);
self.buffer.put_slice(&buf[..bytes_to_copy]);
if remaining_capacity - bytes_to_copy == 0 {
self.exchange_buffer()?;
}
Ok(bytes_to_copy)
}
fn flush(&mut self) -> io::Result<()> {
self.exchange_buffer()?;
self.sender
.send(ParallelWriterMessage::Flush)
.map_err(to_io_error)
}
}
struct BgWriter {
sender: std::sync::mpsc::Sender<io::Result<BytesMut>>,
receiver: std::sync::mpsc::Receiver<ParallelWriterMessage>,
}
impl BgWriter {
fn new(
sender: std::sync::mpsc::Sender<io::Result<BytesMut>>,
receiver: std::sync::mpsc::Receiver<ParallelWriterMessage>,
) -> Self {
Self { sender, receiver }
}
fn listen<W: io::Write>(&mut self, writer: &mut W) -> io::Result<()> {
loop {
match self.receiver.recv().map_err(to_io_error)? {
ParallelWriterMessage::Payload(buffer) => {
if buffer.is_empty() {
self.sender.send(Ok(buffer)).map_err(to_io_error)?;
continue;
}
if let Err(e) = writer.write_all(&buffer) {
self.sender
.send(Err(to_io_error("error occurred while writing")))
.map_err(to_io_error)?;
return Err(e);
}
self.sender.send(Ok(buffer)).map_err(to_io_error)?;
}
ParallelWriterMessage::Flush => {
if let Err(e) = writer.flush() {
self.sender
.send(Err(to_io_error("error occurred while flushing")))
.map_err(to_io_error)?;
return Err(e);
}
}
ParallelWriterMessage::Finalize => break,
}
}
Ok(())
}
}
pub(super) fn write_parallel<W, F, O, E>(writer: &mut W, f: F) -> Result<O, E>
where
W: io::Write + Send,
E: From<io::Error>,
F: FnOnce(&mut FgWriter) -> Result<O, E>,
{
const BUFFER_SIZE: usize = 1 << 22;
const QUEUE_SIZE: usize = 3;
let (sender, receiver) = std::sync::mpsc::channel();
let (sender_back, receiver_back) = std::sync::mpsc::channel();
for _ in 0..QUEUE_SIZE {
sender_back
.send(Ok(BytesMut::with_capacity(BUFFER_SIZE)))
.map_err(to_io_error)?;
}
let mut bg_writer = BgWriter::new(sender_back, receiver);
let mut fg_writer = FgWriter::new(sender, receiver_back)?;
thread::scope(move |s| {
let handle = s.spawn(move || bg_writer.listen(writer));
let output = f(&mut fg_writer);
fg_writer.flush()?;
fg_writer.finalize()?;
handle.join().map_err(to_io_error)??;
output
})
}
#[cfg(test)]
mod tests {
#[test]
fn write_to_buffer() {
use std::io::Write as _;
let mut output = Vec::new();
let text = "We want a shrubbery!".as_bytes();
for b in text {
super::write_parallel(&mut output, |w| -> Result<(), std::io::Error> {
assert_eq!(w.write(&[*b]).unwrap(), 1);
Ok(())
})
.unwrap();
}
assert_eq!(&output, text);
}
}