Skip to main content

arrow_ipc/
compression.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::CompressionType;
19use arrow_buffer::Buffer;
20use arrow_schema::ArrowError;
21
22const LENGTH_NO_COMPRESSED_DATA: i64 = -1;
23const LENGTH_OF_PREFIX_DATA: i64 = 8;
24
25/// Additional context that may be needed for compression.
26///
27/// In the case of zstd, this will contain the zstd context, which can be reused between subsequent
28/// compression calls to avoid the performance overhead of initialising a new context for every
29/// compression.
30pub struct CompressionContext {
31    #[cfg(feature = "zstd")]
32    compressor: zstd::bulk::Compressor<'static>,
33}
34
35// the reason we allow derivable_impls here is because when zstd feature is not enabled, this
36// becomes derivable. however with zstd feature want to be explicit about the compression level.
37#[allow(clippy::derivable_impls)]
38impl Default for CompressionContext {
39    fn default() -> Self {
40        CompressionContext {
41            // safety: `new` here will only return error here if using an invalid compression level
42            #[cfg(feature = "zstd")]
43            compressor: zstd::bulk::Compressor::new(zstd::DEFAULT_COMPRESSION_LEVEL)
44                .expect("can use default compression level"),
45        }
46    }
47}
48
49impl std::fmt::Debug for CompressionContext {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        let mut ds = f.debug_struct("CompressionContext");
52
53        #[cfg(feature = "zstd")]
54        ds.field("compressor", &"zstd::bulk::Compressor");
55
56        ds.finish()
57    }
58}
59
60/// Additional context that may be needed for decompression.
61///
62/// In the case of zstd, this will contain the zstd decompression context, which can be reused
63/// between subsequent decompression calls to avoid the performance overhead of initialising a new
64/// context for every decompression.
65pub struct DecompressionContext {
66    #[cfg(feature = "zstd")]
67    decompressor: zstd::bulk::Decompressor<'static>,
68}
69
70impl DecompressionContext {
71    pub(crate) fn new() -> Self {
72        Default::default()
73    }
74}
75
76#[allow(clippy::derivable_impls)]
77impl Default for DecompressionContext {
78    fn default() -> Self {
79        DecompressionContext {
80            #[cfg(feature = "zstd")]
81            decompressor: zstd::bulk::Decompressor::new().expect("can create zstd decompressor"),
82        }
83    }
84}
85
86impl std::fmt::Debug for DecompressionContext {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        let mut ds = f.debug_struct("DecompressionContext");
89
90        #[cfg(feature = "zstd")]
91        ds.field("decompressor", &"zstd::bulk::Decompressor");
92
93        ds.finish()
94    }
95}
96
97/// Represents compressing a ipc stream using a particular compression algorithm
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum CompressionCodec {
100    Lz4Frame,
101    Zstd,
102}
103
104impl TryFrom<CompressionType> for CompressionCodec {
105    type Error = ArrowError;
106
107    fn try_from(compression_type: CompressionType) -> Result<Self, ArrowError> {
108        match compression_type {
109            CompressionType::ZSTD => Ok(CompressionCodec::Zstd),
110            CompressionType::LZ4_FRAME => Ok(CompressionCodec::Lz4Frame),
111            other_type => Err(ArrowError::NotYetImplemented(format!(
112                "compression type {other_type:?} not supported "
113            ))),
114        }
115    }
116}
117
118impl CompressionCodec {
119    /// Compresses the data in `input` to `output` and appends the
120    /// data using the specified compression mechanism.
121    ///
122    /// returns the number of bytes written to the stream
123    ///
124    /// Writes this format to output:
125    /// ```text
126    /// [8 bytes]:         uncompressed length
127    /// [remaining bytes]: compressed data stream
128    /// ```
129    pub(crate) fn compress_to_vec(
130        &self,
131        input: &[u8],
132        output: &mut Vec<u8>,
133        context: &mut CompressionContext,
134    ) -> Result<usize, ArrowError> {
135        let uncompressed_data_len = input.len();
136        let original_output_len = output.len();
137
138        if input.is_empty() {
139            // empty input, nothing to do
140        } else {
141            // write compressed data directly into the output buffer
142            output.extend_from_slice(&uncompressed_data_len.to_le_bytes());
143            self.compress(input, output, context)?;
144
145            let compression_len = output.len() - original_output_len;
146            if compression_len > uncompressed_data_len {
147                // length of compressed data was larger than
148                // uncompressed data, use the uncompressed data with
149                // length -1 to indicate that we don't compress the
150                // data
151                output.truncate(original_output_len);
152                output.extend_from_slice(&LENGTH_NO_COMPRESSED_DATA.to_le_bytes());
153                output.extend_from_slice(input);
154            }
155        }
156        Ok(output.len() - original_output_len)
157    }
158
159    /// Decompresses the input into a [`Buffer`]
160    ///
161    /// The input should look like:
162    /// ```text
163    /// [8 bytes]:         uncompressed length
164    /// [remaining bytes]: compressed data stream
165    /// ```
166    pub(crate) fn decompress_to_buffer(
167        &self,
168        input: &Buffer,
169        context: &mut DecompressionContext,
170    ) -> Result<Buffer, ArrowError> {
171        // read the first 8 bytes to determine if the data is
172        // compressed
173        let decompressed_length = read_uncompressed_size(input);
174        let buffer = if decompressed_length == 0 {
175            // empty
176            Buffer::from([])
177        } else if decompressed_length == LENGTH_NO_COMPRESSED_DATA {
178            // no compression
179            input.slice(LENGTH_OF_PREFIX_DATA as usize)
180        } else if let Ok(decompressed_length) = usize::try_from(decompressed_length) {
181            // decompress data using the codec
182            let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..];
183            let v = self.decompress(input_data, decompressed_length as _, context)?;
184            Buffer::from_vec(v)
185        } else {
186            return Err(ArrowError::IpcError(format!(
187                "Invalid uncompressed length: {decompressed_length}"
188            )));
189        };
190        Ok(buffer)
191    }
192
193    /// Compress the data in input buffer and write to output buffer
194    /// using the specified compression
195    fn compress(
196        &self,
197        input: &[u8],
198        output: &mut Vec<u8>,
199        context: &mut CompressionContext,
200    ) -> Result<(), ArrowError> {
201        match self {
202            CompressionCodec::Lz4Frame => compress_lz4(input, output),
203            CompressionCodec::Zstd => compress_zstd(input, output, context),
204        }
205    }
206
207    /// Decompress the data in input buffer and write to output buffer
208    /// using the specified compression
209    fn decompress(
210        &self,
211        input: &[u8],
212        decompressed_size: usize,
213        context: &mut DecompressionContext,
214    ) -> Result<Vec<u8>, ArrowError> {
215        let ret = match self {
216            CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?,
217            CompressionCodec::Zstd => decompress_zstd(input, decompressed_size, context)?,
218        };
219        if ret.len() != decompressed_size {
220            return Err(ArrowError::IpcError(format!(
221                "Expected compressed length of {decompressed_size} got {}",
222                ret.len()
223            )));
224        }
225        Ok(ret)
226    }
227}
228
229#[cfg(feature = "lz4")]
230fn compress_lz4(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
231    use std::io::Write;
232    let mut encoder = lz4_flex::frame::FrameEncoder::new(output);
233    encoder.write_all(input)?;
234    encoder
235        .finish()
236        .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
237    Ok(())
238}
239
240#[cfg(not(feature = "lz4"))]
241#[allow(clippy::ptr_arg)]
242fn compress_lz4(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> {
243    Err(ArrowError::InvalidArgumentError(
244        "lz4 IPC compression requires the lz4 feature".to_string(),
245    ))
246}
247
248#[cfg(feature = "lz4")]
249fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
250    use std::io::Read;
251    let mut output = Vec::with_capacity(decompressed_size);
252    lz4_flex::frame::FrameDecoder::new(input).read_to_end(&mut output)?;
253    Ok(output)
254}
255
256#[cfg(not(feature = "lz4"))]
257#[allow(clippy::ptr_arg)]
258fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
259    Err(ArrowError::InvalidArgumentError(
260        "lz4 IPC decompression requires the lz4 feature".to_string(),
261    ))
262}
263
264#[cfg(feature = "zstd")]
265fn compress_zstd(
266    input: &[u8],
267    output: &mut Vec<u8>,
268    context: &mut CompressionContext,
269) -> Result<(), ArrowError> {
270    let result = context.compressor.compress(input)?;
271    output.extend_from_slice(&result);
272    Ok(())
273}
274
275#[cfg(not(feature = "zstd"))]
276#[allow(clippy::ptr_arg)]
277fn compress_zstd(
278    _input: &[u8],
279    _output: &mut Vec<u8>,
280    _context: &mut CompressionContext,
281) -> Result<(), ArrowError> {
282    Err(ArrowError::InvalidArgumentError(
283        "zstd IPC compression requires the zstd feature".to_string(),
284    ))
285}
286
287#[cfg(feature = "zstd")]
288fn decompress_zstd(
289    input: &[u8],
290    decompressed_size: usize,
291    context: &mut DecompressionContext,
292) -> Result<Vec<u8>, ArrowError> {
293    let output = context.decompressor.decompress(input, decompressed_size)?;
294    Ok(output)
295}
296
297#[cfg(not(feature = "zstd"))]
298#[allow(clippy::ptr_arg)]
299fn decompress_zstd(
300    _input: &[u8],
301    _decompressed_size: usize,
302    _context: &mut DecompressionContext,
303) -> Result<Vec<u8>, ArrowError> {
304    Err(ArrowError::InvalidArgumentError(
305        "zstd IPC decompression requires the zstd feature".to_string(),
306    ))
307}
308
309/// Get the uncompressed length
310/// Notes:
311///   LENGTH_NO_COMPRESSED_DATA: indicate that the data that follows is not compressed
312///    0: indicate that there is no data
313///   positive number: indicate the uncompressed length for the following data
314#[inline]
315fn read_uncompressed_size(buffer: &[u8]) -> i64 {
316    let len_buffer = &buffer[0..8];
317    // 64-bit little-endian signed integer
318    i64::from_le_bytes(len_buffer.try_into().unwrap())
319}
320
321#[cfg(test)]
322mod tests {
323    #[test]
324    #[cfg(feature = "lz4")]
325    fn test_lz4_compression() {
326        let input_bytes = b"hello lz4";
327        let codec = super::CompressionCodec::Lz4Frame;
328        let mut output_bytes: Vec<u8> = Vec::new();
329        codec
330            .compress(input_bytes, &mut output_bytes, &mut Default::default())
331            .unwrap();
332        let result = codec
333            .decompress(
334                output_bytes.as_slice(),
335                input_bytes.len(),
336                &mut Default::default(),
337            )
338            .unwrap();
339        assert_eq!(input_bytes, result.as_slice());
340    }
341
342    #[test]
343    #[cfg(feature = "zstd")]
344    fn test_zstd_compression() {
345        let input_bytes = b"hello zstd";
346        let codec = super::CompressionCodec::Zstd;
347        let mut output_bytes: Vec<u8> = Vec::new();
348        codec
349            .compress(input_bytes, &mut output_bytes, &mut Default::default())
350            .unwrap();
351        let result = codec
352            .decompress(
353                output_bytes.as_slice(),
354                input_bytes.len(),
355                &mut Default::default(),
356            )
357            .unwrap();
358        assert_eq!(input_bytes, result.as_slice());
359    }
360}