oxirs_vec/
pq.rs

1//! Product Quantization (PQ) for memory-efficient vector compression and search
2//!
3//! PQ divides high-dimensional vectors into subvectors and quantizes each subvector
4//! independently using k-means clustering. This achieves high compression ratios
5//! while maintaining reasonable search accuracy.
6
7use crate::{Vector, VectorIndex};
8use anyhow::{anyhow, Result};
9use std::collections::HashMap;
10
11/// Configuration for Product Quantization
12#[derive(Debug, Clone, PartialEq)]
13pub struct PQConfig {
14    /// Number of subquantizers (vector is split into this many parts)
15    pub n_subquantizers: usize,
16    /// Number of centroids per subquantizer (typically 256 for 8-bit codes)
17    pub n_centroids: usize,
18    /// Number of bits per subquantizer (determines n_centroids: 2^n_bits)
19    pub n_bits: usize,
20    /// Number of iterations for k-means training
21    pub max_iterations: usize,
22    /// Convergence threshold for k-means
23    pub convergence_threshold: f32,
24    /// Random seed for reproducibility
25    pub seed: Option<u64>,
26    /// Enable residual quantization for better accuracy
27    pub enable_residual_quantization: bool,
28    /// Number of residual quantization levels
29    pub residual_levels: usize,
30    /// Enable multi-codebook quantization
31    pub enable_multi_codebook: bool,
32    /// Number of codebooks for multi-codebook quantization
33    pub num_codebooks: usize,
34    /// Enable symmetric distance computation
35    pub enable_symmetric_distance: bool,
36}
37
38impl Default for PQConfig {
39    fn default() -> Self {
40        Self {
41            n_subquantizers: 8,
42            n_centroids: 256,
43            n_bits: 8, // 2^8 = 256 centroids
44            max_iterations: 50,
45            convergence_threshold: 1e-4,
46            seed: None,
47            enable_residual_quantization: false,
48            residual_levels: 2,
49            enable_multi_codebook: false,
50            num_codebooks: 2,
51            enable_symmetric_distance: false,
52        }
53    }
54}
55
56impl PQConfig {
57    /// Create a new PQConfig with specified bits per subquantizer
58    pub fn with_bits(n_subquantizers: usize, n_bits: usize) -> Self {
59        Self {
60            n_subquantizers,
61            n_centroids: 1 << n_bits, // 2^n_bits
62            n_bits,
63            max_iterations: 50,
64            convergence_threshold: 1e-4,
65            seed: None,
66            enable_residual_quantization: false,
67            residual_levels: 2,
68            enable_multi_codebook: false,
69            num_codebooks: 2,
70            enable_symmetric_distance: false,
71        }
72    }
73
74    /// Create a configuration with residual quantization enabled
75    pub fn with_residual_quantization(
76        n_subquantizers: usize,
77        n_bits: usize,
78        residual_levels: usize,
79    ) -> Self {
80        Self {
81            n_subquantizers,
82            n_centroids: 1 << n_bits,
83            n_bits,
84            enable_residual_quantization: true,
85            residual_levels,
86            ..Default::default()
87        }
88    }
89
90    /// Create a configuration with multi-codebook quantization enabled
91    pub fn with_multi_codebook(
92        n_subquantizers: usize,
93        n_bits: usize,
94        num_codebooks: usize,
95    ) -> Self {
96        Self {
97            n_subquantizers,
98            n_centroids: 1 << n_bits,
99            n_bits,
100            enable_multi_codebook: true,
101            num_codebooks,
102            ..Default::default()
103        }
104    }
105
106    /// Create a configuration with all enhancements enabled
107    pub fn enhanced(n_subquantizers: usize, n_bits: usize) -> Self {
108        Self {
109            n_subquantizers,
110            n_centroids: 1 << n_bits,
111            n_bits,
112            enable_residual_quantization: true,
113            residual_levels: 2,
114            enable_multi_codebook: true,
115            num_codebooks: 2,
116            enable_symmetric_distance: true,
117            ..Default::default()
118        }
119    }
120
121    /// Validate the configuration
122    pub fn validate(&self) -> Result<()> {
123        if self.n_centroids != (1 << self.n_bits) {
124            return Err(anyhow!(
125                "n_centroids {} doesn't match 2^n_bits ({})",
126                self.n_centroids,
127                1 << self.n_bits
128            ));
129        }
130        if self.n_subquantizers == 0 {
131            return Err(anyhow!("n_subquantizers must be greater than 0"));
132        }
133        if self.n_bits == 0 || self.n_bits > 16 {
134            return Err(anyhow!("n_bits must be between 1 and 16"));
135        }
136        if self.enable_residual_quantization && self.residual_levels == 0 {
137            return Err(anyhow!(
138                "residual_levels must be greater than 0 when residual quantization is enabled"
139            ));
140        }
141        if self.enable_multi_codebook && self.num_codebooks < 2 {
142            return Err(anyhow!(
143                "num_codebooks must be at least 2 when multi-codebook quantization is enabled"
144            ));
145        }
146        Ok(())
147    }
148}
149
150/// A single subquantizer that handles a portion of the vector dimensions
151#[derive(Debug, Clone)]
152struct SubQuantizer {
153    /// Start dimension (inclusive)
154    start_dim: usize,
155    /// End dimension (exclusive)
156    end_dim: usize,
157    /// Centroids for this subquantizer
158    centroids: Vec<Vec<f32>>,
159}
160
161impl SubQuantizer {
162    fn new(start_dim: usize, end_dim: usize, n_centroids: usize) -> Self {
163        Self {
164            start_dim,
165            end_dim,
166            centroids: Vec::with_capacity(n_centroids),
167        }
168    }
169
170    /// Extract subvector from full vector
171    fn extract_subvector(&self, vector: &[f32]) -> Vec<f32> {
172        vector[self.start_dim..self.end_dim].to_vec()
173    }
174
175    /// Train this subquantizer on subvectors
176    fn train(&mut self, subvectors: &[Vec<f32>], config: &PQConfig) -> Result<()> {
177        if subvectors.is_empty() {
178            return Err(anyhow!("Cannot train subquantizer with empty data"));
179        }
180
181        let dims = subvectors[0].len();
182
183        // Initialize centroids using k-means++
184        self.centroids = self.initialize_centroids_kmeans_plus_plus(subvectors, config)?;
185
186        // Run k-means
187        let mut iteration = 0;
188        let mut prev_error = f32::INFINITY;
189
190        while iteration < config.max_iterations {
191            // Assign points to nearest centroids
192            let mut clusters: Vec<Vec<&Vec<f32>>> = vec![Vec::new(); config.n_centroids];
193
194            for subvector in subvectors {
195                let nearest_idx = self.find_nearest_centroid(subvector)?;
196                clusters[nearest_idx].push(subvector);
197            }
198
199            // Update centroids
200            let mut total_error = 0.0;
201            for (i, cluster) in clusters.iter().enumerate() {
202                if !cluster.is_empty() {
203                    let new_centroid = self.compute_centroid(cluster, dims);
204                    total_error += self.euclidean_distance(&self.centroids[i], &new_centroid);
205                    self.centroids[i] = new_centroid;
206                }
207            }
208
209            // Check convergence
210            if (prev_error - total_error).abs() < config.convergence_threshold {
211                break;
212            }
213
214            prev_error = total_error;
215            iteration += 1;
216        }
217
218        Ok(())
219    }
220
221    /// Initialize centroids using k-means++
222    fn initialize_centroids_kmeans_plus_plus(
223        &self,
224        subvectors: &[Vec<f32>],
225        config: &PQConfig,
226    ) -> Result<Vec<Vec<f32>>> {
227        use std::collections::hash_map::DefaultHasher;
228        use std::hash::{Hash, Hasher};
229
230        let mut hasher = DefaultHasher::new();
231        config.seed.unwrap_or(42).hash(&mut hasher);
232        let mut rng_state = hasher.finish();
233
234        let mut centroids = Vec::with_capacity(config.n_centroids);
235
236        // Choose first centroid randomly
237        let first_idx = (rng_state as usize) % subvectors.len();
238        centroids.push(subvectors[first_idx].clone());
239
240        // Choose remaining centroids
241        while centroids.len() < config.n_centroids {
242            let mut distances = Vec::with_capacity(subvectors.len());
243            let mut sum_distances = 0.0;
244
245            // Calculate distance to nearest centroid for each point
246            for subvector in subvectors {
247                let min_dist = centroids
248                    .iter()
249                    .map(|c| self.euclidean_distance(subvector, c))
250                    .fold(f32::INFINITY, |a, b| a.min(b));
251
252                distances.push(min_dist * min_dist);
253                sum_distances += min_dist * min_dist;
254            }
255
256            // Choose next centroid
257            rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
258            let threshold = (rng_state as f32 / u64::MAX as f32) * sum_distances;
259
260            let mut cumulative = 0.0;
261            for (i, &dist) in distances.iter().enumerate() {
262                cumulative += dist;
263                if cumulative >= threshold {
264                    centroids.push(subvectors[i].clone());
265                    break;
266                }
267            }
268        }
269
270        Ok(centroids)
271    }
272
273    /// Compute centroid of a cluster
274    fn compute_centroid(&self, cluster: &[&Vec<f32>], dims: usize) -> Vec<f32> {
275        if cluster.is_empty() {
276            return vec![0.0; dims];
277        }
278
279        let mut sum = vec![0.0; dims];
280        for vector in cluster {
281            for (i, &val) in vector.iter().enumerate() {
282                sum[i] += val;
283            }
284        }
285
286        let count = cluster.len() as f32;
287        for val in &mut sum {
288            *val /= count;
289        }
290
291        sum
292    }
293
294    /// Find nearest centroid for a subvector
295    fn find_nearest_centroid(&self, subvector: &[f32]) -> Result<usize> {
296        if self.centroids.is_empty() {
297            return Err(anyhow!("No centroids available"));
298        }
299
300        let mut min_distance = f32::INFINITY;
301        let mut nearest_idx = 0;
302
303        for (i, centroid) in self.centroids.iter().enumerate() {
304            let distance = self.euclidean_distance(subvector, centroid);
305            if distance < min_distance {
306                min_distance = distance;
307                nearest_idx = i;
308            }
309        }
310
311        Ok(nearest_idx)
312    }
313
314    /// Compute Euclidean distance between two vectors
315    fn euclidean_distance(&self, a: &[f32], b: &[f32]) -> f32 {
316        a.iter()
317            .zip(b.iter())
318            .map(|(x, y)| (x - y).powi(2))
319            .sum::<f32>()
320            .sqrt()
321    }
322
323    /// Encode a subvector to its nearest centroid index
324    fn encode(&self, subvector: &[f32]) -> Result<u8> {
325        if self.centroids.len() > 256 {
326            return Err(anyhow!("Too many centroids for u8 encoding"));
327        }
328
329        let idx = self.find_nearest_centroid(subvector)?;
330        Ok(idx as u8)
331    }
332
333    /// Decode a centroid index back to a subvector
334    fn decode(&self, code: u8) -> Result<Vec<f32>> {
335        let idx = code as usize;
336        if idx >= self.centroids.len() {
337            return Err(anyhow!("Invalid code: {}", code));
338        }
339        Ok(self.centroids[idx].clone())
340    }
341}
342
343/// Enhanced codes structure for advanced PQ features
344#[derive(Debug, Clone)]
345pub struct EnhancedCodes {
346    /// Primary quantization codes
347    pub primary: Vec<u8>,
348    /// Residual quantization codes (one per level)
349    pub residual: Vec<Vec<u8>>,
350    /// Multi-codebook quantization codes (one per codebook)
351    pub multi_codebook: Vec<Vec<u8>>,
352}
353
354/// Enhanced Product Quantization index with residual and multi-codebook support
355#[derive(Debug, Clone)]
356pub struct PQIndex {
357    config: PQConfig,
358    /// Primary subquantizers
359    subquantizers: Vec<SubQuantizer>,
360    /// Residual quantizers (for each level)
361    residual_quantizers: Vec<Vec<SubQuantizer>>,
362    /// Multi-codebook quantizers
363    multi_codebook_quantizers: Vec<Vec<SubQuantizer>>,
364    /// Encoded vectors (primary codes)
365    codes: Vec<(String, Vec<u8>)>,
366    /// Residual codes (for each level)
367    residual_codes: Vec<Vec<(String, Vec<u8>)>>,
368    /// Multi-codebook codes
369    multi_codebook_codes: Vec<Vec<(String, Vec<u8>)>>,
370    /// Distance lookup tables for symmetric distance computation
371    distance_tables: Option<Vec<Vec<Vec<f32>>>>,
372    /// URI to index mapping
373    uri_to_id: HashMap<String, usize>,
374    /// Vector dimensions
375    dimensions: Option<usize>,
376    /// Whether the index has been trained
377    is_trained: bool,
378}
379
380impl PQIndex {
381    /// Create a new PQ index
382    pub fn new(config: PQConfig) -> Self {
383        Self {
384            residual_quantizers: vec![Vec::new(); config.residual_levels],
385            multi_codebook_quantizers: vec![Vec::new(); config.num_codebooks],
386            residual_codes: vec![Vec::new(); config.residual_levels],
387            multi_codebook_codes: vec![Vec::new(); config.num_codebooks],
388            distance_tables: None,
389            config,
390            subquantizers: Vec::new(),
391            codes: Vec::new(),
392            uri_to_id: HashMap::new(),
393            dimensions: None,
394            is_trained: false,
395        }
396    }
397
398    /// Train the PQ index with training vectors
399    pub fn train(&mut self, training_vectors: &[Vector]) -> Result<()> {
400        if training_vectors.is_empty() {
401            return Err(anyhow!("Cannot train PQ with empty training set"));
402        }
403
404        // Validate dimensions
405        let dims = training_vectors[0].dimensions;
406        if !training_vectors.iter().all(|v| v.dimensions == dims) {
407            return Err(anyhow!(
408                "All training vectors must have the same dimensions"
409            ));
410        }
411
412        if dims % self.config.n_subquantizers != 0 {
413            return Err(anyhow!(
414                "Vector dimensions {} must be divisible by n_subquantizers {}",
415                dims,
416                self.config.n_subquantizers
417            ));
418        }
419
420        self.dimensions = Some(dims);
421        let subdim = dims / self.config.n_subquantizers;
422
423        // Initialize subquantizers
424        self.subquantizers.clear();
425        for i in 0..self.config.n_subquantizers {
426            let start = i * subdim;
427            let end = start + subdim;
428            self.subquantizers
429                .push(SubQuantizer::new(start, end, self.config.n_centroids));
430        }
431
432        // Extract training data as f32
433        let training_data: Vec<Vec<f32>> = training_vectors.iter().map(|v| v.as_f32()).collect();
434
435        // Train each subquantizer
436        for sq in self.subquantizers.iter_mut() {
437            // Extract subvectors for this subquantizer
438            let subvectors: Vec<Vec<f32>> = training_data
439                .iter()
440                .map(|v| sq.extract_subvector(v))
441                .collect();
442
443            sq.train(&subvectors, &self.config)?;
444        }
445
446        // Train residual quantizers if enabled
447        if self.config.enable_residual_quantization {
448            self.train_residual_quantizers(&training_data)?;
449        }
450
451        // Train multi-codebook quantizers if enabled
452        if self.config.enable_multi_codebook {
453            self.train_multi_codebook_quantizers(&training_data)?;
454        }
455
456        // Build distance tables for symmetric distance computation if enabled
457        if self.config.enable_symmetric_distance {
458            self.build_distance_tables()?;
459        }
460
461        self.is_trained = true;
462        Ok(())
463    }
464
465    /// Train residual quantizers for improved accuracy
466    fn train_residual_quantizers(&mut self, training_data: &[Vec<f32>]) -> Result<()> {
467        let subdim = self
468            .dimensions
469            .expect("dimensions must be set after training")
470            / self.config.n_subquantizers;
471
472        // Start with residuals from the primary quantizers
473        let mut current_residuals = training_data.to_vec();
474
475        for level in 0..self.config.residual_levels {
476            // Compute residuals from previous level
477            if level == 0 {
478                // Compute residuals from primary quantizers
479                for (i, vector) in training_data.iter().enumerate() {
480                    let primary_codes = self.encode_primary_vector(vector)?;
481                    let reconstructed = self.decode_primary_codes(&primary_codes)?;
482
483                    // Compute residual
484                    let residual: Vec<f32> = vector
485                        .iter()
486                        .zip(reconstructed.iter())
487                        .map(|(a, b)| a - b)
488                        .collect();
489                    current_residuals[i] = residual;
490                }
491            } else {
492                // Compute residuals from previous residual level
493                for (i, residual) in current_residuals.clone().iter().enumerate() {
494                    let residual_codes = self.encode_residual_vector(residual, level - 1)?;
495                    let reconstructed_residual =
496                        self.decode_residual_codes(&residual_codes, level - 1)?;
497
498                    let new_residual: Vec<f32> = residual
499                        .iter()
500                        .zip(reconstructed_residual.iter())
501                        .map(|(a, b)| a - b)
502                        .collect();
503                    current_residuals[i] = new_residual;
504                }
505            }
506
507            // Initialize residual subquantizers for this level
508            self.residual_quantizers[level].clear();
509            for i in 0..self.config.n_subquantizers {
510                let start = i * subdim;
511                let end = start + subdim;
512                self.residual_quantizers[level].push(SubQuantizer::new(
513                    start,
514                    end,
515                    self.config.n_centroids,
516                ));
517            }
518
519            // Train each residual subquantizer
520            for sq in self.residual_quantizers[level].iter_mut() {
521                let subvectors: Vec<Vec<f32>> = current_residuals
522                    .iter()
523                    .map(|v| sq.extract_subvector(v))
524                    .collect();
525
526                sq.train(&subvectors, &self.config)?;
527            }
528        }
529
530        Ok(())
531    }
532
533    /// Train multi-codebook quantizers for better coverage
534    fn train_multi_codebook_quantizers(&mut self, training_data: &[Vec<f32>]) -> Result<()> {
535        let subdim = self
536            .dimensions
537            .expect("dimensions must be set after training")
538            / self.config.n_subquantizers;
539
540        for codebook_idx in 0..self.config.num_codebooks {
541            // Initialize subquantizers for this codebook
542            self.multi_codebook_quantizers[codebook_idx].clear();
543            for i in 0..self.config.n_subquantizers {
544                let start = i * subdim;
545                let end = start + subdim;
546                self.multi_codebook_quantizers[codebook_idx].push(SubQuantizer::new(
547                    start,
548                    end,
549                    self.config.n_centroids,
550                ));
551            }
552
553            // Use different initialization for each codebook
554            let mut modified_config = self.config.clone();
555            modified_config.seed = self.config.seed.map(|s| s + codebook_idx as u64);
556
557            // Train each subquantizer in this codebook
558            for sq in self.multi_codebook_quantizers[codebook_idx].iter_mut() {
559                let subvectors: Vec<Vec<f32>> = training_data
560                    .iter()
561                    .map(|v| sq.extract_subvector(v))
562                    .collect();
563
564                sq.train(&subvectors, &modified_config)?;
565            }
566        }
567
568        Ok(())
569    }
570
571    /// Build distance lookup tables for symmetric distance computation
572    fn build_distance_tables(&mut self) -> Result<()> {
573        let mut tables = Vec::new();
574
575        for sq_idx in 0..self.config.n_subquantizers {
576            let sq = &self.subquantizers[sq_idx];
577            let mut sq_table = Vec::new();
578
579            // Build distance table between all pairs of centroids
580            for i in 0..sq.centroids.len() {
581                let mut centroid_distances = Vec::new();
582                for j in 0..sq.centroids.len() {
583                    let distance = sq.euclidean_distance(&sq.centroids[i], &sq.centroids[j]);
584                    centroid_distances.push(distance);
585                }
586                sq_table.push(centroid_distances);
587            }
588            tables.push(sq_table);
589        }
590
591        self.distance_tables = Some(tables);
592        Ok(())
593    }
594
595    /// Helper method to encode with primary quantizers only
596    fn encode_primary_vector(&self, vector: &[f32]) -> Result<Vec<u8>> {
597        let mut codes = Vec::with_capacity(self.subquantizers.len());
598
599        for sq in &self.subquantizers {
600            let subvec = sq.extract_subvector(vector);
601            let code = sq.encode(&subvec)?;
602            codes.push(code);
603        }
604
605        Ok(codes)
606    }
607
608    /// Helper method to decode primary codes
609    fn decode_primary_codes(&self, codes: &[u8]) -> Result<Vec<f32>> {
610        let mut reconstructed = Vec::new();
611
612        for (sq, &code) in self.subquantizers.iter().zip(codes.iter()) {
613            let subvec = sq.decode(code)?;
614            reconstructed.extend(subvec);
615        }
616
617        Ok(reconstructed)
618    }
619
620    /// Helper method to encode with residual quantizers
621    fn encode_residual_vector(&self, vector: &[f32], level: usize) -> Result<Vec<u8>> {
622        if level >= self.residual_quantizers.len() {
623            return Err(anyhow!("Invalid residual level: {}", level));
624        }
625
626        let mut codes = Vec::with_capacity(self.residual_quantizers[level].len());
627
628        for sq in &self.residual_quantizers[level] {
629            let subvec = sq.extract_subvector(vector);
630            let code = sq.encode(&subvec)?;
631            codes.push(code);
632        }
633
634        Ok(codes)
635    }
636
637    /// Helper method to decode residual codes
638    fn decode_residual_codes(&self, codes: &[u8], level: usize) -> Result<Vec<f32>> {
639        if level >= self.residual_quantizers.len() {
640            return Err(anyhow!("Invalid residual level: {}", level));
641        }
642
643        let mut reconstructed = Vec::new();
644
645        for (sq, &code) in self.residual_quantizers[level].iter().zip(codes.iter()) {
646            let subvec = sq.decode(code)?;
647            reconstructed.extend(subvec);
648        }
649
650        Ok(reconstructed)
651    }
652
653    /// Encode a vector into PQ codes
654    fn encode_vector(&self, vector: &Vector) -> Result<Vec<u8>> {
655        if !self.is_trained {
656            return Err(anyhow!("PQ index must be trained before encoding"));
657        }
658
659        let vector_f32 = vector.as_f32();
660        let mut codes = Vec::with_capacity(self.subquantizers.len());
661
662        for sq in &self.subquantizers {
663            let subvec = sq.extract_subvector(&vector_f32);
664            let code = sq.encode(&subvec)?;
665            codes.push(code);
666        }
667
668        Ok(codes)
669    }
670
671    /// Decode PQ codes back to an approximate vector
672    fn decode_codes(&self, codes: &[u8]) -> Result<Vector> {
673        if codes.len() != self.subquantizers.len() {
674            return Err(anyhow!("Invalid code length"));
675        }
676
677        let mut reconstructed = Vec::new();
678
679        for (sq, &code) in self.subquantizers.iter().zip(codes.iter()) {
680            let subvec = sq.decode(code)?;
681            reconstructed.extend(subvec);
682        }
683
684        Ok(Vector::new(reconstructed))
685    }
686
687    /// Public method to encode a vector (for OPQ)
688    pub fn encode(&self, vector: &Vector) -> Result<Vec<u8>> {
689        self.encode_vector(vector)
690    }
691
692    /// Public method to decode codes (for OPQ)
693    pub fn decode(&self, codes: &[u8]) -> Result<Vector> {
694        self.decode_codes(codes)
695    }
696
697    /// Reconstruct a vector by encoding and then decoding (for OPQ)
698    pub fn reconstruct(&self, vector: &Vector) -> Result<Vector> {
699        let codes = self.encode_vector(vector)?;
700        self.decode_codes(&codes)
701    }
702
703    /// Compute asymmetric distance between a query vector and PQ codes
704    fn asymmetric_distance(&self, query: &Vector, codes: &[u8]) -> Result<f32> {
705        let query_f32 = query.as_f32();
706        let mut total_distance = 0.0;
707
708        for (sq, &code) in self.subquantizers.iter().zip(codes.iter()) {
709            let query_subvec = sq.extract_subvector(&query_f32);
710            let centroid = &sq.centroids[code as usize];
711
712            // Compute squared distance to avoid sqrt
713            let dist: f32 = query_subvec
714                .iter()
715                .zip(centroid.iter())
716                .map(|(a, b)| (a - b).powi(2))
717                .sum();
718
719            total_distance += dist;
720        }
721
722        Ok(total_distance.sqrt())
723    }
724
725    /// Enhanced encoding with residual and multi-codebook support
726    fn encode_vector_enhanced(&self, vector: &Vector) -> Result<EnhancedCodes> {
727        if !self.is_trained {
728            return Err(anyhow!("PQ index must be trained before encoding"));
729        }
730
731        let vector_f32 = vector.as_f32();
732
733        // Primary encoding
734        let primary_codes = self.encode_primary_vector(&vector_f32)?;
735
736        // Residual encoding if enabled
737        let mut residual_codes = Vec::new();
738        if self.config.enable_residual_quantization {
739            let mut current_residual = vector_f32.clone();
740
741            // Compute residual from primary quantization
742            let primary_reconstructed = self.decode_primary_codes(&primary_codes)?;
743            current_residual = current_residual
744                .iter()
745                .zip(primary_reconstructed.iter())
746                .map(|(a, b)| a - b)
747                .collect();
748
749            // Encode residuals at each level
750            for level in 0..self.config.residual_levels {
751                let level_codes = self.encode_residual_vector(&current_residual, level)?;
752                residual_codes.push(level_codes.clone());
753
754                // Update residual for next level
755                if level < self.config.residual_levels - 1 {
756                    let level_reconstructed = self.decode_residual_codes(&level_codes, level)?;
757                    current_residual = current_residual
758                        .iter()
759                        .zip(level_reconstructed.iter())
760                        .map(|(a, b)| a - b)
761                        .collect();
762                }
763            }
764        }
765
766        // Multi-codebook encoding if enabled
767        let mut multi_codebook_codes = Vec::new();
768        if self.config.enable_multi_codebook {
769            for codebook_idx in 0..self.config.num_codebooks {
770                let mut codes =
771                    Vec::with_capacity(self.multi_codebook_quantizers[codebook_idx].len());
772
773                for sq in &self.multi_codebook_quantizers[codebook_idx] {
774                    let subvec = sq.extract_subvector(&vector_f32);
775                    let code = sq.encode(&subvec)?;
776                    codes.push(code);
777                }
778                multi_codebook_codes.push(codes);
779            }
780        }
781
782        Ok(EnhancedCodes {
783            primary: primary_codes,
784            residual: residual_codes,
785            multi_codebook: multi_codebook_codes,
786        })
787    }
788
789    /// Symmetric distance computation between two sets of codes
790    fn symmetric_distance(&self, codes1: &[u8], codes2: &[u8]) -> Result<f32> {
791        if !self.config.enable_symmetric_distance {
792            return Err(anyhow!("Symmetric distance computation not enabled"));
793        }
794
795        let distance_tables = self
796            .distance_tables
797            .as_ref()
798            .ok_or_else(|| anyhow!("Distance tables not built"))?;
799
800        if codes1.len() != codes2.len() || codes1.len() != self.config.n_subquantizers {
801            return Err(anyhow!("Invalid code lengths for symmetric distance"));
802        }
803
804        let mut total_distance = 0.0;
805
806        for (sq_idx, (&code1, &code2)) in codes1.iter().zip(codes2.iter()).enumerate() {
807            let distance = distance_tables[sq_idx][code1 as usize][code2 as usize];
808            total_distance += distance * distance; // Squared distance
809        }
810
811        Ok(total_distance.sqrt())
812    }
813
814    /// Enhanced distance computation with residual and multi-codebook support
815    fn enhanced_distance(&self, query: &Vector, enhanced_codes: &EnhancedCodes) -> Result<f32> {
816        // Start with primary distance
817        let mut total_distance = self.asymmetric_distance(query, &enhanced_codes.primary)?;
818
819        // Add residual distances if enabled
820        if self.config.enable_residual_quantization && !enhanced_codes.residual.is_empty() {
821            let query_f32 = query.as_f32();
822            let mut current_residual = query_f32.clone();
823
824            // Compute residual from primary quantization
825            let primary_reconstructed = self.decode_primary_codes(&enhanced_codes.primary)?;
826            current_residual = current_residual
827                .iter()
828                .zip(primary_reconstructed.iter())
829                .map(|(a, b)| a - b)
830                .collect();
831
832            // Add distance from each residual level
833            for (level, residual_codes) in enhanced_codes.residual.iter().enumerate() {
834                let mut residual_distance = 0.0;
835
836                for (sq, &code) in self.residual_quantizers[level]
837                    .iter()
838                    .zip(residual_codes.iter())
839                {
840                    let query_subvec = sq.extract_subvector(&current_residual);
841                    let centroid = &sq.centroids[code as usize];
842
843                    let dist: f32 = query_subvec
844                        .iter()
845                        .zip(centroid.iter())
846                        .map(|(a, b)| (a - b).powi(2))
847                        .sum();
848
849                    residual_distance += dist;
850                }
851
852                total_distance += residual_distance.sqrt() * 0.5; // Weight residual distances
853
854                // Update residual for next level
855                if level < enhanced_codes.residual.len() - 1 {
856                    let level_reconstructed = self.decode_residual_codes(residual_codes, level)?;
857                    current_residual = current_residual
858                        .iter()
859                        .zip(level_reconstructed.iter())
860                        .map(|(a, b)| a - b)
861                        .collect();
862                }
863            }
864        }
865
866        // For multi-codebook, use the minimum distance across codebooks
867        if self.config.enable_multi_codebook && !enhanced_codes.multi_codebook.is_empty() {
868            let mut min_codebook_distance = f32::INFINITY;
869
870            for codes in &enhanced_codes.multi_codebook {
871                let codebook_distance = self.asymmetric_distance(query, codes)?;
872                min_codebook_distance = min_codebook_distance.min(codebook_distance);
873            }
874
875            // Use the minimum as a refinement
876            total_distance = total_distance.min(min_codebook_distance);
877        }
878
879        Ok(total_distance)
880    }
881
882    /// Get compression ratio
883    pub fn compression_ratio(&self) -> f32 {
884        if let Some(dims) = self.dimensions {
885            // Original: dims * 4 bytes (f32)
886            // Compressed: n_subquantizers bytes
887            (dims as f32 * 4.0) / (self.config.n_subquantizers as f32)
888        } else {
889            0.0
890        }
891    }
892
893    /// Get index statistics
894    pub fn stats(&self) -> PQStats {
895        PQStats {
896            n_vectors: self.codes.len(),
897            n_subquantizers: self.config.n_subquantizers,
898            n_centroids: self.config.n_centroids,
899            is_trained: self.is_trained,
900            dimensions: self.dimensions,
901            compression_ratio: self.compression_ratio(),
902            memory_usage_bytes: self.estimate_memory_usage(),
903        }
904    }
905
906    /// Estimate memory usage in bytes
907    fn estimate_memory_usage(&self) -> usize {
908        let codebook_size = self
909            .subquantizers
910            .iter()
911            .map(|sq| sq.centroids.len() * (sq.end_dim - sq.start_dim) * 4)
912            .sum::<usize>();
913
914        let codes_size = self.codes.len() * self.config.n_subquantizers;
915
916        codebook_size + codes_size
917    }
918
919    /// Check if the index is trained
920    pub fn is_trained(&self) -> bool {
921        self.is_trained
922    }
923
924    /// Compute distance between query and encoded vector (for IVF compatibility)
925    pub fn compute_distance(&self, query: &Vector, codes: &[u8]) -> Result<f32> {
926        self.asymmetric_distance(query, codes)
927    }
928
929    /// Decode codes to vector (for IVF compatibility)
930    pub fn decode_vector(&self, codes: &[u8]) -> Result<Vector> {
931        self.decode_codes(codes)
932    }
933}
934
935impl VectorIndex for PQIndex {
936    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
937        if !self.is_trained {
938            return Err(anyhow!("PQ index must be trained before inserting vectors"));
939        }
940
941        // Validate dimensions
942        if let Some(dims) = self.dimensions {
943            if vector.dimensions != dims {
944                return Err(anyhow!(
945                    "Vector dimensions {} don't match index dimensions {}",
946                    vector.dimensions,
947                    dims
948                ));
949            }
950        }
951
952        // Encode the vector
953        let codes = self.encode_vector(&vector)?;
954
955        // Store the codes
956        let id = self.codes.len();
957        self.uri_to_id.insert(uri.clone(), id);
958        self.codes.push((uri, codes));
959
960        Ok(())
961    }
962
963    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
964        if !self.is_trained {
965            return Err(anyhow!("PQ index must be trained before searching"));
966        }
967
968        // Compute distances to all vectors
969        let mut distances: Vec<(String, f32)> = self
970            .codes
971            .iter()
972            .map(|(uri, codes)| {
973                let dist = self
974                    .asymmetric_distance(query, codes)
975                    .unwrap_or(f32::INFINITY);
976                (uri.clone(), dist)
977            })
978            .collect();
979
980        // Sort by distance
981        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
982        distances.truncate(k);
983
984        // Convert distances to similarities
985        Ok(distances
986            .into_iter()
987            .map(|(uri, dist)| (uri, 1.0 / (1.0 + dist)))
988            .collect())
989    }
990
991    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
992        if !self.is_trained {
993            return Err(anyhow!("PQ index must be trained before searching"));
994        }
995
996        let mut results = Vec::new();
997
998        for (uri, codes) in &self.codes {
999            let dist = self.asymmetric_distance(query, codes)?;
1000            let similarity = 1.0 / (1.0 + dist);
1001
1002            if similarity >= threshold {
1003                results.push((uri.clone(), similarity));
1004            }
1005        }
1006
1007        // Sort by similarity
1008        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1009
1010        Ok(results)
1011    }
1012
1013    fn get_vector(&self, _uri: &str) -> Option<&Vector> {
1014        // PQ doesn't store original vectors, only codes
1015        // Would need to decode, but that returns an approximation
1016        None
1017    }
1018}
1019
1020impl PQIndex {
1021    /// Public search method for use by OPQ and other modules
1022    pub fn search(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1023        self.search_knn(query, k)
1024    }
1025}
1026
1027/// Statistics for PQ index
1028#[derive(Debug, Clone)]
1029pub struct PQStats {
1030    pub n_vectors: usize,
1031    pub n_subquantizers: usize,
1032    pub n_centroids: usize,
1033    pub is_trained: bool,
1034    pub dimensions: Option<usize>,
1035    pub compression_ratio: f32,
1036    pub memory_usage_bytes: usize,
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041    use super::*;
1042
1043    #[test]
1044    fn test_pq_basic() {
1045        let config = PQConfig {
1046            n_subquantizers: 2,
1047            n_centroids: 4,
1048            ..Default::default()
1049        };
1050
1051        let mut index = PQIndex::new(config);
1052
1053        // Create training vectors
1054        let training_vectors = vec![
1055            Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1056            Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1057            Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1058            Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1059            Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1060            Vector::new(vec![-0.5, -0.5, -0.5, -0.5]),
1061        ];
1062
1063        // Train the index
1064        index.train(&training_vectors).unwrap();
1065        assert!(index.is_trained);
1066
1067        // Insert vectors
1068        for (i, vec) in training_vectors.iter().enumerate() {
1069            index.insert(format!("vec{i}"), vec.clone()).unwrap();
1070        }
1071
1072        // Search for nearest neighbors
1073        let query = Vector::new(vec![0.9, 0.1, 0.1, 0.9]);
1074        let results = index.search_knn(&query, 3).unwrap();
1075
1076        assert!(!results.is_empty());
1077        assert!(results.len() <= 3);
1078    }
1079
1080    #[test]
1081    fn test_pq_compression() {
1082        let config = PQConfig {
1083            n_subquantizers: 4,
1084            n_centroids: 16,
1085            ..Default::default()
1086        };
1087
1088        let mut index = PQIndex::new(config);
1089
1090        // Create 128-dimensional vectors
1091        let dims = 128;
1092        let training_vectors: Vec<Vector> = (0..100)
1093            .map(|i| {
1094                let values: Vec<f32> = (0..dims).map(|j| ((i + j) as f32).sin()).collect();
1095                Vector::new(values)
1096            })
1097            .collect();
1098
1099        // Train and check compression ratio
1100        index.train(&training_vectors).unwrap();
1101
1102        let compression_ratio = index.compression_ratio();
1103        assert_eq!(compression_ratio, 128.0); // 128*4 bytes -> 4 bytes
1104
1105        let stats = index.stats();
1106        assert_eq!(stats.n_subquantizers, 4);
1107        assert_eq!(stats.n_centroids, 16);
1108        assert_eq!(stats.dimensions, Some(128));
1109    }
1110
1111    #[test]
1112    fn test_pq_reconstruction() {
1113        let config = PQConfig {
1114            n_subquantizers: 2,
1115            n_centroids: 8,
1116            ..Default::default()
1117        };
1118
1119        let mut index = PQIndex::new(config);
1120
1121        // Simple training set
1122        let training_vectors = vec![
1123            Vector::new(vec![1.0, 0.0]),
1124            Vector::new(vec![0.0, 1.0]),
1125            Vector::new(vec![-1.0, 0.0]),
1126            Vector::new(vec![0.0, -1.0]),
1127        ];
1128
1129        index.train(&training_vectors).unwrap();
1130
1131        // Encode and decode a vector
1132        let original = Vector::new(vec![0.7, 0.7]);
1133        let codes = index.encode_vector(&original).unwrap();
1134        let reconstructed = index.decode_codes(&codes).unwrap();
1135
1136        // Check that reconstruction is reasonable (not exact due to quantization)
1137        let dist = original.euclidean_distance(&reconstructed).unwrap();
1138        assert!(dist < 1.0); // Should be reasonably close
1139    }
1140
1141    #[test]
1142    fn test_pq_residual_quantization() {
1143        let config = PQConfig::with_residual_quantization(2, 3, 2); // 2 subquantizers, 3 bits, 2 residual levels
1144        let mut index = PQIndex::new(config);
1145
1146        // Create training vectors
1147        let training_vectors = vec![
1148            Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1149            Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1150            Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1151            Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1152            Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1153            Vector::new(vec![-0.5, -0.5, -0.5, -0.5]),
1154        ];
1155
1156        // Train the index with residual quantization
1157        index.train(&training_vectors).unwrap();
1158        assert!(index.is_trained());
1159        assert_eq!(index.residual_quantizers.len(), 2);
1160
1161        // Test enhanced encoding
1162        let test_vector = Vector::new(vec![0.7, 0.3, 0.3, 0.7]);
1163        let enhanced_codes = index.encode_vector_enhanced(&test_vector).unwrap();
1164
1165        assert!(!enhanced_codes.primary.is_empty());
1166        assert_eq!(enhanced_codes.residual.len(), 2);
1167        assert!(enhanced_codes.multi_codebook.is_empty()); // Multi-codebook not enabled
1168    }
1169
1170    #[test]
1171    fn test_pq_multi_codebook() {
1172        let config = PQConfig::with_multi_codebook(2, 3, 3); // 2 subquantizers, 3 bits, 3 codebooks
1173        let mut index = PQIndex::new(config);
1174
1175        // Create training vectors
1176        let training_vectors = vec![
1177            Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1178            Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1179            Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1180            Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1181            Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1182            Vector::new(vec![-0.5, -0.5, -0.5, -0.5]),
1183        ];
1184
1185        // Train the index with multi-codebook quantization
1186        index.train(&training_vectors).unwrap();
1187        assert!(index.is_trained());
1188        assert_eq!(index.multi_codebook_quantizers.len(), 3);
1189
1190        // Test enhanced encoding
1191        let test_vector = Vector::new(vec![0.7, 0.3, 0.3, 0.7]);
1192        let enhanced_codes = index.encode_vector_enhanced(&test_vector).unwrap();
1193
1194        assert!(!enhanced_codes.primary.is_empty());
1195        assert!(enhanced_codes.residual.is_empty()); // Residual not enabled
1196        assert_eq!(enhanced_codes.multi_codebook.len(), 3);
1197    }
1198
1199    #[test]
1200    fn test_pq_symmetric_distance() {
1201        let config = PQConfig {
1202            enable_symmetric_distance: true,
1203            n_subquantizers: 2,
1204            n_centroids: 4,
1205            ..Default::default()
1206        };
1207
1208        let mut index = PQIndex::new(config);
1209
1210        // Create training vectors
1211        let training_vectors = vec![
1212            Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1213            Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1214            Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1215            Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1216        ];
1217
1218        // Train the index
1219        index.train(&training_vectors).unwrap();
1220        assert!(index.distance_tables.is_some());
1221
1222        // Test symmetric distance computation
1223        let codes1 = vec![0, 1];
1224        let codes2 = vec![1, 0];
1225        let distance = index.symmetric_distance(&codes1, &codes2).unwrap();
1226
1227        assert!(distance >= 0.0);
1228        assert!(distance.is_finite());
1229    }
1230
1231    #[test]
1232    fn test_pq_enhanced_features() {
1233        let config = PQConfig::enhanced(2, 3); // All features enabled
1234        let mut index = PQIndex::new(config);
1235
1236        // Create training vectors
1237        let training_vectors = vec![
1238            Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1239            Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1240            Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1241            Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1242            Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1243            Vector::new(vec![-0.5, -0.5, -0.5, -0.5]),
1244        ];
1245
1246        // Train with all enhanced features
1247        index.train(&training_vectors).unwrap();
1248        assert!(index.is_trained());
1249
1250        // Verify all features are initialized
1251        assert!(!index.residual_quantizers.is_empty());
1252        assert!(!index.multi_codebook_quantizers.is_empty());
1253        assert!(index.distance_tables.is_some());
1254
1255        // Test enhanced encoding and distance computation
1256        let test_vector = Vector::new(vec![0.7, 0.3, 0.3, 0.7]);
1257        let enhanced_codes = index.encode_vector_enhanced(&test_vector).unwrap();
1258        let enhanced_distance = index
1259            .enhanced_distance(&test_vector, &enhanced_codes)
1260            .unwrap();
1261
1262        assert!(enhanced_distance >= 0.0);
1263        assert!(enhanced_distance.is_finite());
1264
1265        // Enhanced distance should be more accurate (smaller) than basic asymmetric distance
1266        let basic_distance = index
1267            .asymmetric_distance(&test_vector, &enhanced_codes.primary)
1268            .unwrap();
1269        assert!(enhanced_distance <= basic_distance * 1.1); // Allow some tolerance
1270    }
1271
1272    #[test]
1273    fn test_pq_config_validation() {
1274        // Test valid enhanced config
1275        let config = PQConfig::enhanced(4, 8);
1276        assert!(config.validate().is_ok());
1277
1278        // Test invalid residual config
1279        let invalid_config = PQConfig {
1280            enable_residual_quantization: true,
1281            residual_levels: 0,
1282            ..Default::default()
1283        };
1284        assert!(invalid_config.validate().is_err());
1285
1286        // Test invalid multi-codebook config
1287        let invalid_config = PQConfig {
1288            enable_multi_codebook: true,
1289            num_codebooks: 1,
1290            ..Default::default()
1291        };
1292        assert!(invalid_config.validate().is_err());
1293    }
1294}