use std::io::{
self, copy, sink, BufReader, Chain, Cursor, ErrorKind, IoSlice, Read, Seek, SeekFrom, Write,
};
pub trait StreamLen {
fn stream_len(&self) -> u64;
}
impl<R: StreamLen> StreamLen for BufReader<R> {
fn stream_len(&self) -> u64 {
self.get_ref().stream_len()
}
}
impl<A: AsRef<[u8]>, B: StreamLen> StreamLen for Chain<Cursor<A>, B> {
fn stream_len(&self) -> u64 {
let (a, b) = self.get_ref();
a.get_ref().as_ref().len() as u64 + b.stream_len()
}
}
pub trait StreamPosition {
fn stream_position(&self) -> usize;
}
pub struct StreamPositionTracker<T: Write> {
writer: T,
pos: usize,
}
impl<T: Write> StreamPositionTracker<T> {
pub fn new(writer: T) -> Self {
Self { writer, pos: 0 }
}
}
impl<T: Write> Write for StreamPositionTracker<T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let written = self.writer.write(buf)?;
self.pos += written;
Ok(written)
}
fn flush(&mut self) -> io::Result<()> {
self.writer.flush()
}
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> std::io::Result<usize> {
let written = self.writer.write_vectored(bufs)?;
self.pos += written;
Ok(written)
}
}
impl<T: Write> StreamPosition for StreamPositionTracker<T> {
fn stream_position(&self) -> usize {
self.pos
}
}
pub struct ForwardOnlySeeker<T: Read> {
reader: T,
pos: u64,
}
impl<T: Read> ForwardOnlySeeker<T> {
pub fn new(reader: T) -> Self {
Self { reader, pos: 0 }
}
}
impl<T: Read> Read for ForwardOnlySeeker<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let bytes_read = self.reader.read(buf)?;
self.pos += bytes_read as u64;
Ok(bytes_read)
}
}
impl<T: Read> Seek for ForwardOnlySeeker<T> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
let seek_size = match pos {
SeekFrom::Current(offset) => u64::try_from(offset).ok(),
SeekFrom::Start(offset) => offset.checked_sub(self.pos),
SeekFrom::End(_) => {
return Err(io::Error::new(
ErrorKind::InvalidInput,
"Cannot seek from end",
))
}
};
match seek_size {
Some(seek_size) => {
copy(&mut self.by_ref().take(seek_size), &mut sink())?;
Ok(self.pos)
}
None => Err(io::Error::new(
ErrorKind::InvalidInput,
"Only seeking forward allowed",
)),
}
}
fn stream_position(&mut self) -> io::Result<u64> {
Ok(self.pos)
}
}
#[cfg(test)]
mod test {
use rstest::rstest;
use super::*;
#[test]
fn test_write_cursor() {
let mut buf = Vec::new();
let mut writer = StreamPositionTracker::new(&mut buf);
writer.write_all(b"hello").unwrap();
assert_eq!(writer.stream_position(), 5);
writer.write_all(b" world").unwrap();
assert_eq!(writer.stream_position(), 11);
writer.write_all(b"!").unwrap();
assert_eq!(writer.stream_position(), 12);
assert_eq!(buf, b"hello world!");
}
#[test]
fn test_write_vectored_cursor() {
let mut buf = Vec::new();
let mut writer = StreamPositionTracker::new(&mut buf);
let write_vector = [
IoSlice::new(b"hello"),
IoSlice::new(b" world"),
IoSlice::new(b"!"),
];
let bytes_written = writer.write_vectored(&write_vector).unwrap();
assert_eq!(bytes_written, 12);
assert_eq!(writer.stream_position(), 12);
assert_eq!(buf, b"hello world!");
}
#[test]
fn test_forward_seeker_stream() {
let mut input_stream = b"Hello world".as_ref();
let mut reader = ForwardOnlySeeker::new(&mut input_stream);
let mut out_buf = [0u8; 5];
reader.read_exact(&mut out_buf).unwrap();
assert_eq!(&out_buf, b"Hello");
reader.seek(SeekFrom::Current(1)).unwrap();
reader.read_exact(&mut out_buf).unwrap();
assert_eq!(&out_buf, b"world");
}
#[rstest]
#[case(SeekFrom::End(0))]
#[case(SeekFrom::Current(-1))]
#[case(SeekFrom::Start(0))]
fn test_forward_seeker_seek_fail(#[case] seek: SeekFrom) {
let mut reader = ForwardOnlySeeker::new(b"Hello world".as_ref());
reader.seek(SeekFrom::Start(1)).unwrap();
assert!(reader.seek(seek).is_err());
}
}