use std::{
cmp, fmt, io,
io::{BufRead, IoSliceMut, Read, ReadBuf, Seek, SeekFrom},
mem::MaybeUninit,
str::from_utf8,
};
pub struct StackBufReader<R, const N: usize = 4096> {
inner: R,
buf: [MaybeUninit<u8>; N],
pos: usize,
cap: usize,
init: usize,
}
impl<R: Read, const N: usize> StackBufReader<R, N> {
pub fn new(inner: R) -> StackBufReader<R, N> {
StackBufReader {
inner,
buf: unsafe { MaybeUninit::uninit().assume_init() },
pos: 0,
cap: 0,
init: 0,
}
}
}
impl<R, const N: usize> StackBufReader<R, N> {
pub fn get_ref(&self) -> &R {
&self.inner
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.inner
}
pub fn buffer(&self) -> &[u8] {
unsafe { MaybeUninit::slice_assume_init_ref(&self.buf[self.pos..self.cap]) }
}
pub fn capacity(&self) -> usize {
self.buf.len()
}
pub fn into_inner(self) -> R {
self.inner
}
#[inline]
fn discard_buffer(&mut self) {
self.pos = 0;
self.cap = 0;
}
}
impl<R: Seek, const N: usize> StackBufReader<R, N> {
pub fn seek_relative(&mut self, offset: i64) -> io::Result<()> {
let pos = self.pos as u64;
if offset < 0 {
if let Some(new_pos) = pos.checked_sub((-offset) as u64) {
self.pos = new_pos as usize;
return Ok(());
}
} else if let Some(new_pos) = pos.checked_add(offset as u64) {
if new_pos <= self.cap as u64 {
self.pos = new_pos as usize;
return Ok(());
}
}
self.seek(SeekFrom::Current(offset)).map(drop)
}
}
impl<R: Read, const N: usize> Read for StackBufReader<R, N> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.pos == self.cap && buf.len() >= self.buf.len() {
self.discard_buffer();
return self.inner.read(buf);
}
let nread = {
let mut rem = self.fill_buf()?;
rem.read(buf)?
};
self.consume(nread);
Ok(nread)
}
fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
let total_len = bufs.iter().map(|b| b.len()).sum::<usize>();
if self.pos == self.cap && total_len >= self.buf.len() {
self.discard_buffer();
return self.inner.read_vectored(bufs);
}
let nread = {
let mut rem = self.fill_buf()?;
rem.read_vectored(bufs)?
};
self.consume(nread);
Ok(nread)
}
fn is_read_vectored(&self) -> bool {
self.inner.is_read_vectored()
}
fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
let nread = self.cap - self.pos;
buf.extend_from_slice(self.buffer());
self.discard_buffer();
Ok(nread + self.inner.read_to_end(buf)?)
}
fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
if buf.is_empty() {
unsafe { append_to_string(buf, |b| self.read_to_end(b)) }
} else {
let mut bytes = Vec::new();
self.read_to_end(&mut bytes)?;
let string = from_utf8(&bytes).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"stream did not contain valid UTF-8",
)
})?;
*buf += string;
Ok(string.len())
}
}
fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
if self.buffer().len() >= buf.len() {
buf.copy_from_slice(&self.buffer()[..buf.len()]);
self.consume(buf.len());
return Ok(());
}
default_read_exact(self, buf)
}
fn read_buf(&mut self, buf: &mut ReadBuf<'_>) -> io::Result<()> {
if self.pos == self.cap && buf.remaining() >= self.buf.len() {
self.discard_buffer();
return self.inner.read_buf(buf);
}
let prev = buf.filled_len();
let mut rem = self.fill_buf()?;
rem.read_buf(buf)?;
self.consume(buf.filled_len() - prev);
Ok(())
}
}
struct Guard<'a> {
buf: &'a mut Vec<u8>,
len: usize,
}
impl Drop for Guard<'_> {
fn drop(&mut self) {
unsafe {
self.buf.set_len(self.len);
}
}
}
unsafe fn append_to_string<F>(buf: &mut String, f: F) -> io::Result<usize>
where
F: FnOnce(&mut Vec<u8>) -> io::Result<usize>,
{
let mut g = Guard {
len: buf.len(),
buf: buf.as_mut_vec(),
};
let ret = f(g.buf);
if from_utf8(&g.buf[g.len..]).is_err() {
ret.and_then(|_| {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"stream did not contain valid UTF-8",
))
})
} else {
g.len = g.buf.len();
ret
}
}
fn default_read_exact<R: Read + ?Sized>(this: &mut R, mut buf: &mut [u8]) -> io::Result<()> {
while !buf.is_empty() {
match this.read(buf) {
Ok(0) => break,
Ok(n) => {
let tmp = buf;
buf = &mut tmp[n..];
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
if !buf.is_empty() {
Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"failed to fill whole buffer",
))
} else {
Ok(())
}
}
impl<R: Read, const N: usize> BufRead for StackBufReader<R, N> {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
if self.pos >= self.cap {
debug_assert!(self.pos == self.cap);
let mut readbuf = ReadBuf::uninit(&mut self.buf);
unsafe {
readbuf.assume_init(self.init);
}
self.inner.read_buf(&mut readbuf)?;
self.cap = readbuf.filled_len();
self.init = readbuf.initialized_len();
self.pos = 0;
}
Ok(self.buffer())
}
fn consume(&mut self, amt: usize) {
self.pos = cmp::min(self.pos + amt, self.cap);
}
}
impl<R, const N: usize> fmt::Debug for StackBufReader<R, N>
where
R: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("BufReader")
.field("reader", &self.inner)
.field(
"buffer",
&format_args!("{}/{}", self.cap - self.pos, self.buf.len()),
)
.finish()
}
}
impl<R: Seek, const N: usize> Seek for StackBufReader<R, N> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
let result: u64;
if let SeekFrom::Current(n) = pos {
let remainder = (self.cap - self.pos) as i64;
if let Some(offset) = n.checked_sub(remainder) {
result = self.inner.seek(SeekFrom::Current(offset))?;
} else {
self.inner.seek(SeekFrom::Current(-remainder))?;
self.discard_buffer();
result = self.inner.seek(SeekFrom::Current(n))?;
}
} else {
result = self.inner.seek(pos)?;
}
self.discard_buffer();
Ok(result)
}
fn stream_position(&mut self) -> io::Result<u64> {
let remainder = (self.cap - self.pos) as u64;
self.inner.stream_position().map(|pos| {
pos.checked_sub(remainder).expect(
"overflow when subtracting remaining buffer size from inner stream position",
)
})
}
}