1use crate::encode::ZstdEncoder;
31use crate::frame::{ZstdDecoder, decompress_multi_frame};
32use std::io::{self, Read, Write};
33
34const DEFAULT_BLOCK_SIZE: usize = 128 * 1024;
36
37pub struct ZstdStreamEncoder<W: Write> {
52 inner: Option<W>,
54 buffer: Vec<u8>,
56 level: i32,
58 dict: Option<Vec<u8>>,
60 finished: bool,
62 block_size: usize,
64}
65
66impl<W: Write> ZstdStreamEncoder<W> {
67 pub fn new(writer: W, level: i32) -> Self {
74 Self {
75 inner: Some(writer),
76 buffer: Vec::new(),
77 level,
78 dict: None,
79 finished: false,
80 block_size: DEFAULT_BLOCK_SIZE,
81 }
82 }
83
84 pub fn with_dictionary(writer: W, level: i32, dict: Vec<u8>) -> Self {
89 Self {
90 inner: Some(writer),
91 buffer: Vec::new(),
92 level,
93 dict: Some(dict),
94 finished: false,
95 block_size: DEFAULT_BLOCK_SIZE,
96 }
97 }
98
99 pub fn with_block_size(mut self, block_size: usize) -> Self {
104 self.block_size = block_size.max(1);
105 self
106 }
107
108 pub fn finish(mut self) -> io::Result<W> {
118 if !self.finished {
119 self.flush_buffer_unconditional()?;
122 self.finished = true;
123 }
124 self.inner
126 .take()
127 .ok_or_else(|| io::Error::other("inner writer already taken"))
128 }
129
130 fn compress_and_write(&mut self, data: &[u8]) -> io::Result<()> {
132 let mut encoder = ZstdEncoder::new();
133 encoder.set_level(self.level);
134 if let Some(ref dict) = self.dict {
135 encoder.set_dictionary(dict);
136 }
137 let compressed = encoder
138 .compress(data)
139 .map_err(|e| io::Error::other(e.to_string()))?;
140 if let Some(ref mut w) = self.inner {
141 w.write_all(&compressed)?;
142 }
143 Ok(())
144 }
145
146 fn maybe_flush_block(&mut self) -> io::Result<()> {
148 if self.buffer.len() >= self.block_size {
149 let data = std::mem::take(&mut self.buffer);
150 self.compress_and_write(&data)?;
151 }
152 Ok(())
153 }
154
155 fn flush_buffer_unconditional(&mut self) -> io::Result<()> {
157 let data = std::mem::take(&mut self.buffer);
158 self.compress_and_write(&data)
159 }
160
161 pub fn buffered_bytes(&self) -> usize {
163 self.buffer.len()
164 }
165
166 pub fn is_finished(&self) -> bool {
168 self.finished
169 }
170}
171
172impl<W: Write> Write for ZstdStreamEncoder<W> {
173 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
176 if self.finished {
177 return Err(io::Error::other("encoder already finished"));
178 }
179 self.buffer.extend_from_slice(buf);
180 self.maybe_flush_block()?;
181 Ok(buf.len())
182 }
183
184 fn flush(&mut self) -> io::Result<()> {
186 if !self.buffer.is_empty() {
187 let data = std::mem::take(&mut self.buffer);
188 self.compress_and_write(&data)?;
189 }
190 if let Some(ref mut w) = self.inner {
191 w.flush()?;
192 }
193 Ok(())
194 }
195}
196
197pub struct ZstdStreamDecoder<R: Read> {
207 inner: R,
209 output_buffer: Vec<u8>,
211 output_pos: usize,
213 finished: bool,
215 dict: Option<Vec<u8>>,
217}
218
219impl<R: Read> ZstdStreamDecoder<R> {
220 pub fn new(reader: R) -> Self {
222 Self {
223 inner: reader,
224 output_buffer: Vec::new(),
225 output_pos: 0,
226 finished: false,
227 dict: None,
228 }
229 }
230
231 pub fn with_dictionary(reader: R, dict: Vec<u8>) -> Self {
236 Self {
237 inner: reader,
238 output_buffer: Vec::new(),
239 output_pos: 0,
240 finished: false,
241 dict: if dict.is_empty() { None } else { Some(dict) },
242 }
243 }
244
245 fn fill_buffer(&mut self) -> io::Result<()> {
250 if self.finished || self.output_pos < self.output_buffer.len() {
251 return Ok(());
252 }
253
254 let mut compressed = Vec::new();
255 self.inner.read_to_end(&mut compressed)?;
256
257 if compressed.is_empty() {
258 self.finished = true;
259 return Ok(());
260 }
261
262 self.output_buffer = if self.dict.is_none() {
268 decompress_multi_frame(&compressed)
269 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?
270 } else {
271 let mut decoder = ZstdDecoder::new();
272 if let Some(ref dict) = self.dict {
273 decoder.set_dictionary(dict);
274 }
275 decoder
276 .decode_frame(&compressed)
277 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?
278 };
279 self.output_pos = 0;
280 self.finished = true;
281
282 Ok(())
283 }
284
285 pub fn decompressed_size(&self) -> usize {
288 self.output_buffer.len()
289 }
290
291 pub fn is_finished(&self) -> bool {
293 self.finished && self.output_pos >= self.output_buffer.len()
294 }
295}
296
297impl<R: Read> Read for ZstdStreamDecoder<R> {
298 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
303 self.fill_buffer()?;
304
305 let available = self.output_buffer.len() - self.output_pos;
306 if available == 0 {
307 return Ok(0);
308 }
309
310 let to_copy = buf.len().min(available);
311 buf[..to_copy]
312 .copy_from_slice(&self.output_buffer[self.output_pos..self.output_pos + to_copy]);
313 self.output_pos += to_copy;
314 Ok(to_copy)
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_stream_encoder_basic() {
324 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
325 encoder.write_all(b"Hello, Zstandard!").unwrap();
326 let compressed = encoder.finish().unwrap();
327 assert!(!compressed.is_empty());
328 }
329
330 #[test]
331 fn test_stream_encoder_empty() {
332 let encoder = ZstdStreamEncoder::new(Vec::new(), 1);
333 let compressed = encoder.finish().unwrap();
334 assert!(!compressed.is_empty());
336 }
337
338 #[test]
339 fn test_stream_roundtrip() {
340 let original = b"The quick brown fox jumps over the lazy dog.";
341
342 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
344 encoder.write_all(original).unwrap();
345 let compressed = encoder.finish().unwrap();
346
347 let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
349 let mut output = Vec::new();
350 decoder.read_to_end(&mut output).unwrap();
351
352 assert_eq!(output, original.as_slice());
353 }
354
355 #[test]
356 fn test_stream_roundtrip_multiple_writes() {
357 let parts: &[&[u8]] = &[b"Hello, ", b"streaming ", b"Zstd!"];
358
359 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
360 for part in parts {
361 encoder.write_all(part).unwrap();
362 }
363 let compressed = encoder.finish().unwrap();
364
365 let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
366 let mut output = Vec::new();
367 decoder.read_to_end(&mut output).unwrap();
368
369 assert_eq!(output, b"Hello, streaming Zstd!");
370 }
371
372 #[test]
373 fn test_stream_decoder_small_reads() {
374 let original = b"ABCDEFGHIJ";
375
376 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
377 encoder.write_all(original).unwrap();
378 let compressed = encoder.finish().unwrap();
379
380 let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
381 let mut output = Vec::new();
382 let mut buf = [0u8; 3];
383
384 loop {
385 let n = decoder.read(&mut buf).unwrap();
386 if n == 0 {
387 break;
388 }
389 output.extend_from_slice(&buf[..n]);
390 }
391
392 assert_eq!(output, original.as_slice());
393 }
394
395 #[test]
396 fn test_stream_decoder_empty_input() {
397 let mut decoder = ZstdStreamDecoder::new(&[][..]);
398 let mut buf = [0u8; 16];
399 let n = decoder.read(&mut buf).unwrap();
400 assert_eq!(n, 0);
401 }
402
403 #[test]
404 fn test_stream_encoder_with_dictionary() {
405 let dict = b"common pattern data".to_vec();
406 let mut encoder = ZstdStreamEncoder::with_dictionary(Vec::new(), 1, dict);
407 encoder.write_all(b"test data").unwrap();
408 let compressed = encoder.finish().unwrap();
409
410 let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
412 let mut output = Vec::new();
413 decoder.read_to_end(&mut output).unwrap();
414 assert_eq!(output, b"test data");
415 }
416
417 #[test]
418 fn test_stream_encoder_buffered_bytes() {
419 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
420 assert_eq!(encoder.buffered_bytes(), 0);
421 encoder.write_all(b"12345").unwrap();
422 assert_eq!(encoder.buffered_bytes(), 5);
423 encoder.write_all(b"67890").unwrap();
424 assert_eq!(encoder.buffered_bytes(), 10);
425 }
426
427 #[test]
428 fn test_stream_encoder_is_finished() {
429 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
430 assert!(!encoder.is_finished());
431 encoder.write_all(b"data").unwrap();
432 assert!(!encoder.is_finished());
433 }
435
436 #[test]
437 fn test_stream_decoder_is_finished() {
438 let original = b"short";
439
440 let mut enc = ZstdStreamEncoder::new(Vec::new(), 1);
441 enc.write_all(original).unwrap();
442 let compressed = enc.finish().unwrap();
443
444 let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
445 assert!(!decoder.is_finished());
446
447 let mut out = Vec::new();
448 decoder.read_to_end(&mut out).unwrap();
449 assert!(decoder.is_finished());
450 }
451
452 #[test]
453 fn test_stream_roundtrip_large_data() {
454 let original: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
455
456 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
457 encoder.write_all(&original).unwrap();
458 let compressed = encoder.finish().unwrap();
459
460 let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
461 let mut output = Vec::new();
462 decoder.read_to_end(&mut output).unwrap();
463
464 assert_eq!(output, original);
465 }
466}