1use super::chunk::ChunkHeader;
7use super::{StreamingConfig, StreamingProgress};
8use crate::de::{Decode, DecoderImpl, SliceReader};
9use crate::enc::{Encode, EncoderImpl, VecWriter};
10use crate::{config, Error, Result};
11
12#[cfg(feature = "alloc")]
13extern crate alloc;
14
15#[cfg(feature = "async-tokio")]
16use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
17
18#[cfg(feature = "async-tokio")]
42pub struct AsyncStreamingEncoder<W: AsyncWrite + Unpin> {
43 writer: W,
44 config: StreamingConfig,
45 buffer: alloc::vec::Vec<u8>,
46 items_in_buffer: u32,
47 progress: StreamingProgress,
48}
49
50#[cfg(feature = "async-tokio")]
51impl<W: AsyncWrite + Unpin> AsyncStreamingEncoder<W> {
52 pub fn new(writer: W) -> Self {
54 Self::with_config(writer, StreamingConfig::default())
55 }
56
57 pub fn with_config(writer: W, config: StreamingConfig) -> Self {
59 Self {
60 writer,
61 config,
62 buffer: alloc::vec::Vec::new(),
63 items_in_buffer: 0,
64 progress: StreamingProgress::default(),
65 }
66 }
67
68 pub fn set_estimated_total(&mut self, total: u64) {
70 self.progress.estimated_total = Some(total);
71 }
72
73 pub async fn write_item<T: Encode>(&mut self, item: &T) -> Result<()> {
75 let item_writer = VecWriter::new();
77 let mut encoder = EncoderImpl::new(item_writer, config::standard());
78 item.encode(&mut encoder)?;
79 let item_bytes = encoder.into_writer().into_vec();
80
81 if !self.buffer.is_empty() && self.buffer.len() + item_bytes.len() > self.config.chunk_size
83 {
84 self.flush_chunk().await?;
85 }
86
87 self.buffer.extend_from_slice(&item_bytes);
89 self.items_in_buffer += 1;
90
91 if self.config.flush_per_item {
93 self.flush_chunk().await?;
94 }
95
96 Ok(())
97 }
98
99 pub async fn write_all<T: Encode, I: IntoIterator<Item = T>>(
101 &mut self,
102 items: I,
103 ) -> Result<()> {
104 for item in items {
105 self.write_item(&item).await?;
106 }
107 Ok(())
108 }
109
110 async fn flush_chunk(&mut self) -> Result<()> {
112 if self.buffer.is_empty() {
113 return Ok(());
114 }
115
116 let header = ChunkHeader::data(self.buffer.len() as u32, self.items_in_buffer);
118 self.writer
119 .write_all(&header.to_bytes())
120 .await
121 .map_err(|e| Error::Io {
122 kind: e.kind(),
123 message: e.to_string(),
124 })?;
125
126 self.writer
128 .write_all(&self.buffer)
129 .await
130 .map_err(|e| Error::Io {
131 kind: e.kind(),
132 message: e.to_string(),
133 })?;
134
135 self.progress.items_processed += self.items_in_buffer as u64;
137 self.progress.bytes_processed += self.buffer.len() as u64;
138 self.progress.chunks_processed += 1;
139
140 self.buffer.clear();
142 self.items_in_buffer = 0;
143
144 Ok(())
145 }
146
147 pub async fn finish(mut self) -> Result<W> {
149 self.flush_chunk().await?;
151
152 let end_header = ChunkHeader::end();
154 self.writer
155 .write_all(&end_header.to_bytes())
156 .await
157 .map_err(|e| Error::Io {
158 kind: e.kind(),
159 message: e.to_string(),
160 })?;
161
162 self.writer.flush().await.map_err(|e| Error::Io {
164 kind: e.kind(),
165 message: e.to_string(),
166 })?;
167
168 Ok(self.writer)
169 }
170
171 pub fn progress(&self) -> &StreamingProgress {
173 &self.progress
174 }
175
176 pub fn get_ref(&self) -> &W {
178 &self.writer
179 }
180}
181
182#[cfg(feature = "async-tokio")]
204pub struct AsyncStreamingDecoder<R: AsyncRead + Unpin> {
205 reader: R,
206 current_chunk: Option<ChunkData>,
207 progress: StreamingProgress,
208 finished: bool,
209}
210
211#[cfg(feature = "async-tokio")]
212struct ChunkData {
213 data: alloc::vec::Vec<u8>,
214 offset: usize,
215 items_remaining: u32,
216}
217
218#[cfg(feature = "async-tokio")]
219impl<R: AsyncRead + Unpin> AsyncStreamingDecoder<R> {
220 pub fn new(reader: R) -> Self {
222 Self {
223 reader,
224 current_chunk: None,
225 progress: StreamingProgress::default(),
226 finished: false,
227 }
228 }
229
230 pub fn with_config(reader: R, _config: StreamingConfig) -> Self {
236 Self::new(reader)
237 }
238
239 pub async fn read_item<T: Decode>(&mut self) -> Result<Option<T>> {
243 if self.finished {
244 return Ok(None);
245 }
246
247 let needs_chunk = self.current_chunk.is_none()
249 || self
250 .current_chunk
251 .as_ref()
252 .map(|c| c.items_remaining == 0)
253 .unwrap_or(true);
254 if needs_chunk && !self.load_next_chunk().await? {
255 return Ok(None);
256 }
257
258 let chunk = self.current_chunk.as_mut().ok_or(Error::InvalidData {
260 message: "no chunk available",
261 })?;
262
263 if chunk.items_remaining == 0 {
264 return Ok(None);
265 }
266
267 let reader = SliceReader::new(&chunk.data[chunk.offset..]);
269 let mut decoder = DecoderImpl::new(reader, config::standard());
270 let item = T::decode(&mut decoder)?;
271
272 let bytes_consumed = chunk.data[chunk.offset..].len() - decoder.reader().slice.len();
274 chunk.offset += bytes_consumed;
275 chunk.items_remaining -= 1;
276
277 self.progress.items_processed += 1;
278 self.progress.bytes_processed += bytes_consumed as u64;
279
280 Ok(Some(item))
281 }
282
283 #[cfg(feature = "alloc")]
285 pub async fn read_all<T: Decode>(&mut self) -> Result<alloc::vec::Vec<T>> {
286 let mut items = alloc::vec::Vec::new();
287 while let Some(item) = self.read_item().await? {
288 items.push(item);
289 }
290 Ok(items)
291 }
292
293 async fn load_next_chunk(&mut self) -> Result<bool> {
295 let mut header_bytes = [0u8; ChunkHeader::SIZE];
297 match self.reader.read_exact(&mut header_bytes).await {
298 Ok(_) => {}
299 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
300 self.finished = true;
301 return Ok(false);
302 }
303 Err(e) => {
304 return Err(Error::Io {
305 kind: e.kind(),
306 message: e.to_string(),
307 });
308 }
309 }
310
311 let header = ChunkHeader::from_bytes(&header_bytes)?;
312
313 if header.is_end() {
315 self.finished = true;
316 return Ok(false);
317 }
318
319 let mut data = alloc::vec![0u8; header.payload_len as usize];
321 self.reader
322 .read_exact(&mut data)
323 .await
324 .map_err(|e| Error::Io {
325 kind: e.kind(),
326 message: e.to_string(),
327 })?;
328
329 self.current_chunk = Some(ChunkData {
330 data,
331 offset: 0,
332 items_remaining: header.item_count,
333 });
334
335 self.progress.chunks_processed += 1;
336
337 Ok(true)
338 }
339
340 pub fn progress(&self) -> &StreamingProgress {
342 &self.progress
343 }
344
345 pub fn is_finished(&self) -> bool {
347 self.finished
348 }
349
350 pub fn get_ref(&self) -> &R {
352 &self.reader
353 }
354}
355
356#[derive(Debug, Clone, Default)]
360pub struct CancellationToken {
361 cancelled: std::sync::Arc<std::sync::atomic::AtomicBool>,
362}
363
364impl CancellationToken {
365 pub fn new() -> Self {
367 Self {
368 cancelled: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
369 }
370 }
371
372 pub fn cancel(&self) {
374 self.cancelled
375 .store(true, std::sync::atomic::Ordering::SeqCst);
376 }
377
378 pub fn is_cancelled(&self) -> bool {
380 self.cancelled.load(std::sync::atomic::Ordering::SeqCst)
381 }
382
383 pub fn child(&self) -> Self {
385 Self {
386 cancelled: self.cancelled.clone(),
387 }
388 }
389}
390
391#[cfg(feature = "async-tokio")]
393pub struct CancellableAsyncEncoder<W: AsyncWrite + Unpin> {
394 inner: AsyncStreamingEncoder<W>,
395 token: CancellationToken,
396}
397
398#[cfg(feature = "async-tokio")]
399impl<W: AsyncWrite + Unpin> CancellableAsyncEncoder<W> {
400 pub fn new(writer: W, token: CancellationToken) -> Self {
402 Self {
403 inner: AsyncStreamingEncoder::new(writer),
404 token,
405 }
406 }
407
408 pub async fn write_item<T: Encode>(&mut self, item: &T) -> Result<()> {
410 if self.token.is_cancelled() {
411 return Err(Error::Custom {
412 message: "operation cancelled",
413 });
414 }
415 self.inner.write_item(item).await
416 }
417
418 pub async fn finish(self) -> Result<W> {
420 if self.token.is_cancelled() {
421 return Err(Error::Custom {
422 message: "operation cancelled",
423 });
424 }
425 self.inner.finish().await
426 }
427
428 pub fn progress(&self) -> &StreamingProgress {
430 self.inner.progress()
431 }
432}
433
434#[cfg(feature = "async-tokio")]
436pub struct CancellableAsyncDecoder<R: AsyncRead + Unpin> {
437 inner: AsyncStreamingDecoder<R>,
438 token: CancellationToken,
439}
440
441#[cfg(feature = "async-tokio")]
442impl<R: AsyncRead + Unpin> CancellableAsyncDecoder<R> {
443 pub fn new(reader: R, token: CancellationToken) -> Self {
445 Self {
446 inner: AsyncStreamingDecoder::new(reader),
447 token,
448 }
449 }
450
451 pub async fn read_item<T: Decode>(&mut self) -> Result<Option<T>> {
453 if self.token.is_cancelled() {
454 return Err(Error::Custom {
455 message: "operation cancelled",
456 });
457 }
458 self.inner.read_item().await
459 }
460
461 #[cfg(feature = "alloc")]
463 pub async fn read_all<T: Decode>(&mut self) -> Result<alloc::vec::Vec<T>> {
464 let mut items = alloc::vec::Vec::new();
465 while let Some(item) = self.read_item().await? {
466 items.push(item);
467 }
468 Ok(items)
469 }
470
471 pub fn progress(&self) -> &StreamingProgress {
473 self.inner.progress()
474 }
475
476 pub fn is_finished(&self) -> bool {
478 self.inner.is_finished()
479 }
480}
481
482#[cfg(all(test, feature = "async-tokio"))]
483mod tests {
484 use super::*;
485 use std::io::Cursor;
486
487 #[tokio::test]
488 async fn test_async_roundtrip() {
489 let mut buffer = alloc::vec::Vec::new();
491 {
492 let cursor = Cursor::new(&mut buffer);
493 let mut encoder = AsyncStreamingEncoder::new(cursor);
494
495 for i in 0..50u32 {
496 encoder.write_item(&i).await.expect("write failed");
497 }
498
499 encoder.finish().await.expect("finish failed");
500 }
501
502 let cursor = Cursor::new(buffer);
504 let mut decoder = AsyncStreamingDecoder::new(cursor);
505 let decoded: alloc::vec::Vec<u32> = decoder.read_all().await.expect("read failed");
506
507 let expected: alloc::vec::Vec<u32> = (0..50).collect();
508 assert_eq!(expected, decoded);
509 assert!(decoder.is_finished());
510 }
511
512 #[tokio::test]
513 async fn test_async_item_by_item() {
514 let mut buffer = alloc::vec::Vec::new();
515 {
516 let cursor = Cursor::new(&mut buffer);
517 let mut encoder = AsyncStreamingEncoder::new(cursor);
518 encoder.write_item(&1u32).await.expect("write failed");
519 encoder.write_item(&2u32).await.expect("write failed");
520 encoder.write_item(&3u32).await.expect("write failed");
521 encoder.finish().await.expect("finish failed");
522 }
523
524 let cursor = Cursor::new(buffer);
525 let mut decoder = AsyncStreamingDecoder::new(cursor);
526
527 assert_eq!(
528 decoder.read_item::<u32>().await.expect("read failed"),
529 Some(1)
530 );
531 assert_eq!(
532 decoder.read_item::<u32>().await.expect("read failed"),
533 Some(2)
534 );
535 assert_eq!(
536 decoder.read_item::<u32>().await.expect("read failed"),
537 Some(3)
538 );
539 assert_eq!(decoder.read_item::<u32>().await.expect("read failed"), None);
540 }
541
542 #[tokio::test]
543 async fn test_cancellation() {
544 let token = CancellationToken::new();
545
546 let mut buffer = alloc::vec::Vec::new();
547 let cursor = Cursor::new(&mut buffer);
548 let mut encoder = CancellableAsyncEncoder::new(cursor, token.child());
549
550 encoder.write_item(&1u32).await.expect("write failed");
552 encoder.write_item(&2u32).await.expect("write failed");
553
554 token.cancel();
556
557 let result = encoder.write_item(&3u32).await;
559 assert!(result.is_err());
560 }
561
562 #[test]
563 fn test_cancellation_token() {
564 let token = CancellationToken::new();
565 assert!(!token.is_cancelled());
566
567 let child = token.child();
568 token.cancel();
569
570 assert!(token.is_cancelled());
571 assert!(child.is_cancelled());
572 }
573
574 #[tokio::test]
575 async fn test_async_progress_tracking() {
576 let mut buffer = alloc::vec::Vec::new();
577 {
578 let cursor = Cursor::new(&mut buffer);
579 let mut encoder = AsyncStreamingEncoder::new(cursor);
580 encoder.set_estimated_total(10);
581
582 for i in 0..10u32 {
583 encoder.write_item(&i).await.expect("write failed");
584 }
585
586 encoder.finish().await.expect("finish failed");
587 }
588
589 let cursor = Cursor::new(buffer);
590 let mut decoder = AsyncStreamingDecoder::new(cursor);
591 let _: alloc::vec::Vec<u32> = decoder.read_all().await.expect("read failed");
592
593 assert_eq!(decoder.progress().items_processed, 10);
594 assert!(decoder.progress().chunks_processed >= 1);
595 }
596
597 #[tokio::test]
598 async fn test_async_large_data() {
599 let config = StreamingConfig::new().with_chunk_size(1024);
601
602 let mut buffer = alloc::vec::Vec::new();
603 {
604 let cursor = Cursor::new(&mut buffer);
605 let mut encoder = AsyncStreamingEncoder::with_config(cursor, config);
606
607 for i in 0..1000u32 {
608 encoder.write_item(&i).await.expect("write failed");
609 }
610
611 encoder.finish().await.expect("finish failed");
612 }
613
614 let cursor = Cursor::new(buffer);
615 let mut decoder = AsyncStreamingDecoder::new(cursor);
616 let decoded: alloc::vec::Vec<u32> = decoder.read_all().await.expect("read failed");
617
618 let expected: alloc::vec::Vec<u32> = (0..1000).collect();
619 assert_eq!(expected, decoded);
620
621 assert!(decoder.progress().chunks_processed > 1);
623 }
624}