oxirs_vec/
pq.rs

1//! Product Quantization (PQ) for memory-efficient vector compression and search
2//!
3//! PQ divides high-dimensional vectors into subvectors and quantizes each subvector
4//! independently using k-means clustering. This achieves high compression ratios
5//! while maintaining reasonable search accuracy.
6
7use crate::{Vector, VectorIndex};
8use anyhow::{anyhow, Result};
9use std::collections::HashMap;
10
11/// Configuration for Product Quantization
12#[derive(Debug, Clone, PartialEq)]
13pub struct PQConfig {
14    /// Number of subquantizers (vector is split into this many parts)
15    pub n_subquantizers: usize,
16    /// Number of centroids per subquantizer (typically 256 for 8-bit codes)
17    pub n_centroids: usize,
18    /// Number of bits per subquantizer (determines n_centroids: 2^n_bits)
19    pub n_bits: usize,
20    /// Number of iterations for k-means training
21    pub max_iterations: usize,
22    /// Convergence threshold for k-means
23    pub convergence_threshold: f32,
24    /// Random seed for reproducibility
25    pub seed: Option<u64>,
26    /// Enable residual quantization for better accuracy
27    pub enable_residual_quantization: bool,
28    /// Number of residual quantization levels
29    pub residual_levels: usize,
30    /// Enable multi-codebook quantization
31    pub enable_multi_codebook: bool,
32    /// Number of codebooks for multi-codebook quantization
33    pub num_codebooks: usize,
34    /// Enable symmetric distance computation
35    pub enable_symmetric_distance: bool,
36}
37
38impl Default for PQConfig {
39    fn default() -> Self {
40        Self {
41            n_subquantizers: 8,
42            n_centroids: 256,
43            n_bits: 8, // 2^8 = 256 centroids
44            max_iterations: 50,
45            convergence_threshold: 1e-4,
46            seed: None,
47            enable_residual_quantization: false,
48            residual_levels: 2,
49            enable_multi_codebook: false,
50            num_codebooks: 2,
51            enable_symmetric_distance: false,
52        }
53    }
54}
55
56impl PQConfig {
57    /// Create a new PQConfig with specified bits per subquantizer
58    pub fn with_bits(n_subquantizers: usize, n_bits: usize) -> Self {
59        Self {
60            n_subquantizers,
61            n_centroids: 1 << n_bits, // 2^n_bits
62            n_bits,
63            max_iterations: 50,
64            convergence_threshold: 1e-4,
65            seed: None,
66            enable_residual_quantization: false,
67            residual_levels: 2,
68            enable_multi_codebook: false,
69            num_codebooks: 2,
70            enable_symmetric_distance: false,
71        }
72    }
73
74    /// Create a configuration with residual quantization enabled
75    pub fn with_residual_quantization(
76        n_subquantizers: usize,
77        n_bits: usize,
78        residual_levels: usize,
79    ) -> Self {
80        Self {
81            n_subquantizers,
82            n_centroids: 1 << n_bits,
83            n_bits,
84            enable_residual_quantization: true,
85            residual_levels,
86            ..Default::default()
87        }
88    }
89
90    /// Create a configuration with multi-codebook quantization enabled
91    pub fn with_multi_codebook(
92        n_subquantizers: usize,
93        n_bits: usize,
94        num_codebooks: usize,
95    ) -> Self {
96        Self {
97            n_subquantizers,
98            n_centroids: 1 << n_bits,
99            n_bits,
100            enable_multi_codebook: true,
101            num_codebooks,
102            ..Default::default()
103        }
104    }
105
106    /// Create a configuration with all enhancements enabled
107    pub fn enhanced(n_subquantizers: usize, n_bits: usize) -> Self {
108        Self {
109            n_subquantizers,
110            n_centroids: 1 << n_bits,
111            n_bits,
112            enable_residual_quantization: true,
113            residual_levels: 2,
114            enable_multi_codebook: true,
115            num_codebooks: 2,
116            enable_symmetric_distance: true,
117            ..Default::default()
118        }
119    }
120
121    /// Validate the configuration
122    pub fn validate(&self) -> Result<()> {
123        if self.n_centroids != (1 << self.n_bits) {
124            return Err(anyhow!(
125                "n_centroids {} doesn't match 2^n_bits ({})",
126                self.n_centroids,
127                1 << self.n_bits
128            ));
129        }
130        if self.n_subquantizers == 0 {
131            return Err(anyhow!("n_subquantizers must be greater than 0"));
132        }
133        if self.n_bits == 0 || self.n_bits > 16 {
134            return Err(anyhow!("n_bits must be between 1 and 16"));
135        }
136        if self.enable_residual_quantization && self.residual_levels == 0 {
137            return Err(anyhow!(
138                "residual_levels must be greater than 0 when residual quantization is enabled"
139            ));
140        }
141        if self.enable_multi_codebook && self.num_codebooks < 2 {
142            return Err(anyhow!(
143                "num_codebooks must be at least 2 when multi-codebook quantization is enabled"
144            ));
145        }
146        Ok(())
147    }
148}
149
150/// A single subquantizer that handles a portion of the vector dimensions
151#[derive(Debug, Clone)]
152struct SubQuantizer {
153    /// Start dimension (inclusive)
154    start_dim: usize,
155    /// End dimension (exclusive)
156    end_dim: usize,
157    /// Centroids for this subquantizer
158    centroids: Vec<Vec<f32>>,
159}
160
161impl SubQuantizer {
162    fn new(start_dim: usize, end_dim: usize, n_centroids: usize) -> Self {
163        Self {
164            start_dim,
165            end_dim,
166            centroids: Vec::with_capacity(n_centroids),
167        }
168    }
169
170    /// Extract subvector from full vector
171    fn extract_subvector(&self, vector: &[f32]) -> Vec<f32> {
172        vector[self.start_dim..self.end_dim].to_vec()
173    }
174
175    /// Train this subquantizer on subvectors
176    fn train(&mut self, subvectors: &[Vec<f32>], config: &PQConfig) -> Result<()> {
177        if subvectors.is_empty() {
178            return Err(anyhow!("Cannot train subquantizer with empty data"));
179        }
180
181        let dims = subvectors[0].len();
182
183        // Initialize centroids using k-means++
184        self.centroids = self.initialize_centroids_kmeans_plus_plus(subvectors, config)?;
185
186        // Run k-means
187        let mut iteration = 0;
188        let mut prev_error = f32::INFINITY;
189
190        while iteration < config.max_iterations {
191            // Assign points to nearest centroids
192            let mut clusters: Vec<Vec<&Vec<f32>>> = vec![Vec::new(); config.n_centroids];
193
194            for subvector in subvectors {
195                let nearest_idx = self.find_nearest_centroid(subvector)?;
196                clusters[nearest_idx].push(subvector);
197            }
198
199            // Update centroids
200            let mut total_error = 0.0;
201            for (i, cluster) in clusters.iter().enumerate() {
202                if !cluster.is_empty() {
203                    let new_centroid = self.compute_centroid(cluster, dims);
204                    total_error += self.euclidean_distance(&self.centroids[i], &new_centroid);
205                    self.centroids[i] = new_centroid;
206                }
207            }
208
209            // Check convergence
210            if (prev_error - total_error).abs() < config.convergence_threshold {
211                break;
212            }
213
214            prev_error = total_error;
215            iteration += 1;
216        }
217
218        Ok(())
219    }
220
221    /// Initialize centroids using k-means++
222    fn initialize_centroids_kmeans_plus_plus(
223        &self,
224        subvectors: &[Vec<f32>],
225        config: &PQConfig,
226    ) -> Result<Vec<Vec<f32>>> {
227        use std::collections::hash_map::DefaultHasher;
228        use std::hash::{Hash, Hasher};
229
230        let mut hasher = DefaultHasher::new();
231        config.seed.unwrap_or(42).hash(&mut hasher);
232        let mut rng_state = hasher.finish();
233
234        let mut centroids = Vec::with_capacity(config.n_centroids);
235
236        // Choose first centroid randomly
237        let first_idx = (rng_state as usize) % subvectors.len();
238        centroids.push(subvectors[first_idx].clone());
239
240        // Choose remaining centroids
241        while centroids.len() < config.n_centroids {
242            let mut distances = Vec::with_capacity(subvectors.len());
243            let mut sum_distances = 0.0;
244
245            // Calculate distance to nearest centroid for each point
246            for subvector in subvectors {
247                let min_dist = centroids
248                    .iter()
249                    .map(|c| self.euclidean_distance(subvector, c))
250                    .fold(f32::INFINITY, |a, b| a.min(b));
251
252                distances.push(min_dist * min_dist);
253                sum_distances += min_dist * min_dist;
254            }
255
256            // Choose next centroid
257            rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
258            let threshold = (rng_state as f32 / u64::MAX as f32) * sum_distances;
259
260            let mut cumulative = 0.0;
261            for (i, &dist) in distances.iter().enumerate() {
262                cumulative += dist;
263                if cumulative >= threshold {
264                    centroids.push(subvectors[i].clone());
265                    break;
266                }
267            }
268        }
269
270        Ok(centroids)
271    }
272
273    /// Compute centroid of a cluster
274    fn compute_centroid(&self, cluster: &[&Vec<f32>], dims: usize) -> Vec<f32> {
275        if cluster.is_empty() {
276            return vec![0.0; dims];
277        }
278
279        let mut sum = vec![0.0; dims];
280        for vector in cluster {
281            for (i, &val) in vector.iter().enumerate() {
282                sum[i] += val;
283            }
284        }
285
286        let count = cluster.len() as f32;
287        for val in &mut sum {
288            *val /= count;
289        }
290
291        sum
292    }
293
294    /// Find nearest centroid for a subvector
295    fn find_nearest_centroid(&self, subvector: &[f32]) -> Result<usize> {
296        if self.centroids.is_empty() {
297            return Err(anyhow!("No centroids available"));
298        }
299
300        let mut min_distance = f32::INFINITY;
301        let mut nearest_idx = 0;
302
303        for (i, centroid) in self.centroids.iter().enumerate() {
304            let distance = self.euclidean_distance(subvector, centroid);
305            if distance < min_distance {
306                min_distance = distance;
307                nearest_idx = i;
308            }
309        }
310
311        Ok(nearest_idx)
312    }
313
314    /// Compute Euclidean distance between two vectors
315    fn euclidean_distance(&self, a: &[f32], b: &[f32]) -> f32 {
316        a.iter()
317            .zip(b.iter())
318            .map(|(x, y)| (x - y).powi(2))
319            .sum::<f32>()
320            .sqrt()
321    }
322
323    /// Encode a subvector to its nearest centroid index
324    fn encode(&self, subvector: &[f32]) -> Result<u8> {
325        if self.centroids.len() > 256 {
326            return Err(anyhow!("Too many centroids for u8 encoding"));
327        }
328
329        let idx = self.find_nearest_centroid(subvector)?;
330        Ok(idx as u8)
331    }
332
333    /// Decode a centroid index back to a subvector
334    fn decode(&self, code: u8) -> Result<Vec<f32>> {
335        let idx = code as usize;
336        if idx >= self.centroids.len() {
337            return Err(anyhow!("Invalid code: {}", code));
338        }
339        Ok(self.centroids[idx].clone())
340    }
341}
342
343/// Enhanced codes structure for advanced PQ features
344#[derive(Debug, Clone)]
345pub struct EnhancedCodes {
346    /// Primary quantization codes
347    pub primary: Vec<u8>,
348    /// Residual quantization codes (one per level)
349    pub residual: Vec<Vec<u8>>,
350    /// Multi-codebook quantization codes (one per codebook)
351    pub multi_codebook: Vec<Vec<u8>>,
352}
353
354/// Enhanced Product Quantization index with residual and multi-codebook support
355#[derive(Debug, Clone)]
356pub struct PQIndex {
357    config: PQConfig,
358    /// Primary subquantizers
359    subquantizers: Vec<SubQuantizer>,
360    /// Residual quantizers (for each level)
361    residual_quantizers: Vec<Vec<SubQuantizer>>,
362    /// Multi-codebook quantizers
363    multi_codebook_quantizers: Vec<Vec<SubQuantizer>>,
364    /// Encoded vectors (primary codes)
365    codes: Vec<(String, Vec<u8>)>,
366    /// Residual codes (for each level)
367    residual_codes: Vec<Vec<(String, Vec<u8>)>>,
368    /// Multi-codebook codes
369    multi_codebook_codes: Vec<Vec<(String, Vec<u8>)>>,
370    /// Distance lookup tables for symmetric distance computation
371    distance_tables: Option<Vec<Vec<Vec<f32>>>>,
372    /// URI to index mapping
373    uri_to_id: HashMap<String, usize>,
374    /// Vector dimensions
375    dimensions: Option<usize>,
376    /// Whether the index has been trained
377    is_trained: bool,
378}
379
380impl PQIndex {
381    /// Create a new PQ index
382    pub fn new(config: PQConfig) -> Self {
383        Self {
384            residual_quantizers: vec![Vec::new(); config.residual_levels],
385            multi_codebook_quantizers: vec![Vec::new(); config.num_codebooks],
386            residual_codes: vec![Vec::new(); config.residual_levels],
387            multi_codebook_codes: vec![Vec::new(); config.num_codebooks],
388            distance_tables: None,
389            config,
390            subquantizers: Vec::new(),
391            codes: Vec::new(),
392            uri_to_id: HashMap::new(),
393            dimensions: None,
394            is_trained: false,
395        }
396    }
397
398    /// Train the PQ index with training vectors
399    pub fn train(&mut self, training_vectors: &[Vector]) -> Result<()> {
400        if training_vectors.is_empty() {
401            return Err(anyhow!("Cannot train PQ with empty training set"));
402        }
403
404        // Validate dimensions
405        let dims = training_vectors[0].dimensions;
406        if !training_vectors.iter().all(|v| v.dimensions == dims) {
407            return Err(anyhow!(
408                "All training vectors must have the same dimensions"
409            ));
410        }
411
412        if dims % self.config.n_subquantizers != 0 {
413            return Err(anyhow!(
414                "Vector dimensions {} must be divisible by n_subquantizers {}",
415                dims,
416                self.config.n_subquantizers
417            ));
418        }
419
420        self.dimensions = Some(dims);
421        let subdim = dims / self.config.n_subquantizers;
422
423        // Initialize subquantizers
424        self.subquantizers.clear();
425        for i in 0..self.config.n_subquantizers {
426            let start = i * subdim;
427            let end = start + subdim;
428            self.subquantizers
429                .push(SubQuantizer::new(start, end, self.config.n_centroids));
430        }
431
432        // Extract training data as f32
433        let training_data: Vec<Vec<f32>> = training_vectors.iter().map(|v| v.as_f32()).collect();
434
435        // Train each subquantizer
436        for sq in self.subquantizers.iter_mut() {
437            // Extract subvectors for this subquantizer
438            let subvectors: Vec<Vec<f32>> = training_data
439                .iter()
440                .map(|v| sq.extract_subvector(v))
441                .collect();
442
443            sq.train(&subvectors, &self.config)?;
444        }
445
446        // Train residual quantizers if enabled
447        if self.config.enable_residual_quantization {
448            self.train_residual_quantizers(&training_data)?;
449        }
450
451        // Train multi-codebook quantizers if enabled
452        if self.config.enable_multi_codebook {
453            self.train_multi_codebook_quantizers(&training_data)?;
454        }
455
456        // Build distance tables for symmetric distance computation if enabled
457        if self.config.enable_symmetric_distance {
458            self.build_distance_tables()?;
459        }
460
461        self.is_trained = true;
462        Ok(())
463    }
464
465    /// Train residual quantizers for improved accuracy
466    fn train_residual_quantizers(&mut self, training_data: &[Vec<f32>]) -> Result<()> {
467        let subdim = self.dimensions.unwrap() / self.config.n_subquantizers;
468
469        // Start with residuals from the primary quantizers
470        let mut current_residuals = training_data.to_vec();
471
472        for level in 0..self.config.residual_levels {
473            // Compute residuals from previous level
474            if level == 0 {
475                // Compute residuals from primary quantizers
476                for (i, vector) in training_data.iter().enumerate() {
477                    let primary_codes = self.encode_primary_vector(vector)?;
478                    let reconstructed = self.decode_primary_codes(&primary_codes)?;
479
480                    // Compute residual
481                    let residual: Vec<f32> = vector
482                        .iter()
483                        .zip(reconstructed.iter())
484                        .map(|(a, b)| a - b)
485                        .collect();
486                    current_residuals[i] = residual;
487                }
488            } else {
489                // Compute residuals from previous residual level
490                for (i, residual) in current_residuals.clone().iter().enumerate() {
491                    let residual_codes = self.encode_residual_vector(residual, level - 1)?;
492                    let reconstructed_residual =
493                        self.decode_residual_codes(&residual_codes, level - 1)?;
494
495                    let new_residual: Vec<f32> = residual
496                        .iter()
497                        .zip(reconstructed_residual.iter())
498                        .map(|(a, b)| a - b)
499                        .collect();
500                    current_residuals[i] = new_residual;
501                }
502            }
503
504            // Initialize residual subquantizers for this level
505            self.residual_quantizers[level].clear();
506            for i in 0..self.config.n_subquantizers {
507                let start = i * subdim;
508                let end = start + subdim;
509                self.residual_quantizers[level].push(SubQuantizer::new(
510                    start,
511                    end,
512                    self.config.n_centroids,
513                ));
514            }
515
516            // Train each residual subquantizer
517            for sq in self.residual_quantizers[level].iter_mut() {
518                let subvectors: Vec<Vec<f32>> = current_residuals
519                    .iter()
520                    .map(|v| sq.extract_subvector(v))
521                    .collect();
522
523                sq.train(&subvectors, &self.config)?;
524            }
525        }
526
527        Ok(())
528    }
529
530    /// Train multi-codebook quantizers for better coverage
531    fn train_multi_codebook_quantizers(&mut self, training_data: &[Vec<f32>]) -> Result<()> {
532        let subdim = self.dimensions.unwrap() / self.config.n_subquantizers;
533
534        for codebook_idx in 0..self.config.num_codebooks {
535            // Initialize subquantizers for this codebook
536            self.multi_codebook_quantizers[codebook_idx].clear();
537            for i in 0..self.config.n_subquantizers {
538                let start = i * subdim;
539                let end = start + subdim;
540                self.multi_codebook_quantizers[codebook_idx].push(SubQuantizer::new(
541                    start,
542                    end,
543                    self.config.n_centroids,
544                ));
545            }
546
547            // Use different initialization for each codebook
548            let mut modified_config = self.config.clone();
549            modified_config.seed = self.config.seed.map(|s| s + codebook_idx as u64);
550
551            // Train each subquantizer in this codebook
552            for sq in self.multi_codebook_quantizers[codebook_idx].iter_mut() {
553                let subvectors: Vec<Vec<f32>> = training_data
554                    .iter()
555                    .map(|v| sq.extract_subvector(v))
556                    .collect();
557
558                sq.train(&subvectors, &modified_config)?;
559            }
560        }
561
562        Ok(())
563    }
564
565    /// Build distance lookup tables for symmetric distance computation
566    fn build_distance_tables(&mut self) -> Result<()> {
567        let mut tables = Vec::new();
568
569        for sq_idx in 0..self.config.n_subquantizers {
570            let sq = &self.subquantizers[sq_idx];
571            let mut sq_table = Vec::new();
572
573            // Build distance table between all pairs of centroids
574            for i in 0..sq.centroids.len() {
575                let mut centroid_distances = Vec::new();
576                for j in 0..sq.centroids.len() {
577                    let distance = sq.euclidean_distance(&sq.centroids[i], &sq.centroids[j]);
578                    centroid_distances.push(distance);
579                }
580                sq_table.push(centroid_distances);
581            }
582            tables.push(sq_table);
583        }
584
585        self.distance_tables = Some(tables);
586        Ok(())
587    }
588
589    /// Helper method to encode with primary quantizers only
590    fn encode_primary_vector(&self, vector: &[f32]) -> Result<Vec<u8>> {
591        let mut codes = Vec::with_capacity(self.subquantizers.len());
592
593        for sq in &self.subquantizers {
594            let subvec = sq.extract_subvector(vector);
595            let code = sq.encode(&subvec)?;
596            codes.push(code);
597        }
598
599        Ok(codes)
600    }
601
602    /// Helper method to decode primary codes
603    fn decode_primary_codes(&self, codes: &[u8]) -> Result<Vec<f32>> {
604        let mut reconstructed = Vec::new();
605
606        for (sq, &code) in self.subquantizers.iter().zip(codes.iter()) {
607            let subvec = sq.decode(code)?;
608            reconstructed.extend(subvec);
609        }
610
611        Ok(reconstructed)
612    }
613
614    /// Helper method to encode with residual quantizers
615    fn encode_residual_vector(&self, vector: &[f32], level: usize) -> Result<Vec<u8>> {
616        if level >= self.residual_quantizers.len() {
617            return Err(anyhow!("Invalid residual level: {}", level));
618        }
619
620        let mut codes = Vec::with_capacity(self.residual_quantizers[level].len());
621
622        for sq in &self.residual_quantizers[level] {
623            let subvec = sq.extract_subvector(vector);
624            let code = sq.encode(&subvec)?;
625            codes.push(code);
626        }
627
628        Ok(codes)
629    }
630
631    /// Helper method to decode residual codes
632    fn decode_residual_codes(&self, codes: &[u8], level: usize) -> Result<Vec<f32>> {
633        if level >= self.residual_quantizers.len() {
634            return Err(anyhow!("Invalid residual level: {}", level));
635        }
636
637        let mut reconstructed = Vec::new();
638
639        for (sq, &code) in self.residual_quantizers[level].iter().zip(codes.iter()) {
640            let subvec = sq.decode(code)?;
641            reconstructed.extend(subvec);
642        }
643
644        Ok(reconstructed)
645    }
646
647    /// Encode a vector into PQ codes
648    fn encode_vector(&self, vector: &Vector) -> Result<Vec<u8>> {
649        if !self.is_trained {
650            return Err(anyhow!("PQ index must be trained before encoding"));
651        }
652
653        let vector_f32 = vector.as_f32();
654        let mut codes = Vec::with_capacity(self.subquantizers.len());
655
656        for sq in &self.subquantizers {
657            let subvec = sq.extract_subvector(&vector_f32);
658            let code = sq.encode(&subvec)?;
659            codes.push(code);
660        }
661
662        Ok(codes)
663    }
664
665    /// Decode PQ codes back to an approximate vector
666    fn decode_codes(&self, codes: &[u8]) -> Result<Vector> {
667        if codes.len() != self.subquantizers.len() {
668            return Err(anyhow!("Invalid code length"));
669        }
670
671        let mut reconstructed = Vec::new();
672
673        for (sq, &code) in self.subquantizers.iter().zip(codes.iter()) {
674            let subvec = sq.decode(code)?;
675            reconstructed.extend(subvec);
676        }
677
678        Ok(Vector::new(reconstructed))
679    }
680
681    /// Public method to encode a vector (for OPQ)
682    pub fn encode(&self, vector: &Vector) -> Result<Vec<u8>> {
683        self.encode_vector(vector)
684    }
685
686    /// Public method to decode codes (for OPQ)
687    pub fn decode(&self, codes: &[u8]) -> Result<Vector> {
688        self.decode_codes(codes)
689    }
690
691    /// Reconstruct a vector by encoding and then decoding (for OPQ)
692    pub fn reconstruct(&self, vector: &Vector) -> Result<Vector> {
693        let codes = self.encode_vector(vector)?;
694        self.decode_codes(&codes)
695    }
696
697    /// Compute asymmetric distance between a query vector and PQ codes
698    fn asymmetric_distance(&self, query: &Vector, codes: &[u8]) -> Result<f32> {
699        let query_f32 = query.as_f32();
700        let mut total_distance = 0.0;
701
702        for (sq, &code) in self.subquantizers.iter().zip(codes.iter()) {
703            let query_subvec = sq.extract_subvector(&query_f32);
704            let centroid = &sq.centroids[code as usize];
705
706            // Compute squared distance to avoid sqrt
707            let dist: f32 = query_subvec
708                .iter()
709                .zip(centroid.iter())
710                .map(|(a, b)| (a - b).powi(2))
711                .sum();
712
713            total_distance += dist;
714        }
715
716        Ok(total_distance.sqrt())
717    }
718
719    /// Enhanced encoding with residual and multi-codebook support
720    fn encode_vector_enhanced(&self, vector: &Vector) -> Result<EnhancedCodes> {
721        if !self.is_trained {
722            return Err(anyhow!("PQ index must be trained before encoding"));
723        }
724
725        let vector_f32 = vector.as_f32();
726
727        // Primary encoding
728        let primary_codes = self.encode_primary_vector(&vector_f32)?;
729
730        // Residual encoding if enabled
731        let mut residual_codes = Vec::new();
732        if self.config.enable_residual_quantization {
733            let mut current_residual = vector_f32.clone();
734
735            // Compute residual from primary quantization
736            let primary_reconstructed = self.decode_primary_codes(&primary_codes)?;
737            current_residual = current_residual
738                .iter()
739                .zip(primary_reconstructed.iter())
740                .map(|(a, b)| a - b)
741                .collect();
742
743            // Encode residuals at each level
744            for level in 0..self.config.residual_levels {
745                let level_codes = self.encode_residual_vector(&current_residual, level)?;
746                residual_codes.push(level_codes.clone());
747
748                // Update residual for next level
749                if level < self.config.residual_levels - 1 {
750                    let level_reconstructed = self.decode_residual_codes(&level_codes, level)?;
751                    current_residual = current_residual
752                        .iter()
753                        .zip(level_reconstructed.iter())
754                        .map(|(a, b)| a - b)
755                        .collect();
756                }
757            }
758        }
759
760        // Multi-codebook encoding if enabled
761        let mut multi_codebook_codes = Vec::new();
762        if self.config.enable_multi_codebook {
763            for codebook_idx in 0..self.config.num_codebooks {
764                let mut codes =
765                    Vec::with_capacity(self.multi_codebook_quantizers[codebook_idx].len());
766
767                for sq in &self.multi_codebook_quantizers[codebook_idx] {
768                    let subvec = sq.extract_subvector(&vector_f32);
769                    let code = sq.encode(&subvec)?;
770                    codes.push(code);
771                }
772                multi_codebook_codes.push(codes);
773            }
774        }
775
776        Ok(EnhancedCodes {
777            primary: primary_codes,
778            residual: residual_codes,
779            multi_codebook: multi_codebook_codes,
780        })
781    }
782
783    /// Symmetric distance computation between two sets of codes
784    fn symmetric_distance(&self, codes1: &[u8], codes2: &[u8]) -> Result<f32> {
785        if !self.config.enable_symmetric_distance {
786            return Err(anyhow!("Symmetric distance computation not enabled"));
787        }
788
789        let distance_tables = self
790            .distance_tables
791            .as_ref()
792            .ok_or_else(|| anyhow!("Distance tables not built"))?;
793
794        if codes1.len() != codes2.len() || codes1.len() != self.config.n_subquantizers {
795            return Err(anyhow!("Invalid code lengths for symmetric distance"));
796        }
797
798        let mut total_distance = 0.0;
799
800        for (sq_idx, (&code1, &code2)) in codes1.iter().zip(codes2.iter()).enumerate() {
801            let distance = distance_tables[sq_idx][code1 as usize][code2 as usize];
802            total_distance += distance * distance; // Squared distance
803        }
804
805        Ok(total_distance.sqrt())
806    }
807
808    /// Enhanced distance computation with residual and multi-codebook support
809    fn enhanced_distance(&self, query: &Vector, enhanced_codes: &EnhancedCodes) -> Result<f32> {
810        // Start with primary distance
811        let mut total_distance = self.asymmetric_distance(query, &enhanced_codes.primary)?;
812
813        // Add residual distances if enabled
814        if self.config.enable_residual_quantization && !enhanced_codes.residual.is_empty() {
815            let query_f32 = query.as_f32();
816            let mut current_residual = query_f32.clone();
817
818            // Compute residual from primary quantization
819            let primary_reconstructed = self.decode_primary_codes(&enhanced_codes.primary)?;
820            current_residual = current_residual
821                .iter()
822                .zip(primary_reconstructed.iter())
823                .map(|(a, b)| a - b)
824                .collect();
825
826            // Add distance from each residual level
827            for (level, residual_codes) in enhanced_codes.residual.iter().enumerate() {
828                let mut residual_distance = 0.0;
829
830                for (sq, &code) in self.residual_quantizers[level]
831                    .iter()
832                    .zip(residual_codes.iter())
833                {
834                    let query_subvec = sq.extract_subvector(&current_residual);
835                    let centroid = &sq.centroids[code as usize];
836
837                    let dist: f32 = query_subvec
838                        .iter()
839                        .zip(centroid.iter())
840                        .map(|(a, b)| (a - b).powi(2))
841                        .sum();
842
843                    residual_distance += dist;
844                }
845
846                total_distance += residual_distance.sqrt() * 0.5; // Weight residual distances
847
848                // Update residual for next level
849                if level < enhanced_codes.residual.len() - 1 {
850                    let level_reconstructed = self.decode_residual_codes(residual_codes, level)?;
851                    current_residual = current_residual
852                        .iter()
853                        .zip(level_reconstructed.iter())
854                        .map(|(a, b)| a - b)
855                        .collect();
856                }
857            }
858        }
859
860        // For multi-codebook, use the minimum distance across codebooks
861        if self.config.enable_multi_codebook && !enhanced_codes.multi_codebook.is_empty() {
862            let mut min_codebook_distance = f32::INFINITY;
863
864            for codes in &enhanced_codes.multi_codebook {
865                let codebook_distance = self.asymmetric_distance(query, codes)?;
866                min_codebook_distance = min_codebook_distance.min(codebook_distance);
867            }
868
869            // Use the minimum as a refinement
870            total_distance = total_distance.min(min_codebook_distance);
871        }
872
873        Ok(total_distance)
874    }
875
876    /// Get compression ratio
877    pub fn compression_ratio(&self) -> f32 {
878        if let Some(dims) = self.dimensions {
879            // Original: dims * 4 bytes (f32)
880            // Compressed: n_subquantizers bytes
881            (dims as f32 * 4.0) / (self.config.n_subquantizers as f32)
882        } else {
883            0.0
884        }
885    }
886
887    /// Get index statistics
888    pub fn stats(&self) -> PQStats {
889        PQStats {
890            n_vectors: self.codes.len(),
891            n_subquantizers: self.config.n_subquantizers,
892            n_centroids: self.config.n_centroids,
893            is_trained: self.is_trained,
894            dimensions: self.dimensions,
895            compression_ratio: self.compression_ratio(),
896            memory_usage_bytes: self.estimate_memory_usage(),
897        }
898    }
899
900    /// Estimate memory usage in bytes
901    fn estimate_memory_usage(&self) -> usize {
902        let codebook_size = self
903            .subquantizers
904            .iter()
905            .map(|sq| sq.centroids.len() * (sq.end_dim - sq.start_dim) * 4)
906            .sum::<usize>();
907
908        let codes_size = self.codes.len() * self.config.n_subquantizers;
909
910        codebook_size + codes_size
911    }
912
913    /// Check if the index is trained
914    pub fn is_trained(&self) -> bool {
915        self.is_trained
916    }
917
918    /// Compute distance between query and encoded vector (for IVF compatibility)
919    pub fn compute_distance(&self, query: &Vector, codes: &[u8]) -> Result<f32> {
920        self.asymmetric_distance(query, codes)
921    }
922
923    /// Decode codes to vector (for IVF compatibility)
924    pub fn decode_vector(&self, codes: &[u8]) -> Result<Vector> {
925        self.decode_codes(codes)
926    }
927}
928
929impl VectorIndex for PQIndex {
930    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
931        if !self.is_trained {
932            return Err(anyhow!("PQ index must be trained before inserting vectors"));
933        }
934
935        // Validate dimensions
936        if let Some(dims) = self.dimensions {
937            if vector.dimensions != dims {
938                return Err(anyhow!(
939                    "Vector dimensions {} don't match index dimensions {}",
940                    vector.dimensions,
941                    dims
942                ));
943            }
944        }
945
946        // Encode the vector
947        let codes = self.encode_vector(&vector)?;
948
949        // Store the codes
950        let id = self.codes.len();
951        self.uri_to_id.insert(uri.clone(), id);
952        self.codes.push((uri, codes));
953
954        Ok(())
955    }
956
957    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
958        if !self.is_trained {
959            return Err(anyhow!("PQ index must be trained before searching"));
960        }
961
962        // Compute distances to all vectors
963        let mut distances: Vec<(String, f32)> = self
964            .codes
965            .iter()
966            .map(|(uri, codes)| {
967                let dist = self
968                    .asymmetric_distance(query, codes)
969                    .unwrap_or(f32::INFINITY);
970                (uri.clone(), dist)
971            })
972            .collect();
973
974        // Sort by distance
975        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
976        distances.truncate(k);
977
978        // Convert distances to similarities
979        Ok(distances
980            .into_iter()
981            .map(|(uri, dist)| (uri, 1.0 / (1.0 + dist)))
982            .collect())
983    }
984
985    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
986        if !self.is_trained {
987            return Err(anyhow!("PQ index must be trained before searching"));
988        }
989
990        let mut results = Vec::new();
991
992        for (uri, codes) in &self.codes {
993            let dist = self.asymmetric_distance(query, codes)?;
994            let similarity = 1.0 / (1.0 + dist);
995
996            if similarity >= threshold {
997                results.push((uri.clone(), similarity));
998            }
999        }
1000
1001        // Sort by similarity
1002        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1003
1004        Ok(results)
1005    }
1006
1007    fn get_vector(&self, _uri: &str) -> Option<&Vector> {
1008        // PQ doesn't store original vectors, only codes
1009        // Would need to decode, but that returns an approximation
1010        None
1011    }
1012}
1013
1014impl PQIndex {
1015    /// Public search method for use by OPQ and other modules
1016    pub fn search(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1017        self.search_knn(query, k)
1018    }
1019}
1020
1021/// Statistics for PQ index
1022#[derive(Debug, Clone)]
1023pub struct PQStats {
1024    pub n_vectors: usize,
1025    pub n_subquantizers: usize,
1026    pub n_centroids: usize,
1027    pub is_trained: bool,
1028    pub dimensions: Option<usize>,
1029    pub compression_ratio: f32,
1030    pub memory_usage_bytes: usize,
1031}
1032
1033#[cfg(test)]
1034mod tests {
1035    use super::*;
1036
1037    #[test]
1038    fn test_pq_basic() {
1039        let config = PQConfig {
1040            n_subquantizers: 2,
1041            n_centroids: 4,
1042            ..Default::default()
1043        };
1044
1045        let mut index = PQIndex::new(config);
1046
1047        // Create training vectors
1048        let training_vectors = vec![
1049            Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1050            Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1051            Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1052            Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1053            Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1054            Vector::new(vec![-0.5, -0.5, -0.5, -0.5]),
1055        ];
1056
1057        // Train the index
1058        index.train(&training_vectors).unwrap();
1059        assert!(index.is_trained);
1060
1061        // Insert vectors
1062        for (i, vec) in training_vectors.iter().enumerate() {
1063            index.insert(format!("vec{i}"), vec.clone()).unwrap();
1064        }
1065
1066        // Search for nearest neighbors
1067        let query = Vector::new(vec![0.9, 0.1, 0.1, 0.9]);
1068        let results = index.search_knn(&query, 3).unwrap();
1069
1070        assert!(!results.is_empty());
1071        assert!(results.len() <= 3);
1072    }
1073
1074    #[test]
1075    fn test_pq_compression() {
1076        let config = PQConfig {
1077            n_subquantizers: 4,
1078            n_centroids: 16,
1079            ..Default::default()
1080        };
1081
1082        let mut index = PQIndex::new(config);
1083
1084        // Create 128-dimensional vectors
1085        let dims = 128;
1086        let training_vectors: Vec<Vector> = (0..100)
1087            .map(|i| {
1088                let values: Vec<f32> = (0..dims).map(|j| ((i + j) as f32).sin()).collect();
1089                Vector::new(values)
1090            })
1091            .collect();
1092
1093        // Train and check compression ratio
1094        index.train(&training_vectors).unwrap();
1095
1096        let compression_ratio = index.compression_ratio();
1097        assert_eq!(compression_ratio, 128.0); // 128*4 bytes -> 4 bytes
1098
1099        let stats = index.stats();
1100        assert_eq!(stats.n_subquantizers, 4);
1101        assert_eq!(stats.n_centroids, 16);
1102        assert_eq!(stats.dimensions, Some(128));
1103    }
1104
1105    #[test]
1106    fn test_pq_reconstruction() {
1107        let config = PQConfig {
1108            n_subquantizers: 2,
1109            n_centroids: 8,
1110            ..Default::default()
1111        };
1112
1113        let mut index = PQIndex::new(config);
1114
1115        // Simple training set
1116        let training_vectors = vec![
1117            Vector::new(vec![1.0, 0.0]),
1118            Vector::new(vec![0.0, 1.0]),
1119            Vector::new(vec![-1.0, 0.0]),
1120            Vector::new(vec![0.0, -1.0]),
1121        ];
1122
1123        index.train(&training_vectors).unwrap();
1124
1125        // Encode and decode a vector
1126        let original = Vector::new(vec![0.7, 0.7]);
1127        let codes = index.encode_vector(&original).unwrap();
1128        let reconstructed = index.decode_codes(&codes).unwrap();
1129
1130        // Check that reconstruction is reasonable (not exact due to quantization)
1131        let dist = original.euclidean_distance(&reconstructed).unwrap();
1132        assert!(dist < 1.0); // Should be reasonably close
1133    }
1134
1135    #[test]
1136    fn test_pq_residual_quantization() {
1137        let config = PQConfig::with_residual_quantization(2, 3, 2); // 2 subquantizers, 3 bits, 2 residual levels
1138        let mut index = PQIndex::new(config);
1139
1140        // Create training vectors
1141        let training_vectors = vec![
1142            Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1143            Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1144            Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1145            Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1146            Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1147            Vector::new(vec![-0.5, -0.5, -0.5, -0.5]),
1148        ];
1149
1150        // Train the index with residual quantization
1151        index.train(&training_vectors).unwrap();
1152        assert!(index.is_trained());
1153        assert_eq!(index.residual_quantizers.len(), 2);
1154
1155        // Test enhanced encoding
1156        let test_vector = Vector::new(vec![0.7, 0.3, 0.3, 0.7]);
1157        let enhanced_codes = index.encode_vector_enhanced(&test_vector).unwrap();
1158
1159        assert!(!enhanced_codes.primary.is_empty());
1160        assert_eq!(enhanced_codes.residual.len(), 2);
1161        assert!(enhanced_codes.multi_codebook.is_empty()); // Multi-codebook not enabled
1162    }
1163
1164    #[test]
1165    fn test_pq_multi_codebook() {
1166        let config = PQConfig::with_multi_codebook(2, 3, 3); // 2 subquantizers, 3 bits, 3 codebooks
1167        let mut index = PQIndex::new(config);
1168
1169        // Create training vectors
1170        let training_vectors = vec![
1171            Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1172            Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1173            Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1174            Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1175            Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1176            Vector::new(vec![-0.5, -0.5, -0.5, -0.5]),
1177        ];
1178
1179        // Train the index with multi-codebook quantization
1180        index.train(&training_vectors).unwrap();
1181        assert!(index.is_trained());
1182        assert_eq!(index.multi_codebook_quantizers.len(), 3);
1183
1184        // Test enhanced encoding
1185        let test_vector = Vector::new(vec![0.7, 0.3, 0.3, 0.7]);
1186        let enhanced_codes = index.encode_vector_enhanced(&test_vector).unwrap();
1187
1188        assert!(!enhanced_codes.primary.is_empty());
1189        assert!(enhanced_codes.residual.is_empty()); // Residual not enabled
1190        assert_eq!(enhanced_codes.multi_codebook.len(), 3);
1191    }
1192
1193    #[test]
1194    fn test_pq_symmetric_distance() {
1195        let config = PQConfig {
1196            enable_symmetric_distance: true,
1197            n_subquantizers: 2,
1198            n_centroids: 4,
1199            ..Default::default()
1200        };
1201
1202        let mut index = PQIndex::new(config);
1203
1204        // Create training vectors
1205        let training_vectors = vec![
1206            Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1207            Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1208            Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1209            Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1210        ];
1211
1212        // Train the index
1213        index.train(&training_vectors).unwrap();
1214        assert!(index.distance_tables.is_some());
1215
1216        // Test symmetric distance computation
1217        let codes1 = vec![0, 1];
1218        let codes2 = vec![1, 0];
1219        let distance = index.symmetric_distance(&codes1, &codes2).unwrap();
1220
1221        assert!(distance >= 0.0);
1222        assert!(distance.is_finite());
1223    }
1224
1225    #[test]
1226    fn test_pq_enhanced_features() {
1227        let config = PQConfig::enhanced(2, 3); // All features enabled
1228        let mut index = PQIndex::new(config);
1229
1230        // Create training vectors
1231        let training_vectors = vec![
1232            Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1233            Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1234            Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1235            Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1236            Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1237            Vector::new(vec![-0.5, -0.5, -0.5, -0.5]),
1238        ];
1239
1240        // Train with all enhanced features
1241        index.train(&training_vectors).unwrap();
1242        assert!(index.is_trained());
1243
1244        // Verify all features are initialized
1245        assert!(!index.residual_quantizers.is_empty());
1246        assert!(!index.multi_codebook_quantizers.is_empty());
1247        assert!(index.distance_tables.is_some());
1248
1249        // Test enhanced encoding and distance computation
1250        let test_vector = Vector::new(vec![0.7, 0.3, 0.3, 0.7]);
1251        let enhanced_codes = index.encode_vector_enhanced(&test_vector).unwrap();
1252        let enhanced_distance = index
1253            .enhanced_distance(&test_vector, &enhanced_codes)
1254            .unwrap();
1255
1256        assert!(enhanced_distance >= 0.0);
1257        assert!(enhanced_distance.is_finite());
1258
1259        // Enhanced distance should be more accurate (smaller) than basic asymmetric distance
1260        let basic_distance = index
1261            .asymmetric_distance(&test_vector, &enhanced_codes.primary)
1262            .unwrap();
1263        assert!(enhanced_distance <= basic_distance * 1.1); // Allow some tolerance
1264    }
1265
1266    #[test]
1267    fn test_pq_config_validation() {
1268        // Test valid enhanced config
1269        let config = PQConfig::enhanced(4, 8);
1270        assert!(config.validate().is_ok());
1271
1272        // Test invalid residual config
1273        let invalid_config = PQConfig {
1274            enable_residual_quantization: true,
1275            residual_levels: 0,
1276            ..Default::default()
1277        };
1278        assert!(invalid_config.validate().is_err());
1279
1280        // Test invalid multi-codebook config
1281        let invalid_config = PQConfig {
1282            enable_multi_codebook: true,
1283            num_codebooks: 1,
1284            ..Default::default()
1285        };
1286        assert!(invalid_config.validate().is_err());
1287    }
1288}