use std::io;
use std::io::{Read, Seek, SeekFrom};
pub struct SectorReader<R>
where
R: Read + Seek,
{
inner: R,
sector_size: usize,
stream_position: u64,
temp_buf: Vec<u8>,
}
impl<R> SectorReader<R>
where
R: Read + Seek,
{
pub fn new(inner: R, sector_size: usize) -> io::Result<Self> {
if !sector_size.is_power_of_two() {
return Err(io::Error::new(
io::ErrorKind::Other,
"sector_size is not a power of two",
));
}
Ok(Self {
inner,
sector_size,
stream_position: 0,
temp_buf: Vec::new(),
})
}
fn align_down_to_sector_size(&self, n: u64) -> u64 {
n / self.sector_size as u64 * self.sector_size as u64
}
fn align_up_to_sector_size(&self, n: u64) -> u64 {
self.align_down_to_sector_size(n) + self.sector_size as u64
}
}
impl<R> Read for SectorReader<R>
where
R: Read + Seek,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let aligned_position = self.align_down_to_sector_size(self.stream_position);
let start = (self.stream_position - aligned_position) as usize;
let end = start + buf.len();
let aligned_bytes_to_read = self.align_up_to_sector_size(end as u64) as usize;
self.temp_buf.resize(aligned_bytes_to_read, 0);
self.inner.read_exact(&mut self.temp_buf)?;
buf.copy_from_slice(&self.temp_buf[start..end]);
self.stream_position += buf.len() as u64;
Ok(buf.len())
}
}
impl<R> Seek for SectorReader<R>
where
R: Read + Seek,
{
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
let new_pos = match pos {
SeekFrom::Start(n) => Some(n),
SeekFrom::End(_n) => {
return Err(io::Error::new(
io::ErrorKind::Other,
"SeekFrom::End is unsupported for SectorReader",
));
}
SeekFrom::Current(n) => {
if n >= 0 {
self.stream_position.checked_add(n as u64)
} else {
self.stream_position.checked_sub(n.wrapping_neg() as u64)
}
}
};
match new_pos {
Some(n) => {
let aligned_n = self.align_down_to_sector_size(n);
self.inner.seek(SeekFrom::Start(aligned_n))?;
self.stream_position = n;
Ok(self.stream_position)
}
None => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid seek to a negative or overflowing position",
)),
}
}
}