1use crate::DiskAnnError;
39use rand::prelude::*;
40use rayon::prelude::*;
41use serde::{Deserialize, Serialize};
42use std::fs::File;
43use std::io::{BufReader, BufWriter};
44
45#[derive(Clone, Copy, Debug)]
47pub struct PQConfig {
48 pub num_subspaces: usize,
51 pub num_centroids: usize,
53 pub kmeans_iterations: usize,
55 pub training_sample_size: usize,
57}
58
59impl Default for PQConfig {
60 fn default() -> Self {
61 Self {
62 num_subspaces: 8,
63 num_centroids: 256,
64 kmeans_iterations: 20,
65 training_sample_size: 50_000,
66 }
67 }
68}
69
70#[derive(Serialize, Deserialize, Clone)]
72pub struct ProductQuantizer {
73 dim: usize,
75 num_subspaces: usize,
77 num_centroids: usize,
79 subspace_dim: usize,
81 codebooks: Vec<f32>,
84}
85
86impl ProductQuantizer {
87 pub fn train(vectors: &[Vec<f32>], config: PQConfig) -> Result<Self, DiskAnnError> {
89 if vectors.is_empty() {
90 return Err(DiskAnnError::IndexError("No vectors to train on".into()));
91 }
92
93 let dim = vectors[0].len();
94 if config.num_centroids > 256 {
95 return Err(DiskAnnError::IndexError(format!(
96 "num_centroids {} exceeds 256 (codes are u8)",
97 config.num_centroids
98 )));
99 }
100 if dim % config.num_subspaces != 0 {
101 return Err(DiskAnnError::IndexError(format!(
102 "Dimension {} not divisible by num_subspaces {}",
103 dim, config.num_subspaces
104 )));
105 }
106
107 let subspace_dim = dim / config.num_subspaces;
108
109 let training_vectors: Vec<&Vec<f32>> = if config.training_sample_size > 0
111 && vectors.len() > config.training_sample_size
112 {
113 let mut rng = thread_rng();
114 vectors
115 .choose_multiple(&mut rng, config.training_sample_size)
116 .collect()
117 } else {
118 vectors.iter().collect()
119 };
120
121 let codebooks_per_subspace: Vec<Vec<f32>> = (0..config.num_subspaces)
123 .into_par_iter()
124 .map(|m| {
125 let start = m * subspace_dim;
127 let end = start + subspace_dim;
128
129 let subspace_vectors: Vec<Vec<f32>> = training_vectors
130 .iter()
131 .map(|v| v[start..end].to_vec())
132 .collect();
133
134 kmeans(
136 &subspace_vectors,
137 config.num_centroids,
138 config.kmeans_iterations,
139 )
140 })
141 .collect();
142
143 let mut codebooks =
145 Vec::with_capacity(config.num_subspaces * config.num_centroids * subspace_dim);
146 for cb in &codebooks_per_subspace {
147 codebooks.extend_from_slice(cb);
148 }
149
150 Ok(Self {
151 dim,
152 num_subspaces: config.num_subspaces,
153 num_centroids: config.num_centroids,
154 subspace_dim,
155 codebooks,
156 })
157 }
158
159 pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
161 assert_eq!(vector.len(), self.dim, "Vector dimension mismatch");
162
163 let mut codes = Vec::with_capacity(self.num_subspaces);
164
165 for m in 0..self.num_subspaces {
166 let start = m * self.subspace_dim;
167 let end = start + self.subspace_dim;
168 let subvec = &vector[start..end];
169
170 let mut best_centroid = 0u8;
172 let mut best_dist = f32::MAX;
173
174 for k in 0..self.num_centroids {
175 let centroid = self.get_centroid(m, k);
176 let dist = l2_distance(subvec, centroid);
177 if dist < best_dist {
178 best_dist = dist;
179 best_centroid = k as u8;
180 }
181 }
182
183 codes.push(best_centroid);
184 }
185
186 codes
187 }
188
189 pub fn encode_batch(&self, vectors: &[Vec<f32>]) -> Vec<Vec<u8>> {
191 vectors.par_iter().map(|v| self.encode(v)).collect()
192 }
193
194 pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
197 assert_eq!(query.len(), self.dim, "Query dimension mismatch");
198 assert_eq!(codes.len(), self.num_subspaces, "Code length mismatch");
199
200 let mut total_dist = 0.0f32;
201
202 for m in 0..self.num_subspaces {
203 let start = m * self.subspace_dim;
204 let end = start + self.subspace_dim;
205 let query_sub = &query[start..end];
206
207 let centroid_id = codes[m] as usize;
208 let centroid = self.get_centroid(m, centroid_id);
209
210 total_dist += l2_distance(query_sub, centroid);
211 }
212
213 total_dist
214 }
215
216 pub fn create_distance_table(&self, query: &[f32]) -> Vec<f32> {
220 assert_eq!(query.len(), self.dim);
221
222 let mut table = Vec::with_capacity(self.num_subspaces * self.num_centroids);
223
224 for m in 0..self.num_subspaces {
225 let start = m * self.subspace_dim;
226 let end = start + self.subspace_dim;
227 let query_sub = &query[start..end];
228
229 for k in 0..self.num_centroids {
230 let centroid = self.get_centroid(m, k);
231 table.push(l2_distance(query_sub, centroid));
232 }
233 }
234
235 table
236 }
237
238 #[inline]
240 pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
241 let mut dist = 0.0f32;
242 for (m, &code) in codes.iter().enumerate() {
243 let idx = m * self.num_centroids + code as usize;
244 dist += table[idx];
245 }
246 dist
247 }
248
249 pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
251 assert_eq!(codes.len(), self.num_subspaces);
252
253 let mut vector = Vec::with_capacity(self.dim);
254
255 for (m, &code) in codes.iter().enumerate() {
256 let centroid = self.get_centroid(m, code as usize);
257 vector.extend_from_slice(centroid);
258 }
259
260 vector
261 }
262
263 #[inline]
265 fn get_centroid(&self, m: usize, k: usize) -> &[f32] {
266 let offset = (m * self.num_centroids + k) * self.subspace_dim;
267 &self.codebooks[offset..offset + self.subspace_dim]
268 }
269
270 pub fn save(&self, path: &str) -> Result<(), DiskAnnError> {
272 let file = File::create(path)?;
273 let writer = BufWriter::new(file);
274 bincode::serialize_into(writer, self)?;
275 Ok(())
276 }
277
278 pub fn load(path: &str) -> Result<Self, DiskAnnError> {
280 let file = File::open(path)?;
281 let reader = BufReader::new(file);
282 let pq: Self = bincode::deserialize_from(reader)?;
283 Ok(pq)
284 }
285
286 pub fn stats(&self) -> PQStats {
288 PQStats {
289 dim: self.dim,
290 num_subspaces: self.num_subspaces,
291 num_centroids: self.num_centroids,
292 subspace_dim: self.subspace_dim,
293 codebook_size_bytes: self.codebooks.len() * 4,
294 code_size_bytes: self.num_subspaces,
295 compression_ratio: (self.dim * 4) as f32 / self.num_subspaces as f32,
296 }
297 }
298}
299
300#[derive(Debug, Clone)]
302pub struct PQStats {
303 pub dim: usize,
304 pub num_subspaces: usize,
305 pub num_centroids: usize,
306 pub subspace_dim: usize,
307 pub codebook_size_bytes: usize,
308 pub code_size_bytes: usize,
309 pub compression_ratio: f32,
310}
311
312impl std::fmt::Display for PQStats {
313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314 writeln!(f, "Product Quantizer Stats:")?;
315 writeln!(f, " Original dimension: {}", self.dim)?;
316 writeln!(f, " Subspaces (M): {}", self.num_subspaces)?;
317 writeln!(f, " Centroids per subspace (K): {}", self.num_centroids)?;
318 writeln!(f, " Subspace dimension: {}", self.subspace_dim)?;
319 writeln!(f, " Codebook size: {} bytes", self.codebook_size_bytes)?;
320 writeln!(f, " Compressed code size: {} bytes", self.code_size_bytes)?;
321 writeln!(f, " Compression ratio: {:.1}x", self.compression_ratio)
322 }
323}
324
325fn kmeans(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<f32> {
329 if vectors.is_empty() {
330 return vec![0.0; k * 1]; }
332
333 let dim = vectors[0].len();
334 let n = vectors.len();
335 let effective_k = k.min(n); let mut centroids = Vec::with_capacity(k * dim);
339 let mut rng = thread_rng();
340
341 let first = rng.gen_range(0..n);
343 centroids.extend_from_slice(&vectors[first]);
344
345 for _ in 1..effective_k {
348 let num_current = centroids.len() / dim;
349 let distances: Vec<f32> = vectors
350 .iter()
351 .map(|v| {
352 let mut min_dist = f32::MAX;
353 for c in 0..num_current {
354 let centroid = ¢roids[c * dim..(c + 1) * dim];
355 let d = l2_distance(v, centroid);
356 min_dist = min_dist.min(d);
357 }
358 min_dist
359 })
360 .collect();
361
362 let total: f32 = distances.iter().sum();
364 if total == 0.0 {
365 let idx = rng.gen_range(0..n);
367 centroids.extend_from_slice(&vectors[idx]);
368 } else {
369 let threshold = rng.r#gen::<f32>() * total;
370 let mut cumsum = 0.0;
371 let mut picked = false;
372 for (i, &d) in distances.iter().enumerate() {
373 cumsum += d;
374 if cumsum >= threshold {
375 centroids.extend_from_slice(&vectors[i]);
376 picked = true;
377 break;
378 }
379 }
380 if !picked {
382 centroids.extend_from_slice(&vectors[n - 1]);
383 }
384 }
385 }
386
387 while centroids.len() < k * dim {
389 let idx = (centroids.len() / dim) % effective_k;
391 let centroid = centroids[idx * dim..(idx + 1) * dim].to_vec();
392 centroids.extend_from_slice(¢roid);
393 }
394 centroids.truncate(k * dim);
395
396 let mut assignments: Vec<usize>;
398
399 for _ in 0..iterations {
400 assignments = vectors
402 .par_iter()
403 .map(|v| {
404 let mut best_c = 0;
405 let mut best_dist = f32::MAX;
406 for c in 0..k {
407 let centroid = ¢roids[c * dim..(c + 1) * dim];
408 let d = l2_distance(v, centroid);
409 if d < best_dist {
410 best_dist = d;
411 best_c = c;
412 }
413 }
414 best_c
415 })
416 .collect();
417
418 let mut new_centroids = vec![0.0f32; k * dim];
420 let mut counts = vec![0usize; k];
421
422 for (i, &c) in assignments.iter().enumerate() {
423 counts[c] += 1;
424 for (j, &val) in vectors[i].iter().enumerate() {
425 new_centroids[c * dim + j] += val;
426 }
427 }
428
429 for c in 0..k {
431 if counts[c] > 0 {
432 for j in 0..dim {
433 new_centroids[c * dim + j] /= counts[c] as f32;
434 }
435 } else {
436 let idx = rng.gen_range(0..n);
438 for j in 0..dim {
439 new_centroids[c * dim + j] = vectors[idx][j];
440 }
441 }
442 }
443
444 centroids = new_centroids;
445 }
446
447 centroids
448}
449
450#[inline]
452fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
453 a.iter()
454 .zip(b.iter())
455 .map(|(x, y)| {
456 let d = x - y;
457 d * d
458 })
459 .sum()
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
467 use rand::SeedableRng;
468 let mut rng = StdRng::seed_from_u64(seed);
469 (0..n)
470 .map(|_| (0..dim).map(|_| rng.r#gen::<f32>()).collect())
471 .collect()
472 }
473
474 #[test]
475 fn test_pq_encode_decode() {
476 let vectors = random_vectors(1000, 64, 42);
477 let config = PQConfig {
478 num_subspaces: 8,
479 num_centroids: 256,
480 kmeans_iterations: 10,
481 training_sample_size: 0,
482 };
483
484 let pq = ProductQuantizer::train(&vectors, config).unwrap();
485
486 let original = &vectors[0];
488 let codes = pq.encode(original);
489 let decoded = pq.decode(&codes);
490
491 assert_eq!(decoded.len(), original.len());
493
494 let dist = l2_distance(original, &decoded);
496 assert!(
497 dist < original.len() as f32 * 0.1,
498 "Reconstruction error too high: {dist}"
499 );
500 }
501
502 #[test]
503 fn test_pq_asymmetric_distance() {
504 let vectors = random_vectors(500, 32, 123);
505 let config = PQConfig {
506 num_subspaces: 4,
507 num_centroids: 64,
508 kmeans_iterations: 10,
509 training_sample_size: 0,
510 };
511
512 let pq = ProductQuantizer::train(&vectors, config).unwrap();
513
514 let query = &vectors[0];
515 let target = &vectors[100];
516
517 let codes = pq.encode(target);
518
519 let asym_dist = pq.asymmetric_distance(query, &codes);
521 let decoded = pq.decode(&codes);
522 let exact_dist = l2_distance(query, &decoded);
523
524 assert!(
526 (asym_dist - exact_dist).abs() < 1e-5,
527 "asym={asym_dist}, exact={exact_dist}"
528 );
529 }
530
531 #[test]
532 fn test_pq_distance_table() {
533 let vectors = random_vectors(500, 32, 456);
534 let config = PQConfig {
535 num_subspaces: 4,
536 num_centroids: 64,
537 kmeans_iterations: 10,
538 training_sample_size: 0,
539 };
540
541 let pq = ProductQuantizer::train(&vectors, config).unwrap();
542
543 let query = &vectors[0];
544 let table = pq.create_distance_table(query);
545
546 for target in vectors.iter().take(10) {
548 let codes = pq.encode(target);
549 let direct = pq.asymmetric_distance(query, &codes);
550 let table_dist = pq.distance_with_table(&table, &codes);
551
552 assert!(
553 (direct - table_dist).abs() < 1e-5,
554 "direct={direct}, table={table_dist}"
555 );
556 }
557 }
558
559 #[test]
560 fn test_pq_batch_encode() {
561 let vectors = random_vectors(100, 64, 789);
562 let config = PQConfig::default();
563
564 let pq = ProductQuantizer::train(&vectors, config).unwrap();
565 let codes = pq.encode_batch(&vectors);
566
567 assert_eq!(codes.len(), vectors.len());
568 for code in &codes {
569 assert_eq!(code.len(), config.num_subspaces);
570 }
571 }
572
573 #[test]
574 fn test_pq_save_load() {
575 let vectors = random_vectors(200, 64, 111);
576 let config = PQConfig {
577 num_subspaces: 8,
578 num_centroids: 128,
579 kmeans_iterations: 5,
580 training_sample_size: 0,
581 };
582
583 let pq = ProductQuantizer::train(&vectors, config).unwrap();
584 let codes_before = pq.encode(&vectors[0]);
585
586 let path = "test_pq.bin";
587 pq.save(path).unwrap();
588
589 let pq_loaded = ProductQuantizer::load(path).unwrap();
590 let codes_after = pq_loaded.encode(&vectors[0]);
591
592 assert_eq!(codes_before, codes_after);
593
594 std::fs::remove_file(path).ok();
595 }
596
597 #[test]
598 fn test_pq_stats() {
599 let vectors = random_vectors(100, 128, 222);
600 let config = PQConfig {
601 num_subspaces: 8,
602 num_centroids: 256,
603 kmeans_iterations: 5,
604 training_sample_size: 0,
605 };
606
607 let pq = ProductQuantizer::train(&vectors, config).unwrap();
608 let stats = pq.stats();
609
610 assert_eq!(stats.dim, 128);
611 assert_eq!(stats.num_subspaces, 8);
612 assert_eq!(stats.num_centroids, 256);
613 assert_eq!(stats.subspace_dim, 16);
614 assert_eq!(stats.code_size_bytes, 8);
615 assert!(stats.compression_ratio > 50.0); println!("{}", stats);
618 }
619
620 #[test]
621 fn test_pq_preserves_ordering() {
622 let vectors = random_vectors(500, 64, 333);
623 let config = PQConfig {
624 num_subspaces: 8,
625 num_centroids: 256,
626 kmeans_iterations: 15,
627 training_sample_size: 0,
628 };
629
630 let pq = ProductQuantizer::train(&vectors, config).unwrap();
631
632 let query = &vectors[0];
633
634 let mut true_dists: Vec<(usize, f32)> = vectors
636 .iter()
637 .enumerate()
638 .skip(1)
639 .map(|(i, v)| (i, l2_distance(query, v)))
640 .collect();
641 true_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
642
643 let table = pq.create_distance_table(query);
645 let codes: Vec<Vec<u8>> = vectors.iter().map(|v| pq.encode(v)).collect();
646
647 let mut pq_dists: Vec<(usize, f32)> = codes
648 .iter()
649 .enumerate()
650 .skip(1)
651 .map(|(i, c)| (i, pq.distance_with_table(&table, c)))
652 .collect();
653 pq_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
654
655 let true_top10: std::collections::HashSet<_> =
657 true_dists.iter().take(10).map(|(i, _)| *i).collect();
658 let pq_top10: std::collections::HashSet<_> =
659 pq_dists.iter().take(10).map(|(i, _)| *i).collect();
660
661 let recall: f32 = true_top10.intersection(&pq_top10).count() as f32 / 10.0;
662 assert!(
663 recall >= 0.5,
664 "PQ recall@10 too low: {recall}. Expected >= 0.5"
665 );
666 }
667}