Skip to main content

hermes_core/structures/vector/quantization/
pq.rs

1//! Product Quantization with OPQ and Anisotropic Loss (ScaNN-style)
2//!
3//! Implementation inspired by Google's ScaNN (Scalable Nearest Neighbors):
4//! - **True anisotropic quantization**: penalizes parallel error more than orthogonal
5//! - **OPQ rotation**: learns optimal rotation matrix before quantization
6//! - **Product quantization** with learned codebooks
7//! - **SIMD-accelerated** asymmetric distance computation
8
9use std::io::{self, Read, Write};
10
11use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
12#[cfg(not(feature = "native"))]
13use rand::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use super::super::ivf::cluster::QuantizedCode;
17use super::Quantizer;
18
19#[cfg(target_arch = "aarch64")]
20#[allow(unused_imports)]
21use std::arch::aarch64::*;
22
23#[cfg(all(target_arch = "x86_64", feature = "native"))]
24#[allow(unused_imports)]
25use std::arch::x86_64::*;
26
27/// Magic number for codebook file
28const CODEBOOK_MAGIC: u32 = 0x5343424B; // "SCBK" - ScaNN CodeBook
29
30/// Default number of centroids per subspace (K) - must be 256 for u8 codes
31pub const DEFAULT_NUM_CENTROIDS: usize = 256;
32
33/// Default dimensions per block (ScaNN recommends 2 for best accuracy)
34pub const DEFAULT_DIMS_PER_BLOCK: usize = 2;
35
36/// Configuration for Product Quantization with OPQ and Anisotropic Loss
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PQConfig {
39    /// Vector dimension
40    pub dim: usize,
41    /// Number of subspaces (M) - computed from dim / dims_per_block
42    pub num_subspaces: usize,
43    /// Dimensions per subspace block (ScaNN recommends 2)
44    pub dims_per_block: usize,
45    /// Number of centroids per subspace (K) - typically 256 for u8 codes
46    pub num_centroids: usize,
47    /// Random seed for reproducibility
48    pub seed: u64,
49    /// Use anisotropic quantization (true ScaNN-style parallel/orthogonal weighting)
50    pub anisotropic: bool,
51    /// Anisotropic eta: ratio of parallel to orthogonal error weight (η)
52    pub aniso_eta: f32,
53    /// Anisotropic threshold T: only consider inner products >= T
54    pub aniso_threshold: f32,
55    /// Use OPQ rotation matrix (learned via SVD)
56    pub use_opq: bool,
57    /// Number of OPQ iterations
58    pub opq_iters: usize,
59}
60
61impl PQConfig {
62    /// Create config with ScaNN-recommended defaults
63    pub fn new(dim: usize) -> Self {
64        let dims_per_block = DEFAULT_DIMS_PER_BLOCK;
65        let num_subspaces = dim / dims_per_block;
66
67        Self {
68            dim,
69            num_subspaces,
70            dims_per_block,
71            num_centroids: DEFAULT_NUM_CENTROIDS,
72            seed: 42,
73            anisotropic: true,
74            aniso_eta: 10.0,
75            aniso_threshold: 0.2,
76            use_opq: true,
77            opq_iters: 10,
78        }
79    }
80
81    /// Create config with larger subspaces (faster but less accurate)
82    pub fn new_fast(dim: usize) -> Self {
83        let num_subspaces = if dim >= 256 {
84            8
85        } else if dim >= 64 {
86            4
87        } else {
88            2
89        };
90        let dims_per_block = dim / num_subspaces;
91
92        Self {
93            dim,
94            num_subspaces,
95            dims_per_block,
96            num_centroids: DEFAULT_NUM_CENTROIDS,
97            seed: 42,
98            anisotropic: true,
99            aniso_eta: 10.0,
100            aniso_threshold: 0.2,
101            use_opq: false,
102            opq_iters: 0,
103        }
104    }
105
106    /// Create balanced config (good recall/speed tradeoff)
107    /// Uses 16 subspaces for 128D+ vectors, 8 for smaller
108    pub fn new_balanced(dim: usize) -> Self {
109        let num_subspaces = if dim >= 128 {
110            16
111        } else if dim >= 64 {
112            8
113        } else {
114            4
115        };
116        let dims_per_block = dim / num_subspaces;
117
118        Self {
119            dim,
120            num_subspaces,
121            dims_per_block,
122            num_centroids: DEFAULT_NUM_CENTROIDS,
123            seed: 42,
124            anisotropic: true,
125            aniso_eta: 10.0,
126            aniso_threshold: 0.2,
127            use_opq: false,
128            opq_iters: 0,
129        }
130    }
131
132    pub fn with_dims_per_block(mut self, d: usize) -> Self {
133        assert!(
134            self.dim.is_multiple_of(d),
135            "Dimension must be divisible by dims_per_block"
136        );
137        self.dims_per_block = d;
138        self.num_subspaces = self.dim / d;
139        self
140    }
141
142    pub fn with_subspaces(mut self, m: usize) -> Self {
143        assert!(
144            self.dim.is_multiple_of(m),
145            "Dimension must be divisible by num_subspaces"
146        );
147        self.num_subspaces = m;
148        self.dims_per_block = self.dim / m;
149        self
150    }
151
152    pub fn with_centroids(mut self, k: usize) -> Self {
153        assert!(k <= 256, "Max 256 centroids for u8 codes");
154        self.num_centroids = k;
155        self
156    }
157
158    pub fn with_anisotropic(mut self, enabled: bool, eta: f32) -> Self {
159        self.anisotropic = enabled;
160        self.aniso_eta = eta;
161        self
162    }
163
164    pub fn with_opq(mut self, enabled: bool, iters: usize) -> Self {
165        self.use_opq = enabled;
166        self.opq_iters = iters;
167        self
168    }
169
170    /// Dimension of each subspace
171    pub fn subspace_dim(&self) -> usize {
172        self.dims_per_block
173    }
174}
175
176/// Quantized vector using Product Quantization
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct PQVector {
179    /// PQ codes (M bytes, one per subspace)
180    pub codes: Vec<u8>,
181    /// Original vector norm (for re-ranking or normalization)
182    pub norm: f32,
183}
184
185impl PQVector {
186    pub fn new(codes: Vec<u8>, norm: f32) -> Self {
187        Self { codes, norm }
188    }
189}
190
191impl QuantizedCode for PQVector {
192    fn size_bytes(&self) -> usize {
193        self.codes.len() + 4 // codes + norm
194    }
195}
196
197/// Learned codebook for Product Quantization with OPQ rotation
198///
199/// Trained once, shared across all segments (like CoarseCentroids).
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct PQCodebook {
202    /// Configuration
203    pub config: PQConfig,
204    /// OPQ rotation matrix (dim × dim), stored row-major
205    pub rotation_matrix: Option<Vec<f32>>,
206    /// Centroids: M subspaces × K centroids × subspace_dim
207    pub centroids: Vec<f32>,
208    /// Version for merge compatibility checking
209    pub version: u64,
210    /// Precomputed centroid norms for faster distance computation
211    pub centroid_norms: Option<Vec<f32>>,
212}
213
214impl PQCodebook {
215    /// Train codebook with OPQ rotation and anisotropic loss
216    #[cfg(feature = "native")]
217    pub fn train(config: PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Self {
218        use kentro::KMeans;
219        use ndarray::Array2;
220
221        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
222        assert_eq!(vectors[0].len(), config.dim, "Vector dimension mismatch");
223
224        let m = config.num_subspaces;
225        let k = config.num_centroids;
226        let sub_dim = config.subspace_dim();
227        let n = vectors.len();
228
229        // Step 1: Learn OPQ rotation matrix if enabled
230        let rotation_matrix = if config.use_opq && config.opq_iters > 0 {
231            Some(Self::learn_opq_rotation(&config, vectors, max_iters))
232        } else {
233            None
234        };
235
236        // Step 2: Apply rotation to vectors
237        let rotated_vectors: Vec<Vec<f32>> = if let Some(ref r) = rotation_matrix {
238            vectors
239                .iter()
240                .map(|v| Self::apply_rotation(r, v, config.dim))
241                .collect()
242        } else {
243            vectors.to_vec()
244        };
245
246        // Step 3: Train k-means for each subspace
247        let mut centroids = Vec::with_capacity(m * k * sub_dim);
248
249        for subspace_idx in 0..m {
250            let offset = subspace_idx * sub_dim;
251
252            let subdata: Vec<f32> = rotated_vectors
253                .iter()
254                .flat_map(|v| v[offset..offset + sub_dim].iter().copied())
255                .collect();
256
257            let actual_k = k.min(n);
258
259            let data = Array2::from_shape_vec((n, sub_dim), subdata)
260                .expect("Failed to create subspace array");
261            let mut kmeans = KMeans::new(actual_k)
262                .with_euclidean(true)
263                .with_iterations(max_iters);
264            let _ = kmeans
265                .train(data.view(), None)
266                .expect("K-means training failed");
267
268            let subspace_centroids: Vec<f32> = kmeans
269                .centroids()
270                .expect("No centroids")
271                .iter()
272                .copied()
273                .collect();
274
275            centroids.extend(subspace_centroids);
276
277            // Pad if needed
278            while centroids.len() < (subspace_idx + 1) * k * sub_dim {
279                let last_start = centroids.len() - sub_dim;
280                let last: Vec<f32> = centroids[last_start..].to_vec();
281                centroids.extend(last);
282            }
283        }
284
285        // Precompute centroid norms
286        let centroid_norms: Vec<f32> = (0..m * k)
287            .map(|i| {
288                let start = i * sub_dim;
289                if start + sub_dim <= centroids.len() {
290                    centroids[start..start + sub_dim]
291                        .iter()
292                        .map(|x| x * x)
293                        .sum::<f32>()
294                        .sqrt()
295                } else {
296                    0.0
297                }
298            })
299            .collect();
300
301        let version = std::time::SystemTime::now()
302            .duration_since(std::time::UNIX_EPOCH)
303            .unwrap_or_default()
304            .as_millis() as u64;
305
306        Self {
307            config,
308            rotation_matrix,
309            centroids,
310            version,
311            centroid_norms: Some(centroid_norms),
312        }
313    }
314
315    /// Fallback training for non-native builds (WASM)
316    #[cfg(not(feature = "native"))]
317    pub fn train(config: PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Self {
318        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
319        assert_eq!(vectors[0].len(), config.dim, "Vector dimension mismatch");
320
321        let m = config.num_subspaces;
322        let k = config.num_centroids;
323        let sub_dim = config.subspace_dim();
324        let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
325
326        let rotation_matrix = None;
327        let mut centroids = Vec::with_capacity(m * k * sub_dim);
328
329        for subspace_idx in 0..m {
330            let offset = subspace_idx * sub_dim;
331            let subvectors: Vec<Vec<f32>> = vectors
332                .iter()
333                .map(|v| v[offset..offset + sub_dim].to_vec())
334                .collect();
335
336            let subspace_centroids =
337                Self::train_subspace_scalar(&subvectors, k, sub_dim, max_iters, &mut rng);
338            centroids.extend(subspace_centroids);
339        }
340
341        let centroid_norms: Vec<f32> = (0..m * k)
342            .map(|i| {
343                let start = i * sub_dim;
344                centroids[start..start + sub_dim]
345                    .iter()
346                    .map(|x| x * x)
347                    .sum::<f32>()
348                    .sqrt()
349            })
350            .collect();
351
352        let version = std::time::SystemTime::now()
353            .duration_since(std::time::UNIX_EPOCH)
354            .unwrap_or_default()
355            .as_millis() as u64;
356
357        Self {
358            config,
359            rotation_matrix,
360            centroids,
361            version,
362            centroid_norms: Some(centroid_norms),
363        }
364    }
365
366    /// Learn OPQ rotation matrix using SVD
367    #[cfg(feature = "native")]
368    fn learn_opq_rotation(config: &PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Vec<f32> {
369        use nalgebra::DMatrix;
370
371        let dim = config.dim;
372        let n = vectors.len();
373
374        let mut rotation = DMatrix::<f32>::identity(dim, dim);
375        let data: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
376        let x = DMatrix::from_row_slice(n, dim, &data);
377
378        for _iter in 0..config.opq_iters.min(max_iters) {
379            let rotated = &x * &rotation;
380            let assignments = Self::compute_pq_assignments(config, &rotated);
381            let reconstructed = Self::reconstruct_from_assignments(config, &rotated, &assignments);
382
383            let xtx_hat = x.transpose() * &reconstructed;
384            let svd = xtx_hat.svd(true, true);
385            if let (Some(u), Some(vt)) = (svd.u, svd.v_t) {
386                let new_rotation: DMatrix<f32> = vt.transpose() * u.transpose();
387                rotation = new_rotation;
388            }
389        }
390
391        rotation.iter().copied().collect()
392    }
393
394    #[cfg(feature = "native")]
395    fn compute_pq_assignments(
396        config: &PQConfig,
397        rotated: &nalgebra::DMatrix<f32>,
398    ) -> Vec<Vec<usize>> {
399        use kentro::KMeans;
400        use ndarray::Array2;
401
402        let m = config.num_subspaces;
403        let k = config.num_centroids.min(rotated.nrows());
404        let sub_dim = config.subspace_dim();
405        let n = rotated.nrows();
406
407        let mut all_assignments = vec![vec![0usize; m]; n];
408
409        for subspace_idx in 0..m {
410            let mut subdata: Vec<f32> = Vec::with_capacity(n * sub_dim);
411            for row in 0..n {
412                for col in 0..sub_dim {
413                    subdata.push(rotated[(row, subspace_idx * sub_dim + col)]);
414                }
415            }
416
417            let data = Array2::from_shape_vec((n, sub_dim), subdata)
418                .expect("Failed to create subspace array");
419            let mut kmeans = KMeans::new(k).with_euclidean(true).with_iterations(5);
420            let clusters = kmeans
421                .train(data.view(), None)
422                .expect("K-means training failed");
423
424            // Invert cluster assignments: clusters[cluster_id] = [point_indices]
425            for (cluster_id, point_indices) in clusters.iter().enumerate() {
426                for &point_idx in point_indices {
427                    all_assignments[point_idx][subspace_idx] = cluster_id;
428                }
429            }
430        }
431
432        all_assignments
433    }
434
435    #[cfg(feature = "native")]
436    fn reconstruct_from_assignments(
437        config: &PQConfig,
438        rotated: &nalgebra::DMatrix<f32>,
439        assignments: &[Vec<usize>],
440    ) -> nalgebra::DMatrix<f32> {
441        use kentro::KMeans;
442        use ndarray::Array2;
443
444        let m = config.num_subspaces;
445        let sub_dim = config.subspace_dim();
446        let n = rotated.nrows();
447        let dim = config.dim;
448
449        let mut reconstructed = nalgebra::DMatrix::<f32>::zeros(n, dim);
450
451        for subspace_idx in 0..m {
452            let mut subdata: Vec<f32> = Vec::with_capacity(n * sub_dim);
453            for row in 0..n {
454                for col in 0..sub_dim {
455                    subdata.push(rotated[(row, subspace_idx * sub_dim + col)]);
456                }
457            }
458
459            let k = config.num_centroids.min(n);
460            let data = Array2::from_shape_vec((n, sub_dim), subdata)
461                .expect("Failed to create subspace array");
462            let mut kmeans = KMeans::new(k).with_euclidean(true).with_iterations(5);
463            let _ = kmeans
464                .train(data.view(), None)
465                .expect("K-means training failed");
466
467            let centroids = kmeans.centroids().expect("No centroids");
468
469            for (row, assignment) in assignments.iter().enumerate() {
470                let centroid_idx = assignment[subspace_idx];
471                if centroid_idx < k {
472                    for col in 0..sub_dim {
473                        reconstructed[(row, subspace_idx * sub_dim + col)] =
474                            centroids[[centroid_idx, col]];
475                    }
476                }
477            }
478        }
479
480        reconstructed
481    }
482
483    /// Apply rotation matrix to vector
484    fn apply_rotation(rotation: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
485        let mut result = vec![0.0f32; dim];
486        for i in 0..dim {
487            for j in 0..dim {
488                result[i] += rotation[i * dim + j] * vector[j];
489            }
490        }
491        result
492    }
493
494    /// Scalar k-means for WASM fallback
495    #[cfg(not(feature = "native"))]
496    fn train_subspace_scalar(
497        subvectors: &[Vec<f32>],
498        k: usize,
499        sub_dim: usize,
500        max_iters: usize,
501        rng: &mut impl Rng,
502    ) -> Vec<f32> {
503        let actual_k = k.min(subvectors.len());
504        let mut centroids = Self::kmeans_plusplus_init_scalar(subvectors, actual_k, sub_dim, rng);
505
506        for _ in 0..max_iters {
507            let assignments: Vec<usize> = subvectors
508                .iter()
509                .map(|v| Self::find_nearest_scalar(&centroids, v, sub_dim))
510                .collect();
511
512            let mut new_centroids = vec![0.0f32; actual_k * sub_dim];
513            let mut counts = vec![0usize; actual_k];
514
515            for (subvec, &assignment) in subvectors.iter().zip(assignments.iter()) {
516                counts[assignment] += 1;
517                let offset = assignment * sub_dim;
518                for (j, &val) in subvec.iter().enumerate() {
519                    new_centroids[offset + j] += val;
520                }
521            }
522
523            for (c, &count) in counts.iter().enumerate().take(actual_k) {
524                if count > 0 {
525                    let offset = c * sub_dim;
526                    for j in 0..sub_dim {
527                        new_centroids[offset + j] /= count as f32;
528                    }
529                }
530            }
531
532            centroids = new_centroids;
533        }
534
535        while centroids.len() < k * sub_dim {
536            let last_start = centroids.len() - sub_dim;
537            let last: Vec<f32> = centroids[last_start..].to_vec();
538            centroids.extend(last);
539        }
540
541        centroids
542    }
543
544    #[cfg(not(feature = "native"))]
545    fn kmeans_plusplus_init_scalar(
546        subvectors: &[Vec<f32>],
547        k: usize,
548        sub_dim: usize,
549        rng: &mut impl Rng,
550    ) -> Vec<f32> {
551        let mut centroids = Vec::with_capacity(k * sub_dim);
552        let first_idx = rng.random_range(0..subvectors.len());
553        centroids.extend_from_slice(&subvectors[first_idx]);
554
555        for _ in 1..k {
556            let distances: Vec<f32> = subvectors
557                .iter()
558                .map(|v| Self::min_dist_to_centroids_scalar(&centroids, v, sub_dim))
559                .collect();
560
561            let total: f32 = distances.iter().sum();
562            let mut r = rng.random::<f32>() * total;
563            let mut chosen_idx = 0;
564            for (i, &d) in distances.iter().enumerate() {
565                r -= d;
566                if r <= 0.0 {
567                    chosen_idx = i;
568                    break;
569                }
570            }
571            centroids.extend_from_slice(&subvectors[chosen_idx]);
572        }
573
574        centroids
575    }
576
577    #[cfg(not(feature = "native"))]
578    fn min_dist_to_centroids_scalar(centroids: &[f32], vector: &[f32], sub_dim: usize) -> f32 {
579        let num_centroids = centroids.len() / sub_dim;
580        (0..num_centroids)
581            .map(|c| {
582                let offset = c * sub_dim;
583                vector
584                    .iter()
585                    .zip(&centroids[offset..offset + sub_dim])
586                    .map(|(&a, &b)| (a - b) * (a - b))
587                    .sum()
588            })
589            .fold(f32::MAX, f32::min)
590    }
591
592    #[cfg(not(feature = "native"))]
593    fn find_nearest_scalar(centroids: &[f32], vector: &[f32], sub_dim: usize) -> usize {
594        let num_centroids = centroids.len() / sub_dim;
595        (0..num_centroids)
596            .map(|c| {
597                let offset = c * sub_dim;
598                let dist: f32 = vector
599                    .iter()
600                    .zip(&centroids[offset..offset + sub_dim])
601                    .map(|(&a, &b)| (a - b) * (a - b))
602                    .sum();
603                (c, dist)
604            })
605            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
606            .map(|(c, _)| c)
607            .unwrap_or(0)
608    }
609
610    /// Find nearest centroid index
611    fn find_nearest(centroids: &[f32], vector: &[f32], sub_dim: usize) -> usize {
612        let num_centroids = centroids.len() / sub_dim;
613        let mut best_idx = 0;
614        let mut best_dist = f32::MAX;
615
616        for c in 0..num_centroids {
617            let offset = c * sub_dim;
618            let dist: f32 = vector
619                .iter()
620                .zip(&centroids[offset..offset + sub_dim])
621                .map(|(&a, &b)| (a - b) * (a - b))
622                .sum();
623
624            if dist < best_dist {
625                best_dist = dist;
626                best_idx = c;
627            }
628        }
629
630        best_idx
631    }
632
633    /// Encode a vector to PQ codes
634    pub fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> PQVector {
635        let m = self.config.num_subspaces;
636        let k = self.config.num_centroids;
637        let sub_dim = self.config.subspace_dim();
638
639        // Compute residual if centroid provided
640        let residual: Vec<f32> = if let Some(c) = centroid {
641            vector.iter().zip(c).map(|(&v, &c)| v - c).collect()
642        } else {
643            vector.to_vec()
644        };
645
646        // Apply rotation if present
647        let rotated: Vec<f32>;
648        let vec_to_encode = if let Some(ref r) = self.rotation_matrix {
649            rotated = Self::apply_rotation(r, &residual, self.config.dim);
650            &rotated
651        } else {
652            &residual
653        };
654
655        let mut codes = Vec::with_capacity(m);
656
657        for subspace_idx in 0..m {
658            let vec_offset = subspace_idx * sub_dim;
659            let subvec = &vec_to_encode[vec_offset..vec_offset + sub_dim];
660
661            let centroid_base = subspace_idx * k * sub_dim;
662            let centroids_slice = &self.centroids[centroid_base..centroid_base + k * sub_dim];
663
664            let nearest = Self::find_nearest(centroids_slice, subvec, sub_dim);
665            codes.push(nearest as u8);
666        }
667
668        let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
669        PQVector::new(codes, norm)
670    }
671
672    /// Decode PQ codes back to approximate vector
673    pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
674        let m = self.config.num_subspaces;
675        let k = self.config.num_centroids;
676        let sub_dim = self.config.subspace_dim();
677
678        let mut rotated_vector = Vec::with_capacity(self.config.dim);
679
680        for (subspace_idx, &code) in codes.iter().enumerate().take(m) {
681            let centroid_base = subspace_idx * k * sub_dim;
682            let centroid_offset = centroid_base + (code as usize) * sub_dim;
683            rotated_vector
684                .extend_from_slice(&self.centroids[centroid_offset..centroid_offset + sub_dim]);
685        }
686
687        // Apply inverse rotation if present
688        if let Some(ref r) = self.rotation_matrix {
689            Self::apply_rotation_transpose(r, &rotated_vector, self.config.dim)
690        } else {
691            rotated_vector
692        }
693    }
694
695    /// Apply transpose of rotation matrix
696    fn apply_rotation_transpose(rotation: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
697        let mut result = vec![0.0f32; dim];
698        for i in 0..dim {
699            for j in 0..dim {
700                result[i] += rotation[j * dim + i] * vector[j];
701            }
702        }
703        result
704    }
705
706    /// Get centroid for a specific subspace and code
707    #[inline]
708    pub fn get_centroid(&self, subspace_idx: usize, code: u8) -> &[f32] {
709        let k = self.config.num_centroids;
710        let sub_dim = self.config.subspace_dim();
711        let offset = subspace_idx * k * sub_dim + (code as usize) * sub_dim;
712        &self.centroids[offset..offset + sub_dim]
713    }
714
715    /// Rotate a query vector
716    pub fn rotate_query(&self, query: &[f32]) -> Vec<f32> {
717        if let Some(ref r) = self.rotation_matrix {
718            Self::apply_rotation(r, query, self.config.dim)
719        } else {
720            query.to_vec()
721        }
722    }
723
724    /// Save to binary file
725    pub fn save(&self, path: &std::path::Path) -> io::Result<()> {
726        let mut file = std::fs::File::create(path)?;
727        self.write_to(&mut file)
728    }
729
730    /// Write to any writer
731    pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
732        writer.write_u32::<LittleEndian>(CODEBOOK_MAGIC)?;
733        writer.write_u32::<LittleEndian>(2)?;
734        writer.write_u64::<LittleEndian>(self.version)?;
735        writer.write_u32::<LittleEndian>(self.config.dim as u32)?;
736        writer.write_u32::<LittleEndian>(self.config.num_subspaces as u32)?;
737        writer.write_u32::<LittleEndian>(self.config.dims_per_block as u32)?;
738        writer.write_u32::<LittleEndian>(self.config.num_centroids as u32)?;
739        writer.write_u8(if self.config.anisotropic { 1 } else { 0 })?;
740        writer.write_f32::<LittleEndian>(self.config.aniso_eta)?;
741        writer.write_f32::<LittleEndian>(self.config.aniso_threshold)?;
742        writer.write_u8(if self.config.use_opq { 1 } else { 0 })?;
743        writer.write_u32::<LittleEndian>(self.config.opq_iters as u32)?;
744
745        if let Some(ref rotation) = self.rotation_matrix {
746            writer.write_u8(1)?;
747            for &val in rotation {
748                writer.write_f32::<LittleEndian>(val)?;
749            }
750        } else {
751            writer.write_u8(0)?;
752        }
753
754        for &val in &self.centroids {
755            writer.write_f32::<LittleEndian>(val)?;
756        }
757
758        if let Some(ref norms) = self.centroid_norms {
759            writer.write_u8(1)?;
760            for &val in norms {
761                writer.write_f32::<LittleEndian>(val)?;
762            }
763        } else {
764            writer.write_u8(0)?;
765        }
766
767        Ok(())
768    }
769
770    /// Load from binary file
771    pub fn load(path: &std::path::Path) -> io::Result<Self> {
772        let data = std::fs::read(path)?;
773        Self::read_from(&mut std::io::Cursor::new(data))
774    }
775
776    /// Read from any reader
777    pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
778        let magic = reader.read_u32::<LittleEndian>()?;
779        if magic != CODEBOOK_MAGIC {
780            return Err(io::Error::new(
781                io::ErrorKind::InvalidData,
782                "Invalid codebook file magic",
783            ));
784        }
785
786        let file_version = reader.read_u32::<LittleEndian>()?;
787        let version = reader.read_u64::<LittleEndian>()?;
788        let dim = reader.read_u32::<LittleEndian>()? as usize;
789        let num_subspaces = reader.read_u32::<LittleEndian>()? as usize;
790
791        let (
792            dims_per_block,
793            num_centroids,
794            anisotropic,
795            aniso_eta,
796            aniso_threshold,
797            use_opq,
798            opq_iters,
799        ) = if file_version >= 2 {
800            let dpb = reader.read_u32::<LittleEndian>()? as usize;
801            let nc = reader.read_u32::<LittleEndian>()? as usize;
802            let aniso = reader.read_u8()? != 0;
803            let eta = reader.read_f32::<LittleEndian>()?;
804            let thresh = reader.read_f32::<LittleEndian>()?;
805            let opq = reader.read_u8()? != 0;
806            let iters = reader.read_u32::<LittleEndian>()? as usize;
807            (dpb, nc, aniso, eta, thresh, opq, iters)
808        } else {
809            let nc = reader.read_u32::<LittleEndian>()? as usize;
810            let aniso = reader.read_u8()? != 0;
811            let thresh = reader.read_f32::<LittleEndian>()?;
812            let dpb = dim / num_subspaces;
813            (dpb, nc, aniso, 10.0, thresh, false, 0)
814        };
815
816        let config = PQConfig {
817            dim,
818            num_subspaces,
819            dims_per_block,
820            num_centroids,
821            seed: 42,
822            anisotropic,
823            aniso_eta,
824            aniso_threshold,
825            use_opq,
826            opq_iters,
827        };
828
829        let rotation_matrix = if file_version >= 2 {
830            let has_rotation = reader.read_u8()? != 0;
831            if has_rotation {
832                let mut rotation = vec![0.0f32; dim * dim];
833                for val in &mut rotation {
834                    *val = reader.read_f32::<LittleEndian>()?;
835                }
836                Some(rotation)
837            } else {
838                None
839            }
840        } else {
841            None
842        };
843
844        let centroid_count = num_subspaces * num_centroids * config.subspace_dim();
845        let mut centroids = vec![0.0f32; centroid_count];
846        for val in &mut centroids {
847            *val = reader.read_f32::<LittleEndian>()?;
848        }
849
850        let has_norms = reader.read_u8()? != 0;
851        let centroid_norms = if has_norms {
852            let mut norms = vec![0.0f32; num_subspaces * num_centroids];
853            for val in &mut norms {
854                *val = reader.read_f32::<LittleEndian>()?;
855            }
856            Some(norms)
857        } else {
858            None
859        };
860
861        Ok(Self {
862            config,
863            rotation_matrix,
864            centroids,
865            version,
866            centroid_norms,
867        })
868    }
869
870    /// Memory usage in bytes
871    pub fn size_bytes(&self) -> usize {
872        let centroids_size = self.centroids.len() * 4;
873        let norms_size = self
874            .centroid_norms
875            .as_ref()
876            .map(|n| n.len() * 4)
877            .unwrap_or(0);
878        let rotation_size = self
879            .rotation_matrix
880            .as_ref()
881            .map(|r| r.len() * 4)
882            .unwrap_or(0);
883        centroids_size + norms_size + rotation_size + 64
884    }
885}
886
887/// Precomputed distance table for fast asymmetric distance computation
888#[derive(Debug, Clone)]
889pub struct DistanceTable {
890    /// M × K table of squared distances
891    pub distances: Vec<f32>,
892    /// Number of subspaces
893    pub num_subspaces: usize,
894    /// Number of centroids per subspace
895    pub num_centroids: usize,
896}
897
898impl DistanceTable {
899    /// Build distance table for a query vector
900    pub fn build(codebook: &PQCodebook, query: &[f32], centroid: Option<&[f32]>) -> Self {
901        let m = codebook.config.num_subspaces;
902        let k = codebook.config.num_centroids;
903        let sub_dim = codebook.config.subspace_dim();
904
905        // Compute residual if centroid provided
906        let residual: Vec<f32> = if let Some(c) = centroid {
907            query.iter().zip(c).map(|(&v, &c)| v - c).collect()
908        } else {
909            query.to_vec()
910        };
911
912        // Apply rotation if present
913        let rotated_query = codebook.rotate_query(&residual);
914
915        let mut distances = Vec::with_capacity(m * k);
916
917        for subspace_idx in 0..m {
918            let query_offset = subspace_idx * sub_dim;
919            let query_sub = &rotated_query[query_offset..query_offset + sub_dim];
920
921            let centroid_base = subspace_idx * k * sub_dim;
922
923            for centroid_idx in 0..k {
924                let centroid_offset = centroid_base + centroid_idx * sub_dim;
925                let centroid = &codebook.centroids[centroid_offset..centroid_offset + sub_dim];
926
927                let dist: f32 = query_sub
928                    .iter()
929                    .zip(centroid.iter())
930                    .map(|(&a, &b)| (a - b) * (a - b))
931                    .sum();
932
933                distances.push(dist);
934            }
935        }
936
937        Self {
938            distances,
939            num_subspaces: m,
940            num_centroids: k,
941        }
942    }
943
944    /// Compute approximate distance using PQ codes
945    #[inline]
946    pub fn compute_distance(&self, codes: &[u8]) -> f32 {
947        let k = self.num_centroids;
948        let mut total = 0.0f32;
949
950        for (subspace_idx, &code) in codes.iter().enumerate() {
951            let table_offset = subspace_idx * k + code as usize;
952            total += self.distances[table_offset];
953        }
954
955        total
956    }
957}
958
959impl Quantizer for PQCodebook {
960    type Code = PQVector;
961    type Config = PQConfig;
962    type QueryData = DistanceTable;
963
964    fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> Self::Code {
965        self.encode(vector, centroid)
966    }
967
968    fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> Self::QueryData {
969        DistanceTable::build(self, query, centroid)
970    }
971
972    fn compute_distance(&self, query_data: &Self::QueryData, code: &Self::Code) -> f32 {
973        query_data.compute_distance(&code.codes)
974    }
975
976    fn decode(&self, code: &Self::Code) -> Option<Vec<f32>> {
977        Some(self.decode(&code.codes))
978    }
979
980    fn size_bytes(&self) -> usize {
981        self.size_bytes()
982    }
983}
984
985#[cfg(test)]
986mod tests {
987    use super::*;
988    use rand::prelude::*;
989
990    #[test]
991    fn test_pq_config() {
992        let config = PQConfig::new(128);
993        assert_eq!(config.dim, 128);
994        assert_eq!(config.dims_per_block, 2);
995        assert_eq!(config.num_subspaces, 64);
996    }
997
998    #[test]
999    fn test_pq_encode_decode() {
1000        let dim = 32;
1001        let config = PQConfig::new(dim).with_opq(false, 0);
1002
1003        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
1004        let vectors: Vec<Vec<f32>> = (0..100)
1005            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
1006            .collect();
1007
1008        let codebook = PQCodebook::train(config, &vectors, 10);
1009
1010        let test_vec: Vec<f32> = (0..dim).map(|i| i as f32 / dim as f32).collect();
1011        let code = codebook.encode(&test_vec, None);
1012
1013        assert_eq!(code.codes.len(), 16); // 32 dims / 2 dims_per_block
1014    }
1015
1016    #[test]
1017    fn test_distance_table() {
1018        let dim = 16;
1019        let config = PQConfig::new(dim).with_opq(false, 0);
1020
1021        let mut rng = rand::rngs::StdRng::seed_from_u64(123);
1022        let vectors: Vec<Vec<f32>> = (0..50)
1023            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
1024            .collect();
1025
1026        let codebook = PQCodebook::train(config, &vectors, 5);
1027
1028        let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
1029        let table = DistanceTable::build(&codebook, &query, None);
1030
1031        let code = codebook.encode(&vectors[0], None);
1032        let dist = table.compute_distance(&code.codes);
1033
1034        assert!(dist >= 0.0);
1035    }
1036}