Skip to main content

hermes_core/structures/vector/quantization/
pq.rs

1//! Product Quantization with OPQ and Anisotropic Loss (ScaNN-style)
2//!
3//! Implementation inspired by Google's ScaNN (Scalable Nearest Neighbors):
4//! - **True anisotropic quantization**: penalizes parallel error more than orthogonal
5//! - **OPQ rotation**: learns optimal rotation matrix before quantization
6//! - **Product quantization** with learned codebooks
7//! - **SIMD-accelerated** asymmetric distance computation
8
9use std::io::{self, Read, Write};
10
11use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
12#[cfg(not(feature = "native"))]
13use rand::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use super::super::ivf::cluster::QuantizedCode;
17use super::Quantizer;
18
19#[cfg(target_arch = "aarch64")]
20#[allow(unused_imports)]
21use std::arch::aarch64::*;
22
23#[cfg(all(target_arch = "x86_64", feature = "native"))]
24#[allow(unused_imports)]
25use std::arch::x86_64::*;
26
27/// Magic number for codebook file
28const CODEBOOK_MAGIC: u32 = 0x5343424B; // "SCBK" - ScaNN CodeBook
29
30/// Default number of centroids per subspace (K) - must be 256 for u8 codes
31pub const DEFAULT_NUM_CENTROIDS: usize = 256;
32
33/// Default dimensions per block (ScaNN recommends 2 for best accuracy)
34pub const DEFAULT_DIMS_PER_BLOCK: usize = 2;
35
36/// Configuration for Product Quantization with OPQ and Anisotropic Loss
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PQConfig {
39    /// Vector dimension
40    pub dim: usize,
41    /// Number of subspaces (M) - computed from dim / dims_per_block
42    pub num_subspaces: usize,
43    /// Dimensions per subspace block (ScaNN recommends 2)
44    pub dims_per_block: usize,
45    /// Number of centroids per subspace (K) - typically 256 for u8 codes
46    pub num_centroids: usize,
47    /// Random seed for reproducibility
48    pub seed: u64,
49    /// Use anisotropic quantization (true ScaNN-style parallel/orthogonal weighting)
50    pub anisotropic: bool,
51    /// Anisotropic eta: ratio of parallel to orthogonal error weight (η)
52    pub aniso_eta: f32,
53    /// Anisotropic threshold T: only consider inner products >= T
54    pub aniso_threshold: f32,
55    /// Use OPQ rotation matrix (learned via SVD)
56    pub use_opq: bool,
57    /// Number of OPQ iterations
58    pub opq_iters: usize,
59}
60
61impl PQConfig {
62    /// Create config with ScaNN-recommended defaults
63    pub fn new(dim: usize) -> Self {
64        let dims_per_block = DEFAULT_DIMS_PER_BLOCK;
65        let num_subspaces = dim / dims_per_block;
66
67        Self {
68            dim,
69            num_subspaces,
70            dims_per_block,
71            num_centroids: DEFAULT_NUM_CENTROIDS,
72            seed: 42,
73            anisotropic: true,
74            aniso_eta: 10.0,
75            aniso_threshold: 0.2,
76            use_opq: true,
77            opq_iters: 10,
78        }
79    }
80
81    /// Create config with larger subspaces (faster but less accurate)
82    pub fn new_fast(dim: usize) -> Self {
83        let num_subspaces = if dim >= 256 {
84            8
85        } else if dim >= 64 {
86            4
87        } else {
88            2
89        };
90        let dims_per_block = dim / num_subspaces;
91
92        Self {
93            dim,
94            num_subspaces,
95            dims_per_block,
96            num_centroids: DEFAULT_NUM_CENTROIDS,
97            seed: 42,
98            anisotropic: true,
99            aniso_eta: 10.0,
100            aniso_threshold: 0.2,
101            use_opq: false,
102            opq_iters: 0,
103        }
104    }
105
106    /// Create balanced config (good recall/speed tradeoff)
107    /// Uses 16 subspaces for 128D+ vectors, 8 for smaller
108    pub fn new_balanced(dim: usize) -> Self {
109        let num_subspaces = if dim >= 128 {
110            16
111        } else if dim >= 64 {
112            8
113        } else {
114            4
115        };
116        let dims_per_block = dim / num_subspaces;
117
118        Self {
119            dim,
120            num_subspaces,
121            dims_per_block,
122            num_centroids: DEFAULT_NUM_CENTROIDS,
123            seed: 42,
124            anisotropic: true,
125            aniso_eta: 10.0,
126            aniso_threshold: 0.2,
127            use_opq: false,
128            opq_iters: 0,
129        }
130    }
131
132    pub fn with_dims_per_block(mut self, d: usize) -> Self {
133        assert!(
134            self.dim.is_multiple_of(d),
135            "Dimension must be divisible by dims_per_block"
136        );
137        self.dims_per_block = d;
138        self.num_subspaces = self.dim / d;
139        self
140    }
141
142    pub fn with_subspaces(mut self, m: usize) -> Self {
143        assert!(
144            self.dim.is_multiple_of(m),
145            "Dimension must be divisible by num_subspaces"
146        );
147        self.num_subspaces = m;
148        self.dims_per_block = self.dim / m;
149        self
150    }
151
152    pub fn with_centroids(mut self, k: usize) -> Self {
153        assert!(k <= 256, "Max 256 centroids for u8 codes");
154        self.num_centroids = k;
155        self
156    }
157
158    pub fn with_anisotropic(mut self, enabled: bool, eta: f32) -> Self {
159        self.anisotropic = enabled;
160        self.aniso_eta = eta;
161        self
162    }
163
164    pub fn with_opq(mut self, enabled: bool, iters: usize) -> Self {
165        self.use_opq = enabled;
166        self.opq_iters = iters;
167        self
168    }
169
170    /// Dimension of each subspace
171    pub fn subspace_dim(&self) -> usize {
172        self.dims_per_block
173    }
174}
175
176/// Quantized vector using Product Quantization
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct PQVector {
179    /// PQ codes (M bytes, one per subspace)
180    pub codes: Vec<u8>,
181    /// Original vector norm (for re-ranking or normalization)
182    pub norm: f32,
183}
184
185impl PQVector {
186    pub fn new(codes: Vec<u8>, norm: f32) -> Self {
187        Self { codes, norm }
188    }
189}
190
191impl QuantizedCode for PQVector {
192    fn size_bytes(&self) -> usize {
193        self.codes.len() + 4 // codes + norm
194    }
195}
196
197/// Learned codebook for Product Quantization with OPQ rotation
198///
199/// Trained once, shared across all segments (like CoarseCentroids).
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct PQCodebook {
202    /// Configuration
203    pub config: PQConfig,
204    /// OPQ rotation matrix (dim × dim), stored row-major
205    pub rotation_matrix: Option<Vec<f32>>,
206    /// Centroids: M subspaces × K centroids × subspace_dim
207    pub centroids: Vec<f32>,
208    /// Version for merge compatibility checking
209    pub version: u64,
210    /// Precomputed centroid norms for faster distance computation
211    pub centroid_norms: Option<Vec<f32>>,
212}
213
214impl PQCodebook {
215    /// Train codebook with OPQ rotation and anisotropic loss
216    #[cfg(feature = "native")]
217    pub fn train(config: PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Self {
218        use kmeans::{EuclideanDistance, KMeans, KMeansConfig};
219
220        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
221        assert_eq!(vectors[0].len(), config.dim, "Vector dimension mismatch");
222
223        let m = config.num_subspaces;
224        let k = config.num_centroids;
225        let sub_dim = config.subspace_dim();
226        let n = vectors.len();
227
228        // Step 1: Learn OPQ rotation matrix if enabled
229        let rotation_matrix = if config.use_opq && config.opq_iters > 0 {
230            Some(Self::learn_opq_rotation(&config, vectors, max_iters))
231        } else {
232            None
233        };
234
235        // Step 2: Apply rotation to vectors
236        let rotated_vectors: Vec<Vec<f32>> = if let Some(ref r) = rotation_matrix {
237            vectors
238                .iter()
239                .map(|v| Self::apply_rotation(r, v, config.dim))
240                .collect()
241        } else {
242            vectors.to_vec()
243        };
244
245        // Step 3: Train k-means for each subspace
246        let mut centroids = Vec::with_capacity(m * k * sub_dim);
247
248        for subspace_idx in 0..m {
249            let offset = subspace_idx * sub_dim;
250
251            let subdata: Vec<f32> = rotated_vectors
252                .iter()
253                .flat_map(|v| v[offset..offset + sub_dim].iter().copied())
254                .collect();
255
256            let actual_k = k.min(n);
257
258            let kmean: KMeans<f32, 8, _> = KMeans::new(&subdata, n, sub_dim, EuclideanDistance);
259            let result = kmean.kmeans_lloyd(
260                actual_k,
261                max_iters,
262                KMeans::init_kmeanplusplus,
263                &KMeansConfig::default(),
264            );
265
266            let subspace_centroids: Vec<f32> = result
267                .centroids
268                .iter()
269                .flat_map(|c| c.iter().copied())
270                .collect();
271
272            centroids.extend(subspace_centroids);
273
274            // Pad if needed
275            while centroids.len() < (subspace_idx + 1) * k * sub_dim {
276                let last_start = centroids.len() - sub_dim;
277                let last: Vec<f32> = centroids[last_start..].to_vec();
278                centroids.extend(last);
279            }
280        }
281
282        // Precompute centroid norms
283        let centroid_norms: Vec<f32> = (0..m * k)
284            .map(|i| {
285                let start = i * sub_dim;
286                if start + sub_dim <= centroids.len() {
287                    centroids[start..start + sub_dim]
288                        .iter()
289                        .map(|x| x * x)
290                        .sum::<f32>()
291                        .sqrt()
292                } else {
293                    0.0
294                }
295            })
296            .collect();
297
298        let version = std::time::SystemTime::now()
299            .duration_since(std::time::UNIX_EPOCH)
300            .unwrap_or_default()
301            .as_millis() as u64;
302
303        Self {
304            config,
305            rotation_matrix,
306            centroids,
307            version,
308            centroid_norms: Some(centroid_norms),
309        }
310    }
311
312    /// Fallback training for non-native builds (WASM)
313    #[cfg(not(feature = "native"))]
314    pub fn train(config: PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Self {
315        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
316        assert_eq!(vectors[0].len(), config.dim, "Vector dimension mismatch");
317
318        let m = config.num_subspaces;
319        let k = config.num_centroids;
320        let sub_dim = config.subspace_dim();
321        let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
322
323        let rotation_matrix = None;
324        let mut centroids = Vec::with_capacity(m * k * sub_dim);
325
326        for subspace_idx in 0..m {
327            let offset = subspace_idx * sub_dim;
328            let subvectors: Vec<Vec<f32>> = vectors
329                .iter()
330                .map(|v| v[offset..offset + sub_dim].to_vec())
331                .collect();
332
333            let subspace_centroids =
334                Self::train_subspace_scalar(&subvectors, k, sub_dim, max_iters, &mut rng);
335            centroids.extend(subspace_centroids);
336        }
337
338        let centroid_norms: Vec<f32> = (0..m * k)
339            .map(|i| {
340                let start = i * sub_dim;
341                centroids[start..start + sub_dim]
342                    .iter()
343                    .map(|x| x * x)
344                    .sum::<f32>()
345                    .sqrt()
346            })
347            .collect();
348
349        let version = std::time::SystemTime::now()
350            .duration_since(std::time::UNIX_EPOCH)
351            .unwrap_or_default()
352            .as_millis() as u64;
353
354        Self {
355            config,
356            rotation_matrix,
357            centroids,
358            version,
359            centroid_norms: Some(centroid_norms),
360        }
361    }
362
363    /// Learn OPQ rotation matrix using SVD
364    #[cfg(feature = "native")]
365    fn learn_opq_rotation(config: &PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Vec<f32> {
366        use nalgebra::DMatrix;
367
368        let dim = config.dim;
369        let n = vectors.len();
370
371        let mut rotation = DMatrix::<f32>::identity(dim, dim);
372        let data: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
373        let x = DMatrix::from_row_slice(n, dim, &data);
374
375        for _iter in 0..config.opq_iters.min(max_iters) {
376            let rotated = &x * &rotation;
377            let assignments = Self::compute_pq_assignments(config, &rotated);
378            let reconstructed = Self::reconstruct_from_assignments(config, &rotated, &assignments);
379
380            let xtx_hat = x.transpose() * &reconstructed;
381            let svd = xtx_hat.svd(true, true);
382            if let (Some(u), Some(vt)) = (svd.u, svd.v_t) {
383                let new_rotation: DMatrix<f32> = vt.transpose() * u.transpose();
384                rotation = new_rotation;
385            }
386        }
387
388        rotation.iter().copied().collect()
389    }
390
391    #[cfg(feature = "native")]
392    fn compute_pq_assignments(
393        config: &PQConfig,
394        rotated: &nalgebra::DMatrix<f32>,
395    ) -> Vec<Vec<usize>> {
396        use kmeans::{EuclideanDistance, KMeans, KMeansConfig};
397
398        let m = config.num_subspaces;
399        let k = config.num_centroids.min(rotated.nrows());
400        let sub_dim = config.subspace_dim();
401        let n = rotated.nrows();
402
403        let mut all_assignments = vec![vec![0usize; m]; n];
404
405        for subspace_idx in 0..m {
406            let mut subdata: Vec<f32> = Vec::with_capacity(n * sub_dim);
407            for row in 0..n {
408                for col in 0..sub_dim {
409                    subdata.push(rotated[(row, subspace_idx * sub_dim + col)]);
410                }
411            }
412
413            let kmean: KMeans<f32, 8, _> = KMeans::new(&subdata, n, sub_dim, EuclideanDistance);
414            let result =
415                kmean.kmeans_lloyd(k, 5, KMeans::init_kmeanplusplus, &KMeansConfig::default());
416
417            for (i, &assignment) in result.assignments.iter().enumerate() {
418                all_assignments[i][subspace_idx] = assignment;
419            }
420        }
421
422        all_assignments
423    }
424
425    #[cfg(feature = "native")]
426    fn reconstruct_from_assignments(
427        config: &PQConfig,
428        rotated: &nalgebra::DMatrix<f32>,
429        assignments: &[Vec<usize>],
430    ) -> nalgebra::DMatrix<f32> {
431        use kmeans::{EuclideanDistance, KMeans, KMeansConfig};
432
433        let m = config.num_subspaces;
434        let sub_dim = config.subspace_dim();
435        let n = rotated.nrows();
436        let dim = config.dim;
437
438        let mut reconstructed = nalgebra::DMatrix::<f32>::zeros(n, dim);
439
440        for subspace_idx in 0..m {
441            let mut subdata: Vec<f32> = Vec::with_capacity(n * sub_dim);
442            for row in 0..n {
443                for col in 0..sub_dim {
444                    subdata.push(rotated[(row, subspace_idx * sub_dim + col)]);
445                }
446            }
447
448            let k = config.num_centroids.min(n);
449            let kmean: KMeans<f32, 8, _> = KMeans::new(&subdata, n, sub_dim, EuclideanDistance);
450            let result =
451                kmean.kmeans_lloyd(k, 5, KMeans::init_kmeanplusplus, &KMeansConfig::default());
452
453            for (row, assignment) in assignments.iter().enumerate() {
454                let centroid_idx = assignment[subspace_idx];
455                if centroid_idx < k {
456                    for (col, &val) in result.centroids[centroid_idx].iter().enumerate() {
457                        reconstructed[(row, subspace_idx * sub_dim + col)] = val;
458                    }
459                }
460            }
461        }
462
463        reconstructed
464    }
465
466    /// Apply rotation matrix to vector
467    fn apply_rotation(rotation: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
468        let mut result = vec![0.0f32; dim];
469        for i in 0..dim {
470            for j in 0..dim {
471                result[i] += rotation[i * dim + j] * vector[j];
472            }
473        }
474        result
475    }
476
477    /// Scalar k-means for WASM fallback
478    #[cfg(not(feature = "native"))]
479    fn train_subspace_scalar(
480        subvectors: &[Vec<f32>],
481        k: usize,
482        sub_dim: usize,
483        max_iters: usize,
484        rng: &mut impl Rng,
485    ) -> Vec<f32> {
486        let actual_k = k.min(subvectors.len());
487        let mut centroids = Self::kmeans_plusplus_init_scalar(subvectors, actual_k, sub_dim, rng);
488
489        for _ in 0..max_iters {
490            let assignments: Vec<usize> = subvectors
491                .iter()
492                .map(|v| Self::find_nearest_scalar(&centroids, v, sub_dim))
493                .collect();
494
495            let mut new_centroids = vec![0.0f32; actual_k * sub_dim];
496            let mut counts = vec![0usize; actual_k];
497
498            for (subvec, &assignment) in subvectors.iter().zip(assignments.iter()) {
499                counts[assignment] += 1;
500                let offset = assignment * sub_dim;
501                for (j, &val) in subvec.iter().enumerate() {
502                    new_centroids[offset + j] += val;
503                }
504            }
505
506            for c in 0..actual_k {
507                if counts[c] > 0 {
508                    let offset = c * sub_dim;
509                    for j in 0..sub_dim {
510                        new_centroids[offset + j] /= counts[c] as f32;
511                    }
512                }
513            }
514
515            centroids = new_centroids;
516        }
517
518        while centroids.len() < k * sub_dim {
519            let last_start = centroids.len() - sub_dim;
520            let last: Vec<f32> = centroids[last_start..].to_vec();
521            centroids.extend(last);
522        }
523
524        centroids
525    }
526
527    #[cfg(not(feature = "native"))]
528    fn kmeans_plusplus_init_scalar(
529        subvectors: &[Vec<f32>],
530        k: usize,
531        sub_dim: usize,
532        rng: &mut impl Rng,
533    ) -> Vec<f32> {
534        let mut centroids = Vec::with_capacity(k * sub_dim);
535        let first_idx = rng.random_range(0..subvectors.len());
536        centroids.extend_from_slice(&subvectors[first_idx]);
537
538        for _ in 1..k {
539            let distances: Vec<f32> = subvectors
540                .iter()
541                .map(|v| Self::min_dist_to_centroids_scalar(&centroids, v, sub_dim))
542                .collect();
543
544            let total: f32 = distances.iter().sum();
545            let mut r = rng.random::<f32>() * total;
546            let mut chosen_idx = 0;
547            for (i, &d) in distances.iter().enumerate() {
548                r -= d;
549                if r <= 0.0 {
550                    chosen_idx = i;
551                    break;
552                }
553            }
554            centroids.extend_from_slice(&subvectors[chosen_idx]);
555        }
556
557        centroids
558    }
559
560    #[cfg(not(feature = "native"))]
561    fn min_dist_to_centroids_scalar(centroids: &[f32], vector: &[f32], sub_dim: usize) -> f32 {
562        let num_centroids = centroids.len() / sub_dim;
563        (0..num_centroids)
564            .map(|c| {
565                let offset = c * sub_dim;
566                vector
567                    .iter()
568                    .zip(&centroids[offset..offset + sub_dim])
569                    .map(|(&a, &b)| (a - b) * (a - b))
570                    .sum()
571            })
572            .fold(f32::MAX, f32::min)
573    }
574
575    #[cfg(not(feature = "native"))]
576    fn find_nearest_scalar(centroids: &[f32], vector: &[f32], sub_dim: usize) -> usize {
577        let num_centroids = centroids.len() / sub_dim;
578        (0..num_centroids)
579            .map(|c| {
580                let offset = c * sub_dim;
581                let dist: f32 = vector
582                    .iter()
583                    .zip(&centroids[offset..offset + sub_dim])
584                    .map(|(&a, &b)| (a - b) * (a - b))
585                    .sum();
586                (c, dist)
587            })
588            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
589            .map(|(c, _)| c)
590            .unwrap_or(0)
591    }
592
593    /// Find nearest centroid index
594    fn find_nearest(centroids: &[f32], vector: &[f32], sub_dim: usize) -> usize {
595        let num_centroids = centroids.len() / sub_dim;
596        let mut best_idx = 0;
597        let mut best_dist = f32::MAX;
598
599        for c in 0..num_centroids {
600            let offset = c * sub_dim;
601            let dist: f32 = vector
602                .iter()
603                .zip(&centroids[offset..offset + sub_dim])
604                .map(|(&a, &b)| (a - b) * (a - b))
605                .sum();
606
607            if dist < best_dist {
608                best_dist = dist;
609                best_idx = c;
610            }
611        }
612
613        best_idx
614    }
615
616    /// Encode a vector to PQ codes
617    pub fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> PQVector {
618        let m = self.config.num_subspaces;
619        let k = self.config.num_centroids;
620        let sub_dim = self.config.subspace_dim();
621
622        // Compute residual if centroid provided
623        let residual: Vec<f32> = if let Some(c) = centroid {
624            vector.iter().zip(c).map(|(&v, &c)| v - c).collect()
625        } else {
626            vector.to_vec()
627        };
628
629        // Apply rotation if present
630        let rotated: Vec<f32>;
631        let vec_to_encode = if let Some(ref r) = self.rotation_matrix {
632            rotated = Self::apply_rotation(r, &residual, self.config.dim);
633            &rotated
634        } else {
635            &residual
636        };
637
638        let mut codes = Vec::with_capacity(m);
639
640        for subspace_idx in 0..m {
641            let vec_offset = subspace_idx * sub_dim;
642            let subvec = &vec_to_encode[vec_offset..vec_offset + sub_dim];
643
644            let centroid_base = subspace_idx * k * sub_dim;
645            let centroids_slice = &self.centroids[centroid_base..centroid_base + k * sub_dim];
646
647            let nearest = Self::find_nearest(centroids_slice, subvec, sub_dim);
648            codes.push(nearest as u8);
649        }
650
651        let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
652        PQVector::new(codes, norm)
653    }
654
655    /// Decode PQ codes back to approximate vector
656    pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
657        let m = self.config.num_subspaces;
658        let k = self.config.num_centroids;
659        let sub_dim = self.config.subspace_dim();
660
661        let mut rotated_vector = Vec::with_capacity(self.config.dim);
662
663        for (subspace_idx, &code) in codes.iter().enumerate().take(m) {
664            let centroid_base = subspace_idx * k * sub_dim;
665            let centroid_offset = centroid_base + (code as usize) * sub_dim;
666            rotated_vector
667                .extend_from_slice(&self.centroids[centroid_offset..centroid_offset + sub_dim]);
668        }
669
670        // Apply inverse rotation if present
671        if let Some(ref r) = self.rotation_matrix {
672            Self::apply_rotation_transpose(r, &rotated_vector, self.config.dim)
673        } else {
674            rotated_vector
675        }
676    }
677
678    /// Apply transpose of rotation matrix
679    fn apply_rotation_transpose(rotation: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
680        let mut result = vec![0.0f32; dim];
681        for i in 0..dim {
682            for j in 0..dim {
683                result[i] += rotation[j * dim + i] * vector[j];
684            }
685        }
686        result
687    }
688
689    /// Get centroid for a specific subspace and code
690    #[inline]
691    pub fn get_centroid(&self, subspace_idx: usize, code: u8) -> &[f32] {
692        let k = self.config.num_centroids;
693        let sub_dim = self.config.subspace_dim();
694        let offset = subspace_idx * k * sub_dim + (code as usize) * sub_dim;
695        &self.centroids[offset..offset + sub_dim]
696    }
697
698    /// Rotate a query vector
699    pub fn rotate_query(&self, query: &[f32]) -> Vec<f32> {
700        if let Some(ref r) = self.rotation_matrix {
701            Self::apply_rotation(r, query, self.config.dim)
702        } else {
703            query.to_vec()
704        }
705    }
706
707    /// Save to binary file
708    pub fn save(&self, path: &std::path::Path) -> io::Result<()> {
709        let mut file = std::fs::File::create(path)?;
710        self.write_to(&mut file)
711    }
712
713    /// Write to any writer
714    pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
715        writer.write_u32::<LittleEndian>(CODEBOOK_MAGIC)?;
716        writer.write_u32::<LittleEndian>(2)?;
717        writer.write_u64::<LittleEndian>(self.version)?;
718        writer.write_u32::<LittleEndian>(self.config.dim as u32)?;
719        writer.write_u32::<LittleEndian>(self.config.num_subspaces as u32)?;
720        writer.write_u32::<LittleEndian>(self.config.dims_per_block as u32)?;
721        writer.write_u32::<LittleEndian>(self.config.num_centroids as u32)?;
722        writer.write_u8(if self.config.anisotropic { 1 } else { 0 })?;
723        writer.write_f32::<LittleEndian>(self.config.aniso_eta)?;
724        writer.write_f32::<LittleEndian>(self.config.aniso_threshold)?;
725        writer.write_u8(if self.config.use_opq { 1 } else { 0 })?;
726        writer.write_u32::<LittleEndian>(self.config.opq_iters as u32)?;
727
728        if let Some(ref rotation) = self.rotation_matrix {
729            writer.write_u8(1)?;
730            for &val in rotation {
731                writer.write_f32::<LittleEndian>(val)?;
732            }
733        } else {
734            writer.write_u8(0)?;
735        }
736
737        for &val in &self.centroids {
738            writer.write_f32::<LittleEndian>(val)?;
739        }
740
741        if let Some(ref norms) = self.centroid_norms {
742            writer.write_u8(1)?;
743            for &val in norms {
744                writer.write_f32::<LittleEndian>(val)?;
745            }
746        } else {
747            writer.write_u8(0)?;
748        }
749
750        Ok(())
751    }
752
753    /// Load from binary file
754    pub fn load(path: &std::path::Path) -> io::Result<Self> {
755        let data = std::fs::read(path)?;
756        Self::read_from(&mut std::io::Cursor::new(data))
757    }
758
759    /// Read from any reader
760    pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
761        let magic = reader.read_u32::<LittleEndian>()?;
762        if magic != CODEBOOK_MAGIC {
763            return Err(io::Error::new(
764                io::ErrorKind::InvalidData,
765                "Invalid codebook file magic",
766            ));
767        }
768
769        let file_version = reader.read_u32::<LittleEndian>()?;
770        let version = reader.read_u64::<LittleEndian>()?;
771        let dim = reader.read_u32::<LittleEndian>()? as usize;
772        let num_subspaces = reader.read_u32::<LittleEndian>()? as usize;
773
774        let (
775            dims_per_block,
776            num_centroids,
777            anisotropic,
778            aniso_eta,
779            aniso_threshold,
780            use_opq,
781            opq_iters,
782        ) = if file_version >= 2 {
783            let dpb = reader.read_u32::<LittleEndian>()? as usize;
784            let nc = reader.read_u32::<LittleEndian>()? as usize;
785            let aniso = reader.read_u8()? != 0;
786            let eta = reader.read_f32::<LittleEndian>()?;
787            let thresh = reader.read_f32::<LittleEndian>()?;
788            let opq = reader.read_u8()? != 0;
789            let iters = reader.read_u32::<LittleEndian>()? as usize;
790            (dpb, nc, aniso, eta, thresh, opq, iters)
791        } else {
792            let nc = reader.read_u32::<LittleEndian>()? as usize;
793            let aniso = reader.read_u8()? != 0;
794            let thresh = reader.read_f32::<LittleEndian>()?;
795            let dpb = dim / num_subspaces;
796            (dpb, nc, aniso, 10.0, thresh, false, 0)
797        };
798
799        let config = PQConfig {
800            dim,
801            num_subspaces,
802            dims_per_block,
803            num_centroids,
804            seed: 42,
805            anisotropic,
806            aniso_eta,
807            aniso_threshold,
808            use_opq,
809            opq_iters,
810        };
811
812        let rotation_matrix = if file_version >= 2 {
813            let has_rotation = reader.read_u8()? != 0;
814            if has_rotation {
815                let mut rotation = vec![0.0f32; dim * dim];
816                for val in &mut rotation {
817                    *val = reader.read_f32::<LittleEndian>()?;
818                }
819                Some(rotation)
820            } else {
821                None
822            }
823        } else {
824            None
825        };
826
827        let centroid_count = num_subspaces * num_centroids * config.subspace_dim();
828        let mut centroids = vec![0.0f32; centroid_count];
829        for val in &mut centroids {
830            *val = reader.read_f32::<LittleEndian>()?;
831        }
832
833        let has_norms = reader.read_u8()? != 0;
834        let centroid_norms = if has_norms {
835            let mut norms = vec![0.0f32; num_subspaces * num_centroids];
836            for val in &mut norms {
837                *val = reader.read_f32::<LittleEndian>()?;
838            }
839            Some(norms)
840        } else {
841            None
842        };
843
844        Ok(Self {
845            config,
846            rotation_matrix,
847            centroids,
848            version,
849            centroid_norms,
850        })
851    }
852
853    /// Memory usage in bytes
854    pub fn size_bytes(&self) -> usize {
855        let centroids_size = self.centroids.len() * 4;
856        let norms_size = self
857            .centroid_norms
858            .as_ref()
859            .map(|n| n.len() * 4)
860            .unwrap_or(0);
861        let rotation_size = self
862            .rotation_matrix
863            .as_ref()
864            .map(|r| r.len() * 4)
865            .unwrap_or(0);
866        centroids_size + norms_size + rotation_size + 64
867    }
868}
869
870/// Precomputed distance table for fast asymmetric distance computation
871#[derive(Debug, Clone)]
872pub struct DistanceTable {
873    /// M × K table of squared distances
874    pub distances: Vec<f32>,
875    /// Number of subspaces
876    pub num_subspaces: usize,
877    /// Number of centroids per subspace
878    pub num_centroids: usize,
879}
880
881impl DistanceTable {
882    /// Build distance table for a query vector
883    pub fn build(codebook: &PQCodebook, query: &[f32], centroid: Option<&[f32]>) -> Self {
884        let m = codebook.config.num_subspaces;
885        let k = codebook.config.num_centroids;
886        let sub_dim = codebook.config.subspace_dim();
887
888        // Compute residual if centroid provided
889        let residual: Vec<f32> = if let Some(c) = centroid {
890            query.iter().zip(c).map(|(&v, &c)| v - c).collect()
891        } else {
892            query.to_vec()
893        };
894
895        // Apply rotation if present
896        let rotated_query = codebook.rotate_query(&residual);
897
898        let mut distances = Vec::with_capacity(m * k);
899
900        for subspace_idx in 0..m {
901            let query_offset = subspace_idx * sub_dim;
902            let query_sub = &rotated_query[query_offset..query_offset + sub_dim];
903
904            let centroid_base = subspace_idx * k * sub_dim;
905
906            for centroid_idx in 0..k {
907                let centroid_offset = centroid_base + centroid_idx * sub_dim;
908                let centroid = &codebook.centroids[centroid_offset..centroid_offset + sub_dim];
909
910                let dist: f32 = query_sub
911                    .iter()
912                    .zip(centroid.iter())
913                    .map(|(&a, &b)| (a - b) * (a - b))
914                    .sum();
915
916                distances.push(dist);
917            }
918        }
919
920        Self {
921            distances,
922            num_subspaces: m,
923            num_centroids: k,
924        }
925    }
926
927    /// Compute approximate distance using PQ codes
928    #[inline]
929    pub fn compute_distance(&self, codes: &[u8]) -> f32 {
930        let k = self.num_centroids;
931        let mut total = 0.0f32;
932
933        for (subspace_idx, &code) in codes.iter().enumerate() {
934            let table_offset = subspace_idx * k + code as usize;
935            total += self.distances[table_offset];
936        }
937
938        total
939    }
940}
941
942impl Quantizer for PQCodebook {
943    type Code = PQVector;
944    type Config = PQConfig;
945    type QueryData = DistanceTable;
946
947    fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> Self::Code {
948        self.encode(vector, centroid)
949    }
950
951    fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> Self::QueryData {
952        DistanceTable::build(self, query, centroid)
953    }
954
955    fn compute_distance(&self, query_data: &Self::QueryData, code: &Self::Code) -> f32 {
956        query_data.compute_distance(&code.codes)
957    }
958
959    fn decode(&self, code: &Self::Code) -> Option<Vec<f32>> {
960        Some(self.decode(&code.codes))
961    }
962
963    fn size_bytes(&self) -> usize {
964        self.size_bytes()
965    }
966}
967
968#[cfg(test)]
969mod tests {
970    use super::*;
971    use rand::prelude::*;
972
973    #[test]
974    fn test_pq_config() {
975        let config = PQConfig::new(128);
976        assert_eq!(config.dim, 128);
977        assert_eq!(config.dims_per_block, 2);
978        assert_eq!(config.num_subspaces, 64);
979    }
980
981    #[test]
982    fn test_pq_encode_decode() {
983        let dim = 32;
984        let config = PQConfig::new(dim).with_opq(false, 0);
985
986        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
987        let vectors: Vec<Vec<f32>> = (0..100)
988            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
989            .collect();
990
991        let codebook = PQCodebook::train(config, &vectors, 10);
992
993        let test_vec: Vec<f32> = (0..dim).map(|i| i as f32 / dim as f32).collect();
994        let code = codebook.encode(&test_vec, None);
995
996        assert_eq!(code.codes.len(), 16); // 32 dims / 2 dims_per_block
997    }
998
999    #[test]
1000    fn test_distance_table() {
1001        let dim = 16;
1002        let config = PQConfig::new(dim).with_opq(false, 0);
1003
1004        let mut rng = rand::rngs::StdRng::seed_from_u64(123);
1005        let vectors: Vec<Vec<f32>> = (0..50)
1006            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
1007            .collect();
1008
1009        let codebook = PQCodebook::train(config, &vectors, 5);
1010
1011        let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
1012        let table = DistanceTable::build(&codebook, &query, None);
1013
1014        let code = codebook.encode(&vectors[0], None);
1015        let dist = table.compute_distance(&code.codes);
1016
1017        assert!(dist >= 0.0);
1018    }
1019}