use super::chunk::ChunkHeader;
use super::{StreamingConfig, StreamingProgress};
use crate::de::{Decode, DecoderImpl, SliceReader};
use crate::enc::{Encode, EncoderImpl, VecWriter};
use crate::{config, Error, Result};
#[cfg(feature = "alloc")]
extern crate alloc;
#[cfg(feature = "async-tokio")]
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(feature = "async-tokio")]
pub struct AsyncStreamingEncoder<W: AsyncWrite + Unpin> {
writer: W,
config: StreamingConfig,
buffer: alloc::vec::Vec<u8>,
items_in_buffer: u32,
progress: StreamingProgress,
}
#[cfg(feature = "async-tokio")]
impl<W: AsyncWrite + Unpin> AsyncStreamingEncoder<W> {
pub fn new(writer: W) -> Self {
Self::with_config(writer, StreamingConfig::default())
}
pub fn with_config(writer: W, config: StreamingConfig) -> Self {
Self {
writer,
config,
buffer: alloc::vec::Vec::new(),
items_in_buffer: 0,
progress: StreamingProgress::default(),
}
}
pub fn set_estimated_total(&mut self, total: u64) {
self.progress.estimated_total = Some(total);
}
pub async fn write_item<T: Encode>(&mut self, item: &T) -> Result<()> {
let item_writer = VecWriter::new();
let mut encoder = EncoderImpl::new(item_writer, config::standard());
item.encode(&mut encoder)?;
let item_bytes = encoder.into_writer().into_vec();
if !self.buffer.is_empty() && self.buffer.len() + item_bytes.len() > self.config.chunk_size
{
self.flush_chunk().await?;
}
self.buffer.extend_from_slice(&item_bytes);
self.items_in_buffer += 1;
if self.config.flush_per_item {
self.flush_chunk().await?;
}
Ok(())
}
pub async fn write_all<T: Encode, I: IntoIterator<Item = T>>(
&mut self,
items: I,
) -> Result<()> {
for item in items {
self.write_item(&item).await?;
}
Ok(())
}
async fn flush_chunk(&mut self) -> Result<()> {
if self.buffer.is_empty() {
return Ok(());
}
let header = ChunkHeader::data(self.buffer.len() as u32, self.items_in_buffer);
self.writer
.write_all(&header.to_bytes())
.await
.map_err(|e| Error::Io {
kind: e.kind(),
message: e.to_string(),
})?;
self.writer
.write_all(&self.buffer)
.await
.map_err(|e| Error::Io {
kind: e.kind(),
message: e.to_string(),
})?;
self.progress.items_processed += self.items_in_buffer as u64;
self.progress.bytes_processed += self.buffer.len() as u64;
self.progress.chunks_processed += 1;
self.buffer.clear();
self.items_in_buffer = 0;
Ok(())
}
pub async fn finish(mut self) -> Result<W> {
self.flush_chunk().await?;
let end_header = ChunkHeader::end();
self.writer
.write_all(&end_header.to_bytes())
.await
.map_err(|e| Error::Io {
kind: e.kind(),
message: e.to_string(),
})?;
self.writer.flush().await.map_err(|e| Error::Io {
kind: e.kind(),
message: e.to_string(),
})?;
Ok(self.writer)
}
pub fn progress(&self) -> &StreamingProgress {
&self.progress
}
pub fn get_ref(&self) -> &W {
&self.writer
}
}
#[cfg(feature = "async-tokio")]
pub struct AsyncStreamingDecoder<R: AsyncRead + Unpin> {
reader: R,
current_chunk: Option<ChunkData>,
progress: StreamingProgress,
finished: bool,
}
#[cfg(feature = "async-tokio")]
struct ChunkData {
data: alloc::vec::Vec<u8>,
offset: usize,
items_remaining: u32,
}
#[cfg(feature = "async-tokio")]
impl<R: AsyncRead + Unpin> AsyncStreamingDecoder<R> {
pub fn new(reader: R) -> Self {
Self {
reader,
current_chunk: None,
progress: StreamingProgress::default(),
finished: false,
}
}
pub fn with_config(reader: R, _config: StreamingConfig) -> Self {
Self::new(reader)
}
pub async fn read_item<T: Decode>(&mut self) -> Result<Option<T>> {
if self.finished {
return Ok(None);
}
let needs_chunk = self.current_chunk.is_none()
|| self
.current_chunk
.as_ref()
.map(|c| c.items_remaining == 0)
.unwrap_or(true);
if needs_chunk && !self.load_next_chunk().await? {
return Ok(None);
}
let chunk = self.current_chunk.as_mut().ok_or(Error::InvalidData {
message: "no chunk available",
})?;
if chunk.items_remaining == 0 {
return Ok(None);
}
let reader = SliceReader::new(&chunk.data[chunk.offset..]);
let mut decoder = DecoderImpl::new(reader, config::standard());
let item = T::decode(&mut decoder)?;
let bytes_consumed = chunk.data[chunk.offset..].len() - decoder.reader().slice.len();
chunk.offset += bytes_consumed;
chunk.items_remaining -= 1;
self.progress.items_processed += 1;
self.progress.bytes_processed += bytes_consumed as u64;
Ok(Some(item))
}
#[cfg(feature = "alloc")]
pub async fn read_all<T: Decode>(&mut self) -> Result<alloc::vec::Vec<T>> {
let mut items = alloc::vec::Vec::new();
while let Some(item) = self.read_item().await? {
items.push(item);
}
Ok(items)
}
async fn load_next_chunk(&mut self) -> Result<bool> {
let mut header_bytes = [0u8; ChunkHeader::SIZE];
match self.reader.read_exact(&mut header_bytes).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
self.finished = true;
return Ok(false);
}
Err(e) => {
return Err(Error::Io {
kind: e.kind(),
message: e.to_string(),
});
}
}
let header = ChunkHeader::from_bytes(&header_bytes)?;
if header.is_end() {
self.finished = true;
return Ok(false);
}
let mut data = alloc::vec![0u8; header.payload_len as usize];
self.reader
.read_exact(&mut data)
.await
.map_err(|e| Error::Io {
kind: e.kind(),
message: e.to_string(),
})?;
self.current_chunk = Some(ChunkData {
data,
offset: 0,
items_remaining: header.item_count,
});
self.progress.chunks_processed += 1;
Ok(true)
}
pub fn progress(&self) -> &StreamingProgress {
&self.progress
}
pub fn is_finished(&self) -> bool {
self.finished
}
pub fn get_ref(&self) -> &R {
&self.reader
}
}
#[derive(Debug, Clone, Default)]
pub struct CancellationToken {
cancelled: std::sync::Arc<std::sync::atomic::AtomicBool>,
}
impl CancellationToken {
pub fn new() -> Self {
Self {
cancelled: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
pub fn cancel(&self) {
self.cancelled
.store(true, std::sync::atomic::Ordering::SeqCst);
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(std::sync::atomic::Ordering::SeqCst)
}
pub fn child(&self) -> Self {
Self {
cancelled: self.cancelled.clone(),
}
}
}
#[cfg(feature = "async-tokio")]
pub struct CancellableAsyncEncoder<W: AsyncWrite + Unpin> {
inner: AsyncStreamingEncoder<W>,
token: CancellationToken,
}
#[cfg(feature = "async-tokio")]
impl<W: AsyncWrite + Unpin> CancellableAsyncEncoder<W> {
pub fn new(writer: W, token: CancellationToken) -> Self {
Self {
inner: AsyncStreamingEncoder::new(writer),
token,
}
}
pub async fn write_item<T: Encode>(&mut self, item: &T) -> Result<()> {
if self.token.is_cancelled() {
return Err(Error::Custom {
message: "operation cancelled",
});
}
self.inner.write_item(item).await
}
pub async fn finish(self) -> Result<W> {
if self.token.is_cancelled() {
return Err(Error::Custom {
message: "operation cancelled",
});
}
self.inner.finish().await
}
pub fn progress(&self) -> &StreamingProgress {
self.inner.progress()
}
}
#[cfg(feature = "async-tokio")]
pub struct CancellableAsyncDecoder<R: AsyncRead + Unpin> {
inner: AsyncStreamingDecoder<R>,
token: CancellationToken,
}
#[cfg(feature = "async-tokio")]
impl<R: AsyncRead + Unpin> CancellableAsyncDecoder<R> {
pub fn new(reader: R, token: CancellationToken) -> Self {
Self {
inner: AsyncStreamingDecoder::new(reader),
token,
}
}
pub async fn read_item<T: Decode>(&mut self) -> Result<Option<T>> {
if self.token.is_cancelled() {
return Err(Error::Custom {
message: "operation cancelled",
});
}
self.inner.read_item().await
}
#[cfg(feature = "alloc")]
pub async fn read_all<T: Decode>(&mut self) -> Result<alloc::vec::Vec<T>> {
let mut items = alloc::vec::Vec::new();
while let Some(item) = self.read_item().await? {
items.push(item);
}
Ok(items)
}
pub fn progress(&self) -> &StreamingProgress {
self.inner.progress()
}
pub fn is_finished(&self) -> bool {
self.inner.is_finished()
}
}
#[cfg(all(test, feature = "async-tokio"))]
mod tests {
use super::*;
use std::io::Cursor;
#[tokio::test]
async fn test_async_roundtrip() {
let mut buffer = alloc::vec::Vec::new();
{
let cursor = Cursor::new(&mut buffer);
let mut encoder = AsyncStreamingEncoder::new(cursor);
for i in 0..50u32 {
encoder.write_item(&i).await.expect("write failed");
}
encoder.finish().await.expect("finish failed");
}
let cursor = Cursor::new(buffer);
let mut decoder = AsyncStreamingDecoder::new(cursor);
let decoded: alloc::vec::Vec<u32> = decoder.read_all().await.expect("read failed");
let expected: alloc::vec::Vec<u32> = (0..50).collect();
assert_eq!(expected, decoded);
assert!(decoder.is_finished());
}
#[tokio::test]
async fn test_async_item_by_item() {
let mut buffer = alloc::vec::Vec::new();
{
let cursor = Cursor::new(&mut buffer);
let mut encoder = AsyncStreamingEncoder::new(cursor);
encoder.write_item(&1u32).await.expect("write failed");
encoder.write_item(&2u32).await.expect("write failed");
encoder.write_item(&3u32).await.expect("write failed");
encoder.finish().await.expect("finish failed");
}
let cursor = Cursor::new(buffer);
let mut decoder = AsyncStreamingDecoder::new(cursor);
assert_eq!(
decoder.read_item::<u32>().await.expect("read failed"),
Some(1)
);
assert_eq!(
decoder.read_item::<u32>().await.expect("read failed"),
Some(2)
);
assert_eq!(
decoder.read_item::<u32>().await.expect("read failed"),
Some(3)
);
assert_eq!(decoder.read_item::<u32>().await.expect("read failed"), None);
}
#[tokio::test]
async fn test_cancellation() {
let token = CancellationToken::new();
let mut buffer = alloc::vec::Vec::new();
let cursor = Cursor::new(&mut buffer);
let mut encoder = CancellableAsyncEncoder::new(cursor, token.child());
encoder.write_item(&1u32).await.expect("write failed");
encoder.write_item(&2u32).await.expect("write failed");
token.cancel();
let result = encoder.write_item(&3u32).await;
assert!(result.is_err());
}
#[test]
fn test_cancellation_token() {
let token = CancellationToken::new();
assert!(!token.is_cancelled());
let child = token.child();
token.cancel();
assert!(token.is_cancelled());
assert!(child.is_cancelled());
}
#[tokio::test]
async fn test_async_progress_tracking() {
let mut buffer = alloc::vec::Vec::new();
{
let cursor = Cursor::new(&mut buffer);
let mut encoder = AsyncStreamingEncoder::new(cursor);
encoder.set_estimated_total(10);
for i in 0..10u32 {
encoder.write_item(&i).await.expect("write failed");
}
encoder.finish().await.expect("finish failed");
}
let cursor = Cursor::new(buffer);
let mut decoder = AsyncStreamingDecoder::new(cursor);
let _: alloc::vec::Vec<u32> = decoder.read_all().await.expect("read failed");
assert_eq!(decoder.progress().items_processed, 10);
assert!(decoder.progress().chunks_processed >= 1);
}
#[tokio::test]
async fn test_async_large_data() {
let config = StreamingConfig::new().with_chunk_size(1024);
let mut buffer = alloc::vec::Vec::new();
{
let cursor = Cursor::new(&mut buffer);
let mut encoder = AsyncStreamingEncoder::with_config(cursor, config);
for i in 0..1000u32 {
encoder.write_item(&i).await.expect("write failed");
}
encoder.finish().await.expect("finish failed");
}
let cursor = Cursor::new(buffer);
let mut decoder = AsyncStreamingDecoder::new(cursor);
let decoded: alloc::vec::Vec<u32> = decoder.read_all().await.expect("read failed");
let expected: alloc::vec::Vec<u32> = (0..1000).collect();
assert_eq!(expected, decoded);
assert!(decoder.progress().chunks_processed > 1);
}
}