1use crate::DiskAnnError;
39use rand::prelude::*;
40use rayon::prelude::*;
41use serde::{Deserialize, Serialize};
42use std::fs::File;
43use std::io::{BufReader, BufWriter};
44
45#[derive(Clone, Copy, Debug)]
47pub struct PQConfig {
48 pub num_subspaces: usize,
51 pub num_centroids: usize,
53 pub kmeans_iterations: usize,
55 pub training_sample_size: usize,
57}
58
59impl Default for PQConfig {
60 fn default() -> Self {
61 Self {
62 num_subspaces: 8,
63 num_centroids: 256,
64 kmeans_iterations: 20,
65 training_sample_size: 50_000,
66 }
67 }
68}
69
70#[derive(Serialize, Deserialize, Clone)]
72pub struct ProductQuantizer {
73 dim: usize,
75 num_subspaces: usize,
77 num_centroids: usize,
79 subspace_dim: usize,
81 codebooks: Vec<f32>,
84}
85
86impl ProductQuantizer {
87 pub fn train(vectors: &[Vec<f32>], config: PQConfig) -> Result<Self, DiskAnnError> {
89 if vectors.is_empty() {
90 return Err(DiskAnnError::IndexError("No vectors to train on".into()));
91 }
92
93 let dim = vectors[0].len();
94 if dim % config.num_subspaces != 0 {
95 return Err(DiskAnnError::IndexError(format!(
96 "Dimension {} not divisible by num_subspaces {}",
97 dim, config.num_subspaces
98 )));
99 }
100
101 let subspace_dim = dim / config.num_subspaces;
102
103 let training_vectors: Vec<&Vec<f32>> = if config.training_sample_size > 0
105 && vectors.len() > config.training_sample_size
106 {
107 let mut rng = thread_rng();
108 vectors
109 .choose_multiple(&mut rng, config.training_sample_size)
110 .collect()
111 } else {
112 vectors.iter().collect()
113 };
114
115 let codebooks_per_subspace: Vec<Vec<f32>> = (0..config.num_subspaces)
117 .into_par_iter()
118 .map(|m| {
119 let start = m * subspace_dim;
121 let end = start + subspace_dim;
122
123 let subspace_vectors: Vec<Vec<f32>> = training_vectors
124 .iter()
125 .map(|v| v[start..end].to_vec())
126 .collect();
127
128 kmeans(
130 &subspace_vectors,
131 config.num_centroids,
132 config.kmeans_iterations,
133 )
134 })
135 .collect();
136
137 let mut codebooks =
139 Vec::with_capacity(config.num_subspaces * config.num_centroids * subspace_dim);
140 for cb in &codebooks_per_subspace {
141 codebooks.extend_from_slice(cb);
142 }
143
144 Ok(Self {
145 dim,
146 num_subspaces: config.num_subspaces,
147 num_centroids: config.num_centroids,
148 subspace_dim,
149 codebooks,
150 })
151 }
152
153 pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
155 assert_eq!(vector.len(), self.dim, "Vector dimension mismatch");
156
157 let mut codes = Vec::with_capacity(self.num_subspaces);
158
159 for m in 0..self.num_subspaces {
160 let start = m * self.subspace_dim;
161 let end = start + self.subspace_dim;
162 let subvec = &vector[start..end];
163
164 let mut best_centroid = 0u8;
166 let mut best_dist = f32::MAX;
167
168 for k in 0..self.num_centroids {
169 let centroid = self.get_centroid(m, k);
170 let dist = l2_distance(subvec, centroid);
171 if dist < best_dist {
172 best_dist = dist;
173 best_centroid = k as u8;
174 }
175 }
176
177 codes.push(best_centroid);
178 }
179
180 codes
181 }
182
183 pub fn encode_batch(&self, vectors: &[Vec<f32>]) -> Vec<Vec<u8>> {
185 vectors.par_iter().map(|v| self.encode(v)).collect()
186 }
187
188 pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
191 assert_eq!(query.len(), self.dim, "Query dimension mismatch");
192 assert_eq!(codes.len(), self.num_subspaces, "Code length mismatch");
193
194 let mut total_dist = 0.0f32;
195
196 for m in 0..self.num_subspaces {
197 let start = m * self.subspace_dim;
198 let end = start + self.subspace_dim;
199 let query_sub = &query[start..end];
200
201 let centroid_id = codes[m] as usize;
202 let centroid = self.get_centroid(m, centroid_id);
203
204 total_dist += l2_distance(query_sub, centroid);
205 }
206
207 total_dist
208 }
209
210 pub fn create_distance_table(&self, query: &[f32]) -> Vec<f32> {
214 assert_eq!(query.len(), self.dim);
215
216 let mut table = Vec::with_capacity(self.num_subspaces * self.num_centroids);
217
218 for m in 0..self.num_subspaces {
219 let start = m * self.subspace_dim;
220 let end = start + self.subspace_dim;
221 let query_sub = &query[start..end];
222
223 for k in 0..self.num_centroids {
224 let centroid = self.get_centroid(m, k);
225 table.push(l2_distance(query_sub, centroid));
226 }
227 }
228
229 table
230 }
231
232 #[inline]
234 pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
235 let mut dist = 0.0f32;
236 for (m, &code) in codes.iter().enumerate() {
237 let idx = m * self.num_centroids + code as usize;
238 dist += table[idx];
239 }
240 dist
241 }
242
243 pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
245 assert_eq!(codes.len(), self.num_subspaces);
246
247 let mut vector = Vec::with_capacity(self.dim);
248
249 for (m, &code) in codes.iter().enumerate() {
250 let centroid = self.get_centroid(m, code as usize);
251 vector.extend_from_slice(centroid);
252 }
253
254 vector
255 }
256
257 #[inline]
259 fn get_centroid(&self, m: usize, k: usize) -> &[f32] {
260 let offset = (m * self.num_centroids + k) * self.subspace_dim;
261 &self.codebooks[offset..offset + self.subspace_dim]
262 }
263
264 pub fn save(&self, path: &str) -> Result<(), DiskAnnError> {
266 let file = File::create(path)?;
267 let writer = BufWriter::new(file);
268 bincode::serialize_into(writer, self)?;
269 Ok(())
270 }
271
272 pub fn load(path: &str) -> Result<Self, DiskAnnError> {
274 let file = File::open(path)?;
275 let reader = BufReader::new(file);
276 let pq: Self = bincode::deserialize_from(reader)?;
277 Ok(pq)
278 }
279
280 pub fn stats(&self) -> PQStats {
282 PQStats {
283 dim: self.dim,
284 num_subspaces: self.num_subspaces,
285 num_centroids: self.num_centroids,
286 subspace_dim: self.subspace_dim,
287 codebook_size_bytes: self.codebooks.len() * 4,
288 code_size_bytes: self.num_subspaces,
289 compression_ratio: (self.dim * 4) as f32 / self.num_subspaces as f32,
290 }
291 }
292}
293
294#[derive(Debug, Clone)]
296pub struct PQStats {
297 pub dim: usize,
298 pub num_subspaces: usize,
299 pub num_centroids: usize,
300 pub subspace_dim: usize,
301 pub codebook_size_bytes: usize,
302 pub code_size_bytes: usize,
303 pub compression_ratio: f32,
304}
305
306impl std::fmt::Display for PQStats {
307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 writeln!(f, "Product Quantizer Stats:")?;
309 writeln!(f, " Original dimension: {}", self.dim)?;
310 writeln!(f, " Subspaces (M): {}", self.num_subspaces)?;
311 writeln!(f, " Centroids per subspace (K): {}", self.num_centroids)?;
312 writeln!(f, " Subspace dimension: {}", self.subspace_dim)?;
313 writeln!(f, " Codebook size: {} bytes", self.codebook_size_bytes)?;
314 writeln!(f, " Compressed code size: {} bytes", self.code_size_bytes)?;
315 writeln!(f, " Compression ratio: {:.1}x", self.compression_ratio)
316 }
317}
318
319fn kmeans(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<f32> {
323 if vectors.is_empty() {
324 return vec![0.0; k * 1]; }
326
327 let dim = vectors[0].len();
328 let n = vectors.len();
329 let effective_k = k.min(n); let mut centroids = Vec::with_capacity(k * dim);
333 let mut rng = thread_rng();
334
335 let first = rng.gen_range(0..n);
337 centroids.extend_from_slice(&vectors[first]);
338
339 for _ in 1..effective_k {
342 let num_current = centroids.len() / dim;
343 let distances: Vec<f32> = vectors
344 .iter()
345 .map(|v| {
346 let mut min_dist = f32::MAX;
347 for c in 0..num_current {
348 let centroid = ¢roids[c * dim..(c + 1) * dim];
349 let d = l2_distance(v, centroid);
350 min_dist = min_dist.min(d);
351 }
352 min_dist
353 })
354 .collect();
355
356 let total: f32 = distances.iter().sum();
358 if total == 0.0 {
359 let idx = rng.gen_range(0..n);
361 centroids.extend_from_slice(&vectors[idx]);
362 } else {
363 let threshold = rng.r#gen::<f32>() * total;
364 let mut cumsum = 0.0;
365 let mut picked = false;
366 for (i, &d) in distances.iter().enumerate() {
367 cumsum += d;
368 if cumsum >= threshold {
369 centroids.extend_from_slice(&vectors[i]);
370 picked = true;
371 break;
372 }
373 }
374 if !picked {
376 centroids.extend_from_slice(&vectors[n - 1]);
377 }
378 }
379 }
380
381 while centroids.len() < k * dim {
383 let idx = (centroids.len() / dim) % effective_k;
385 let centroid = centroids[idx * dim..(idx + 1) * dim].to_vec();
386 centroids.extend_from_slice(¢roid);
387 }
388 centroids.truncate(k * dim);
389
390 let mut assignments: Vec<usize>;
392
393 for _ in 0..iterations {
394 assignments = vectors
396 .par_iter()
397 .map(|v| {
398 let mut best_c = 0;
399 let mut best_dist = f32::MAX;
400 for c in 0..k {
401 let centroid = ¢roids[c * dim..(c + 1) * dim];
402 let d = l2_distance(v, centroid);
403 if d < best_dist {
404 best_dist = d;
405 best_c = c;
406 }
407 }
408 best_c
409 })
410 .collect();
411
412 let mut new_centroids = vec![0.0f32; k * dim];
414 let mut counts = vec![0usize; k];
415
416 for (i, &c) in assignments.iter().enumerate() {
417 counts[c] += 1;
418 for (j, &val) in vectors[i].iter().enumerate() {
419 new_centroids[c * dim + j] += val;
420 }
421 }
422
423 for c in 0..k {
425 if counts[c] > 0 {
426 for j in 0..dim {
427 new_centroids[c * dim + j] /= counts[c] as f32;
428 }
429 } else {
430 let idx = rng.gen_range(0..n);
432 for j in 0..dim {
433 new_centroids[c * dim + j] = vectors[idx][j];
434 }
435 }
436 }
437
438 centroids = new_centroids;
439 }
440
441 centroids
442}
443
444#[inline]
446fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
447 a.iter()
448 .zip(b.iter())
449 .map(|(x, y)| {
450 let d = x - y;
451 d * d
452 })
453 .sum()
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
461 use rand::SeedableRng;
462 let mut rng = StdRng::seed_from_u64(seed);
463 (0..n)
464 .map(|_| (0..dim).map(|_| rng.r#gen::<f32>()).collect())
465 .collect()
466 }
467
468 #[test]
469 fn test_pq_encode_decode() {
470 let vectors = random_vectors(1000, 64, 42);
471 let config = PQConfig {
472 num_subspaces: 8,
473 num_centroids: 256,
474 kmeans_iterations: 10,
475 training_sample_size: 0,
476 };
477
478 let pq = ProductQuantizer::train(&vectors, config).unwrap();
479
480 let original = &vectors[0];
482 let codes = pq.encode(original);
483 let decoded = pq.decode(&codes);
484
485 assert_eq!(decoded.len(), original.len());
487
488 let dist = l2_distance(original, &decoded);
490 assert!(
491 dist < original.len() as f32 * 0.1,
492 "Reconstruction error too high: {dist}"
493 );
494 }
495
496 #[test]
497 fn test_pq_asymmetric_distance() {
498 let vectors = random_vectors(500, 32, 123);
499 let config = PQConfig {
500 num_subspaces: 4,
501 num_centroids: 64,
502 kmeans_iterations: 10,
503 training_sample_size: 0,
504 };
505
506 let pq = ProductQuantizer::train(&vectors, config).unwrap();
507
508 let query = &vectors[0];
509 let target = &vectors[100];
510
511 let codes = pq.encode(target);
512
513 let asym_dist = pq.asymmetric_distance(query, &codes);
515 let decoded = pq.decode(&codes);
516 let exact_dist = l2_distance(query, &decoded);
517
518 assert!(
520 (asym_dist - exact_dist).abs() < 1e-5,
521 "asym={asym_dist}, exact={exact_dist}"
522 );
523 }
524
525 #[test]
526 fn test_pq_distance_table() {
527 let vectors = random_vectors(500, 32, 456);
528 let config = PQConfig {
529 num_subspaces: 4,
530 num_centroids: 64,
531 kmeans_iterations: 10,
532 training_sample_size: 0,
533 };
534
535 let pq = ProductQuantizer::train(&vectors, config).unwrap();
536
537 let query = &vectors[0];
538 let table = pq.create_distance_table(query);
539
540 for target in vectors.iter().take(10) {
542 let codes = pq.encode(target);
543 let direct = pq.asymmetric_distance(query, &codes);
544 let table_dist = pq.distance_with_table(&table, &codes);
545
546 assert!(
547 (direct - table_dist).abs() < 1e-5,
548 "direct={direct}, table={table_dist}"
549 );
550 }
551 }
552
553 #[test]
554 fn test_pq_batch_encode() {
555 let vectors = random_vectors(100, 64, 789);
556 let config = PQConfig::default();
557
558 let pq = ProductQuantizer::train(&vectors, config).unwrap();
559 let codes = pq.encode_batch(&vectors);
560
561 assert_eq!(codes.len(), vectors.len());
562 for code in &codes {
563 assert_eq!(code.len(), config.num_subspaces);
564 }
565 }
566
567 #[test]
568 fn test_pq_save_load() {
569 let vectors = random_vectors(200, 64, 111);
570 let config = PQConfig {
571 num_subspaces: 8,
572 num_centroids: 128,
573 kmeans_iterations: 5,
574 training_sample_size: 0,
575 };
576
577 let pq = ProductQuantizer::train(&vectors, config).unwrap();
578 let codes_before = pq.encode(&vectors[0]);
579
580 let path = "test_pq.bin";
581 pq.save(path).unwrap();
582
583 let pq_loaded = ProductQuantizer::load(path).unwrap();
584 let codes_after = pq_loaded.encode(&vectors[0]);
585
586 assert_eq!(codes_before, codes_after);
587
588 std::fs::remove_file(path).ok();
589 }
590
591 #[test]
592 fn test_pq_stats() {
593 let vectors = random_vectors(100, 128, 222);
594 let config = PQConfig {
595 num_subspaces: 8,
596 num_centroids: 256,
597 kmeans_iterations: 5,
598 training_sample_size: 0,
599 };
600
601 let pq = ProductQuantizer::train(&vectors, config).unwrap();
602 let stats = pq.stats();
603
604 assert_eq!(stats.dim, 128);
605 assert_eq!(stats.num_subspaces, 8);
606 assert_eq!(stats.num_centroids, 256);
607 assert_eq!(stats.subspace_dim, 16);
608 assert_eq!(stats.code_size_bytes, 8);
609 assert!(stats.compression_ratio > 50.0); println!("{}", stats);
612 }
613
614 #[test]
615 fn test_pq_preserves_ordering() {
616 let vectors = random_vectors(500, 64, 333);
617 let config = PQConfig {
618 num_subspaces: 8,
619 num_centroids: 256,
620 kmeans_iterations: 15,
621 training_sample_size: 0,
622 };
623
624 let pq = ProductQuantizer::train(&vectors, config).unwrap();
625
626 let query = &vectors[0];
627
628 let mut true_dists: Vec<(usize, f32)> = vectors
630 .iter()
631 .enumerate()
632 .skip(1)
633 .map(|(i, v)| (i, l2_distance(query, v)))
634 .collect();
635 true_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
636
637 let table = pq.create_distance_table(query);
639 let codes: Vec<Vec<u8>> = vectors.iter().map(|v| pq.encode(v)).collect();
640
641 let mut pq_dists: Vec<(usize, f32)> = codes
642 .iter()
643 .enumerate()
644 .skip(1)
645 .map(|(i, c)| (i, pq.distance_with_table(&table, c)))
646 .collect();
647 pq_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
648
649 let true_top10: std::collections::HashSet<_> =
651 true_dists.iter().take(10).map(|(i, _)| *i).collect();
652 let pq_top10: std::collections::HashSet<_> =
653 pq_dists.iter().take(10).map(|(i, _)| *i).collect();
654
655 let recall: f32 = true_top10.intersection(&pq_top10).count() as f32 / 10.0;
656 assert!(
657 recall >= 0.5,
658 "PQ recall@10 too low: {recall}. Expected >= 0.5"
659 );
660 }
661}