1use std::io::Read;
22
23use bytes::{Bytes, BytesMut};
24use fallible_streaming_iterator::FallibleStreamingIterator;
25use snafu::ResultExt;
26
27use crate::error::{self, OrcError, Result};
28use crate::proto::{self, CompressionKind};
29
30const DEFAULT_COMPRESSION_BLOCK_SIZE: u64 = 256 * 1024;
32
33#[derive(Clone, Copy, Debug)]
34pub struct Compression {
35 compression_type: CompressionType,
36 max_decompressed_block_size: usize,
39}
40
41impl std::fmt::Display for Compression {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 write!(
44 f,
45 "{} ({} byte max block size)",
46 self.compression_type, self.max_decompressed_block_size
47 )
48 }
49}
50
51impl Compression {
52 pub fn compression_type(&self) -> CompressionType {
53 self.compression_type
54 }
55
56 pub(crate) fn from_proto(
57 kind: proto::CompressionKind,
58 compression_block_size: Option<u64>,
59 ) -> Option<Self> {
60 let max_decompressed_block_size =
61 compression_block_size.unwrap_or(DEFAULT_COMPRESSION_BLOCK_SIZE) as usize;
62 match kind {
63 CompressionKind::None => None,
64 CompressionKind::Zlib => Some(Self {
65 compression_type: CompressionType::Zlib,
66 max_decompressed_block_size,
67 }),
68 CompressionKind::Snappy => Some(Self {
69 compression_type: CompressionType::Snappy,
70 max_decompressed_block_size,
71 }),
72 CompressionKind::Lzo => Some(Self {
73 compression_type: CompressionType::Lzo,
74 max_decompressed_block_size,
75 }),
76 CompressionKind::Lz4 => Some(Self {
77 compression_type: CompressionType::Lz4,
78 max_decompressed_block_size,
79 }),
80 CompressionKind::Zstd => Some(Self {
81 compression_type: CompressionType::Zstd,
82 max_decompressed_block_size,
83 }),
84 }
85 }
86}
87
88#[derive(Clone, Copy, Debug)]
89pub enum CompressionType {
90 Zlib,
91 Snappy,
92 Lzo,
93 Lz4,
94 Zstd,
95}
96
97impl std::fmt::Display for CompressionType {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 write!(f, "{:?}", self)
100 }
101}
102
103#[derive(Debug, PartialEq, Eq)]
105enum CompressionHeader {
106 Original(u32),
107 Compressed(u32),
108}
109
110fn decode_header(bytes: [u8; 3]) -> CompressionHeader {
114 let bytes = [bytes[0], bytes[1], bytes[2], 0];
115 let length_and_flag = u32::from_le_bytes(bytes);
116 let is_original = length_and_flag & 1 == 1;
117 let length = length_and_flag >> 1;
118 if is_original {
119 CompressionHeader::Original(length)
120 } else {
121 CompressionHeader::Compressed(length)
122 }
123}
124
125pub(crate) trait DecompressorVariant: Send {
126 fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec<u8>) -> Result<()>;
127}
128
129#[derive(Debug, Clone, Copy)]
130struct Zlib;
131#[derive(Debug, Clone, Copy)]
132struct Zstd;
133#[derive(Debug, Clone, Copy)]
134struct Snappy;
135#[derive(Debug, Clone, Copy)]
136struct Lzo;
137#[derive(Debug, Clone, Copy)]
138struct Lz4 {
139 max_decompressed_block_size: usize,
140}
141
142impl DecompressorVariant for Zlib {
143 fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec<u8>) -> Result<()> {
144 let mut gz = flate2::read::DeflateDecoder::new(compressed_bytes);
145 scratch.clear();
146 gz.read_to_end(scratch).context(error::IoSnafu)?;
147 Ok(())
148 }
149}
150
151impl DecompressorVariant for Zstd {
152 fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec<u8>) -> Result<()> {
153 let mut reader =
154 zstd::Decoder::new(compressed_bytes).context(error::BuildZstdDecoderSnafu)?;
155 scratch.clear();
156 reader.read_to_end(scratch).context(error::IoSnafu)?;
157 Ok(())
158 }
159}
160
161impl DecompressorVariant for Snappy {
162 fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec<u8>) -> Result<()> {
163 let len =
164 snap::raw::decompress_len(compressed_bytes).context(error::BuildSnappyDecoderSnafu)?;
165 scratch.resize(len, 0);
166 let mut decoder = snap::raw::Decoder::new();
167 decoder
168 .decompress(compressed_bytes, scratch)
169 .context(error::BuildSnappyDecoderSnafu)?;
170 Ok(())
171 }
172}
173
174impl DecompressorVariant for Lzo {
175 fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec<u8>) -> Result<()> {
176 let decompressed = lzokay_native::decompress_all(compressed_bytes, None)
177 .context(error::BuildLzoDecoderSnafu)?;
178 scratch.clear();
180 scratch.extend(decompressed);
181 Ok(())
182 }
183}
184
185impl DecompressorVariant for Lz4 {
186 fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec<u8>) -> Result<()> {
187 let decompressed =
188 lz4_flex::block::decompress(compressed_bytes, self.max_decompressed_block_size)
189 .context(error::BuildLz4DecoderSnafu)?;
190 scratch.clear();
192 scratch.extend(decompressed);
193 Ok(())
194 }
195}
196
197fn get_decompressor_variant(
199 Compression {
200 compression_type,
201 max_decompressed_block_size,
202 }: Compression,
203) -> Box<dyn DecompressorVariant> {
204 match compression_type {
205 CompressionType::Zlib => Box::new(Zlib),
206 CompressionType::Snappy => Box::new(Snappy),
207 CompressionType::Lzo => Box::new(Lzo),
208 CompressionType::Lz4 => Box::new(Lz4 {
209 max_decompressed_block_size,
210 }),
211 CompressionType::Zstd => Box::new(Zstd),
212 }
213}
214
215enum State {
216 Original(Bytes),
217 Compressed(Vec<u8>),
218}
219
220struct DecompressorIter {
221 stream: BytesMut,
222 current: Option<State>, compression: Option<Box<dyn DecompressorVariant>>,
224 scratch: Vec<u8>,
225}
226
227impl DecompressorIter {
228 fn new(stream: Bytes, compression: Option<Compression>, scratch: Vec<u8>) -> Self {
229 Self {
230 stream: BytesMut::from(stream.as_ref()),
231 current: None,
232 compression: compression.map(get_decompressor_variant),
233 scratch,
234 }
235 }
236}
237
238impl FallibleStreamingIterator for DecompressorIter {
239 type Item = [u8];
240
241 type Error = OrcError;
242
243 #[inline]
244 fn advance(&mut self) -> Result<(), Self::Error> {
245 if self.stream.is_empty() {
246 self.current = None;
247 return Ok(());
248 }
249
250 match &self.compression {
251 Some(compression) => {
252 let header = self.stream.split_to(3);
254 let header = [header[0], header[1], header[2]];
255 match decode_header(header) {
256 CompressionHeader::Original(length) => {
257 let original = self.stream.split_to(length as usize);
258 self.current = Some(State::Original(original.into()));
259 }
260 CompressionHeader::Compressed(length) => {
261 let compressed = self.stream.split_to(length as usize);
262 compression.decompress_block(&compressed, &mut self.scratch)?;
263 self.current = Some(State::Compressed(std::mem::take(&mut self.scratch)));
264 }
265 };
266 Ok(())
267 }
268 None => {
269 self.current = Some(State::Original(self.stream.clone().into()));
271 self.stream.clear();
272 Ok(())
273 }
274 }
275 }
276
277 #[inline]
278 fn get(&self) -> Option<&Self::Item> {
279 self.current.as_ref().map(|x| match x {
280 State::Original(x) => x.as_ref(),
281 State::Compressed(x) => x.as_ref(),
282 })
283 }
284}
285
286pub struct Decompressor {
288 decompressor: DecompressorIter,
289 offset: usize,
290 is_first: bool,
291}
292
293impl Decompressor {
294 pub fn new(stream: Bytes, compression: Option<Compression>, scratch: Vec<u8>) -> Self {
296 Self {
297 decompressor: DecompressorIter::new(stream, compression, scratch),
298 offset: 0,
299 is_first: true,
300 }
301 }
302
303 pub fn empty() -> Self {
305 Self {
306 decompressor: DecompressorIter::new(Bytes::new(), None, vec![]),
307 offset: 0,
308 is_first: true,
309 }
310 }
311}
312
313impl std::io::Read for Decompressor {
314 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
315 if self.is_first {
316 self.is_first = false;
317 self.decompressor.advance().unwrap();
318 }
319 let current = self.decompressor.get();
320 let current = if let Some(current) = current {
321 if current.len() == self.offset {
322 self.decompressor.advance().unwrap();
323 self.offset = 0;
324 let current = self.decompressor.get();
325 if let Some(current) = current {
326 current
327 } else {
328 return Ok(0);
329 }
330 } else {
331 ¤t[self.offset..]
332 }
333 } else {
334 return Ok(0);
335 };
336
337 if current.len() >= buf.len() {
338 buf.copy_from_slice(¤t[..buf.len()]);
339 self.offset += buf.len();
340 Ok(buf.len())
341 } else {
342 buf[..current.len()].copy_from_slice(current);
343 self.offset += current.len();
344 Ok(current.len())
345 }
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn decode_uncompressed() {
355 let bytes = [0b1011, 0, 0];
357
358 let expected = CompressionHeader::Original(5);
359 let actual = decode_header(bytes);
360 assert_eq!(expected, actual);
361 }
362
363 #[test]
364 fn decode_compressed() {
365 let bytes = [0b0100_0000, 0b0000_1101, 0b0000_0011];
367 let expected = CompressionHeader::Compressed(100_000);
368 let actual = decode_header(bytes);
369 assert_eq!(expected, actual);
370 }
371}