1use arrow_buffer::ArrowNativeType;
25use lance_core::{Error, Result};
26use snafu::location;
27
28use std::io::Cursor;
29use std::{io::Write, str::FromStr};
30use zstd::bulk::decompress_to_buffer;
31use zstd::stream::copy_decode;
32
33use crate::{
34 buffer::LanceBuffer,
35 compression::VariablePerValueDecompressor,
36 data::{BlockInfo, DataBlock, VariableWidthBlock},
37 encodings::logical::primitive::fullzip::{PerValueCompressor, PerValueDataBlock},
38 format::{pb, ProtobufUtils},
39};
40
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub struct CompressionConfig {
43 pub(crate) scheme: CompressionScheme,
44 pub(crate) level: Option<i32>,
45}
46
47impl CompressionConfig {
48 pub(crate) fn new(scheme: CompressionScheme, level: Option<i32>) -> Self {
49 Self { scheme, level }
50 }
51}
52
53impl Default for CompressionConfig {
54 fn default() -> Self {
55 Self {
56 scheme: CompressionScheme::Lz4,
57 level: Some(0),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq)]
63pub enum CompressionScheme {
64 None,
65 Fsst,
66 Zstd,
67 Lz4,
68}
69
70impl std::fmt::Display for CompressionScheme {
71 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
72 let scheme_str = match self {
73 Self::Fsst => "fsst",
74 Self::Zstd => "zstd",
75 Self::None => "none",
76 Self::Lz4 => "lz4",
77 };
78 write!(f, "{}", scheme_str)
79 }
80}
81
82impl FromStr for CompressionScheme {
83 type Err = Error;
84
85 fn from_str(s: &str) -> Result<Self> {
86 match s {
87 "none" => Ok(Self::None),
88 "zstd" => Ok(Self::Zstd),
89 "lz4" => Ok(Self::Lz4),
90 _ => Err(Error::invalid_input(
91 format!("Unknown compression scheme: {}", s),
92 location!(),
93 )),
94 }
95 }
96}
97
98pub trait BufferCompressor: std::fmt::Debug + Send + Sync {
99 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
100 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
101 fn name(&self) -> &str;
102}
103
104#[derive(Debug, Default)]
105pub struct ZstdBufferCompressor {
106 compression_level: i32,
107}
108
109impl ZstdBufferCompressor {
110 pub fn new(compression_level: i32) -> Self {
111 Self { compression_level }
112 }
113
114 fn is_raw_stream_format(&self, input_buf: &[u8]) -> bool {
116 if input_buf.len() < 8 {
117 return true; }
119 let mut magic_buf = [0u8; 4];
121 magic_buf.copy_from_slice(&input_buf[..4]);
122 let magic = u32::from_le_bytes(magic_buf);
123
124 const ZSTD_MAGIC_NUMBER: u32 = 0xFD2FB528;
126 if magic == ZSTD_MAGIC_NUMBER {
127 const FHD_BYTE_INDEX: usize = 4;
131 let fhd_byte = input_buf[FHD_BYTE_INDEX];
132 const FHD_RESERVED_BIT_MASK: u8 = 0b0001_0000;
133 let reserved_bit = fhd_byte & FHD_RESERVED_BIT_MASK;
134
135 if reserved_bit != 0 {
136 false
140 } else {
141 true
144 }
145 } else {
146 false
148 }
149 }
150
151 fn decompress_length_prefixed_zstd(
152 &self,
153 input_buf: &[u8],
154 output_buf: &mut Vec<u8>,
155 ) -> Result<()> {
156 const LENGTH_PREFIX_SIZE: usize = 8;
157 let mut len_buf = [0u8; LENGTH_PREFIX_SIZE];
158 len_buf.copy_from_slice(&input_buf[..LENGTH_PREFIX_SIZE]);
159
160 let uncompressed_len = u64::from_le_bytes(len_buf) as usize;
161
162 let start = output_buf.len();
163 output_buf.resize(start + uncompressed_len, 0);
164
165 let compressed_data = &input_buf[LENGTH_PREFIX_SIZE..];
166 decompress_to_buffer(compressed_data, &mut output_buf[start..])?;
167 Ok(())
168 }
169}
170
171impl BufferCompressor for ZstdBufferCompressor {
172 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
173 output_buf.write_all(&(input_buf.len() as u64).to_le_bytes())?;
174 let mut encoder = zstd::stream::Encoder::new(output_buf, self.compression_level)?;
175
176 encoder.write_all(input_buf)?;
177 match encoder.finish() {
178 Ok(_) => Ok(()),
179 Err(e) => Err(e.into()),
180 }
181 }
182
183 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
184 if input_buf.is_empty() {
185 return Ok(());
186 }
187
188 let is_raw_stream_format = self.is_raw_stream_format(input_buf);
189 if is_raw_stream_format {
190 copy_decode(Cursor::new(input_buf), output_buf)?;
191 } else {
192 self.decompress_length_prefixed_zstd(input_buf, output_buf)?;
193 }
194
195 Ok(())
196 }
197
198 fn name(&self) -> &str {
199 "zstd"
200 }
201}
202
203#[derive(Debug, Default)]
204pub struct Lz4BufferCompressor {}
205
206impl BufferCompressor for Lz4BufferCompressor {
207 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
208 let start_pos = output_buf.len();
210
211 let max_size = lz4::block::compress_bound(input_buf.len())?;
213 output_buf.resize(start_pos + max_size + 4, 0);
215
216 let compressed_size =
217 lz4::block::compress_to_buffer(input_buf, None, true, &mut output_buf[start_pos..])
218 .map_err(|err| Error::Internal {
219 message: format!("LZ4 compression error: {}", err),
220 location: location!(),
221 })?;
222
223 output_buf.truncate(start_pos + compressed_size);
225 Ok(())
226 }
227
228 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
229 if input_buf.len() < 4 {
232 return Err(Error::Internal {
233 message: "LZ4 compressed data too short".to_string(),
234 location: location!(),
235 });
236 }
237
238 let uncompressed_size =
240 u32::from_le_bytes([input_buf[0], input_buf[1], input_buf[2], input_buf[3]]) as usize;
241
242 let start_pos = output_buf.len();
244
245 output_buf.resize(start_pos + uncompressed_size, 0);
247
248 let decompressed_size =
250 lz4::block::decompress_to_buffer(input_buf, None, &mut output_buf[start_pos..])
251 .map_err(|err| Error::Internal {
252 message: format!("LZ4 decompression error: {}", err),
253 location: location!(),
254 })?;
255
256 output_buf.truncate(start_pos + decompressed_size);
258
259 Ok(())
260 }
261
262 fn name(&self) -> &str {
263 "lz4"
264 }
265}
266
267#[derive(Debug, Default)]
268pub struct NoopBufferCompressor {}
269
270impl BufferCompressor for NoopBufferCompressor {
271 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
272 output_buf.extend_from_slice(input_buf);
273 Ok(())
274 }
275
276 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
277 output_buf.extend_from_slice(input_buf);
278 Ok(())
279 }
280
281 fn name(&self) -> &str {
282 "none"
283 }
284}
285
286pub struct GeneralBufferCompressor {}
287
288impl GeneralBufferCompressor {
289 pub fn get_compressor(compression_config: CompressionConfig) -> Box<dyn BufferCompressor> {
290 match compression_config.scheme {
291 CompressionScheme::Fsst => unimplemented!(),
293 CompressionScheme::Zstd => Box::new(ZstdBufferCompressor::new(
294 compression_config.level.unwrap_or(0),
295 )),
296 CompressionScheme::Lz4 => Box::new(Lz4BufferCompressor::default()),
297 CompressionScheme::None => Box::new(NoopBufferCompressor {}),
298 }
299 }
300}
301
302#[derive(Debug)]
304pub struct CompressedBufferEncoder {
305 pub(crate) compressor: Box<dyn BufferCompressor>,
306}
307
308impl Default for CompressedBufferEncoder {
309 fn default() -> Self {
310 Self {
311 compressor: GeneralBufferCompressor::get_compressor(CompressionConfig {
312 scheme: CompressionScheme::Zstd,
313 level: Some(0),
314 }),
315 }
316 }
317}
318
319impl CompressedBufferEncoder {
320 pub fn new(compression_config: CompressionConfig) -> Self {
321 let compressor = GeneralBufferCompressor::get_compressor(compression_config);
322 Self { compressor }
323 }
324
325 pub fn from_scheme(scheme: &str) -> Result<Self> {
326 let scheme = CompressionScheme::from_str(scheme)?;
327 Ok(Self {
328 compressor: GeneralBufferCompressor::get_compressor(CompressionConfig {
329 scheme,
330 level: Some(0),
331 }),
332 })
333 }
334}
335
336impl CompressedBufferEncoder {
337 pub fn per_value_compress<T: ArrowNativeType>(
338 &self,
339 data: &[u8],
340 offsets: &[T],
341 compressed: &mut Vec<u8>,
342 ) -> Result<LanceBuffer> {
343 let mut new_offsets: Vec<T> = Vec::with_capacity(offsets.len());
344 new_offsets.push(T::from_usize(0).unwrap());
345
346 for off in offsets.windows(2) {
347 let start = off[0].as_usize();
348 let end = off[1].as_usize();
349 self.compressor.compress(&data[start..end], compressed)?;
350 new_offsets.push(T::from_usize(compressed.len()).unwrap());
351 }
352
353 Ok(LanceBuffer::reinterpret_vec(new_offsets))
354 }
355
356 pub fn per_value_decompress<T: ArrowNativeType>(
357 &self,
358 data: &[u8],
359 offsets: &[T],
360 decompressed: &mut Vec<u8>,
361 ) -> Result<LanceBuffer> {
362 let mut new_offsets: Vec<T> = Vec::with_capacity(offsets.len());
363 new_offsets.push(T::from_usize(0).unwrap());
364
365 for off in offsets.windows(2) {
366 let start = off[0].as_usize();
367 let end = off[1].as_usize();
368 self.compressor
369 .decompress(&data[start..end], decompressed)?;
370 new_offsets.push(T::from_usize(decompressed.len()).unwrap());
371 }
372
373 Ok(LanceBuffer::reinterpret_vec(new_offsets))
374 }
375}
376
377impl PerValueCompressor for CompressedBufferEncoder {
378 fn compress(&self, data: DataBlock) -> Result<(PerValueDataBlock, pb::ArrayEncoding)> {
379 let data_type = data.name();
380 let mut data = data.as_variable_width().ok_or(Error::Internal {
381 message: format!(
382 "Attempt to use CompressedBufferEncoder on data of type {}",
383 data_type
384 ),
385 location: location!(),
386 })?;
387
388 let data_bytes = &data.data;
389 let mut compressed = Vec::with_capacity(data_bytes.len());
390
391 let new_offsets = match data.bits_per_offset {
392 32 => self.per_value_compress::<u32>(
393 data_bytes,
394 &data.offsets.borrow_to_typed_slice::<u32>(),
395 &mut compressed,
396 )?,
397 64 => self.per_value_compress::<u64>(
398 data_bytes,
399 &data.offsets.borrow_to_typed_slice::<u64>(),
400 &mut compressed,
401 )?,
402 _ => unreachable!(),
403 };
404
405 let compressed = PerValueDataBlock::Variable(VariableWidthBlock {
406 bits_per_offset: data.bits_per_offset,
407 data: LanceBuffer::from(compressed),
408 offsets: new_offsets,
409 num_values: data.num_values,
410 block_info: BlockInfo::new(),
411 });
412
413 let encoding = ProtobufUtils::block(self.compressor.name());
414
415 Ok((compressed, encoding))
416 }
417}
418
419impl VariablePerValueDecompressor for CompressedBufferEncoder {
420 fn decompress(&self, mut data: VariableWidthBlock) -> Result<DataBlock> {
421 let data_bytes = &data.data;
422 let mut decompressed = Vec::with_capacity(data_bytes.len() * 2);
423
424 let new_offsets = match data.bits_per_offset {
425 32 => self.per_value_decompress(
426 data_bytes,
427 &data.offsets.borrow_to_typed_slice::<u32>(),
428 &mut decompressed,
429 )?,
430 64 => self.per_value_decompress(
431 data_bytes,
432 &data.offsets.borrow_to_typed_slice::<u32>(),
433 &mut decompressed,
434 )?,
435 _ => unreachable!(),
436 };
437 Ok(DataBlock::VariableWidth(VariableWidthBlock {
438 bits_per_offset: data.bits_per_offset,
439 data: LanceBuffer::from(decompressed),
440 offsets: new_offsets,
441 num_values: data.num_values,
442 block_info: BlockInfo::new(),
443 }))
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use std::str::FromStr;
451
452 #[test]
453 fn test_compression_scheme_from_str() {
454 assert_eq!(
455 CompressionScheme::from_str("none").unwrap(),
456 CompressionScheme::None
457 );
458 assert_eq!(
459 CompressionScheme::from_str("zstd").unwrap(),
460 CompressionScheme::Zstd
461 );
462 }
463
464 #[test]
465 fn test_compression_scheme_from_str_invalid() {
466 assert!(CompressionScheme::from_str("invalid").is_err());
467 }
468
469 #[test]
470 fn test_compress_zstd_with_length_prefixed() {
471 let compressor = ZstdBufferCompressor::new(0);
472 let input_data = b"Hello, world!";
473 let mut compressed_data = Vec::new();
474
475 compressor
476 .compress(input_data, &mut compressed_data)
477 .unwrap();
478 let mut decompressed_data = Vec::new();
479 compressor
480 .decompress(&compressed_data, &mut decompressed_data)
481 .unwrap();
482 assert_eq!(input_data, decompressed_data.as_slice());
483 }
484
485 #[test]
486 fn test_zstd_compress_decompress_multiple_times() {
487 let compressor = ZstdBufferCompressor::new(0);
488 let (input_data_1, input_data_2) = (b"Hello ", b"World");
489 let mut compressed_data = Vec::new();
490
491 compressor
492 .compress(input_data_1, &mut compressed_data)
493 .unwrap();
494 let compressed_length_1 = compressed_data.len();
495
496 compressor
497 .compress(input_data_2, &mut compressed_data)
498 .unwrap();
499
500 let mut decompressed_data = Vec::new();
501 compressor
502 .decompress(
503 &compressed_data[..compressed_length_1],
504 &mut decompressed_data,
505 )
506 .unwrap();
507
508 compressor
509 .decompress(
510 &compressed_data[compressed_length_1..],
511 &mut decompressed_data,
512 )
513 .unwrap();
514
515 assert_eq!(
517 decompressed_data.len(),
518 input_data_1.len() + input_data_2.len()
519 );
520 assert_eq!(
521 &decompressed_data[..input_data_1.len()],
522 input_data_1,
523 "First part of decompressed data should match input_1"
524 );
525 assert_eq!(
526 &decompressed_data[input_data_1.len()..],
527 input_data_2,
528 "Second part of decompressed data should match input_2"
529 );
530 }
531
532 #[test]
533 fn test_compress_zstd_raw_stream_format_and_decompress_with_length_prefixed() {
534 let compressor = ZstdBufferCompressor::new(0);
535 let input_data = b"Hello, world!";
536 let mut compressed_data = Vec::new();
537
538 let mut encoder = zstd::Encoder::new(&mut compressed_data, 0).unwrap();
540 encoder.write_all(input_data).unwrap();
541 encoder.finish().expect("failed to encode data with zstd");
542
543 let mut decompressed_data = Vec::new();
545 compressor
546 .decompress(&compressed_data, &mut decompressed_data)
547 .unwrap();
548 assert_eq!(input_data, decompressed_data.as_slice());
549 }
550}