use std::borrow::Cow;
use std::io::{self, Cursor, Read, Write};
pub(crate) struct Handle<'i>(Source<'i>);
enum Source<'i> {
Slice(&'i [u8]),
Reader(GuardedCaptureReader<Box<dyn Read + 'i>>),
}
impl<'i> Handle<'i> {
pub(crate) fn from_slice(b: &'i [u8]) -> Handle<'i> {
Handle(Source::Slice(b))
}
pub(crate) fn from_reader<R>(r: R) -> Handle<'i>
where
R: Read + 'i,
{
Handle(Source::Reader(GuardedCaptureReader::new(Box::new(r))))
}
pub(crate) fn borrow_mut(&mut self) -> Ref<'i, '_> {
match &mut self.0 {
Source::Slice(b) => Ref::Slice(b),
Source::Reader(r) => {
let r = r.rewind_and_borrow_mut();
if r.is_source_eof() {
Ref::Slice(r.captured())
} else {
Ref::Reader(r)
}
}
}
}
}
impl<'i> TryFrom<Handle<'i>> for Cow<'i, [u8]> {
type Error = io::Error;
fn try_from(handle: Handle<'i>) -> io::Result<Cow<'i, [u8]>> {
match handle.0 {
Source::Slice(b) => Ok(Cow::Borrowed(b)),
Source::Reader(r) => {
let mut r = r.rewind_and_take();
r.capture_to_end()?;
let (cursor, _) = r.into_inner();
Ok(Cow::Owned(cursor.into_inner()))
}
}
}
}
pub(crate) enum Input<'i> {
Slice(Cow<'i, [u8]>),
Reader(Box<dyn Read + 'i>),
}
impl<'i> From<Handle<'i>> for Input<'i> {
fn from(handle: Handle<'i>) -> Self {
match handle.0 {
Source::Slice(b) => Input::Slice(Cow::Borrowed(b)),
Source::Reader(r) => {
let r = r.rewind_and_take();
let source_eof = r.is_source_eof();
let (cursor, source) = r.into_inner();
if source_eof {
Input::Slice(Cow::Owned(cursor.into_inner()))
} else if cursor.get_ref().is_empty() {
Input::Reader(source)
} else {
Input::Reader(Box::new(FusedReader::new(cursor).chain(source)))
}
}
}
}
}
pub(crate) enum Ref<'i, 'h>
where
'i: 'h,
{
Slice(&'h [u8]),
Reader(&'h mut CaptureReader<Box<dyn Read + 'i>>),
}
impl<'i, 'h> Ref<'i, 'h>
where
'i: 'h,
{
pub(crate) fn prefix(&mut self, size_hint: usize) -> io::Result<&[u8]> {
match self {
Ref::Slice(b) => Ok(b),
Ref::Reader(r) => {
r.capture_up_to_size(size_hint)?;
Ok(r.captured())
}
}
}
}
struct FusedReader<R>(Option<R>)
where
R: Read;
impl<R> FusedReader<R>
where
R: Read,
{
fn new(r: R) -> FusedReader<R> {
FusedReader(Some(r))
}
}
impl<R> Read for FusedReader<R>
where
R: Read,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = match &mut self.0 {
None => return Ok(0),
Some(r) => r.read(buf)?,
};
if n == 0 && !buf.is_empty() {
self.0 = None;
}
Ok(n)
}
}
struct GuardedCaptureReader<R>(CaptureReader<R>)
where
R: Read;
impl<R> GuardedCaptureReader<R>
where
R: Read,
{
fn new(r: R) -> Self {
Self(CaptureReader::new(r))
}
fn rewind_and_borrow_mut(&mut self) -> &mut CaptureReader<R> {
self.0.rewind();
&mut self.0
}
fn rewind_and_take(mut self) -> CaptureReader<R> {
self.0.rewind();
self.0
}
}
pub(crate) struct CaptureReader<R>
where
R: Read,
{
prefix: Cursor<Vec<u8>>,
source: R,
source_eof: bool,
}
impl<R> CaptureReader<R>
where
R: Read,
{
fn new(source: R) -> Self {
Self {
prefix: Cursor::new(vec![]),
source,
source_eof: false,
}
}
fn captured(&self) -> &[u8] {
self.prefix.get_ref()
}
fn captured_unread_size(&self) -> usize {
#[allow(clippy::cast_possible_truncation)]
let offset = self.prefix.position() as usize;
self.prefix.get_ref().len() - offset
}
fn rewind(&mut self) {
self.prefix.set_position(0);
}
fn capture_to_end(&mut self) -> io::Result<()> {
if !self.source_eof {
self.source.read_to_end(self.prefix.get_mut())?;
self.source_eof = true;
}
Ok(())
}
fn capture_up_to_size(&mut self, size: usize) -> io::Result<()> {
let needed = size.saturating_sub(self.prefix.get_ref().len());
if needed == 0 {
return Ok(());
}
let mut take = self.source.by_ref().take(needed as u64);
take.read_to_end(self.prefix.get_mut())?;
if take.limit() > 0 {
self.source_eof = true;
}
Ok(())
}
fn is_source_eof(&self) -> bool {
self.source_eof
}
fn into_inner(self) -> (Cursor<Vec<u8>>, R) {
(self.prefix, self.source)
}
}
impl<R> Read for CaptureReader<R>
where
R: Read,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let prefix_size = std::cmp::min(buf.len(), self.captured_unread_size());
self.prefix.read_exact(&mut buf[..prefix_size])?;
if self.captured_unread_size() > 0 || prefix_size == buf.len() {
return Ok(prefix_size);
}
let buf = &mut buf[prefix_size..];
let source_size = self.source.read(buf)?;
self.prefix.write_all(&buf[..source_size])?;
self.source_eof = source_size == 0;
Ok(prefix_size + source_size)
}
}
#[cfg(test)]
mod tests {
use super::{CaptureReader, Handle, Input, Ref};
use std::borrow::Cow;
use std::io::{self, Cursor, Read};
const DATA: &str = "abcdefghij";
const HALF: usize = DATA.len() / 2;
#[test]
fn input_borrow_mut_rewind() {
let mut handle = Handle::from_reader(DATA.as_bytes());
let mut buf = vec![];
let mut input_ref = handle.borrow_mut();
match input_ref {
Ref::Slice(_) => unreachable!(),
Ref::Reader(ref mut r) => r.take(HALF as u64).read_to_end(&mut buf).unwrap(),
};
assert_eq!(std::str::from_utf8(&buf), Ok(&DATA[..HALF]));
buf.clear();
#[allow(clippy::forget_non_drop)]
std::mem::forget(input_ref);
match handle.borrow_mut() {
Ref::Slice(_) => unreachable!(),
Ref::Reader(r) => r.take(HALF as u64).read_to_end(&mut buf).unwrap(),
};
assert_eq!(std::str::from_utf8(&buf), Ok(&DATA[..HALF]));
buf.clear();
let mut r = match handle.into() {
Input::Slice(_) => unreachable!(),
Input::Reader(r) => r,
};
assert!(matches!(r.read_to_end(&mut buf), Ok(len) if len == DATA.len()));
assert_eq!(std::str::from_utf8(&buf), Ok(DATA));
}
#[test]
fn input_into_cow() {
let mut handle = Handle::from_reader(DATA.as_bytes());
match handle.borrow_mut() {
Ref::Slice(_) => unreachable!(),
Ref::Reader(r) => io::copy(&mut r.take(HALF as u64), &mut io::sink()).unwrap(),
};
let buf: Cow<'_, [u8]> = handle.try_into().unwrap();
assert_eq!(std::str::from_utf8(&buf), Ok(DATA));
}
#[test]
fn capture_reader_straight_read() {
let mut r = CaptureReader::new(Cursor::new(String::from(DATA)));
assert_eq!(io::read_to_string(&mut r).unwrap(), DATA);
assert!(r.is_source_eof());
let (cursor, _) = r.into_inner();
assert!(matches!(std::str::from_utf8(cursor.get_ref()), Ok(DATA)));
}
#[test]
fn capture_reader_rewind() {
let mut r = CaptureReader::new(Cursor::new(String::from(DATA)));
let mut tmp = [0; HALF];
assert!(matches!(r.read_exact(&mut tmp), Ok(())));
assert_eq!(std::str::from_utf8(&tmp), Ok(&DATA[..HALF]));
assert_eq!(std::str::from_utf8(r.captured()), Ok(&DATA[..HALF]));
assert!(!r.is_source_eof());
r.rewind();
assert_eq!(io::read_to_string(&mut r).unwrap(), DATA);
assert_eq!(r.captured(), DATA.as_bytes());
assert!(r.is_source_eof());
}
#[test]
fn capture_reader_to_end() {
let mut r = CaptureReader::new(Cursor::new(String::from(DATA)));
assert!(r.capture_to_end().is_ok());
assert_eq!(std::str::from_utf8(r.captured()), Ok(DATA));
assert!(r.is_source_eof());
}
#[test]
fn capture_reader_up_to() {
let mut r = CaptureReader::new(Cursor::new(String::from(DATA)));
assert!(r.capture_up_to_size(HALF).is_ok());
assert_eq!(std::str::from_utf8(r.captured()), Ok(&DATA[..HALF]));
assert!(!r.is_source_eof());
}
}