1use ailake_core::AilakeError;
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct PQCodebook {
11 pub num_subvectors: usize,
13 pub num_centroids: usize,
15 pub sub_dim: usize,
17 pub centroids: Vec<Vec<Vec<f32>>>,
19}
20
21impl PQCodebook {
22 pub fn train(
24 vectors: &[Vec<f32>],
25 num_subvectors: usize,
26 num_centroids: usize,
27 max_iter: usize,
28 ) -> Result<Self, AilakeError> {
29 Self::train_with_kmeans(vectors, num_subvectors, num_centroids, max_iter, kmeans)
30 }
31
32 pub fn train_with_kmeans<F>(
37 vectors: &[Vec<f32>],
38 num_subvectors: usize,
39 num_centroids: usize,
40 max_iter: usize,
41 kmeans_fn: F,
42 ) -> Result<Self, AilakeError>
43 where
44 F: Fn(&[Vec<f32>], usize, usize) -> Vec<Vec<f32>>,
45 {
46 if vectors.is_empty() {
47 return Err(AilakeError::Catalog(
48 "PQ training requires at least 1 vector".into(),
49 ));
50 }
51 let dim = vectors[0].len();
52 if !dim.is_multiple_of(num_subvectors) {
53 return Err(AilakeError::Catalog(format!(
54 "dim {dim} not divisible by num_subvectors {num_subvectors}"
55 )));
56 }
57 let sub_dim = dim / num_subvectors;
58 let n_train = num_centroids.min(vectors.len());
59
60 let mut centroids = Vec::with_capacity(num_subvectors);
61 for m in 0..num_subvectors {
62 let start = m * sub_dim;
63 let end = start + sub_dim;
64 let sub_vecs: Vec<Vec<f32>> = vectors.iter().map(|v| v[start..end].to_vec()).collect();
65 let sub_centroids = kmeans_fn(&sub_vecs, n_train, max_iter);
66 centroids.push(sub_centroids);
67 }
68
69 Ok(Self {
70 num_subvectors,
71 num_centroids,
72 sub_dim,
73 centroids,
74 })
75 }
76
77 pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
79 let mut codes = Vec::with_capacity(self.num_subvectors);
80 for m in 0..self.num_subvectors {
81 let start = m * self.sub_dim;
82 let sub = &vector[start..start + self.sub_dim];
83 let best = self.centroids[m]
84 .iter()
85 .enumerate()
86 .map(|(k, c)| (k, l2_sq(sub, c)))
87 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
88 .map(|(k, _)| k)
89 .unwrap_or(0);
90 codes.push(best as u8);
91 }
92 codes
93 }
94
95 pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
97 let mut out = Vec::with_capacity(self.num_subvectors * self.sub_dim);
98 for (m, &code) in codes.iter().enumerate() {
99 out.extend_from_slice(&self.centroids[m][code as usize]);
100 }
101 out
102 }
103
104 pub fn compute_adc_table(&self, query: &[f32]) -> Vec<Vec<f32>> {
109 (0..self.num_subvectors)
110 .map(|m| {
111 let start = m * self.sub_dim;
112 let q_sub = &query[start..start + self.sub_dim];
113 self.centroids[m].iter().map(|c| l2_sq(q_sub, c)).collect()
114 })
115 .collect()
116 }
117
118 pub fn adc_distance(&self, codes: &[u8], table: &[Vec<f32>]) -> f32 {
120 codes
121 .iter()
122 .enumerate()
123 .map(|(m, &c)| table[m][c as usize])
124 .sum()
125 }
126}
127
128fn kmeans(points: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
130 let dim = points[0].len();
131 let mut centroids = kmeans_pp_init(points, k);
132
133 for _ in 0..max_iter {
134 let assignments: Vec<usize> = points
136 .par_iter()
137 .map(|p| nearest_centroid(p, ¢roids))
138 .collect();
139
140 let mut new_centroids = vec![vec![0.0f32; dim]; k];
142 let mut counts = vec![0usize; k];
143 for (point, &assigned) in points.iter().zip(assignments.iter()) {
144 for (d, &v) in point.iter().enumerate() {
145 new_centroids[assigned][d] += v;
146 }
147 counts[assigned] += 1;
148 }
149 let mut converged = true;
150 for (i, count) in counts.iter().enumerate() {
151 if *count > 0 {
152 let scale = *count as f32;
153 for x in new_centroids[i].iter_mut() {
154 *x /= scale;
155 }
156 } else {
157 new_centroids[i] = centroids[i].clone();
159 }
160 if l2_sq(&new_centroids[i], ¢roids[i]) > 1e-8 {
161 converged = false;
162 }
163 }
164 centroids = new_centroids;
165 if converged {
166 break;
167 }
168 }
169 centroids
170}
171
172fn kmeans_pp_init(points: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
174 let mut centroids = Vec::with_capacity(k);
175 let mut rng_state = 0x123456789u64;
176
177 centroids.push(points[0].clone());
178 let mut min_dists: Vec<f32> = points.par_iter().map(|p| l2_sq(p, ¢roids[0])).collect();
180
181 while centroids.len() < k {
182 let total: f32 = min_dists.iter().sum();
183 rng_state = rng_state
184 .wrapping_mul(6364136223846793005)
185 .wrapping_add(1442695040888963407);
186 let r = (rng_state >> 33) as f32 / (u32::MAX as f32);
187 let target = r * total;
188 let mut cumsum = 0.0f32;
189 let mut chosen = points.len() - 1;
190 for (i, &d) in min_dists.iter().enumerate() {
191 cumsum += d;
192 if cumsum >= target {
193 chosen = i;
194 break;
195 }
196 }
197 let new_centroid = points[chosen].clone();
198 points
200 .par_iter()
201 .zip(min_dists.par_iter_mut())
202 .for_each(|(p, min_d)| {
203 let d = l2_sq(p, &new_centroid);
204 if d < *min_d {
205 *min_d = d;
206 }
207 });
208 centroids.push(new_centroid);
209 }
210 centroids
211}
212
213fn nearest_centroid(point: &[f32], centroids: &[Vec<f32>]) -> usize {
214 centroids
215 .iter()
216 .enumerate()
217 .map(|(i, c)| (i, l2_sq(point, c)))
218 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
219 .map(|(i, _)| i)
220 .unwrap_or(0)
221}
222
223fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
224 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
225}
226
227pub fn kmeans_centroids(vectors: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
230 let k_eff = k.min(vectors.len());
231 kmeans(vectors, k_eff, max_iter)
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 fn unit_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
239 (0..n)
240 .map(|i| {
241 let mut v = vec![0.0f32; dim];
242 v[i % dim] = 1.0;
243 v
244 })
245 .collect()
246 }
247
248 #[test]
249 fn encode_decode_roundtrip_approx() {
250 let dim = 8;
251 let vecs = unit_vecs(64, dim);
252 let cb = PQCodebook::train(&vecs, 2, 4, 50).unwrap();
253 for v in &vecs {
254 let codes = cb.encode(v);
255 assert_eq!(codes.len(), 2);
256 let decoded = cb.decode(&codes);
257 assert_eq!(decoded.len(), dim);
258 }
259 }
260
261 #[test]
262 fn adc_distance_non_negative() {
263 let dim = 8;
264 let vecs = unit_vecs(32, dim);
265 let cb = PQCodebook::train(&vecs, 2, 4, 50).unwrap();
266 let query = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
267 let table = cb.compute_adc_table(&query);
268 for v in &vecs {
269 let codes = cb.encode(v);
270 let dist = cb.adc_distance(&codes, &table);
271 assert!(dist >= 0.0, "ADC distance must be non-negative");
272 }
273 }
274
275 #[test]
276 fn dim_not_divisible_errors() {
277 let vecs = unit_vecs(16, 9);
278 assert!(PQCodebook::train(&vecs, 4, 4, 10).is_err());
279 }
280
281 #[test]
282 fn nearest_neighbor_rank_preserved() {
283 let dim = 8;
285 let mut vecs: Vec<Vec<f32>> = Vec::new();
286 for _ in 0..20 {
287 let mut v = vec![0.0f32; dim];
288 v[0] = 1.0;
289 vecs.push(v);
290 }
291 for _ in 0..20 {
292 let mut v = vec![0.0f32; dim];
293 v[7] = 1.0;
294 vecs.push(v);
295 }
296 let cb = PQCodebook::train(&vecs, 2, 4, 100).unwrap();
297 let q1 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
298 let q2 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0];
299 let t1 = cb.compute_adc_table(&q1);
300 let t2 = cb.compute_adc_table(&q2);
301 let code1 = cb.encode(&vecs[0]);
302 let code2 = cb.encode(&vecs[39]);
303 assert!(cb.adc_distance(&code1, &t1) < cb.adc_distance(&code2, &t1));
305 assert!(cb.adc_distance(&code2, &t2) < cb.adc_distance(&code1, &t2));
307 }
308}