1use std::io::{self, Read, Write};
10
11use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
12#[cfg(not(feature = "native"))]
13use rand::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use super::super::ivf::cluster::QuantizedCode;
17use super::Quantizer;
18
19#[cfg(target_arch = "aarch64")]
20#[allow(unused_imports)]
21use std::arch::aarch64::*;
22
23#[cfg(all(target_arch = "x86_64", feature = "native"))]
24#[allow(unused_imports)]
25use std::arch::x86_64::*;
26
27const CODEBOOK_MAGIC: u32 = 0x5343424B; pub const DEFAULT_NUM_CENTROIDS: usize = 256;
32
33pub const DEFAULT_DIMS_PER_BLOCK: usize = 2;
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PQConfig {
39 pub dim: usize,
41 pub num_subspaces: usize,
43 pub dims_per_block: usize,
45 pub num_centroids: usize,
47 pub seed: u64,
49 pub anisotropic: bool,
51 pub aniso_eta: f32,
53 pub aniso_threshold: f32,
55 pub use_opq: bool,
57 pub opq_iters: usize,
59}
60
61impl PQConfig {
62 pub fn new(dim: usize) -> Self {
64 let dims_per_block = DEFAULT_DIMS_PER_BLOCK;
65 let num_subspaces = dim / dims_per_block;
66
67 Self {
68 dim,
69 num_subspaces,
70 dims_per_block,
71 num_centroids: DEFAULT_NUM_CENTROIDS,
72 seed: 42,
73 anisotropic: true,
74 aniso_eta: 10.0,
75 aniso_threshold: 0.2,
76 use_opq: true,
77 opq_iters: 10,
78 }
79 }
80
81 pub fn new_fast(dim: usize) -> Self {
83 let num_subspaces = if dim >= 256 {
84 8
85 } else if dim >= 64 {
86 4
87 } else {
88 2
89 };
90 let dims_per_block = dim / num_subspaces;
91
92 Self {
93 dim,
94 num_subspaces,
95 dims_per_block,
96 num_centroids: DEFAULT_NUM_CENTROIDS,
97 seed: 42,
98 anisotropic: true,
99 aniso_eta: 10.0,
100 aniso_threshold: 0.2,
101 use_opq: false,
102 opq_iters: 0,
103 }
104 }
105
106 pub fn new_balanced(dim: usize) -> Self {
109 let num_subspaces = if dim >= 128 {
110 16
111 } else if dim >= 64 {
112 8
113 } else {
114 4
115 };
116 let dims_per_block = dim / num_subspaces;
117
118 Self {
119 dim,
120 num_subspaces,
121 dims_per_block,
122 num_centroids: DEFAULT_NUM_CENTROIDS,
123 seed: 42,
124 anisotropic: true,
125 aniso_eta: 10.0,
126 aniso_threshold: 0.2,
127 use_opq: false,
128 opq_iters: 0,
129 }
130 }
131
132 pub fn with_dims_per_block(mut self, d: usize) -> Self {
133 assert!(
134 self.dim.is_multiple_of(d),
135 "Dimension must be divisible by dims_per_block"
136 );
137 self.dims_per_block = d;
138 self.num_subspaces = self.dim / d;
139 self
140 }
141
142 pub fn with_subspaces(mut self, m: usize) -> Self {
143 assert!(
144 self.dim.is_multiple_of(m),
145 "Dimension must be divisible by num_subspaces"
146 );
147 self.num_subspaces = m;
148 self.dims_per_block = self.dim / m;
149 self
150 }
151
152 pub fn with_centroids(mut self, k: usize) -> Self {
153 assert!(k <= 256, "Max 256 centroids for u8 codes");
154 self.num_centroids = k;
155 self
156 }
157
158 pub fn with_anisotropic(mut self, enabled: bool, eta: f32) -> Self {
159 self.anisotropic = enabled;
160 self.aniso_eta = eta;
161 self
162 }
163
164 pub fn with_opq(mut self, enabled: bool, iters: usize) -> Self {
165 self.use_opq = enabled;
166 self.opq_iters = iters;
167 self
168 }
169
170 pub fn subspace_dim(&self) -> usize {
172 self.dims_per_block
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct PQVector {
179 pub codes: Vec<u8>,
181 pub norm: f32,
183}
184
185impl PQVector {
186 pub fn new(codes: Vec<u8>, norm: f32) -> Self {
187 Self { codes, norm }
188 }
189}
190
191impl QuantizedCode for PQVector {
192 fn size_bytes(&self) -> usize {
193 self.codes.len() + 4 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct PQCodebook {
202 pub config: PQConfig,
204 pub rotation_matrix: Option<Vec<f32>>,
206 pub centroids: Vec<f32>,
208 pub version: u64,
210 pub centroid_norms: Option<Vec<f32>>,
212}
213
214impl PQCodebook {
215 #[cfg(feature = "native")]
217 pub fn train(config: PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Self {
218 use kentro::KMeans;
219 use ndarray::Array2;
220
221 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
222 assert_eq!(vectors[0].len(), config.dim, "Vector dimension mismatch");
223
224 let m = config.num_subspaces;
225 let k = config.num_centroids;
226 let sub_dim = config.subspace_dim();
227 let n = vectors.len();
228
229 let rotation_matrix = if config.use_opq && config.opq_iters > 0 {
231 Some(Self::learn_opq_rotation(&config, vectors, max_iters))
232 } else {
233 None
234 };
235
236 let rotated_vectors: Vec<Vec<f32>> = if let Some(ref r) = rotation_matrix {
238 vectors
239 .iter()
240 .map(|v| Self::apply_rotation(r, v, config.dim))
241 .collect()
242 } else {
243 vectors.to_vec()
244 };
245
246 let mut centroids = Vec::with_capacity(m * k * sub_dim);
248
249 for subspace_idx in 0..m {
250 let offset = subspace_idx * sub_dim;
251
252 let subdata: Vec<f32> = rotated_vectors
253 .iter()
254 .flat_map(|v| v[offset..offset + sub_dim].iter().copied())
255 .collect();
256
257 let actual_k = k.min(n);
258
259 let data = Array2::from_shape_vec((n, sub_dim), subdata)
260 .expect("Failed to create subspace array");
261 let mut kmeans = KMeans::new(actual_k)
262 .with_euclidean(true)
263 .with_iterations(max_iters);
264 let _ = kmeans
265 .train(data.view(), None)
266 .expect("K-means training failed");
267
268 let subspace_centroids: Vec<f32> = kmeans
269 .centroids()
270 .expect("No centroids")
271 .iter()
272 .copied()
273 .collect();
274
275 centroids.extend(subspace_centroids);
276
277 while centroids.len() < (subspace_idx + 1) * k * sub_dim {
279 let last_start = centroids.len() - sub_dim;
280 let last: Vec<f32> = centroids[last_start..].to_vec();
281 centroids.extend(last);
282 }
283 }
284
285 let centroid_norms: Vec<f32> = (0..m * k)
287 .map(|i| {
288 let start = i * sub_dim;
289 if start + sub_dim <= centroids.len() {
290 centroids[start..start + sub_dim]
291 .iter()
292 .map(|x| x * x)
293 .sum::<f32>()
294 .sqrt()
295 } else {
296 0.0
297 }
298 })
299 .collect();
300
301 let version = std::time::SystemTime::now()
302 .duration_since(std::time::UNIX_EPOCH)
303 .unwrap_or_default()
304 .as_millis() as u64;
305
306 Self {
307 config,
308 rotation_matrix,
309 centroids,
310 version,
311 centroid_norms: Some(centroid_norms),
312 }
313 }
314
315 #[cfg(not(feature = "native"))]
317 pub fn train(config: PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Self {
318 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
319 assert_eq!(vectors[0].len(), config.dim, "Vector dimension mismatch");
320
321 let m = config.num_subspaces;
322 let k = config.num_centroids;
323 let sub_dim = config.subspace_dim();
324 let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
325
326 let rotation_matrix = None;
327 let mut centroids = Vec::with_capacity(m * k * sub_dim);
328
329 for subspace_idx in 0..m {
330 let offset = subspace_idx * sub_dim;
331 let subvectors: Vec<Vec<f32>> = vectors
332 .iter()
333 .map(|v| v[offset..offset + sub_dim].to_vec())
334 .collect();
335
336 let subspace_centroids =
337 Self::train_subspace_scalar(&subvectors, k, sub_dim, max_iters, &mut rng);
338 centroids.extend(subspace_centroids);
339 }
340
341 let centroid_norms: Vec<f32> = (0..m * k)
342 .map(|i| {
343 let start = i * sub_dim;
344 centroids[start..start + sub_dim]
345 .iter()
346 .map(|x| x * x)
347 .sum::<f32>()
348 .sqrt()
349 })
350 .collect();
351
352 let version = std::time::SystemTime::now()
353 .duration_since(std::time::UNIX_EPOCH)
354 .unwrap_or_default()
355 .as_millis() as u64;
356
357 Self {
358 config,
359 rotation_matrix,
360 centroids,
361 version,
362 centroid_norms: Some(centroid_norms),
363 }
364 }
365
366 #[cfg(feature = "native")]
368 fn learn_opq_rotation(config: &PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Vec<f32> {
369 use nalgebra::DMatrix;
370
371 let dim = config.dim;
372 let n = vectors.len();
373
374 let mut rotation = DMatrix::<f32>::identity(dim, dim);
375 let data: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
376 let x = DMatrix::from_row_slice(n, dim, &data);
377
378 for _iter in 0..config.opq_iters.min(max_iters) {
379 let rotated = &x * &rotation;
380 let assignments = Self::compute_pq_assignments(config, &rotated);
381 let reconstructed = Self::reconstruct_from_assignments(config, &rotated, &assignments);
382
383 let xtx_hat = x.transpose() * &reconstructed;
384 let svd = xtx_hat.svd(true, true);
385 if let (Some(u), Some(vt)) = (svd.u, svd.v_t) {
386 let new_rotation: DMatrix<f32> = vt.transpose() * u.transpose();
387 rotation = new_rotation;
388 }
389 }
390
391 rotation.iter().copied().collect()
392 }
393
394 #[cfg(feature = "native")]
395 fn compute_pq_assignments(
396 config: &PQConfig,
397 rotated: &nalgebra::DMatrix<f32>,
398 ) -> Vec<Vec<usize>> {
399 use kentro::KMeans;
400 use ndarray::Array2;
401
402 let m = config.num_subspaces;
403 let k = config.num_centroids.min(rotated.nrows());
404 let sub_dim = config.subspace_dim();
405 let n = rotated.nrows();
406
407 let mut all_assignments = vec![vec![0usize; m]; n];
408
409 for subspace_idx in 0..m {
410 let mut subdata: Vec<f32> = Vec::with_capacity(n * sub_dim);
411 for row in 0..n {
412 for col in 0..sub_dim {
413 subdata.push(rotated[(row, subspace_idx * sub_dim + col)]);
414 }
415 }
416
417 let data = Array2::from_shape_vec((n, sub_dim), subdata)
418 .expect("Failed to create subspace array");
419 let mut kmeans = KMeans::new(k).with_euclidean(true).with_iterations(5);
420 let clusters = kmeans
421 .train(data.view(), None)
422 .expect("K-means training failed");
423
424 for (cluster_id, point_indices) in clusters.iter().enumerate() {
426 for &point_idx in point_indices {
427 all_assignments[point_idx][subspace_idx] = cluster_id;
428 }
429 }
430 }
431
432 all_assignments
433 }
434
435 #[cfg(feature = "native")]
436 fn reconstruct_from_assignments(
437 config: &PQConfig,
438 rotated: &nalgebra::DMatrix<f32>,
439 assignments: &[Vec<usize>],
440 ) -> nalgebra::DMatrix<f32> {
441 use kentro::KMeans;
442 use ndarray::Array2;
443
444 let m = config.num_subspaces;
445 let sub_dim = config.subspace_dim();
446 let n = rotated.nrows();
447 let dim = config.dim;
448
449 let mut reconstructed = nalgebra::DMatrix::<f32>::zeros(n, dim);
450
451 for subspace_idx in 0..m {
452 let mut subdata: Vec<f32> = Vec::with_capacity(n * sub_dim);
453 for row in 0..n {
454 for col in 0..sub_dim {
455 subdata.push(rotated[(row, subspace_idx * sub_dim + col)]);
456 }
457 }
458
459 let k = config.num_centroids.min(n);
460 let data = Array2::from_shape_vec((n, sub_dim), subdata)
461 .expect("Failed to create subspace array");
462 let mut kmeans = KMeans::new(k).with_euclidean(true).with_iterations(5);
463 let _ = kmeans
464 .train(data.view(), None)
465 .expect("K-means training failed");
466
467 let centroids = kmeans.centroids().expect("No centroids");
468
469 for (row, assignment) in assignments.iter().enumerate() {
470 let centroid_idx = assignment[subspace_idx];
471 if centroid_idx < k {
472 for col in 0..sub_dim {
473 reconstructed[(row, subspace_idx * sub_dim + col)] =
474 centroids[[centroid_idx, col]];
475 }
476 }
477 }
478 }
479
480 reconstructed
481 }
482
483 fn apply_rotation(rotation: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
485 let mut result = vec![0.0f32; dim];
486 for i in 0..dim {
487 result[i] = crate::structures::simd::dot_product_f32(
488 &rotation[i * dim..(i + 1) * dim],
489 vector,
490 dim,
491 );
492 }
493 result
494 }
495
496 #[cfg(not(feature = "native"))]
498 fn train_subspace_scalar(
499 subvectors: &[Vec<f32>],
500 k: usize,
501 sub_dim: usize,
502 max_iters: usize,
503 rng: &mut impl Rng,
504 ) -> Vec<f32> {
505 let actual_k = k.min(subvectors.len());
506 let mut centroids = Self::kmeans_plusplus_init_scalar(subvectors, actual_k, sub_dim, rng);
507
508 for _ in 0..max_iters {
509 let assignments: Vec<usize> = subvectors
510 .iter()
511 .map(|v| Self::find_nearest_scalar(¢roids, v, sub_dim))
512 .collect();
513
514 let mut new_centroids = vec![0.0f32; actual_k * sub_dim];
515 let mut counts = vec![0usize; actual_k];
516
517 for (subvec, &assignment) in subvectors.iter().zip(assignments.iter()) {
518 counts[assignment] += 1;
519 let offset = assignment * sub_dim;
520 for (j, &val) in subvec.iter().enumerate() {
521 new_centroids[offset + j] += val;
522 }
523 }
524
525 for (c, &count) in counts.iter().enumerate().take(actual_k) {
526 if count > 0 {
527 let offset = c * sub_dim;
528 for j in 0..sub_dim {
529 new_centroids[offset + j] /= count as f32;
530 }
531 }
532 }
533
534 centroids = new_centroids;
535 }
536
537 while centroids.len() < k * sub_dim {
538 let last_start = centroids.len() - sub_dim;
539 let last: Vec<f32> = centroids[last_start..].to_vec();
540 centroids.extend(last);
541 }
542
543 centroids
544 }
545
546 #[cfg(not(feature = "native"))]
547 fn kmeans_plusplus_init_scalar(
548 subvectors: &[Vec<f32>],
549 k: usize,
550 sub_dim: usize,
551 rng: &mut impl Rng,
552 ) -> Vec<f32> {
553 let mut centroids = Vec::with_capacity(k * sub_dim);
554 let first_idx = rng.random_range(0..subvectors.len());
555 centroids.extend_from_slice(&subvectors[first_idx]);
556
557 for _ in 1..k {
558 let distances: Vec<f32> = subvectors
559 .iter()
560 .map(|v| Self::min_dist_to_centroids_scalar(¢roids, v, sub_dim))
561 .collect();
562
563 let total: f32 = distances.iter().sum();
564 let mut r = rng.random::<f32>() * total;
565 let mut chosen_idx = 0;
566 for (i, &d) in distances.iter().enumerate() {
567 r -= d;
568 if r <= 0.0 {
569 chosen_idx = i;
570 break;
571 }
572 }
573 centroids.extend_from_slice(&subvectors[chosen_idx]);
574 }
575
576 centroids
577 }
578
579 #[cfg(not(feature = "native"))]
580 fn min_dist_to_centroids_scalar(centroids: &[f32], vector: &[f32], sub_dim: usize) -> f32 {
581 let num_centroids = centroids.len() / sub_dim;
582 (0..num_centroids)
583 .map(|c| {
584 let offset = c * sub_dim;
585 vector
586 .iter()
587 .zip(¢roids[offset..offset + sub_dim])
588 .map(|(&a, &b)| (a - b) * (a - b))
589 .sum()
590 })
591 .fold(f32::MAX, f32::min)
592 }
593
594 #[cfg(not(feature = "native"))]
595 fn find_nearest_scalar(centroids: &[f32], vector: &[f32], sub_dim: usize) -> usize {
596 let num_centroids = centroids.len() / sub_dim;
597 (0..num_centroids)
598 .map(|c| {
599 let offset = c * sub_dim;
600 let dist: f32 = vector
601 .iter()
602 .zip(¢roids[offset..offset + sub_dim])
603 .map(|(&a, &b)| (a - b) * (a - b))
604 .sum();
605 (c, dist)
606 })
607 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
608 .map(|(c, _)| c)
609 .unwrap_or(0)
610 }
611
612 fn find_nearest(centroids: &[f32], vector: &[f32], sub_dim: usize) -> usize {
614 let num_centroids = centroids.len() / sub_dim;
615 let mut best_idx = 0;
616 let mut best_dist = f32::MAX;
617
618 for c in 0..num_centroids {
619 let offset = c * sub_dim;
620 let dist: f32 = vector
621 .iter()
622 .zip(¢roids[offset..offset + sub_dim])
623 .map(|(&a, &b)| (a - b) * (a - b))
624 .sum();
625
626 if dist < best_dist {
627 best_dist = dist;
628 best_idx = c;
629 }
630 }
631
632 best_idx
633 }
634
635 pub fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> PQVector {
637 let m = self.config.num_subspaces;
638 let k = self.config.num_centroids;
639 let sub_dim = self.config.subspace_dim();
640
641 let residual: Vec<f32> = if let Some(c) = centroid {
643 vector.iter().zip(c).map(|(&v, &c)| v - c).collect()
644 } else {
645 vector.to_vec()
646 };
647
648 let rotated: Vec<f32>;
650 let vec_to_encode = if let Some(ref r) = self.rotation_matrix {
651 rotated = Self::apply_rotation(r, &residual, self.config.dim);
652 &rotated
653 } else {
654 &residual
655 };
656
657 let mut codes = Vec::with_capacity(m);
658
659 for subspace_idx in 0..m {
660 let vec_offset = subspace_idx * sub_dim;
661 let subvec = &vec_to_encode[vec_offset..vec_offset + sub_dim];
662
663 let centroid_base = subspace_idx * k * sub_dim;
664 let centroids_slice = &self.centroids[centroid_base..centroid_base + k * sub_dim];
665
666 let nearest = Self::find_nearest(centroids_slice, subvec, sub_dim);
667 codes.push(nearest as u8);
668 }
669
670 let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
671 PQVector::new(codes, norm)
672 }
673
674 pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
676 let m = self.config.num_subspaces;
677 let k = self.config.num_centroids;
678 let sub_dim = self.config.subspace_dim();
679
680 let mut rotated_vector = Vec::with_capacity(self.config.dim);
681
682 for (subspace_idx, &code) in codes.iter().enumerate().take(m) {
683 let centroid_base = subspace_idx * k * sub_dim;
684 let centroid_offset = centroid_base + (code as usize) * sub_dim;
685 rotated_vector
686 .extend_from_slice(&self.centroids[centroid_offset..centroid_offset + sub_dim]);
687 }
688
689 if let Some(ref r) = self.rotation_matrix {
691 Self::apply_rotation_transpose(r, &rotated_vector, self.config.dim)
692 } else {
693 rotated_vector
694 }
695 }
696
697 fn apply_rotation_transpose(rotation: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
699 let mut result = vec![0.0f32; dim];
700 for i in 0..dim {
701 for j in 0..dim {
702 result[i] += rotation[j * dim + i] * vector[j];
703 }
704 }
705 result
706 }
707
708 #[inline]
710 pub fn get_centroid(&self, subspace_idx: usize, code: u8) -> &[f32] {
711 let k = self.config.num_centroids;
712 let sub_dim = self.config.subspace_dim();
713 let offset = subspace_idx * k * sub_dim + (code as usize) * sub_dim;
714 &self.centroids[offset..offset + sub_dim]
715 }
716
717 pub fn rotate_query(&self, query: &[f32]) -> Vec<f32> {
719 if let Some(ref r) = self.rotation_matrix {
720 Self::apply_rotation(r, query, self.config.dim)
721 } else {
722 query.to_vec()
723 }
724 }
725
726 pub fn save(&self, path: &std::path::Path) -> io::Result<()> {
728 let mut file = std::fs::File::create(path)?;
729 self.write_to(&mut file)
730 }
731
732 pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
734 writer.write_u32::<LittleEndian>(CODEBOOK_MAGIC)?;
735 writer.write_u32::<LittleEndian>(2)?;
736 writer.write_u64::<LittleEndian>(self.version)?;
737 writer.write_u32::<LittleEndian>(self.config.dim as u32)?;
738 writer.write_u32::<LittleEndian>(self.config.num_subspaces as u32)?;
739 writer.write_u32::<LittleEndian>(self.config.dims_per_block as u32)?;
740 writer.write_u32::<LittleEndian>(self.config.num_centroids as u32)?;
741 writer.write_u8(if self.config.anisotropic { 1 } else { 0 })?;
742 writer.write_f32::<LittleEndian>(self.config.aniso_eta)?;
743 writer.write_f32::<LittleEndian>(self.config.aniso_threshold)?;
744 writer.write_u8(if self.config.use_opq { 1 } else { 0 })?;
745 writer.write_u32::<LittleEndian>(self.config.opq_iters as u32)?;
746
747 if let Some(ref rotation) = self.rotation_matrix {
748 writer.write_u8(1)?;
749 for &val in rotation {
750 writer.write_f32::<LittleEndian>(val)?;
751 }
752 } else {
753 writer.write_u8(0)?;
754 }
755
756 for &val in &self.centroids {
757 writer.write_f32::<LittleEndian>(val)?;
758 }
759
760 if let Some(ref norms) = self.centroid_norms {
761 writer.write_u8(1)?;
762 for &val in norms {
763 writer.write_f32::<LittleEndian>(val)?;
764 }
765 } else {
766 writer.write_u8(0)?;
767 }
768
769 Ok(())
770 }
771
772 pub fn load(path: &std::path::Path) -> io::Result<Self> {
774 let data = std::fs::read(path)?;
775 Self::read_from(&mut std::io::Cursor::new(data))
776 }
777
778 pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
780 let magic = reader.read_u32::<LittleEndian>()?;
781 if magic != CODEBOOK_MAGIC {
782 return Err(io::Error::new(
783 io::ErrorKind::InvalidData,
784 "Invalid codebook file magic",
785 ));
786 }
787
788 let file_version = reader.read_u32::<LittleEndian>()?;
789 let version = reader.read_u64::<LittleEndian>()?;
790 let dim = reader.read_u32::<LittleEndian>()? as usize;
791 let num_subspaces = reader.read_u32::<LittleEndian>()? as usize;
792
793 let (
794 dims_per_block,
795 num_centroids,
796 anisotropic,
797 aniso_eta,
798 aniso_threshold,
799 use_opq,
800 opq_iters,
801 ) = if file_version >= 2 {
802 let dpb = reader.read_u32::<LittleEndian>()? as usize;
803 let nc = reader.read_u32::<LittleEndian>()? as usize;
804 let aniso = reader.read_u8()? != 0;
805 let eta = reader.read_f32::<LittleEndian>()?;
806 let thresh = reader.read_f32::<LittleEndian>()?;
807 let opq = reader.read_u8()? != 0;
808 let iters = reader.read_u32::<LittleEndian>()? as usize;
809 (dpb, nc, aniso, eta, thresh, opq, iters)
810 } else {
811 let nc = reader.read_u32::<LittleEndian>()? as usize;
812 let aniso = reader.read_u8()? != 0;
813 let thresh = reader.read_f32::<LittleEndian>()?;
814 let dpb = dim / num_subspaces;
815 (dpb, nc, aniso, 10.0, thresh, false, 0)
816 };
817
818 let config = PQConfig {
819 dim,
820 num_subspaces,
821 dims_per_block,
822 num_centroids,
823 seed: 42,
824 anisotropic,
825 aniso_eta,
826 aniso_threshold,
827 use_opq,
828 opq_iters,
829 };
830
831 let rotation_matrix = if file_version >= 2 {
832 let has_rotation = reader.read_u8()? != 0;
833 if has_rotation {
834 let mut rotation = vec![0.0f32; dim * dim];
835 for val in &mut rotation {
836 *val = reader.read_f32::<LittleEndian>()?;
837 }
838 Some(rotation)
839 } else {
840 None
841 }
842 } else {
843 None
844 };
845
846 let centroid_count = num_subspaces * num_centroids * config.subspace_dim();
847 let mut centroids = vec![0.0f32; centroid_count];
848 for val in &mut centroids {
849 *val = reader.read_f32::<LittleEndian>()?;
850 }
851
852 let has_norms = reader.read_u8()? != 0;
853 let centroid_norms = if has_norms {
854 let mut norms = vec![0.0f32; num_subspaces * num_centroids];
855 for val in &mut norms {
856 *val = reader.read_f32::<LittleEndian>()?;
857 }
858 Some(norms)
859 } else {
860 None
861 };
862
863 Ok(Self {
864 config,
865 rotation_matrix,
866 centroids,
867 version,
868 centroid_norms,
869 })
870 }
871
872 pub fn size_bytes(&self) -> usize {
874 let centroids_size = self.centroids.len() * 4;
875 let norms_size = self
876 .centroid_norms
877 .as_ref()
878 .map(|n| n.len() * 4)
879 .unwrap_or(0);
880 let rotation_size = self
881 .rotation_matrix
882 .as_ref()
883 .map(|r| r.len() * 4)
884 .unwrap_or(0);
885 centroids_size + norms_size + rotation_size + 64
886 }
887
888 pub fn estimated_memory_bytes(&self) -> usize {
890 self.size_bytes()
891 }
892}
893
894#[derive(Debug, Clone)]
896pub struct DistanceTable {
897 pub distances: Vec<f32>,
899 pub num_subspaces: usize,
901 pub num_centroids: usize,
903}
904
905impl DistanceTable {
906 pub fn build(codebook: &PQCodebook, query: &[f32], centroid: Option<&[f32]>) -> Self {
908 let m = codebook.config.num_subspaces;
909 let k = codebook.config.num_centroids;
910 let sub_dim = codebook.config.subspace_dim();
911
912 let residual: Vec<f32> = if let Some(c) = centroid {
914 query.iter().zip(c).map(|(&v, &c)| v - c).collect()
915 } else {
916 query.to_vec()
917 };
918
919 let rotated_query = codebook.rotate_query(&residual);
921
922 let mut distances = Vec::with_capacity(m * k);
923
924 for subspace_idx in 0..m {
925 let query_offset = subspace_idx * sub_dim;
926 let query_sub = &rotated_query[query_offset..query_offset + sub_dim];
927
928 let centroid_base = subspace_idx * k * sub_dim;
929
930 for centroid_idx in 0..k {
931 let centroid_offset = centroid_base + centroid_idx * sub_dim;
932 let centroid = &codebook.centroids[centroid_offset..centroid_offset + sub_dim];
933
934 let dist: f32 = query_sub
935 .iter()
936 .zip(centroid.iter())
937 .map(|(&a, &b)| (a - b) * (a - b))
938 .sum();
939
940 distances.push(dist);
941 }
942 }
943
944 Self {
945 distances,
946 num_subspaces: m,
947 num_centroids: k,
948 }
949 }
950
951 #[inline]
953 pub fn compute_distance(&self, codes: &[u8]) -> f32 {
954 let k = self.num_centroids;
955 let mut total = 0.0f32;
956
957 for (subspace_idx, &code) in codes.iter().enumerate() {
958 let table_offset = subspace_idx * k + code as usize;
959 total += self.distances[table_offset];
960 }
961
962 total
963 }
964}
965
966impl Quantizer for PQCodebook {
967 type Code = PQVector;
968 type Config = PQConfig;
969 type QueryData = DistanceTable;
970
971 fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> Self::Code {
972 self.encode(vector, centroid)
973 }
974
975 fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> Self::QueryData {
976 DistanceTable::build(self, query, centroid)
977 }
978
979 fn compute_distance(&self, query_data: &Self::QueryData, code: &Self::Code) -> f32 {
980 query_data.compute_distance(&code.codes)
981 }
982
983 fn decode(&self, code: &Self::Code) -> Option<Vec<f32>> {
984 Some(self.decode(&code.codes))
985 }
986
987 fn size_bytes(&self) -> usize {
988 self.size_bytes()
989 }
990}
991
992#[cfg(test)]
993mod tests {
994 use super::*;
995 use rand::prelude::*;
996
997 #[test]
998 fn test_pq_config() {
999 let config = PQConfig::new(128);
1000 assert_eq!(config.dim, 128);
1001 assert_eq!(config.dims_per_block, 2);
1002 assert_eq!(config.num_subspaces, 64);
1003 }
1004
1005 #[test]
1006 fn test_pq_encode_decode() {
1007 let dim = 32;
1008 let config = PQConfig::new(dim).with_opq(false, 0);
1009
1010 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
1011 let vectors: Vec<Vec<f32>> = (0..100)
1012 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
1013 .collect();
1014
1015 let codebook = PQCodebook::train(config, &vectors, 10);
1016
1017 let test_vec: Vec<f32> = (0..dim).map(|i| i as f32 / dim as f32).collect();
1018 let code = codebook.encode(&test_vec, None);
1019
1020 assert_eq!(code.codes.len(), 16); }
1022
1023 #[test]
1024 fn test_distance_table() {
1025 let dim = 16;
1026 let config = PQConfig::new(dim).with_opq(false, 0);
1027
1028 let mut rng = rand::rngs::StdRng::seed_from_u64(123);
1029 let vectors: Vec<Vec<f32>> = (0..50)
1030 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
1031 .collect();
1032
1033 let codebook = PQCodebook::train(config, &vectors, 5);
1034
1035 let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
1036 let table = DistanceTable::build(&codebook, &query, None);
1037
1038 let code = codebook.encode(&vectors[0], None);
1039 let dist = table.compute_distance(&code.codes);
1040
1041 assert!(dist >= 0.0);
1042 }
1043}