1use std::io::{self, Read, Write};
10
11use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
12#[cfg(not(feature = "native"))]
13use rand::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use super::super::ivf::cluster::QuantizedCode;
17use super::Quantizer;
18
19#[cfg(target_arch = "aarch64")]
20#[allow(unused_imports)]
21use std::arch::aarch64::*;
22
23#[cfg(all(target_arch = "x86_64", feature = "native"))]
24#[allow(unused_imports)]
25use std::arch::x86_64::*;
26
27const CODEBOOK_MAGIC: u32 = 0x5343424B; pub const DEFAULT_NUM_CENTROIDS: usize = 256;
32
33pub const DEFAULT_DIMS_PER_BLOCK: usize = 2;
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PQConfig {
39 pub dim: usize,
41 pub num_subspaces: usize,
43 pub dims_per_block: usize,
45 pub num_centroids: usize,
47 pub seed: u64,
49 pub anisotropic: bool,
51 pub aniso_eta: f32,
53 pub aniso_threshold: f32,
55 pub use_opq: bool,
57 pub opq_iters: usize,
59}
60
61impl PQConfig {
62 pub fn new(dim: usize) -> Self {
64 let dims_per_block = DEFAULT_DIMS_PER_BLOCK;
65 let num_subspaces = dim / dims_per_block;
66
67 Self {
68 dim,
69 num_subspaces,
70 dims_per_block,
71 num_centroids: DEFAULT_NUM_CENTROIDS,
72 seed: 42,
73 anisotropic: true,
74 aniso_eta: 10.0,
75 aniso_threshold: 0.2,
76 use_opq: true,
77 opq_iters: 10,
78 }
79 }
80
81 pub fn new_fast(dim: usize) -> Self {
83 let num_subspaces = if dim >= 256 {
84 8
85 } else if dim >= 64 {
86 4
87 } else {
88 2
89 };
90 let dims_per_block = dim / num_subspaces;
91
92 Self {
93 dim,
94 num_subspaces,
95 dims_per_block,
96 num_centroids: DEFAULT_NUM_CENTROIDS,
97 seed: 42,
98 anisotropic: true,
99 aniso_eta: 10.0,
100 aniso_threshold: 0.2,
101 use_opq: false,
102 opq_iters: 0,
103 }
104 }
105
106 pub fn new_balanced(dim: usize) -> Self {
109 let num_subspaces = if dim >= 128 {
110 16
111 } else if dim >= 64 {
112 8
113 } else {
114 4
115 };
116 let dims_per_block = dim / num_subspaces;
117
118 Self {
119 dim,
120 num_subspaces,
121 dims_per_block,
122 num_centroids: DEFAULT_NUM_CENTROIDS,
123 seed: 42,
124 anisotropic: true,
125 aniso_eta: 10.0,
126 aniso_threshold: 0.2,
127 use_opq: false,
128 opq_iters: 0,
129 }
130 }
131
132 pub fn with_dims_per_block(mut self, d: usize) -> Self {
133 assert!(
134 self.dim.is_multiple_of(d),
135 "Dimension must be divisible by dims_per_block"
136 );
137 self.dims_per_block = d;
138 self.num_subspaces = self.dim / d;
139 self
140 }
141
142 pub fn with_subspaces(mut self, m: usize) -> Self {
143 assert!(
144 self.dim.is_multiple_of(m),
145 "Dimension must be divisible by num_subspaces"
146 );
147 self.num_subspaces = m;
148 self.dims_per_block = self.dim / m;
149 self
150 }
151
152 pub fn with_centroids(mut self, k: usize) -> Self {
153 assert!(k <= 256, "Max 256 centroids for u8 codes");
154 self.num_centroids = k;
155 self
156 }
157
158 pub fn with_anisotropic(mut self, enabled: bool, eta: f32) -> Self {
159 self.anisotropic = enabled;
160 self.aniso_eta = eta;
161 self
162 }
163
164 pub fn with_opq(mut self, enabled: bool, iters: usize) -> Self {
165 self.use_opq = enabled;
166 self.opq_iters = iters;
167 self
168 }
169
170 pub fn subspace_dim(&self) -> usize {
172 self.dims_per_block
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct PQVector {
179 pub codes: Vec<u8>,
181 pub norm: f32,
183}
184
185impl PQVector {
186 pub fn new(codes: Vec<u8>, norm: f32) -> Self {
187 Self { codes, norm }
188 }
189}
190
191impl QuantizedCode for PQVector {
192 fn size_bytes(&self) -> usize {
193 self.codes.len() + 4 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct PQCodebook {
202 pub config: PQConfig,
204 pub rotation_matrix: Option<Vec<f32>>,
206 pub centroids: Vec<f32>,
208 pub version: u64,
210 pub centroid_norms: Option<Vec<f32>>,
212}
213
214impl PQCodebook {
215 #[cfg(feature = "native")]
217 pub fn train(config: PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Self {
218 use kentro::KMeans;
219 use ndarray::Array2;
220
221 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
222 assert_eq!(vectors[0].len(), config.dim, "Vector dimension mismatch");
223
224 let m = config.num_subspaces;
225 let k = config.num_centroids;
226 let sub_dim = config.subspace_dim();
227 let n = vectors.len();
228
229 let rotation_matrix = if config.use_opq && config.opq_iters > 0 {
231 Some(Self::learn_opq_rotation(&config, vectors, max_iters))
232 } else {
233 None
234 };
235
236 let rotated_vectors: Vec<Vec<f32>> = if let Some(ref r) = rotation_matrix {
238 vectors
239 .iter()
240 .map(|v| Self::apply_rotation(r, v, config.dim))
241 .collect()
242 } else {
243 vectors.to_vec()
244 };
245
246 let mut centroids = Vec::with_capacity(m * k * sub_dim);
248
249 for subspace_idx in 0..m {
250 let offset = subspace_idx * sub_dim;
251
252 let subdata: Vec<f32> = rotated_vectors
253 .iter()
254 .flat_map(|v| v[offset..offset + sub_dim].iter().copied())
255 .collect();
256
257 let actual_k = k.min(n);
258
259 let data = Array2::from_shape_vec((n, sub_dim), subdata)
260 .expect("Failed to create subspace array");
261 let mut kmeans = KMeans::new(actual_k)
262 .with_euclidean(true)
263 .with_iterations(max_iters);
264 let _ = kmeans
265 .train(data.view(), None)
266 .expect("K-means training failed");
267
268 let subspace_centroids: Vec<f32> = kmeans
269 .centroids()
270 .expect("No centroids")
271 .iter()
272 .copied()
273 .collect();
274
275 centroids.extend(subspace_centroids);
276
277 while centroids.len() < (subspace_idx + 1) * k * sub_dim {
279 let last_start = centroids.len() - sub_dim;
280 let last: Vec<f32> = centroids[last_start..].to_vec();
281 centroids.extend(last);
282 }
283 }
284
285 let centroid_norms: Vec<f32> = (0..m * k)
287 .map(|i| {
288 let start = i * sub_dim;
289 if start + sub_dim <= centroids.len() {
290 centroids[start..start + sub_dim]
291 .iter()
292 .map(|x| x * x)
293 .sum::<f32>()
294 .sqrt()
295 } else {
296 0.0
297 }
298 })
299 .collect();
300
301 let version = std::time::SystemTime::now()
302 .duration_since(std::time::UNIX_EPOCH)
303 .unwrap_or_default()
304 .as_millis() as u64;
305
306 Self {
307 config,
308 rotation_matrix,
309 centroids,
310 version,
311 centroid_norms: Some(centroid_norms),
312 }
313 }
314
315 #[cfg(not(feature = "native"))]
317 pub fn train(config: PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Self {
318 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
319 assert_eq!(vectors[0].len(), config.dim, "Vector dimension mismatch");
320
321 let m = config.num_subspaces;
322 let k = config.num_centroids;
323 let sub_dim = config.subspace_dim();
324 let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
325
326 let rotation_matrix = None;
327 let mut centroids = Vec::with_capacity(m * k * sub_dim);
328
329 for subspace_idx in 0..m {
330 let offset = subspace_idx * sub_dim;
331 let subvectors: Vec<Vec<f32>> = vectors
332 .iter()
333 .map(|v| v[offset..offset + sub_dim].to_vec())
334 .collect();
335
336 let subspace_centroids =
337 Self::train_subspace_scalar(&subvectors, k, sub_dim, max_iters, &mut rng);
338 centroids.extend(subspace_centroids);
339 }
340
341 let centroid_norms: Vec<f32> = (0..m * k)
342 .map(|i| {
343 let start = i * sub_dim;
344 centroids[start..start + sub_dim]
345 .iter()
346 .map(|x| x * x)
347 .sum::<f32>()
348 .sqrt()
349 })
350 .collect();
351
352 let version = std::time::SystemTime::now()
353 .duration_since(std::time::UNIX_EPOCH)
354 .unwrap_or_default()
355 .as_millis() as u64;
356
357 Self {
358 config,
359 rotation_matrix,
360 centroids,
361 version,
362 centroid_norms: Some(centroid_norms),
363 }
364 }
365
366 #[cfg(feature = "native")]
368 fn learn_opq_rotation(config: &PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Vec<f32> {
369 use nalgebra::DMatrix;
370
371 let dim = config.dim;
372 let n = vectors.len();
373
374 let mut rotation = DMatrix::<f32>::identity(dim, dim);
375 let data: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
376 let x = DMatrix::from_row_slice(n, dim, &data);
377
378 for _iter in 0..config.opq_iters.min(max_iters) {
379 let rotated = &x * &rotation;
380 let assignments = Self::compute_pq_assignments(config, &rotated);
381 let reconstructed = Self::reconstruct_from_assignments(config, &rotated, &assignments);
382
383 let xtx_hat = x.transpose() * &reconstructed;
384 let svd = xtx_hat.svd(true, true);
385 if let (Some(u), Some(vt)) = (svd.u, svd.v_t) {
386 let new_rotation: DMatrix<f32> = vt.transpose() * u.transpose();
387 rotation = new_rotation;
388 }
389 }
390
391 rotation.iter().copied().collect()
392 }
393
394 #[cfg(feature = "native")]
395 fn compute_pq_assignments(
396 config: &PQConfig,
397 rotated: &nalgebra::DMatrix<f32>,
398 ) -> Vec<Vec<usize>> {
399 use kentro::KMeans;
400 use ndarray::Array2;
401
402 let m = config.num_subspaces;
403 let k = config.num_centroids.min(rotated.nrows());
404 let sub_dim = config.subspace_dim();
405 let n = rotated.nrows();
406
407 let mut all_assignments = vec![vec![0usize; m]; n];
408
409 for subspace_idx in 0..m {
410 let mut subdata: Vec<f32> = Vec::with_capacity(n * sub_dim);
411 for row in 0..n {
412 for col in 0..sub_dim {
413 subdata.push(rotated[(row, subspace_idx * sub_dim + col)]);
414 }
415 }
416
417 let data = Array2::from_shape_vec((n, sub_dim), subdata)
418 .expect("Failed to create subspace array");
419 let mut kmeans = KMeans::new(k).with_euclidean(true).with_iterations(5);
420 let clusters = kmeans
421 .train(data.view(), None)
422 .expect("K-means training failed");
423
424 for (cluster_id, point_indices) in clusters.iter().enumerate() {
426 for &point_idx in point_indices {
427 all_assignments[point_idx][subspace_idx] = cluster_id;
428 }
429 }
430 }
431
432 all_assignments
433 }
434
435 #[cfg(feature = "native")]
436 fn reconstruct_from_assignments(
437 config: &PQConfig,
438 rotated: &nalgebra::DMatrix<f32>,
439 assignments: &[Vec<usize>],
440 ) -> nalgebra::DMatrix<f32> {
441 use kentro::KMeans;
442 use ndarray::Array2;
443
444 let m = config.num_subspaces;
445 let sub_dim = config.subspace_dim();
446 let n = rotated.nrows();
447 let dim = config.dim;
448
449 let mut reconstructed = nalgebra::DMatrix::<f32>::zeros(n, dim);
450
451 for subspace_idx in 0..m {
452 let mut subdata: Vec<f32> = Vec::with_capacity(n * sub_dim);
453 for row in 0..n {
454 for col in 0..sub_dim {
455 subdata.push(rotated[(row, subspace_idx * sub_dim + col)]);
456 }
457 }
458
459 let k = config.num_centroids.min(n);
460 let data = Array2::from_shape_vec((n, sub_dim), subdata)
461 .expect("Failed to create subspace array");
462 let mut kmeans = KMeans::new(k).with_euclidean(true).with_iterations(5);
463 let _ = kmeans
464 .train(data.view(), None)
465 .expect("K-means training failed");
466
467 let centroids = kmeans.centroids().expect("No centroids");
468
469 for (row, assignment) in assignments.iter().enumerate() {
470 let centroid_idx = assignment[subspace_idx];
471 if centroid_idx < k {
472 for col in 0..sub_dim {
473 reconstructed[(row, subspace_idx * sub_dim + col)] =
474 centroids[[centroid_idx, col]];
475 }
476 }
477 }
478 }
479
480 reconstructed
481 }
482
483 fn apply_rotation(rotation: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
485 let mut result = vec![0.0f32; dim];
486 for i in 0..dim {
487 for j in 0..dim {
488 result[i] += rotation[i * dim + j] * vector[j];
489 }
490 }
491 result
492 }
493
494 #[cfg(not(feature = "native"))]
496 fn train_subspace_scalar(
497 subvectors: &[Vec<f32>],
498 k: usize,
499 sub_dim: usize,
500 max_iters: usize,
501 rng: &mut impl Rng,
502 ) -> Vec<f32> {
503 let actual_k = k.min(subvectors.len());
504 let mut centroids = Self::kmeans_plusplus_init_scalar(subvectors, actual_k, sub_dim, rng);
505
506 for _ in 0..max_iters {
507 let assignments: Vec<usize> = subvectors
508 .iter()
509 .map(|v| Self::find_nearest_scalar(¢roids, v, sub_dim))
510 .collect();
511
512 let mut new_centroids = vec![0.0f32; actual_k * sub_dim];
513 let mut counts = vec![0usize; actual_k];
514
515 for (subvec, &assignment) in subvectors.iter().zip(assignments.iter()) {
516 counts[assignment] += 1;
517 let offset = assignment * sub_dim;
518 for (j, &val) in subvec.iter().enumerate() {
519 new_centroids[offset + j] += val;
520 }
521 }
522
523 for (c, &count) in counts.iter().enumerate().take(actual_k) {
524 if count > 0 {
525 let offset = c * sub_dim;
526 for j in 0..sub_dim {
527 new_centroids[offset + j] /= count as f32;
528 }
529 }
530 }
531
532 centroids = new_centroids;
533 }
534
535 while centroids.len() < k * sub_dim {
536 let last_start = centroids.len() - sub_dim;
537 let last: Vec<f32> = centroids[last_start..].to_vec();
538 centroids.extend(last);
539 }
540
541 centroids
542 }
543
544 #[cfg(not(feature = "native"))]
545 fn kmeans_plusplus_init_scalar(
546 subvectors: &[Vec<f32>],
547 k: usize,
548 sub_dim: usize,
549 rng: &mut impl Rng,
550 ) -> Vec<f32> {
551 let mut centroids = Vec::with_capacity(k * sub_dim);
552 let first_idx = rng.random_range(0..subvectors.len());
553 centroids.extend_from_slice(&subvectors[first_idx]);
554
555 for _ in 1..k {
556 let distances: Vec<f32> = subvectors
557 .iter()
558 .map(|v| Self::min_dist_to_centroids_scalar(¢roids, v, sub_dim))
559 .collect();
560
561 let total: f32 = distances.iter().sum();
562 let mut r = rng.random::<f32>() * total;
563 let mut chosen_idx = 0;
564 for (i, &d) in distances.iter().enumerate() {
565 r -= d;
566 if r <= 0.0 {
567 chosen_idx = i;
568 break;
569 }
570 }
571 centroids.extend_from_slice(&subvectors[chosen_idx]);
572 }
573
574 centroids
575 }
576
577 #[cfg(not(feature = "native"))]
578 fn min_dist_to_centroids_scalar(centroids: &[f32], vector: &[f32], sub_dim: usize) -> f32 {
579 let num_centroids = centroids.len() / sub_dim;
580 (0..num_centroids)
581 .map(|c| {
582 let offset = c * sub_dim;
583 vector
584 .iter()
585 .zip(¢roids[offset..offset + sub_dim])
586 .map(|(&a, &b)| (a - b) * (a - b))
587 .sum()
588 })
589 .fold(f32::MAX, f32::min)
590 }
591
592 #[cfg(not(feature = "native"))]
593 fn find_nearest_scalar(centroids: &[f32], vector: &[f32], sub_dim: usize) -> usize {
594 let num_centroids = centroids.len() / sub_dim;
595 (0..num_centroids)
596 .map(|c| {
597 let offset = c * sub_dim;
598 let dist: f32 = vector
599 .iter()
600 .zip(¢roids[offset..offset + sub_dim])
601 .map(|(&a, &b)| (a - b) * (a - b))
602 .sum();
603 (c, dist)
604 })
605 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
606 .map(|(c, _)| c)
607 .unwrap_or(0)
608 }
609
610 fn find_nearest(centroids: &[f32], vector: &[f32], sub_dim: usize) -> usize {
612 let num_centroids = centroids.len() / sub_dim;
613 let mut best_idx = 0;
614 let mut best_dist = f32::MAX;
615
616 for c in 0..num_centroids {
617 let offset = c * sub_dim;
618 let dist: f32 = vector
619 .iter()
620 .zip(¢roids[offset..offset + sub_dim])
621 .map(|(&a, &b)| (a - b) * (a - b))
622 .sum();
623
624 if dist < best_dist {
625 best_dist = dist;
626 best_idx = c;
627 }
628 }
629
630 best_idx
631 }
632
633 pub fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> PQVector {
635 let m = self.config.num_subspaces;
636 let k = self.config.num_centroids;
637 let sub_dim = self.config.subspace_dim();
638
639 let residual: Vec<f32> = if let Some(c) = centroid {
641 vector.iter().zip(c).map(|(&v, &c)| v - c).collect()
642 } else {
643 vector.to_vec()
644 };
645
646 let rotated: Vec<f32>;
648 let vec_to_encode = if let Some(ref r) = self.rotation_matrix {
649 rotated = Self::apply_rotation(r, &residual, self.config.dim);
650 &rotated
651 } else {
652 &residual
653 };
654
655 let mut codes = Vec::with_capacity(m);
656
657 for subspace_idx in 0..m {
658 let vec_offset = subspace_idx * sub_dim;
659 let subvec = &vec_to_encode[vec_offset..vec_offset + sub_dim];
660
661 let centroid_base = subspace_idx * k * sub_dim;
662 let centroids_slice = &self.centroids[centroid_base..centroid_base + k * sub_dim];
663
664 let nearest = Self::find_nearest(centroids_slice, subvec, sub_dim);
665 codes.push(nearest as u8);
666 }
667
668 let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
669 PQVector::new(codes, norm)
670 }
671
672 pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
674 let m = self.config.num_subspaces;
675 let k = self.config.num_centroids;
676 let sub_dim = self.config.subspace_dim();
677
678 let mut rotated_vector = Vec::with_capacity(self.config.dim);
679
680 for (subspace_idx, &code) in codes.iter().enumerate().take(m) {
681 let centroid_base = subspace_idx * k * sub_dim;
682 let centroid_offset = centroid_base + (code as usize) * sub_dim;
683 rotated_vector
684 .extend_from_slice(&self.centroids[centroid_offset..centroid_offset + sub_dim]);
685 }
686
687 if let Some(ref r) = self.rotation_matrix {
689 Self::apply_rotation_transpose(r, &rotated_vector, self.config.dim)
690 } else {
691 rotated_vector
692 }
693 }
694
695 fn apply_rotation_transpose(rotation: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
697 let mut result = vec![0.0f32; dim];
698 for i in 0..dim {
699 for j in 0..dim {
700 result[i] += rotation[j * dim + i] * vector[j];
701 }
702 }
703 result
704 }
705
706 #[inline]
708 pub fn get_centroid(&self, subspace_idx: usize, code: u8) -> &[f32] {
709 let k = self.config.num_centroids;
710 let sub_dim = self.config.subspace_dim();
711 let offset = subspace_idx * k * sub_dim + (code as usize) * sub_dim;
712 &self.centroids[offset..offset + sub_dim]
713 }
714
715 pub fn rotate_query(&self, query: &[f32]) -> Vec<f32> {
717 if let Some(ref r) = self.rotation_matrix {
718 Self::apply_rotation(r, query, self.config.dim)
719 } else {
720 query.to_vec()
721 }
722 }
723
724 pub fn save(&self, path: &std::path::Path) -> io::Result<()> {
726 let mut file = std::fs::File::create(path)?;
727 self.write_to(&mut file)
728 }
729
730 pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
732 writer.write_u32::<LittleEndian>(CODEBOOK_MAGIC)?;
733 writer.write_u32::<LittleEndian>(2)?;
734 writer.write_u64::<LittleEndian>(self.version)?;
735 writer.write_u32::<LittleEndian>(self.config.dim as u32)?;
736 writer.write_u32::<LittleEndian>(self.config.num_subspaces as u32)?;
737 writer.write_u32::<LittleEndian>(self.config.dims_per_block as u32)?;
738 writer.write_u32::<LittleEndian>(self.config.num_centroids as u32)?;
739 writer.write_u8(if self.config.anisotropic { 1 } else { 0 })?;
740 writer.write_f32::<LittleEndian>(self.config.aniso_eta)?;
741 writer.write_f32::<LittleEndian>(self.config.aniso_threshold)?;
742 writer.write_u8(if self.config.use_opq { 1 } else { 0 })?;
743 writer.write_u32::<LittleEndian>(self.config.opq_iters as u32)?;
744
745 if let Some(ref rotation) = self.rotation_matrix {
746 writer.write_u8(1)?;
747 for &val in rotation {
748 writer.write_f32::<LittleEndian>(val)?;
749 }
750 } else {
751 writer.write_u8(0)?;
752 }
753
754 for &val in &self.centroids {
755 writer.write_f32::<LittleEndian>(val)?;
756 }
757
758 if let Some(ref norms) = self.centroid_norms {
759 writer.write_u8(1)?;
760 for &val in norms {
761 writer.write_f32::<LittleEndian>(val)?;
762 }
763 } else {
764 writer.write_u8(0)?;
765 }
766
767 Ok(())
768 }
769
770 pub fn load(path: &std::path::Path) -> io::Result<Self> {
772 let data = std::fs::read(path)?;
773 Self::read_from(&mut std::io::Cursor::new(data))
774 }
775
776 pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
778 let magic = reader.read_u32::<LittleEndian>()?;
779 if magic != CODEBOOK_MAGIC {
780 return Err(io::Error::new(
781 io::ErrorKind::InvalidData,
782 "Invalid codebook file magic",
783 ));
784 }
785
786 let file_version = reader.read_u32::<LittleEndian>()?;
787 let version = reader.read_u64::<LittleEndian>()?;
788 let dim = reader.read_u32::<LittleEndian>()? as usize;
789 let num_subspaces = reader.read_u32::<LittleEndian>()? as usize;
790
791 let (
792 dims_per_block,
793 num_centroids,
794 anisotropic,
795 aniso_eta,
796 aniso_threshold,
797 use_opq,
798 opq_iters,
799 ) = if file_version >= 2 {
800 let dpb = reader.read_u32::<LittleEndian>()? as usize;
801 let nc = reader.read_u32::<LittleEndian>()? as usize;
802 let aniso = reader.read_u8()? != 0;
803 let eta = reader.read_f32::<LittleEndian>()?;
804 let thresh = reader.read_f32::<LittleEndian>()?;
805 let opq = reader.read_u8()? != 0;
806 let iters = reader.read_u32::<LittleEndian>()? as usize;
807 (dpb, nc, aniso, eta, thresh, opq, iters)
808 } else {
809 let nc = reader.read_u32::<LittleEndian>()? as usize;
810 let aniso = reader.read_u8()? != 0;
811 let thresh = reader.read_f32::<LittleEndian>()?;
812 let dpb = dim / num_subspaces;
813 (dpb, nc, aniso, 10.0, thresh, false, 0)
814 };
815
816 let config = PQConfig {
817 dim,
818 num_subspaces,
819 dims_per_block,
820 num_centroids,
821 seed: 42,
822 anisotropic,
823 aniso_eta,
824 aniso_threshold,
825 use_opq,
826 opq_iters,
827 };
828
829 let rotation_matrix = if file_version >= 2 {
830 let has_rotation = reader.read_u8()? != 0;
831 if has_rotation {
832 let mut rotation = vec![0.0f32; dim * dim];
833 for val in &mut rotation {
834 *val = reader.read_f32::<LittleEndian>()?;
835 }
836 Some(rotation)
837 } else {
838 None
839 }
840 } else {
841 None
842 };
843
844 let centroid_count = num_subspaces * num_centroids * config.subspace_dim();
845 let mut centroids = vec![0.0f32; centroid_count];
846 for val in &mut centroids {
847 *val = reader.read_f32::<LittleEndian>()?;
848 }
849
850 let has_norms = reader.read_u8()? != 0;
851 let centroid_norms = if has_norms {
852 let mut norms = vec![0.0f32; num_subspaces * num_centroids];
853 for val in &mut norms {
854 *val = reader.read_f32::<LittleEndian>()?;
855 }
856 Some(norms)
857 } else {
858 None
859 };
860
861 Ok(Self {
862 config,
863 rotation_matrix,
864 centroids,
865 version,
866 centroid_norms,
867 })
868 }
869
870 pub fn size_bytes(&self) -> usize {
872 let centroids_size = self.centroids.len() * 4;
873 let norms_size = self
874 .centroid_norms
875 .as_ref()
876 .map(|n| n.len() * 4)
877 .unwrap_or(0);
878 let rotation_size = self
879 .rotation_matrix
880 .as_ref()
881 .map(|r| r.len() * 4)
882 .unwrap_or(0);
883 centroids_size + norms_size + rotation_size + 64
884 }
885}
886
887#[derive(Debug, Clone)]
889pub struct DistanceTable {
890 pub distances: Vec<f32>,
892 pub num_subspaces: usize,
894 pub num_centroids: usize,
896}
897
898impl DistanceTable {
899 pub fn build(codebook: &PQCodebook, query: &[f32], centroid: Option<&[f32]>) -> Self {
901 let m = codebook.config.num_subspaces;
902 let k = codebook.config.num_centroids;
903 let sub_dim = codebook.config.subspace_dim();
904
905 let residual: Vec<f32> = if let Some(c) = centroid {
907 query.iter().zip(c).map(|(&v, &c)| v - c).collect()
908 } else {
909 query.to_vec()
910 };
911
912 let rotated_query = codebook.rotate_query(&residual);
914
915 let mut distances = Vec::with_capacity(m * k);
916
917 for subspace_idx in 0..m {
918 let query_offset = subspace_idx * sub_dim;
919 let query_sub = &rotated_query[query_offset..query_offset + sub_dim];
920
921 let centroid_base = subspace_idx * k * sub_dim;
922
923 for centroid_idx in 0..k {
924 let centroid_offset = centroid_base + centroid_idx * sub_dim;
925 let centroid = &codebook.centroids[centroid_offset..centroid_offset + sub_dim];
926
927 let dist: f32 = query_sub
928 .iter()
929 .zip(centroid.iter())
930 .map(|(&a, &b)| (a - b) * (a - b))
931 .sum();
932
933 distances.push(dist);
934 }
935 }
936
937 Self {
938 distances,
939 num_subspaces: m,
940 num_centroids: k,
941 }
942 }
943
944 #[inline]
946 pub fn compute_distance(&self, codes: &[u8]) -> f32 {
947 let k = self.num_centroids;
948 let mut total = 0.0f32;
949
950 for (subspace_idx, &code) in codes.iter().enumerate() {
951 let table_offset = subspace_idx * k + code as usize;
952 total += self.distances[table_offset];
953 }
954
955 total
956 }
957}
958
959impl Quantizer for PQCodebook {
960 type Code = PQVector;
961 type Config = PQConfig;
962 type QueryData = DistanceTable;
963
964 fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> Self::Code {
965 self.encode(vector, centroid)
966 }
967
968 fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> Self::QueryData {
969 DistanceTable::build(self, query, centroid)
970 }
971
972 fn compute_distance(&self, query_data: &Self::QueryData, code: &Self::Code) -> f32 {
973 query_data.compute_distance(&code.codes)
974 }
975
976 fn decode(&self, code: &Self::Code) -> Option<Vec<f32>> {
977 Some(self.decode(&code.codes))
978 }
979
980 fn size_bytes(&self) -> usize {
981 self.size_bytes()
982 }
983}
984
985#[cfg(test)]
986mod tests {
987 use super::*;
988 use rand::prelude::*;
989
990 #[test]
991 fn test_pq_config() {
992 let config = PQConfig::new(128);
993 assert_eq!(config.dim, 128);
994 assert_eq!(config.dims_per_block, 2);
995 assert_eq!(config.num_subspaces, 64);
996 }
997
998 #[test]
999 fn test_pq_encode_decode() {
1000 let dim = 32;
1001 let config = PQConfig::new(dim).with_opq(false, 0);
1002
1003 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
1004 let vectors: Vec<Vec<f32>> = (0..100)
1005 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
1006 .collect();
1007
1008 let codebook = PQCodebook::train(config, &vectors, 10);
1009
1010 let test_vec: Vec<f32> = (0..dim).map(|i| i as f32 / dim as f32).collect();
1011 let code = codebook.encode(&test_vec, None);
1012
1013 assert_eq!(code.codes.len(), 16); }
1015
1016 #[test]
1017 fn test_distance_table() {
1018 let dim = 16;
1019 let config = PQConfig::new(dim).with_opq(false, 0);
1020
1021 let mut rng = rand::rngs::StdRng::seed_from_u64(123);
1022 let vectors: Vec<Vec<f32>> = (0..50)
1023 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
1024 .collect();
1025
1026 let codebook = PQCodebook::train(config, &vectors, 5);
1027
1028 let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
1029 let table = DistanceTable::build(&codebook, &query, None);
1030
1031 let code = codebook.encode(&vectors[0], None);
1032 let dist = table.compute_distance(&code.codes);
1033
1034 assert!(dist >= 0.0);
1035 }
1036}