1use ailake_core::AilakeError;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct PQCodebook {
10 pub num_subvectors: usize,
12 pub num_centroids: usize,
14 pub sub_dim: usize,
16 pub centroids: Vec<Vec<Vec<f32>>>,
18}
19
20impl PQCodebook {
21 pub fn train(
23 vectors: &[Vec<f32>],
24 num_subvectors: usize,
25 num_centroids: usize,
26 max_iter: usize,
27 ) -> Result<Self, AilakeError> {
28 Self::train_with_kmeans(vectors, num_subvectors, num_centroids, max_iter, kmeans)
29 }
30
31 pub fn train_with_kmeans<F>(
36 vectors: &[Vec<f32>],
37 num_subvectors: usize,
38 num_centroids: usize,
39 max_iter: usize,
40 kmeans_fn: F,
41 ) -> Result<Self, AilakeError>
42 where
43 F: Fn(&[Vec<f32>], usize, usize) -> Vec<Vec<f32>>,
44 {
45 if vectors.is_empty() {
46 return Err(AilakeError::Catalog(
47 "PQ training requires at least 1 vector".into(),
48 ));
49 }
50 let dim = vectors[0].len();
51 if !dim.is_multiple_of(num_subvectors) {
52 return Err(AilakeError::Catalog(format!(
53 "dim {dim} not divisible by num_subvectors {num_subvectors}"
54 )));
55 }
56 let sub_dim = dim / num_subvectors;
57 let n_train = num_centroids.min(vectors.len());
58
59 let mut centroids = Vec::with_capacity(num_subvectors);
60 for m in 0..num_subvectors {
61 let start = m * sub_dim;
62 let end = start + sub_dim;
63 let sub_vecs: Vec<Vec<f32>> = vectors.iter().map(|v| v[start..end].to_vec()).collect();
64 let sub_centroids = kmeans_fn(&sub_vecs, n_train, max_iter);
65 centroids.push(sub_centroids);
66 }
67
68 Ok(Self {
69 num_subvectors,
70 num_centroids,
71 sub_dim,
72 centroids,
73 })
74 }
75
76 pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
78 let mut codes = Vec::with_capacity(self.num_subvectors);
79 for m in 0..self.num_subvectors {
80 let start = m * self.sub_dim;
81 let sub = &vector[start..start + self.sub_dim];
82 let best = self.centroids[m]
83 .iter()
84 .enumerate()
85 .map(|(k, c)| (k, l2_sq(sub, c)))
86 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
87 .map(|(k, _)| k)
88 .unwrap_or(0);
89 codes.push(best as u8);
90 }
91 codes
92 }
93
94 pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
96 let mut out = Vec::with_capacity(self.num_subvectors * self.sub_dim);
97 for (m, &code) in codes.iter().enumerate() {
98 out.extend_from_slice(&self.centroids[m][code as usize]);
99 }
100 out
101 }
102
103 pub fn compute_adc_table(&self, query: &[f32]) -> Vec<Vec<f32>> {
108 (0..self.num_subvectors)
109 .map(|m| {
110 let start = m * self.sub_dim;
111 let q_sub = &query[start..start + self.sub_dim];
112 self.centroids[m].iter().map(|c| l2_sq(q_sub, c)).collect()
113 })
114 .collect()
115 }
116
117 pub fn adc_distance(&self, codes: &[u8], table: &[Vec<f32>]) -> f32 {
119 codes
120 .iter()
121 .enumerate()
122 .map(|(m, &c)| table[m][c as usize])
123 .sum()
124 }
125}
126
127fn kmeans(points: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
129 let dim = points[0].len();
130 let mut centroids = kmeans_pp_init(points, k);
131
132 for _ in 0..max_iter {
133 let assignments: Vec<usize> = points
135 .iter()
136 .map(|p| nearest_centroid(p, ¢roids))
137 .collect();
138
139 let mut new_centroids = vec![vec![0.0f32; dim]; k];
141 let mut counts = vec![0usize; k];
142 for (point, &assigned) in points.iter().zip(assignments.iter()) {
143 for (d, &v) in point.iter().enumerate() {
144 new_centroids[assigned][d] += v;
145 }
146 counts[assigned] += 1;
147 }
148 let mut converged = true;
149 for (i, count) in counts.iter().enumerate() {
150 if *count > 0 {
151 let scale = *count as f32;
152 for x in new_centroids[i].iter_mut() {
153 *x /= scale;
154 }
155 } else {
156 new_centroids[i] = centroids[i].clone();
158 }
159 if l2_sq(&new_centroids[i], ¢roids[i]) > 1e-8 {
160 converged = false;
161 }
162 }
163 centroids = new_centroids;
164 if converged {
165 break;
166 }
167 }
168 centroids
169}
170
171fn kmeans_pp_init(points: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
173 let mut centroids = vec![points[0].clone()];
174 let mut rng_state = 0x123456789u64;
175
176 while centroids.len() < k {
177 let dists: Vec<f32> = points
178 .iter()
179 .map(|p| {
180 centroids
181 .iter()
182 .map(|c| l2_sq(p, c))
183 .fold(f32::INFINITY, f32::min)
184 })
185 .collect();
186 let total: f32 = dists.iter().sum();
187 rng_state = rng_state
189 .wrapping_mul(6364136223846793005)
190 .wrapping_add(1442695040888963407);
191 let r = (rng_state >> 33) as f32 / (u32::MAX as f32);
192 let target = r * total;
193 let mut cumsum = 0.0f32;
194 let mut chosen = points.len() - 1;
195 for (i, &d) in dists.iter().enumerate() {
196 cumsum += d;
197 if cumsum >= target {
198 chosen = i;
199 break;
200 }
201 }
202 centroids.push(points[chosen].clone());
203 }
204 centroids
205}
206
207fn nearest_centroid(point: &[f32], centroids: &[Vec<f32>]) -> usize {
208 centroids
209 .iter()
210 .enumerate()
211 .map(|(i, c)| (i, l2_sq(point, c)))
212 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
213 .map(|(i, _)| i)
214 .unwrap_or(0)
215}
216
217fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
218 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
219}
220
221pub fn kmeans_centroids(vectors: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
224 let k_eff = k.min(vectors.len());
225 kmeans(vectors, k_eff, max_iter)
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 fn unit_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
233 (0..n)
234 .map(|i| {
235 let mut v = vec![0.0f32; dim];
236 v[i % dim] = 1.0;
237 v
238 })
239 .collect()
240 }
241
242 #[test]
243 fn encode_decode_roundtrip_approx() {
244 let dim = 8;
245 let vecs = unit_vecs(64, dim);
246 let cb = PQCodebook::train(&vecs, 2, 4, 50).unwrap();
247 for v in &vecs {
248 let codes = cb.encode(v);
249 assert_eq!(codes.len(), 2);
250 let decoded = cb.decode(&codes);
251 assert_eq!(decoded.len(), dim);
252 }
253 }
254
255 #[test]
256 fn adc_distance_non_negative() {
257 let dim = 8;
258 let vecs = unit_vecs(32, dim);
259 let cb = PQCodebook::train(&vecs, 2, 4, 50).unwrap();
260 let query = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
261 let table = cb.compute_adc_table(&query);
262 for v in &vecs {
263 let codes = cb.encode(v);
264 let dist = cb.adc_distance(&codes, &table);
265 assert!(dist >= 0.0, "ADC distance must be non-negative");
266 }
267 }
268
269 #[test]
270 fn dim_not_divisible_errors() {
271 let vecs = unit_vecs(16, 9);
272 assert!(PQCodebook::train(&vecs, 4, 4, 10).is_err());
273 }
274
275 #[test]
276 fn nearest_neighbor_rank_preserved() {
277 let dim = 8;
279 let mut vecs: Vec<Vec<f32>> = Vec::new();
280 for _ in 0..20 {
281 let mut v = vec![0.0f32; dim];
282 v[0] = 1.0;
283 vecs.push(v);
284 }
285 for _ in 0..20 {
286 let mut v = vec![0.0f32; dim];
287 v[7] = 1.0;
288 vecs.push(v);
289 }
290 let cb = PQCodebook::train(&vecs, 2, 4, 100).unwrap();
291 let q1 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
292 let q2 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0];
293 let t1 = cb.compute_adc_table(&q1);
294 let t2 = cb.compute_adc_table(&q2);
295 let code1 = cb.encode(&vecs[0]);
296 let code2 = cb.encode(&vecs[39]);
297 assert!(cb.adc_distance(&code1, &t1) < cb.adc_distance(&code2, &t1));
299 assert!(cb.adc_distance(&code2, &t2) < cb.adc_distance(&code1, &t2));
301 }
302}