Skip to main content

lance_encoding/encodings/physical/
block.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Encodings based on traditional block compression schemes
5//!
6//! Traditional compressors take in a buffer and return a smaller buffer.  All encoding
7//! description is shoved into the compressed buffer and the entire buffer is needed to
8//! decompress any of the data.
9//!
10//! These encodings are not transparent, which limits our ability to use them.  In addition
11//! they are often quite expensive in CPU terms.
12//!
13//! However, they are effective and useful for some cases.  For example, when working with large
14//! variable length values (e.g. source code files) they can be very effective.
15//!
16//! The module introduces the `[BufferCompressor]` trait which describes the interface for a
17//! traditional block compressor.  It is implemented for the most common compression schemes
18//! (zstd, lz4, etc).
19//!
20//! There is not yet a mini-block variant of this compressor (but could easily be one) and the
21//! full zip variant works by applying compression on a per-value basis (which allows it to be
22//! transparent).
23
24use arrow_buffer::ArrowNativeType;
25use lance_core::{Error, Result};
26
27use std::str::FromStr;
28
29use crate::compression::{BlockCompressor, BlockDecompressor};
30use crate::encodings::physical::binary::{BinaryBlockDecompressor, VariableEncoder};
31use crate::format::{
32    ProtobufUtils21,
33    pb21::{self, CompressiveEncoding},
34};
35use crate::{
36    buffer::LanceBuffer,
37    compression::VariablePerValueDecompressor,
38    data::{BlockInfo, DataBlock, VariableWidthBlock},
39    encodings::logical::primitive::fullzip::{PerValueCompressor, PerValueDataBlock},
40};
41
42#[derive(Debug, Clone, Copy, PartialEq)]
43pub struct CompressionConfig {
44    pub(crate) scheme: CompressionScheme,
45    pub(crate) level: Option<i32>,
46}
47
48impl CompressionConfig {
49    pub(crate) fn new(scheme: CompressionScheme, level: Option<i32>) -> Self {
50        Self { scheme, level }
51    }
52}
53
54impl Default for CompressionConfig {
55    fn default() -> Self {
56        Self {
57            scheme: CompressionScheme::Lz4,
58            level: Some(0),
59        }
60    }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq)]
64pub enum CompressionScheme {
65    None,
66    Fsst,
67    Zstd,
68    Lz4,
69}
70
71impl TryFrom<CompressionScheme> for pb21::CompressionScheme {
72    type Error = Error;
73
74    fn try_from(scheme: CompressionScheme) -> Result<Self> {
75        match scheme {
76            CompressionScheme::Lz4 => Ok(Self::CompressionAlgorithmLz4),
77            CompressionScheme::Zstd => Ok(Self::CompressionAlgorithmZstd),
78            _ => Err(Error::invalid_input(format!(
79                "Unsupported compression scheme: {:?}",
80                scheme
81            ))),
82        }
83    }
84}
85
86impl TryFrom<pb21::CompressionScheme> for CompressionScheme {
87    type Error = Error;
88
89    fn try_from(scheme: pb21::CompressionScheme) -> Result<Self> {
90        match scheme {
91            pb21::CompressionScheme::CompressionAlgorithmLz4 => Ok(Self::Lz4),
92            pb21::CompressionScheme::CompressionAlgorithmZstd => Ok(Self::Zstd),
93            _ => Err(Error::invalid_input(format!(
94                "Unsupported compression scheme: {:?}",
95                scheme
96            ))),
97        }
98    }
99}
100
101impl std::fmt::Display for CompressionScheme {
102    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
103        let scheme_str = match self {
104            Self::Fsst => "fsst",
105            Self::Zstd => "zstd",
106            Self::None => "none",
107            Self::Lz4 => "lz4",
108        };
109        write!(f, "{}", scheme_str)
110    }
111}
112
113impl FromStr for CompressionScheme {
114    type Err = Error;
115
116    fn from_str(s: &str) -> Result<Self> {
117        match s {
118            "none" => Ok(Self::None),
119            "fsst" => Ok(Self::Fsst),
120            "zstd" => Ok(Self::Zstd),
121            "lz4" => Ok(Self::Lz4),
122            _ => Err(Error::invalid_input(format!(
123                "Unknown compression scheme: {}",
124                s
125            ))),
126        }
127    }
128}
129
130pub trait BufferCompressor: std::fmt::Debug + Send + Sync {
131    fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
132    fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
133    fn config(&self) -> CompressionConfig;
134}
135
136#[cfg(feature = "zstd")]
137mod zstd {
138    use std::io::{Cursor, Write};
139    use std::sync::{Mutex, OnceLock};
140
141    use super::*;
142
143    use ::zstd::bulk::{Compressor, decompress_to_buffer};
144    use ::zstd::stream::copy_decode;
145
146    /// A zstd buffer compressor that lazily creates and reuses compression contexts.
147    ///
148    /// The compression context is cached to enable reuse across chunks within a
149    /// page. It is lazily initialized to prevent it from getting initialized on
150    /// decode-only codepaths.
151    ///
152    /// Reuse is not implemented for decompression, only for compression:
153    /// * The single-threaded benefit of reuse was negligible when measured.
154    /// * Decompressors can get shared across threads, leading to mutex
155    ///   contention if the same strategy is used as for compression here. This
156    ///   should be mitigable with pooling but we can skip the complexity until a
157    ///   need is demonstrated. The multithreaded decode benchmark effectively
158    ///   demonstrates this scenario.
159    pub struct ZstdBufferCompressor {
160        compression_level: i32,
161        compressor: OnceLock<std::result::Result<Mutex<Compressor<'static>>, String>>,
162    }
163
164    impl std::fmt::Debug for ZstdBufferCompressor {
165        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166            f.debug_struct("ZstdBufferCompressor")
167                .field("compression_level", &self.compression_level)
168                .finish()
169        }
170    }
171
172    impl ZstdBufferCompressor {
173        pub fn new(compression_level: i32) -> Self {
174            Self {
175                compression_level,
176                compressor: OnceLock::new(),
177            }
178        }
179
180        fn get_compressor(&self) -> Result<&Mutex<Compressor<'static>>> {
181            self.compressor
182                .get_or_init(|| {
183                    Compressor::new(self.compression_level)
184                        .map(Mutex::new)
185                        .map_err(|e| e.to_string())
186                })
187                .as_ref()
188                .map_err(|e| Error::internal(format!("Failed to create zstd compressor: {}", e)))
189        }
190
191        // https://datatracker.ietf.org/doc/html/rfc8878
192        fn is_raw_stream_format(&self, input_buf: &[u8]) -> bool {
193            if input_buf.len() < 8 {
194                return true; // can't be length prefixed format if less than 8 bytes
195            }
196            // read the first 4 bytes as the magic number
197            let mut magic_buf = [0u8; 4];
198            magic_buf.copy_from_slice(&input_buf[..4]);
199            let magic = u32::from_le_bytes(magic_buf);
200
201            // see RFC 8878, section 3.1.1. Zstandard Frames, which defines the magic number
202            const ZSTD_MAGIC_NUMBER: u32 = 0xFD2FB528;
203            if magic == ZSTD_MAGIC_NUMBER {
204                // the compressed buffer starts like a Zstd frame.
205                // Per RFC 8878, the reserved bit (with Bit Number 3, the 4th bit) in the FHD (frame header descriptor) MUST be 0
206                // see section 3.1.1.1.1. 'Frame_Header_Descriptor' and section 3.1.1.1.1.4. 'Reserved Bit' for details
207                const FHD_BYTE_INDEX: usize = 4;
208                let fhd_byte = input_buf[FHD_BYTE_INDEX];
209                const FHD_RESERVED_BIT_MASK: u8 = 0b0001_0000;
210                let reserved_bit = fhd_byte & FHD_RESERVED_BIT_MASK;
211
212                if reserved_bit != 0 {
213                    // this bit is 1. This is NOT a valid zstd frame.
214                    // therefore, it must be length prefixed format where the length coincidentally
215                    // started with the magic number
216                    false
217                } else {
218                    // the reserved bit is 0. This is consistent with a valid Zstd frame.
219                    // treat it as raw stream format
220                    true
221                }
222            } else {
223                // doesn't start with the magic number, so it can't be the raw stream format
224                false
225            }
226        }
227
228        fn decompress_length_prefixed_zstd(
229            &self,
230            input_buf: &[u8],
231            output_buf: &mut Vec<u8>,
232        ) -> Result<()> {
233            const LENGTH_PREFIX_SIZE: usize = 8;
234            let mut len_buf = [0u8; LENGTH_PREFIX_SIZE];
235            len_buf.copy_from_slice(&input_buf[..LENGTH_PREFIX_SIZE]);
236
237            let uncompressed_len = u64::from_le_bytes(len_buf) as usize;
238
239            let start = output_buf.len();
240            output_buf.resize(start + uncompressed_len, 0);
241
242            let compressed_data = &input_buf[LENGTH_PREFIX_SIZE..];
243            decompress_to_buffer(compressed_data, &mut output_buf[start..])?;
244            Ok(())
245        }
246    }
247
248    impl BufferCompressor for ZstdBufferCompressor {
249        fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
250            output_buf.write_all(&(input_buf.len() as u64).to_le_bytes())?;
251
252            let max_compressed_size = ::zstd::zstd_safe::compress_bound(input_buf.len());
253            let start_pos = output_buf.len();
254            output_buf.resize(start_pos + max_compressed_size, 0);
255
256            let compressed_size = self
257                .get_compressor()?
258                .lock()
259                .unwrap()
260                .compress_to_buffer(input_buf, &mut output_buf[start_pos..])
261                .map_err(|e| Error::internal(format!("Zstd compression error: {}", e)))?;
262
263            output_buf.truncate(start_pos + compressed_size);
264            Ok(())
265        }
266
267        fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
268            if input_buf.is_empty() {
269                return Ok(());
270            }
271
272            let is_raw_stream_format = self.is_raw_stream_format(input_buf);
273            if is_raw_stream_format {
274                copy_decode(Cursor::new(input_buf), output_buf)?;
275            } else {
276                self.decompress_length_prefixed_zstd(input_buf, output_buf)?;
277            }
278
279            Ok(())
280        }
281
282        fn config(&self) -> CompressionConfig {
283            CompressionConfig {
284                scheme: CompressionScheme::Zstd,
285                level: Some(self.compression_level),
286            }
287        }
288    }
289}
290
291#[cfg(feature = "lz4")]
292mod lz4 {
293    use super::*;
294
295    #[derive(Debug, Default)]
296    pub struct Lz4BufferCompressor {}
297
298    impl BufferCompressor for Lz4BufferCompressor {
299        fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
300            // Remember the starting position
301            let start_pos = output_buf.len();
302
303            // LZ4 needs space for the compressed data
304            let max_size = ::lz4::block::compress_bound(input_buf.len())?;
305            // Resize to ensure we have enough space (including 4 bytes for size header)
306            output_buf.resize(start_pos + max_size + 4, 0);
307
308            let compressed_size = ::lz4::block::compress_to_buffer(
309                input_buf,
310                None,
311                true,
312                &mut output_buf[start_pos..],
313            )
314            .map_err(|err| Error::internal(format!("LZ4 compression error: {}", err)))?;
315
316            // Truncate to actual size
317            output_buf.truncate(start_pos + compressed_size);
318            Ok(())
319        }
320
321        fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
322            // When prepend_size is true, LZ4 stores the uncompressed size in the first 4 bytes
323            // We can read this to know exactly how much space we need
324            if input_buf.len() < 4 {
325                return Err(Error::internal("LZ4 compressed data too short".to_string()));
326            }
327
328            // Read the uncompressed size from the first 4 bytes (little-endian)
329            let uncompressed_size =
330                u32::from_le_bytes([input_buf[0], input_buf[1], input_buf[2], input_buf[3]])
331                    as usize;
332
333            // Remember the starting position
334            let start_pos = output_buf.len();
335
336            // Resize to ensure we have the exact space needed
337            output_buf.resize(start_pos + uncompressed_size, 0);
338
339            // Now decompress directly into the buffer slice
340            let decompressed_size =
341                ::lz4::block::decompress_to_buffer(input_buf, None, &mut output_buf[start_pos..])
342                    .map_err(|err| Error::internal(format!("LZ4 decompression error: {}", err)))?;
343
344            // Truncate to actual decompressed size (should be same as uncompressed_size)
345            output_buf.truncate(start_pos + decompressed_size);
346
347            Ok(())
348        }
349
350        fn config(&self) -> CompressionConfig {
351            CompressionConfig {
352                scheme: CompressionScheme::Lz4,
353                level: None,
354            }
355        }
356    }
357}
358
359#[derive(Debug, Default)]
360pub struct NoopBufferCompressor {}
361
362impl BufferCompressor for NoopBufferCompressor {
363    fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
364        output_buf.extend_from_slice(input_buf);
365        Ok(())
366    }
367
368    fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
369        output_buf.extend_from_slice(input_buf);
370        Ok(())
371    }
372
373    fn config(&self) -> CompressionConfig {
374        CompressionConfig {
375            scheme: CompressionScheme::None,
376            level: None,
377        }
378    }
379}
380
381pub struct GeneralBufferCompressor {}
382
383impl GeneralBufferCompressor {
384    pub fn get_compressor(
385        compression_config: CompressionConfig,
386    ) -> Result<Box<dyn BufferCompressor>> {
387        match compression_config.scheme {
388            // FSST has its own compression path and isn't implemented as a generic buffer compressor
389            CompressionScheme::Fsst => Err(Error::invalid_input_source(
390                "fsst is not usable as a general buffer compressor".into(),
391            )),
392            CompressionScheme::Zstd => {
393                #[cfg(feature = "zstd")]
394                {
395                    Ok(Box::new(zstd::ZstdBufferCompressor::new(
396                        compression_config.level.unwrap_or(0),
397                    )))
398                }
399                #[cfg(not(feature = "zstd"))]
400                {
401                    Err(Error::invalid_input_source(
402                        "package was not built with zstd support".into(),
403                    ))
404                }
405            }
406            CompressionScheme::Lz4 => {
407                #[cfg(feature = "lz4")]
408                {
409                    Ok(Box::new(lz4::Lz4BufferCompressor::default()))
410                }
411                #[cfg(not(feature = "lz4"))]
412                {
413                    Err(Error::invalid_input_source(
414                        "package was not built with lz4 support".into(),
415                    ))
416                }
417            }
418            CompressionScheme::None => Ok(Box::new(NoopBufferCompressor {})),
419        }
420    }
421}
422
423/// A block decompressor that first applies general-purpose compression (LZ4/Zstd)
424/// before delegating to an inner block decompressor.
425#[derive(Debug)]
426pub struct GeneralBlockDecompressor {
427    inner: Box<dyn BlockDecompressor>,
428    compressor: Box<dyn BufferCompressor>,
429}
430
431impl GeneralBlockDecompressor {
432    pub fn try_new(
433        inner: Box<dyn BlockDecompressor>,
434        compression: CompressionConfig,
435    ) -> Result<Self> {
436        let compressor = GeneralBufferCompressor::get_compressor(compression)?;
437        Ok(Self { inner, compressor })
438    }
439}
440
441impl BlockDecompressor for GeneralBlockDecompressor {
442    fn decompress(&self, data: LanceBuffer, num_values: u64) -> Result<DataBlock> {
443        let mut decompressed = Vec::new();
444        self.compressor.decompress(&data, &mut decompressed)?;
445        self.inner
446            .decompress(LanceBuffer::from(decompressed), num_values)
447    }
448}
449
450// An encoder which uses generic compression, such as zstd/lz4 to encode buffers
451#[derive(Debug)]
452pub struct CompressedBufferEncoder {
453    pub(crate) compressor: Box<dyn BufferCompressor>,
454}
455
456impl Default for CompressedBufferEncoder {
457    fn default() -> Self {
458        // Pick zstd if available, otherwise lz4, otherwise none
459        #[cfg(feature = "zstd")]
460        let (scheme, level) = (CompressionScheme::Zstd, Some(0));
461        #[cfg(all(feature = "lz4", not(feature = "zstd")))]
462        let (scheme, level) = (CompressionScheme::Lz4, None);
463        #[cfg(not(any(feature = "zstd", feature = "lz4")))]
464        let (scheme, level) = (CompressionScheme::None, None);
465
466        let compressor =
467            GeneralBufferCompressor::get_compressor(CompressionConfig { scheme, level }).unwrap();
468        Self { compressor }
469    }
470}
471
472impl CompressedBufferEncoder {
473    pub fn try_new(compression_config: CompressionConfig) -> Result<Self> {
474        let compressor = GeneralBufferCompressor::get_compressor(compression_config)?;
475        Ok(Self { compressor })
476    }
477
478    pub fn from_scheme(scheme: pb21::CompressionScheme) -> Result<Self> {
479        let scheme = CompressionScheme::try_from(scheme)?;
480        Ok(Self {
481            compressor: GeneralBufferCompressor::get_compressor(CompressionConfig {
482                scheme,
483                level: Some(0),
484            })?,
485        })
486    }
487}
488
489impl CompressedBufferEncoder {
490    pub fn per_value_compress<T: ArrowNativeType>(
491        &self,
492        data: &[u8],
493        offsets: &[T],
494        compressed: &mut Vec<u8>,
495    ) -> Result<LanceBuffer> {
496        let mut new_offsets: Vec<T> = Vec::with_capacity(offsets.len());
497        new_offsets.push(T::from_usize(0).unwrap());
498
499        for off in offsets.windows(2) {
500            let start = off[0].as_usize();
501            let end = off[1].as_usize();
502            self.compressor.compress(&data[start..end], compressed)?;
503            new_offsets.push(T::from_usize(compressed.len()).unwrap());
504        }
505
506        Ok(LanceBuffer::reinterpret_vec(new_offsets))
507    }
508
509    pub fn per_value_decompress<T: ArrowNativeType>(
510        &self,
511        data: &[u8],
512        offsets: &[T],
513        decompressed: &mut Vec<u8>,
514    ) -> Result<LanceBuffer> {
515        let mut new_offsets: Vec<T> = Vec::with_capacity(offsets.len());
516        new_offsets.push(T::from_usize(0).unwrap());
517
518        for off in offsets.windows(2) {
519            let start = off[0].as_usize();
520            let end = off[1].as_usize();
521            self.compressor
522                .decompress(&data[start..end], decompressed)?;
523            new_offsets.push(T::from_usize(decompressed.len()).unwrap());
524        }
525
526        Ok(LanceBuffer::reinterpret_vec(new_offsets))
527    }
528}
529
530impl PerValueCompressor for CompressedBufferEncoder {
531    fn compress(&self, data: DataBlock) -> Result<(PerValueDataBlock, CompressiveEncoding)> {
532        let data_type = data.name();
533        let data = data.as_variable_width().ok_or(Error::internal(format!(
534            "Attempt to use CompressedBufferEncoder on data of type {}",
535            data_type
536        )))?;
537
538        let data_bytes = &data.data;
539        let mut compressed = Vec::with_capacity(data_bytes.len());
540
541        let new_offsets = match data.bits_per_offset {
542            32 => self.per_value_compress::<u32>(
543                data_bytes,
544                &data.offsets.borrow_to_typed_slice::<u32>(),
545                &mut compressed,
546            )?,
547            64 => self.per_value_compress::<u64>(
548                data_bytes,
549                &data.offsets.borrow_to_typed_slice::<u64>(),
550                &mut compressed,
551            )?,
552            _ => unreachable!(),
553        };
554
555        let compressed = PerValueDataBlock::Variable(VariableWidthBlock {
556            bits_per_offset: data.bits_per_offset,
557            data: LanceBuffer::from(compressed),
558            offsets: new_offsets,
559            num_values: data.num_values,
560            block_info: BlockInfo::new(),
561        });
562
563        // TODO: Support setting the level
564        // TODO: Support underlying compression of data (e.g. defer to binary encoding for offset bitpacking)
565        let encoding = ProtobufUtils21::wrapped(
566            self.compressor.config(),
567            ProtobufUtils21::variable(
568                ProtobufUtils21::flat(data.bits_per_offset as u64, None),
569                None,
570            ),
571        )?;
572
573        Ok((compressed, encoding))
574    }
575}
576
577impl VariablePerValueDecompressor for CompressedBufferEncoder {
578    fn decompress(&self, data: VariableWidthBlock) -> Result<DataBlock> {
579        let data_bytes = &data.data;
580        let mut decompressed = Vec::with_capacity(data_bytes.len() * 2);
581
582        let new_offsets = match data.bits_per_offset {
583            32 => self.per_value_decompress(
584                data_bytes,
585                &data.offsets.borrow_to_typed_slice::<u32>(),
586                &mut decompressed,
587            )?,
588            64 => self.per_value_decompress(
589                data_bytes,
590                &data.offsets.borrow_to_typed_slice::<u64>(),
591                &mut decompressed,
592            )?,
593            _ => unreachable!(),
594        };
595        Ok(DataBlock::VariableWidth(VariableWidthBlock {
596            bits_per_offset: data.bits_per_offset,
597            data: LanceBuffer::from(decompressed),
598            offsets: new_offsets,
599            num_values: data.num_values,
600            block_info: BlockInfo::new(),
601        }))
602    }
603}
604
605impl BlockCompressor for CompressedBufferEncoder {
606    fn compress(&self, data: DataBlock) -> Result<LanceBuffer> {
607        let encoded = match data {
608            DataBlock::FixedWidth(fixed_width) => fixed_width.data,
609            DataBlock::VariableWidth(variable_width) => {
610                // Wrap VariableEncoder to handle the encoding
611                let encoder = VariableEncoder::default();
612                BlockCompressor::compress(&encoder, DataBlock::VariableWidth(variable_width))?
613            }
614            _ => {
615                return Err(Error::invalid_input_source(
616                    "Unsupported data block type".into(),
617                ));
618            }
619        };
620
621        let mut compressed = Vec::new();
622        self.compressor.compress(&encoded, &mut compressed)?;
623        Ok(LanceBuffer::from(compressed))
624    }
625}
626
627impl BlockDecompressor for CompressedBufferEncoder {
628    fn decompress(&self, data: LanceBuffer, num_values: u64) -> Result<DataBlock> {
629        let mut decompressed = Vec::new();
630        self.compressor.decompress(&data, &mut decompressed)?;
631
632        // Delegate to BinaryBlockDecompressor which handles the inline metadata
633        let inner_decoder = BinaryBlockDecompressor::default();
634        inner_decoder.decompress(LanceBuffer::from(decompressed), num_values)
635    }
636}
637
638#[cfg(test)]
639mod tests {
640    use super::*;
641    use std::str::FromStr;
642
643    use crate::encodings::physical::block::zstd::ZstdBufferCompressor;
644
645    #[test]
646    fn test_compression_scheme_from_str() {
647        assert_eq!(
648            CompressionScheme::from_str("none").unwrap(),
649            CompressionScheme::None
650        );
651        assert_eq!(
652            CompressionScheme::from_str("zstd").unwrap(),
653            CompressionScheme::Zstd
654        );
655    }
656
657    #[test]
658    fn test_compression_scheme_from_str_invalid() {
659        assert!(CompressionScheme::from_str("invalid").is_err());
660    }
661
662    #[cfg(feature = "zstd")]
663    mod zstd {
664        use std::io::Write;
665
666        use super::*;
667
668        #[test]
669        fn test_compress_zstd_with_length_prefixed() {
670            let compressor = ZstdBufferCompressor::new(0);
671            let input_data = b"Hello, world!";
672            let mut compressed_data = Vec::new();
673
674            compressor
675                .compress(input_data, &mut compressed_data)
676                .unwrap();
677            let mut decompressed_data = Vec::new();
678            compressor
679                .decompress(&compressed_data, &mut decompressed_data)
680                .unwrap();
681            assert_eq!(input_data, decompressed_data.as_slice());
682        }
683
684        #[test]
685        fn test_zstd_compress_decompress_multiple_times() {
686            let compressor = ZstdBufferCompressor::new(0);
687            let (input_data_1, input_data_2) = (b"Hello ", b"World");
688            let mut compressed_data = Vec::new();
689
690            compressor
691                .compress(input_data_1, &mut compressed_data)
692                .unwrap();
693            let compressed_length_1 = compressed_data.len();
694
695            compressor
696                .compress(input_data_2, &mut compressed_data)
697                .unwrap();
698
699            let mut decompressed_data = Vec::new();
700            compressor
701                .decompress(
702                    &compressed_data[..compressed_length_1],
703                    &mut decompressed_data,
704                )
705                .unwrap();
706
707            compressor
708                .decompress(
709                    &compressed_data[compressed_length_1..],
710                    &mut decompressed_data,
711                )
712                .unwrap();
713
714            // the output should contain both input_data_1 and input_data_2
715            assert_eq!(
716                decompressed_data.len(),
717                input_data_1.len() + input_data_2.len()
718            );
719            assert_eq!(
720                &decompressed_data[..input_data_1.len()],
721                input_data_1,
722                "First part of decompressed data should match input_1"
723            );
724            assert_eq!(
725                &decompressed_data[input_data_1.len()..],
726                input_data_2,
727                "Second part of decompressed data should match input_2"
728            );
729        }
730
731        #[test]
732        fn test_compress_zstd_raw_stream_format_and_decompress_with_length_prefixed() {
733            let compressor = ZstdBufferCompressor::new(0);
734            let input_data = b"Hello, world!";
735            let mut compressed_data = Vec::new();
736
737            // compress using raw stream format
738            let mut encoder = ::zstd::Encoder::new(&mut compressed_data, 0).unwrap();
739            encoder.write_all(input_data).unwrap();
740            encoder.finish().expect("failed to encode data with zstd");
741
742            // decompress using length prefixed format
743            let mut decompressed_data = Vec::new();
744            compressor
745                .decompress(&compressed_data, &mut decompressed_data)
746                .unwrap();
747            assert_eq!(input_data, decompressed_data.as_slice());
748        }
749    }
750
751    #[cfg(feature = "lz4")]
752    mod lz4 {
753        use std::{collections::HashMap, sync::Arc};
754
755        use arrow_schema::{DataType, Field};
756        use lance_datagen::array::{binary_prefix_plus_counter, utf8_prefix_plus_counter};
757
758        use super::*;
759
760        use crate::constants::DICT_SIZE_RATIO_META_KEY;
761        use crate::{
762            constants::{
763                COMPRESSION_META_KEY, DICT_DIVISOR_META_KEY, STRUCTURAL_ENCODING_FULLZIP,
764                STRUCTURAL_ENCODING_META_KEY,
765            },
766            encodings::physical::block::lz4::Lz4BufferCompressor,
767            testing::{FnArrayGeneratorProvider, TestCases, check_round_trip_encoding_generated},
768            version::LanceFileVersion,
769        };
770
771        #[test]
772        fn test_lz4_compress_decompress() {
773            let compressor = Lz4BufferCompressor::default();
774            let input_data = b"Hello, world!";
775            let mut compressed_data = Vec::new();
776
777            compressor
778                .compress(input_data, &mut compressed_data)
779                .unwrap();
780            let mut decompressed_data = Vec::new();
781            compressor
782                .decompress(&compressed_data, &mut decompressed_data)
783                .unwrap();
784            assert_eq!(input_data, decompressed_data.as_slice());
785        }
786
787        #[test_log::test(tokio::test)]
788        async fn test_lz4_compress_round_trip() {
789            for data_type in &[
790                DataType::Utf8,
791                DataType::LargeUtf8,
792                DataType::Binary,
793                DataType::LargeBinary,
794            ] {
795                let field = Field::new("", data_type.clone(), false);
796                let mut field_meta = HashMap::new();
797                field_meta.insert(COMPRESSION_META_KEY.to_string(), "lz4".to_string());
798                // Some bad cardinality estimatation causes us to use dictionary encoding currently
799                // which causes the expected encoding check to fail.
800                field_meta.insert(DICT_DIVISOR_META_KEY.to_string(), "100000".to_string());
801                field_meta.insert(DICT_SIZE_RATIO_META_KEY.to_string(), "0.0001".to_string());
802                // Also disable size-based dictionary encoding
803                field_meta.insert(
804                    STRUCTURAL_ENCODING_META_KEY.to_string(),
805                    STRUCTURAL_ENCODING_FULLZIP.to_string(),
806                );
807                let field = field.with_metadata(field_meta);
808                let test_cases = TestCases::basic()
809                    // Need to use large pages as small pages might be too small to compress
810                    .with_page_sizes(vec![1024 * 1024])
811                    .with_expected_encoding("zstd")
812                    .with_min_file_version(LanceFileVersion::V2_1);
813
814                // Can't use the default random provider because random data isn't compressible
815                // and we will fallback to uncompressed encoding
816                let datagen = Box::new(FnArrayGeneratorProvider::new(move || match data_type {
817                    DataType::Utf8 => utf8_prefix_plus_counter("compressme", false),
818                    DataType::Binary => {
819                        binary_prefix_plus_counter(Arc::from(b"compressme".to_owned()), false)
820                    }
821                    DataType::LargeUtf8 => utf8_prefix_plus_counter("compressme", true),
822                    DataType::LargeBinary => {
823                        binary_prefix_plus_counter(Arc::from(b"compressme".to_owned()), true)
824                    }
825                    _ => panic!("Unsupported data type: {:?}", data_type),
826                }));
827
828                check_round_trip_encoding_generated(field, datagen, test_cases).await;
829            }
830        }
831    }
832}