#![expect(clippy::cast_sign_loss)]
use std::{
io::{BufReader, Cursor, Read, Seek, SeekFrom},
sync::{Arc, Mutex},
};
pub(crate) trait HasLength {
fn len(&self) -> u64;
}
pub(crate) struct CloneableSeekableReader<R: Read + Seek + HasLength> {
file: Arc<Mutex<R>>,
pos: u64,
file_length: Option<u64>,
}
impl<R: Read + Seek + HasLength> Clone for CloneableSeekableReader<R> {
fn clone(&self) -> Self {
Self {
file: self.file.clone(),
pos: self.pos,
file_length: self.file_length,
}
}
}
impl<R: Read + Seek + HasLength> CloneableSeekableReader<R> {
pub(crate) fn new(file: R) -> Self {
Self {
file: Arc::new(Mutex::new(file)),
pos: 0u64,
file_length: None,
}
}
fn ascertain_file_length(&mut self) -> u64 {
self.file_length.unwrap_or_else(|| {
let len = self.file.lock().unwrap().len();
self.file_length = Some(len);
len
})
}
}
impl<R: Read + Seek + HasLength> Read for CloneableSeekableReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let mut underlying_file = self.file.lock().expect("Unable to get underlying file");
underlying_file.seek(SeekFrom::Start(self.pos))?;
let read_result = underlying_file.read(buf);
if let Ok(bytes_read) = read_result {
self.pos += bytes_read as u64;
}
read_result
}
}
impl<R: Read + Seek + HasLength> Seek for CloneableSeekableReader<R> {
fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
let new_pos = match pos {
SeekFrom::Start(pos) => pos,
SeekFrom::End(offset_from_end) => {
let file_len = self.ascertain_file_length();
if -offset_from_end as u64 > file_len {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Seek too far backwards",
));
}
file_len - (-offset_from_end as u64)
}
SeekFrom::Current(offset_from_pos) => {
if offset_from_pos > 0 {
self.pos + (offset_from_pos as u64)
} else {
self.pos - ((-offset_from_pos) as u64)
}
}
};
self.pos = new_pos;
Ok(new_pos)
}
}
impl<R: HasLength> HasLength for BufReader<R> {
fn len(&self) -> u64 {
self.get_ref().len()
}
}
#[expect(clippy::disallowed_types)]
impl HasLength for std::fs::File {
fn len(&self) -> u64 {
self.metadata().unwrap().len()
}
}
impl HasLength for fs_err::File {
fn len(&self) -> u64 {
self.metadata().unwrap().len()
}
}
impl HasLength for Cursor<Vec<u8>> {
fn len(&self) -> u64 {
self.get_ref().len() as u64
}
}
impl HasLength for Cursor<&Vec<u8>> {
fn len(&self) -> u64 {
self.get_ref().len() as u64
}
}
#[cfg(test)]
mod test {
use std::io::{Cursor, Read, Seek, SeekFrom};
use super::CloneableSeekableReader;
#[test]
fn test_cloneable_seekable_reader() {
let buf: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let buf = Cursor::new(buf);
let mut reader = CloneableSeekableReader::new(buf);
let mut out = vec![0; 2];
assert!(reader.read_exact(&mut out).is_ok());
assert_eq!(out[0], 0);
assert_eq!(out[1], 1);
assert!(reader.seek(SeekFrom::Start(0)).is_ok());
assert!(reader.read_exact(&mut out).is_ok());
assert_eq!(out[0], 0);
assert_eq!(out[1], 1);
assert!(reader.stream_position().is_ok());
assert!(reader.read_exact(&mut out).is_ok());
assert_eq!(out[0], 2);
assert_eq!(out[1], 3);
assert!(reader.seek(SeekFrom::End(-2)).is_ok());
assert!(reader.read_exact(&mut out).is_ok());
assert_eq!(out[0], 8);
assert_eq!(out[1], 9);
assert!(reader.read_exact(&mut out).is_err());
}
}