libcoreinst/io/
zstd.rs

1// Copyright 2022 Red Hat
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Implementation of an API similar to zstd::stream::read::Decoder using
16// zstd::stream::raw::Decoder.  We need this because read::Decoder returns
17// io::ErrorKind::Other if there's trailing data after a zstd stream, which
18// can't be disambiguated from an actual error.  By using the low-level API,
19// we can check zstd::stream::raw::Status.remaining to see whether the
20// decoder thinks it's at the end of a frame, check the upcoming bytes for
21// the magic number of another frame, and decide whether we're done.  The
22// raw decoder always stops at frame boundaries, so this is reliable.  If
23// done, return Ok(0) and allow the caller to decide what it wants to do
24// about trailing data.
25
26use anyhow::{Context, Result};
27use bytes::{Buf, BytesMut};
28use std::io::{self, BufRead, Error, ErrorKind, Read};
29use zstd::stream::raw::{Decoder, Operation};
30use zstd::zstd_safe::{MAGICNUMBER, MAGIC_SKIPPABLE_MASK, MAGIC_SKIPPABLE_START};
31
32use crate::io::PeekReader;
33
34pub struct ZstdStreamDecoder<'a, R: Read> {
35    source: PeekReader<R>,
36    buf: BytesMut,
37    decoder: Decoder<'a>,
38    start_of_frame: bool,
39}
40
41impl<R: Read> ZstdStreamDecoder<'_, R> {
42    pub fn new(source: PeekReader<R>) -> Result<Self> {
43        Ok(Self {
44            source,
45            buf: BytesMut::new(),
46            decoder: Decoder::new().context("creating zstd decoder")?,
47            start_of_frame: true,
48        })
49    }
50
51    pub fn get_mut(&mut self) -> &mut PeekReader<R> {
52        &mut self.source
53    }
54
55    pub fn into_inner(self) -> PeekReader<R> {
56        self.source
57    }
58}
59
60impl<R: Read> Read for ZstdStreamDecoder<'_, R> {
61    fn read(&mut self, out: &mut [u8]) -> io::Result<usize> {
62        if out.is_empty() {
63            return Ok(0);
64        }
65        loop {
66            if !self.buf.is_empty() {
67                let count = self.buf.len().min(out.len());
68                self.buf.copy_to_slice(&mut out[..count]);
69                return Ok(count);
70            }
71            if self.start_of_frame {
72                let peek = self.source.peek(4)?;
73                if peek.len() < 4 || !is_zstd_magic(peek[0..4].try_into().unwrap()) {
74                    // end of compressed data
75                    return Ok(0);
76                }
77                self.start_of_frame = false;
78            }
79            let in_ = self.source.fill_buf()?;
80            if in_.is_empty() {
81                return Err(Error::new(
82                    ErrorKind::UnexpectedEof,
83                    "premature EOF reading zstd frame",
84                ));
85            }
86            // unfortunately we have to initialize to 0 for safety
87            // BUFFER_SIZE is very large; use a smaller buffer to avoid
88            // unneeded reinitialization
89            self.buf.resize(16384, 0);
90            let status = self.decoder.run_on_buffers(in_, &mut self.buf)?;
91            self.source.consume(status.bytes_read);
92            self.buf.truncate(status.bytes_written);
93            if status.remaining == 0 {
94                self.start_of_frame = true;
95            }
96        }
97    }
98}
99
100pub fn is_zstd_magic(buf: [u8; 4]) -> bool {
101    let val = u32::from_le_bytes(buf);
102    val == MAGICNUMBER || val & MAGIC_SKIPPABLE_MASK == MAGIC_SKIPPABLE_START
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn small_decode() {
111        let mut compressed = Vec::new();
112        compressed.extend(include_bytes!("../../fixtures/verify/1M.zst"));
113        let uncompressed = zstd::stream::decode_all(&*compressed).unwrap();
114        compressed.extend(b"abcdefg");
115
116        let mut d = ZstdStreamDecoder::new(PeekReader::with_capacity(1, &*compressed)).unwrap();
117        let mut out = Vec::new();
118        let mut buf = [0u8];
119        loop {
120            match d.read(&mut buf).unwrap() {
121                0 => break,
122                1 => out.push(buf[0]),
123                _ => unreachable!(),
124            }
125        }
126        assert_eq!(&out, &uncompressed);
127        let mut remainder = Vec::new();
128        d.into_inner().read_to_end(&mut remainder).unwrap();
129        assert_eq!(&remainder, b"abcdefg");
130    }
131}