use crate::errors::{Error, ErrorKind};
use std::io::Read;
use std::ops::Range;
#[cfg(unix)]
use std::os::unix::fs::FileExt;
#[cfg(windows)]
use std::os::windows::fs::FileExt;
use std::{rc::Rc, sync::Arc};
pub trait ReaderAt {
fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize>;
fn read_exact_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<()> {
let mut read = 0;
while read < buf.len() {
let latest = self.read_at(&mut buf[read..], offset + (read as u64))?;
if latest == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"failed to fill whole buffer",
));
}
read += latest;
}
Ok(())
}
}
pub(crate) trait ReaderAtExt {
fn try_read_at_least_at(
&self,
buffer: &mut [u8],
size: usize,
offset: u64,
) -> std::io::Result<usize>;
fn read_at_least_at(&self, buffer: &mut [u8], size: usize, offset: u64)
-> Result<usize, Error>;
}
impl<T: ReaderAt> ReaderAtExt for T {
fn try_read_at_least_at(
&self,
buffer: &mut [u8],
mut size: usize,
offset: u64,
) -> std::io::Result<usize> {
size = size.min(buffer.len());
let mut pos = 0;
while pos < size {
let read = self.read_at(&mut buffer[pos..], offset + pos as u64)?;
if read == 0 {
return Ok(pos);
}
pos += read;
}
Ok(pos)
}
fn read_at_least_at(
&self,
buffer: &mut [u8],
size: usize,
offset: u64,
) -> Result<usize, Error> {
if buffer.len() < size {
return Err(Error::from(ErrorKind::BufferTooSmall));
}
let read = self.try_read_at_least_at(buffer, size, offset)?;
if read < size {
return Err(Error::from(ErrorKind::Eof));
}
Ok(read)
}
}
#[cfg(not(any(unix, windows)))]
#[derive(Debug)]
pub struct FileReader(MutexReader<std::fs::File>);
#[cfg(any(unix, windows))]
#[derive(Debug)]
pub struct FileReader(std::fs::File);
impl FileReader {
pub fn into_inner(self) -> std::fs::File {
#[cfg(not(any(unix, windows)))]
return self.0.into_inner();
#[cfg(any(unix, windows))]
return self.0;
}
}
impl ReaderAt for FileReader {
#[inline]
fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
#[cfg(unix)]
return self.0.read_at(buf, offset);
#[cfg(windows)]
return self.0.seek_read(buf, offset);
#[cfg(not(any(unix, windows)))]
return self.0.read_at(buf, offset);
}
}
impl std::io::Seek for FileReader {
#[inline]
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
self.0.seek(pos)
}
}
impl From<std::fs::File> for FileReader {
#[cfg(not(any(unix, windows)))]
fn from(file: std::fs::File) -> Self {
Self(MutexReader(std::sync::Mutex::new(file)))
}
#[cfg(any(unix, windows))]
fn from(file: std::fs::File) -> Self {
Self(file)
}
}
#[derive(Debug)]
pub struct MutexReader<R>(std::sync::Mutex<R>);
impl<R> MutexReader<R> {
pub fn new(inner: R) -> Self {
Self(std::sync::Mutex::new(inner))
}
pub fn into_inner(self) -> R {
self.0.into_inner().unwrap()
}
}
impl<R> ReaderAt for MutexReader<R>
where
R: std::io::Read + std::io::Seek,
{
fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
let mut lock = self.0.lock().unwrap();
let original_position = lock.stream_position()?;
lock.seek(std::io::SeekFrom::Start(offset))?;
let result = lock.read(buf);
lock.seek(std::io::SeekFrom::Start(original_position))?;
result
}
}
impl<R> std::io::Read for MutexReader<R>
where
R: std::io::Read,
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.0.lock().unwrap().read(buf)
}
}
impl<R> std::io::Seek for MutexReader<R>
where
R: std::io::Seek,
{
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
self.0.lock().unwrap().seek(pos)
}
}
impl<T: ReaderAt> ReaderAt for &'_ T {
#[inline]
fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
(*self).read_at(buf, offset)
}
}
impl<T: ReaderAt> ReaderAt for &'_ mut T {
#[inline]
fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
(**self).read_at(buf, offset)
}
}
impl ReaderAt for &[u8] {
#[inline]
fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
let skip = self.len().min(offset as usize);
let data = &self[skip..];
let len = data.len().min(buf.len());
buf[..len].copy_from_slice(&data[..len]);
Ok(len)
}
}
impl<R> ReaderAt for std::io::Cursor<R>
where
R: AsRef<[u8]>,
{
#[inline]
fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
let data = self.get_ref().as_ref();
data.read_at(buf, offset)
}
}
impl ReaderAt for Vec<u8> {
#[inline]
fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
self.as_slice().read_at(buf, offset)
}
}
impl<T: ReaderAt + ?Sized> ReaderAt for Arc<T> {
#[inline]
fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
(**self).read_at(buf, offset)
}
}
impl<T: ReaderAt + ?Sized> ReaderAt for Rc<T> {
#[inline]
fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
(**self).read_at(buf, offset)
}
}
impl<T: ReaderAt + ?Sized> ReaderAt for Box<T> {
#[inline]
fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
(**self).read_at(buf, offset)
}
}
#[derive(Debug, Clone)]
pub struct RangeReader<R> {
archive: R,
offset: u64,
end_offset: u64,
}
impl<R> RangeReader<R> {
#[inline]
pub fn new(archive: R, range: Range<u64>) -> Self {
Self {
archive,
offset: range.start,
end_offset: range.end,
}
}
#[inline]
pub fn position(&self) -> u64 {
self.offset
}
#[inline]
pub fn remaining(&self) -> u64 {
self.end_offset - self.offset
}
#[inline]
pub fn end_offset(&self) -> u64 {
self.end_offset
}
#[inline]
pub fn get_ref(&self) -> &R {
&self.archive
}
#[inline]
pub fn into_inner(self) -> R {
self.archive
}
}
impl<R> Read for RangeReader<R>
where
R: ReaderAt,
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let read_size = buf.len().min(self.remaining() as usize);
let read = self.archive.read_at(&mut buf[..read_size], self.offset)?;
self.offset += read as u64;
Ok(read)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
const TEST_DATA: &[u8] = b"Hello, World! This is test data for ReaderAt implementations.";
fn test_reader_at_impl<R: ReaderAt>(reader: R, data_len: usize) {
let mut buf = [0u8; 5];
assert_eq!(reader.read_at(&mut buf, 0).unwrap(), 5);
assert_eq!(&buf, b"Hello");
buf.fill(0);
assert_eq!(reader.read_at(&mut buf, 7).unwrap(), 5);
assert_eq!(&buf, b"World");
buf.fill(0);
let bytes_read = reader.read_at(&mut buf, data_len as u64).unwrap();
assert_eq!(bytes_read, 0);
buf.fill(0);
let bytes_read = reader.read_at(&mut buf, (data_len - 3) as u64).unwrap();
assert_eq!(bytes_read, 3);
assert_eq!(&buf[..3], &TEST_DATA[data_len - 3..]);
}
#[test]
fn test_smart_pointer_implementations() {
let data = TEST_DATA.to_vec();
let arc_reader = Arc::new(data.clone());
test_reader_at_impl(&*arc_reader, data.len());
test_reader_at_impl(arc_reader, data.len());
let rc_reader = Rc::new(data.clone());
test_reader_at_impl(&*rc_reader, data.len());
test_reader_at_impl(rc_reader, data.len());
let box_reader = Box::new(data.clone());
test_reader_at_impl(&*box_reader, data.len());
test_reader_at_impl(box_reader, data.len());
}
#[test]
fn test_reference_implementations() {
let mut data = TEST_DATA.to_vec();
let data_len = data.len();
test_reader_at_impl(&data, data_len);
test_reader_at_impl(&mut data, data_len);
}
#[test]
fn test_byte_slice_implementation() {
let data = TEST_DATA;
test_reader_at_impl(data, data.len());
}
#[test]
fn test_cursor_implementation() {
let data = TEST_DATA.to_vec();
let cursor = Cursor::new(data.clone());
test_reader_at_impl(&cursor, data.len());
}
#[test]
fn test_vec_implementation() {
let data = TEST_DATA.to_vec();
test_reader_at_impl(&data, data.len());
}
#[test]
fn test_range_reader_basic() {
let data = b"Hello, World! This is test data.";
let mut range_reader = RangeReader::new(data.as_slice(), 7..13);
let mut buffer = [0u8; 10];
let bytes_read = range_reader.read(&mut buffer).unwrap();
assert_eq!(bytes_read, 6);
assert_eq!(&buffer[..bytes_read], b"World!");
}
#[test]
fn test_range_reader_multiple_reads() {
let data = b"0123456789";
let mut range_reader = RangeReader::new(data.as_slice(), 2..8);
let mut buffer = [0u8; 3];
let bytes_read1 = range_reader.read(&mut buffer).unwrap();
assert_eq!(bytes_read1, 3);
assert_eq!(&buffer[..bytes_read1], b"234");
assert_eq!(range_reader.position(), 5);
let bytes_read2 = range_reader.read(&mut buffer).unwrap();
assert_eq!(bytes_read2, 3);
assert_eq!(&buffer[..bytes_read2], b"567");
assert_eq!(range_reader.position(), 8);
let bytes_read3 = range_reader.read(&mut buffer).unwrap();
assert_eq!(bytes_read3, 0);
}
#[test]
fn test_range_reader_empty_range() {
let data = b"Hello, World!";
let mut range_reader = RangeReader::new(data.as_slice(), 5..5);
let mut buffer = [0u8; 10];
let bytes_read = range_reader.read(&mut buffer).unwrap();
assert_eq!(bytes_read, 0);
assert_eq!(range_reader.remaining(), 0);
}
#[test]
fn test_range_reader_get_ref_and_into_inner() {
let data = b"Hello, World!";
let range_reader = RangeReader::new(data.as_slice(), 0..5);
assert_eq!(range_reader.get_ref(), &data.as_slice());
let inner = range_reader.into_inner();
assert_eq!(inner, data.as_slice());
}
#[test]
fn test_range_reader_clone() {
let data = b"Hello, World!";
let range_reader = RangeReader::new(data.as_slice(), 0..5);
let cloned = range_reader.clone();
assert_eq!(range_reader.position(), cloned.position());
assert_eq!(range_reader.remaining(), cloned.remaining());
}
#[test]
fn test_range_reader_range_exceeds_data() {
let data = b"Hello";
let mut reader1 = RangeReader::new(data.as_slice(), 3..10);
let mut buf1 = [0u8; 10];
let read1 = reader1.read(&mut buf1).unwrap();
assert_eq!(read1, 2); assert_eq!(&buf1[..read1], b"lo");
let mut reader2 = RangeReader::new(data.as_slice(), 5..10);
let mut buf2 = [0u8; 10];
let read2 = reader2.read(&mut buf2).unwrap();
assert_eq!(read2, 0);
let mut reader3 = RangeReader::new(data.as_slice(), 10..20);
let mut buf3 = [0u8; 10];
let read3 = reader3.read(&mut buf3).unwrap();
assert_eq!(read3, 0); }
}