orc_rust/
compression.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18// Modified from https://github.com/DataEngineeringLabs/orc-format/blob/416490db0214fc51d53289253c0ee91f7fc9bc17/src/read/decompress/mod.rs
19//! Related code for handling decompression of ORC files.
20
21use std::io::Read;
22
23use bytes::{Bytes, BytesMut};
24use fallible_streaming_iterator::FallibleStreamingIterator;
25use snafu::ResultExt;
26
27use crate::error::{self, OrcError, Result};
28use crate::proto::{self, CompressionKind};
29
30// Spec states default is 256K
31const DEFAULT_COMPRESSION_BLOCK_SIZE: u64 = 256 * 1024;
32
33#[derive(Clone, Copy, Debug)]
34pub struct Compression {
35    compression_type: CompressionType,
36    /// No compression chunk will decompress to larger than this size.
37    /// Use to size the scratch buffer appropriately.
38    max_decompressed_block_size: usize,
39}
40
41impl std::fmt::Display for Compression {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        write!(
44            f,
45            "{} ({} byte max block size)",
46            self.compression_type, self.max_decompressed_block_size
47        )
48    }
49}
50
51impl Compression {
52    pub fn compression_type(&self) -> CompressionType {
53        self.compression_type
54    }
55
56    pub(crate) fn from_proto(
57        kind: proto::CompressionKind,
58        compression_block_size: Option<u64>,
59    ) -> Option<Self> {
60        let max_decompressed_block_size =
61            compression_block_size.unwrap_or(DEFAULT_COMPRESSION_BLOCK_SIZE) as usize;
62        match kind {
63            CompressionKind::None => None,
64            CompressionKind::Zlib => Some(Self {
65                compression_type: CompressionType::Zlib,
66                max_decompressed_block_size,
67            }),
68            CompressionKind::Snappy => Some(Self {
69                compression_type: CompressionType::Snappy,
70                max_decompressed_block_size,
71            }),
72            CompressionKind::Lzo => Some(Self {
73                compression_type: CompressionType::Lzo,
74                max_decompressed_block_size,
75            }),
76            CompressionKind::Lz4 => Some(Self {
77                compression_type: CompressionType::Lz4,
78                max_decompressed_block_size,
79            }),
80            CompressionKind::Zstd => Some(Self {
81                compression_type: CompressionType::Zstd,
82                max_decompressed_block_size,
83            }),
84        }
85    }
86}
87
88#[derive(Clone, Copy, Debug)]
89pub enum CompressionType {
90    Zlib,
91    Snappy,
92    Lzo,
93    Lz4,
94    Zstd,
95}
96
97impl std::fmt::Display for CompressionType {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        write!(f, "{:?}", self)
100    }
101}
102
103/// Indicates length of block and whether it's compressed or not.
104#[derive(Debug, PartialEq, Eq)]
105enum CompressionHeader {
106    Original(u32),
107    Compressed(u32),
108}
109
110/// ORC files are compressed in blocks, with a 3 byte header at the start
111/// of these blocks indicating the length of the block and whether it's
112/// compressed or not.
113fn decode_header(bytes: [u8; 3]) -> CompressionHeader {
114    let bytes = [bytes[0], bytes[1], bytes[2], 0];
115    let length_and_flag = u32::from_le_bytes(bytes);
116    let is_original = length_and_flag & 1 == 1;
117    let length = length_and_flag >> 1;
118    if is_original {
119        CompressionHeader::Original(length)
120    } else {
121        CompressionHeader::Compressed(length)
122    }
123}
124
125pub(crate) trait DecompressorVariant: Send {
126    fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec<u8>) -> Result<()>;
127}
128
129#[derive(Debug, Clone, Copy)]
130struct Zlib;
131#[derive(Debug, Clone, Copy)]
132struct Zstd;
133#[derive(Debug, Clone, Copy)]
134struct Snappy;
135#[derive(Debug, Clone, Copy)]
136struct Lzo;
137#[derive(Debug, Clone, Copy)]
138struct Lz4 {
139    max_decompressed_block_size: usize,
140}
141
142impl DecompressorVariant for Zlib {
143    fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec<u8>) -> Result<()> {
144        let mut gz = flate2::read::DeflateDecoder::new(compressed_bytes);
145        scratch.clear();
146        gz.read_to_end(scratch).context(error::IoSnafu)?;
147        Ok(())
148    }
149}
150
151impl DecompressorVariant for Zstd {
152    fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec<u8>) -> Result<()> {
153        let mut reader =
154            zstd::Decoder::new(compressed_bytes).context(error::BuildZstdDecoderSnafu)?;
155        scratch.clear();
156        reader.read_to_end(scratch).context(error::IoSnafu)?;
157        Ok(())
158    }
159}
160
161impl DecompressorVariant for Snappy {
162    fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec<u8>) -> Result<()> {
163        let len =
164            snap::raw::decompress_len(compressed_bytes).context(error::BuildSnappyDecoderSnafu)?;
165        scratch.resize(len, 0);
166        let mut decoder = snap::raw::Decoder::new();
167        decoder
168            .decompress(compressed_bytes, scratch)
169            .context(error::BuildSnappyDecoderSnafu)?;
170        Ok(())
171    }
172}
173
174impl DecompressorVariant for Lzo {
175    fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec<u8>) -> Result<()> {
176        let decompressed = lzokay_native::decompress_all(compressed_bytes, None)
177            .context(error::BuildLzoDecoderSnafu)?;
178        // TODO: better way to utilize scratch here
179        scratch.clear();
180        scratch.extend(decompressed);
181        Ok(())
182    }
183}
184
185impl DecompressorVariant for Lz4 {
186    fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec<u8>) -> Result<()> {
187        let decompressed =
188            lz4_flex::block::decompress(compressed_bytes, self.max_decompressed_block_size)
189                .context(error::BuildLz4DecoderSnafu)?;
190        // TODO: better way to utilize scratch here
191        scratch.clear();
192        scratch.extend(decompressed);
193        Ok(())
194    }
195}
196
197// TODO: push this earlier so we don't check this variant each time
198fn get_decompressor_variant(
199    Compression {
200        compression_type,
201        max_decompressed_block_size,
202    }: Compression,
203) -> Box<dyn DecompressorVariant> {
204    match compression_type {
205        CompressionType::Zlib => Box::new(Zlib),
206        CompressionType::Snappy => Box::new(Snappy),
207        CompressionType::Lzo => Box::new(Lzo),
208        CompressionType::Lz4 => Box::new(Lz4 {
209            max_decompressed_block_size,
210        }),
211        CompressionType::Zstd => Box::new(Zstd),
212    }
213}
214
215enum State {
216    Original(Bytes),
217    Compressed(Vec<u8>),
218}
219
220struct DecompressorIter {
221    stream: BytesMut,
222    current: Option<State>, // when we have compression but the value is original
223    compression: Option<Box<dyn DecompressorVariant>>,
224    scratch: Vec<u8>,
225}
226
227impl DecompressorIter {
228    fn new(stream: Bytes, compression: Option<Compression>, scratch: Vec<u8>) -> Self {
229        Self {
230            stream: BytesMut::from(stream.as_ref()),
231            current: None,
232            compression: compression.map(get_decompressor_variant),
233            scratch,
234        }
235    }
236}
237
238impl FallibleStreamingIterator for DecompressorIter {
239    type Item = [u8];
240
241    type Error = OrcError;
242
243    #[inline]
244    fn advance(&mut self) -> Result<(), Self::Error> {
245        if self.stream.is_empty() {
246            self.current = None;
247            return Ok(());
248        }
249
250        match &self.compression {
251            Some(compression) => {
252                // TODO: take stratch from current State::Compressed for re-use
253                let header = self.stream.split_to(3);
254                let header = [header[0], header[1], header[2]];
255                match decode_header(header) {
256                    CompressionHeader::Original(length) => {
257                        let original = self.stream.split_to(length as usize);
258                        self.current = Some(State::Original(original.into()));
259                    }
260                    CompressionHeader::Compressed(length) => {
261                        let compressed = self.stream.split_to(length as usize);
262                        compression.decompress_block(&compressed, &mut self.scratch)?;
263                        self.current = Some(State::Compressed(std::mem::take(&mut self.scratch)));
264                    }
265                };
266                Ok(())
267            }
268            None => {
269                // TODO: take stratch from current State::Compressed for re-use
270                self.current = Some(State::Original(self.stream.clone().into()));
271                self.stream.clear();
272                Ok(())
273            }
274        }
275    }
276
277    #[inline]
278    fn get(&self) -> Option<&Self::Item> {
279        self.current.as_ref().map(|x| match x {
280            State::Original(x) => x.as_ref(),
281            State::Compressed(x) => x.as_ref(),
282        })
283    }
284}
285
286/// A [`Read`]er fulfilling the ORC specification of reading compressed data.
287pub struct Decompressor {
288    decompressor: DecompressorIter,
289    offset: usize,
290    is_first: bool,
291}
292
293impl Decompressor {
294    /// Creates a new [`Decompressor`] that will use `scratch` as a temporary region.
295    pub fn new(stream: Bytes, compression: Option<Compression>, scratch: Vec<u8>) -> Self {
296        Self {
297            decompressor: DecompressorIter::new(stream, compression, scratch),
298            offset: 0,
299            is_first: true,
300        }
301    }
302
303    // TODO: remove need for this upstream
304    pub fn empty() -> Self {
305        Self {
306            decompressor: DecompressorIter::new(Bytes::new(), None, vec![]),
307            offset: 0,
308            is_first: true,
309        }
310    }
311}
312
313impl std::io::Read for Decompressor {
314    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
315        if self.is_first {
316            self.is_first = false;
317            self.decompressor.advance().unwrap();
318        }
319        let current = self.decompressor.get();
320        let current = if let Some(current) = current {
321            if current.len() == self.offset {
322                self.decompressor.advance().unwrap();
323                self.offset = 0;
324                let current = self.decompressor.get();
325                if let Some(current) = current {
326                    current
327                } else {
328                    return Ok(0);
329                }
330            } else {
331                &current[self.offset..]
332            }
333        } else {
334            return Ok(0);
335        };
336
337        if current.len() >= buf.len() {
338            buf.copy_from_slice(&current[..buf.len()]);
339            self.offset += buf.len();
340            Ok(buf.len())
341        } else {
342            buf[..current.len()].copy_from_slice(current);
343            self.offset += current.len();
344            Ok(current.len())
345        }
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn decode_uncompressed() {
355        // 5 uncompressed = [0x0b, 0x00, 0x00] = [0b1011, 0, 0]
356        let bytes = [0b1011, 0, 0];
357
358        let expected = CompressionHeader::Original(5);
359        let actual = decode_header(bytes);
360        assert_eq!(expected, actual);
361    }
362
363    #[test]
364    fn decode_compressed() {
365        // 100_000 compressed = [0x40, 0x0d, 0x03] = [0b01000000, 0b00001101, 0b00000011]
366        let bytes = [0b0100_0000, 0b0000_1101, 0b0000_0011];
367        let expected = CompressionHeader::Compressed(100_000);
368        let actual = decode_header(bytes);
369        assert_eq!(expected, actual);
370    }
371}