libcoreinst/io/
compress.rs1use anyhow::{Context, Result};
16use flate2::bufread::GzDecoder;
17use std::io::{self, ErrorKind, Read};
18
19use crate::io::{is_zstd_magic, PeekReader, XzStreamDecoder, ZstdStreamDecoder};
20
21enum CompressDecoder<'a, R: Read> {
22 Uncompressed(PeekReader<R>),
23 Gzip(GzDecoder<PeekReader<R>>),
24 Xz(XzStreamDecoder<PeekReader<R>>),
25 Zstd(ZstdStreamDecoder<'a, R>),
26}
27
28pub struct DecompressReader<'a, R: Read> {
29 decoder: CompressDecoder<'a, R>,
30 allow_trailing: bool,
31}
32
33impl<R: Read> DecompressReader<'_, R> {
35 pub fn new(source: PeekReader<R>) -> Result<Self> {
36 Self::new_full(source, false)
37 }
38
39 pub fn for_concatenated(source: PeekReader<R>) -> Result<Self> {
40 Self::new_full(source, true)
41 }
42
43 fn new_full(mut source: PeekReader<R>, allow_trailing: bool) -> Result<Self> {
44 use CompressDecoder::*;
45 let sniff = source.peek(6).context("sniffing input")?;
46 let decoder = if sniff.len() >= 2 && &sniff[0..2] == b"\x1f\x8b" {
47 Gzip(GzDecoder::new(source))
48 } else if sniff.len() >= 6 && &sniff[0..6] == b"\xfd7zXZ\x00" {
49 Xz(XzStreamDecoder::new(source))
50 } else if sniff.len() > 4 && is_zstd_magic(sniff[0..4].try_into().unwrap()) {
51 Zstd(ZstdStreamDecoder::new(source)?)
52 } else {
53 Uncompressed(source)
54 };
55 Ok(Self {
56 decoder,
57 allow_trailing,
58 })
59 }
60
61 pub fn into_inner(self) -> PeekReader<R> {
62 use CompressDecoder::*;
63 match self.decoder {
64 Uncompressed(d) => d,
65 Gzip(d) => d.into_inner(),
66 Xz(d) => d.into_inner(),
67 Zstd(d) => d.into_inner(),
68 }
69 }
70
71 pub fn get_mut(&mut self) -> &mut PeekReader<R> {
72 use CompressDecoder::*;
73 match &mut self.decoder {
74 Uncompressed(d) => d,
75 Gzip(d) => d.get_mut(),
76 Xz(d) => d.get_mut(),
77 Zstd(d) => d.get_mut(),
78 }
79 }
80
81 pub fn compressed(&self) -> bool {
82 use CompressDecoder::*;
83 match &self.decoder {
84 Uncompressed(_) => false,
85 Gzip(_) => true,
86 Xz(_) => true,
87 Zstd(_) => true,
88 }
89 }
90}
91
92impl<R: Read> Read for DecompressReader<'_, R> {
93 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
94 use CompressDecoder::*;
95 let count = match &mut self.decoder {
96 Uncompressed(d) => d.read(buf)?,
97 Gzip(d) => d.read(buf)?,
98 Xz(d) => d.read(buf)?,
99 Zstd(d) => d.read(buf)?,
100 };
101 if count == 0 && !buf.is_empty() && self.compressed() && !self.allow_trailing {
102 if !self.get_mut().peek(1)?.is_empty() {
107 return Err(io::Error::new(
108 ErrorKind::InvalidData,
109 "found trailing data after compressed stream",
110 ));
111 }
112 }
113 Ok(count)
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
124 fn test_decompress_reader_trailing_data() {
125 test_decompress_reader_trailing_data_one(
126 &include_bytes!("../../fixtures/verify/1M.gz")[..],
127 );
128 test_decompress_reader_trailing_data_one(
129 &include_bytes!("../../fixtures/verify/1M.xz")[..],
130 );
131 test_decompress_reader_trailing_data_one(
132 &include_bytes!("../../fixtures/verify/1M.zst")[..],
133 );
134 }
135
136 fn test_decompress_reader_trailing_data_one(input: &[u8]) {
137 let mut input = input.to_vec();
138 let mut output = Vec::new();
139
140 DecompressReader::new(PeekReader::with_capacity(32, &*input))
142 .unwrap()
143 .read_to_end(&mut output)
144 .unwrap();
145
146 DecompressReader::new(PeekReader::with_capacity(32, &input[0..input.len() - 1]))
148 .unwrap()
149 .read_to_end(&mut output)
150 .unwrap_err();
151
152 input.push(0);
154 DecompressReader::new(PeekReader::with_capacity(32, &*input))
155 .unwrap()
156 .read_to_end(&mut output)
157 .unwrap_err();
158
159 let mut reader =
161 DecompressReader::for_concatenated(PeekReader::with_capacity(32, &*input)).unwrap();
162 reader.read_to_end(&mut output).unwrap();
163 let mut remainder = Vec::new();
164 reader.into_inner().read_to_end(&mut remainder).unwrap();
165 assert_eq!(&remainder, &[0]);
166 }
167}