use vortex_array::ArrayRef;
use vortex_array::Canonical;
use vortex_array::IntoArray;
use vortex_array::arrays::Extension;
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
use vortex_compressor::CascadingCompressor;
use vortex_compressor::ctx::CompressorContext;
use vortex_compressor::estimate::CompressionEstimate;
use vortex_compressor::scheme::Scheme;
use vortex_compressor::stats::ArrayAndStats;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use crate::encodings::turboquant::TurboQuant;
use crate::encodings::turboquant::TurboQuantConfig;
use crate::encodings::turboquant::turboquant_encode_unchecked;
use crate::scalar_fns::ApproxOptions;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
pub(super) mod compress;
pub(super) mod decompress;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct TurboQuantScheme;
impl Scheme for TurboQuantScheme {
fn scheme_name(&self) -> &'static str {
"vortex.tensor.turboquant"
}
fn matches(&self, canonical: &Canonical) -> bool {
let Canonical::Extension(ext) = canonical else {
return false;
};
TurboQuant::validate_dtype(ext.dtype()).is_ok()
}
fn expected_compression_ratio(
&self,
data: &mut ArrayAndStats,
_ctx: CompressorContext,
) -> CompressionEstimate {
let len = data.array().len();
let dtype = data.array().dtype();
let vector_metadata =
TurboQuant::validate_dtype(dtype).vortex_expect("invalid dtype for TurboQuant");
let element_ptype = vector_metadata.element_ptype();
let bit_width: u8 = element_ptype
.bit_width()
.try_into()
.vortex_expect("invalid bit width for TurboQuant");
let dimension = vector_metadata.dimensions();
CompressionEstimate::Ratio(estimate_compression_ratio(bit_width, dimension, len))
}
fn compress(
&self,
compressor: &CascadingCompressor,
data: &mut ArrayAndStats,
_ctx: CompressorContext,
) -> VortexResult<ArrayRef> {
let ext_array = data
.array()
.as_opt::<Extension>()
.vortex_expect("expected an extension array");
let mut ctx = compressor.execution_ctx();
let l2_denorm =
normalize_as_l2_denorm(&ApproxOptions::Exact, ext_array.as_ref().clone(), &mut ctx)?;
let normalized = l2_denorm.child_at(0).clone();
let norms = l2_denorm.child_at(1).clone();
let num_rows = l2_denorm.len();
let normalized_ext = normalized
.as_opt::<Extension>()
.vortex_expect("normalized child should be an Extension array");
let config = TurboQuantConfig::default();
let tq = unsafe { turboquant_encode_unchecked(normalized_ext, &config, &mut ctx)? };
Ok(
unsafe { L2Denorm::new_array_unchecked(&ApproxOptions::Exact, tq, norms, num_rows) }?
.into_array(),
)
}
}
fn estimate_compression_ratio(bits_per_element: u8, dimensions: u32, num_vectors: usize) -> f64 {
let config = TurboQuantConfig::default();
let padded_dim = dimensions.next_power_of_two() as usize;
let compressed_bits_per_vector = 32 + (config.bit_width as usize) * padded_dim;
let num_centroids = 1usize << config.bit_width;
debug_assert!(num_centroids <= TurboQuant::MAX_CENTROIDS);
let overhead_bits = num_centroids * 32 + config.num_rounds as usize * padded_dim;
let compressed_size_bits = compressed_bits_per_vector * num_vectors + overhead_bits;
let uncompressed_size_bits = bits_per_element as usize * dimensions as usize * num_vectors;
uncompressed_size_bits as f64 / compressed_size_bits as f64
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[rstest]
#[case::f32_768d(32, 768, 1000, 2.5, 4.0)]
#[case::f32_1024d(32, 1024, 1000, 3.5, 5.0)]
#[case::f32_1536d(32, 1536, 1000, 2.5, 4.0)]
#[case::f32_128d(32, 128, 1000, 3.0, 5.0)]
#[case::f64_768d(64, 768, 1000, 5.0, 7.0)]
#[case::f16_768d(16, 768, 1000, 1.2, 2.0)]
fn compression_ratio_in_expected_range(
#[case] bits_per_element: u8,
#[case] dim: u32,
#[case] num_vectors: usize,
#[case] min_ratio: f64,
#[case] max_ratio: f64,
) {
let ratio = estimate_compression_ratio(bits_per_element, dim, num_vectors);
assert!(
ratio > min_ratio && ratio < max_ratio,
"ratio {ratio:.2} not in [{min_ratio}, {max_ratio}] for \
{bits_per_element}-bit elements, dim={dim}, n={num_vectors}"
);
}
#[rstest]
#[case(32, 128, 100)]
#[case(32, 768, 10)]
#[case(64, 256, 50)]
fn ratio_always_greater_than_one(
#[case] bits_per_element: u8,
#[case] dim: u32,
#[case] num_vectors: usize,
) {
let ratio = estimate_compression_ratio(bits_per_element, dim, num_vectors);
assert!(
ratio > 1.0,
"ratio {ratio:.4} <= 1.0 for {bits_per_element}-bit, dim={dim}, n={num_vectors}"
);
}
#[test]
fn power_of_two_has_better_ratio() {
let ratio_768 = estimate_compression_ratio(32, 768, 1000);
let ratio_1024 = estimate_compression_ratio(32, 1024, 1000);
assert!(
ratio_1024 > ratio_768,
"1024-d ratio ({ratio_1024:.2}) should exceed 768-d ({ratio_768:.2})"
);
}
}