orc_format/read/decompress/
mod.rs

1//! Contains [`Decompressor`]
2use std::io::Read;
3
4use fallible_streaming_iterator::FallibleStreamingIterator;
5
6use crate::error::Error;
7use crate::proto::CompressionKind;
8
9fn decode_header(bytes: &[u8]) -> (bool, usize) {
10    let a: [u8; 3] = (&bytes[..3]).try_into().unwrap();
11    let a = [0, a[0], a[1], a[2]];
12    let length = u32::from_le_bytes(a);
13    let is_original = a[1] & 1 == 1;
14    let length = (length >> (8 + 1)) as usize;
15
16    (is_original, length)
17}
18
19enum State<'a> {
20    Original(&'a [u8]),
21    Compressed(Vec<u8>),
22}
23
24struct DecompressorIter<'a> {
25    stream: &'a [u8],
26    current: Option<State<'a>>, // when we have compression but the value is original
27    compression: CompressionKind,
28    scratch: Vec<u8>,
29}
30
31impl<'a> DecompressorIter<'a> {
32    pub fn new(stream: &'a [u8], compression: CompressionKind, scratch: Vec<u8>) -> Self {
33        Self {
34            stream,
35            current: None,
36            compression,
37            scratch,
38        }
39    }
40
41    pub fn into_inner(self) -> Vec<u8> {
42        match self.current {
43            Some(State::Compressed(some)) => some,
44            _ => self.scratch,
45        }
46    }
47}
48
49impl<'a> FallibleStreamingIterator for DecompressorIter<'a> {
50    type Item = [u8];
51
52    type Error = Error;
53
54    #[inline]
55    fn advance(&mut self) -> Result<(), Self::Error> {
56        if self.stream.is_empty() {
57            self.current = None;
58            return Ok(());
59        }
60        match self.compression {
61            CompressionKind::None => {
62                // todo: take stratch from current State::Compressed for re-use
63                self.current = Some(State::Original(self.stream));
64                self.stream = &[];
65            }
66            CompressionKind::Zlib => {
67                // todo: take stratch from current State::Compressed for re-use
68                let (is_original, length) = decode_header(self.stream);
69                self.stream = &self.stream[3..];
70                let (maybe_compressed, remaining) = self.stream.split_at(length);
71                self.stream = remaining;
72                if is_original {
73                    self.current = Some(State::Original(maybe_compressed));
74                } else {
75                    let mut gz = flate2::read::DeflateDecoder::new(maybe_compressed);
76                    self.scratch.clear();
77                    gz.read_to_end(&mut self.scratch)?;
78                    self.current = Some(State::Compressed(std::mem::take(&mut self.scratch)));
79                }
80            }
81            _ => todo!(),
82        };
83        Ok(())
84    }
85
86    #[inline]
87    fn get(&self) -> Option<&Self::Item> {
88        self.current.as_ref().map(|x| match x {
89            State::Original(x) => *x,
90            State::Compressed(x) => x.as_ref(),
91        })
92    }
93}
94
95/// A [`Read`]er fulfilling the ORC specification of reading compressed data.
96pub struct Decompressor<'a> {
97    decompressor: DecompressorIter<'a>,
98    offset: usize,
99    is_first: bool,
100}
101
102impl<'a> Decompressor<'a> {
103    /// Creates a new [`Decompressor`] that will use `scratch` as a temporary region.
104    pub fn new(stream: &'a [u8], compression: CompressionKind, scratch: Vec<u8>) -> Self {
105        Self {
106            decompressor: DecompressorIter::new(stream, compression, scratch),
107            offset: 0,
108            is_first: true,
109        }
110    }
111
112    /// Returns the internal memory region, so it can be re-used
113    pub fn into_inner(self) -> Vec<u8> {
114        self.decompressor.into_inner()
115    }
116}
117
118impl<'a> std::io::Read for Decompressor<'a> {
119    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
120        if self.is_first {
121            self.is_first = false;
122            self.decompressor.advance().unwrap();
123        }
124        let current = self.decompressor.get();
125        let current = if let Some(current) = current {
126            if current.len() == self.offset {
127                self.decompressor.advance().unwrap();
128                self.offset = 0;
129                let current = self.decompressor.get();
130                if let Some(current) = current {
131                    current
132                } else {
133                    return Ok(0);
134                }
135            } else {
136                &current[self.offset..]
137            }
138        } else {
139            return Ok(0);
140        };
141
142        if current.len() >= buf.len() {
143            buf.copy_from_slice(&current[..buf.len()]);
144            self.offset += buf.len();
145            Ok(buf.len())
146        } else {
147            buf[..current.len()].copy_from_slice(current);
148            self.offset += current.len();
149            Ok(current.len())
150        }
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn decode_uncompressed() {
160        // 5 uncompressed = [0x0b, 0x00, 0x00] = [0b1011, 0, 0]
161        let bytes = &[0b1011, 0, 0, 0];
162
163        let (is_original, length) = decode_header(bytes);
164        assert!(is_original);
165        assert_eq!(length, 5);
166    }
167
168    #[test]
169    fn decode_compressed() {
170        // 100_000 compressed = [0x40, 0x0d, 0x03] = [0b01000000, 0b00001101, 0b00000011]
171        let bytes = &[0b01000000, 0b00001101, 0b00000011, 0];
172
173        let (is_original, length) = decode_header(bytes);
174        assert!(!is_original);
175        assert_eq!(length, 100_000);
176    }
177}