use std::mem;
use log::trace;
use thiserror::Error;
use crate::io::StreamIo;
use super::read::{ReadStream, ReadStreamError, ReadStreamResult};
#[derive(Clone, Debug, Error)]
pub enum ReadStreamToEndError {
#[error(transparent)]
Read(#[from] ReadStreamError),
}
#[derive(Clone, Debug)]
pub enum ReadStreamToEndResult {
Ok(Vec<u8>),
Io(StreamIo),
Err(ReadStreamToEndError),
}
#[derive(Debug)]
pub struct ReadStreamToEnd {
read: ReadStream,
buffer: Vec<u8>,
}
impl ReadStreamToEnd {
pub fn new() -> Self {
Self::with_capacity(ReadStream::DEFAULT_CAPACITY)
}
pub fn with_capacity(capacity: usize) -> Self {
trace!("init coroutine to read until EOF (capacity: {capacity})");
let read = ReadStream::with_capacity(capacity);
let buffer = Vec::with_capacity(capacity);
Self { read, buffer }
}
pub fn extend(&mut self, bytes: impl IntoIterator<Item = u8>) {
self.buffer.extend(bytes);
}
pub fn resume(&mut self, mut arg: Option<StreamIo>) -> ReadStreamToEndResult {
loop {
let output = match self.read.resume(arg.take()) {
ReadStreamResult::Ok(output) => output,
ReadStreamResult::Err(err) => break ReadStreamToEndResult::Err(err.into()),
ReadStreamResult::Io(io) => break ReadStreamToEndResult::Io(io),
ReadStreamResult::Eof => {
let buffer = mem::take(&mut self.buffer);
break ReadStreamToEndResult::Ok(buffer);
}
};
self.buffer.extend(output.bytes());
self.read.replace(output.buffer);
}
}
}
impl Default for ReadStreamToEnd {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use std::io::{BufReader, Read as _};
use crate::{
coroutines::read_to_end::ReadStreamToEndResult,
io::{StreamIo, StreamOutput},
};
use super::ReadStreamToEnd;
#[test]
fn read_to_end() {
let _ = env_logger::try_init();
let mut reader = BufReader::new("abcdef".as_bytes());
let mut read = ReadStreamToEnd::with_capacity(4);
let mut arg = None;
let output = loop {
match read.resume(arg.take()) {
ReadStreamToEndResult::Ok(output) => break output,
ReadStreamToEndResult::Io(StreamIo::Read(Err(mut buffer))) => {
let bytes_count = reader.read(&mut buffer).unwrap();
let output = StreamOutput {
buffer,
bytes_count,
};
arg = Some(StreamIo::Read(Ok(output)))
}
other => unreachable!("Unexpected result: {other:?}"),
}
};
assert_eq!(output, b"abcdef");
}
}