use std::collections::VecDeque;
use async_stream::try_stream;
use tokio::io::AsyncReadExt;
use tokio_stream::Stream;
use crate::{deserialize::Deserializable, types::stream::ScannedItem, Error, GenError, GenResult};
#[async_trait::async_trait]
pub trait WriteIntoAsync {
async fn write_into(
&self,
writer: &mut (dyn tokio::io::AsyncWrite + Send + Unpin),
) -> std::io::Result<()>;
}
async fn next<R: AsyncReadExt + Unpin>(
reader: &mut R,
buffer: &mut VecDeque<u8>,
) -> GenResult<Option<ScannedItem>> {
let mut required_bytes = {
let buf = buffer.make_contiguous();
match ScannedItem::required_bytes(buf) {
Ok(v) => v,
Err(e) => Err(GenError::from(e))?,
}
};
while buffer.len() < required_bytes {
let mut read_buffer = [0u8; 8192];
match reader.read(&mut read_buffer).await {
Ok(0) => {
if buffer.is_empty() {
return Ok(None);
}
return Ok(Some(ScannedItem::Unknown(
buffer.drain(..).collect(),
Error::NotEnoughBytes.into(),
)));
}
Ok(bytes_read) => {
buffer.extend(read_buffer[..bytes_read].iter());
}
Err(e) => Err(GenError::from(e))?,
}
{
let buf = buffer.make_contiguous();
required_bytes = match ScannedItem::required_bytes(buf) {
Ok(v) => v,
Err(e) => Err(GenError::from(e))?,
};
}
}
let buf = buffer.make_contiguous();
match ScannedItem::deserialize(buf) {
Ok((ScannedItem::Event(event), used)) => {
buffer.drain(..used);
Ok(Some(ScannedItem::Event(event)))
}
Ok((ScannedItem::Bytes(_), _)) => {
let mut bytes = Vec::new();
while let Some(&byte) = buffer.front() {
if byte == crate::constants::V2_SIGNATURE {
break;
}
bytes.push(buffer.pop_front().unwrap());
}
Ok(Some(ScannedItem::Bytes(bytes)))
}
Ok((ScannedItem::Unknown(data, e), used)) => {
buffer.drain(..used);
Ok(Some(ScannedItem::Unknown(data, e)))
}
Err(e) => {
Ok(Some(ScannedItem::Unknown(
buffer.drain(..required_bytes).collect(),
e,
)))
}
}
}
pub fn iter_stream<R: AsyncReadExt + Unpin>(
mut reader: R,
) -> impl Stream<Item = GenResult<ScannedItem>> {
try_stream! {
let mut buffer = VecDeque::<u8>::with_capacity(4 * 1024 * 1024);
while let Some(item) = next(&mut reader, &mut buffer).await? {
yield item;
}
}
}
#[cfg(test)]
mod tests {
use tokio_stream::StreamExt;
use crate::{
io::r#async::iter_stream,
serialize::Serializable,
types::{event::Event, stream::ScannedItem, teststatus::TestStatus},
};
#[tokio::test]
async fn test_iter_stream() {
let events = vec![
Event::new(TestStatus::Success).test_id("foo").build(),
Event::new(TestStatus::Success).test_id("bar").build(),
Event::new(TestStatus::Success).test_id("baz").build(),
];
let mut buf = Vec::new();
for event in events {
event.serialize(&mut buf).unwrap();
}
let stream = iter_stream(&buf[..]);
let results = stream
.collect::<Result<Vec<ScannedItem>, _>>()
.await
.unwrap();
assert_eq!(results.len(), 3);
}
#[tokio::test]
async fn test_stream_with_invalid_utf8() {
let event = Event::new(TestStatus::Success).test_id("test").build();
let mut buffer = Vec::new();
buffer.extend_from_slice(&[0xFF, 0xFE, 0xFD]);
event.serialize(&mut buffer).unwrap();
buffer.extend_from_slice(&[0x80, 0x81]);
let stream = iter_stream(&buffer[..]);
let items: Vec<_> = stream.collect::<Result<Vec<_>, _>>().await.unwrap();
assert_eq!(items.len(), 3);
match &items[0] {
ScannedItem::Bytes(bytes) => assert_eq!(bytes, &[0xFF, 0xFE, 0xFD]),
_ => panic!("Expected Bytes, got {:?}", items[0]),
}
assert!(matches!(items[1], ScannedItem::Event(_)));
match &items[2] {
ScannedItem::Bytes(bytes) => assert_eq!(bytes, &[0x80, 0x81]),
_ => panic!("Expected Bytes, got {:?}", items[2]),
}
}
#[tokio::test]
async fn test_no_infinite_loop_on_malformed_stream() {
let data: &[u8] =
b"\xb3\x29\x00\x16test1\x20\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xb3";
let stream = iter_stream(data);
let items: Vec<_> = stream
.take(101)
.collect::<Result<Vec<_>, _>>()
.await
.unwrap();
assert!(
items.len() <= 10,
"Expected few iterations, got {}",
items.len()
);
}
}