nodedb_codec/vector_quant/
opq_kmeans.rs1use super::opq_rotation::Xorshift64;
6
7#[inline]
9pub fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
10 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
11}
12
13pub fn lloyd(
17 points: &[Vec<f32>],
18 sub_dim: usize,
19 k: usize,
20 iters: usize,
21 seed: u64,
22) -> Vec<Vec<f32>> {
23 let n = points.len();
24 if n == 0 || k == 0 {
25 return Vec::new();
26 }
27 let k = k.min(n);
28
29 let mut rng = Xorshift64::new(seed.wrapping_add(0x9E3779B97F4A7C15));
30 let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
31 centroids.push(points[0].clone());
32
33 let mut min_dists = vec![f32::MAX; n];
34 for (i, p) in points.iter().enumerate() {
35 min_dists[i] = l2_sq(p, ¢roids[0]);
36 }
37
38 for _ in 1..k {
39 let total: f64 = min_dists.iter().map(|&d| d as f64).sum();
40 let chosen = if total < f64::EPSILON {
41 0usize
42 } else {
43 let target = {
44 let u = (rng.next_u64() >> 11) as f64 / (1u64 << 53) as f64;
45 u * total
46 };
47 let mut acc = 0.0f64;
48 let mut idx = n - 1;
49 for (i, &d) in min_dists.iter().enumerate() {
50 acc += d as f64;
51 if acc >= target {
52 idx = i;
53 break;
54 }
55 }
56 idx
57 };
58 let new_c = points[chosen].clone();
59 for (i, p) in points.iter().enumerate() {
60 let d = l2_sq(p, &new_c);
61 if d < min_dists[i] {
62 min_dists[i] = d;
63 }
64 }
65 centroids.push(new_c);
66 }
67
68 let mut assignments = vec![0usize; n];
69 for _ in 0..iters {
70 let mut changed = false;
71 for (i, p) in points.iter().enumerate() {
72 let best = (0..k)
73 .min_by(|&a, &b| {
74 l2_sq(p, ¢roids[a])
75 .partial_cmp(&l2_sq(p, ¢roids[b]))
76 .unwrap_or(std::cmp::Ordering::Equal)
77 })
78 .unwrap_or(0);
79 if assignments[i] != best {
80 assignments[i] = best;
81 changed = true;
82 }
83 }
84 if !changed {
85 break;
86 }
87 let mut sums = vec![vec![0.0f32; sub_dim]; k];
88 let mut counts = vec![0usize; k];
89 for (i, p) in points.iter().enumerate() {
90 let c = assignments[i];
91 counts[c] += 1;
92 for d in 0..sub_dim {
93 sums[c][d] += p[d];
94 }
95 }
96 for c in 0..k {
97 if counts[c] > 0 {
98 for d in 0..sub_dim {
99 centroids[c][d] = sums[c][d] / counts[c] as f32;
100 }
101 }
102 }
103 }
104 centroids
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn lloyd_separates_two_clusters() {
113 let mut pts: Vec<Vec<f32>> = Vec::new();
115 for i in 0..10 {
116 pts.push(vec![i as f32 * 0.01, i as f32 * 0.01]);
117 }
118 for i in 0..10 {
119 pts.push(vec![10.0 + i as f32 * 0.01, 10.0 + i as f32 * 0.01]);
120 }
121 let centroids = lloyd(&pts, 2, 2, 20, 99);
122 assert_eq!(centroids.len(), 2);
123 let d = l2_sq(¢roids[0], ¢roids[1]).sqrt();
125 assert!(d > 5.0, "centroids too close: {d}");
126 }
127}