use std::sync::LazyLock;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_utils::aliases::dash_map::DashMap;
use crate::encodings::turboquant::TurboQuant;
const MAX_ITERATIONS: usize = 200;
const CONVERGENCE_EPSILON: f64 = 1e-12;
const INTEGRATION_POINTS: usize = 1000;
static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Vec<f32>>> = LazyLock::new(DashMap::default);
pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Vec<f32>> {
vortex_ensure!(
(1..=TurboQuant::MAX_BIT_WIDTH).contains(&bit_width),
"TurboQuant bit_width must be 1-{}, got {bit_width}",
TurboQuant::MAX_BIT_WIDTH
);
vortex_ensure!(
dimension >= TurboQuant::MIN_DIMENSION,
"TurboQuant dimension must be >= {}, got {dimension}",
TurboQuant::MIN_DIMENSION
);
if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) {
return Ok(centroids.clone());
}
let centroids = max_lloyd_centroids(dimension, bit_width);
CENTROID_CACHE.insert((dimension, bit_width), centroids.clone());
Ok(centroids)
}
#[derive(Clone, Copy, Debug)]
struct HalfIntExponent {
int_part: i32,
has_half: bool,
}
impl HalfIntExponent {
fn from_numerator(numerator: i32) -> Self {
let int_part = numerator.div_euclid(2);
let has_half = numerator.rem_euclid(2) != 0;
Self { int_part, has_half }
}
}
fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec<f32> {
debug_assert!((1..=TurboQuant::MAX_BIT_WIDTH).contains(&bit_width));
let num_centroids = 1usize << bit_width;
let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3);
let mut centroids: Vec<f64> = (0..num_centroids)
.map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64))
.collect();
let mut boundaries: Vec<f64> = vec![0.0; num_centroids + 1];
for _ in 0..MAX_ITERATIONS {
boundaries[0] = -1.0;
for idx in 0..num_centroids - 1 {
boundaries[idx + 1] = (centroids[idx] + centroids[idx + 1]) / 2.0;
}
boundaries[num_centroids] = 1.0;
let mut max_change = 0.0f64;
for idx in 0..num_centroids {
let lo = boundaries[idx];
let hi = boundaries[idx + 1];
let new_centroid = mean_between_centroids(lo, hi, exponent);
max_change = max_change.max((new_centroid - centroids[idx]).abs());
centroids[idx] = new_centroid;
}
if max_change < CONVERGENCE_EPSILON {
break;
}
}
#[expect(
clippy::cast_possible_truncation,
reason = "all values are in [-1, 1] so this just loses precision"
)]
centroids.into_iter().map(|val| val as f32).collect()
}
fn mean_between_centroids(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 {
if (hi - lo).abs() < 1e-15 {
return (lo + hi) / 2.0;
}
let dx = (hi - lo) / INTEGRATION_POINTS as f64;
let mut numerator = 0.0;
let mut denominator = 0.0;
for step in 0..=INTEGRATION_POINTS {
let x_val = lo + (step as f64) * dx;
let weight = pdf_unnormalized(x_val, exponent);
let trap_weight = if step == 0 || step == INTEGRATION_POINTS {
0.5
} else {
1.0
};
numerator += trap_weight * x_val * weight;
denominator += trap_weight * weight;
}
if denominator.abs() < 1e-30 {
(lo + hi) / 2.0
} else {
numerator / denominator
}
}
#[inline]
fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 {
let base = (1.0 - x_val * x_val).max(0.0);
if exponent.has_half {
base.powi(exponent.int_part) * base.sqrt()
} else {
base.powi(exponent.int_part)
}
}
pub fn compute_centroid_boundaries(centroids: &[f32]) -> Vec<f32> {
centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect()
}
#[inline]
pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 {
debug_assert!(
boundaries.windows(2).all(|w| w[0] <= w[1]),
"boundaries must be sorted"
);
debug_assert!(
boundaries.len() <= 256, "boundaries must be sorted"
);
#[expect(
clippy::cast_possible_truncation,
reason = "num_centroids <= 256 and partition_point will return at most 255"
)]
(boundaries.partition_point(|&b| b < value) as u8)
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_error::VortexResult;
use super::*;
#[rstest]
#[case(128, 1, 2)]
#[case(128, 2, 4)]
#[case(128, 3, 8)]
#[case(128, 4, 16)]
#[case(768, 2, 4)]
#[case(1536, 3, 8)]
fn centroids_have_correct_count(
#[case] dim: u32,
#[case] bits: u8,
#[case] expected: usize,
) -> VortexResult<()> {
let centroids = get_centroids(dim, bits)?;
assert_eq!(centroids.len(), expected);
Ok(())
}
#[rstest]
#[case(128, 1)]
#[case(128, 2)]
#[case(128, 3)]
#[case(128, 4)]
#[case(768, 2)]
fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
let centroids = get_centroids(dim, bits)?;
for window in centroids.windows(2) {
assert!(
window[0] < window[1],
"centroids not sorted: {:?}",
centroids
);
}
Ok(())
}
#[rstest]
#[case(128, 1)]
#[case(128, 2)]
#[case(256, 2)]
#[case(768, 2)]
fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
let centroids = get_centroids(dim, bits)?;
let count = centroids.len();
for idx in 0..count / 2 {
let diff = (centroids[idx] + centroids[count - 1 - idx]).abs();
assert!(
diff < 1e-5,
"centroids not symmetric: c[{idx}]={}, c[{}]={}",
centroids[idx],
count - 1 - idx,
centroids[count - 1 - idx]
);
}
Ok(())
}
#[rstest]
#[case(128, 1)]
#[case(128, 4)]
fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
let centroids = get_centroids(dim, bits)?;
for &val in ¢roids {
assert!(
(-1.0..=1.0).contains(&val),
"centroid out of [-1, 1]: {val}",
);
}
Ok(())
}
#[test]
fn centroids_cached() -> VortexResult<()> {
let c1 = get_centroids(128, 2)?;
let c2 = get_centroids(128, 2)?;
assert_eq!(c1, c2);
Ok(())
}
#[test]
fn find_nearest_basic() -> VortexResult<()> {
let centroids = get_centroids(128, 2)?;
let boundaries = compute_centroid_boundaries(¢roids);
assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0);
#[expect(clippy::cast_possible_truncation)]
let last_idx = (centroids.len() - 1) as u8;
assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx);
for (idx, &cv) in centroids.iter().enumerate() {
#[expect(clippy::cast_possible_truncation)]
let expected = idx as u8;
assert_eq!(find_nearest_centroid(cv, &boundaries), expected);
}
Ok(())
}
#[test]
fn rejects_invalid_params() {
assert!(get_centroids(128, 0).is_err());
assert!(get_centroids(128, 9).is_err());
assert!(get_centroids(1, 2).is_err());
assert!(get_centroids(127, 2).is_err());
}
}