use iqdb_types::{IqdbError, Result};
use crate::rng::SplitMix64;
pub(crate) const MAX_ITERS: usize = 25;
pub(crate) const REL_TOL: f32 = 1e-4;
pub(crate) fn train_codebook(
sub_dim: usize,
n_centroids: usize,
seed: u64,
subvectors: &[&[f32]],
) -> Result<Vec<Vec<f32>>> {
if subvectors.is_empty() {
return Err(IqdbError::InvalidConfig {
reason: "ProductQuantizer codebook training requires a non-empty sample",
});
}
for v in subvectors {
if v.len() != sub_dim {
return Err(IqdbError::DimensionMismatch {
expected: sub_dim,
found: v.len(),
});
}
}
if subvectors.len() < n_centroids {
return Err(IqdbError::InvalidConfig {
reason: "ProductQuantizer codebook training requires sample size >= n_centroids",
});
}
let mut rng = SplitMix64::new(seed);
let working_set: Vec<&[f32]> = subvectors.to_vec();
let centroids = kmeans_plus_plus(sub_dim, n_centroids, &working_set, &mut rng);
let final_centroids = lloyd(centroids, &working_set, sub_dim);
Ok(final_centroids)
}
#[must_use]
pub(crate) fn squared_l2(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "squared_l2 requires same-dim slices");
let mut sum: f32 = 0.0;
for i in 0..a.len() {
let d = a[i] - b[i];
sum += d * d;
}
sum
}
#[must_use]
pub(crate) fn assign_to_cluster(centroids: &[Vec<f32>], vector: &[f32]) -> usize {
debug_assert!(
!centroids.is_empty(),
"assign_to_cluster needs at least one centroid"
);
let mut best_idx: usize = 0;
let mut best_dist = squared_l2(¢roids[0], vector);
for (i, c) in centroids.iter().enumerate().skip(1) {
let d = squared_l2(c, vector);
if d < best_dist {
best_dist = d;
best_idx = i;
}
}
best_idx
}
fn kmeans_plus_plus(
dim: usize,
n_centroids: usize,
working_set: &[&[f32]],
rng: &mut SplitMix64,
) -> Vec<Vec<f32>> {
let n = working_set.len();
let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(n_centroids);
let first_idx = rng.next_below(n as u64) as usize;
centroids.push(working_set[first_idx].to_vec());
let mut min_sq: Vec<f32> = working_set
.iter()
.map(|v| squared_l2(working_set[first_idx], v))
.collect();
for _ in 1..n_centroids {
let mut total: f64 = 0.0;
for &w in &min_sq {
total += w as f64;
}
let next_idx = if total <= 0.0 {
rng.next_below(n as u64) as usize
} else {
let target = rng.next_open_unit() * total;
let mut running: f64 = 0.0;
let mut chosen: usize = n - 1;
for (i, &w) in min_sq.iter().enumerate() {
running += w as f64;
if running >= target {
chosen = i;
break;
}
}
chosen
};
centroids.push(working_set[next_idx].to_vec());
let new_centroid = ¢roids[centroids.len() - 1];
for (i, v) in working_set.iter().enumerate() {
let d = squared_l2(new_centroid, v);
if d < min_sq[i] {
min_sq[i] = d;
}
}
}
debug_assert_eq!(centroids.len(), n_centroids);
debug_assert!(centroids.iter().all(|c| c.len() == dim));
let _ = dim;
centroids
}
fn lloyd(mut centroids: Vec<Vec<f32>>, working_set: &[&[f32]], dim: usize) -> Vec<Vec<f32>> {
let n_clusters = centroids.len();
let n = working_set.len();
let mut sums: Vec<Vec<f64>> = vec![vec![0.0_f64; dim]; n_clusters];
let mut counts: Vec<usize> = vec![0_usize; n_clusters];
let mut assignments: Vec<usize> = vec![0_usize; n];
for _iter in 0..MAX_ITERS {
for s in sums.iter_mut() {
for v in s.iter_mut() {
*v = 0.0;
}
}
for c in counts.iter_mut() {
*c = 0;
}
for (i, v) in working_set.iter().enumerate() {
let c = assign_to_cluster(¢roids, v);
assignments[i] = c;
let s = &mut sums[c];
for (k, &x) in v.iter().enumerate() {
s[k] += x as f64;
}
counts[c] += 1;
}
let mut max_rel_shift: f32 = 0.0;
for c in 0..n_clusters {
let count = counts[c];
if count == 0 {
let mut best_idx: usize = 0;
let mut best_dist: f32 = -1.0;
for (i, v) in working_set.iter().enumerate() {
let mut nearest: f32 = squared_l2(¢roids[0], v);
for cc in centroids.iter().skip(1) {
let d = squared_l2(cc, v);
if d < nearest {
nearest = d;
}
}
if nearest > best_dist {
best_dist = nearest;
best_idx = i;
}
}
let new_centroid = working_set[best_idx].to_vec();
let shift = relative_shift(¢roids[c], &new_centroid);
if shift > max_rel_shift {
max_rel_shift = shift;
}
centroids[c] = new_centroid;
continue;
}
let inv = 1.0_f64 / (count as f64);
let new_centroid: Vec<f32> = sums[c].iter().map(|&s| (s * inv) as f32).collect();
let shift = relative_shift(¢roids[c], &new_centroid);
if shift > max_rel_shift {
max_rel_shift = shift;
}
centroids[c] = new_centroid;
}
if max_rel_shift < REL_TOL {
break;
}
}
let _ = assignments;
centroids
}
fn relative_shift(old: &[f32], new: &[f32]) -> f32 {
debug_assert_eq!(old.len(), new.len());
let mut diff_sq: f32 = 0.0;
let mut old_norm_sq: f32 = 0.0;
for i in 0..old.len() {
let d = old[i] - new[i];
diff_sq += d * d;
old_norm_sq += old[i] * old[i];
}
let diff = diff_sq.sqrt();
let denom = old_norm_sq.sqrt().max(1.0);
diff / denom
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
fn refs(slice: &[Vec<f32>]) -> Vec<&[f32]> {
slice.iter().map(|v| v.as_slice()).collect()
}
#[test]
fn rejects_empty_sample() {
let err = train_codebook(2, 3, 0, &[]).unwrap_err();
match err {
IqdbError::InvalidConfig { reason } => {
assert!(reason.contains("non-empty sample"));
}
other => panic!("expected InvalidConfig, got {other:?}"),
}
}
#[test]
fn rejects_sample_smaller_than_n_centroids() {
let data = vec![vec![0.0_f32, 0.0], vec![1.0, 1.0]];
let sample = refs(&data);
let err = train_codebook(2, 5, 0, &sample).unwrap_err();
match err {
IqdbError::InvalidConfig { reason } => {
assert!(reason.contains("sample size >= n_centroids"));
}
other => panic!("expected InvalidConfig, got {other:?}"),
}
}
#[test]
fn rejects_dim_mismatch() {
let bad = vec![1.0_f32, 2.0, 3.0];
let sample = vec![bad.as_slice()];
let err = train_codebook(2, 1, 0, &sample).unwrap_err();
match err {
IqdbError::DimensionMismatch { expected, found } => {
assert_eq!(expected, 2);
assert_eq!(found, 3);
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn converges_on_two_obvious_clusters() {
let data: Vec<Vec<f32>> = vec![
vec![0.0, 0.0],
vec![0.1, -0.1],
vec![-0.05, 0.05],
vec![10.0, 10.0],
vec![10.1, 9.9],
vec![9.95, 10.05],
];
let sample = refs(&data);
let centroids = train_codebook(2, 2, 1, &sample).unwrap();
assert_eq!(centroids.len(), 2);
let mut near_origin = 0;
let mut near_ten = 0;
for c in ¢roids {
if c[0].abs() < 1.0 && c[1].abs() < 1.0 {
near_origin += 1;
}
if (c[0] - 10.0).abs() < 1.0 && (c[1] - 10.0).abs() < 1.0 {
near_ten += 1;
}
}
assert_eq!(near_origin, 1);
assert_eq!(near_ten, 1);
}
#[test]
fn same_seed_produces_identical_centroids() {
let data: Vec<Vec<f32>> = (0..50)
.map(|i| vec![(i as f32) * 0.1, ((i * 3) as f32) * 0.07])
.collect();
let sample = refs(&data);
let a = train_codebook(2, 4, 1234, &sample).unwrap();
let b = train_codebook(2, 4, 1234, &sample).unwrap();
assert_eq!(a, b, "same seed + same data → identical centroids");
}
#[test]
fn different_seeds_can_diverge() {
let data: Vec<Vec<f32>> = (0..50)
.map(|i| vec![(i as f32) * 0.1, ((i * 3) as f32) * 0.07])
.collect();
let sample = refs(&data);
let a = train_codebook(2, 4, 1, &sample).unwrap();
let b = train_codebook(2, 4, 2, &sample).unwrap();
assert_ne!(a, b);
}
}