use std::mem;
use log::{debug, trace};
use thiserror::Error;
use crate::{coroutines::read::ReadStreamResult, io::StreamIo};
use super::read::{ReadStream, ReadStreamError};
#[derive(Clone, Debug, Error)]
pub enum ReadStreamExactError {
#[error("Unexpected EOF, expected to read {0}/{1} more bytes")]
UnexpectedEof(usize, usize, Vec<u8>),
#[error(transparent)]
Read(#[from] ReadStreamError),
}
#[derive(Clone, Debug)]
pub enum ReadStreamExactResult {
Ok(Vec<u8>),
Io(StreamIo),
Err(ReadStreamExactError),
}
#[derive(Debug)]
pub struct ReadStreamExact {
read: ReadStream,
buffer: Vec<u8>,
max: usize,
}
impl ReadStreamExact {
pub fn new(max: usize) -> Self {
Self::with_capacity(ReadStream::DEFAULT_CAPACITY, max)
}
pub fn with_capacity(capacity: usize, max: usize) -> Self {
trace!("init coroutine to read exactly {max} bytes (capacity: {capacity})");
let read = ReadStream::with_capacity(capacity.min(max));
let buffer = Vec::with_capacity(max);
Self { read, buffer, max }
}
pub fn extend(&mut self, bytes: impl IntoIterator<Item = u8>) {
self.buffer.extend(bytes);
}
pub fn resume(&mut self, mut arg: Option<StreamIo>) -> ReadStreamExactResult {
loop {
if self.buffer.len() >= self.max {
let buffer = mem::take(&mut self.buffer);
break ReadStreamExactResult::Ok(buffer);
}
let remaining = self.max - self.buffer.len();
debug!("{remaining} remaining bytes to read");
if remaining < self.read.capacity() {
self.read.truncate(remaining);
}
let output = match self.read.resume(arg.take()) {
ReadStreamResult::Ok(output) => output,
ReadStreamResult::Err(err) => break ReadStreamExactResult::Err(err.into()),
ReadStreamResult::Io(io) => break ReadStreamExactResult::Io(io),
ReadStreamResult::Eof => {
let buffer = mem::take(&mut self.buffer);
let err = ReadStreamExactError::UnexpectedEof(remaining, self.max, buffer);
break ReadStreamExactResult::Err(err);
}
};
self.buffer.extend(output.bytes());
self.read.replace(output.buffer);
}
}
}
#[cfg(test)]
mod tests {
use std::io::{BufReader, Read as _};
use crate::{
coroutines::read_exact::{ReadStreamExactError, ReadStreamExactResult},
io::{StreamIo, StreamOutput},
};
use super::ReadStreamExact;
#[test]
fn read_exact_smaller_capacity() {
let _ = env_logger::try_init();
let mut reader = BufReader::new("abcdef".as_bytes());
let mut read = ReadStreamExact::with_capacity(3, 4);
let mut arg = None;
let output = loop {
match read.resume(arg.take()) {
ReadStreamExactResult::Ok(output) => break output,
ReadStreamExactResult::Io(StreamIo::Read(Err(mut buffer))) => {
let bytes_count = reader.read(&mut buffer).unwrap();
let output = StreamOutput {
buffer,
bytes_count,
};
arg = Some(StreamIo::Read(Ok(output)))
}
other => unreachable!("Unexpected result: {other:?}"),
}
};
assert_eq!(output, b"abcd");
let mut remaining = vec![0; 4];
let bytes_count = reader.read(&mut remaining).unwrap();
assert_eq!(bytes_count, 2);
assert_eq!(&remaining[..bytes_count], b"ef");
}
#[test]
fn read_exact_bigger_capacity() {
let _ = env_logger::try_init();
let mut reader = BufReader::new("abcdef".as_bytes());
let mut read = ReadStreamExact::with_capacity(5, 4);
let mut arg = None;
let output = loop {
match read.resume(arg.take()) {
ReadStreamExactResult::Ok(output) => break output,
ReadStreamExactResult::Io(StreamIo::Read(Err(mut buffer))) => {
let bytes_count = reader.read(&mut buffer).unwrap();
let output = StreamOutput {
buffer,
bytes_count,
};
arg = Some(StreamIo::Read(Ok(output)))
}
other => unreachable!("Unexpected result: {other:?}"),
}
};
assert_eq!(output, b"abcd");
let mut remaining = vec![0; 4];
let bytes_count = reader.read(&mut remaining).unwrap();
assert_eq!(bytes_count, 2);
assert_eq!(&remaining[..bytes_count], b"ef");
}
#[test]
fn read_exact_0() {
let _ = env_logger::try_init();
let mut reader = BufReader::new("abcdef".as_bytes());
let mut read = ReadStreamExact::with_capacity(5, 0);
read.extend("123".as_bytes().to_vec());
let mut arg = None;
let output = loop {
match read.resume(arg.take()) {
ReadStreamExactResult::Ok(output) => break output,
ReadStreamExactResult::Io(StreamIo::Read(Err(mut buffer))) => {
let bytes_count = reader.read(&mut buffer).unwrap();
let output = StreamOutput {
buffer,
bytes_count,
};
arg = Some(StreamIo::Read(Ok(output)))
}
other => unreachable!("Unexpected result: {other:?}"),
}
};
assert_eq!(output, b"123");
}
#[test]
fn read_eof() {
let _ = env_logger::try_init();
let mut reader = BufReader::new("abcdef".as_bytes());
let mut read = ReadStreamExact::new(8);
let mut arg = None;
loop {
match read.resume(arg.take()) {
ReadStreamExactResult::Err(ReadStreamExactError::UnexpectedEof(2, 8, output)) => {
break assert_eq!(output, b"abcdef");
}
ReadStreamExactResult::Io(StreamIo::Read(Err(mut buffer))) => {
let bytes_count = reader.read(&mut buffer).unwrap();
let output = StreamOutput {
buffer,
bytes_count,
};
arg = Some(StreamIo::Read(Ok(output)))
}
other => unreachable!("Unexpected result: {other:?}"),
}
}
}
}