#[cfg(doc)]
use std::io::ErrorKind;
use tokio::io::AsyncReadExt;
use crate::error::MessageReadError;
pub struct AsyncPeekReader<R, const BUFFER_SIZE: usize = 280> {
buffer: [u8; BUFFER_SIZE],
cursor: usize,
top: usize,
reader: R,
}
impl<R: AsyncReadExt + Unpin, const BUFFER_SIZE: usize> AsyncPeekReader<R, BUFFER_SIZE> {
pub fn new(reader: R) -> Self {
Self {
buffer: [0; BUFFER_SIZE],
cursor: 0,
top: 0,
reader,
}
}
pub async fn peek_exact(&mut self, amount: usize) -> Result<&[u8], MessageReadError> {
self.fetch(amount, false).await
}
pub async fn read_exact(&mut self, amount: usize) -> Result<&[u8], MessageReadError> {
self.fetch(amount, true).await
}
pub async fn read_u8(&mut self) -> Result<u8, MessageReadError> {
let buf = self.read_exact(1).await?;
Ok(buf[0])
}
pub fn consume(&mut self, amount: usize) -> usize {
let amount = amount.min(self.top - self.cursor);
self.cursor += amount;
amount
}
pub fn reader_ref(&mut self) -> &R {
&self.reader
}
pub fn reader_mut(&mut self) -> &mut R {
&mut self.reader
}
async fn fetch(&mut self, amount: usize, consume: bool) -> Result<&[u8], MessageReadError> {
assert!(BUFFER_SIZE >= amount);
let buffered = self.top - self.cursor;
if buffered < amount {
let bytes_needed = amount - buffered;
if self.top + bytes_needed > self.buffer.len() {
self.buffer.copy_within(self.cursor..self.top, 0);
self.cursor = 0;
self.top = buffered;
}
let dest = &mut self.buffer[self.top..self.top + bytes_needed];
self.reader.read_exact(dest).await?;
self.top += bytes_needed;
}
let result = &self.buffer[self.cursor..self.cursor + amount];
if consume {
self.cursor += amount;
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[should_panic(expected = "assertion failed")]
async fn test_peek_exact_panics_when_amount_exceeds_buffer_size() {
let data = b"abcd";
let mut reader = AsyncPeekReader::<_, 4>::new(&data[..]);
let _ = reader.peek_exact(5).await;
}
#[tokio::test]
#[should_panic(expected = "assertion failed")]
async fn test_read_exact_panics_when_amount_exceeds_buffer_size() {
let data = b"abcd";
let mut reader = AsyncPeekReader::<_, 4>::new(&data[..]);
let _ = reader.read_exact(5).await;
}
}