use crate::readers::prepend_reader::AsyncPrependReader;
use tokio::io::{AsyncRead, ReadBuf};
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct AsyncDelimiterReader<R: AsyncRead + Unpin> {
inner: AsyncPrependReader<R>,
delimiter: Vec<u8>,
matched: bool,
}
impl<R: AsyncRead + Unpin> AsyncDelimiterReader <R> {
pub fn new(inner: R, delimiter: &[u8]) -> Self {
Self::with_owned_delimiter(inner, delimiter.to_vec())
}
pub fn with_owned_delimiter(inner: R, delimiter: Vec<u8>) -> Self {
if delimiter.len() == 0 {
panic!("Delimiter byte length must be non-zero.");
}
Self {
inner: AsyncPrependReader::new(inner),
delimiter,
matched: false,
}
}
pub fn matched(&self) -> bool {
self.matched
}
pub fn reset(&mut self) {
self.matched = false;
}
pub fn into_inner(self) -> R {
self.inner.into_inner()
}
pub fn get_mut(&mut self) -> &mut R {
self.inner.get_mut()
}
pub fn buffer(&self) -> &[u8] {
self.inner.buffer()
}
}
impl<R: AsyncRead + Unpin> AsyncRead for AsyncDelimiterReader <R> {
fn poll_read(mut self: Pin<&mut Self>, c: &mut Context<'_>, b: &mut ReadBuf<'_>) -> Poll<tokio::io::Result<()>> {
if self.matched {
return Poll::Ready(Ok(()));
}
let (prev_len, new_len, poll) = fill_buffer(&mut self, c, b);
if new_len - prev_len == 0 {
return poll;
}
let read_slice = &b.filled()[prev_len..new_len];
if let Some((full_match, match_index)) = match_delimiter(&mut self, read_slice) {
let read_slice_len = read_slice.len();
let index = match_index + self.delimiter.len();
if read_slice_len >= index {
let actual_read_slice = &read_slice[index..];
self.inner.prepend(actual_read_slice);
b.set_filled(match_index);
} else if !full_match {
let partial_delimiter = &read_slice[match_index..];
self.inner.prepend(partial_delimiter);
b.set_filled(match_index);
if match_index == 0 {
c.waker().wake_by_ref();
return Poll::Pending;
}
}
if full_match {
self.matched = true;
return Poll::Ready(Ok(()));
}
}
poll
}
}
fn fill_buffer<W: AsyncRead + Unpin>(
reader: &mut Pin<&mut AsyncDelimiterReader<W>>, c: &mut Context<'_>, b: &mut ReadBuf<'_>
) -> (usize, usize, Poll<tokio::io::Result<()>>) {
let prev_len = b.filled().len();
let poll = Pin::new(&mut reader.inner).poll_read(c, b);
let new_len = b.filled().len();
(prev_len, new_len, poll)
}
fn match_delimiter<W: AsyncRead + Unpin>(reader: &mut AsyncDelimiterReader<W>, buf: &[u8]) -> Option<(bool, usize)> {
'outer: for index in 0..buf.len() {
for (delim_index, delim_byte) in reader.delimiter.iter().enumerate() {
if index + delim_index >= buf.len() && delim_index != 0 {
return Some((false, index));
}
if buf[index + delim_index] != *delim_byte {
continue 'outer;
}
}
return Some((true, index));
}
None
}
#[cfg(test)]
#[tokio::test]
async fn foo_bar() {
use std::io::Cursor;
use tokio::io::AsyncReadExt;
let data = b"Foo. Bar. Foo. Foo. Bar. Foo. Bar.";
let delimiter = b"Bar.";
let cursor = Cursor::new(data);
let mut buffer = String::new();
let mut reader = AsyncDelimiterReader::new(cursor, delimiter);
reader.read_to_string(&mut buffer).await.expect("failed to read");
assert_eq!(buffer, "Foo. ");
assert!(reader.matched());
reader.reset();
buffer.clear();
reader.read_to_string(&mut buffer).await.expect("failed to read");
assert_eq!(buffer, " Foo. Foo. ");
assert!(reader.matched());
reader.reset();
buffer.clear();
reader.read_to_string(&mut buffer).await.expect("failed to read");
assert_eq!(buffer, " Foo. ");
assert!(reader.matched());
reader.reset();
buffer.clear();
reader.read_to_string(&mut buffer).await.expect("failed to read");
assert_eq!(buffer, "");
assert!(!reader.matched());
}