#![forbid(unsafe_code)]
#![deny(
clippy::dbg_macro,
missing_copy_implementations,
rustdoc::missing_crate_level_docs,
missing_debug_implementations,
nonstandard_style,
unused_qualifications
)]
#![warn(missing_docs)]
use futures_lite::AsyncRead;
use std::{
error::Error,
fmt::Display,
io::{ErrorKind, Result},
pin::Pin,
task::{ready, Context, Poll},
};
pin_project_lite::pin_project! {
#[derive(Debug, Clone, Copy)]
pub struct LengthLimit<T> {
#[pin]
reader: T,
bytes_remaining: usize,
}
}
impl<T> LengthLimit<T>
where
T: AsyncRead,
{
pub fn new(reader: T, max_bytes: usize) -> Self {
Self {
reader,
bytes_remaining: max_bytes,
}
}
pub fn bytes_remaining(&self) -> usize {
self.bytes_remaining
}
pub fn into_inner(self) -> T {
self.reader
}
}
impl<T> AsRef<T> for LengthLimit<T> {
fn as_ref(&self) -> &T {
&self.reader
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct LengthLimitExceeded;
impl Display for LengthLimitExceeded {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Length limit exceeded")
}
}
impl Error for LengthLimitExceeded {}
impl From<LengthLimitExceeded> for std::io::Error {
fn from(value: LengthLimitExceeded) -> Self {
Self::new(ErrorKind::InvalidData, value)
}
}
impl<T: AsyncRead> AsyncRead for LengthLimit<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: &mut [u8],
) -> Poll<Result<usize>> {
let projection = self.project();
let reader = projection.reader;
let bytes_remaining = *projection.bytes_remaining;
if bytes_remaining == 0 {
return Poll::Ready(Err(LengthLimitExceeded.into()));
}
if bytes_remaining < buf.len() {
buf = &mut buf[..bytes_remaining];
}
let new_bytes = ready!(reader.poll_read(cx, buf))?;
*projection.bytes_remaining = bytes_remaining.saturating_sub(new_bytes);
Poll::Ready(Ok(new_bytes))
}
}
pub trait LengthLimitExt: Sized + AsyncRead {
fn limit_bytes(self, max_bytes: usize) -> LengthLimit<Self> {
LengthLimit::new(self, max_bytes)
}
fn limit_kb(self, max_kb: usize) -> LengthLimit<Self> {
self.limit_bytes(max_kb * 1024)
}
fn limit_mb(self, max_mb: usize) -> LengthLimit<Self> {
self.limit_kb(max_mb * 1024)
}
fn limit_gb(self, max_gb: usize) -> LengthLimit<Self> {
self.limit_mb(max_gb * 1024)
}
}
impl<T> LengthLimitExt for T where T: AsyncRead + Unpin {}