use vortex_array::ArrayRef;
use vortex_array::Canonical;
use vortex_array::ExecutionCtx;
use vortex_compressor::CascadingCompressor;
use vortex_compressor::ctx::CompressorContext;
use vortex_compressor::estimate::CompressionEstimate;
use vortex_compressor::estimate::EstimateVerdict;
use vortex_compressor::scheme::Scheme;
use vortex_compressor::stats::ArrayAndStats;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use crate::encodings::turboquant::MAX_CENTROIDS;
use crate::encodings::turboquant::TurboQuantConfig;
use crate::encodings::turboquant::tq_validate_vector_dtype;
use crate::encodings::turboquant::turboquant_encode;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct TurboQuantScheme;
impl Scheme for TurboQuantScheme {
fn scheme_name(&self) -> &'static str {
"vortex.tensor.turboquant"
}
fn matches(&self, canonical: &Canonical) -> bool {
let Canonical::Extension(ext) = canonical else {
return false;
};
tq_validate_vector_dtype(ext.dtype()).is_ok()
}
fn expected_compression_ratio(
&self,
data: &ArrayAndStats,
_compress_ctx: CompressorContext,
_exec_ctx: &mut ExecutionCtx,
) -> CompressionEstimate {
let len = data.array().len();
let dtype = data.array().dtype();
let vector_metadata =
tq_validate_vector_dtype(dtype).vortex_expect("invalid dtype for TurboQuant");
let element_ptype = vector_metadata.element_ptype();
let element_bit_width: u8 = element_ptype
.bit_width()
.try_into()
.vortex_expect("invalid bit width for TurboQuant");
let dimension = vector_metadata.dimensions();
CompressionEstimate::Verdict(EstimateVerdict::Ratio(estimate_compression_ratio(
element_bit_width,
dimension,
len,
)))
}
fn compress(
&self,
_compressor: &CascadingCompressor,
data: &ArrayAndStats,
_compress_ctx: CompressorContext,
exec_ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
turboquant_encode(data.array().clone(), &TurboQuantConfig::default(), exec_ctx)
}
}
fn estimate_compression_ratio(element_bit_width: u8, dimensions: u32, num_vectors: usize) -> f64 {
let config = TurboQuantConfig::default();
let padded_dim = dimensions.next_power_of_two() as usize;
let element_bits = usize::from(element_bit_width);
let uncompressed_size_bits = element_bits * dimensions as usize * num_vectors;
let norm_bits = element_bits;
let compressed_bits_per_vector = usize::from(config.bit_width) * padded_dim;
let total_bits_per_vector = norm_bits + compressed_bits_per_vector;
let num_centroids = 1usize << config.bit_width;
debug_assert!(num_centroids <= MAX_CENTROIDS);
let overhead_bits = num_centroids * 32;
let compressed_size_bits = total_bits_per_vector * num_vectors + overhead_bits;
uncompressed_size_bits as f64 / compressed_size_bits as f64
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[rstest]
#[case::f32_768d(32, 768, 1000, 2.5, 4.5)]
#[case::f32_1024d(32, 1024, 1000, 3.5, 5.0)]
#[case::f32_1536d(32, 1536, 1000, 2.5, 4.5)]
#[case::f32_128d(32, 128, 1000, 3.0, 5.0)]
#[case::f64_768d(64, 768, 1000, 5.0, 9.0)]
#[case::f16_768d(16, 768, 1000, 1.2, 2.5)]
fn compression_ratio_in_expected_range(
#[case] element_bit_width: u8,
#[case] dim: u32,
#[case] num_vectors: usize,
#[case] min_ratio: f64,
#[case] max_ratio: f64,
) {
let ratio = estimate_compression_ratio(element_bit_width, dim, num_vectors);
assert!(
ratio > min_ratio && ratio < max_ratio,
"ratio {ratio:.2} not in [{min_ratio}, {max_ratio}] for \
{element_bit_width}-bit elements, dim={dim}, n={num_vectors}"
);
}
#[rstest]
#[case(32, 128, 100)]
#[case(32, 768, 10)]
#[case(64, 256, 50)]
fn ratio_always_greater_than_one(
#[case] element_bit_width: u8,
#[case] dim: u32,
#[case] num_vectors: usize,
) {
let ratio = estimate_compression_ratio(element_bit_width, dim, num_vectors);
assert!(
ratio > 1.0,
"ratio {ratio:.4} <= 1.0 for {element_bit_width}-bit, dim={dim}, n={num_vectors}"
);
}
#[rstest]
#[case(16)]
#[case(32)]
#[case(64)]
fn ratio_accounts_for_norm_storage_width(#[case] element_bit_width: u8) {
let dim = 128u32;
let num_vectors = 1usize;
let padded_dim = dim.next_power_of_two() as usize;
let config = TurboQuantConfig::default();
let num_centroids = 1usize << config.bit_width;
let expected_compressed_bits = usize::from(element_bit_width)
+ usize::from(config.bit_width) * padded_dim
+ num_centroids * 32;
let expected_uncompressed_bits =
usize::from(element_bit_width) * dim as usize * num_vectors;
let expected = expected_uncompressed_bits as f64 / expected_compressed_bits as f64;
assert_eq!(
estimate_compression_ratio(element_bit_width, dim, num_vectors),
expected
);
}
#[test]
fn power_of_two_has_better_ratio() {
let ratio_768 = estimate_compression_ratio(32, 768, 1000);
let ratio_1024 = estimate_compression_ratio(32, 1024, 1000);
assert!(
ratio_1024 > ratio_768,
"1024-d ratio ({ratio_1024:.2}) should exceed 768-d ({ratio_768:.2})"
);
}
}