use super::{Read, Result, Seek, SeekFrom};
#[derive(Debug)]
pub struct TakeSeek<T> {
inner: T,
pos: u64,
end: u64,
}
impl<T> TakeSeek<T> {
pub fn get_ref(&self) -> &T {
&self.inner
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner
}
pub fn into_inner(self) -> T {
self.inner
}
pub fn limit(&self) -> u64 {
self.end.saturating_sub(self.pos)
}
}
impl<T: Seek> TakeSeek<T> {
pub fn set_limit(&mut self, limit: u64) {
let pos = self
.inner
.stream_position()
.expect("cannot get position for `set_limit`");
self.pos = pos;
self.end = pos + limit;
}
}
impl<T: Read> Read for TakeSeek<T> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
let limit = self.limit();
if limit == 0 {
return Ok(0);
}
#[allow(clippy::cast_possible_truncation)]
let max = (buf.len() as u64).min(limit) as usize;
let n = self.inner.read(&mut buf[0..max])?;
self.pos += n as u64;
Ok(n)
}
}
impl<T: Seek> Seek for TakeSeek<T> {
fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
let pos = match pos {
SeekFrom::End(end) => {
let inner_end = self.inner.seek(SeekFrom::End(0))?;
match self.end.min(inner_end).checked_add_signed(end) {
Some(pos) => SeekFrom::Start(pos),
None => {
return Err(super::Error::new(
super::ErrorKind::InvalidInput,
"invalid seek to a negative or overflowing position",
))
}
}
}
pos => pos,
};
self.pos = self.inner.seek(pos)?;
Ok(self.pos)
}
fn stream_position(&mut self) -> Result<u64> {
Ok(self.pos)
}
}
pub trait TakeSeekExt {
fn take_seek(self, limit: u64) -> TakeSeek<Self>
where
Self: Sized;
}
impl<T: Read + Seek> TakeSeekExt for T {
fn take_seek(mut self, limit: u64) -> TakeSeek<Self>
where
Self: Sized,
{
let pos = self
.stream_position()
.expect("cannot get position for `take_seek`");
TakeSeek {
inner: self,
pos,
end: pos + limit,
}
}
}