1use std::io::{self, Read, Write};
10
11use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
12#[cfg(not(feature = "native"))]
13use rand::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use super::super::ivf::cluster::QuantizedCode;
17use super::Quantizer;
18
19#[cfg(target_arch = "aarch64")]
20#[allow(unused_imports)]
21use std::arch::aarch64::*;
22
23#[cfg(all(target_arch = "x86_64", feature = "native"))]
24#[allow(unused_imports)]
25use std::arch::x86_64::*;
26
27const CODEBOOK_MAGIC: u32 = 0x5343424B; pub const DEFAULT_NUM_CENTROIDS: usize = 256;
32
33pub const DEFAULT_DIMS_PER_BLOCK: usize = 2;
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PQConfig {
39 pub dim: usize,
41 pub num_subspaces: usize,
43 pub dims_per_block: usize,
45 pub num_centroids: usize,
47 pub seed: u64,
49 pub anisotropic: bool,
51 pub aniso_eta: f32,
53 pub aniso_threshold: f32,
55 pub use_opq: bool,
57 pub opq_iters: usize,
59}
60
61impl PQConfig {
62 pub fn new(dim: usize) -> Self {
64 let dims_per_block = DEFAULT_DIMS_PER_BLOCK;
65 let num_subspaces = dim / dims_per_block;
66
67 Self {
68 dim,
69 num_subspaces,
70 dims_per_block,
71 num_centroids: DEFAULT_NUM_CENTROIDS,
72 seed: 42,
73 anisotropic: true,
74 aniso_eta: 10.0,
75 aniso_threshold: 0.2,
76 use_opq: true,
77 opq_iters: 10,
78 }
79 }
80
81 pub fn new_fast(dim: usize) -> Self {
83 let num_subspaces = if dim >= 256 {
84 8
85 } else if dim >= 64 {
86 4
87 } else {
88 2
89 };
90 let dims_per_block = dim / num_subspaces;
91
92 Self {
93 dim,
94 num_subspaces,
95 dims_per_block,
96 num_centroids: DEFAULT_NUM_CENTROIDS,
97 seed: 42,
98 anisotropic: true,
99 aniso_eta: 10.0,
100 aniso_threshold: 0.2,
101 use_opq: false,
102 opq_iters: 0,
103 }
104 }
105
106 pub fn new_balanced(dim: usize) -> Self {
109 let num_subspaces = if dim >= 128 {
110 16
111 } else if dim >= 64 {
112 8
113 } else {
114 4
115 };
116 let dims_per_block = dim / num_subspaces;
117
118 Self {
119 dim,
120 num_subspaces,
121 dims_per_block,
122 num_centroids: DEFAULT_NUM_CENTROIDS,
123 seed: 42,
124 anisotropic: true,
125 aniso_eta: 10.0,
126 aniso_threshold: 0.2,
127 use_opq: false,
128 opq_iters: 0,
129 }
130 }
131
132 pub fn with_dims_per_block(mut self, d: usize) -> Self {
133 assert!(
134 self.dim.is_multiple_of(d),
135 "Dimension must be divisible by dims_per_block"
136 );
137 self.dims_per_block = d;
138 self.num_subspaces = self.dim / d;
139 self
140 }
141
142 pub fn with_subspaces(mut self, m: usize) -> Self {
143 assert!(
144 self.dim.is_multiple_of(m),
145 "Dimension must be divisible by num_subspaces"
146 );
147 self.num_subspaces = m;
148 self.dims_per_block = self.dim / m;
149 self
150 }
151
152 pub fn with_centroids(mut self, k: usize) -> Self {
153 assert!(k <= 256, "Max 256 centroids for u8 codes");
154 self.num_centroids = k;
155 self
156 }
157
158 pub fn with_anisotropic(mut self, enabled: bool, eta: f32) -> Self {
159 self.anisotropic = enabled;
160 self.aniso_eta = eta;
161 self
162 }
163
164 pub fn with_opq(mut self, enabled: bool, iters: usize) -> Self {
165 self.use_opq = enabled;
166 self.opq_iters = iters;
167 self
168 }
169
170 pub fn subspace_dim(&self) -> usize {
172 self.dims_per_block
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct PQVector {
179 pub codes: Vec<u8>,
181 pub norm: f32,
183}
184
185impl PQVector {
186 pub fn new(codes: Vec<u8>, norm: f32) -> Self {
187 Self { codes, norm }
188 }
189}
190
191impl QuantizedCode for PQVector {
192 fn size_bytes(&self) -> usize {
193 self.codes.len() + 4 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct PQCodebook {
202 pub config: PQConfig,
204 pub rotation_matrix: Option<Vec<f32>>,
206 pub centroids: Vec<f32>,
208 pub version: u64,
210 pub centroid_norms: Option<Vec<f32>>,
212}
213
214impl PQCodebook {
215 #[cfg(feature = "native")]
217 pub fn train(config: PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Self {
218 use kmeans::{EuclideanDistance, KMeans, KMeansConfig};
219
220 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
221 assert_eq!(vectors[0].len(), config.dim, "Vector dimension mismatch");
222
223 let m = config.num_subspaces;
224 let k = config.num_centroids;
225 let sub_dim = config.subspace_dim();
226 let n = vectors.len();
227
228 let rotation_matrix = if config.use_opq && config.opq_iters > 0 {
230 Some(Self::learn_opq_rotation(&config, vectors, max_iters))
231 } else {
232 None
233 };
234
235 let rotated_vectors: Vec<Vec<f32>> = if let Some(ref r) = rotation_matrix {
237 vectors
238 .iter()
239 .map(|v| Self::apply_rotation(r, v, config.dim))
240 .collect()
241 } else {
242 vectors.to_vec()
243 };
244
245 let mut centroids = Vec::with_capacity(m * k * sub_dim);
247
248 for subspace_idx in 0..m {
249 let offset = subspace_idx * sub_dim;
250
251 let subdata: Vec<f32> = rotated_vectors
252 .iter()
253 .flat_map(|v| v[offset..offset + sub_dim].iter().copied())
254 .collect();
255
256 let actual_k = k.min(n);
257
258 let kmean: KMeans<f32, 8, _> = KMeans::new(&subdata, n, sub_dim, EuclideanDistance);
259 let result = kmean.kmeans_lloyd(
260 actual_k,
261 max_iters,
262 KMeans::init_kmeanplusplus,
263 &KMeansConfig::default(),
264 );
265
266 let subspace_centroids: Vec<f32> = result
267 .centroids
268 .iter()
269 .flat_map(|c| c.iter().copied())
270 .collect();
271
272 centroids.extend(subspace_centroids);
273
274 while centroids.len() < (subspace_idx + 1) * k * sub_dim {
276 let last_start = centroids.len() - sub_dim;
277 let last: Vec<f32> = centroids[last_start..].to_vec();
278 centroids.extend(last);
279 }
280 }
281
282 let centroid_norms: Vec<f32> = (0..m * k)
284 .map(|i| {
285 let start = i * sub_dim;
286 if start + sub_dim <= centroids.len() {
287 centroids[start..start + sub_dim]
288 .iter()
289 .map(|x| x * x)
290 .sum::<f32>()
291 .sqrt()
292 } else {
293 0.0
294 }
295 })
296 .collect();
297
298 let version = std::time::SystemTime::now()
299 .duration_since(std::time::UNIX_EPOCH)
300 .unwrap_or_default()
301 .as_millis() as u64;
302
303 Self {
304 config,
305 rotation_matrix,
306 centroids,
307 version,
308 centroid_norms: Some(centroid_norms),
309 }
310 }
311
312 #[cfg(not(feature = "native"))]
314 pub fn train(config: PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Self {
315 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
316 assert_eq!(vectors[0].len(), config.dim, "Vector dimension mismatch");
317
318 let m = config.num_subspaces;
319 let k = config.num_centroids;
320 let sub_dim = config.subspace_dim();
321 let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
322
323 let rotation_matrix = None;
324 let mut centroids = Vec::with_capacity(m * k * sub_dim);
325
326 for subspace_idx in 0..m {
327 let offset = subspace_idx * sub_dim;
328 let subvectors: Vec<Vec<f32>> = vectors
329 .iter()
330 .map(|v| v[offset..offset + sub_dim].to_vec())
331 .collect();
332
333 let subspace_centroids =
334 Self::train_subspace_scalar(&subvectors, k, sub_dim, max_iters, &mut rng);
335 centroids.extend(subspace_centroids);
336 }
337
338 let centroid_norms: Vec<f32> = (0..m * k)
339 .map(|i| {
340 let start = i * sub_dim;
341 centroids[start..start + sub_dim]
342 .iter()
343 .map(|x| x * x)
344 .sum::<f32>()
345 .sqrt()
346 })
347 .collect();
348
349 let version = std::time::SystemTime::now()
350 .duration_since(std::time::UNIX_EPOCH)
351 .unwrap_or_default()
352 .as_millis() as u64;
353
354 Self {
355 config,
356 rotation_matrix,
357 centroids,
358 version,
359 centroid_norms: Some(centroid_norms),
360 }
361 }
362
363 #[cfg(feature = "native")]
365 fn learn_opq_rotation(config: &PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Vec<f32> {
366 use nalgebra::DMatrix;
367
368 let dim = config.dim;
369 let n = vectors.len();
370
371 let mut rotation = DMatrix::<f32>::identity(dim, dim);
372 let data: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
373 let x = DMatrix::from_row_slice(n, dim, &data);
374
375 for _iter in 0..config.opq_iters.min(max_iters) {
376 let rotated = &x * &rotation;
377 let assignments = Self::compute_pq_assignments(config, &rotated);
378 let reconstructed = Self::reconstruct_from_assignments(config, &rotated, &assignments);
379
380 let xtx_hat = x.transpose() * &reconstructed;
381 let svd = xtx_hat.svd(true, true);
382 if let (Some(u), Some(vt)) = (svd.u, svd.v_t) {
383 let new_rotation: DMatrix<f32> = vt.transpose() * u.transpose();
384 rotation = new_rotation;
385 }
386 }
387
388 rotation.iter().copied().collect()
389 }
390
391 #[cfg(feature = "native")]
392 fn compute_pq_assignments(
393 config: &PQConfig,
394 rotated: &nalgebra::DMatrix<f32>,
395 ) -> Vec<Vec<usize>> {
396 use kmeans::{EuclideanDistance, KMeans, KMeansConfig};
397
398 let m = config.num_subspaces;
399 let k = config.num_centroids.min(rotated.nrows());
400 let sub_dim = config.subspace_dim();
401 let n = rotated.nrows();
402
403 let mut all_assignments = vec![vec![0usize; m]; n];
404
405 for subspace_idx in 0..m {
406 let mut subdata: Vec<f32> = Vec::with_capacity(n * sub_dim);
407 for row in 0..n {
408 for col in 0..sub_dim {
409 subdata.push(rotated[(row, subspace_idx * sub_dim + col)]);
410 }
411 }
412
413 let kmean: KMeans<f32, 8, _> = KMeans::new(&subdata, n, sub_dim, EuclideanDistance);
414 let result =
415 kmean.kmeans_lloyd(k, 5, KMeans::init_kmeanplusplus, &KMeansConfig::default());
416
417 for (i, &assignment) in result.assignments.iter().enumerate() {
418 all_assignments[i][subspace_idx] = assignment;
419 }
420 }
421
422 all_assignments
423 }
424
425 #[cfg(feature = "native")]
426 fn reconstruct_from_assignments(
427 config: &PQConfig,
428 rotated: &nalgebra::DMatrix<f32>,
429 assignments: &[Vec<usize>],
430 ) -> nalgebra::DMatrix<f32> {
431 use kmeans::{EuclideanDistance, KMeans, KMeansConfig};
432
433 let m = config.num_subspaces;
434 let sub_dim = config.subspace_dim();
435 let n = rotated.nrows();
436 let dim = config.dim;
437
438 let mut reconstructed = nalgebra::DMatrix::<f32>::zeros(n, dim);
439
440 for subspace_idx in 0..m {
441 let mut subdata: Vec<f32> = Vec::with_capacity(n * sub_dim);
442 for row in 0..n {
443 for col in 0..sub_dim {
444 subdata.push(rotated[(row, subspace_idx * sub_dim + col)]);
445 }
446 }
447
448 let k = config.num_centroids.min(n);
449 let kmean: KMeans<f32, 8, _> = KMeans::new(&subdata, n, sub_dim, EuclideanDistance);
450 let result =
451 kmean.kmeans_lloyd(k, 5, KMeans::init_kmeanplusplus, &KMeansConfig::default());
452
453 for (row, assignment) in assignments.iter().enumerate() {
454 let centroid_idx = assignment[subspace_idx];
455 if centroid_idx < k {
456 for (col, &val) in result.centroids[centroid_idx].iter().enumerate() {
457 reconstructed[(row, subspace_idx * sub_dim + col)] = val;
458 }
459 }
460 }
461 }
462
463 reconstructed
464 }
465
466 fn apply_rotation(rotation: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
468 let mut result = vec![0.0f32; dim];
469 for i in 0..dim {
470 for j in 0..dim {
471 result[i] += rotation[i * dim + j] * vector[j];
472 }
473 }
474 result
475 }
476
477 #[cfg(not(feature = "native"))]
479 fn train_subspace_scalar(
480 subvectors: &[Vec<f32>],
481 k: usize,
482 sub_dim: usize,
483 max_iters: usize,
484 rng: &mut impl Rng,
485 ) -> Vec<f32> {
486 let actual_k = k.min(subvectors.len());
487 let mut centroids = Self::kmeans_plusplus_init_scalar(subvectors, actual_k, sub_dim, rng);
488
489 for _ in 0..max_iters {
490 let assignments: Vec<usize> = subvectors
491 .iter()
492 .map(|v| Self::find_nearest_scalar(¢roids, v, sub_dim))
493 .collect();
494
495 let mut new_centroids = vec![0.0f32; actual_k * sub_dim];
496 let mut counts = vec![0usize; actual_k];
497
498 for (subvec, &assignment) in subvectors.iter().zip(assignments.iter()) {
499 counts[assignment] += 1;
500 let offset = assignment * sub_dim;
501 for (j, &val) in subvec.iter().enumerate() {
502 new_centroids[offset + j] += val;
503 }
504 }
505
506 for c in 0..actual_k {
507 if counts[c] > 0 {
508 let offset = c * sub_dim;
509 for j in 0..sub_dim {
510 new_centroids[offset + j] /= counts[c] as f32;
511 }
512 }
513 }
514
515 centroids = new_centroids;
516 }
517
518 while centroids.len() < k * sub_dim {
519 let last_start = centroids.len() - sub_dim;
520 let last: Vec<f32> = centroids[last_start..].to_vec();
521 centroids.extend(last);
522 }
523
524 centroids
525 }
526
527 #[cfg(not(feature = "native"))]
528 fn kmeans_plusplus_init_scalar(
529 subvectors: &[Vec<f32>],
530 k: usize,
531 sub_dim: usize,
532 rng: &mut impl Rng,
533 ) -> Vec<f32> {
534 let mut centroids = Vec::with_capacity(k * sub_dim);
535 let first_idx = rng.random_range(0..subvectors.len());
536 centroids.extend_from_slice(&subvectors[first_idx]);
537
538 for _ in 1..k {
539 let distances: Vec<f32> = subvectors
540 .iter()
541 .map(|v| Self::min_dist_to_centroids_scalar(¢roids, v, sub_dim))
542 .collect();
543
544 let total: f32 = distances.iter().sum();
545 let mut r = rng.random::<f32>() * total;
546 let mut chosen_idx = 0;
547 for (i, &d) in distances.iter().enumerate() {
548 r -= d;
549 if r <= 0.0 {
550 chosen_idx = i;
551 break;
552 }
553 }
554 centroids.extend_from_slice(&subvectors[chosen_idx]);
555 }
556
557 centroids
558 }
559
560 #[cfg(not(feature = "native"))]
561 fn min_dist_to_centroids_scalar(centroids: &[f32], vector: &[f32], sub_dim: usize) -> f32 {
562 let num_centroids = centroids.len() / sub_dim;
563 (0..num_centroids)
564 .map(|c| {
565 let offset = c * sub_dim;
566 vector
567 .iter()
568 .zip(¢roids[offset..offset + sub_dim])
569 .map(|(&a, &b)| (a - b) * (a - b))
570 .sum()
571 })
572 .fold(f32::MAX, f32::min)
573 }
574
575 #[cfg(not(feature = "native"))]
576 fn find_nearest_scalar(centroids: &[f32], vector: &[f32], sub_dim: usize) -> usize {
577 let num_centroids = centroids.len() / sub_dim;
578 (0..num_centroids)
579 .map(|c| {
580 let offset = c * sub_dim;
581 let dist: f32 = vector
582 .iter()
583 .zip(¢roids[offset..offset + sub_dim])
584 .map(|(&a, &b)| (a - b) * (a - b))
585 .sum();
586 (c, dist)
587 })
588 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
589 .map(|(c, _)| c)
590 .unwrap_or(0)
591 }
592
593 fn find_nearest(centroids: &[f32], vector: &[f32], sub_dim: usize) -> usize {
595 let num_centroids = centroids.len() / sub_dim;
596 let mut best_idx = 0;
597 let mut best_dist = f32::MAX;
598
599 for c in 0..num_centroids {
600 let offset = c * sub_dim;
601 let dist: f32 = vector
602 .iter()
603 .zip(¢roids[offset..offset + sub_dim])
604 .map(|(&a, &b)| (a - b) * (a - b))
605 .sum();
606
607 if dist < best_dist {
608 best_dist = dist;
609 best_idx = c;
610 }
611 }
612
613 best_idx
614 }
615
616 pub fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> PQVector {
618 let m = self.config.num_subspaces;
619 let k = self.config.num_centroids;
620 let sub_dim = self.config.subspace_dim();
621
622 let residual: Vec<f32> = if let Some(c) = centroid {
624 vector.iter().zip(c).map(|(&v, &c)| v - c).collect()
625 } else {
626 vector.to_vec()
627 };
628
629 let rotated: Vec<f32>;
631 let vec_to_encode = if let Some(ref r) = self.rotation_matrix {
632 rotated = Self::apply_rotation(r, &residual, self.config.dim);
633 &rotated
634 } else {
635 &residual
636 };
637
638 let mut codes = Vec::with_capacity(m);
639
640 for subspace_idx in 0..m {
641 let vec_offset = subspace_idx * sub_dim;
642 let subvec = &vec_to_encode[vec_offset..vec_offset + sub_dim];
643
644 let centroid_base = subspace_idx * k * sub_dim;
645 let centroids_slice = &self.centroids[centroid_base..centroid_base + k * sub_dim];
646
647 let nearest = Self::find_nearest(centroids_slice, subvec, sub_dim);
648 codes.push(nearest as u8);
649 }
650
651 let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
652 PQVector::new(codes, norm)
653 }
654
655 pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
657 let m = self.config.num_subspaces;
658 let k = self.config.num_centroids;
659 let sub_dim = self.config.subspace_dim();
660
661 let mut rotated_vector = Vec::with_capacity(self.config.dim);
662
663 for (subspace_idx, &code) in codes.iter().enumerate().take(m) {
664 let centroid_base = subspace_idx * k * sub_dim;
665 let centroid_offset = centroid_base + (code as usize) * sub_dim;
666 rotated_vector
667 .extend_from_slice(&self.centroids[centroid_offset..centroid_offset + sub_dim]);
668 }
669
670 if let Some(ref r) = self.rotation_matrix {
672 Self::apply_rotation_transpose(r, &rotated_vector, self.config.dim)
673 } else {
674 rotated_vector
675 }
676 }
677
678 fn apply_rotation_transpose(rotation: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
680 let mut result = vec![0.0f32; dim];
681 for i in 0..dim {
682 for j in 0..dim {
683 result[i] += rotation[j * dim + i] * vector[j];
684 }
685 }
686 result
687 }
688
689 #[inline]
691 pub fn get_centroid(&self, subspace_idx: usize, code: u8) -> &[f32] {
692 let k = self.config.num_centroids;
693 let sub_dim = self.config.subspace_dim();
694 let offset = subspace_idx * k * sub_dim + (code as usize) * sub_dim;
695 &self.centroids[offset..offset + sub_dim]
696 }
697
698 pub fn rotate_query(&self, query: &[f32]) -> Vec<f32> {
700 if let Some(ref r) = self.rotation_matrix {
701 Self::apply_rotation(r, query, self.config.dim)
702 } else {
703 query.to_vec()
704 }
705 }
706
707 pub fn save(&self, path: &std::path::Path) -> io::Result<()> {
709 let mut file = std::fs::File::create(path)?;
710 self.write_to(&mut file)
711 }
712
713 pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
715 writer.write_u32::<LittleEndian>(CODEBOOK_MAGIC)?;
716 writer.write_u32::<LittleEndian>(2)?;
717 writer.write_u64::<LittleEndian>(self.version)?;
718 writer.write_u32::<LittleEndian>(self.config.dim as u32)?;
719 writer.write_u32::<LittleEndian>(self.config.num_subspaces as u32)?;
720 writer.write_u32::<LittleEndian>(self.config.dims_per_block as u32)?;
721 writer.write_u32::<LittleEndian>(self.config.num_centroids as u32)?;
722 writer.write_u8(if self.config.anisotropic { 1 } else { 0 })?;
723 writer.write_f32::<LittleEndian>(self.config.aniso_eta)?;
724 writer.write_f32::<LittleEndian>(self.config.aniso_threshold)?;
725 writer.write_u8(if self.config.use_opq { 1 } else { 0 })?;
726 writer.write_u32::<LittleEndian>(self.config.opq_iters as u32)?;
727
728 if let Some(ref rotation) = self.rotation_matrix {
729 writer.write_u8(1)?;
730 for &val in rotation {
731 writer.write_f32::<LittleEndian>(val)?;
732 }
733 } else {
734 writer.write_u8(0)?;
735 }
736
737 for &val in &self.centroids {
738 writer.write_f32::<LittleEndian>(val)?;
739 }
740
741 if let Some(ref norms) = self.centroid_norms {
742 writer.write_u8(1)?;
743 for &val in norms {
744 writer.write_f32::<LittleEndian>(val)?;
745 }
746 } else {
747 writer.write_u8(0)?;
748 }
749
750 Ok(())
751 }
752
753 pub fn load(path: &std::path::Path) -> io::Result<Self> {
755 let data = std::fs::read(path)?;
756 Self::read_from(&mut std::io::Cursor::new(data))
757 }
758
759 pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
761 let magic = reader.read_u32::<LittleEndian>()?;
762 if magic != CODEBOOK_MAGIC {
763 return Err(io::Error::new(
764 io::ErrorKind::InvalidData,
765 "Invalid codebook file magic",
766 ));
767 }
768
769 let file_version = reader.read_u32::<LittleEndian>()?;
770 let version = reader.read_u64::<LittleEndian>()?;
771 let dim = reader.read_u32::<LittleEndian>()? as usize;
772 let num_subspaces = reader.read_u32::<LittleEndian>()? as usize;
773
774 let (
775 dims_per_block,
776 num_centroids,
777 anisotropic,
778 aniso_eta,
779 aniso_threshold,
780 use_opq,
781 opq_iters,
782 ) = if file_version >= 2 {
783 let dpb = reader.read_u32::<LittleEndian>()? as usize;
784 let nc = reader.read_u32::<LittleEndian>()? as usize;
785 let aniso = reader.read_u8()? != 0;
786 let eta = reader.read_f32::<LittleEndian>()?;
787 let thresh = reader.read_f32::<LittleEndian>()?;
788 let opq = reader.read_u8()? != 0;
789 let iters = reader.read_u32::<LittleEndian>()? as usize;
790 (dpb, nc, aniso, eta, thresh, opq, iters)
791 } else {
792 let nc = reader.read_u32::<LittleEndian>()? as usize;
793 let aniso = reader.read_u8()? != 0;
794 let thresh = reader.read_f32::<LittleEndian>()?;
795 let dpb = dim / num_subspaces;
796 (dpb, nc, aniso, 10.0, thresh, false, 0)
797 };
798
799 let config = PQConfig {
800 dim,
801 num_subspaces,
802 dims_per_block,
803 num_centroids,
804 seed: 42,
805 anisotropic,
806 aniso_eta,
807 aniso_threshold,
808 use_opq,
809 opq_iters,
810 };
811
812 let rotation_matrix = if file_version >= 2 {
813 let has_rotation = reader.read_u8()? != 0;
814 if has_rotation {
815 let mut rotation = vec![0.0f32; dim * dim];
816 for val in &mut rotation {
817 *val = reader.read_f32::<LittleEndian>()?;
818 }
819 Some(rotation)
820 } else {
821 None
822 }
823 } else {
824 None
825 };
826
827 let centroid_count = num_subspaces * num_centroids * config.subspace_dim();
828 let mut centroids = vec![0.0f32; centroid_count];
829 for val in &mut centroids {
830 *val = reader.read_f32::<LittleEndian>()?;
831 }
832
833 let has_norms = reader.read_u8()? != 0;
834 let centroid_norms = if has_norms {
835 let mut norms = vec![0.0f32; num_subspaces * num_centroids];
836 for val in &mut norms {
837 *val = reader.read_f32::<LittleEndian>()?;
838 }
839 Some(norms)
840 } else {
841 None
842 };
843
844 Ok(Self {
845 config,
846 rotation_matrix,
847 centroids,
848 version,
849 centroid_norms,
850 })
851 }
852
853 pub fn size_bytes(&self) -> usize {
855 let centroids_size = self.centroids.len() * 4;
856 let norms_size = self
857 .centroid_norms
858 .as_ref()
859 .map(|n| n.len() * 4)
860 .unwrap_or(0);
861 let rotation_size = self
862 .rotation_matrix
863 .as_ref()
864 .map(|r| r.len() * 4)
865 .unwrap_or(0);
866 centroids_size + norms_size + rotation_size + 64
867 }
868}
869
870#[derive(Debug, Clone)]
872pub struct DistanceTable {
873 pub distances: Vec<f32>,
875 pub num_subspaces: usize,
877 pub num_centroids: usize,
879}
880
881impl DistanceTable {
882 pub fn build(codebook: &PQCodebook, query: &[f32], centroid: Option<&[f32]>) -> Self {
884 let m = codebook.config.num_subspaces;
885 let k = codebook.config.num_centroids;
886 let sub_dim = codebook.config.subspace_dim();
887
888 let residual: Vec<f32> = if let Some(c) = centroid {
890 query.iter().zip(c).map(|(&v, &c)| v - c).collect()
891 } else {
892 query.to_vec()
893 };
894
895 let rotated_query = codebook.rotate_query(&residual);
897
898 let mut distances = Vec::with_capacity(m * k);
899
900 for subspace_idx in 0..m {
901 let query_offset = subspace_idx * sub_dim;
902 let query_sub = &rotated_query[query_offset..query_offset + sub_dim];
903
904 let centroid_base = subspace_idx * k * sub_dim;
905
906 for centroid_idx in 0..k {
907 let centroid_offset = centroid_base + centroid_idx * sub_dim;
908 let centroid = &codebook.centroids[centroid_offset..centroid_offset + sub_dim];
909
910 let dist: f32 = query_sub
911 .iter()
912 .zip(centroid.iter())
913 .map(|(&a, &b)| (a - b) * (a - b))
914 .sum();
915
916 distances.push(dist);
917 }
918 }
919
920 Self {
921 distances,
922 num_subspaces: m,
923 num_centroids: k,
924 }
925 }
926
927 #[inline]
929 pub fn compute_distance(&self, codes: &[u8]) -> f32 {
930 let k = self.num_centroids;
931 let mut total = 0.0f32;
932
933 for (subspace_idx, &code) in codes.iter().enumerate() {
934 let table_offset = subspace_idx * k + code as usize;
935 total += self.distances[table_offset];
936 }
937
938 total
939 }
940}
941
942impl Quantizer for PQCodebook {
943 type Code = PQVector;
944 type Config = PQConfig;
945 type QueryData = DistanceTable;
946
947 fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> Self::Code {
948 self.encode(vector, centroid)
949 }
950
951 fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> Self::QueryData {
952 DistanceTable::build(self, query, centroid)
953 }
954
955 fn compute_distance(&self, query_data: &Self::QueryData, code: &Self::Code) -> f32 {
956 query_data.compute_distance(&code.codes)
957 }
958
959 fn decode(&self, code: &Self::Code) -> Option<Vec<f32>> {
960 Some(self.decode(&code.codes))
961 }
962
963 fn size_bytes(&self) -> usize {
964 self.size_bytes()
965 }
966}
967
968#[cfg(test)]
969mod tests {
970 use super::*;
971 use rand::prelude::*;
972
973 #[test]
974 fn test_pq_config() {
975 let config = PQConfig::new(128);
976 assert_eq!(config.dim, 128);
977 assert_eq!(config.dims_per_block, 2);
978 assert_eq!(config.num_subspaces, 64);
979 }
980
981 #[test]
982 fn test_pq_encode_decode() {
983 let dim = 32;
984 let config = PQConfig::new(dim).with_opq(false, 0);
985
986 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
987 let vectors: Vec<Vec<f32>> = (0..100)
988 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
989 .collect();
990
991 let codebook = PQCodebook::train(config, &vectors, 10);
992
993 let test_vec: Vec<f32> = (0..dim).map(|i| i as f32 / dim as f32).collect();
994 let code = codebook.encode(&test_vec, None);
995
996 assert_eq!(code.codes.len(), 16); }
998
999 #[test]
1000 fn test_distance_table() {
1001 let dim = 16;
1002 let config = PQConfig::new(dim).with_opq(false, 0);
1003
1004 let mut rng = rand::rngs::StdRng::seed_from_u64(123);
1005 let vectors: Vec<Vec<f32>> = (0..50)
1006 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
1007 .collect();
1008
1009 let codebook = PQCodebook::train(config, &vectors, 5);
1010
1011 let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
1012 let table = DistanceTable::build(&codebook, &query, None);
1013
1014 let code = codebook.encode(&vectors[0], None);
1015 let dist = table.compute_distance(&code.codes);
1016
1017 assert!(dist >= 0.0);
1018 }
1019}