use {
std::{
io::{Error, ErrorKind, StdinLock},
path::Path,
},
crate::Result,
};
#[cfg(not(feature="tokio"))]
use std::{
fs::{self, File},
io::{self, BufRead, BufReader, Read, Stdin},
};
#[cfg(feature="tokio")]
use {
core::{
pin::Pin,
task::{Context, Poll},
},
tokio::{
fs::{self, File},
io::{self, AsyncBufReadExt, AsyncRead, AsyncReadExt, BufReader, ReadBuf, Stdin},
},
};
#[cfg(test)]
mod tests;
macro_rules! read_to_end { ($self: ident, $capacity: ident) => {{
let mut result = Vec::with_capacity(match $capacity {
None => usize::MIN,
Some(capacity) => match $self.limit.checked_sub($self.read) {
None => return Err(err!()),
Some(max) => match usize::try_from(max) {
Ok(max) => max.min(capacity),
Err(_) => capacity,
},
},
});
async_call!($self.inner.read_to_end(&mut result))?;
Ok(result)
}}}
macro_rules! read { ($stream: ident, $limit: ident, $capacity: ident) => {{
async_call!(Self::new($stream, $limit).read_to_end($capacity))
}}}
macro_rules! read_to_string { ($stream: ident, $limit: ident, $capacity: ident) => {{
let data = read!($stream, $limit, $capacity)?;
match String::from_utf8(data) {
Ok(s) => Ok(s),
Err(err) => Err(Error::new(ErrorKind::InvalidInput, __!("{}", err.utf8_error()))),
}
}}}
macro_rules! open_file { ($path: ident, $limit: ident) => {{
let path = $path.as_ref();
match async_call!(fs::metadata(path))?.len() {
size => if size > $limit {
return Err(err!("File too large: {size} (limit: {limit})", limit=$limit));
},
};
Ok(Self::new(async_call!(File::open(path))?, $limit))
}}}
macro_rules! read_file { ($path: ident, $limit: ident) => {{
let path = $path.as_ref();
let mut reader = BufReader::new(async_call!(Self::open_file(path, $limit))?);
let mut result = Vec::with_capacity(async_call!(fs::metadata(path))?.len().try_into().map_err(|_| err!())?);
loop {
let count = {
let tmp = async_call!(reader.fill_buf())?;
if tmp.is_empty() {
break;
}
result.extend(tmp);
tmp.len()
};
reader.consume(count);
}
Result::Ok(result)
}}}
macro_rules! read_file_to_string { ($path: ident, $limit: ident) => {{
let data = read_file!($path, $limit)?;
match String::from_utf8(data) {
Ok(s) => Ok(s),
Err(err) => Err(Error::new(ErrorKind::InvalidInput, __!("{}", err.utf8_error()))),
}
}}}
#[derive(Debug)]
pub struct Limit<R> {
inner: R,
read: u64,
limit: u64,
}
impl<R> Limit<R> {
pub fn new(inner: R, limit: u64) -> Self {
Self {
inner,
read: u64::MIN,
limit,
}
}
}
#[cfg(not(feature="tokio"))]
#[doc(cfg(not(feature="tokio")))]
impl<R> Limit<R> where R: Read {
pub fn read_to_end(&mut self, capacity: Option<usize>) -> Result<Vec<u8>> {
read_to_end!(self, capacity)
}
pub fn read(stream: R, limit: u64, capacity: Option<usize>) -> Result<Vec<u8>> {
read!(stream, limit, capacity)
}
pub fn read_to_string(stream: R, limit: u64, capacity: Option<usize>) -> Result<String> {
read_to_string!(stream, limit, capacity)
}
}
#[cfg(feature="tokio")]
#[doc(cfg(feature="tokio"))]
impl<R> Limit<R> where R: Unpin + AsyncRead {
pub async fn read_to_end(&mut self, capacity: Option<usize>) -> Result<Vec<u8>> {
read_to_end!(self, capacity)
}
pub async fn read(stream: R, limit: u64, capacity: Option<usize>) -> Result<Vec<u8>> {
read!(stream, limit, capacity)
}
pub async fn read_to_string(stream: R, limit: u64, capacity: Option<usize>) -> Result<String> {
read_to_string!(stream, limit, capacity)
}
}
impl Limit<File> {
#[cfg(not(feature="tokio"))]
#[doc(cfg(not(feature="tokio")))]
pub fn open_file<P>(path: P, limit: u64) -> Result<Self> where P: AsRef<Path> {
open_file!(path, limit)
}
#[cfg(feature="tokio")]
#[doc(cfg(feature="tokio"))]
pub async fn open_file<P>(path: P, limit: u64) -> Result<Self> where P: AsRef<Path> {
open_file!(path, limit)
}
#[cfg(not(feature="tokio"))]
#[doc(cfg(not(feature="tokio")))]
pub fn read_file<P>(path: P, limit: u64) -> Result<Vec<u8>> where P: AsRef<Path> {
read_file!(path, limit)
}
#[cfg(feature="tokio")]
#[doc(cfg(feature="tokio"))]
pub async fn read_file<P>(path: P, limit: u64) -> Result<Vec<u8>> where P: AsRef<Path> {
read_file!(path, limit)
}
#[cfg(not(feature="tokio"))]
#[doc(cfg(not(feature="tokio")))]
pub fn read_file_to_string<P>(path: P, limit: u64) -> Result<String> where P: AsRef<Path> {
read_file_to_string!(path, limit)
}
#[cfg(feature="tokio")]
#[doc(cfg(feature="tokio"))]
pub async fn read_file_to_string<P>(path: P, limit: u64) -> Result<String> where P: AsRef<Path> {
read_file_to_string!(path, limit)
}
}
impl Limit<Stdin> {
pub fn stdin(limit: u64) -> Self {
Self::new(io::stdin(), limit)
}
}
impl Limit<StdinLock<'_>> {
pub fn lock_stdin(limit: u64) -> Self {
Self::new(std::io::stdin().lock(), limit)
}
}
macro_rules! check { ($self: ident) => {
if $self.read > $self.limit {
let err = Err(err!("Limit: {limit}, already read: {read}", limit=$self.limit, read=$self.read));
#[cfg(feature="tokio")]
let err = Poll::Ready(err);
return err;
}
}}
macro_rules! update_read { ($self: ident, $more: ident) => {{
match u64::try_from($more).map_err(|_| err!("Cannot convert {more} into u64", more=$more))?.checked_add($self.read) {
None => {
let err = Err(err!("Failed: {read} + {more}", read=$self.read, more=$more));
#[cfg(feature="tokio")]
let err = Poll::Ready(err);
return err;
},
Some(read) => $self.read = read,
};
}}}
#[cfg(not(feature="tokio"))]
#[doc(cfg(not(feature="tokio")))]
impl<R> Read for Limit<R> where R: Read {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
check!(self);
let result = self.inner.read(buf)?;
update_read!(self, result);
check!(self);
Ok(result)
}
}
#[cfg(feature="tokio")]
#[doc(cfg(feature="tokio"))]
impl<R> AsyncRead for Limit<R> where R: Unpin + AsyncRead {
fn poll_read(self: Pin<&mut Self>, context: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<()>> {
let limit = self.get_mut();
check!(limit);
let read = {
let last = buf.filled().len();
match AsyncRead::poll_read(Pin::new(&mut limit.inner), context, buf) {
Poll::Ready(Ok(())) => buf.filled().len().checked_sub(last).ok_or_else(|| err!())?,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
};
update_read!(limit, read);
check!(limit);
Poll::Ready(Ok(()))
}
}