use std::{
io::{Read, Seek, SeekFrom},
sync::{Arc, Mutex},
};
use super::determine_stream_len;
struct Inner<R: Read + Seek> {
r: R,
pos: u64,
len: Option<u64>,
}
impl<R: Read + Seek> Inner<R> {
fn new(r: R) -> Self {
Self {
r,
pos: 0,
len: None,
}
}
fn len(&mut self) -> std::io::Result<u64> {
if let Some(len) = self.len {
return Ok(len);
}
let len = determine_stream_len(&mut self.r)?;
self.len = Some(len);
Ok(len)
}
fn read_at(&mut self, offset: u64, buf: &mut [u8]) -> std::io::Result<usize> {
if offset != self.pos {
self.r.seek(SeekFrom::Start(offset))?;
}
let read_result = self.r.read(buf);
if let Ok(bytes_read) = read_result {
self.pos += bytes_read as u64;
}
read_result
}
}
pub(crate) struct CloneableSeekableReader<R: Read + Seek> {
inner: Arc<Mutex<Inner<R>>>,
pos: u64,
}
impl<R: Read + Seek> Clone for CloneableSeekableReader<R> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
pos: self.pos,
}
}
}
impl<R: Read + Seek> CloneableSeekableReader<R> {
pub(crate) fn new(r: R) -> Self {
Self {
inner: Arc::new(Mutex::new(Inner::new(r))),
pos: 0u64,
}
}
}
impl<R: Read + Seek> Read for CloneableSeekableReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let mut inner = self.inner.lock().unwrap();
let read_result = inner.read_at(self.pos, buf);
if let Ok(bytes_read) = read_result {
self.pos = self
.pos
.checked_add(bytes_read as u64)
.ok_or(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Read too far forward",
))?;
}
read_result
}
}
impl<R: Read + Seek> Seek for CloneableSeekableReader<R> {
fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
let new_pos = match pos {
SeekFrom::Start(pos) => pos,
SeekFrom::End(offset_from_end) => {
let file_len = self.inner.lock().unwrap().len()?;
if -offset_from_end as u64 > file_len {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Seek too far backwards",
));
}
file_len
.checked_add_signed(offset_from_end)
.ok_or(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Seek too far backward from end",
))?
}
SeekFrom::Current(offset_from_pos) => self
.pos
.checked_add_signed(offset_from_pos)
.ok_or(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Seek too far forward from current pos",
))?,
};
self.pos = new_pos;
Ok(new_pos)
}
}
#[cfg(test)]
mod test {
use super::CloneableSeekableReader;
use std::io::{Cursor, Read, Seek, SeekFrom};
use test_log::test;
#[test]
fn test_cloneable_seekable_reader() -> std::io::Result<()> {
let buf: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let buf = Cursor::new(buf);
let mut reader = CloneableSeekableReader::new(buf);
let mut out = vec![0; 2];
reader.read_exact(&mut out)?;
assert_eq!(&out, &[0, 1]);
reader.rewind()?;
reader.read_exact(&mut out)?;
assert_eq!(&out, &[0, 1]);
reader.stream_position()?;
reader.read_exact(&mut out)?;
assert_eq!(&out, &[2, 3]);
reader.seek(SeekFrom::End(-2))?;
reader.read_exact(&mut out)?;
assert_eq!(&out, &[8, 9]);
assert!(reader.read_exact(&mut out).is_err());
Ok(())
}
#[test]
fn test_cloned_independent_positions() -> std::io::Result<()> {
let buf: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let buf = Cursor::new(buf);
let mut r1 = CloneableSeekableReader::new(buf);
let mut r2 = r1.clone();
let mut out = vec![0; 2];
r1.read_exact(&mut out)?;
assert_eq!(&out, &[0, 1]);
r2.read_exact(&mut out)?;
assert_eq!(&out, &[0, 1]);
r1.read_exact(&mut out)?;
assert_eq!(&out, &[2, 3]);
r2.seek(SeekFrom::End(-2))?;
r2.read_exact(&mut out)?;
assert_eq!(&out, &[8, 9]);
r1.read_exact(&mut out)?;
assert_eq!(&out, &[4, 5]);
Ok(())
}
}