1use ailake_core::AilakeError;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct PQCodebook {
9 pub num_subvectors: usize,
11 pub num_centroids: usize,
13 pub sub_dim: usize,
15 pub centroids: Vec<Vec<Vec<f32>>>,
17}
18
19impl PQCodebook {
20 pub fn train(
22 vectors: &[Vec<f32>],
23 num_subvectors: usize,
24 num_centroids: usize,
25 max_iter: usize,
26 ) -> Result<Self, AilakeError> {
27 Self::train_with_kmeans(vectors, num_subvectors, num_centroids, max_iter, kmeans)
28 }
29
30 pub fn train_with_kmeans<F>(
35 vectors: &[Vec<f32>],
36 num_subvectors: usize,
37 num_centroids: usize,
38 max_iter: usize,
39 kmeans_fn: F,
40 ) -> Result<Self, AilakeError>
41 where
42 F: Fn(&[Vec<f32>], usize, usize) -> Vec<Vec<f32>>,
43 {
44 if vectors.is_empty() {
45 return Err(AilakeError::Catalog(
46 "PQ training requires at least 1 vector".into(),
47 ));
48 }
49 let dim = vectors[0].len();
50 if !dim.is_multiple_of(num_subvectors) {
51 return Err(AilakeError::Catalog(format!(
52 "dim {dim} not divisible by num_subvectors {num_subvectors}"
53 )));
54 }
55 let sub_dim = dim / num_subvectors;
56 let n_train = num_centroids.min(vectors.len());
57
58 let mut centroids = Vec::with_capacity(num_subvectors);
59 for m in 0..num_subvectors {
60 let start = m * sub_dim;
61 let end = start + sub_dim;
62 let sub_vecs: Vec<Vec<f32>> = vectors.iter().map(|v| v[start..end].to_vec()).collect();
63 let sub_centroids = kmeans_fn(&sub_vecs, n_train, max_iter);
64 centroids.push(sub_centroids);
65 }
66
67 Ok(Self {
68 num_subvectors,
69 num_centroids,
70 sub_dim,
71 centroids,
72 })
73 }
74
75 pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
77 let mut codes = Vec::with_capacity(self.num_subvectors);
78 for m in 0..self.num_subvectors {
79 let start = m * self.sub_dim;
80 let sub = &vector[start..start + self.sub_dim];
81 let best = self.centroids[m]
82 .iter()
83 .enumerate()
84 .map(|(k, c)| (k, l2_sq(sub, c)))
85 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
86 .map(|(k, _)| k)
87 .unwrap_or(0);
88 codes.push(best as u8);
89 }
90 codes
91 }
92
93 pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
95 let mut out = Vec::with_capacity(self.num_subvectors * self.sub_dim);
96 for (m, &code) in codes.iter().enumerate() {
97 out.extend_from_slice(&self.centroids[m][code as usize]);
98 }
99 out
100 }
101
102 pub fn compute_adc_table(&self, query: &[f32]) -> Vec<Vec<f32>> {
107 (0..self.num_subvectors)
108 .map(|m| {
109 let start = m * self.sub_dim;
110 let q_sub = &query[start..start + self.sub_dim];
111 self.centroids[m].iter().map(|c| l2_sq(q_sub, c)).collect()
112 })
113 .collect()
114 }
115
116 pub fn adc_distance(&self, codes: &[u8], table: &[Vec<f32>]) -> f32 {
118 codes
119 .iter()
120 .enumerate()
121 .map(|(m, &c)| table[m][c as usize])
122 .sum()
123 }
124}
125
126fn kmeans(points: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
128 let dim = points[0].len();
129 let mut centroids = kmeans_pp_init(points, k);
130
131 for _ in 0..max_iter {
132 let assignments: Vec<usize> = points
134 .iter()
135 .map(|p| nearest_centroid(p, ¢roids))
136 .collect();
137
138 let mut new_centroids = vec![vec![0.0f32; dim]; k];
140 let mut counts = vec![0usize; k];
141 for (point, &assigned) in points.iter().zip(assignments.iter()) {
142 for (d, &v) in point.iter().enumerate() {
143 new_centroids[assigned][d] += v;
144 }
145 counts[assigned] += 1;
146 }
147 let mut converged = true;
148 for (i, count) in counts.iter().enumerate() {
149 if *count > 0 {
150 let scale = *count as f32;
151 for x in new_centroids[i].iter_mut() {
152 *x /= scale;
153 }
154 } else {
155 new_centroids[i] = centroids[i].clone();
157 }
158 if l2_sq(&new_centroids[i], ¢roids[i]) > 1e-8 {
159 converged = false;
160 }
161 }
162 centroids = new_centroids;
163 if converged {
164 break;
165 }
166 }
167 centroids
168}
169
170fn kmeans_pp_init(points: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
172 let mut centroids = vec![points[0].clone()];
173 let mut rng_state = 0x123456789u64;
174
175 while centroids.len() < k {
176 let dists: Vec<f32> = points
177 .iter()
178 .map(|p| {
179 centroids
180 .iter()
181 .map(|c| l2_sq(p, c))
182 .fold(f32::INFINITY, f32::min)
183 })
184 .collect();
185 let total: f32 = dists.iter().sum();
186 rng_state = rng_state
188 .wrapping_mul(6364136223846793005)
189 .wrapping_add(1442695040888963407);
190 let r = (rng_state >> 33) as f32 / (u32::MAX as f32);
191 let target = r * total;
192 let mut cumsum = 0.0f32;
193 let mut chosen = points.len() - 1;
194 for (i, &d) in dists.iter().enumerate() {
195 cumsum += d;
196 if cumsum >= target {
197 chosen = i;
198 break;
199 }
200 }
201 centroids.push(points[chosen].clone());
202 }
203 centroids
204}
205
206fn nearest_centroid(point: &[f32], centroids: &[Vec<f32>]) -> usize {
207 centroids
208 .iter()
209 .enumerate()
210 .map(|(i, c)| (i, l2_sq(point, c)))
211 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
212 .map(|(i, _)| i)
213 .unwrap_or(0)
214}
215
216fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
217 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
218}
219
220pub fn kmeans_centroids(vectors: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
223 let k_eff = k.min(vectors.len());
224 kmeans(vectors, k_eff, max_iter)
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 fn unit_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
232 (0..n)
233 .map(|i| {
234 let mut v = vec![0.0f32; dim];
235 v[i % dim] = 1.0;
236 v
237 })
238 .collect()
239 }
240
241 #[test]
242 fn encode_decode_roundtrip_approx() {
243 let dim = 8;
244 let vecs = unit_vecs(64, dim);
245 let cb = PQCodebook::train(&vecs, 2, 4, 50).unwrap();
246 for v in &vecs {
247 let codes = cb.encode(v);
248 assert_eq!(codes.len(), 2);
249 let decoded = cb.decode(&codes);
250 assert_eq!(decoded.len(), dim);
251 }
252 }
253
254 #[test]
255 fn adc_distance_non_negative() {
256 let dim = 8;
257 let vecs = unit_vecs(32, dim);
258 let cb = PQCodebook::train(&vecs, 2, 4, 50).unwrap();
259 let query = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
260 let table = cb.compute_adc_table(&query);
261 for v in &vecs {
262 let codes = cb.encode(v);
263 let dist = cb.adc_distance(&codes, &table);
264 assert!(dist >= 0.0, "ADC distance must be non-negative");
265 }
266 }
267
268 #[test]
269 fn dim_not_divisible_errors() {
270 let vecs = unit_vecs(16, 9);
271 assert!(PQCodebook::train(&vecs, 4, 4, 10).is_err());
272 }
273
274 #[test]
275 fn nearest_neighbor_rank_preserved() {
276 let dim = 8;
278 let mut vecs: Vec<Vec<f32>> = Vec::new();
279 for _ in 0..20 {
280 let mut v = vec![0.0f32; dim];
281 v[0] = 1.0;
282 vecs.push(v);
283 }
284 for _ in 0..20 {
285 let mut v = vec![0.0f32; dim];
286 v[7] = 1.0;
287 vecs.push(v);
288 }
289 let cb = PQCodebook::train(&vecs, 2, 4, 100).unwrap();
290 let q1 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
291 let q2 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0];
292 let t1 = cb.compute_adc_table(&q1);
293 let t2 = cb.compute_adc_table(&q2);
294 let code1 = cb.encode(&vecs[0]);
295 let code2 = cb.encode(&vecs[39]);
296 assert!(cb.adc_distance(&code1, &t1) < cb.adc_distance(&code2, &t1));
298 assert!(cb.adc_distance(&code2, &t2) < cb.adc_distance(&code1, &t2));
300 }
301}