1use arrow_buffer::ArrowNativeType;
5use arrow_schema::DataType;
6use snafu::location;
7use std::{
8 io::{Cursor, Write},
9 str::FromStr,
10};
11
12use lance_core::{Error, Result};
13
14use crate::{
15 buffer::LanceBuffer,
16 data::{BlockInfo, DataBlock, OpaqueBlock, VariableWidthBlock},
17 decoder::VariablePerValueDecompressor,
18 encoder::{ArrayEncoder, EncodedArray, PerValueCompressor, PerValueDataBlock},
19 format::{pb, ProtobufUtils},
20};
21
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub struct CompressionConfig {
24 pub(crate) scheme: CompressionScheme,
25 pub(crate) level: Option<i32>,
26}
27
28impl CompressionConfig {
29 pub(crate) fn new(scheme: CompressionScheme, level: Option<i32>) -> Self {
30 Self { scheme, level }
31 }
32}
33
34impl Default for CompressionConfig {
35 fn default() -> Self {
36 Self {
37 scheme: CompressionScheme::Lz4,
38 level: Some(0),
39 }
40 }
41}
42
43#[derive(Debug, Clone, Copy, PartialEq)]
44pub enum CompressionScheme {
45 None,
46 Fsst,
47 Zstd,
48 Lz4,
49}
50
51impl std::fmt::Display for CompressionScheme {
52 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
53 let scheme_str = match self {
54 Self::Fsst => "fsst",
55 Self::Zstd => "zstd",
56 Self::None => "none",
57 Self::Lz4 => "lz4",
58 };
59 write!(f, "{}", scheme_str)
60 }
61}
62
63impl FromStr for CompressionScheme {
64 type Err = Error;
65
66 fn from_str(s: &str) -> Result<Self> {
67 match s {
68 "none" => Ok(Self::None),
69 "zstd" => Ok(Self::Zstd),
70 _ => Err(Error::invalid_input(
71 format!("Unknown compression scheme: {}", s),
72 location!(),
73 )),
74 }
75 }
76}
77
78pub trait BufferCompressor: std::fmt::Debug + Send + Sync {
79 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
80 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
81 fn name(&self) -> &str;
82}
83
84#[derive(Debug, Default)]
85pub struct ZstdBufferCompressor {
86 compression_level: i32,
87}
88
89impl ZstdBufferCompressor {
90 pub fn new(compression_level: i32) -> Self {
91 Self { compression_level }
92 }
93}
94
95impl BufferCompressor for ZstdBufferCompressor {
96 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
97 let mut encoder = zstd::Encoder::new(output_buf, self.compression_level)?;
98 encoder.write_all(input_buf)?;
99 match encoder.finish() {
100 Ok(_) => Ok(()),
101 Err(e) => Err(e.into()),
102 }
103 }
104
105 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
106 let source = Cursor::new(input_buf);
107 zstd::stream::copy_decode(source, output_buf)?;
108 Ok(())
109 }
110
111 fn name(&self) -> &str {
112 "zstd"
113 }
114}
115
116#[derive(Debug, Default)]
117pub struct Lz4BufferCompressor {}
118
119impl BufferCompressor for Lz4BufferCompressor {
120 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
121 lz4::block::compress_to_buffer(input_buf, None, true, output_buf)
122 .map_err(|err| Error::Internal {
123 message: format!("LZ4 compression error: {}", err),
124 location: location!(),
125 })
126 .map(|_| ())
127 }
128
129 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
130 lz4::block::decompress_to_buffer(input_buf, None, output_buf)
131 .map_err(|err| Error::Internal {
132 message: format!("LZ4 decompression error: {}", err),
133 location: location!(),
134 })
135 .map(|_| ())
136 }
137
138 fn name(&self) -> &str {
139 "zstd"
140 }
141}
142
143#[derive(Debug, Default)]
144pub struct NoopBufferCompressor {}
145
146impl BufferCompressor for NoopBufferCompressor {
147 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
148 output_buf.extend_from_slice(input_buf);
149 Ok(())
150 }
151
152 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
153 output_buf.extend_from_slice(input_buf);
154 Ok(())
155 }
156
157 fn name(&self) -> &str {
158 "none"
159 }
160}
161
162pub struct GeneralBufferCompressor {}
163
164impl GeneralBufferCompressor {
165 pub fn get_compressor(compression_config: CompressionConfig) -> Box<dyn BufferCompressor> {
166 match compression_config.scheme {
167 CompressionScheme::Fsst => unimplemented!(),
169 CompressionScheme::Zstd => Box::new(ZstdBufferCompressor::new(
170 compression_config.level.unwrap_or(0),
171 )),
172 CompressionScheme::Lz4 => Box::new(Lz4BufferCompressor::default()),
173 CompressionScheme::None => Box::new(NoopBufferCompressor {}),
174 }
175 }
176}
177
178#[derive(Debug)]
180pub struct CompressedBufferEncoder {
181 compressor: Box<dyn BufferCompressor>,
182}
183
184impl Default for CompressedBufferEncoder {
185 fn default() -> Self {
186 Self {
187 compressor: GeneralBufferCompressor::get_compressor(CompressionConfig {
188 scheme: CompressionScheme::Zstd,
189 level: Some(0),
190 }),
191 }
192 }
193}
194
195impl CompressedBufferEncoder {
196 pub fn new(compression_config: CompressionConfig) -> Self {
197 let compressor = GeneralBufferCompressor::get_compressor(compression_config);
198 Self { compressor }
199 }
200
201 pub fn from_scheme(scheme: &str) -> Result<Self> {
202 let scheme = CompressionScheme::from_str(scheme)?;
203 Ok(Self {
204 compressor: GeneralBufferCompressor::get_compressor(CompressionConfig {
205 scheme,
206 level: Some(0),
207 }),
208 })
209 }
210}
211
212impl ArrayEncoder for CompressedBufferEncoder {
213 fn encode(
214 &self,
215 data: DataBlock,
216 _data_type: &DataType,
217 buffer_index: &mut u32,
218 ) -> Result<EncodedArray> {
219 let uncompressed_data = data.as_fixed_width().unwrap();
220
221 let mut compressed_buf = Vec::with_capacity(uncompressed_data.data.len());
222 self.compressor
223 .compress(&uncompressed_data.data, &mut compressed_buf)?;
224
225 let compressed_data = DataBlock::Opaque(OpaqueBlock {
226 buffers: vec![compressed_buf.into()],
227 num_values: uncompressed_data.num_values,
228 block_info: BlockInfo::new(),
229 });
230
231 let comp_buf_index = *buffer_index;
232 *buffer_index += 1;
233
234 let encoding = ProtobufUtils::flat_encoding(
235 uncompressed_data.bits_per_value,
236 comp_buf_index,
237 Some(CompressionConfig::new(CompressionScheme::Zstd, None)),
238 );
239
240 Ok(EncodedArray {
241 data: compressed_data,
242 encoding,
243 })
244 }
245}
246
247impl CompressedBufferEncoder {
248 pub fn per_value_compress<T: ArrowNativeType>(
249 &self,
250 data: &[u8],
251 offsets: &[T],
252 compressed: &mut Vec<u8>,
253 ) -> Result<LanceBuffer> {
254 let mut new_offsets: Vec<T> = Vec::with_capacity(offsets.len());
255 new_offsets.push(T::from_usize(0).unwrap());
256
257 for off in offsets.windows(2) {
258 let start = off[0].as_usize();
259 let end = off[1].as_usize();
260 self.compressor.compress(&data[start..end], compressed)?;
261 new_offsets.push(T::from_usize(compressed.len()).unwrap());
262 }
263
264 Ok(LanceBuffer::reinterpret_vec(new_offsets))
265 }
266
267 pub fn per_value_decompress<T: ArrowNativeType>(
268 &self,
269 data: &[u8],
270 offsets: &[T],
271 decompressed: &mut Vec<u8>,
272 ) -> Result<LanceBuffer> {
273 let mut new_offsets: Vec<T> = Vec::with_capacity(offsets.len());
274 new_offsets.push(T::from_usize(0).unwrap());
275
276 for off in offsets.windows(2) {
277 let start = off[0].as_usize();
278 let end = off[1].as_usize();
279 self.compressor
280 .decompress(&data[start..end], decompressed)?;
281 new_offsets.push(T::from_usize(decompressed.len()).unwrap());
282 }
283
284 Ok(LanceBuffer::reinterpret_vec(new_offsets))
285 }
286}
287
288impl PerValueCompressor for CompressedBufferEncoder {
289 fn compress(&self, data: DataBlock) -> Result<(PerValueDataBlock, pb::ArrayEncoding)> {
290 let data_type = data.name();
291 let mut data = data.as_variable_width().ok_or(Error::Internal {
292 message: format!(
293 "Attempt to use CompressedBufferEncoder on data of type {}",
294 data_type
295 ),
296 location: location!(),
297 })?;
298
299 let data_bytes = &data.data;
300 let mut compressed = Vec::with_capacity(data_bytes.len());
301
302 let new_offsets = match data.bits_per_offset {
303 32 => self.per_value_compress::<u32>(
304 data_bytes,
305 &data.offsets.borrow_to_typed_slice::<u32>(),
306 &mut compressed,
307 )?,
308 64 => self.per_value_compress::<u64>(
309 data_bytes,
310 &data.offsets.borrow_to_typed_slice::<u64>(),
311 &mut compressed,
312 )?,
313 _ => unreachable!(),
314 };
315
316 let compressed = PerValueDataBlock::Variable(VariableWidthBlock {
317 bits_per_offset: data.bits_per_offset,
318 data: LanceBuffer::from(compressed),
319 offsets: new_offsets,
320 num_values: data.num_values,
321 block_info: BlockInfo::new(),
322 });
323
324 let encoding = ProtobufUtils::block(self.compressor.name());
325
326 Ok((compressed, encoding))
327 }
328}
329
330impl VariablePerValueDecompressor for CompressedBufferEncoder {
331 fn decompress(&self, mut data: VariableWidthBlock) -> Result<DataBlock> {
332 let data_bytes = &data.data;
333 let mut decompressed = Vec::with_capacity(data_bytes.len() * 2);
334
335 let new_offsets = match data.bits_per_offset {
336 32 => self.per_value_decompress(
337 data_bytes,
338 &data.offsets.borrow_to_typed_slice::<u32>(),
339 &mut decompressed,
340 )?,
341 64 => self.per_value_decompress(
342 data_bytes,
343 &data.offsets.borrow_to_typed_slice::<u32>(),
344 &mut decompressed,
345 )?,
346 _ => unreachable!(),
347 };
348 Ok(DataBlock::VariableWidth(VariableWidthBlock {
349 bits_per_offset: data.bits_per_offset,
350 data: LanceBuffer::from(decompressed),
351 offsets: new_offsets,
352 num_values: data.num_values,
353 block_info: BlockInfo::new(),
354 }))
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::buffer::LanceBuffer;
362 use crate::data::FixedWidthDataBlock;
363 use arrow_schema::DataType;
364 use std::str::FromStr;
365
366 #[test]
367 fn test_compression_scheme_from_str() {
368 assert_eq!(
369 CompressionScheme::from_str("none").unwrap(),
370 CompressionScheme::None
371 );
372 assert_eq!(
373 CompressionScheme::from_str("zstd").unwrap(),
374 CompressionScheme::Zstd
375 );
376 }
377
378 #[test]
379 fn test_compression_scheme_from_str_invalid() {
380 assert!(CompressionScheme::from_str("invalid").is_err());
381 }
382
383 #[test]
384 fn test_compressed_buffer_encoder() {
385 let encoder = CompressedBufferEncoder::default();
386 let data = DataBlock::FixedWidth(FixedWidthDataBlock {
387 bits_per_value: 64,
388 data: LanceBuffer::reinterpret_vec(vec![0, 1, 2, 3, 4, 5, 6, 7]),
389 num_values: 8,
390 block_info: BlockInfo::new(),
391 });
392
393 let mut buffer_index = 0;
394 let encoded_array_result = encoder.encode(data, &DataType::Int64, &mut buffer_index);
395 assert!(encoded_array_result.is_ok(), "{:?}", encoded_array_result);
396 let encoded_array = encoded_array_result.unwrap();
397 assert_eq!(encoded_array.data.num_values(), 8);
398 let buffers = encoded_array.data.into_buffers();
399 assert_eq!(buffers.len(), 1);
400 assert!(buffers[0].len() < 64 * 8);
401 }
402}