Skip to main content

lance_encoding/encodings/physical/
block.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Encodings based on traditional block compression schemes
5//!
6//! Traditional compressors take in a buffer and return a smaller buffer.  All encoding
7//! description is shoved into the compressed buffer and the entire buffer is needed to
8//! decompress any of the data.
9//!
10//! These encodings are not transparent, which limits our ability to use them.  In addition
11//! they are often quite expensive in CPU terms.
12//!
13//! However, they are effective and useful for some cases.  For example, when working with large
14//! variable length values (e.g. source code files) they can be very effective.
15//!
16//! The module introduces the `[BufferCompressor]` trait which describes the interface for a
17//! traditional block compressor.  It is implemented for the most common compression schemes
18//! (zstd, lz4, etc).
19//!
20//! There is not yet a mini-block variant of this compressor (but could easily be one) and the
21//! full zip variant works by applying compression on a per-value basis (which allows it to be
22//! transparent).
23
24use arrow_buffer::ArrowNativeType;
25use lance_core::{Error, Result};
26use snafu::location;
27
28use std::str::FromStr;
29
30use crate::compression::{BlockCompressor, BlockDecompressor};
31use crate::encodings::physical::binary::{BinaryBlockDecompressor, VariableEncoder};
32use crate::format::{
33    pb21::{self, CompressiveEncoding},
34    ProtobufUtils21,
35};
36use crate::{
37    buffer::LanceBuffer,
38    compression::VariablePerValueDecompressor,
39    data::{BlockInfo, DataBlock, VariableWidthBlock},
40    encodings::logical::primitive::fullzip::{PerValueCompressor, PerValueDataBlock},
41};
42
43#[derive(Debug, Clone, Copy, PartialEq)]
44pub struct CompressionConfig {
45    pub(crate) scheme: CompressionScheme,
46    pub(crate) level: Option<i32>,
47}
48
49impl CompressionConfig {
50    pub(crate) fn new(scheme: CompressionScheme, level: Option<i32>) -> Self {
51        Self { scheme, level }
52    }
53}
54
55impl Default for CompressionConfig {
56    fn default() -> Self {
57        Self {
58            scheme: CompressionScheme::Lz4,
59            level: Some(0),
60        }
61    }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq)]
65pub enum CompressionScheme {
66    None,
67    Fsst,
68    Zstd,
69    Lz4,
70}
71
72impl TryFrom<CompressionScheme> for pb21::CompressionScheme {
73    type Error = Error;
74
75    fn try_from(scheme: CompressionScheme) -> Result<Self> {
76        match scheme {
77            CompressionScheme::Lz4 => Ok(Self::CompressionAlgorithmLz4),
78            CompressionScheme::Zstd => Ok(Self::CompressionAlgorithmZstd),
79            _ => Err(Error::invalid_input(
80                format!("Unsupported compression scheme: {:?}", scheme),
81                location!(),
82            )),
83        }
84    }
85}
86
87impl TryFrom<pb21::CompressionScheme> for CompressionScheme {
88    type Error = Error;
89
90    fn try_from(scheme: pb21::CompressionScheme) -> Result<Self> {
91        match scheme {
92            pb21::CompressionScheme::CompressionAlgorithmLz4 => Ok(Self::Lz4),
93            pb21::CompressionScheme::CompressionAlgorithmZstd => Ok(Self::Zstd),
94            _ => Err(Error::invalid_input(
95                format!("Unsupported compression scheme: {:?}", scheme),
96                location!(),
97            )),
98        }
99    }
100}
101
102impl std::fmt::Display for CompressionScheme {
103    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
104        let scheme_str = match self {
105            Self::Fsst => "fsst",
106            Self::Zstd => "zstd",
107            Self::None => "none",
108            Self::Lz4 => "lz4",
109        };
110        write!(f, "{}", scheme_str)
111    }
112}
113
114impl FromStr for CompressionScheme {
115    type Err = Error;
116
117    fn from_str(s: &str) -> Result<Self> {
118        match s {
119            "none" => Ok(Self::None),
120            "fsst" => Ok(Self::Fsst),
121            "zstd" => Ok(Self::Zstd),
122            "lz4" => Ok(Self::Lz4),
123            _ => Err(Error::invalid_input(
124                format!("Unknown compression scheme: {}", s),
125                location!(),
126            )),
127        }
128    }
129}
130
131pub trait BufferCompressor: std::fmt::Debug + Send + Sync {
132    fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
133    fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
134    fn config(&self) -> CompressionConfig;
135}
136
137#[cfg(feature = "zstd")]
138mod zstd {
139    use std::io::{Cursor, Write};
140    use std::sync::{Mutex, OnceLock};
141
142    use super::*;
143
144    use ::zstd::bulk::{decompress_to_buffer, Compressor};
145    use ::zstd::stream::copy_decode;
146
147    /// A zstd buffer compressor that lazily creates and reuses compression contexts.
148    ///
149    /// The compression context is cached to enable reuse across chunks within a
150    /// page. It is lazily initialized to prevent it from getting initialized on
151    /// decode-only codepaths.
152    ///
153    /// Reuse is not implemented for decompression, only for compression:
154    /// * The single-threaded benefit of reuse was negligible when measured.
155    /// * Decompressors can get shared across threads, leading to mutex
156    ///   contention if the same strategy is used as for compression here. This
157    ///   should be mitigable with pooling but we can skip the complexity until a
158    ///   need is demonstrated. The multithreaded decode benchmark effectively
159    ///   demonstrates this scenario.
160    pub struct ZstdBufferCompressor {
161        compression_level: i32,
162        compressor: OnceLock<std::result::Result<Mutex<Compressor<'static>>, String>>,
163    }
164
165    impl std::fmt::Debug for ZstdBufferCompressor {
166        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167            f.debug_struct("ZstdBufferCompressor")
168                .field("compression_level", &self.compression_level)
169                .finish()
170        }
171    }
172
173    impl ZstdBufferCompressor {
174        pub fn new(compression_level: i32) -> Self {
175            Self {
176                compression_level,
177                compressor: OnceLock::new(),
178            }
179        }
180
181        fn get_compressor(&self) -> Result<&Mutex<Compressor<'static>>> {
182            self.compressor
183                .get_or_init(|| {
184                    Compressor::new(self.compression_level)
185                        .map(Mutex::new)
186                        .map_err(|e| e.to_string())
187                })
188                .as_ref()
189                .map_err(|e| Error::Internal {
190                    message: format!("Failed to create zstd compressor: {}", e),
191                    location: location!(),
192                })
193        }
194
195        // https://datatracker.ietf.org/doc/html/rfc8878
196        fn is_raw_stream_format(&self, input_buf: &[u8]) -> bool {
197            if input_buf.len() < 8 {
198                return true; // can't be length prefixed format if less than 8 bytes
199            }
200            // read the first 4 bytes as the magic number
201            let mut magic_buf = [0u8; 4];
202            magic_buf.copy_from_slice(&input_buf[..4]);
203            let magic = u32::from_le_bytes(magic_buf);
204
205            // see RFC 8878, section 3.1.1. Zstandard Frames, which defines the magic number
206            const ZSTD_MAGIC_NUMBER: u32 = 0xFD2FB528;
207            if magic == ZSTD_MAGIC_NUMBER {
208                // the compressed buffer starts like a Zstd frame.
209                // Per RFC 8878, the reserved bit (with Bit Number 3, the 4th bit) in the FHD (frame header descriptor) MUST be 0
210                // see section 3.1.1.1.1. 'Frame_Header_Descriptor' and section 3.1.1.1.1.4. 'Reserved Bit' for details
211                const FHD_BYTE_INDEX: usize = 4;
212                let fhd_byte = input_buf[FHD_BYTE_INDEX];
213                const FHD_RESERVED_BIT_MASK: u8 = 0b0001_0000;
214                let reserved_bit = fhd_byte & FHD_RESERVED_BIT_MASK;
215
216                if reserved_bit != 0 {
217                    // this bit is 1. This is NOT a valid zstd frame.
218                    // therefore, it must be length prefixed format where the length coincidentally
219                    // started with the magic number
220                    false
221                } else {
222                    // the reserved bit is 0. This is consistent with a valid Zstd frame.
223                    // treat it as raw stream format
224                    true
225                }
226            } else {
227                // doesn't start with the magic number, so it can't be the raw stream format
228                false
229            }
230        }
231
232        fn decompress_length_prefixed_zstd(
233            &self,
234            input_buf: &[u8],
235            output_buf: &mut Vec<u8>,
236        ) -> Result<()> {
237            const LENGTH_PREFIX_SIZE: usize = 8;
238            let mut len_buf = [0u8; LENGTH_PREFIX_SIZE];
239            len_buf.copy_from_slice(&input_buf[..LENGTH_PREFIX_SIZE]);
240
241            let uncompressed_len = u64::from_le_bytes(len_buf) as usize;
242
243            let start = output_buf.len();
244            output_buf.resize(start + uncompressed_len, 0);
245
246            let compressed_data = &input_buf[LENGTH_PREFIX_SIZE..];
247            decompress_to_buffer(compressed_data, &mut output_buf[start..])?;
248            Ok(())
249        }
250    }
251
252    impl BufferCompressor for ZstdBufferCompressor {
253        fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
254            output_buf.write_all(&(input_buf.len() as u64).to_le_bytes())?;
255
256            let max_compressed_size = ::zstd::zstd_safe::compress_bound(input_buf.len());
257            let start_pos = output_buf.len();
258            output_buf.resize(start_pos + max_compressed_size, 0);
259
260            let compressed_size = self
261                .get_compressor()?
262                .lock()
263                .unwrap()
264                .compress_to_buffer(input_buf, &mut output_buf[start_pos..])
265                .map_err(|e| Error::Internal {
266                    message: format!("Zstd compression error: {}", e),
267                    location: location!(),
268                })?;
269
270            output_buf.truncate(start_pos + compressed_size);
271            Ok(())
272        }
273
274        fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
275            if input_buf.is_empty() {
276                return Ok(());
277            }
278
279            let is_raw_stream_format = self.is_raw_stream_format(input_buf);
280            if is_raw_stream_format {
281                copy_decode(Cursor::new(input_buf), output_buf)?;
282            } else {
283                self.decompress_length_prefixed_zstd(input_buf, output_buf)?;
284            }
285
286            Ok(())
287        }
288
289        fn config(&self) -> CompressionConfig {
290            CompressionConfig {
291                scheme: CompressionScheme::Zstd,
292                level: Some(self.compression_level),
293            }
294        }
295    }
296}
297
298#[cfg(feature = "lz4")]
299mod lz4 {
300    use super::*;
301
302    #[derive(Debug, Default)]
303    pub struct Lz4BufferCompressor {}
304
305    impl BufferCompressor for Lz4BufferCompressor {
306        fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
307            // Remember the starting position
308            let start_pos = output_buf.len();
309
310            // LZ4 needs space for the compressed data
311            let max_size = ::lz4::block::compress_bound(input_buf.len())?;
312            // Resize to ensure we have enough space (including 4 bytes for size header)
313            output_buf.resize(start_pos + max_size + 4, 0);
314
315            let compressed_size = ::lz4::block::compress_to_buffer(
316                input_buf,
317                None,
318                true,
319                &mut output_buf[start_pos..],
320            )
321            .map_err(|err| Error::Internal {
322                message: format!("LZ4 compression error: {}", err),
323                location: location!(),
324            })?;
325
326            // Truncate to actual size
327            output_buf.truncate(start_pos + compressed_size);
328            Ok(())
329        }
330
331        fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
332            // When prepend_size is true, LZ4 stores the uncompressed size in the first 4 bytes
333            // We can read this to know exactly how much space we need
334            if input_buf.len() < 4 {
335                return Err(Error::Internal {
336                    message: "LZ4 compressed data too short".to_string(),
337                    location: location!(),
338                });
339            }
340
341            // Read the uncompressed size from the first 4 bytes (little-endian)
342            let uncompressed_size =
343                u32::from_le_bytes([input_buf[0], input_buf[1], input_buf[2], input_buf[3]])
344                    as usize;
345
346            // Remember the starting position
347            let start_pos = output_buf.len();
348
349            // Resize to ensure we have the exact space needed
350            output_buf.resize(start_pos + uncompressed_size, 0);
351
352            // Now decompress directly into the buffer slice
353            let decompressed_size =
354                ::lz4::block::decompress_to_buffer(input_buf, None, &mut output_buf[start_pos..])
355                    .map_err(|err| Error::Internal {
356                    message: format!("LZ4 decompression error: {}", err),
357                    location: location!(),
358                })?;
359
360            // Truncate to actual decompressed size (should be same as uncompressed_size)
361            output_buf.truncate(start_pos + decompressed_size);
362
363            Ok(())
364        }
365
366        fn config(&self) -> CompressionConfig {
367            CompressionConfig {
368                scheme: CompressionScheme::Lz4,
369                level: None,
370            }
371        }
372    }
373}
374
375#[derive(Debug, Default)]
376pub struct NoopBufferCompressor {}
377
378impl BufferCompressor for NoopBufferCompressor {
379    fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
380        output_buf.extend_from_slice(input_buf);
381        Ok(())
382    }
383
384    fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
385        output_buf.extend_from_slice(input_buf);
386        Ok(())
387    }
388
389    fn config(&self) -> CompressionConfig {
390        CompressionConfig {
391            scheme: CompressionScheme::None,
392            level: None,
393        }
394    }
395}
396
397pub struct GeneralBufferCompressor {}
398
399impl GeneralBufferCompressor {
400    pub fn get_compressor(
401        compression_config: CompressionConfig,
402    ) -> Result<Box<dyn BufferCompressor>> {
403        match compression_config.scheme {
404            // FSST has its own compression path and isn't implemented as a generic buffer compressor
405            CompressionScheme::Fsst => Err(Error::InvalidInput {
406                source: "fsst is not usable as a general buffer compressor".into(),
407                location: location!(),
408            }),
409            CompressionScheme::Zstd => {
410                #[cfg(feature = "zstd")]
411                {
412                    Ok(Box::new(zstd::ZstdBufferCompressor::new(
413                        compression_config.level.unwrap_or(0),
414                    )))
415                }
416                #[cfg(not(feature = "zstd"))]
417                {
418                    Err(Error::InvalidInput {
419                        source: "package was not built with zstd support".into(),
420                        location: location!(),
421                    })
422                }
423            }
424            CompressionScheme::Lz4 => {
425                #[cfg(feature = "lz4")]
426                {
427                    Ok(Box::new(lz4::Lz4BufferCompressor::default()))
428                }
429                #[cfg(not(feature = "lz4"))]
430                {
431                    Err(Error::InvalidInput {
432                        source: "package was not built with lz4 support".into(),
433                        location: location!(),
434                    })
435                }
436            }
437            CompressionScheme::None => Ok(Box::new(NoopBufferCompressor {})),
438        }
439    }
440}
441
442/// A block decompressor that first applies general-purpose compression (LZ4/Zstd)
443/// before delegating to an inner block decompressor.
444#[derive(Debug)]
445pub struct GeneralBlockDecompressor {
446    inner: Box<dyn BlockDecompressor>,
447    compressor: Box<dyn BufferCompressor>,
448}
449
450impl GeneralBlockDecompressor {
451    pub fn try_new(
452        inner: Box<dyn BlockDecompressor>,
453        compression: CompressionConfig,
454    ) -> Result<Self> {
455        let compressor = GeneralBufferCompressor::get_compressor(compression)?;
456        Ok(Self { inner, compressor })
457    }
458}
459
460impl BlockDecompressor for GeneralBlockDecompressor {
461    fn decompress(&self, data: LanceBuffer, num_values: u64) -> Result<DataBlock> {
462        let mut decompressed = Vec::new();
463        self.compressor.decompress(&data, &mut decompressed)?;
464        self.inner
465            .decompress(LanceBuffer::from(decompressed), num_values)
466    }
467}
468
469// An encoder which uses generic compression, such as zstd/lz4 to encode buffers
470#[derive(Debug)]
471pub struct CompressedBufferEncoder {
472    pub(crate) compressor: Box<dyn BufferCompressor>,
473}
474
475impl Default for CompressedBufferEncoder {
476    fn default() -> Self {
477        // Pick zstd if available, otherwise lz4, otherwise none
478        #[cfg(feature = "zstd")]
479        let (scheme, level) = (CompressionScheme::Zstd, Some(0));
480        #[cfg(all(feature = "lz4", not(feature = "zstd")))]
481        let (scheme, level) = (CompressionScheme::Lz4, None);
482        #[cfg(not(any(feature = "zstd", feature = "lz4")))]
483        let (scheme, level) = (CompressionScheme::None, None);
484
485        let compressor =
486            GeneralBufferCompressor::get_compressor(CompressionConfig { scheme, level }).unwrap();
487        Self { compressor }
488    }
489}
490
491impl CompressedBufferEncoder {
492    pub fn try_new(compression_config: CompressionConfig) -> Result<Self> {
493        let compressor = GeneralBufferCompressor::get_compressor(compression_config)?;
494        Ok(Self { compressor })
495    }
496
497    pub fn from_scheme(scheme: pb21::CompressionScheme) -> Result<Self> {
498        let scheme = CompressionScheme::try_from(scheme)?;
499        Ok(Self {
500            compressor: GeneralBufferCompressor::get_compressor(CompressionConfig {
501                scheme,
502                level: Some(0),
503            })?,
504        })
505    }
506}
507
508impl CompressedBufferEncoder {
509    pub fn per_value_compress<T: ArrowNativeType>(
510        &self,
511        data: &[u8],
512        offsets: &[T],
513        compressed: &mut Vec<u8>,
514    ) -> Result<LanceBuffer> {
515        let mut new_offsets: Vec<T> = Vec::with_capacity(offsets.len());
516        new_offsets.push(T::from_usize(0).unwrap());
517
518        for off in offsets.windows(2) {
519            let start = off[0].as_usize();
520            let end = off[1].as_usize();
521            self.compressor.compress(&data[start..end], compressed)?;
522            new_offsets.push(T::from_usize(compressed.len()).unwrap());
523        }
524
525        Ok(LanceBuffer::reinterpret_vec(new_offsets))
526    }
527
528    pub fn per_value_decompress<T: ArrowNativeType>(
529        &self,
530        data: &[u8],
531        offsets: &[T],
532        decompressed: &mut Vec<u8>,
533    ) -> Result<LanceBuffer> {
534        let mut new_offsets: Vec<T> = Vec::with_capacity(offsets.len());
535        new_offsets.push(T::from_usize(0).unwrap());
536
537        for off in offsets.windows(2) {
538            let start = off[0].as_usize();
539            let end = off[1].as_usize();
540            self.compressor
541                .decompress(&data[start..end], decompressed)?;
542            new_offsets.push(T::from_usize(decompressed.len()).unwrap());
543        }
544
545        Ok(LanceBuffer::reinterpret_vec(new_offsets))
546    }
547}
548
549impl PerValueCompressor for CompressedBufferEncoder {
550    fn compress(&self, data: DataBlock) -> Result<(PerValueDataBlock, CompressiveEncoding)> {
551        let data_type = data.name();
552        let data = data.as_variable_width().ok_or(Error::Internal {
553            message: format!(
554                "Attempt to use CompressedBufferEncoder on data of type {}",
555                data_type
556            ),
557            location: location!(),
558        })?;
559
560        let data_bytes = &data.data;
561        let mut compressed = Vec::with_capacity(data_bytes.len());
562
563        let new_offsets = match data.bits_per_offset {
564            32 => self.per_value_compress::<u32>(
565                data_bytes,
566                &data.offsets.borrow_to_typed_slice::<u32>(),
567                &mut compressed,
568            )?,
569            64 => self.per_value_compress::<u64>(
570                data_bytes,
571                &data.offsets.borrow_to_typed_slice::<u64>(),
572                &mut compressed,
573            )?,
574            _ => unreachable!(),
575        };
576
577        let compressed = PerValueDataBlock::Variable(VariableWidthBlock {
578            bits_per_offset: data.bits_per_offset,
579            data: LanceBuffer::from(compressed),
580            offsets: new_offsets,
581            num_values: data.num_values,
582            block_info: BlockInfo::new(),
583        });
584
585        // TODO: Support setting the level
586        // TODO: Support underlying compression of data (e.g. defer to binary encoding for offset bitpacking)
587        let encoding = ProtobufUtils21::wrapped(
588            self.compressor.config(),
589            ProtobufUtils21::variable(
590                ProtobufUtils21::flat(data.bits_per_offset as u64, None),
591                None,
592            ),
593        )?;
594
595        Ok((compressed, encoding))
596    }
597}
598
599impl VariablePerValueDecompressor for CompressedBufferEncoder {
600    fn decompress(&self, data: VariableWidthBlock) -> Result<DataBlock> {
601        let data_bytes = &data.data;
602        let mut decompressed = Vec::with_capacity(data_bytes.len() * 2);
603
604        let new_offsets = match data.bits_per_offset {
605            32 => self.per_value_decompress(
606                data_bytes,
607                &data.offsets.borrow_to_typed_slice::<u32>(),
608                &mut decompressed,
609            )?,
610            64 => self.per_value_decompress(
611                data_bytes,
612                &data.offsets.borrow_to_typed_slice::<u64>(),
613                &mut decompressed,
614            )?,
615            _ => unreachable!(),
616        };
617        Ok(DataBlock::VariableWidth(VariableWidthBlock {
618            bits_per_offset: data.bits_per_offset,
619            data: LanceBuffer::from(decompressed),
620            offsets: new_offsets,
621            num_values: data.num_values,
622            block_info: BlockInfo::new(),
623        }))
624    }
625}
626
627impl BlockCompressor for CompressedBufferEncoder {
628    fn compress(&self, data: DataBlock) -> Result<LanceBuffer> {
629        let encoded = match data {
630            DataBlock::FixedWidth(fixed_width) => fixed_width.data,
631            DataBlock::VariableWidth(variable_width) => {
632                // Wrap VariableEncoder to handle the encoding
633                let encoder = VariableEncoder::default();
634                BlockCompressor::compress(&encoder, DataBlock::VariableWidth(variable_width))?
635            }
636            _ => {
637                return Err(Error::InvalidInput {
638                    source: "Unsupported data block type".into(),
639                    location: location!(),
640                })
641            }
642        };
643
644        let mut compressed = Vec::new();
645        self.compressor.compress(&encoded, &mut compressed)?;
646        Ok(LanceBuffer::from(compressed))
647    }
648}
649
650impl BlockDecompressor for CompressedBufferEncoder {
651    fn decompress(&self, data: LanceBuffer, num_values: u64) -> Result<DataBlock> {
652        let mut decompressed = Vec::new();
653        self.compressor.decompress(&data, &mut decompressed)?;
654
655        // Delegate to BinaryBlockDecompressor which handles the inline metadata
656        let inner_decoder = BinaryBlockDecompressor::default();
657        inner_decoder.decompress(LanceBuffer::from(decompressed), num_values)
658    }
659}
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664    use std::str::FromStr;
665
666    use crate::encodings::physical::block::zstd::ZstdBufferCompressor;
667
668    #[test]
669    fn test_compression_scheme_from_str() {
670        assert_eq!(
671            CompressionScheme::from_str("none").unwrap(),
672            CompressionScheme::None
673        );
674        assert_eq!(
675            CompressionScheme::from_str("zstd").unwrap(),
676            CompressionScheme::Zstd
677        );
678    }
679
680    #[test]
681    fn test_compression_scheme_from_str_invalid() {
682        assert!(CompressionScheme::from_str("invalid").is_err());
683    }
684
685    #[cfg(feature = "zstd")]
686    mod zstd {
687        use std::io::Write;
688
689        use super::*;
690
691        #[test]
692        fn test_compress_zstd_with_length_prefixed() {
693            let compressor = ZstdBufferCompressor::new(0);
694            let input_data = b"Hello, world!";
695            let mut compressed_data = Vec::new();
696
697            compressor
698                .compress(input_data, &mut compressed_data)
699                .unwrap();
700            let mut decompressed_data = Vec::new();
701            compressor
702                .decompress(&compressed_data, &mut decompressed_data)
703                .unwrap();
704            assert_eq!(input_data, decompressed_data.as_slice());
705        }
706
707        #[test]
708        fn test_zstd_compress_decompress_multiple_times() {
709            let compressor = ZstdBufferCompressor::new(0);
710            let (input_data_1, input_data_2) = (b"Hello ", b"World");
711            let mut compressed_data = Vec::new();
712
713            compressor
714                .compress(input_data_1, &mut compressed_data)
715                .unwrap();
716            let compressed_length_1 = compressed_data.len();
717
718            compressor
719                .compress(input_data_2, &mut compressed_data)
720                .unwrap();
721
722            let mut decompressed_data = Vec::new();
723            compressor
724                .decompress(
725                    &compressed_data[..compressed_length_1],
726                    &mut decompressed_data,
727                )
728                .unwrap();
729
730            compressor
731                .decompress(
732                    &compressed_data[compressed_length_1..],
733                    &mut decompressed_data,
734                )
735                .unwrap();
736
737            // the output should contain both input_data_1 and input_data_2
738            assert_eq!(
739                decompressed_data.len(),
740                input_data_1.len() + input_data_2.len()
741            );
742            assert_eq!(
743                &decompressed_data[..input_data_1.len()],
744                input_data_1,
745                "First part of decompressed data should match input_1"
746            );
747            assert_eq!(
748                &decompressed_data[input_data_1.len()..],
749                input_data_2,
750                "Second part of decompressed data should match input_2"
751            );
752        }
753
754        #[test]
755        fn test_compress_zstd_raw_stream_format_and_decompress_with_length_prefixed() {
756            let compressor = ZstdBufferCompressor::new(0);
757            let input_data = b"Hello, world!";
758            let mut compressed_data = Vec::new();
759
760            // compress using raw stream format
761            let mut encoder = ::zstd::Encoder::new(&mut compressed_data, 0).unwrap();
762            encoder.write_all(input_data).unwrap();
763            encoder.finish().expect("failed to encode data with zstd");
764
765            // decompress using length prefixed format
766            let mut decompressed_data = Vec::new();
767            compressor
768                .decompress(&compressed_data, &mut decompressed_data)
769                .unwrap();
770            assert_eq!(input_data, decompressed_data.as_slice());
771        }
772    }
773
774    #[cfg(feature = "lz4")]
775    mod lz4 {
776        use std::{collections::HashMap, sync::Arc};
777
778        use arrow_schema::{DataType, Field};
779        use lance_datagen::array::{binary_prefix_plus_counter, utf8_prefix_plus_counter};
780
781        use super::*;
782
783        use crate::constants::DICT_SIZE_RATIO_META_KEY;
784        use crate::{
785            constants::{
786                COMPRESSION_META_KEY, DICT_DIVISOR_META_KEY, STRUCTURAL_ENCODING_FULLZIP,
787                STRUCTURAL_ENCODING_META_KEY,
788            },
789            encodings::physical::block::lz4::Lz4BufferCompressor,
790            testing::{check_round_trip_encoding_generated, FnArrayGeneratorProvider, TestCases},
791            version::LanceFileVersion,
792        };
793
794        #[test]
795        fn test_lz4_compress_decompress() {
796            let compressor = Lz4BufferCompressor::default();
797            let input_data = b"Hello, world!";
798            let mut compressed_data = Vec::new();
799
800            compressor
801                .compress(input_data, &mut compressed_data)
802                .unwrap();
803            let mut decompressed_data = Vec::new();
804            compressor
805                .decompress(&compressed_data, &mut decompressed_data)
806                .unwrap();
807            assert_eq!(input_data, decompressed_data.as_slice());
808        }
809
810        #[test_log::test(tokio::test)]
811        async fn test_lz4_compress_round_trip() {
812            for data_type in &[
813                DataType::Utf8,
814                DataType::LargeUtf8,
815                DataType::Binary,
816                DataType::LargeBinary,
817            ] {
818                let field = Field::new("", data_type.clone(), false);
819                let mut field_meta = HashMap::new();
820                field_meta.insert(COMPRESSION_META_KEY.to_string(), "lz4".to_string());
821                // Some bad cardinality estimatation causes us to use dictionary encoding currently
822                // which causes the expected encoding check to fail.
823                field_meta.insert(DICT_DIVISOR_META_KEY.to_string(), "100000".to_string());
824                field_meta.insert(DICT_SIZE_RATIO_META_KEY.to_string(), "0.0001".to_string());
825                // Also disable size-based dictionary encoding
826                field_meta.insert(
827                    STRUCTURAL_ENCODING_META_KEY.to_string(),
828                    STRUCTURAL_ENCODING_FULLZIP.to_string(),
829                );
830                let field = field.with_metadata(field_meta);
831                let test_cases = TestCases::basic()
832                    // Need to use large pages as small pages might be too small to compress
833                    .with_page_sizes(vec![1024 * 1024])
834                    .with_expected_encoding("zstd")
835                    .with_min_file_version(LanceFileVersion::V2_1);
836
837                // Can't use the default random provider because random data isn't compressible
838                // and we will fallback to uncompressed encoding
839                let datagen = Box::new(FnArrayGeneratorProvider::new(move || match data_type {
840                    DataType::Utf8 => utf8_prefix_plus_counter("compressme", false),
841                    DataType::Binary => {
842                        binary_prefix_plus_counter(Arc::from(b"compressme".to_owned()), false)
843                    }
844                    DataType::LargeUtf8 => utf8_prefix_plus_counter("compressme", true),
845                    DataType::LargeBinary => {
846                        binary_prefix_plus_counter(Arc::from(b"compressme".to_owned()), true)
847                    }
848                    _ => panic!("Unsupported data type: {:?}", data_type),
849                }));
850
851                check_round_trip_encoding_generated(field, datagen, test_cases).await;
852            }
853        }
854    }
855}