libcoreinst/io/
compress.rs

1// Copyright 2019 CoreOS, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::{Context, Result};
16use flate2::bufread::GzDecoder;
17use std::io::{self, ErrorKind, Read};
18
19use crate::io::{is_zstd_magic, PeekReader, XzStreamDecoder, ZstdStreamDecoder};
20
21enum CompressDecoder<'a, R: Read> {
22    Uncompressed(PeekReader<R>),
23    Gzip(GzDecoder<PeekReader<R>>),
24    Xz(XzStreamDecoder<PeekReader<R>>),
25    Zstd(ZstdStreamDecoder<'a, R>),
26}
27
28pub struct DecompressReader<'a, R: Read> {
29    decoder: CompressDecoder<'a, R>,
30    allow_trailing: bool,
31}
32
33/// Format-sniffing decompressor
34impl<R: Read> DecompressReader<'_, R> {
35    pub fn new(source: PeekReader<R>) -> Result<Self> {
36        Self::new_full(source, false)
37    }
38
39    pub fn for_concatenated(source: PeekReader<R>) -> Result<Self> {
40        Self::new_full(source, true)
41    }
42
43    fn new_full(mut source: PeekReader<R>, allow_trailing: bool) -> Result<Self> {
44        use CompressDecoder::*;
45        let sniff = source.peek(6).context("sniffing input")?;
46        let decoder = if sniff.len() >= 2 && &sniff[0..2] == b"\x1f\x8b" {
47            Gzip(GzDecoder::new(source))
48        } else if sniff.len() >= 6 && &sniff[0..6] == b"\xfd7zXZ\x00" {
49            Xz(XzStreamDecoder::new(source))
50        } else if sniff.len() > 4 && is_zstd_magic(sniff[0..4].try_into().unwrap()) {
51            Zstd(ZstdStreamDecoder::new(source)?)
52        } else {
53            Uncompressed(source)
54        };
55        Ok(Self {
56            decoder,
57            allow_trailing,
58        })
59    }
60
61    pub fn into_inner(self) -> PeekReader<R> {
62        use CompressDecoder::*;
63        match self.decoder {
64            Uncompressed(d) => d,
65            Gzip(d) => d.into_inner(),
66            Xz(d) => d.into_inner(),
67            Zstd(d) => d.into_inner(),
68        }
69    }
70
71    pub fn get_mut(&mut self) -> &mut PeekReader<R> {
72        use CompressDecoder::*;
73        match &mut self.decoder {
74            Uncompressed(d) => d,
75            Gzip(d) => d.get_mut(),
76            Xz(d) => d.get_mut(),
77            Zstd(d) => d.get_mut(),
78        }
79    }
80
81    pub fn compressed(&self) -> bool {
82        use CompressDecoder::*;
83        match &self.decoder {
84            Uncompressed(_) => false,
85            Gzip(_) => true,
86            Xz(_) => true,
87            Zstd(_) => true,
88        }
89    }
90}
91
92impl<R: Read> Read for DecompressReader<'_, R> {
93    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
94        use CompressDecoder::*;
95        let count = match &mut self.decoder {
96            Uncompressed(d) => d.read(buf)?,
97            Gzip(d) => d.read(buf)?,
98            Xz(d) => d.read(buf)?,
99            Zstd(d) => d.read(buf)?,
100        };
101        if count == 0 && !buf.is_empty() && self.compressed() && !self.allow_trailing {
102            // Decompressors stop reading as soon as they encounter the
103            // compression trailer, so they don't notice trailing data,
104            // which indicates something wrong with the input.  Look for
105            // one more byte, and fail if there is one.
106            if !self.get_mut().peek(1)?.is_empty() {
107                return Err(io::Error::new(
108                    ErrorKind::InvalidData,
109                    "found trailing data after compressed stream",
110                ));
111            }
112        }
113        Ok(count)
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    /// Test that DecompressReader fails if data is appended to the
122    /// compressed stream.
123    #[test]
124    fn test_decompress_reader_trailing_data() {
125        test_decompress_reader_trailing_data_one(
126            &include_bytes!("../../fixtures/verify/1M.gz")[..],
127        );
128        test_decompress_reader_trailing_data_one(
129            &include_bytes!("../../fixtures/verify/1M.xz")[..],
130        );
131        test_decompress_reader_trailing_data_one(
132            &include_bytes!("../../fixtures/verify/1M.zst")[..],
133        );
134    }
135
136    fn test_decompress_reader_trailing_data_one(input: &[u8]) {
137        let mut input = input.to_vec();
138        let mut output = Vec::new();
139
140        // successful run
141        DecompressReader::new(PeekReader::with_capacity(32, &*input))
142            .unwrap()
143            .read_to_end(&mut output)
144            .unwrap();
145
146        // drop last byte, make sure we notice
147        DecompressReader::new(PeekReader::with_capacity(32, &input[0..input.len() - 1]))
148            .unwrap()
149            .read_to_end(&mut output)
150            .unwrap_err();
151
152        // add trailing garbage, make sure we notice
153        input.push(0);
154        DecompressReader::new(PeekReader::with_capacity(32, &*input))
155            .unwrap()
156            .read_to_end(&mut output)
157            .unwrap_err();
158
159        // use concatenated mode, make sure we ignore trailing garbage
160        let mut reader =
161            DecompressReader::for_concatenated(PeekReader::with_capacity(32, &*input)).unwrap();
162        reader.read_to_end(&mut output).unwrap();
163        let mut remainder = Vec::new();
164        reader.into_inner().read_to_end(&mut remainder).unwrap();
165        assert_eq!(&remainder, &[0]);
166    }
167}