Skip to main content

hermes_core/structures/vector/quantization/
pq.rs

1//! Product Quantization with OPQ and Anisotropic Loss (ScaNN-style)
2//!
3//! Implementation inspired by Google's ScaNN (Scalable Nearest Neighbors):
4//! - **True anisotropic quantization**: penalizes parallel error more than orthogonal
5//! - **OPQ rotation**: learns optimal rotation matrix before quantization
6//! - **Product quantization** with learned codebooks
7//! - **SIMD-accelerated** asymmetric distance computation
8
9use std::io::{self, Read, Write};
10
11use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
12#[cfg(not(feature = "native"))]
13use rand::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use super::super::ivf::cluster::QuantizedCode;
17use super::Quantizer;
18
19#[cfg(target_arch = "aarch64")]
20#[allow(unused_imports)]
21use std::arch::aarch64::*;
22
23#[cfg(all(target_arch = "x86_64", feature = "native"))]
24#[allow(unused_imports)]
25use std::arch::x86_64::*;
26
27/// Magic number for codebook file
28const CODEBOOK_MAGIC: u32 = 0x5343424B; // "SCBK" - ScaNN CodeBook
29
30/// Default number of centroids per subspace (K) - must be 256 for u8 codes
31pub const DEFAULT_NUM_CENTROIDS: usize = 256;
32
33/// Default dimensions per block (ScaNN recommends 2 for best accuracy)
34pub const DEFAULT_DIMS_PER_BLOCK: usize = 2;
35
36/// Configuration for Product Quantization with OPQ and Anisotropic Loss
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PQConfig {
39    /// Vector dimension
40    pub dim: usize,
41    /// Number of subspaces (M) - computed from dim / dims_per_block
42    pub num_subspaces: usize,
43    /// Dimensions per subspace block (ScaNN recommends 2)
44    pub dims_per_block: usize,
45    /// Number of centroids per subspace (K) - typically 256 for u8 codes
46    pub num_centroids: usize,
47    /// Random seed for reproducibility
48    pub seed: u64,
49    /// Use anisotropic quantization (true ScaNN-style parallel/orthogonal weighting)
50    pub anisotropic: bool,
51    /// Anisotropic eta: ratio of parallel to orthogonal error weight (η)
52    pub aniso_eta: f32,
53    /// Anisotropic threshold T: only consider inner products >= T
54    pub aniso_threshold: f32,
55    /// Use OPQ rotation matrix (learned via SVD)
56    pub use_opq: bool,
57    /// Number of OPQ iterations
58    pub opq_iters: usize,
59}
60
61impl PQConfig {
62    /// Create config with ScaNN-recommended defaults
63    pub fn new(dim: usize) -> Self {
64        let dims_per_block = DEFAULT_DIMS_PER_BLOCK;
65        let num_subspaces = dim / dims_per_block;
66
67        Self {
68            dim,
69            num_subspaces,
70            dims_per_block,
71            num_centroids: DEFAULT_NUM_CENTROIDS,
72            seed: 42,
73            anisotropic: true,
74            aniso_eta: 10.0,
75            aniso_threshold: 0.2,
76            use_opq: true,
77            opq_iters: 10,
78        }
79    }
80
81    /// Create config with larger subspaces (faster but less accurate)
82    pub fn new_fast(dim: usize) -> Self {
83        let num_subspaces = if dim >= 256 {
84            8
85        } else if dim >= 64 {
86            4
87        } else {
88            2
89        };
90        let dims_per_block = dim / num_subspaces;
91
92        Self {
93            dim,
94            num_subspaces,
95            dims_per_block,
96            num_centroids: DEFAULT_NUM_CENTROIDS,
97            seed: 42,
98            anisotropic: true,
99            aniso_eta: 10.0,
100            aniso_threshold: 0.2,
101            use_opq: false,
102            opq_iters: 0,
103        }
104    }
105
106    /// Create balanced config (good recall/speed tradeoff)
107    /// Uses 16 subspaces for 128D+ vectors, 8 for smaller
108    pub fn new_balanced(dim: usize) -> Self {
109        let num_subspaces = if dim >= 128 {
110            16
111        } else if dim >= 64 {
112            8
113        } else {
114            4
115        };
116        let dims_per_block = dim / num_subspaces;
117
118        Self {
119            dim,
120            num_subspaces,
121            dims_per_block,
122            num_centroids: DEFAULT_NUM_CENTROIDS,
123            seed: 42,
124            anisotropic: true,
125            aniso_eta: 10.0,
126            aniso_threshold: 0.2,
127            use_opq: false,
128            opq_iters: 0,
129        }
130    }
131
132    pub fn with_dims_per_block(mut self, d: usize) -> Self {
133        assert!(
134            self.dim.is_multiple_of(d),
135            "Dimension must be divisible by dims_per_block"
136        );
137        self.dims_per_block = d;
138        self.num_subspaces = self.dim / d;
139        self
140    }
141
142    pub fn with_subspaces(mut self, m: usize) -> Self {
143        assert!(
144            self.dim.is_multiple_of(m),
145            "Dimension must be divisible by num_subspaces"
146        );
147        self.num_subspaces = m;
148        self.dims_per_block = self.dim / m;
149        self
150    }
151
152    pub fn with_centroids(mut self, k: usize) -> Self {
153        assert!(k <= 256, "Max 256 centroids for u8 codes");
154        self.num_centroids = k;
155        self
156    }
157
158    pub fn with_anisotropic(mut self, enabled: bool, eta: f32) -> Self {
159        self.anisotropic = enabled;
160        self.aniso_eta = eta;
161        self
162    }
163
164    pub fn with_opq(mut self, enabled: bool, iters: usize) -> Self {
165        self.use_opq = enabled;
166        self.opq_iters = iters;
167        self
168    }
169
170    /// Dimension of each subspace
171    pub fn subspace_dim(&self) -> usize {
172        self.dims_per_block
173    }
174}
175
176/// Quantized vector using Product Quantization
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct PQVector {
179    /// PQ codes (M bytes, one per subspace)
180    pub codes: Vec<u8>,
181    /// Original vector norm (for re-ranking or normalization)
182    pub norm: f32,
183}
184
185impl PQVector {
186    pub fn new(codes: Vec<u8>, norm: f32) -> Self {
187        Self { codes, norm }
188    }
189}
190
191impl QuantizedCode for PQVector {
192    fn size_bytes(&self) -> usize {
193        self.codes.len() + 4 // codes + norm
194    }
195}
196
197/// Learned codebook for Product Quantization with OPQ rotation
198///
199/// Trained once, shared across all segments (like CoarseCentroids).
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct PQCodebook {
202    /// Configuration
203    pub config: PQConfig,
204    /// OPQ rotation matrix (dim × dim), stored row-major
205    pub rotation_matrix: Option<Vec<f32>>,
206    /// Centroids: M subspaces × K centroids × subspace_dim
207    pub centroids: Vec<f32>,
208    /// Version for merge compatibility checking
209    pub version: u64,
210    /// Precomputed centroid norms for faster distance computation
211    pub centroid_norms: Option<Vec<f32>>,
212}
213
214impl PQCodebook {
215    /// Train codebook with OPQ rotation and anisotropic loss
216    #[cfg(feature = "native")]
217    pub fn train(config: PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Self {
218        use kentro::KMeans;
219        use ndarray::Array2;
220
221        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
222        assert_eq!(vectors[0].len(), config.dim, "Vector dimension mismatch");
223
224        let m = config.num_subspaces;
225        let k = config.num_centroids;
226        let sub_dim = config.subspace_dim();
227        let n = vectors.len();
228
229        // Step 1: Learn OPQ rotation matrix if enabled
230        let rotation_matrix = if config.use_opq && config.opq_iters > 0 {
231            Some(Self::learn_opq_rotation(&config, vectors, max_iters))
232        } else {
233            None
234        };
235
236        // Step 2: Apply rotation to vectors
237        let rotated_vectors: Vec<Vec<f32>> = if let Some(ref r) = rotation_matrix {
238            vectors
239                .iter()
240                .map(|v| Self::apply_rotation(r, v, config.dim))
241                .collect()
242        } else {
243            vectors.to_vec()
244        };
245
246        // Step 3: Train k-means for each subspace
247        let mut centroids = Vec::with_capacity(m * k * sub_dim);
248
249        for subspace_idx in 0..m {
250            let offset = subspace_idx * sub_dim;
251
252            let subdata: Vec<f32> = rotated_vectors
253                .iter()
254                .flat_map(|v| v[offset..offset + sub_dim].iter().copied())
255                .collect();
256
257            let actual_k = k.min(n);
258
259            let data = Array2::from_shape_vec((n, sub_dim), subdata)
260                .expect("Failed to create subspace array");
261            let mut kmeans = KMeans::new(actual_k)
262                .with_euclidean(true)
263                .with_iterations(max_iters);
264            let _ = kmeans
265                .train(data.view(), None)
266                .expect("K-means training failed");
267
268            let subspace_centroids: Vec<f32> = kmeans
269                .centroids()
270                .expect("No centroids")
271                .iter()
272                .copied()
273                .collect();
274
275            centroids.extend(subspace_centroids);
276
277            // Pad if needed
278            while centroids.len() < (subspace_idx + 1) * k * sub_dim {
279                let last_start = centroids.len() - sub_dim;
280                let last: Vec<f32> = centroids[last_start..].to_vec();
281                centroids.extend(last);
282            }
283        }
284
285        // Precompute centroid norms
286        let centroid_norms: Vec<f32> = (0..m * k)
287            .map(|i| {
288                let start = i * sub_dim;
289                if start + sub_dim <= centroids.len() {
290                    centroids[start..start + sub_dim]
291                        .iter()
292                        .map(|x| x * x)
293                        .sum::<f32>()
294                        .sqrt()
295                } else {
296                    0.0
297                }
298            })
299            .collect();
300
301        let version = std::time::SystemTime::now()
302            .duration_since(std::time::UNIX_EPOCH)
303            .unwrap_or_default()
304            .as_millis() as u64;
305
306        Self {
307            config,
308            rotation_matrix,
309            centroids,
310            version,
311            centroid_norms: Some(centroid_norms),
312        }
313    }
314
315    /// Fallback training for non-native builds (WASM)
316    #[cfg(not(feature = "native"))]
317    pub fn train(config: PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Self {
318        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
319        assert_eq!(vectors[0].len(), config.dim, "Vector dimension mismatch");
320
321        let m = config.num_subspaces;
322        let k = config.num_centroids;
323        let sub_dim = config.subspace_dim();
324        let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
325
326        let rotation_matrix = None;
327        let mut centroids = Vec::with_capacity(m * k * sub_dim);
328
329        for subspace_idx in 0..m {
330            let offset = subspace_idx * sub_dim;
331            let subvectors: Vec<Vec<f32>> = vectors
332                .iter()
333                .map(|v| v[offset..offset + sub_dim].to_vec())
334                .collect();
335
336            let subspace_centroids =
337                Self::train_subspace_scalar(&subvectors, k, sub_dim, max_iters, &mut rng);
338            centroids.extend(subspace_centroids);
339        }
340
341        let centroid_norms: Vec<f32> = (0..m * k)
342            .map(|i| {
343                let start = i * sub_dim;
344                centroids[start..start + sub_dim]
345                    .iter()
346                    .map(|x| x * x)
347                    .sum::<f32>()
348                    .sqrt()
349            })
350            .collect();
351
352        let version = std::time::SystemTime::now()
353            .duration_since(std::time::UNIX_EPOCH)
354            .unwrap_or_default()
355            .as_millis() as u64;
356
357        Self {
358            config,
359            rotation_matrix,
360            centroids,
361            version,
362            centroid_norms: Some(centroid_norms),
363        }
364    }
365
366    /// Learn OPQ rotation matrix using SVD
367    #[cfg(feature = "native")]
368    fn learn_opq_rotation(config: &PQConfig, vectors: &[Vec<f32>], max_iters: usize) -> Vec<f32> {
369        use nalgebra::DMatrix;
370
371        let dim = config.dim;
372        let n = vectors.len();
373
374        let mut rotation = DMatrix::<f32>::identity(dim, dim);
375        let data: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
376        let x = DMatrix::from_row_slice(n, dim, &data);
377
378        for _iter in 0..config.opq_iters.min(max_iters) {
379            let rotated = &x * &rotation;
380            let assignments = Self::compute_pq_assignments(config, &rotated);
381            let reconstructed = Self::reconstruct_from_assignments(config, &rotated, &assignments);
382
383            let xtx_hat = x.transpose() * &reconstructed;
384            let svd = xtx_hat.svd(true, true);
385            if let (Some(u), Some(vt)) = (svd.u, svd.v_t) {
386                let new_rotation: DMatrix<f32> = vt.transpose() * u.transpose();
387                rotation = new_rotation;
388            }
389        }
390
391        rotation.iter().copied().collect()
392    }
393
394    #[cfg(feature = "native")]
395    fn compute_pq_assignments(
396        config: &PQConfig,
397        rotated: &nalgebra::DMatrix<f32>,
398    ) -> Vec<Vec<usize>> {
399        use kentro::KMeans;
400        use ndarray::Array2;
401
402        let m = config.num_subspaces;
403        let k = config.num_centroids.min(rotated.nrows());
404        let sub_dim = config.subspace_dim();
405        let n = rotated.nrows();
406
407        let mut all_assignments = vec![vec![0usize; m]; n];
408
409        for subspace_idx in 0..m {
410            let mut subdata: Vec<f32> = Vec::with_capacity(n * sub_dim);
411            for row in 0..n {
412                for col in 0..sub_dim {
413                    subdata.push(rotated[(row, subspace_idx * sub_dim + col)]);
414                }
415            }
416
417            let data = Array2::from_shape_vec((n, sub_dim), subdata)
418                .expect("Failed to create subspace array");
419            let mut kmeans = KMeans::new(k).with_euclidean(true).with_iterations(5);
420            let clusters = kmeans
421                .train(data.view(), None)
422                .expect("K-means training failed");
423
424            // Invert cluster assignments: clusters[cluster_id] = [point_indices]
425            for (cluster_id, point_indices) in clusters.iter().enumerate() {
426                for &point_idx in point_indices {
427                    all_assignments[point_idx][subspace_idx] = cluster_id;
428                }
429            }
430        }
431
432        all_assignments
433    }
434
435    #[cfg(feature = "native")]
436    fn reconstruct_from_assignments(
437        config: &PQConfig,
438        rotated: &nalgebra::DMatrix<f32>,
439        assignments: &[Vec<usize>],
440    ) -> nalgebra::DMatrix<f32> {
441        use kentro::KMeans;
442        use ndarray::Array2;
443
444        let m = config.num_subspaces;
445        let sub_dim = config.subspace_dim();
446        let n = rotated.nrows();
447        let dim = config.dim;
448
449        let mut reconstructed = nalgebra::DMatrix::<f32>::zeros(n, dim);
450
451        for subspace_idx in 0..m {
452            let mut subdata: Vec<f32> = Vec::with_capacity(n * sub_dim);
453            for row in 0..n {
454                for col in 0..sub_dim {
455                    subdata.push(rotated[(row, subspace_idx * sub_dim + col)]);
456                }
457            }
458
459            let k = config.num_centroids.min(n);
460            let data = Array2::from_shape_vec((n, sub_dim), subdata)
461                .expect("Failed to create subspace array");
462            let mut kmeans = KMeans::new(k).with_euclidean(true).with_iterations(5);
463            let _ = kmeans
464                .train(data.view(), None)
465                .expect("K-means training failed");
466
467            let centroids = kmeans.centroids().expect("No centroids");
468
469            for (row, assignment) in assignments.iter().enumerate() {
470                let centroid_idx = assignment[subspace_idx];
471                if centroid_idx < k {
472                    for col in 0..sub_dim {
473                        reconstructed[(row, subspace_idx * sub_dim + col)] =
474                            centroids[[centroid_idx, col]];
475                    }
476                }
477            }
478        }
479
480        reconstructed
481    }
482
483    /// Apply rotation matrix to vector (SIMD-accelerated dot product per row)
484    fn apply_rotation(rotation: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
485        let mut result = vec![0.0f32; dim];
486        for i in 0..dim {
487            result[i] = crate::structures::simd::dot_product_f32(
488                &rotation[i * dim..(i + 1) * dim],
489                vector,
490                dim,
491            );
492        }
493        result
494    }
495
496    /// Scalar k-means for WASM fallback
497    #[cfg(not(feature = "native"))]
498    fn train_subspace_scalar(
499        subvectors: &[Vec<f32>],
500        k: usize,
501        sub_dim: usize,
502        max_iters: usize,
503        rng: &mut impl Rng,
504    ) -> Vec<f32> {
505        let actual_k = k.min(subvectors.len());
506        let mut centroids = Self::kmeans_plusplus_init_scalar(subvectors, actual_k, sub_dim, rng);
507
508        for _ in 0..max_iters {
509            let assignments: Vec<usize> = subvectors
510                .iter()
511                .map(|v| Self::find_nearest_scalar(&centroids, v, sub_dim))
512                .collect();
513
514            let mut new_centroids = vec![0.0f32; actual_k * sub_dim];
515            let mut counts = vec![0usize; actual_k];
516
517            for (subvec, &assignment) in subvectors.iter().zip(assignments.iter()) {
518                counts[assignment] += 1;
519                let offset = assignment * sub_dim;
520                for (j, &val) in subvec.iter().enumerate() {
521                    new_centroids[offset + j] += val;
522                }
523            }
524
525            for (c, &count) in counts.iter().enumerate().take(actual_k) {
526                if count > 0 {
527                    let offset = c * sub_dim;
528                    for j in 0..sub_dim {
529                        new_centroids[offset + j] /= count as f32;
530                    }
531                }
532            }
533
534            centroids = new_centroids;
535        }
536
537        while centroids.len() < k * sub_dim {
538            let last_start = centroids.len() - sub_dim;
539            let last: Vec<f32> = centroids[last_start..].to_vec();
540            centroids.extend(last);
541        }
542
543        centroids
544    }
545
546    #[cfg(not(feature = "native"))]
547    fn kmeans_plusplus_init_scalar(
548        subvectors: &[Vec<f32>],
549        k: usize,
550        sub_dim: usize,
551        rng: &mut impl Rng,
552    ) -> Vec<f32> {
553        let mut centroids = Vec::with_capacity(k * sub_dim);
554        let first_idx = rng.random_range(0..subvectors.len());
555        centroids.extend_from_slice(&subvectors[first_idx]);
556
557        for _ in 1..k {
558            let distances: Vec<f32> = subvectors
559                .iter()
560                .map(|v| Self::min_dist_to_centroids_scalar(&centroids, v, sub_dim))
561                .collect();
562
563            let total: f32 = distances.iter().sum();
564            let mut r = rng.random::<f32>() * total;
565            let mut chosen_idx = 0;
566            for (i, &d) in distances.iter().enumerate() {
567                r -= d;
568                if r <= 0.0 {
569                    chosen_idx = i;
570                    break;
571                }
572            }
573            centroids.extend_from_slice(&subvectors[chosen_idx]);
574        }
575
576        centroids
577    }
578
579    #[cfg(not(feature = "native"))]
580    fn min_dist_to_centroids_scalar(centroids: &[f32], vector: &[f32], sub_dim: usize) -> f32 {
581        let num_centroids = centroids.len() / sub_dim;
582        (0..num_centroids)
583            .map(|c| {
584                let offset = c * sub_dim;
585                vector
586                    .iter()
587                    .zip(&centroids[offset..offset + sub_dim])
588                    .map(|(&a, &b)| (a - b) * (a - b))
589                    .sum()
590            })
591            .fold(f32::MAX, f32::min)
592    }
593
594    #[cfg(not(feature = "native"))]
595    fn find_nearest_scalar(centroids: &[f32], vector: &[f32], sub_dim: usize) -> usize {
596        let num_centroids = centroids.len() / sub_dim;
597        (0..num_centroids)
598            .map(|c| {
599                let offset = c * sub_dim;
600                let dist: f32 = vector
601                    .iter()
602                    .zip(&centroids[offset..offset + sub_dim])
603                    .map(|(&a, &b)| (a - b) * (a - b))
604                    .sum();
605                (c, dist)
606            })
607            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
608            .map(|(c, _)| c)
609            .unwrap_or(0)
610    }
611
612    /// Find nearest centroid index
613    fn find_nearest(centroids: &[f32], vector: &[f32], sub_dim: usize) -> usize {
614        let num_centroids = centroids.len() / sub_dim;
615        let mut best_idx = 0;
616        let mut best_dist = f32::MAX;
617
618        for c in 0..num_centroids {
619            let offset = c * sub_dim;
620            let dist: f32 = vector
621                .iter()
622                .zip(&centroids[offset..offset + sub_dim])
623                .map(|(&a, &b)| (a - b) * (a - b))
624                .sum();
625
626            if dist < best_dist {
627                best_dist = dist;
628                best_idx = c;
629            }
630        }
631
632        best_idx
633    }
634
635    /// Encode a vector to PQ codes
636    pub fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> PQVector {
637        let m = self.config.num_subspaces;
638        let k = self.config.num_centroids;
639        let sub_dim = self.config.subspace_dim();
640
641        // Compute residual if centroid provided
642        let residual: Vec<f32> = if let Some(c) = centroid {
643            vector.iter().zip(c).map(|(&v, &c)| v - c).collect()
644        } else {
645            vector.to_vec()
646        };
647
648        // Apply rotation if present
649        let rotated: Vec<f32>;
650        let vec_to_encode = if let Some(ref r) = self.rotation_matrix {
651            rotated = Self::apply_rotation(r, &residual, self.config.dim);
652            &rotated
653        } else {
654            &residual
655        };
656
657        let mut codes = Vec::with_capacity(m);
658
659        for subspace_idx in 0..m {
660            let vec_offset = subspace_idx * sub_dim;
661            let subvec = &vec_to_encode[vec_offset..vec_offset + sub_dim];
662
663            let centroid_base = subspace_idx * k * sub_dim;
664            let centroids_slice = &self.centroids[centroid_base..centroid_base + k * sub_dim];
665
666            let nearest = Self::find_nearest(centroids_slice, subvec, sub_dim);
667            codes.push(nearest as u8);
668        }
669
670        let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
671        PQVector::new(codes, norm)
672    }
673
674    /// Decode PQ codes back to approximate vector
675    pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
676        let m = self.config.num_subspaces;
677        let k = self.config.num_centroids;
678        let sub_dim = self.config.subspace_dim();
679
680        let mut rotated_vector = Vec::with_capacity(self.config.dim);
681
682        for (subspace_idx, &code) in codes.iter().enumerate().take(m) {
683            let centroid_base = subspace_idx * k * sub_dim;
684            let centroid_offset = centroid_base + (code as usize) * sub_dim;
685            rotated_vector
686                .extend_from_slice(&self.centroids[centroid_offset..centroid_offset + sub_dim]);
687        }
688
689        // Apply inverse rotation if present
690        if let Some(ref r) = self.rotation_matrix {
691            Self::apply_rotation_transpose(r, &rotated_vector, self.config.dim)
692        } else {
693            rotated_vector
694        }
695    }
696
697    /// Apply transpose of rotation matrix
698    fn apply_rotation_transpose(rotation: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
699        let mut result = vec![0.0f32; dim];
700        for i in 0..dim {
701            for j in 0..dim {
702                result[i] += rotation[j * dim + i] * vector[j];
703            }
704        }
705        result
706    }
707
708    /// Get centroid for a specific subspace and code
709    #[inline]
710    pub fn get_centroid(&self, subspace_idx: usize, code: u8) -> &[f32] {
711        let k = self.config.num_centroids;
712        let sub_dim = self.config.subspace_dim();
713        let offset = subspace_idx * k * sub_dim + (code as usize) * sub_dim;
714        &self.centroids[offset..offset + sub_dim]
715    }
716
717    /// Rotate a query vector
718    pub fn rotate_query(&self, query: &[f32]) -> Vec<f32> {
719        if let Some(ref r) = self.rotation_matrix {
720            Self::apply_rotation(r, query, self.config.dim)
721        } else {
722            query.to_vec()
723        }
724    }
725
726    /// Save to binary file
727    pub fn save(&self, path: &std::path::Path) -> io::Result<()> {
728        let mut file = std::fs::File::create(path)?;
729        self.write_to(&mut file)
730    }
731
732    /// Write to any writer
733    pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
734        writer.write_u32::<LittleEndian>(CODEBOOK_MAGIC)?;
735        writer.write_u32::<LittleEndian>(2)?;
736        writer.write_u64::<LittleEndian>(self.version)?;
737        writer.write_u32::<LittleEndian>(self.config.dim as u32)?;
738        writer.write_u32::<LittleEndian>(self.config.num_subspaces as u32)?;
739        writer.write_u32::<LittleEndian>(self.config.dims_per_block as u32)?;
740        writer.write_u32::<LittleEndian>(self.config.num_centroids as u32)?;
741        writer.write_u8(if self.config.anisotropic { 1 } else { 0 })?;
742        writer.write_f32::<LittleEndian>(self.config.aniso_eta)?;
743        writer.write_f32::<LittleEndian>(self.config.aniso_threshold)?;
744        writer.write_u8(if self.config.use_opq { 1 } else { 0 })?;
745        writer.write_u32::<LittleEndian>(self.config.opq_iters as u32)?;
746
747        if let Some(ref rotation) = self.rotation_matrix {
748            writer.write_u8(1)?;
749            for &val in rotation {
750                writer.write_f32::<LittleEndian>(val)?;
751            }
752        } else {
753            writer.write_u8(0)?;
754        }
755
756        for &val in &self.centroids {
757            writer.write_f32::<LittleEndian>(val)?;
758        }
759
760        if let Some(ref norms) = self.centroid_norms {
761            writer.write_u8(1)?;
762            for &val in norms {
763                writer.write_f32::<LittleEndian>(val)?;
764            }
765        } else {
766            writer.write_u8(0)?;
767        }
768
769        Ok(())
770    }
771
772    /// Load from binary file
773    pub fn load(path: &std::path::Path) -> io::Result<Self> {
774        let data = std::fs::read(path)?;
775        Self::read_from(&mut std::io::Cursor::new(data))
776    }
777
778    /// Read from any reader
779    pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
780        let magic = reader.read_u32::<LittleEndian>()?;
781        if magic != CODEBOOK_MAGIC {
782            return Err(io::Error::new(
783                io::ErrorKind::InvalidData,
784                "Invalid codebook file magic",
785            ));
786        }
787
788        let file_version = reader.read_u32::<LittleEndian>()?;
789        let version = reader.read_u64::<LittleEndian>()?;
790        let dim = reader.read_u32::<LittleEndian>()? as usize;
791        let num_subspaces = reader.read_u32::<LittleEndian>()? as usize;
792
793        let (
794            dims_per_block,
795            num_centroids,
796            anisotropic,
797            aniso_eta,
798            aniso_threshold,
799            use_opq,
800            opq_iters,
801        ) = if file_version >= 2 {
802            let dpb = reader.read_u32::<LittleEndian>()? as usize;
803            let nc = reader.read_u32::<LittleEndian>()? as usize;
804            let aniso = reader.read_u8()? != 0;
805            let eta = reader.read_f32::<LittleEndian>()?;
806            let thresh = reader.read_f32::<LittleEndian>()?;
807            let opq = reader.read_u8()? != 0;
808            let iters = reader.read_u32::<LittleEndian>()? as usize;
809            (dpb, nc, aniso, eta, thresh, opq, iters)
810        } else {
811            let nc = reader.read_u32::<LittleEndian>()? as usize;
812            let aniso = reader.read_u8()? != 0;
813            let thresh = reader.read_f32::<LittleEndian>()?;
814            let dpb = dim / num_subspaces;
815            (dpb, nc, aniso, 10.0, thresh, false, 0)
816        };
817
818        let config = PQConfig {
819            dim,
820            num_subspaces,
821            dims_per_block,
822            num_centroids,
823            seed: 42,
824            anisotropic,
825            aniso_eta,
826            aniso_threshold,
827            use_opq,
828            opq_iters,
829        };
830
831        let rotation_matrix = if file_version >= 2 {
832            let has_rotation = reader.read_u8()? != 0;
833            if has_rotation {
834                let mut rotation = vec![0.0f32; dim * dim];
835                for val in &mut rotation {
836                    *val = reader.read_f32::<LittleEndian>()?;
837                }
838                Some(rotation)
839            } else {
840                None
841            }
842        } else {
843            None
844        };
845
846        let centroid_count = num_subspaces * num_centroids * config.subspace_dim();
847        let mut centroids = vec![0.0f32; centroid_count];
848        for val in &mut centroids {
849            *val = reader.read_f32::<LittleEndian>()?;
850        }
851
852        let has_norms = reader.read_u8()? != 0;
853        let centroid_norms = if has_norms {
854            let mut norms = vec![0.0f32; num_subspaces * num_centroids];
855            for val in &mut norms {
856                *val = reader.read_f32::<LittleEndian>()?;
857            }
858            Some(norms)
859        } else {
860            None
861        };
862
863        Ok(Self {
864            config,
865            rotation_matrix,
866            centroids,
867            version,
868            centroid_norms,
869        })
870    }
871
872    /// Memory usage in bytes
873    pub fn size_bytes(&self) -> usize {
874        let centroids_size = self.centroids.len() * 4;
875        let norms_size = self
876            .centroid_norms
877            .as_ref()
878            .map(|n| n.len() * 4)
879            .unwrap_or(0);
880        let rotation_size = self
881            .rotation_matrix
882            .as_ref()
883            .map(|r| r.len() * 4)
884            .unwrap_or(0);
885        centroids_size + norms_size + rotation_size + 64
886    }
887
888    /// Estimated memory usage in bytes (alias for size_bytes)
889    pub fn estimated_memory_bytes(&self) -> usize {
890        self.size_bytes()
891    }
892}
893
894/// Precomputed distance table for fast asymmetric distance computation
895#[derive(Debug, Clone)]
896pub struct DistanceTable {
897    /// M × K table of squared distances
898    pub distances: Vec<f32>,
899    /// Number of subspaces
900    pub num_subspaces: usize,
901    /// Number of centroids per subspace
902    pub num_centroids: usize,
903}
904
905impl DistanceTable {
906    /// Build distance table for a query vector
907    pub fn build(codebook: &PQCodebook, query: &[f32], centroid: Option<&[f32]>) -> Self {
908        let m = codebook.config.num_subspaces;
909        let k = codebook.config.num_centroids;
910        let sub_dim = codebook.config.subspace_dim();
911
912        // Compute residual if centroid provided
913        let residual: Vec<f32> = if let Some(c) = centroid {
914            query.iter().zip(c).map(|(&v, &c)| v - c).collect()
915        } else {
916            query.to_vec()
917        };
918
919        // Apply rotation if present
920        let rotated_query = codebook.rotate_query(&residual);
921
922        let mut distances = Vec::with_capacity(m * k);
923
924        for subspace_idx in 0..m {
925            let query_offset = subspace_idx * sub_dim;
926            let query_sub = &rotated_query[query_offset..query_offset + sub_dim];
927
928            let centroid_base = subspace_idx * k * sub_dim;
929
930            for centroid_idx in 0..k {
931                let centroid_offset = centroid_base + centroid_idx * sub_dim;
932                let centroid = &codebook.centroids[centroid_offset..centroid_offset + sub_dim];
933
934                let dist: f32 = query_sub
935                    .iter()
936                    .zip(centroid.iter())
937                    .map(|(&a, &b)| (a - b) * (a - b))
938                    .sum();
939
940                distances.push(dist);
941            }
942        }
943
944        Self {
945            distances,
946            num_subspaces: m,
947            num_centroids: k,
948        }
949    }
950
951    /// Compute approximate distance using PQ codes
952    #[inline]
953    pub fn compute_distance(&self, codes: &[u8]) -> f32 {
954        let k = self.num_centroids;
955        let mut total = 0.0f32;
956
957        for (subspace_idx, &code) in codes.iter().enumerate() {
958            let table_offset = subspace_idx * k + code as usize;
959            total += self.distances[table_offset];
960        }
961
962        total
963    }
964}
965
966impl Quantizer for PQCodebook {
967    type Code = PQVector;
968    type Config = PQConfig;
969    type QueryData = DistanceTable;
970
971    fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> Self::Code {
972        self.encode(vector, centroid)
973    }
974
975    fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> Self::QueryData {
976        DistanceTable::build(self, query, centroid)
977    }
978
979    fn compute_distance(&self, query_data: &Self::QueryData, code: &Self::Code) -> f32 {
980        query_data.compute_distance(&code.codes)
981    }
982
983    fn decode(&self, code: &Self::Code) -> Option<Vec<f32>> {
984        Some(self.decode(&code.codes))
985    }
986
987    fn size_bytes(&self) -> usize {
988        self.size_bytes()
989    }
990}
991
992#[cfg(test)]
993mod tests {
994    use super::*;
995    use rand::prelude::*;
996
997    #[test]
998    fn test_pq_config() {
999        let config = PQConfig::new(128);
1000        assert_eq!(config.dim, 128);
1001        assert_eq!(config.dims_per_block, 2);
1002        assert_eq!(config.num_subspaces, 64);
1003    }
1004
1005    #[test]
1006    fn test_pq_encode_decode() {
1007        let dim = 32;
1008        let config = PQConfig::new(dim).with_opq(false, 0);
1009
1010        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
1011        let vectors: Vec<Vec<f32>> = (0..100)
1012            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
1013            .collect();
1014
1015        let codebook = PQCodebook::train(config, &vectors, 10);
1016
1017        let test_vec: Vec<f32> = (0..dim).map(|i| i as f32 / dim as f32).collect();
1018        let code = codebook.encode(&test_vec, None);
1019
1020        assert_eq!(code.codes.len(), 16); // 32 dims / 2 dims_per_block
1021    }
1022
1023    #[test]
1024    fn test_distance_table() {
1025        let dim = 16;
1026        let config = PQConfig::new(dim).with_opq(false, 0);
1027
1028        let mut rng = rand::rngs::StdRng::seed_from_u64(123);
1029        let vectors: Vec<Vec<f32>> = (0..50)
1030            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
1031            .collect();
1032
1033        let codebook = PQCodebook::train(config, &vectors, 5);
1034
1035        let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
1036        let table = DistanceTable::build(&codebook, &query, None);
1037
1038        let code = codebook.encode(&vectors[0], None);
1039        let dist = table.compute_distance(&code.codes);
1040
1041        assert!(dist >= 0.0);
1042    }
1043}