use std::{
cmp, io,
ops::DerefMut,
pin::Pin,
task::{Context, Poll},
};
use super::{Buffer, DefaultBuffer};
use ::tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
mod peek;
pub use peek::*;
mod peek_buf;
pub use peek_buf::*;
mod peek_exact;
pub use peek_exact::*;
mod peek_to_end;
pub use peek_to_end::*;
mod peek_to_string;
pub use peek_to_string::*;
mod fill_peek_buf;
pub use fill_peek_buf::*;
pub trait AsyncPeek: ::tokio::io::AsyncRead {
fn poll_peek(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>>;
}
macro_rules! deref_async_peek {
() => {
#[cfg_attr(not(tarpaulin), inline(always))]
fn poll_peek(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut **self).poll_peek(cx, buf)
}
};
}
impl<T: ?Sized + AsyncPeek + Unpin> AsyncPeek for Box<T> {
deref_async_peek!();
}
impl<T: ?Sized + AsyncPeek + Unpin> AsyncPeek for &mut T {
deref_async_peek!();
}
impl<P> AsyncPeek for Pin<P>
where
P: DerefMut + Unpin,
P::Target: AsyncPeek,
{
#[cfg_attr(not(tarpaulin), inline(always))]
fn poll_peek(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.get_mut().as_mut().poll_peek(cx, buf)
}
}
pin_project_lite::pin_project! {
#[derive(Debug)]
pub struct AsyncPeekable<R, B = DefaultBuffer> {
#[pin]
reader: R,
buffer: B,
buf_cap: Option<usize>,
}
}
impl<R> From<R> for AsyncPeekable<R> {
#[cfg_attr(not(tarpaulin), inline(always))]
fn from(reader: R) -> Self {
Self::new(reader)
}
}
impl<R> From<(usize, R)> for AsyncPeekable<R> {
#[cfg_attr(not(tarpaulin), inline(always))]
fn from((cap, reader): (usize, R)) -> Self {
Self::with_capacity(reader, cap)
}
}
impl<R: AsyncRead, B: Buffer> AsyncRead for AsyncPeekable<R, B> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.project();
let buffer_len = this.buffer.len();
if buffer_len > 0 {
let available = buf.remaining();
return match available.cmp(&buffer_len) {
cmp::Ordering::Greater => {
let orig_filled = buf.filled().len();
buf.put_slice(this.buffer.as_slice());
match this.reader.poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
this.buffer.clear();
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => {
buf.set_filled(orig_filled);
Poll::Ready(Err(e))
}
Poll::Pending => {
this.buffer.clear();
Poll::Ready(Ok(()))
}
}
}
cmp::Ordering::Equal => {
buf.put_slice(this.buffer.as_slice());
this.buffer.clear();
return Poll::Ready(Ok(()));
}
cmp::Ordering::Less => {
buf.put_slice(&this.buffer.as_slice()[..available]);
this.buffer.consume(..available);
return Poll::Ready(Ok(()));
}
};
}
this.reader.poll_read(cx, buf)
}
}
impl<W: AsyncWrite, B> AsyncWrite for AsyncPeekable<W, B> {
#[cfg_attr(not(tarpaulin), inline(always))]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().reader.poll_flush(cx)
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().reader.poll_shutdown(cx)
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().reader.poll_write(cx, buf)
}
}
impl<R: AsyncRead, B: Buffer> AsyncPeek for AsyncPeekable<R, B> {
fn poll_peek(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.project();
let buffer_len = this.buffer.len();
if buffer_len > 0 {
let available = buf.remaining();
if available > buffer_len {
let orig_filled = buf.filled().len();
buf.put_slice(this.buffer.as_slice());
let cur = buf.filled().len();
match this.reader.poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let filled = buf.filled();
let read = filled.len() - cur;
this.buffer.extend_from_slice(&filled[cur..cur + read])?;
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => {
buf.set_filled(orig_filled);
Poll::Ready(Err(e))
}
Poll::Pending => {
Poll::Ready(Ok(()))
}
}
} else {
buf.put_slice(&this.buffer.as_slice()[..available]);
Poll::Ready(Ok(()))
}
} else {
let cur = buf.filled().len();
match this.reader.poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let filled = buf.filled();
let read = filled.len() - cur;
this.buffer.extend_from_slice(&filled[cur..cur + read])?;
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
}
impl<R> AsyncPeekable<R> {
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn new(reader: R) -> Self {
Self::construct(reader, DefaultBuffer::new(), None)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn with_capacity(reader: R, capacity: usize) -> Self {
Self::construct(
reader,
DefaultBuffer::with_capacity(capacity),
Some(capacity),
)
}
}
impl<R, B> AsyncPeekable<R, B> {
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn with_buffer(reader: R) -> Self
where
B: Buffer,
{
Self::construct(reader, B::new(), None)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn with_capacity_and_buffer(reader: R, capacity: usize) -> Self
where
B: Buffer,
{
Self::construct(reader, B::with_capacity(capacity), Some(capacity))
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn construct(reader: R, buffer: B, buf_cap: Option<usize>) -> Self {
Self {
reader,
buffer,
buf_cap,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn consume(&mut self) -> B
where
B: Buffer,
{
let buf = match self.buf_cap {
Some(capacity) => B::with_capacity(capacity),
None => B::new(),
};
core::mem::replace(&mut self.buffer, buf)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn consume_in_place(&mut self)
where
B: Buffer,
{
self.buffer.clear();
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn get_mut(&mut self) -> (&[u8], &mut R)
where
B: Buffer,
{
(self.buffer.as_slice(), &mut self.reader)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn get_ref(&self) -> (&[u8], &R)
where
B: Buffer,
{
(self.buffer.as_slice(), &self.reader)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn into_components(self) -> (B, R) {
(self.buffer, self.reader)
}
}
pub trait AsyncPeekExt: AsyncRead {
fn peekable(self) -> AsyncPeekable<Self>
where
Self: Sized,
{
AsyncPeekable::from(self)
}
fn peekable_with_capacity(self, capacity: usize) -> AsyncPeekable<Self>
where
Self: Sized,
{
AsyncPeekable::from((capacity, self))
}
fn peekable_with_buffer<B>(self) -> AsyncPeekable<Self, B>
where
Self: Sized,
B: Buffer,
{
AsyncPeekable::with_buffer(self)
}
fn peekable_with_capacity_and_buffer<B>(self, capacity: usize) -> AsyncPeekable<Self, B>
where
Self: Sized,
B: Buffer,
{
AsyncPeekable::with_capacity_and_buffer(self, capacity)
}
}
impl<R: AsyncRead> AsyncPeekExt for R {}
impl<R: AsyncRead + Unpin, BUF: Buffer> AsyncPeekable<R, BUF> {
pub fn peek<'a>(&'a mut self, buf: &'a mut [u8]) -> Peek<'a, R, BUF>
where
Self: Unpin,
{
peek(self, buf)
}
pub fn peek_buf<'a, B>(&'a mut self, buf: &'a mut B) -> PeekBuf<'a, R, B, BUF>
where
Self: Unpin,
B: bytes::BufMut + ?Sized,
{
peek_buf(self, buf)
}
pub fn peek_exact<'a>(&'a mut self, buf: &'a mut [u8]) -> PeekExact<'a, R, BUF>
where
Self: Unpin,
{
peek_exact(self, buf)
}
pub fn peek_to_end<'a>(&'a mut self, buf: &'a mut Vec<u8>) -> PeekToEnd<'a, R, BUF>
where
Self: Unpin,
{
peek_to_end(self, buf)
}
pub fn peek_to_string<'a>(&'a mut self, dst: &'a mut String) -> PeekToString<'a, R, BUF>
where
Self: Unpin,
{
peek_to_string(self, dst)
}
pub fn fill_peek_buf(&mut self) -> FillPeekBuf<'_, R, BUF> {
fill_peek_buf(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use tokio::io::AsyncReadExt;
#[tokio::test]
async fn test_peek_exact_peek_exact_read_exact() {
let mut peekable = Cursor::new([1, 2, 3, 4, 5, 6, 7, 8, 9]).peekable();
let mut buf1 = [0; 2];
peekable.peek_exact(&mut buf1).await.unwrap();
assert_eq!(buf1, [1, 2]);
let mut buf2 = [0; 4];
peekable.peek_exact(&mut buf2).await.unwrap();
assert_eq!(buf2, [1, 2, 3, 4]);
let mut buf3 = [0; 4];
peekable.read_exact(&mut buf3).await.unwrap();
assert_eq!(buf3, [1, 2, 3, 4]);
}
#[tokio::test]
async fn test_peek_exact_peek_exact_read_exact_1() {
let mut peekable = Cursor::new([1, 2, 3, 4, 5, 6, 7, 8, 9]).peekable_with_buffer::<Vec<u8>>();
let mut buf1 = [0; 2];
peekable.peek_exact(&mut buf1).await.unwrap();
assert_eq!(buf1, [1, 2]);
let mut buf2 = [0; 4];
peekable.peek_exact(&mut buf2).await.unwrap();
assert_eq!(buf2, [1, 2, 3, 4]);
let mut buf3 = [0; 4];
peekable.read_exact(&mut buf3).await.unwrap();
assert_eq!(buf3, [1, 2, 3, 4]);
}
#[tokio::test]
async fn test_peek_exact_peek_exact_read_exact_2() {
let mut peekable =
Cursor::new([1, 2, 3, 4, 5, 6, 7, 8, 9]).peekable_with_capacity_and_buffer::<Vec<u8>>(24);
let mut buf1 = [0; 2];
peekable.peek_exact(&mut buf1).await.unwrap();
assert_eq!(buf1, [1, 2]);
let mut buf2 = [0; 4];
peekable.peek_exact(&mut buf2).await.unwrap();
assert_eq!(buf2, [1, 2, 3, 4]);
let mut buf3 = [0; 4];
peekable.read_exact(&mut buf3).await.unwrap();
assert_eq!(buf3, [1, 2, 3, 4]);
}
}