use fsst::Compressor;
use fsst::Symbol;
use vortex_array::IntoArray;
use vortex_array::accessor::ArrayAccessor;
use vortex_array::arrays::varbin::builder::VarBinBuilder;
use vortex_array::dtype::DType;
use vortex_buffer::Buffer;
use vortex_buffer::BufferMut;
use vortex_error::VortexExpect;
use crate::FSST;
use crate::FSSTArray;
pub fn fsst_compress<A: ArrayAccessor<[u8]>>(
strings: A,
len: usize,
dtype: &DType,
compressor: &Compressor,
) -> FSSTArray {
strings.with_iterator(|iter| fsst_compress_iter(iter, len, dtype.clone(), compressor))
}
pub fn fsst_train_compressor<A: ArrayAccessor<[u8]>>(array: &A) -> Compressor {
array.with_iterator(|iter| fsst_train_compressor_iter(iter))
}
fn fsst_train_compressor_iter<'a, I>(iter: I) -> Compressor
where
I: Iterator<Item = Option<&'a [u8]>>,
{
let mut lines = Vec::with_capacity(8_192);
for string in iter {
match string {
None => {}
Some(b) => lines.push(b),
}
}
Compressor::train(&lines)
}
const DEFAULT_BUFFER_LEN: usize = 1024 * 1024;
pub fn fsst_compress_iter<'a, I>(
iter: I,
len: usize,
dtype: DType,
compressor: &Compressor,
) -> FSSTArray
where
I: Iterator<Item = Option<&'a [u8]>>,
{
let mut buffer = Vec::with_capacity(DEFAULT_BUFFER_LEN);
let mut builder = VarBinBuilder::<i32>::with_capacity(len);
let mut uncompressed_lengths: BufferMut<i32> = BufferMut::with_capacity(len);
for string in iter {
match string {
None => {
builder.append_null();
uncompressed_lengths.push(0);
}
Some(s) => {
uncompressed_lengths.push(
s.len()
.try_into()
.vortex_expect("string length must fit in i32"),
);
let target_size = 2 * s.len() + 7;
if target_size > buffer.len() {
let additional_capacity = target_size - buffer.len();
buffer.reserve(additional_capacity);
}
unsafe { compressor.compress_into(s, &mut buffer) };
builder.append_value(&buffer);
}
}
}
let codes = builder.finish(DType::Binary(dtype.nullability()));
let symbols: Buffer<Symbol> = Buffer::copy_from(compressor.symbol_table());
let symbol_lengths: Buffer<u8> = Buffer::<u8>::copy_from(compressor.symbol_lengths());
let uncompressed_lengths = uncompressed_lengths.into_array();
FSST::try_new(dtype, symbols, symbol_lengths, codes, uncompressed_lengths)
.vortex_expect("FSST parts must be valid")
}
#[cfg(test)]
mod tests {
use fsst::CompressorBuilder;
use vortex_array::LEGACY_SESSION;
use vortex_array::VortexSessionExecute;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::scalar::Scalar;
use crate::compress::DEFAULT_BUFFER_LEN;
use crate::fsst_compress_iter;
#[test]
fn test_large_string() {
let big_string: String = "abc"
.chars()
.cycle()
.take(10 * DEFAULT_BUFFER_LEN)
.collect();
let compressor = CompressorBuilder::default().build();
let compressed = fsst_compress_iter(
[Some(big_string.as_bytes())].into_iter(),
1,
DType::Utf8(Nullability::NonNullable),
&compressor,
);
let decoded = compressed
.execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
.unwrap();
let expected = Scalar::utf8(big_string, Nullability::NonNullable);
assert_eq!(decoded, expected);
}
}