#[cfg(doc)]
use std::io::ErrorKind;
use tokio::io::AsyncReadExt;
use crate::error::MessageReadError;
pub struct AsyncPeekReader<R, const BUFFER_SIZE: usize = 280> {
buffer: [u8; BUFFER_SIZE],
cursor: usize,
top: usize,
reader: R,
}
impl<R: AsyncReadExt + Unpin, const BUFFER_SIZE: usize> AsyncPeekReader<R, BUFFER_SIZE> {
pub fn new(reader: R) -> Self {
Self {
buffer: [0; BUFFER_SIZE],
cursor: 0,
top: 0,
reader,
}
}
pub async fn peek_exact(&mut self, amount: usize) -> Result<&[u8], MessageReadError> {
self.fetch(amount, false).await
}
pub async fn read_exact(&mut self, amount: usize) -> Result<&[u8], MessageReadError> {
self.fetch(amount, true).await
}
pub async fn read_u8(&mut self) -> Result<u8, MessageReadError> {
let buf = self.read_exact(1).await?;
Ok(buf[0])
}
pub fn consume(&mut self, amount: usize) -> usize {
let amount = amount.min(self.top - self.cursor);
self.cursor += amount;
amount
}
pub fn reader_ref(&mut self) -> &R {
&self.reader
}
pub fn reader_mut(&mut self) -> &mut R {
&mut self.reader
}
async fn fetch(&mut self, amount: usize, consume: bool) -> Result<&[u8], MessageReadError> {
let buffered = self.top - self.cursor;
if buffered < amount {
let bytes_read = amount - buffered;
assert!(bytes_read < BUFFER_SIZE);
let mut buf = [0u8; BUFFER_SIZE];
self.reader.read_exact(&mut buf[..bytes_read]).await?;
if self.buffer.len() - self.top < bytes_read {
self.buffer.copy_within(self.cursor..self.top, 0);
self.cursor = 0;
self.top = buffered;
}
self.buffer[self.top..self.top + bytes_read].copy_from_slice(&buf[..bytes_read]);
self.top += bytes_read;
}
let result = &self.buffer[self.cursor..self.cursor + amount];
if consume {
self.cursor += amount;
}
Ok(result)
}
}