1use anyhow::{Context, Result};
27use bytes::{Buf, BytesMut};
28use std::io::{self, BufRead, Error, ErrorKind, Read};
29use zstd::stream::raw::{Decoder, Operation};
30use zstd::zstd_safe::{MAGICNUMBER, MAGIC_SKIPPABLE_MASK, MAGIC_SKIPPABLE_START};
31
32use crate::io::PeekReader;
33
34pub struct ZstdStreamDecoder<'a, R: Read> {
35 source: PeekReader<R>,
36 buf: BytesMut,
37 decoder: Decoder<'a>,
38 start_of_frame: bool,
39}
40
41impl<R: Read> ZstdStreamDecoder<'_, R> {
42 pub fn new(source: PeekReader<R>) -> Result<Self> {
43 Ok(Self {
44 source,
45 buf: BytesMut::new(),
46 decoder: Decoder::new().context("creating zstd decoder")?,
47 start_of_frame: true,
48 })
49 }
50
51 pub fn get_mut(&mut self) -> &mut PeekReader<R> {
52 &mut self.source
53 }
54
55 pub fn into_inner(self) -> PeekReader<R> {
56 self.source
57 }
58}
59
60impl<R: Read> Read for ZstdStreamDecoder<'_, R> {
61 fn read(&mut self, out: &mut [u8]) -> io::Result<usize> {
62 if out.is_empty() {
63 return Ok(0);
64 }
65 loop {
66 if !self.buf.is_empty() {
67 let count = self.buf.len().min(out.len());
68 self.buf.copy_to_slice(&mut out[..count]);
69 return Ok(count);
70 }
71 if self.start_of_frame {
72 let peek = self.source.peek(4)?;
73 if peek.len() < 4 || !is_zstd_magic(peek[0..4].try_into().unwrap()) {
74 return Ok(0);
76 }
77 self.start_of_frame = false;
78 }
79 let in_ = self.source.fill_buf()?;
80 if in_.is_empty() {
81 return Err(Error::new(
82 ErrorKind::UnexpectedEof,
83 "premature EOF reading zstd frame",
84 ));
85 }
86 self.buf.resize(16384, 0);
90 let status = self.decoder.run_on_buffers(in_, &mut self.buf)?;
91 self.source.consume(status.bytes_read);
92 self.buf.truncate(status.bytes_written);
93 if status.remaining == 0 {
94 self.start_of_frame = true;
95 }
96 }
97 }
98}
99
100pub fn is_zstd_magic(buf: [u8; 4]) -> bool {
101 let val = u32::from_le_bytes(buf);
102 val == MAGICNUMBER || val & MAGIC_SKIPPABLE_MASK == MAGIC_SKIPPABLE_START
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn small_decode() {
111 let mut compressed = Vec::new();
112 compressed.extend(include_bytes!("../../fixtures/verify/1M.zst"));
113 let uncompressed = zstd::stream::decode_all(&*compressed).unwrap();
114 compressed.extend(b"abcdefg");
115
116 let mut d = ZstdStreamDecoder::new(PeekReader::with_capacity(1, &*compressed)).unwrap();
117 let mut out = Vec::new();
118 let mut buf = [0u8];
119 loop {
120 match d.read(&mut buf).unwrap() {
121 0 => break,
122 1 => out.push(buf[0]),
123 _ => unreachable!(),
124 }
125 }
126 assert_eq!(&out, &uncompressed);
127 let mut remainder = Vec::new();
128 d.into_inner().read_to_end(&mut remainder).unwrap();
129 assert_eq!(&remainder, b"abcdefg");
130 }
131}