Skip to main content

walker_common/compression/
detecting.rs

1use bytes::Bytes;
2use std::collections::HashSet;
3
4#[derive(Debug, thiserror::Error)]
5pub enum Error<'a> {
6    #[error("unsupported compression: {0}")]
7    Unsupported(&'a str),
8    #[error(transparent)]
9    Io(#[from] std::io::Error),
10}
11
12#[derive(Copy, Clone, Eq, PartialEq, Debug)]
13#[non_exhaustive]
14pub enum Compression {
15    None,
16    #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
17    Bzip2,
18    #[cfg(feature = "lzma")]
19    Xz,
20    #[cfg(feature = "flate2")]
21    Gzip,
22}
23
24#[non_exhaustive]
25#[derive(Clone, Debug, PartialEq, Eq, Default)]
26pub struct DecompressionOptions {
27    /// The maximum decompressed payload size.
28    ///
29    /// If the size of the uncompressed payload exceeds this limit, and error would be returned
30    /// instead. Zero means, unlimited.
31    pub limit: usize,
32}
33
34impl DecompressionOptions {
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    /// Set the limit of the maximum uncompressed payload size.
40    pub fn limit(mut self, limit: usize) -> Self {
41        self.limit = limit;
42        self
43    }
44}
45
46impl Compression {
47    /// Perform decompression.
48    ///
49    /// Returns the original data for [`Compression::None`].
50    pub fn decompress(&self, data: Bytes) -> Result<Bytes, std::io::Error> {
51        Ok(self.decompress_opt(&data)?.unwrap_or(data))
52    }
53
54    /// Perform decompression.
55    ///
56    /// Returns the original data for [`Compression::None`].
57    pub fn decompress_with(
58        &self,
59        data: Bytes,
60        opts: &DecompressionOptions,
61    ) -> Result<Bytes, std::io::Error> {
62        Ok(self.decompress_opt_with(&data, opts)?.unwrap_or(data))
63    }
64
65    /// Perform decompression.
66    ///
67    /// Returns `None` for [`Compression::None`]
68    pub fn decompress_opt(&self, data: &[u8]) -> Result<Option<Bytes>, std::io::Error> {
69        self.decompress_opt_with(data, &Default::default())
70    }
71
72    /// Perform decompression.
73    ///
74    /// Returns `None` for [`Compression::None`]
75    pub fn decompress_opt_with(
76        &self,
77        #[allow(unused_variables)] data: &[u8],
78        #[allow(unused_variables)] opts: &DecompressionOptions,
79    ) -> Result<Option<Bytes>, std::io::Error> {
80        match self {
81            #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
82            Compression::Bzip2 => super::decompress_bzip2_with(data, opts).map(Some),
83            #[cfg(feature = "lzma")]
84            Compression::Xz => super::decompress_xz_with(data, opts).map(Some),
85            #[cfg(feature = "flate2")]
86            Compression::Gzip => super::decompress_gzip_with(data, opts).map(Some),
87            Compression::None => Ok(None),
88        }
89    }
90}
91
92#[derive(Clone, Debug, Default)]
93pub struct Detector<'a> {
94    /// File name
95    pub file_name: Option<&'a str>,
96
97    /// Disable detection by magic bytes
98    pub disable_magic: bool,
99
100    /// File name extensions to ignore.
101    pub ignore_file_extensions: HashSet<&'a str>,
102    /// If a file name is present, but the extension is unknown, report as an error
103    pub fail_unknown_file_extension: bool,
104}
105
106impl<'a> Detector<'a> {
107    /// Detect and decompress in a single step.
108    pub fn decompress(&self, data: Bytes) -> Result<Bytes, Error<'a>> {
109        self.decompress_with(data, &Default::default())
110    }
111
112    /// Detect and decompress in a single step.
113    pub fn decompress_with(
114        &self,
115        data: Bytes,
116        opts: &DecompressionOptions,
117    ) -> Result<Bytes, Error<'a>> {
118        let compression = self.detect(&data)?;
119        Ok(compression.decompress_with(data, opts)?)
120    }
121
122    pub fn detect(&self, #[allow(unused)] data: &[u8]) -> Result<Compression, Error<'a>> {
123        // detect by file name extension
124
125        if let Some(file_name) = self.file_name {
126            #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
127            if file_name.ends_with(".bz2") {
128                return Ok(Compression::Bzip2);
129            }
130            #[cfg(feature = "lzma")]
131            if file_name.ends_with(".xz") {
132                return Ok(Compression::Xz);
133            }
134            #[cfg(feature = "flate2")]
135            if file_name.ends_with(".gz") {
136                return Ok(Compression::Gzip);
137            }
138            if self.fail_unknown_file_extension
139                && let Some((_, ext)) = file_name.rsplit_once('.')
140                && !self.ignore_file_extensions.contains(ext)
141            {
142                return Err(Error::Unsupported(ext));
143            }
144        }
145
146        // magic bytes
147
148        if !self.disable_magic {
149            #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
150            if data.starts_with(b"BZh") {
151                return Ok(Compression::Bzip2);
152            }
153            #[cfg(feature = "lzma")]
154            if data.starts_with(&[0xFD, 0x37, 0x7A, 0x58, 0x5A, 0x00]) {
155                return Ok(Compression::Xz);
156            }
157            #[cfg(feature = "flate2")]
158            if data.starts_with(&[0x1F, 0x8B, 0x08]) {
159                // NOTE: Byte #3 (0x08) is the compression format, which means "deflate" and is the
160                // only one supported right now. Having additional compression formats, we'd need
161                // to extend this check, or drop the 3rd byte.
162                return Ok(Compression::Gzip);
163            }
164        }
165
166        // done
167
168        Ok(Compression::None)
169    }
170}
171
172#[cfg(test)]
173mod test {
174    use super::*;
175
176    fn detect(name: &str) -> Compression {
177        Detector {
178            file_name: Some(name),
179            disable_magic: true,
180            ..Default::default()
181        }
182        .detect(&[])
183        .unwrap()
184    }
185
186    #[test]
187    fn by_name_none() {
188        assert_eq!(detect("foo.bar.json"), Compression::None);
189    }
190
191    #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
192    #[test]
193    fn by_name_bzip2() {
194        assert_eq!(detect("foo.bar.bz2"), Compression::Bzip2);
195    }
196
197    #[cfg(feature = "lzma")]
198    #[test]
199    fn by_name_xz() {
200        assert_eq!(detect("foo.bar.xz"), Compression::Xz);
201    }
202
203    #[cfg(feature = "flate2")]
204    #[test]
205    fn by_name_gzip() {
206        assert_eq!(detect("foo.bar.gz"), Compression::Gzip);
207    }
208
209    #[test]
210    fn default() {
211        // we're not interested in running this, just ensuring we can use the Default ergonomically
212        let _result = Detector::default().decompress(Bytes::from_static(b"foo"));
213
214        let detector = Detector::default();
215        let _result = detector.decompress(Bytes::from_static(b"foo"));
216    }
217}