Skip to main content

nodedb_codec/vector_quant/
opq_kmeans.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Lloyd's k-means for OPQ codebook training.
4
5use super::opq_rotation::Xorshift64;
6
7/// L2 squared distance between two equal-length slices.
8#[inline]
9pub fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
10    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
11}
12
13/// Lloyd's k-means clustering.
14///
15/// Returns `k` centroids of length `sub_dim`, initialized via k-means++.
16pub fn lloyd(
17    points: &[Vec<f32>],
18    sub_dim: usize,
19    k: usize,
20    iters: usize,
21    seed: u64,
22) -> Vec<Vec<f32>> {
23    let n = points.len();
24    if n == 0 || k == 0 {
25        return Vec::new();
26    }
27    let k = k.min(n);
28
29    let mut rng = Xorshift64::new(seed.wrapping_add(0x9E3779B97F4A7C15));
30    let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
31    centroids.push(points[0].clone());
32
33    let mut min_dists = vec![f32::MAX; n];
34    for (i, p) in points.iter().enumerate() {
35        min_dists[i] = l2_sq(p, &centroids[0]);
36    }
37
38    for _ in 1..k {
39        let total: f64 = min_dists.iter().map(|&d| d as f64).sum();
40        let chosen = if total < f64::EPSILON {
41            0usize
42        } else {
43            let target = {
44                let u = (rng.next_u64() >> 11) as f64 / (1u64 << 53) as f64;
45                u * total
46            };
47            let mut acc = 0.0f64;
48            let mut idx = n - 1;
49            for (i, &d) in min_dists.iter().enumerate() {
50                acc += d as f64;
51                if acc >= target {
52                    idx = i;
53                    break;
54                }
55            }
56            idx
57        };
58        let new_c = points[chosen].clone();
59        for (i, p) in points.iter().enumerate() {
60            let d = l2_sq(p, &new_c);
61            if d < min_dists[i] {
62                min_dists[i] = d;
63            }
64        }
65        centroids.push(new_c);
66    }
67
68    let mut assignments = vec![0usize; n];
69    for _ in 0..iters {
70        let mut changed = false;
71        for (i, p) in points.iter().enumerate() {
72            let best = (0..k)
73                .min_by(|&a, &b| {
74                    l2_sq(p, &centroids[a])
75                        .partial_cmp(&l2_sq(p, &centroids[b]))
76                        .unwrap_or(std::cmp::Ordering::Equal)
77                })
78                .unwrap_or(0);
79            if assignments[i] != best {
80                assignments[i] = best;
81                changed = true;
82            }
83        }
84        if !changed {
85            break;
86        }
87        let mut sums = vec![vec![0.0f32; sub_dim]; k];
88        let mut counts = vec![0usize; k];
89        for (i, p) in points.iter().enumerate() {
90            let c = assignments[i];
91            counts[c] += 1;
92            for d in 0..sub_dim {
93                sums[c][d] += p[d];
94            }
95        }
96        for c in 0..k {
97            if counts[c] > 0 {
98                for d in 0..sub_dim {
99                    centroids[c][d] = sums[c][d] / counts[c] as f32;
100                }
101            }
102        }
103    }
104    centroids
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn lloyd_separates_two_clusters() {
113        // Two tight clusters far apart.
114        let mut pts: Vec<Vec<f32>> = Vec::new();
115        for i in 0..10 {
116            pts.push(vec![i as f32 * 0.01, i as f32 * 0.01]);
117        }
118        for i in 0..10 {
119            pts.push(vec![10.0 + i as f32 * 0.01, 10.0 + i as f32 * 0.01]);
120        }
121        let centroids = lloyd(&pts, 2, 2, 20, 99);
122        assert_eq!(centroids.len(), 2);
123        // The two centroids should be separated by more than 5.0.
124        let d = l2_sq(&centroids[0], &centroids[1]).sqrt();
125        assert!(d > 5.0, "centroids too close: {d}");
126    }
127}