use std::{
pin::Pin,
task::{Context, Poll},
};
use bytes::{BufMut, Bytes, BytesMut};
use tokio::io::{self, AsyncRead, AsyncReadExt, ReadBuf};
const INITIAL_CAPACITY: usize = 4 * 1024;
pub struct Prebuffered<R> {
inner: R,
buf: BytesMut,
max_len: usize,
}
pub trait Prebufferable: Send + Unpin {
fn is_full(&self) -> bool;
fn buffer(&self) -> &[u8];
fn discard(&mut self, n: usize);
fn buffer_more(&mut self) -> impl Future<Output = io::Result<usize>> + Send;
}
impl<R: AsyncRead + Unpin + Send> Prebufferable for Prebuffered<R> {
fn is_full(&self) -> bool {
Prebuffered::is_full(self)
}
fn buffer(&self) -> &[u8] {
Prebuffered::buffer(self)
}
fn discard(&mut self, n: usize) {
Prebuffered::discard(self, n)
}
fn buffer_more(&mut self) -> impl Future<Output = io::Result<usize>> + Send {
Prebuffered::buffer_more(self)
}
}
impl<R: AsyncRead + Unpin> Prebuffered<R> {
pub(crate) fn new(inner: R, max_len: usize) -> Self {
Self {
inner,
buf: BytesMut::with_capacity(INITIAL_CAPACITY),
max_len,
}
}
#[cfg(test)]
pub(crate) fn unlimited(inner: R) -> Self {
Self::new(inner, usize::MAX)
}
pub(crate) fn buffer(&self) -> &[u8] {
&self.buf[..]
}
#[cfg(test)]
pub(crate) fn len(&self) -> usize {
self.buf.len()
}
pub(crate) fn is_full(&self) -> bool {
self.buf.len() == self.max_len
}
pub(crate) fn discard(&mut self, n: usize) {
let _ = self.buf.split_to(n);
}
pub(crate) async fn buffer_more(&mut self) -> io::Result<usize> {
let max = self.max_len.saturating_sub(self.buf.len());
let n = (&mut self.inner)
.take(max as u64)
.read_buf(&mut self.buf)
.await?;
if n == 0 {
Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"wanted to buffer more but stream closed",
))
} else {
Ok(n)
}
}
pub(crate) fn into_parts(self) -> (Bytes, R) {
(self.buf.freeze(), self.inner)
}
}
impl<R: AsyncRead + Unpin> AsyncRead for Prebuffered<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
out: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if !out.has_remaining_mut() {
Poll::Ready(Ok(()))
} else if !self.buf.is_empty() {
let n = self.buf.len().min(out.remaining_mut());
let chunk = self.buf.split_to(n);
out.put_slice(&chunk);
Poll::Ready(Ok(()))
} else {
Pin::new(&mut self.inner).poll_read(cx, out)
}
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use tokio::io::AsyncReadExt;
use super::*;
fn cursor(data: &'static [u8]) -> Cursor<&'static [u8]> {
Cursor::new(data)
}
#[tokio::test]
async fn buffer_more_respects_max() {
let mut p = Prebuffered::unlimited(cursor(b"abcdefgh"));
let n = p.buffer_more().await.unwrap();
assert_eq!(n, 8);
assert_eq!(p.buffer(), b"abcdefgh");
}
#[tokio::test]
async fn buffer_more_eof() {
let mut p = Prebuffered::unlimited(cursor(b""));
let err = p.buffer_more().await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
}
#[tokio::test]
async fn discard_beyond_len_is_ok() {
let mut p = Prebuffered::unlimited(cursor(b"abc"));
p.buffer_more().await.unwrap();
p.discard(p.len());
assert_eq!(p.buffer(), b"");
let err = p.buffer_more().await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
}
#[tokio::test]
async fn async_read_fallthrough_from_buffer() {
let mut p = Prebuffered::new(cursor(b"hello world"), 5);
p.buffer_more().await.unwrap(); assert_eq!(p.buffer(), b"hello");
let mut out = Vec::new();
p.read_to_end(&mut out).await.unwrap();
assert_eq!(out, b"hello world");
}
#[tokio::test]
async fn async_read_partial_reads_from_buffer_then_inner() {
let mut p = Prebuffered::new(cursor(b"abcdef"), 4);
p.buffer_more().await.unwrap();
assert_eq!(p.buffer(), b"abcd");
p.discard(2);
assert_eq!(p.buffer(), b"cd");
let mut buf = [0u8; 2];
let n = p.read(&mut buf).await.unwrap();
assert_eq!(n, 2);
assert_eq!(&buf, b"cd");
let mut rest = Vec::new();
p.read_to_end(&mut rest).await.unwrap();
assert_eq!(rest, b"ef");
}
#[tokio::test]
async fn buffer_more_does_not_reset_pos() {
let mut p = Prebuffered::new(cursor(b"abcdefghij"), 4);
p.buffer_more().await.unwrap();
assert_eq!(p.buffer(), b"abcd");
p.discard(3);
assert_eq!(p.buffer(), b"d");
p.buffer_more().await.unwrap();
assert_eq!(p.buffer(), b"defg");
p.discard(1);
assert_eq!(p.buffer(), b"efg");
let mut out = Vec::new();
p.read_to_end(&mut out).await.unwrap();
assert_eq!(out, b"efghij");
assert_eq!(p.buffer(), b"");
}
#[tokio::test]
async fn read_without_any_buffering() {
let mut p = Prebuffered::unlimited(cursor(b"xyz"));
let mut out = Vec::new();
p.read_to_end(&mut out).await.unwrap();
assert_eq!(out, b"xyz");
assert_eq!(p.buffer(), b"");
}
}