use core::{
cmp,
io::{BorrowedBuf, BorrowedCursor},
};
use crate::{BufRead, Error, IoBuf, Read, Result, Seek, SeekFrom};
#[derive(Debug)]
pub struct Take<T> {
inner: T,
len: u64,
limit: u64,
}
impl<T> Take<T> {
pub(crate) fn new(inner: T, limit: u64) -> Self {
Take {
inner,
len: limit,
limit,
}
}
pub fn limit(&self) -> u64 {
self.limit
}
pub fn position(&self) -> u64 {
self.len - self.limit
}
pub fn set_limit(&mut self, limit: u64) {
self.len = limit;
self.limit = limit;
}
pub fn into_inner(self) -> T {
self.inner
}
pub fn get_ref(&self) -> &T {
&self.inner
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner
}
}
impl<T: Read> Read for Take<T> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
if self.limit == 0 {
return Ok(0);
}
let max = cmp::min(buf.len() as u64, self.limit) as usize;
let n = self.inner.read(&mut buf[..max])?;
assert!(n as u64 <= self.limit, "number of read bytes exceeds limit");
self.limit -= n as u64;
Ok(n)
}
fn read_buf(&mut self, mut buf: BorrowedCursor<'_>) -> Result<()> {
if self.limit == 0 {
return Ok(());
}
if self.limit < buf.capacity() as u64 {
let limit = self.limit as usize;
#[cfg(borrowedbuf_init)]
let extra_init = cmp::min(limit, buf.init_mut().len());
let ibuf = unsafe { &mut buf.as_mut()[..limit] };
let mut sliced_buf: BorrowedBuf<'_> = ibuf.into();
#[cfg(borrowedbuf_init)]
unsafe {
sliced_buf.set_init(extra_init);
}
let mut cursor = sliced_buf.unfilled();
let result = self.inner.read_buf(cursor.reborrow());
#[cfg(borrowedbuf_init)]
let new_init = cursor.init_mut().len();
let filled = sliced_buf.len();
#[cfg(borrowedbuf_init)]
unsafe {
buf.advance_unchecked(filled);
buf.set_init(new_init);
}
#[cfg(not(borrowedbuf_init))]
unsafe {
buf.advance(filled);
}
self.limit -= filled as u64;
result
} else {
let written = buf.written();
let result = self.inner.read_buf(buf.reborrow());
self.limit -= (buf.written() - written) as u64;
result
}
}
}
impl<T: BufRead> BufRead for Take<T> {
fn fill_buf(&mut self) -> Result<&[u8]> {
if self.limit == 0 {
return Ok(&[]);
}
let buf = self.inner.fill_buf()?;
let cap = cmp::min(buf.len() as u64, self.limit) as usize;
Ok(&buf[..cap])
}
fn consume(&mut self, amt: usize) {
let amt = cmp::min(amt as u64, self.limit) as usize;
self.limit -= amt as u64;
self.inner.consume(amt);
}
}
impl<T: Seek> Seek for Take<T> {
fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
let new_position = match pos {
SeekFrom::Start(v) => Some(v),
SeekFrom::Current(v) => self.position().checked_add_signed(v),
SeekFrom::End(v) => self.len.checked_add_signed(v),
};
let new_position = match new_position {
Some(v) if v <= self.len => v,
_ => return Err(Error::InvalidInput),
};
while new_position != self.position() {
if let Some(offset) = new_position.checked_signed_diff(self.position()) {
self.inner.seek_relative(offset)?;
self.limit = self.limit.wrapping_sub(offset as u64);
break;
}
let offset = if new_position > self.position() {
i64::MAX
} else {
i64::MIN
};
self.inner.seek_relative(offset)?;
self.limit = self.limit.wrapping_sub(offset as u64);
}
Ok(new_position)
}
fn stream_len(&mut self) -> Result<u64> {
Ok(self.len)
}
fn stream_position(&mut self) -> Result<u64> {
Ok(self.position())
}
fn seek_relative(&mut self, offset: i64) -> Result<()> {
if self
.position()
.checked_add_signed(offset)
.is_none_or(|p| p > self.len)
{
return Err(Error::InvalidInput);
}
self.inner.seek_relative(offset)?;
self.limit = self.limit.wrapping_sub(offset as u64);
Ok(())
}
}
impl<T: IoBuf> IoBuf for Take<T> {
fn remaining(&self) -> usize {
cmp::min(self.inner.remaining(), self.limit as usize)
}
}