1use crate::error::{Error, Result};
43use bytes::{Bytes, BytesMut};
44use std::io::Cursor;
45use std::pin::Pin;
46use std::task::{Context, Poll};
47use tokio::io::{AsyncRead, ReadBuf};
48
49pub use crate::compression::CompressionAlgorithm;
51
52const DEFAULT_BUFFER_SIZE: usize = 64 * 1024;
54
55pub struct CompressingStream<R: AsyncRead + Unpin> {
84 reader: R,
85 algorithm: CompressionAlgorithm,
86 level: u8,
87 buffer: BytesMut,
88 compressed_buffer: Cursor<Vec<u8>>,
89 stats: StreamingStats,
90 finished: bool,
91 buffer_size: usize,
92}
93
94impl<R: AsyncRead + Unpin> CompressingStream<R> {
95 pub fn new(reader: R, algorithm: CompressionAlgorithm, level: u8) -> Result<Self> {
107 if level > 9 {
108 return Err(Error::InvalidInput(format!(
109 "compression level must be 0-9, got {}",
110 level
111 )));
112 }
113
114 Ok(Self {
115 reader,
116 algorithm,
117 level,
118 buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
119 compressed_buffer: Cursor::new(Vec::new()),
120 stats: StreamingStats::default(),
121 finished: false,
122 buffer_size: DEFAULT_BUFFER_SIZE,
123 })
124 }
125
126 pub fn with_buffer_size(
135 reader: R,
136 algorithm: CompressionAlgorithm,
137 level: u8,
138 buffer_size: usize,
139 ) -> Result<Self> {
140 if level > 9 {
141 return Err(Error::InvalidInput(format!(
142 "compression level must be 0-9, got {}",
143 level
144 )));
145 }
146
147 Ok(Self {
148 reader,
149 algorithm,
150 level,
151 buffer: BytesMut::with_capacity(buffer_size),
152 compressed_buffer: Cursor::new(Vec::new()),
153 stats: StreamingStats::default(),
154 finished: false,
155 buffer_size,
156 })
157 }
158
159 pub fn stats(&self) -> &StreamingStats {
161 &self.stats
162 }
163}
164
165impl<R: AsyncRead + Unpin> AsyncRead for CompressingStream<R> {
166 fn poll_read(
167 mut self: Pin<&mut Self>,
168 cx: &mut Context<'_>,
169 buf: &mut ReadBuf<'_>,
170 ) -> Poll<std::io::Result<()>> {
171 let pos = self.compressed_buffer.position() as usize;
173 let available = self.compressed_buffer.get_ref().len() - pos;
174
175 if available > 0 {
176 let to_copy = available.min(buf.remaining());
177 buf.put_slice(&self.compressed_buffer.get_ref()[pos..pos + to_copy]);
178 self.compressed_buffer.set_position((pos + to_copy) as u64);
179 return Poll::Ready(Ok(()));
180 }
181
182 if self.finished {
184 return Poll::Ready(Ok(()));
185 }
186
187 let this = &mut *self;
190
191 this.buffer.resize(this.buffer_size, 0);
192 let mut read_buf = ReadBuf::new(&mut this.buffer[..]);
193
194 match Pin::new(&mut this.reader).poll_read(cx, &mut read_buf) {
195 Poll::Ready(Ok(())) => {
196 let n = read_buf.filled().len();
197
198 if n == 0 {
199 this.finished = true;
200 return Poll::Ready(Ok(()));
201 }
202
203 this.stats.bytes_read += n as u64;
204
205 let data = Bytes::from(this.buffer[..n].to_vec());
207 let compressed =
208 match crate::compression::compress(&data, this.algorithm, this.level) {
209 Ok(c) => c,
210 Err(e) => return Poll::Ready(Err(std::io::Error::other(e.to_string()))),
211 };
212
213 this.stats.bytes_written += compressed.len() as u64;
214 this.compressed_buffer = Cursor::new(compressed.to_vec());
215
216 let pos = this.compressed_buffer.position() as usize;
218 let available = this.compressed_buffer.get_ref().len() - pos;
219
220 if available > 0 {
221 let to_copy = available.min(buf.remaining());
222 buf.put_slice(&this.compressed_buffer.get_ref()[pos..pos + to_copy]);
223 this.compressed_buffer.set_position((pos + to_copy) as u64);
224 }
225
226 Poll::Ready(Ok(()))
227 }
228 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
229 Poll::Pending => Poll::Pending,
230 }
231 }
232}
233
234pub struct DecompressingStream<R: AsyncRead + Unpin> {
269 reader: R,
270 algorithm: CompressionAlgorithm,
271 buffer: BytesMut,
272 decompressed_buffer: Cursor<Vec<u8>>,
273 stats: StreamingStats,
274 finished: bool,
275 buffer_size: usize,
276}
277
278impl<R: AsyncRead + Unpin> DecompressingStream<R> {
279 pub fn new(reader: R, algorithm: CompressionAlgorithm) -> Result<Self> {
286 Ok(Self {
287 reader,
288 algorithm,
289 buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
290 decompressed_buffer: Cursor::new(Vec::new()),
291 stats: StreamingStats::default(),
292 finished: false,
293 buffer_size: DEFAULT_BUFFER_SIZE,
294 })
295 }
296
297 pub fn with_buffer_size(
305 reader: R,
306 algorithm: CompressionAlgorithm,
307 buffer_size: usize,
308 ) -> Result<Self> {
309 Ok(Self {
310 reader,
311 algorithm,
312 buffer: BytesMut::with_capacity(buffer_size),
313 decompressed_buffer: Cursor::new(Vec::new()),
314 stats: StreamingStats::default(),
315 finished: false,
316 buffer_size,
317 })
318 }
319
320 pub fn stats(&self) -> &StreamingStats {
322 &self.stats
323 }
324}
325
326impl<R: AsyncRead + Unpin> AsyncRead for DecompressingStream<R> {
327 fn poll_read(
328 mut self: Pin<&mut Self>,
329 cx: &mut Context<'_>,
330 buf: &mut ReadBuf<'_>,
331 ) -> Poll<std::io::Result<()>> {
332 let pos = self.decompressed_buffer.position() as usize;
334 let available = self.decompressed_buffer.get_ref().len() - pos;
335
336 if available > 0 {
337 let to_copy = available.min(buf.remaining());
338 buf.put_slice(&self.decompressed_buffer.get_ref()[pos..pos + to_copy]);
339 self.decompressed_buffer
340 .set_position((pos + to_copy) as u64);
341 return Poll::Ready(Ok(()));
342 }
343
344 if self.finished {
346 return Poll::Ready(Ok(()));
347 }
348
349 let this = &mut *self;
352
353 this.buffer.resize(this.buffer_size, 0);
354 let mut read_buf = ReadBuf::new(&mut this.buffer[..]);
355
356 match Pin::new(&mut this.reader).poll_read(cx, &mut read_buf) {
357 Poll::Ready(Ok(())) => {
358 let n = read_buf.filled().len();
359
360 if n == 0 {
361 this.finished = true;
362 return Poll::Ready(Ok(()));
363 }
364
365 this.stats.bytes_read += n as u64;
366
367 let data = Bytes::from(this.buffer[..n].to_vec());
369 let decompressed = match crate::compression::decompress(&data, this.algorithm) {
370 Ok(d) => d,
371 Err(e) => return Poll::Ready(Err(std::io::Error::other(e.to_string()))),
372 };
373
374 this.stats.bytes_written += decompressed.len() as u64;
375 this.decompressed_buffer = Cursor::new(decompressed.to_vec());
376
377 let pos = this.decompressed_buffer.position() as usize;
379 let available = this.decompressed_buffer.get_ref().len() - pos;
380
381 if available > 0 {
382 let to_copy = available.min(buf.remaining());
383 buf.put_slice(&this.decompressed_buffer.get_ref()[pos..pos + to_copy]);
384 this.decompressed_buffer
385 .set_position((pos + to_copy) as u64);
386 }
387
388 Poll::Ready(Ok(()))
389 }
390 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
391 Poll::Pending => Poll::Pending,
392 }
393 }
394}
395
396#[derive(Debug, Clone, Default)]
398pub struct StreamingStats {
399 pub bytes_read: u64,
401 pub bytes_written: u64,
403}
404
405impl StreamingStats {
406 pub fn compression_ratio(&self) -> f64 {
410 if self.bytes_read == 0 {
411 1.0
412 } else {
413 self.bytes_written as f64 / self.bytes_read as f64
414 }
415 }
416
417 pub fn bytes_saved(&self) -> i64 {
419 self.bytes_read as i64 - self.bytes_written as i64
420 }
421
422 pub fn savings_percent(&self) -> f64 {
424 if self.bytes_read == 0 {
425 0.0
426 } else {
427 (self.bytes_saved() as f64 / self.bytes_read as f64) * 100.0
428 }
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use tokio::io::AsyncReadExt;
436
437 #[tokio::test]
438 async fn test_compressing_stream_zstd() {
439 let data = b"Hello, world! ".repeat(100);
440 let cursor = std::io::Cursor::new(data.clone());
441
442 let mut stream = CompressingStream::new(cursor, CompressionAlgorithm::Zstd, 3).unwrap();
443
444 let mut compressed = Vec::new();
445 stream.read_to_end(&mut compressed).await.unwrap();
446
447 assert!(compressed.len() < data.len());
448 let stats = stream.stats();
449 assert_eq!(stats.bytes_read, data.len() as u64);
450 assert!(stats.compression_ratio() < 1.0);
451 }
452
453 #[tokio::test]
454 async fn test_compressing_stream_lz4() {
455 let data = b"Test data for LZ4 compression! ".repeat(100);
456 let cursor = std::io::Cursor::new(data.clone());
457
458 let mut stream = CompressingStream::new(cursor, CompressionAlgorithm::Lz4, 5).unwrap();
459
460 let mut compressed = Vec::new();
461 stream.read_to_end(&mut compressed).await.unwrap();
462
463 assert!(compressed.len() < data.len());
464 }
465
466 #[tokio::test]
467 async fn test_compressing_stream_none() {
468 let data = b"No compression applied".repeat(10);
469 let cursor = std::io::Cursor::new(data.clone());
470
471 let mut stream = CompressingStream::new(cursor, CompressionAlgorithm::None, 0).unwrap();
472
473 let mut output = Vec::new();
474 stream.read_to_end(&mut output).await.unwrap();
475
476 assert_eq!(output, data);
477 let stats = stream.stats();
478 assert_eq!(stats.compression_ratio(), 1.0);
479 }
480
481 #[tokio::test]
482 async fn test_decompressing_stream_roundtrip() {
483 let original = b"Roundtrip test data! ".repeat(100);
484
485 let cursor = std::io::Cursor::new(original.clone());
487 let mut compressor = CompressingStream::new(cursor, CompressionAlgorithm::Zstd, 5).unwrap();
488
489 let mut compressed = Vec::new();
490 compressor.read_to_end(&mut compressed).await.unwrap();
491
492 let cursor = std::io::Cursor::new(compressed);
494 let mut decompressor =
495 DecompressingStream::new(cursor, CompressionAlgorithm::Zstd).unwrap();
496
497 let mut decompressed = Vec::new();
498 decompressor.read_to_end(&mut decompressed).await.unwrap();
499
500 assert_eq!(original, decompressed.as_slice());
501 }
502
503 #[tokio::test]
504 async fn test_streaming_stats() {
505 let data = vec![0u8; 10000];
506 let cursor = std::io::Cursor::new(data.clone());
507
508 let mut stream = CompressingStream::new(cursor, CompressionAlgorithm::Zstd, 6).unwrap();
509
510 let mut compressed = Vec::new();
511 stream.read_to_end(&mut compressed).await.unwrap();
512
513 let stats = stream.stats();
514 assert_eq!(stats.bytes_read, 10000);
515 assert!(stats.bytes_written < 10000);
516 assert!(stats.compression_ratio() < 1.0);
517 assert!(stats.bytes_saved() > 0);
518 assert!(stats.savings_percent() > 0.0);
519 }
520
521 #[tokio::test]
522 async fn test_custom_buffer_size() {
523 let data = b"Custom buffer size test".repeat(50);
524 let cursor = std::io::Cursor::new(data.clone());
525
526 let mut stream =
527 CompressingStream::with_buffer_size(cursor, CompressionAlgorithm::Lz4, 3, 1024)
528 .unwrap();
529
530 let mut compressed = Vec::new();
531 stream.read_to_end(&mut compressed).await.unwrap();
532
533 assert!(compressed.len() < data.len());
534 }
535
536 #[tokio::test]
537 async fn test_invalid_compression_level() {
538 let data = b"test";
539 let cursor = std::io::Cursor::new(data);
540
541 let result = CompressingStream::new(cursor, CompressionAlgorithm::Zstd, 10);
542 assert!(result.is_err());
543 }
544
545 #[tokio::test]
546 async fn test_empty_stream() {
547 let data: Vec<u8> = vec![];
548 let cursor = std::io::Cursor::new(data);
549
550 let mut stream = CompressingStream::new(cursor, CompressionAlgorithm::Zstd, 3).unwrap();
551
552 let mut compressed = Vec::new();
553 stream.read_to_end(&mut compressed).await.unwrap();
554
555 let stats = stream.stats();
556 assert_eq!(stats.bytes_read, 0);
557 assert_eq!(stats.bytes_written, 0);
558 }
559
560 #[tokio::test]
561 async fn test_large_data_streaming() {
562 let data = vec![42u8; 1024 * 1024];
564 let cursor = std::io::Cursor::new(data.clone());
565
566 let mut stream = CompressingStream::new(cursor, CompressionAlgorithm::Zstd, 9).unwrap();
567
568 let mut compressed = Vec::new();
569 stream.read_to_end(&mut compressed).await.unwrap();
570
571 assert!(compressed.len() < data.len() / 10);
573
574 let stats = stream.stats();
575 assert_eq!(stats.bytes_read, 1024 * 1024);
576 assert!(stats.compression_ratio() < 0.1);
577 }
578
579 #[tokio::test]
580 async fn test_decompression_stats() {
581 let original = vec![1u8; 5000];
582
583 let cursor = std::io::Cursor::new(original.clone());
585 let mut compressor = CompressingStream::new(cursor, CompressionAlgorithm::Lz4, 5).unwrap();
586
587 let mut compressed = Vec::new();
588 compressor.read_to_end(&mut compressed).await.unwrap();
589
590 let cursor = std::io::Cursor::new(compressed.clone());
592 let mut decompressor = DecompressingStream::new(cursor, CompressionAlgorithm::Lz4).unwrap();
593
594 let mut decompressed = Vec::new();
595 decompressor.read_to_end(&mut decompressed).await.unwrap();
596
597 let stats = decompressor.stats();
598 assert_eq!(stats.bytes_read, compressed.len() as u64);
599 assert_eq!(stats.bytes_written, original.len() as u64);
600 assert!(stats.compression_ratio() > 1.0); }
602}