1use arrow_buffer::ArrowNativeType;
25use lance_core::{Error, Result};
26use snafu::location;
27
28use std::io::Cursor;
29use std::{io::Write, str::FromStr};
30use zstd::bulk::decompress_to_buffer;
31use zstd::stream::copy_decode;
32
33use crate::{
34 buffer::LanceBuffer,
35 compression::VariablePerValueDecompressor,
36 data::{BlockInfo, DataBlock, VariableWidthBlock},
37 encodings::logical::primitive::fullzip::{PerValueCompressor, PerValueDataBlock},
38 format::{pb, ProtobufUtils},
39};
40
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub struct CompressionConfig {
43 pub(crate) scheme: CompressionScheme,
44 pub(crate) level: Option<i32>,
45}
46
47impl CompressionConfig {
48 pub(crate) fn new(scheme: CompressionScheme, level: Option<i32>) -> Self {
49 Self { scheme, level }
50 }
51}
52
53impl Default for CompressionConfig {
54 fn default() -> Self {
55 Self {
56 scheme: CompressionScheme::Lz4,
57 level: Some(0),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq)]
63pub enum CompressionScheme {
64 None,
65 Fsst,
66 Zstd,
67 Lz4,
68}
69
70impl std::fmt::Display for CompressionScheme {
71 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
72 let scheme_str = match self {
73 Self::Fsst => "fsst",
74 Self::Zstd => "zstd",
75 Self::None => "none",
76 Self::Lz4 => "lz4",
77 };
78 write!(f, "{}", scheme_str)
79 }
80}
81
82impl FromStr for CompressionScheme {
83 type Err = Error;
84
85 fn from_str(s: &str) -> Result<Self> {
86 match s {
87 "none" => Ok(Self::None),
88 "fsst" => Ok(Self::Fsst),
89 "zstd" => Ok(Self::Zstd),
90 "lz4" => Ok(Self::Lz4),
91 _ => Err(Error::invalid_input(
92 format!("Unknown compression scheme: {}", s),
93 location!(),
94 )),
95 }
96 }
97}
98
99pub trait BufferCompressor: std::fmt::Debug + Send + Sync {
100 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
101 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
102 fn name(&self) -> &str;
103}
104
105#[derive(Debug, Default)]
106pub struct ZstdBufferCompressor {
107 compression_level: i32,
108}
109
110impl ZstdBufferCompressor {
111 pub fn new(compression_level: i32) -> Self {
112 Self { compression_level }
113 }
114
115 fn is_raw_stream_format(&self, input_buf: &[u8]) -> bool {
117 if input_buf.len() < 8 {
118 return true; }
120 let mut magic_buf = [0u8; 4];
122 magic_buf.copy_from_slice(&input_buf[..4]);
123 let magic = u32::from_le_bytes(magic_buf);
124
125 const ZSTD_MAGIC_NUMBER: u32 = 0xFD2FB528;
127 if magic == ZSTD_MAGIC_NUMBER {
128 const FHD_BYTE_INDEX: usize = 4;
132 let fhd_byte = input_buf[FHD_BYTE_INDEX];
133 const FHD_RESERVED_BIT_MASK: u8 = 0b0001_0000;
134 let reserved_bit = fhd_byte & FHD_RESERVED_BIT_MASK;
135
136 if reserved_bit != 0 {
137 false
141 } else {
142 true
145 }
146 } else {
147 false
149 }
150 }
151
152 fn decompress_length_prefixed_zstd(
153 &self,
154 input_buf: &[u8],
155 output_buf: &mut Vec<u8>,
156 ) -> Result<()> {
157 const LENGTH_PREFIX_SIZE: usize = 8;
158 let mut len_buf = [0u8; LENGTH_PREFIX_SIZE];
159 len_buf.copy_from_slice(&input_buf[..LENGTH_PREFIX_SIZE]);
160
161 let uncompressed_len = u64::from_le_bytes(len_buf) as usize;
162
163 let start = output_buf.len();
164 output_buf.resize(start + uncompressed_len, 0);
165
166 let compressed_data = &input_buf[LENGTH_PREFIX_SIZE..];
167 decompress_to_buffer(compressed_data, &mut output_buf[start..])?;
168 Ok(())
169 }
170}
171
172impl BufferCompressor for ZstdBufferCompressor {
173 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
174 output_buf.write_all(&(input_buf.len() as u64).to_le_bytes())?;
175 let mut encoder = zstd::stream::Encoder::new(output_buf, self.compression_level)?;
176
177 encoder.write_all(input_buf)?;
178 match encoder.finish() {
179 Ok(_) => Ok(()),
180 Err(e) => Err(e.into()),
181 }
182 }
183
184 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
185 if input_buf.is_empty() {
186 return Ok(());
187 }
188
189 let is_raw_stream_format = self.is_raw_stream_format(input_buf);
190 if is_raw_stream_format {
191 copy_decode(Cursor::new(input_buf), output_buf)?;
192 } else {
193 self.decompress_length_prefixed_zstd(input_buf, output_buf)?;
194 }
195
196 Ok(())
197 }
198
199 fn name(&self) -> &str {
200 "zstd"
201 }
202}
203
204#[derive(Debug, Default)]
205pub struct Lz4BufferCompressor {}
206
207impl BufferCompressor for Lz4BufferCompressor {
208 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
209 let start_pos = output_buf.len();
211
212 let max_size = lz4::block::compress_bound(input_buf.len())?;
214 output_buf.resize(start_pos + max_size + 4, 0);
216
217 let compressed_size =
218 lz4::block::compress_to_buffer(input_buf, None, true, &mut output_buf[start_pos..])
219 .map_err(|err| Error::Internal {
220 message: format!("LZ4 compression error: {}", err),
221 location: location!(),
222 })?;
223
224 output_buf.truncate(start_pos + compressed_size);
226 Ok(())
227 }
228
229 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
230 if input_buf.len() < 4 {
233 return Err(Error::Internal {
234 message: "LZ4 compressed data too short".to_string(),
235 location: location!(),
236 });
237 }
238
239 let uncompressed_size =
241 u32::from_le_bytes([input_buf[0], input_buf[1], input_buf[2], input_buf[3]]) as usize;
242
243 let start_pos = output_buf.len();
245
246 output_buf.resize(start_pos + uncompressed_size, 0);
248
249 let decompressed_size =
251 lz4::block::decompress_to_buffer(input_buf, None, &mut output_buf[start_pos..])
252 .map_err(|err| Error::Internal {
253 message: format!("LZ4 decompression error: {}", err),
254 location: location!(),
255 })?;
256
257 output_buf.truncate(start_pos + decompressed_size);
259
260 Ok(())
261 }
262
263 fn name(&self) -> &str {
264 "lz4"
265 }
266}
267
268#[derive(Debug, Default)]
269pub struct NoopBufferCompressor {}
270
271impl BufferCompressor for NoopBufferCompressor {
272 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
273 output_buf.extend_from_slice(input_buf);
274 Ok(())
275 }
276
277 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
278 output_buf.extend_from_slice(input_buf);
279 Ok(())
280 }
281
282 fn name(&self) -> &str {
283 "none"
284 }
285}
286
287pub struct GeneralBufferCompressor {}
288
289impl GeneralBufferCompressor {
290 pub fn get_compressor(compression_config: CompressionConfig) -> Box<dyn BufferCompressor> {
291 match compression_config.scheme {
292 CompressionScheme::Fsst => unimplemented!(),
294 CompressionScheme::Zstd => Box::new(ZstdBufferCompressor::new(
295 compression_config.level.unwrap_or(0),
296 )),
297 CompressionScheme::Lz4 => Box::new(Lz4BufferCompressor::default()),
298 CompressionScheme::None => Box::new(NoopBufferCompressor {}),
299 }
300 }
301}
302
303#[derive(Debug)]
305pub struct CompressedBufferEncoder {
306 pub(crate) compressor: Box<dyn BufferCompressor>,
307}
308
309impl Default for CompressedBufferEncoder {
310 fn default() -> Self {
311 Self {
312 compressor: GeneralBufferCompressor::get_compressor(CompressionConfig {
313 scheme: CompressionScheme::Zstd,
314 level: Some(0),
315 }),
316 }
317 }
318}
319
320impl CompressedBufferEncoder {
321 pub fn new(compression_config: CompressionConfig) -> Self {
322 let compressor = GeneralBufferCompressor::get_compressor(compression_config);
323 Self { compressor }
324 }
325
326 pub fn from_scheme(scheme: &str) -> Result<Self> {
327 let scheme = CompressionScheme::from_str(scheme)?;
328 Ok(Self {
329 compressor: GeneralBufferCompressor::get_compressor(CompressionConfig {
330 scheme,
331 level: Some(0),
332 }),
333 })
334 }
335}
336
337impl CompressedBufferEncoder {
338 pub fn per_value_compress<T: ArrowNativeType>(
339 &self,
340 data: &[u8],
341 offsets: &[T],
342 compressed: &mut Vec<u8>,
343 ) -> Result<LanceBuffer> {
344 let mut new_offsets: Vec<T> = Vec::with_capacity(offsets.len());
345 new_offsets.push(T::from_usize(0).unwrap());
346
347 for off in offsets.windows(2) {
348 let start = off[0].as_usize();
349 let end = off[1].as_usize();
350 self.compressor.compress(&data[start..end], compressed)?;
351 new_offsets.push(T::from_usize(compressed.len()).unwrap());
352 }
353
354 Ok(LanceBuffer::reinterpret_vec(new_offsets))
355 }
356
357 pub fn per_value_decompress<T: ArrowNativeType>(
358 &self,
359 data: &[u8],
360 offsets: &[T],
361 decompressed: &mut Vec<u8>,
362 ) -> Result<LanceBuffer> {
363 let mut new_offsets: Vec<T> = Vec::with_capacity(offsets.len());
364 new_offsets.push(T::from_usize(0).unwrap());
365
366 for off in offsets.windows(2) {
367 let start = off[0].as_usize();
368 let end = off[1].as_usize();
369 self.compressor
370 .decompress(&data[start..end], decompressed)?;
371 new_offsets.push(T::from_usize(decompressed.len()).unwrap());
372 }
373
374 Ok(LanceBuffer::reinterpret_vec(new_offsets))
375 }
376}
377
378impl PerValueCompressor for CompressedBufferEncoder {
379 fn compress(&self, data: DataBlock) -> Result<(PerValueDataBlock, pb::ArrayEncoding)> {
380 let data_type = data.name();
381 let mut data = data.as_variable_width().ok_or(Error::Internal {
382 message: format!(
383 "Attempt to use CompressedBufferEncoder on data of type {}",
384 data_type
385 ),
386 location: location!(),
387 })?;
388
389 let data_bytes = &data.data;
390 let mut compressed = Vec::with_capacity(data_bytes.len());
391
392 let new_offsets = match data.bits_per_offset {
393 32 => self.per_value_compress::<u32>(
394 data_bytes,
395 &data.offsets.borrow_to_typed_slice::<u32>(),
396 &mut compressed,
397 )?,
398 64 => self.per_value_compress::<u64>(
399 data_bytes,
400 &data.offsets.borrow_to_typed_slice::<u64>(),
401 &mut compressed,
402 )?,
403 _ => unreachable!(),
404 };
405
406 let compressed = PerValueDataBlock::Variable(VariableWidthBlock {
407 bits_per_offset: data.bits_per_offset,
408 data: LanceBuffer::from(compressed),
409 offsets: new_offsets,
410 num_values: data.num_values,
411 block_info: BlockInfo::new(),
412 });
413
414 let encoding = ProtobufUtils::block(self.compressor.name());
415
416 Ok((compressed, encoding))
417 }
418}
419
420impl VariablePerValueDecompressor for CompressedBufferEncoder {
421 fn decompress(&self, mut data: VariableWidthBlock) -> Result<DataBlock> {
422 let data_bytes = &data.data;
423 let mut decompressed = Vec::with_capacity(data_bytes.len() * 2);
424
425 let new_offsets = match data.bits_per_offset {
426 32 => self.per_value_decompress(
427 data_bytes,
428 &data.offsets.borrow_to_typed_slice::<u32>(),
429 &mut decompressed,
430 )?,
431 64 => self.per_value_decompress(
432 data_bytes,
433 &data.offsets.borrow_to_typed_slice::<u32>(),
434 &mut decompressed,
435 )?,
436 _ => unreachable!(),
437 };
438 Ok(DataBlock::VariableWidth(VariableWidthBlock {
439 bits_per_offset: data.bits_per_offset,
440 data: LanceBuffer::from(decompressed),
441 offsets: new_offsets,
442 num_values: data.num_values,
443 block_info: BlockInfo::new(),
444 }))
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use std::str::FromStr;
452
453 #[test]
454 fn test_compression_scheme_from_str() {
455 assert_eq!(
456 CompressionScheme::from_str("none").unwrap(),
457 CompressionScheme::None
458 );
459 assert_eq!(
460 CompressionScheme::from_str("zstd").unwrap(),
461 CompressionScheme::Zstd
462 );
463 }
464
465 #[test]
466 fn test_compression_scheme_from_str_invalid() {
467 assert!(CompressionScheme::from_str("invalid").is_err());
468 }
469
470 #[test]
471 fn test_compress_zstd_with_length_prefixed() {
472 let compressor = ZstdBufferCompressor::new(0);
473 let input_data = b"Hello, world!";
474 let mut compressed_data = Vec::new();
475
476 compressor
477 .compress(input_data, &mut compressed_data)
478 .unwrap();
479 let mut decompressed_data = Vec::new();
480 compressor
481 .decompress(&compressed_data, &mut decompressed_data)
482 .unwrap();
483 assert_eq!(input_data, decompressed_data.as_slice());
484 }
485
486 #[test]
487 fn test_zstd_compress_decompress_multiple_times() {
488 let compressor = ZstdBufferCompressor::new(0);
489 let (input_data_1, input_data_2) = (b"Hello ", b"World");
490 let mut compressed_data = Vec::new();
491
492 compressor
493 .compress(input_data_1, &mut compressed_data)
494 .unwrap();
495 let compressed_length_1 = compressed_data.len();
496
497 compressor
498 .compress(input_data_2, &mut compressed_data)
499 .unwrap();
500
501 let mut decompressed_data = Vec::new();
502 compressor
503 .decompress(
504 &compressed_data[..compressed_length_1],
505 &mut decompressed_data,
506 )
507 .unwrap();
508
509 compressor
510 .decompress(
511 &compressed_data[compressed_length_1..],
512 &mut decompressed_data,
513 )
514 .unwrap();
515
516 assert_eq!(
518 decompressed_data.len(),
519 input_data_1.len() + input_data_2.len()
520 );
521 assert_eq!(
522 &decompressed_data[..input_data_1.len()],
523 input_data_1,
524 "First part of decompressed data should match input_1"
525 );
526 assert_eq!(
527 &decompressed_data[input_data_1.len()..],
528 input_data_2,
529 "Second part of decompressed data should match input_2"
530 );
531 }
532
533 #[test]
534 fn test_compress_zstd_raw_stream_format_and_decompress_with_length_prefixed() {
535 let compressor = ZstdBufferCompressor::new(0);
536 let input_data = b"Hello, world!";
537 let mut compressed_data = Vec::new();
538
539 let mut encoder = zstd::Encoder::new(&mut compressed_data, 0).unwrap();
541 encoder.write_all(input_data).unwrap();
542 encoder.finish().expect("failed to encode data with zstd");
543
544 let mut decompressed_data = Vec::new();
546 compressor
547 .decompress(&compressed_data, &mut decompressed_data)
548 .unwrap();
549 assert_eq!(input_data, decompressed_data.as_slice());
550 }
551}