ipfrs_semantic/
quantization.rs

1//! Vector quantization for memory-efficient storage
2//!
3//! This module provides various quantization methods to compress
4//! high-dimensional vectors while preserving similarity search accuracy.
5
6use ipfrs_core::{Error, Result};
7use nalgebra::{DMatrix, DVector};
8use serde::{Deserialize, Serialize};
9
10/// Scalar quantization configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ScalarQuantizerConfig {
13    /// Number of bits for quantization (8 for int8/uint8)
14    pub bits: u8,
15    /// Whether to use signed quantization (int8 vs uint8)
16    pub signed: bool,
17    /// Per-dimension min values
18    pub min_values: Vec<f32>,
19    /// Per-dimension max values
20    pub max_values: Vec<f32>,
21}
22
23/// Scalar quantizer for vector compression
24///
25/// Quantizes floating-point vectors to int8 or uint8, achieving 4x compression
26/// with typically < 5% accuracy loss.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ScalarQuantizer {
29    /// Quantization configuration
30    config: ScalarQuantizerConfig,
31    /// Vector dimension
32    dimension: usize,
33    /// Whether the quantizer has been trained
34    trained: bool,
35}
36
37impl ScalarQuantizer {
38    /// Create a new scalar quantizer for the given dimension
39    ///
40    /// # Arguments
41    /// * `dimension` - Vector dimension
42    /// * `signed` - Use int8 (true) or uint8 (false)
43    pub fn new(dimension: usize, signed: bool) -> Self {
44        Self {
45            config: ScalarQuantizerConfig {
46                bits: 8,
47                signed,
48                min_values: vec![f32::MAX; dimension],
49                max_values: vec![f32::MIN; dimension],
50            },
51            dimension,
52            trained: false,
53        }
54    }
55
56    /// Create a quantizer with uint8 (unsigned) quantization
57    pub fn uint8(dimension: usize) -> Self {
58        Self::new(dimension, false)
59    }
60
61    /// Create a quantizer with int8 (signed) quantization
62    pub fn int8(dimension: usize) -> Self {
63        Self::new(dimension, true)
64    }
65
66    /// Train the quantizer on a set of vectors
67    ///
68    /// Learns min/max values for each dimension to establish
69    /// the quantization range.
70    ///
71    /// # Arguments
72    /// * `vectors` - Training vectors
73    pub fn train(&mut self, vectors: &[Vec<f32>]) -> Result<()> {
74        if vectors.is_empty() {
75            return Err(Error::InvalidInput(
76                "Cannot train on empty vector set".to_string(),
77            ));
78        }
79
80        // Validate dimensions
81        for (i, vec) in vectors.iter().enumerate() {
82            if vec.len() != self.dimension {
83                return Err(Error::InvalidInput(format!(
84                    "Vector {} has dimension {}, expected {}",
85                    i,
86                    vec.len(),
87                    self.dimension
88                )));
89            }
90        }
91
92        // Reset min/max values
93        self.config.min_values = vec![f32::MAX; self.dimension];
94        self.config.max_values = vec![f32::MIN; self.dimension];
95
96        // Compute per-dimension min/max
97        for vec in vectors {
98            for (i, &val) in vec.iter().enumerate() {
99                if val < self.config.min_values[i] {
100                    self.config.min_values[i] = val;
101                }
102                if val > self.config.max_values[i] {
103                    self.config.max_values[i] = val;
104                }
105            }
106        }
107
108        // Add small margin to avoid edge cases
109        for i in 0..self.dimension {
110            let range = self.config.max_values[i] - self.config.min_values[i];
111            if range < 1e-6 {
112                // Constant dimension, set arbitrary range
113                self.config.min_values[i] -= 0.5;
114                self.config.max_values[i] += 0.5;
115            } else {
116                let margin = range * 0.01;
117                self.config.min_values[i] -= margin;
118                self.config.max_values[i] += margin;
119            }
120        }
121
122        self.trained = true;
123        Ok(())
124    }
125
126    /// Train on a single vector (incremental training)
127    pub fn train_incremental(&mut self, vector: &[f32]) -> Result<()> {
128        if vector.len() != self.dimension {
129            return Err(Error::InvalidInput(format!(
130                "Vector has dimension {}, expected {}",
131                vector.len(),
132                self.dimension
133            )));
134        }
135
136        for (i, &val) in vector.iter().enumerate() {
137            if val < self.config.min_values[i] {
138                self.config.min_values[i] = val;
139            }
140            if val > self.config.max_values[i] {
141                self.config.max_values[i] = val;
142            }
143        }
144
145        self.trained = true;
146        Ok(())
147    }
148
149    /// Check if the quantizer has been trained
150    pub fn is_trained(&self) -> bool {
151        self.trained
152    }
153
154    /// Quantize a vector to uint8
155    pub fn quantize(&self, vector: &[f32]) -> Result<QuantizedVector> {
156        if !self.trained {
157            return Err(Error::InvalidInput(
158                "Quantizer must be trained before use".to_string(),
159            ));
160        }
161
162        if vector.len() != self.dimension {
163            return Err(Error::InvalidInput(format!(
164                "Vector has dimension {}, expected {}",
165                vector.len(),
166                self.dimension
167            )));
168        }
169
170        let mut quantized = Vec::with_capacity(self.dimension);
171
172        for (i, &val) in vector.iter().enumerate() {
173            let min = self.config.min_values[i];
174            let max = self.config.max_values[i];
175            let range = max - min;
176
177            // Normalize to [0, 1]
178            let normalized = if range > 1e-6 {
179                ((val - min) / range).clamp(0.0, 1.0)
180            } else {
181                0.5
182            };
183
184            // Quantize
185            let q = if self.config.signed {
186                // int8: [-128, 127]
187                ((normalized * 255.0 - 128.0).round() as i8) as u8
188            } else {
189                // uint8: [0, 255]
190                (normalized * 255.0).round() as u8
191            };
192
193            quantized.push(q);
194        }
195
196        Ok(QuantizedVector {
197            data: quantized,
198            signed: self.config.signed,
199        })
200    }
201
202    /// Dequantize a vector back to f32
203    pub fn dequantize(&self, quantized: &QuantizedVector) -> Result<Vec<f32>> {
204        if !self.trained {
205            return Err(Error::InvalidInput(
206                "Quantizer must be trained before use".to_string(),
207            ));
208        }
209
210        if quantized.data.len() != self.dimension {
211            return Err(Error::InvalidInput(format!(
212                "Quantized vector has dimension {}, expected {}",
213                quantized.data.len(),
214                self.dimension
215            )));
216        }
217
218        let mut result = Vec::with_capacity(self.dimension);
219
220        for (i, &q) in quantized.data.iter().enumerate() {
221            let min = self.config.min_values[i];
222            let max = self.config.max_values[i];
223            let range = max - min;
224
225            // Dequantize
226            let normalized = if self.config.signed {
227                // int8: [-128, 127] -> [0, 1]
228                ((q as i8) as f32 + 128.0) / 255.0
229            } else {
230                // uint8: [0, 255] -> [0, 1]
231                q as f32 / 255.0
232            };
233
234            let val = min + normalized * range;
235            result.push(val);
236        }
237
238        Ok(result)
239    }
240
241    /// Compute L2 distance between two quantized vectors (approximate)
242    ///
243    /// Uses integer arithmetic for fast computation
244    pub fn distance_l2_quantized(&self, a: &QuantizedVector, b: &QuantizedVector) -> Result<f32> {
245        if a.data.len() != b.data.len() {
246            return Err(Error::InvalidInput(
247                "Vectors must have same dimension".to_string(),
248            ));
249        }
250
251        let mut sum_sq: i64 = 0;
252
253        for (qa, qb) in a.data.iter().zip(b.data.iter()) {
254            let diff = if a.signed {
255                (*qa as i8 as i64) - (*qb as i8 as i64)
256            } else {
257                (*qa as i64) - (*qb as i64)
258            };
259            sum_sq += diff * diff;
260        }
261
262        // Scale back to approximate original distance
263        // This is an approximation since we lose per-dimension scaling info
264        Ok((sum_sq as f32).sqrt() / 255.0)
265    }
266
267    /// Compute dot product between two quantized vectors (approximate)
268    pub fn dot_product_quantized(&self, a: &QuantizedVector, b: &QuantizedVector) -> Result<f32> {
269        if a.data.len() != b.data.len() {
270            return Err(Error::InvalidInput(
271                "Vectors must have same dimension".to_string(),
272            ));
273        }
274
275        let mut sum: i64 = 0;
276
277        for (qa, qb) in a.data.iter().zip(b.data.iter()) {
278            if a.signed {
279                sum += (*qa as i8 as i64) * (*qb as i8 as i64);
280            } else {
281                sum += (*qa as i64) * (*qb as i64);
282            }
283        }
284
285        // Normalize
286        Ok(sum as f32 / (255.0 * 255.0))
287    }
288
289    /// Get the dimension
290    pub fn dimension(&self) -> usize {
291        self.dimension
292    }
293
294    /// Get compression ratio (always 4x for 8-bit quantization)
295    pub fn compression_ratio(&self) -> f32 {
296        4.0 // f32 (4 bytes) -> u8 (1 byte)
297    }
298
299    /// Get memory usage estimate for a given number of vectors
300    pub fn memory_estimate(&self, num_vectors: usize) -> usize {
301        // Per-vector: dimension bytes
302        // Plus overhead: 2 * dimension * 4 bytes for min/max values
303        num_vectors * self.dimension + 2 * self.dimension * 4
304    }
305}
306
307/// Quantized vector storage
308#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct QuantizedVector {
310    /// Quantized data (uint8 storage)
311    pub data: Vec<u8>,
312    /// Whether values are signed
313    pub signed: bool,
314}
315
316impl QuantizedVector {
317    /// Create a new quantized vector
318    pub fn new(data: Vec<u8>, signed: bool) -> Self {
319        Self { data, signed }
320    }
321
322    /// Get the dimension
323    pub fn dimension(&self) -> usize {
324        self.data.len()
325    }
326
327    /// Get the raw bytes
328    pub fn as_bytes(&self) -> &[u8] {
329        &self.data
330    }
331
332    /// Get memory size in bytes
333    pub fn size_bytes(&self) -> usize {
334        self.data.len()
335    }
336}
337
338/// Product Quantization configuration
339#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct ProductQuantizerConfig {
341    /// Vector dimension
342    pub dimension: usize,
343    /// Number of sub-quantizers (sub-vectors)
344    pub num_subquantizers: usize,
345    /// Bits per sub-quantizer (usually 8)
346    pub bits_per_subquantizer: u8,
347    /// Codebooks for each sub-quantizer (centroids)
348    pub codebooks: Vec<Vec<Vec<f32>>>,
349}
350
351/// Product Quantizer for high compression
352///
353/// Achieves 8-32x compression by dividing vectors into sub-vectors
354/// and quantizing each with a codebook.
355#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct ProductQuantizer {
357    /// Configuration
358    config: ProductQuantizerConfig,
359    /// Sub-vector dimension
360    subdimension: usize,
361    /// Number of centroids per sub-quantizer
362    num_centroids: usize,
363    /// Whether trained
364    trained: bool,
365}
366
367impl ProductQuantizer {
368    /// Create a new product quantizer
369    ///
370    /// # Arguments
371    /// * `dimension` - Vector dimension (must be divisible by num_subquantizers)
372    /// * `num_subquantizers` - Number of sub-quantizers (typically 8, 16, or 32)
373    /// * `bits` - Bits per code (typically 8, giving 256 centroids)
374    pub fn new(dimension: usize, num_subquantizers: usize, bits: u8) -> Result<Self> {
375        if !dimension.is_multiple_of(num_subquantizers) {
376            return Err(Error::InvalidInput(format!(
377                "Dimension {} must be divisible by num_subquantizers {}",
378                dimension, num_subquantizers
379            )));
380        }
381
382        if bits > 16 {
383            return Err(Error::InvalidInput(
384                "Bits per subquantizer must be <= 16".to_string(),
385            ));
386        }
387
388        let subdimension = dimension / num_subquantizers;
389        let num_centroids = 1 << bits;
390
391        Ok(Self {
392            config: ProductQuantizerConfig {
393                dimension,
394                num_subquantizers,
395                bits_per_subquantizer: bits,
396                codebooks: Vec::new(),
397            },
398            subdimension,
399            num_centroids,
400            trained: false,
401        })
402    }
403
404    /// Create a standard PQ with 8 sub-quantizers and 8 bits
405    pub fn standard(dimension: usize) -> Result<Self> {
406        Self::new(dimension, 8, 8)
407    }
408
409    /// Train the product quantizer using k-means clustering
410    ///
411    /// # Arguments
412    /// * `vectors` - Training vectors
413    /// * `max_iterations` - Maximum k-means iterations
414    pub fn train(&mut self, vectors: &[Vec<f32>], max_iterations: usize) -> Result<()> {
415        if vectors.is_empty() {
416            return Err(Error::InvalidInput(
417                "Cannot train on empty vector set".to_string(),
418            ));
419        }
420
421        // Validate dimensions
422        for (i, vec) in vectors.iter().enumerate() {
423            if vec.len() != self.config.dimension {
424                return Err(Error::InvalidInput(format!(
425                    "Vector {} has dimension {}, expected {}",
426                    i,
427                    vec.len(),
428                    self.config.dimension
429                )));
430            }
431        }
432
433        // Train each sub-quantizer independently
434        self.config.codebooks = Vec::with_capacity(self.config.num_subquantizers);
435
436        for sq in 0..self.config.num_subquantizers {
437            let start = sq * self.subdimension;
438            let end = start + self.subdimension;
439
440            // Extract sub-vectors for this sub-quantizer
441            let subvectors: Vec<Vec<f32>> =
442                vectors.iter().map(|v| v[start..end].to_vec()).collect();
443
444            // Run k-means to find centroids
445            let centroids = self.kmeans(&subvectors, self.num_centroids, max_iterations)?;
446            self.config.codebooks.push(centroids);
447        }
448
449        self.trained = true;
450        Ok(())
451    }
452
453    /// Simple k-means implementation
454    fn kmeans(&self, data: &[Vec<f32>], k: usize, max_iterations: usize) -> Result<Vec<Vec<f32>>> {
455        if data.is_empty() {
456            return Err(Error::InvalidInput("Empty data for k-means".to_string()));
457        }
458
459        let dim = data[0].len();
460        let n = data.len();
461        let actual_k = k.min(n); // Can't have more centroids than data points
462
463        // Initialize centroids using k-means++ style
464        let mut centroids = Vec::with_capacity(actual_k);
465
466        // Pick first centroid randomly (deterministically using first vector)
467        centroids.push(data[0].clone());
468
469        // Pick remaining centroids with probability proportional to distance
470        for _ in 1..actual_k {
471            let mut best_idx = 0;
472            let mut best_dist = 0.0f32;
473
474            for (i, vec) in data.iter().enumerate() {
475                let min_dist = centroids
476                    .iter()
477                    .map(|c| self.l2_distance(vec, c))
478                    .fold(f32::MAX, |a, b| a.min(b));
479
480                if min_dist > best_dist {
481                    best_dist = min_dist;
482                    best_idx = i;
483                }
484            }
485
486            centroids.push(data[best_idx].clone());
487        }
488
489        // Run k-means iterations
490        let mut assignments = vec![0usize; n];
491
492        for _iter in 0..max_iterations {
493            // Assign points to nearest centroid
494            let mut changed = false;
495            for (i, vec) in data.iter().enumerate() {
496                let nearest = centroids
497                    .iter()
498                    .enumerate()
499                    .map(|(j, c)| (j, self.l2_distance(vec, c)))
500                    .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
501                    .map(|(j, _)| j)
502                    .unwrap_or(0);
503
504                if assignments[i] != nearest {
505                    assignments[i] = nearest;
506                    changed = true;
507                }
508            }
509
510            if !changed {
511                break;
512            }
513
514            // Update centroids
515            let mut new_centroids = vec![vec![0.0f32; dim]; actual_k];
516            let mut counts = vec![0usize; actual_k];
517
518            for (i, vec) in data.iter().enumerate() {
519                let cluster = assignments[i];
520                counts[cluster] += 1;
521                for (j, &val) in vec.iter().enumerate() {
522                    new_centroids[cluster][j] += val;
523                }
524            }
525
526            for (i, centroid) in new_centroids.iter_mut().enumerate() {
527                if counts[i] > 0 {
528                    for val in centroid.iter_mut() {
529                        *val /= counts[i] as f32;
530                    }
531                } else {
532                    // Keep old centroid if empty
533                    *centroid = centroids[i].clone();
534                }
535            }
536
537            centroids = new_centroids;
538        }
539
540        // Ensure we have exactly k centroids (pad with duplicates if needed)
541        while centroids.len() < k {
542            centroids.push(centroids[centroids.len() - 1].clone());
543        }
544
545        Ok(centroids)
546    }
547
548    fn l2_distance(&self, a: &[f32], b: &[f32]) -> f32 {
549        a.iter()
550            .zip(b.iter())
551            .map(|(x, y)| (x - y) * (x - y))
552            .sum::<f32>()
553            .sqrt()
554    }
555
556    /// Quantize a vector to PQ codes
557    pub fn quantize(&self, vector: &[f32]) -> Result<PQCode> {
558        if !self.trained {
559            return Err(Error::InvalidInput(
560                "Product quantizer must be trained before use".to_string(),
561            ));
562        }
563
564        if vector.len() != self.config.dimension {
565            return Err(Error::InvalidInput(format!(
566                "Vector has dimension {}, expected {}",
567                vector.len(),
568                self.config.dimension
569            )));
570        }
571
572        let mut codes = Vec::with_capacity(self.config.num_subquantizers);
573
574        for sq in 0..self.config.num_subquantizers {
575            let start = sq * self.subdimension;
576            let end = start + self.subdimension;
577            let subvector = &vector[start..end];
578
579            // Find nearest centroid
580            let codebook = &self.config.codebooks[sq];
581            let (best_idx, _) = codebook
582                .iter()
583                .enumerate()
584                .map(|(i, c)| (i, self.l2_distance(subvector, c)))
585                .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
586                .unwrap_or((0, 0.0));
587
588            codes.push(best_idx as u8);
589        }
590
591        Ok(PQCode { codes })
592    }
593
594    /// Dequantize PQ codes back to approximate vector
595    pub fn dequantize(&self, code: &PQCode) -> Result<Vec<f32>> {
596        if !self.trained {
597            return Err(Error::InvalidInput(
598                "Product quantizer must be trained before use".to_string(),
599            ));
600        }
601
602        if code.codes.len() != self.config.num_subquantizers {
603            return Err(Error::InvalidInput(format!(
604                "PQ code has {} elements, expected {}",
605                code.codes.len(),
606                self.config.num_subquantizers
607            )));
608        }
609
610        let mut result = Vec::with_capacity(self.config.dimension);
611
612        for (sq, &idx) in code.codes.iter().enumerate() {
613            let centroid = &self.config.codebooks[sq][idx as usize];
614            result.extend_from_slice(centroid);
615        }
616
617        Ok(result)
618    }
619
620    /// Compute asymmetric distance (query is not quantized)
621    ///
622    /// This is the preferred method for search as it's more accurate
623    pub fn asymmetric_distance(&self, query: &[f32], code: &PQCode) -> Result<f32> {
624        if !self.trained {
625            return Err(Error::InvalidInput(
626                "Product quantizer must be trained".to_string(),
627            ));
628        }
629
630        let mut total_dist_sq = 0.0f32;
631
632        for sq in 0..self.config.num_subquantizers {
633            let start = sq * self.subdimension;
634            let end = start + self.subdimension;
635            let subquery = &query[start..end];
636            let centroid = &self.config.codebooks[sq][code.codes[sq] as usize];
637
638            for (q, c) in subquery.iter().zip(centroid.iter()) {
639                let diff = q - c;
640                total_dist_sq += diff * diff;
641            }
642        }
643
644        Ok(total_dist_sq.sqrt())
645    }
646
647    /// Precompute distance tables for fast ADC (Asymmetric Distance Computation)
648    ///
649    /// Returns a table\[sq\]\[centroid\] = distance from query subvector to centroid
650    pub fn compute_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>> {
651        if !self.trained {
652            return Err(Error::InvalidInput(
653                "Product quantizer must be trained".to_string(),
654            ));
655        }
656
657        let mut table = Vec::with_capacity(self.config.num_subquantizers);
658
659        for sq in 0..self.config.num_subquantizers {
660            let start = sq * self.subdimension;
661            let end = start + self.subdimension;
662            let subquery = &query[start..end];
663
664            let distances: Vec<f32> = self.config.codebooks[sq]
665                .iter()
666                .map(|c| {
667                    subquery
668                        .iter()
669                        .zip(c.iter())
670                        .map(|(q, c)| (q - c) * (q - c))
671                        .sum::<f32>()
672                })
673                .collect();
674
675            table.push(distances);
676        }
677
678        Ok(table)
679    }
680
681    /// Fast distance computation using precomputed table
682    pub fn distance_from_table(&self, table: &[Vec<f32>], code: &PQCode) -> f32 {
683        let mut total = 0.0f32;
684        for (sq, &idx) in code.codes.iter().enumerate() {
685            total += table[sq][idx as usize];
686        }
687        total.sqrt()
688    }
689
690    /// Get compression ratio
691    pub fn compression_ratio(&self) -> f32 {
692        // Original: dimension * 4 bytes (f32)
693        // Compressed: num_subquantizers * 1 byte
694        (self.config.dimension * 4) as f32 / self.config.num_subquantizers as f32
695    }
696
697    /// Check if trained
698    pub fn is_trained(&self) -> bool {
699        self.trained
700    }
701}
702
703/// Product Quantization code
704#[derive(Debug, Clone, Serialize, Deserialize)]
705pub struct PQCode {
706    /// Centroid indices for each sub-quantizer
707    pub codes: Vec<u8>,
708}
709
710impl PQCode {
711    /// Get memory size in bytes
712    pub fn size_bytes(&self) -> usize {
713        self.codes.len()
714    }
715}
716
717/// Optimized Product Quantization (OPQ)
718///
719/// OPQ extends PQ by learning a rotation matrix that minimizes quantization error.
720/// It achieves better accuracy than standard PQ at the same compression ratio.
721#[derive(Debug, Clone, Serialize, Deserialize)]
722pub struct OptimizedProductQuantizer {
723    /// Base product quantizer
724    pq: ProductQuantizer,
725    /// Rotation matrix (dimension x dimension)
726    #[serde(with = "rotation_matrix_serde")]
727    rotation: Option<DMatrix<f32>>,
728    /// Whether rotation is trained
729    rotation_trained: bool,
730}
731
732// Custom serialization for DMatrix
733mod rotation_matrix_serde {
734    use super::DMatrix;
735    use serde::{Deserialize, Deserializer, Serialize, Serializer};
736
737    #[derive(Serialize, Deserialize)]
738    struct MatrixData {
739        nrows: usize,
740        ncols: usize,
741        data: Vec<f32>,
742    }
743
744    pub fn serialize<S>(
745        matrix: &Option<DMatrix<f32>>,
746        serializer: S,
747    ) -> std::result::Result<S::Ok, S::Error>
748    where
749        S: Serializer,
750    {
751        let opt_data = matrix.as_ref().map(|m| MatrixData {
752            nrows: m.nrows(),
753            ncols: m.ncols(),
754            data: m.as_slice().to_vec(),
755        });
756        opt_data.serialize(serializer)
757    }
758
759    pub fn deserialize<'de, D>(
760        deserializer: D,
761    ) -> std::result::Result<Option<DMatrix<f32>>, D::Error>
762    where
763        D: Deserializer<'de>,
764    {
765        let opt: Option<MatrixData> = Option::deserialize(deserializer)?;
766        Ok(opt.map(|data| DMatrix::from_vec(data.nrows, data.ncols, data.data)))
767    }
768}
769
770impl OptimizedProductQuantizer {
771    /// Create a new OPQ quantizer
772    ///
773    /// # Arguments
774    /// * `dimension` - Vector dimension (must be divisible by num_subquantizers)
775    /// * `num_subquantizers` - Number of sub-quantizers
776    /// * `bits` - Bits per code (typically 8)
777    pub fn new(dimension: usize, num_subquantizers: usize, bits: u8) -> Result<Self> {
778        let pq = ProductQuantizer::new(dimension, num_subquantizers, bits)?;
779        Ok(Self {
780            pq,
781            rotation: None,
782            rotation_trained: false,
783        })
784    }
785
786    /// Create standard OPQ with 8 sub-quantizers and 8 bits
787    pub fn standard(dimension: usize) -> Result<Self> {
788        Self::new(dimension, 8, 8)
789    }
790
791    /// Train OPQ with rotation learning
792    ///
793    /// Uses iterative optimization: alternate between learning rotation and PQ codebooks
794    ///
795    /// # Arguments
796    /// * `vectors` - Training vectors
797    /// * `max_iterations` - Max iterations for PQ k-means
798    /// * `rotation_iterations` - Iterations for rotation learning (typically 5-10)
799    #[allow(clippy::too_many_arguments)]
800    pub fn train(
801        &mut self,
802        vectors: &[Vec<f32>],
803        max_iterations: usize,
804        rotation_iterations: usize,
805    ) -> Result<()> {
806        if vectors.is_empty() {
807            return Err(Error::InvalidInput(
808                "Cannot train on empty vector set".to_string(),
809            ));
810        }
811
812        let dim = self.pq.config.dimension;
813
814        // Validate dimensions
815        for (i, vec) in vectors.iter().enumerate() {
816            if vec.len() != dim {
817                return Err(Error::InvalidInput(format!(
818                    "Vector {} has dimension {}, expected {}",
819                    i,
820                    vec.len(),
821                    dim
822                )));
823            }
824        }
825
826        // Initialize with identity rotation
827        let mut rotation = DMatrix::<f32>::identity(dim, dim);
828
829        // Iteratively optimize rotation and PQ
830        for iteration in 0..rotation_iterations {
831            // Step 1: Rotate vectors
832            let rotated = self.apply_rotation_batch(vectors, &rotation);
833
834            // Step 2: Train PQ on rotated vectors
835            self.pq.train(&rotated, max_iterations)?;
836
837            // Step 3: Learn better rotation (only if not last iteration)
838            if iteration < rotation_iterations - 1 {
839                rotation = self.learn_rotation(vectors, &self.pq)?;
840            }
841        }
842
843        self.rotation = Some(rotation);
844        self.rotation_trained = true;
845
846        Ok(())
847    }
848
849    /// Learn rotation matrix using SVD
850    ///
851    /// Finds rotation that aligns data with PQ structure
852    #[allow(dead_code)]
853    fn learn_rotation(&self, vectors: &[Vec<f32>], pq: &ProductQuantizer) -> Result<DMatrix<f32>> {
854        let dim = pq.config.dimension;
855        let n = vectors.len();
856
857        // Compute covariance between original and reconstructed vectors
858        let mut cov = DMatrix::<f32>::zeros(dim, dim);
859
860        for vec in vectors {
861            // Quantize and reconstruct
862            let code = pq.quantize(vec)?;
863            let reconstructed = pq.dequantize(&code)?;
864
865            // Compute outer product
866            let v = DVector::from_vec(vec.clone());
867            let r = DVector::from_vec(reconstructed);
868            cov += v * r.transpose();
869        }
870
871        cov /= n as f32;
872
873        // SVD to find optimal rotation
874        let svd = cov.svd(true, true);
875
876        // Rotation = U * V^T
877        match (svd.u, svd.v_t) {
878            (Some(u), Some(vt)) => Ok(u * vt),
879            _ => {
880                // If SVD fails, return identity
881                Ok(DMatrix::identity(dim, dim))
882            }
883        }
884    }
885
886    /// Apply rotation to a batch of vectors
887    fn apply_rotation_batch(&self, vectors: &[Vec<f32>], rotation: &DMatrix<f32>) -> Vec<Vec<f32>> {
888        vectors
889            .iter()
890            .map(|v| self.apply_rotation(v, rotation))
891            .collect()
892    }
893
894    /// Apply rotation to a single vector
895    fn apply_rotation(&self, vector: &[f32], rotation: &DMatrix<f32>) -> Vec<f32> {
896        let v = DVector::from_vec(vector.to_vec());
897        let rotated = rotation * v;
898        rotated.as_slice().to_vec()
899    }
900
901    /// Quantize a vector
902    pub fn quantize(&self, vector: &[f32]) -> Result<PQCode> {
903        if !self.is_trained() {
904            return Err(Error::InvalidInput("OPQ must be trained".to_string()));
905        }
906
907        // Apply rotation before quantization
908        let rotated = match &self.rotation {
909            Some(r) => self.apply_rotation(vector, r),
910            None => vector.to_vec(),
911        };
912
913        self.pq.quantize(&rotated)
914    }
915
916    /// Dequantize a code back to vector
917    pub fn dequantize(&self, code: &PQCode) -> Result<Vec<f32>> {
918        if !self.is_trained() {
919            return Err(Error::InvalidInput("OPQ must be trained".to_string()));
920        }
921
922        // Dequantize
923        let rotated = self.pq.dequantize(code)?;
924
925        // Apply inverse rotation
926        match &self.rotation {
927            Some(r) => {
928                // For orthogonal matrices, inverse = transpose
929                let r_inv = r.transpose();
930                Ok(self.apply_rotation(&rotated, &r_inv))
931            }
932            None => Ok(rotated),
933        }
934    }
935
936    /// Compute asymmetric distance (query is not quantized)
937    pub fn asymmetric_distance(&self, query: &[f32], code: &PQCode) -> Result<f32> {
938        if !self.is_trained() {
939            return Err(Error::InvalidInput("OPQ must be trained".to_string()));
940        }
941
942        // Rotate query
943        let rotated_query = match &self.rotation {
944            Some(r) => self.apply_rotation(query, r),
945            None => query.to_vec(),
946        };
947
948        self.pq.asymmetric_distance(&rotated_query, code)
949    }
950
951    /// Compute distance table for fast batch queries
952    pub fn compute_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>> {
953        if !self.is_trained() {
954            return Err(Error::InvalidInput("OPQ must be trained".to_string()));
955        }
956
957        // Rotate query
958        let rotated_query = match &self.rotation {
959            Some(r) => self.apply_rotation(query, r),
960            None => query.to_vec(),
961        };
962
963        self.pq.compute_distance_table(&rotated_query)
964    }
965
966    /// Fast distance using precomputed table
967    pub fn distance_from_table(&self, table: &[Vec<f32>], code: &PQCode) -> f32 {
968        self.pq.distance_from_table(table, code)
969    }
970
971    /// Get compression ratio
972    pub fn compression_ratio(&self) -> f32 {
973        self.pq.compression_ratio()
974    }
975
976    /// Check if trained
977    pub fn is_trained(&self) -> bool {
978        self.pq.is_trained() && self.rotation_trained
979    }
980
981    /// Get the underlying PQ (for testing)
982    #[allow(dead_code)]
983    pub fn inner_pq(&self) -> &ProductQuantizer {
984        &self.pq
985    }
986}
987
988/// Quantization benchmark results
989#[derive(Debug, Clone, Serialize, Deserialize)]
990pub struct QuantizationBenchmark {
991    /// Recall at k for various k values
992    pub recall_at_k: Vec<(usize, f32)>,
993    /// Compression ratio achieved
994    pub compression_ratio: f32,
995    /// Average quantization error
996    pub avg_quantization_error: f32,
997    /// Max quantization error
998    pub max_quantization_error: f32,
999    /// Memory savings in bytes
1000    pub memory_savings: usize,
1001    /// Speed improvement factor (approximate)
1002    pub speed_factor: f32,
1003}
1004
1005impl QuantizationBenchmark {
1006    /// Generate a summary string
1007    pub fn summary(&self) -> String {
1008        let recall_str: Vec<String> = self
1009            .recall_at_k
1010            .iter()
1011            .map(|(k, r)| format!("R@{}: {:.2}%", k, r * 100.0))
1012            .collect();
1013
1014        format!(
1015            "Compression: {:.1}x, Avg Error: {:.4}, {}, Memory Saved: {} bytes",
1016            self.compression_ratio,
1017            self.avg_quantization_error,
1018            recall_str.join(", "),
1019            self.memory_savings
1020        )
1021    }
1022}
1023
1024/// Benchmark utilities for quantization evaluation
1025pub struct QuantizationBenchmarker;
1026
1027impl QuantizationBenchmarker {
1028    /// Benchmark scalar quantization on a dataset
1029    ///
1030    /// # Arguments
1031    /// * `quantizer` - Trained scalar quantizer
1032    /// * `vectors` - Test vectors
1033    /// * `queries` - Query vectors
1034    /// * `ground_truth` - Ground truth k-NN for each query (indices into vectors)
1035    /// * `k_values` - Values of k to measure recall at
1036    pub fn benchmark_scalar(
1037        quantizer: &ScalarQuantizer,
1038        vectors: &[Vec<f32>],
1039        queries: &[Vec<f32>],
1040        ground_truth: &[Vec<usize>],
1041        k_values: &[usize],
1042    ) -> Result<QuantizationBenchmark> {
1043        if !quantizer.is_trained() {
1044            return Err(Error::InvalidInput("Quantizer must be trained".to_string()));
1045        }
1046
1047        // Quantize all vectors
1048        let quantized: Vec<QuantizedVector> = vectors
1049            .iter()
1050            .map(|v| quantizer.quantize(v))
1051            .collect::<Result<Vec<_>>>()?;
1052
1053        // Compute quantization error
1054        let mut total_error = 0.0f32;
1055        let mut max_error = 0.0f32;
1056
1057        for (i, qv) in quantized.iter().enumerate() {
1058            let restored = quantizer.dequantize(qv)?;
1059            let error: f32 = vectors[i]
1060                .iter()
1061                .zip(restored.iter())
1062                .map(|(a, b)| (a - b).powi(2))
1063                .sum::<f32>()
1064                .sqrt();
1065            total_error += error;
1066            max_error = max_error.max(error);
1067        }
1068
1069        let avg_error = total_error / vectors.len() as f32;
1070
1071        // Compute recall at k
1072        let mut recall_at_k = Vec::new();
1073
1074        for &k in k_values {
1075            let mut total_recall = 0.0f32;
1076
1077            for (qi, query) in queries.iter().enumerate() {
1078                let query_quantized = quantizer.quantize(query)?;
1079
1080                // Find k nearest using quantized distances
1081                let mut distances: Vec<(usize, f32)> = quantized
1082                    .iter()
1083                    .enumerate()
1084                    .map(|(i, qv)| {
1085                        let dist = quantizer
1086                            .distance_l2_quantized(&query_quantized, qv)
1087                            .unwrap_or(f32::MAX);
1088                        (i, dist)
1089                    })
1090                    .collect();
1091
1092                distances
1093                    .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1094
1095                let found: std::collections::HashSet<usize> =
1096                    distances.iter().take(k).map(|(i, _)| *i).collect();
1097
1098                let gt: std::collections::HashSet<usize> =
1099                    ground_truth[qi].iter().take(k).copied().collect();
1100
1101                let intersection = found.intersection(&gt).count();
1102                total_recall += intersection as f32 / k.min(gt.len()) as f32;
1103            }
1104
1105            let recall = total_recall / queries.len() as f32;
1106            recall_at_k.push((k, recall));
1107        }
1108
1109        // Calculate memory savings
1110        let original_size = vectors.len() * vectors[0].len() * 4; // f32 = 4 bytes
1111        let quantized_size = vectors.len() * vectors[0].len(); // u8 = 1 byte
1112        let memory_savings = original_size - quantized_size;
1113
1114        Ok(QuantizationBenchmark {
1115            recall_at_k,
1116            compression_ratio: quantizer.compression_ratio(),
1117            avg_quantization_error: avg_error,
1118            max_quantization_error: max_error,
1119            memory_savings,
1120            speed_factor: 2.0, // Approximate for int ops vs float ops
1121        })
1122    }
1123
1124    /// Benchmark product quantization on a dataset
1125    pub fn benchmark_pq(
1126        pq: &ProductQuantizer,
1127        vectors: &[Vec<f32>],
1128        queries: &[Vec<f32>],
1129        ground_truth: &[Vec<usize>],
1130        k_values: &[usize],
1131    ) -> Result<QuantizationBenchmark> {
1132        if !pq.is_trained() {
1133            return Err(Error::InvalidInput("PQ must be trained".to_string()));
1134        }
1135
1136        // Quantize all vectors
1137        let codes: Vec<PQCode> = vectors
1138            .iter()
1139            .map(|v| pq.quantize(v))
1140            .collect::<Result<Vec<_>>>()?;
1141
1142        // Compute quantization error
1143        let mut total_error = 0.0f32;
1144        let mut max_error = 0.0f32;
1145
1146        for (i, code) in codes.iter().enumerate() {
1147            let restored = pq.dequantize(code)?;
1148            let error: f32 = vectors[i]
1149                .iter()
1150                .zip(restored.iter())
1151                .map(|(a, b)| (a - b).powi(2))
1152                .sum::<f32>()
1153                .sqrt();
1154            total_error += error;
1155            max_error = max_error.max(error);
1156        }
1157
1158        let avg_error = total_error / vectors.len() as f32;
1159
1160        // Compute recall at k using asymmetric distance
1161        let mut recall_at_k = Vec::new();
1162
1163        for &k in k_values {
1164            let mut total_recall = 0.0f32;
1165
1166            for (qi, query) in queries.iter().enumerate() {
1167                // Use distance table for fast computation
1168                let table = pq.compute_distance_table(query)?;
1169
1170                let mut distances: Vec<(usize, f32)> = codes
1171                    .iter()
1172                    .enumerate()
1173                    .map(|(i, code)| (i, pq.distance_from_table(&table, code)))
1174                    .collect();
1175
1176                distances
1177                    .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1178
1179                let found: std::collections::HashSet<usize> =
1180                    distances.iter().take(k).map(|(i, _)| *i).collect();
1181
1182                let gt: std::collections::HashSet<usize> =
1183                    ground_truth[qi].iter().take(k).copied().collect();
1184
1185                let intersection = found.intersection(&gt).count();
1186                total_recall += intersection as f32 / k.min(gt.len()) as f32;
1187            }
1188
1189            let recall = total_recall / queries.len() as f32;
1190            recall_at_k.push((k, recall));
1191        }
1192
1193        // Calculate memory savings
1194        let original_size = vectors.len() * vectors[0].len() * 4;
1195        let quantized_size = vectors.len() * codes[0].size_bytes();
1196        let memory_savings = original_size.saturating_sub(quantized_size);
1197
1198        Ok(QuantizationBenchmark {
1199            recall_at_k,
1200            compression_ratio: pq.compression_ratio(),
1201            avg_quantization_error: avg_error,
1202            max_quantization_error: max_error,
1203            memory_savings,
1204            speed_factor: 4.0, // Approximate for table lookup vs float ops
1205        })
1206    }
1207
1208    /// Compute ground truth k-NN using brute force L2 distance
1209    pub fn compute_ground_truth(
1210        vectors: &[Vec<f32>],
1211        queries: &[Vec<f32>],
1212        k: usize,
1213    ) -> Vec<Vec<usize>> {
1214        queries
1215            .iter()
1216            .map(|query| {
1217                let mut distances: Vec<(usize, f32)> = vectors
1218                    .iter()
1219                    .enumerate()
1220                    .map(|(i, v)| {
1221                        let dist: f32 = query
1222                            .iter()
1223                            .zip(v.iter())
1224                            .map(|(a, b)| (a - b).powi(2))
1225                            .sum();
1226                        (i, dist)
1227                    })
1228                    .collect();
1229
1230                distances
1231                    .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1232
1233                distances.iter().take(k).map(|(i, _)| *i).collect()
1234            })
1235            .collect()
1236    }
1237
1238    /// Compare multiple quantization methods
1239    pub fn compare_methods(
1240        vectors: &[Vec<f32>],
1241        queries: &[Vec<f32>],
1242        k_values: &[usize],
1243    ) -> Result<QuantizationComparison> {
1244        let max_k = *k_values.iter().max().unwrap_or(&10);
1245        let ground_truth = Self::compute_ground_truth(vectors, queries, max_k);
1246
1247        // Benchmark scalar quantization
1248        let mut sq = ScalarQuantizer::uint8(vectors[0].len());
1249        sq.train(vectors)?;
1250        let scalar_results =
1251            Self::benchmark_scalar(&sq, vectors, queries, &ground_truth, k_values)?;
1252
1253        // Benchmark PQ (if dimension allows)
1254        let dim = vectors[0].len();
1255        let pq_results = if dim >= 8 && dim.is_multiple_of(8) {
1256            let mut pq = ProductQuantizer::new(dim, 8, 8)?;
1257            pq.train(vectors, 20)?;
1258            Some(Self::benchmark_pq(
1259                &pq,
1260                vectors,
1261                queries,
1262                &ground_truth,
1263                k_values,
1264            )?)
1265        } else {
1266            None
1267        };
1268
1269        Ok(QuantizationComparison {
1270            scalar: scalar_results,
1271            product: pq_results,
1272            dataset_size: vectors.len(),
1273            dimension: dim,
1274        })
1275    }
1276}
1277
1278/// Comparison of multiple quantization methods
1279#[derive(Debug, Clone, Serialize, Deserialize)]
1280pub struct QuantizationComparison {
1281    /// Scalar quantization results
1282    pub scalar: QuantizationBenchmark,
1283    /// Product quantization results (if applicable)
1284    pub product: Option<QuantizationBenchmark>,
1285    /// Dataset size
1286    pub dataset_size: usize,
1287    /// Vector dimension
1288    pub dimension: usize,
1289}
1290
1291impl QuantizationComparison {
1292    /// Generate a comparison summary
1293    pub fn summary(&self) -> String {
1294        let mut result = format!(
1295            "Dataset: {} vectors, {} dimensions\n\nScalar Quantization:\n  {}\n",
1296            self.dataset_size,
1297            self.dimension,
1298            self.scalar.summary()
1299        );
1300
1301        if let Some(ref pq) = self.product {
1302            result.push_str(&format!("\nProduct Quantization:\n  {}\n", pq.summary()));
1303        }
1304
1305        result
1306    }
1307
1308    /// Get the best method for a given k value based on recall
1309    pub fn best_method_for_k(&self, k: usize) -> (&str, f32) {
1310        let scalar_recall = self
1311            .scalar
1312            .recall_at_k
1313            .iter()
1314            .find(|(kv, _)| *kv == k)
1315            .map(|(_, r)| *r)
1316            .unwrap_or(0.0);
1317
1318        if let Some(ref pq) = self.product {
1319            let pq_recall = pq
1320                .recall_at_k
1321                .iter()
1322                .find(|(kv, _)| *kv == k)
1323                .map(|(_, r)| *r)
1324                .unwrap_or(0.0);
1325
1326            if pq_recall > scalar_recall {
1327                return ("ProductQuantization", pq_recall);
1328            }
1329        }
1330
1331        ("ScalarQuantization", scalar_recall)
1332    }
1333}
1334
1335#[cfg(test)]
1336mod tests {
1337    use super::*;
1338
1339    #[test]
1340    fn test_scalar_quantizer_uint8() {
1341        let mut quantizer = ScalarQuantizer::uint8(4);
1342
1343        // Train on some vectors
1344        let vectors = vec![
1345            vec![0.0, 0.5, 1.0, -0.5],
1346            vec![1.0, 0.0, 0.5, 0.5],
1347            vec![0.5, 0.5, 0.0, 0.0],
1348        ];
1349
1350        quantizer.train(&vectors).unwrap();
1351        assert!(quantizer.is_trained());
1352
1353        // Quantize and dequantize
1354        let original = vec![0.5, 0.25, 0.75, 0.0];
1355        let quantized = quantizer.quantize(&original).unwrap();
1356        let restored = quantizer.dequantize(&quantized).unwrap();
1357
1358        // Check approximate equality
1359        for (o, r) in original.iter().zip(restored.iter()) {
1360            assert!((o - r).abs() < 0.05, "Expected {} ~= {}", o, r);
1361        }
1362
1363        assert_eq!(quantizer.compression_ratio(), 4.0);
1364    }
1365
1366    #[test]
1367    fn test_scalar_quantizer_int8() {
1368        let mut quantizer = ScalarQuantizer::int8(4);
1369
1370        let vectors = vec![vec![-1.0, 0.0, 1.0, -0.5], vec![1.0, -1.0, 0.5, 0.5]];
1371
1372        quantizer.train(&vectors).unwrap();
1373
1374        let original = vec![0.0, -0.5, 0.5, 0.25];
1375        let quantized = quantizer.quantize(&original).unwrap();
1376        let restored = quantizer.dequantize(&quantized).unwrap();
1377
1378        for (o, r) in original.iter().zip(restored.iter()) {
1379            assert!((o - r).abs() < 0.1, "Expected {} ~= {}", o, r);
1380        }
1381    }
1382
1383    #[test]
1384    fn test_scalar_distance() {
1385        let mut quantizer = ScalarQuantizer::uint8(4);
1386
1387        let vectors = vec![vec![0.0, 0.0, 0.0, 0.0], vec![1.0, 1.0, 1.0, 1.0]];
1388
1389        quantizer.train(&vectors).unwrap();
1390
1391        let a = quantizer.quantize(&[0.0, 0.0, 0.0, 0.0]).unwrap();
1392        let b = quantizer.quantize(&[1.0, 1.0, 1.0, 1.0]).unwrap();
1393
1394        let dist = quantizer.distance_l2_quantized(&a, &b).unwrap();
1395        assert!(dist > 0.0, "Distance should be positive");
1396    }
1397
1398    #[test]
1399    fn test_product_quantizer() {
1400        // 8-dimensional vectors with 2 sub-quantizers
1401        let mut pq = ProductQuantizer::new(8, 2, 4).unwrap(); // 4 bits = 16 centroids
1402
1403        let vectors: Vec<Vec<f32>> = (0..100)
1404            .map(|i| (0..8).map(|j| i as f32 * 0.01 + j as f32 * 0.1).collect())
1405            .collect();
1406
1407        pq.train(&vectors, 10).unwrap();
1408        assert!(pq.is_trained());
1409
1410        // Quantize and dequantize
1411        let original = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
1412        let code = pq.quantize(&original).unwrap();
1413        let restored = pq.dequantize(&code).unwrap();
1414
1415        // Check code size
1416        assert_eq!(code.codes.len(), 2);
1417        assert_eq!(restored.len(), 8);
1418
1419        // Compression ratio should be 8*4 / 2 = 16
1420        assert_eq!(pq.compression_ratio(), 16.0);
1421    }
1422
1423    #[test]
1424    fn test_pq_distance_table() {
1425        let mut pq = ProductQuantizer::new(8, 2, 4).unwrap();
1426
1427        let vectors: Vec<Vec<f32>> = (0..50)
1428            .map(|i| (0..8).map(|j| i as f32 * 0.02 + j as f32 * 0.1).collect())
1429            .collect();
1430
1431        pq.train(&vectors, 5).unwrap();
1432
1433        let query = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
1434        let code = pq.quantize(&vectors[0]).unwrap();
1435
1436        // Compute distance two ways
1437        let direct_dist = pq.asymmetric_distance(&query, &code).unwrap();
1438        let table = pq.compute_distance_table(&query).unwrap();
1439        let table_dist = pq.distance_from_table(&table, &code);
1440
1441        // Should be equal (within floating point tolerance)
1442        assert!((direct_dist - table_dist).abs() < 1e-5);
1443    }
1444
1445    #[test]
1446    fn test_quantization_benchmarker() {
1447        // Create test dataset
1448        let dim = 8;
1449        let n_vectors = 50;
1450        let n_queries = 5;
1451
1452        let vectors: Vec<Vec<f32>> = (0..n_vectors)
1453            .map(|i| (0..dim).map(|j| (i as f32 + j as f32) * 0.1).collect())
1454            .collect();
1455
1456        let queries: Vec<Vec<f32>> = (0..n_queries)
1457            .map(|i| {
1458                (0..dim)
1459                    .map(|j| (i as f32 * 2.0 + j as f32) * 0.1)
1460                    .collect()
1461            })
1462            .collect();
1463
1464        // Test ground truth computation
1465        let gt = QuantizationBenchmarker::compute_ground_truth(&vectors, &queries, 10);
1466        assert_eq!(gt.len(), n_queries);
1467        assert_eq!(gt[0].len(), 10);
1468
1469        // Test scalar quantization benchmark
1470        let mut sq = ScalarQuantizer::uint8(dim);
1471        sq.train(&vectors).unwrap();
1472
1473        let sq_benchmark =
1474            QuantizationBenchmarker::benchmark_scalar(&sq, &vectors, &queries, &gt, &[1, 5, 10])
1475                .unwrap();
1476
1477        assert_eq!(sq_benchmark.recall_at_k.len(), 3);
1478        assert!(sq_benchmark.compression_ratio > 1.0);
1479        assert!(sq_benchmark.memory_savings > 0);
1480    }
1481
1482    #[test]
1483    fn test_quantization_comparison() {
1484        let dim = 8;
1485        let n_vectors = 100;
1486        let n_queries = 10;
1487
1488        let vectors: Vec<Vec<f32>> = (0..n_vectors)
1489            .map(|i| (0..dim).map(|j| (i as f32 + j as f32) * 0.05).collect())
1490            .collect();
1491
1492        let queries: Vec<Vec<f32>> = (0..n_queries)
1493            .map(|i| {
1494                (0..dim)
1495                    .map(|j| (i as f32 * 3.0 + j as f32) * 0.05)
1496                    .collect()
1497            })
1498            .collect();
1499
1500        let comparison =
1501            QuantizationBenchmarker::compare_methods(&vectors, &queries, &[1, 5, 10]).unwrap();
1502
1503        assert_eq!(comparison.dataset_size, n_vectors);
1504        assert_eq!(comparison.dimension, dim);
1505        assert!(comparison.product.is_some()); // dim = 8 is divisible by 8
1506
1507        // Check best method selection
1508        let (method, recall) = comparison.best_method_for_k(10);
1509        assert!(!method.is_empty());
1510        assert!((0.0..=1.0).contains(&recall));
1511    }
1512
1513    #[test]
1514    fn test_opq_basic() {
1515        // 8-dimensional vectors with 2 sub-quantizers
1516        let mut opq = OptimizedProductQuantizer::new(8, 2, 4).unwrap();
1517
1518        let vectors: Vec<Vec<f32>> = (0..100)
1519            .map(|i| (0..8).map(|j| i as f32 * 0.01 + j as f32 * 0.1).collect())
1520            .collect();
1521
1522        // Train with rotation learning (5 iterations)
1523        opq.train(&vectors, 10, 5).unwrap();
1524        assert!(opq.is_trained());
1525
1526        // Quantize and dequantize
1527        let original = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
1528        let code = opq.quantize(&original).unwrap();
1529        let restored = opq.dequantize(&code).unwrap();
1530
1531        assert_eq!(code.codes.len(), 2);
1532        assert_eq!(restored.len(), 8);
1533
1534        // Compression ratio should be same as PQ
1535        assert_eq!(opq.compression_ratio(), 16.0);
1536
1537        // Test asymmetric distance
1538        let query = vec![0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85];
1539        let dist = opq.asymmetric_distance(&query, &code).unwrap();
1540        assert!(dist >= 0.0);
1541    }
1542
1543    #[test]
1544    fn test_opq_distance_table() {
1545        let mut opq = OptimizedProductQuantizer::new(8, 2, 4).unwrap();
1546
1547        let vectors: Vec<Vec<f32>> = (0..50)
1548            .map(|i| (0..8).map(|j| i as f32 * 0.02 + j as f32 * 0.1).collect())
1549            .collect();
1550
1551        opq.train(&vectors, 5, 3).unwrap();
1552
1553        let query = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
1554        let code = opq.quantize(&vectors[0]).unwrap();
1555
1556        // Compute distance two ways
1557        let direct_dist = opq.asymmetric_distance(&query, &code).unwrap();
1558        let table = opq.compute_distance_table(&query).unwrap();
1559        let table_dist = opq.distance_from_table(&table, &code);
1560
1561        // Should be equal (within floating point tolerance)
1562        assert!((direct_dist - table_dist).abs() < 1e-4);
1563    }
1564
1565    #[test]
1566    fn test_opq_serialization() {
1567        let mut opq = OptimizedProductQuantizer::new(8, 2, 4).unwrap();
1568
1569        let vectors: Vec<Vec<f32>> = (0..50)
1570            .map(|i| (0..8).map(|j| i as f32 * 0.02 + j as f32 * 0.1).collect())
1571            .collect();
1572
1573        opq.train(&vectors, 5, 3).unwrap();
1574
1575        // Serialize
1576        let serialized = oxicode::serde::encode_to_vec(&opq, oxicode::config::standard()).unwrap();
1577
1578        // Deserialize
1579        let deserialized: OptimizedProductQuantizer =
1580            oxicode::serde::decode_owned_from_slice(&serialized, oxicode::config::standard())
1581                .map(|(v, _)| v)
1582                .unwrap();
1583
1584        assert!(deserialized.is_trained());
1585        assert_eq!(deserialized.compression_ratio(), opq.compression_ratio());
1586
1587        // Test that deserialized works correctly
1588        let original = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
1589        let code1 = opq.quantize(&original).unwrap();
1590        let code2 = deserialized.quantize(&original).unwrap();
1591
1592        // Codes should be identical
1593        assert_eq!(code1.codes, code2.codes);
1594    }
1595}