use futures_channel::mpsc;
use futures_core::Stream;
use std::io::{self, Write};
use std::mem;
pub(crate) struct BodyWriter<D, E>
where
D: From<Vec<u8>> + Send + 'static,
E: Send + 'static,
{
sender: mpsc::UnboundedSender<Result<D, E>>,
buf: Vec<u8>,
}
impl<D, E> BodyWriter<D, E>
where
D: From<Vec<u8>> + Send + 'static,
E: Send + 'static,
{
pub(crate) fn with_chunk_size(
cap: usize,
) -> (Self, Box<dyn Stream<Item = Result<D, E>> + Send>) {
assert!(cap > 0);
let (snd, rcv) = mpsc::unbounded();
let body = Box::new(rcv);
(
BodyWriter {
sender: snd,
buf: Vec::with_capacity(cap),
},
body,
)
}
pub(crate) fn abort(&mut self, error: E) {
let _ = self.sender.unbounded_send(Err(error));
}
#[cfg(test)]
fn truncate(&mut self) {
self.buf.clear()
}
}
impl<D, E> Write for BodyWriter<D, E>
where
D: From<Vec<u8>> + Send + 'static,
E: Send + 'static,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let remaining = self.buf.capacity() - self.buf.len();
let full = remaining <= buf.len();
let bytes = if full { remaining } else { buf.len() };
self.buf.extend_from_slice(&buf[0..bytes]);
if full {
self.flush()?;
}
Ok(bytes)
}
fn flush(&mut self) -> io::Result<()> {
if !self.buf.is_empty() {
let cap = self.buf.capacity();
let full_buf = mem::replace(&mut self.buf, Vec::with_capacity(cap));
if self.sender.unbounded_send(Ok(full_buf.into())).is_err() {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"receiver was dropped",
));
}
}
Ok(())
}
}
impl<D, E> Drop for BodyWriter<D, E>
where
D: From<Vec<u8>> + Send + 'static,
E: Send + 'static,
{
fn drop(&mut self) {
let _ = self.flush();
}
}
#[cfg(test)]
mod tests {
use super::BodyWriter;
use futures_core::Stream;
use futures_util::{stream::StreamExt, stream::TryStreamExt};
use std::io::Write;
use std::pin::Pin;
type BoxedError = Box<dyn std::error::Error + 'static + Send + Sync>;
type BodyStream = Box<dyn Stream<Item = Result<Vec<u8>, BoxedError>> + Send>;
async fn to_vec(s: BodyStream) -> Vec<u8> {
Pin::from(s).try_concat().await.unwrap()
}
#[tokio::test]
async fn small_no_flush() {
let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
assert_eq!(w.write(b"1").unwrap(), 1);
w.truncate();
drop(w);
assert_eq!(b"", &to_vec(body).await[..]);
}
#[tokio::test]
async fn small_flush() {
let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
assert_eq!(w.write(b"1").unwrap(), 1);
w.flush().unwrap();
drop(w);
assert_eq!(b"1", &to_vec(body).await[..]);
}
#[tokio::test]
async fn chunk_write() {
let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
assert_eq!(w.write(b"1234").unwrap(), 4);
w.flush().unwrap();
drop(w);
assert_eq!(b"1234", &to_vec(body).await[..]);
}
#[tokio::test]
async fn chunk_double_write() {
let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
assert_eq!(w.write(b"1234").unwrap(), 4);
assert_eq!(w.write(b"5678").unwrap(), 4);
w.flush().unwrap();
drop(w);
assert_eq!(b"12345678", &to_vec(body).await[..]);
}
#[tokio::test]
async fn large_write() {
let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
assert_eq!(w.write(b"123456").unwrap(), 4);
drop(w);
assert_eq!(b"1234", &to_vec(body).await[..]);
}
#[tokio::test]
async fn small_large_write() {
let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
assert_eq!(w.write(b"1").unwrap(), 1);
assert_eq!(w.write(b"2345").unwrap(), 3);
drop(w);
assert_eq!(b"1234", &to_vec(body).await[..]);
}
#[tokio::test]
async fn abort() {
let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
w.write_all(b"12345").unwrap();
w.truncate();
w.abort(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
"asdf",
)));
drop(w);
let items = Pin::<_>::from(body)
.collect::<Vec<Result<Vec<u8>, BoxedError>>>()
.await;
assert_eq!(items.len(), 2);
assert_eq!(b"1234", &items[0].as_ref().unwrap()[..]);
items[1].as_ref().unwrap_err();
}
}