Skip to main content

grafeo_core/index/vector/
quantization.rs

1//! Vector quantization algorithms for memory-efficient storage.
2//!
3//! Quantization reduces vector precision for memory savings:
4//!
5//! | Method  | Compression | Accuracy | Speed    | Use Case                    |
6//! |---------|-------------|----------|----------|----------------------------|
7//! | Scalar  | 4x          | ~97%     | Fast     | Default for most datasets   |
8//! | Binary  | 32x         | ~80%     | Fastest  | Very large datasets         |
9//!
10//! # Scalar Quantization
11//!
12//! Converts f32 values to u8 by learning min/max ranges per dimension:
13//!
14//! ```
15//! use grafeo_core::index::vector::quantization::ScalarQuantizer;
16//!
17//! // Training vectors to learn min/max ranges
18//! let vectors = vec![
19//!     vec![0.0f32, 0.3, 0.7],
20//!     vec![0.2, 0.5, 1.0],
21//!     vec![0.1, 0.6, 0.9],
22//! ];
23//! let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
24//! let quantizer = ScalarQuantizer::train(&refs);
25//!
26//! // Quantize: f32 -> u8 (4x compression)
27//! let original = vec![0.1f32, 0.5, 0.9];
28//! let quantized = quantizer.quantize(&original);
29//!
30//! // Compute distance in quantized space (approximate)
31//! let other_quantized = quantizer.quantize(&[0.15, 0.45, 0.85]);
32//! let dist = quantizer.distance_u8(&quantized, &other_quantized);
33//! ```
34//!
35//! # Binary Quantization
36//!
37//! Converts f32 values to bits (sign only), enabling hamming distance:
38//!
39//! ```
40//! use grafeo_core::index::vector::quantization::BinaryQuantizer;
41//!
42//! let v1 = vec![0.1f32, -0.5, 0.0, 0.9];
43//! let v2 = vec![0.2f32, -0.3, 0.1, 0.8];
44//! let bits1 = BinaryQuantizer::quantize(&v1);
45//! let bits2 = BinaryQuantizer::quantize(&v2);
46//!
47//! // Hamming distance (count differing bits)
48//! let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
49//! ```
50
51use serde::{Deserialize, Serialize};
52
53// ============================================================================
54// Quantization Type
55// ============================================================================
56
57/// Quantization strategy for vector storage.
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
59pub enum QuantizationType {
60    /// No quantization - full f32 precision.
61    #[default]
62    None,
63    /// Scalar quantization: f32 -> u8 (4x compression, ~97% accuracy).
64    Scalar,
65    /// Binary quantization: f32 -> 1 bit (32x compression, ~80% accuracy).
66    Binary,
67    /// Product quantization: f32 -> M u8 codes (8-32x compression, ~90% accuracy).
68    Product {
69        /// Number of subvectors (typically 8, 16, 32, 64).
70        num_subvectors: usize,
71    },
72}
73
74impl QuantizationType {
75    /// Returns the compression ratio (memory reduction factor).
76    ///
77    /// For Product quantization, this depends on dimensions and num_subvectors.
78    /// The ratio is approximate: dimensions * 4 / num_subvectors.
79    #[must_use]
80    pub fn compression_ratio(&self, dimensions: usize) -> usize {
81        match self {
82            Self::None => 1,
83            Self::Scalar => 4,  // f32 (4 bytes) -> u8 (1 byte)
84            Self::Binary => 32, // f32 (4 bytes) -> 1 bit (0.125 bytes)
85            Self::Product { num_subvectors } => {
86                // Original: dimensions * 4 bytes (f32)
87                // Compressed: num_subvectors bytes (u8 codes)
88                // Ratio: (dimensions * 4) / num_subvectors
89                let m = (*num_subvectors).max(1);
90                (dimensions * 4) / m
91            }
92        }
93    }
94
95    /// Returns the name of the quantization type.
96    #[must_use]
97    pub fn name(&self) -> &'static str {
98        match self {
99            Self::None => "none",
100            Self::Scalar => "scalar",
101            Self::Binary => "binary",
102            Self::Product { .. } => "product",
103        }
104    }
105
106    /// Parses from string (case-insensitive).
107    #[must_use]
108    pub fn from_str(s: &str) -> Option<Self> {
109        match s.to_lowercase().as_str() {
110            "none" | "full" | "f32" => Some(Self::None),
111            "scalar" | "sq" | "u8" | "int8" => Some(Self::Scalar),
112            "binary" | "bin" | "bit" | "1bit" => Some(Self::Binary),
113            "product" | "pq" => Some(Self::Product { num_subvectors: 8 }),
114            s if s.starts_with("pq") => {
115                // Parse "pq8", "pq16", etc.
116                s[2..]
117                    .parse()
118                    .ok()
119                    .map(|n| Self::Product { num_subvectors: n })
120            }
121            _ => None,
122        }
123    }
124
125    /// Returns true if this quantization type requires training.
126    #[must_use]
127    pub const fn requires_training(&self) -> bool {
128        matches!(self, Self::Scalar | Self::Product { .. })
129    }
130}
131
132// ============================================================================
133// Scalar Quantization
134// ============================================================================
135
136/// Scalar quantizer: f32 -> u8 with per-dimension min/max scaling.
137///
138/// Training learns the min/max value for each dimension, then quantizes
139/// values to [0, 255] range. This achieves 4x compression with typically
140/// >97% recall retention.
141///
142/// # Example
143///
144/// ```
145/// use grafeo_core::index::vector::quantization::ScalarQuantizer;
146///
147/// // Training vectors
148/// let vectors = vec![
149///     vec![0.0f32, 0.5, 1.0],
150///     vec![0.2, 0.3, 0.8],
151///     vec![0.1, 0.6, 0.9],
152/// ];
153///
154/// // Train quantizer
155/// let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
156/// let quantizer = ScalarQuantizer::train(&refs);
157///
158/// // Quantize a vector
159/// let quantized = quantizer.quantize(&[0.1, 0.4, 0.85]);
160/// assert_eq!(quantized.len(), 3);
161///
162/// // Compute approximate distance
163/// let q2 = quantizer.quantize(&[0.15, 0.45, 0.9]);
164/// let dist = quantizer.distance_squared_u8(&quantized, &q2);
165/// assert!(dist < 1000.0);
166/// ```
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct ScalarQuantizer {
169    /// Minimum value per dimension.
170    min: Vec<f32>,
171    /// Scale factor per dimension: 255 / (max - min).
172    scale: Vec<f32>,
173    /// Inverse scale for distance computation: (max - min) / 255.
174    inv_scale: Vec<f32>,
175    /// Number of dimensions.
176    dimensions: usize,
177}
178
179impl ScalarQuantizer {
180    /// Trains a scalar quantizer from sample vectors.
181    ///
182    /// Learns the min/max value per dimension from the training data.
183    /// The more representative the training data, the better the quantization.
184    ///
185    /// # Arguments
186    ///
187    /// * `vectors` - Training vectors (should be representative of the dataset)
188    ///
189    /// # Panics
190    ///
191    /// Panics if `vectors` is empty or if vectors have different dimensions.
192    #[must_use]
193    pub fn train(vectors: &[&[f32]]) -> Self {
194        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
195
196        let dimensions = vectors[0].len();
197        assert!(
198            vectors.iter().all(|v| v.len() == dimensions),
199            "All training vectors must have the same dimensions"
200        );
201
202        // Find min/max per dimension
203        let mut min = vec![f32::INFINITY; dimensions];
204        let mut max = vec![f32::NEG_INFINITY; dimensions];
205
206        for vec in vectors {
207            for (i, &v) in vec.iter().enumerate() {
208                min[i] = min[i].min(v);
209                max[i] = max[i].max(v);
210            }
211        }
212
213        // Compute scale factors (avoid division by zero)
214        let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
215            .iter()
216            .zip(&max)
217            .map(|(&mn, &mx)| {
218                let range = mx - mn;
219                if range.abs() < f32::EPSILON {
220                    // All values are the same, use 1.0 as scale
221                    (1.0, 1.0)
222                } else {
223                    (255.0 / range, range / 255.0)
224                }
225            })
226            .unzip();
227
228        Self {
229            min,
230            scale,
231            inv_scale,
232            dimensions,
233        }
234    }
235
236    /// Creates a quantizer with explicit ranges (useful for testing).
237    #[must_use]
238    pub fn with_ranges(min: Vec<f32>, max: Vec<f32>) -> Self {
239        let dimensions = min.len();
240        assert_eq!(min.len(), max.len(), "Min and max must have same length");
241
242        let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
243            .iter()
244            .zip(&max)
245            .map(|(&mn, &mx)| {
246                let range = mx - mn;
247                if range.abs() < f32::EPSILON {
248                    (1.0, 1.0)
249                } else {
250                    (255.0 / range, range / 255.0)
251                }
252            })
253            .unzip();
254
255        Self {
256            min,
257            scale,
258            inv_scale,
259            dimensions,
260        }
261    }
262
263    /// Returns the number of dimensions.
264    #[must_use]
265    pub fn dimensions(&self) -> usize {
266        self.dimensions
267    }
268
269    /// Returns the min values per dimension.
270    #[must_use]
271    pub fn min_values(&self) -> &[f32] {
272        &self.min
273    }
274
275    /// Quantizes an f32 vector to u8.
276    ///
277    /// Values are clamped to the learned [min, max] range.
278    #[must_use]
279    pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
280        debug_assert_eq!(
281            vector.len(),
282            self.dimensions,
283            "Vector dimension mismatch: expected {}, got {}",
284            self.dimensions,
285            vector.len()
286        );
287
288        vector
289            .iter()
290            .enumerate()
291            .map(|(i, &v)| {
292                let normalized = (v - self.min[i]) * self.scale[i];
293                normalized.clamp(0.0, 255.0) as u8
294            })
295            .collect()
296    }
297
298    /// Quantizes multiple vectors in batch.
299    #[must_use]
300    pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
301        vectors.iter().map(|v| self.quantize(v)).collect()
302    }
303
304    /// Dequantizes a u8 vector back to f32 (approximate).
305    #[must_use]
306    pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
307        debug_assert_eq!(quantized.len(), self.dimensions);
308
309        quantized
310            .iter()
311            .enumerate()
312            .map(|(i, &q)| self.min[i] + (q as f32) * self.inv_scale[i])
313            .collect()
314    }
315
316    /// Computes squared Euclidean distance between quantized vectors.
317    ///
318    /// This is an approximation that works well for ranking nearest neighbors.
319    /// The returned distance is scaled back to the original space.
320    #[must_use]
321    pub fn distance_squared_u8(&self, a: &[u8], b: &[u8]) -> f32 {
322        debug_assert_eq!(a.len(), self.dimensions);
323        debug_assert_eq!(b.len(), self.dimensions);
324
325        // Compute in quantized space, then scale
326        let mut sum = 0.0f32;
327        for i in 0..a.len() {
328            let diff = (a[i] as f32) - (b[i] as f32);
329            sum += diff * diff * self.inv_scale[i] * self.inv_scale[i];
330        }
331        sum
332    }
333
334    /// Computes Euclidean distance between quantized vectors.
335    #[must_use]
336    #[inline]
337    pub fn distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
338        self.distance_squared_u8(a, b).sqrt()
339    }
340
341    /// Computes approximate cosine distance using quantized vectors.
342    ///
343    /// This is less accurate than exact computation but much faster.
344    #[must_use]
345    pub fn cosine_distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
346        debug_assert_eq!(a.len(), self.dimensions);
347        debug_assert_eq!(b.len(), self.dimensions);
348
349        let mut dot = 0.0f32;
350        let mut norm_a = 0.0f32;
351        let mut norm_b = 0.0f32;
352
353        for i in 0..a.len() {
354            // Dequantize on the fly
355            let va = self.min[i] + (a[i] as f32) * self.inv_scale[i];
356            let vb = self.min[i] + (b[i] as f32) * self.inv_scale[i];
357
358            dot += va * vb;
359            norm_a += va * va;
360            norm_b += vb * vb;
361        }
362
363        let denom = (norm_a * norm_b).sqrt();
364        if denom < f32::EPSILON {
365            1.0 // Maximum distance for zero vectors
366        } else {
367            1.0 - (dot / denom)
368        }
369    }
370
371    /// Computes distance between a f32 query and a quantized vector.
372    ///
373    /// This is useful for search where we keep the query in full precision.
374    #[must_use]
375    pub fn asymmetric_distance_squared(&self, query: &[f32], quantized: &[u8]) -> f32 {
376        debug_assert_eq!(query.len(), self.dimensions);
377        debug_assert_eq!(quantized.len(), self.dimensions);
378
379        let mut sum = 0.0f32;
380        for i in 0..query.len() {
381            // Dequantize the stored vector
382            let dequant = self.min[i] + (quantized[i] as f32) * self.inv_scale[i];
383            let diff = query[i] - dequant;
384            sum += diff * diff;
385        }
386        sum
387    }
388
389    /// Computes asymmetric Euclidean distance.
390    #[must_use]
391    #[inline]
392    pub fn asymmetric_distance(&self, query: &[f32], quantized: &[u8]) -> f32 {
393        self.asymmetric_distance_squared(query, quantized).sqrt()
394    }
395}
396
397// ============================================================================
398// Binary Quantization
399// ============================================================================
400
401/// Binary quantizer: f32 -> 1 bit (sign only).
402///
403/// Provides extreme compression (32x) at the cost of accuracy (~80% recall).
404/// Uses hamming distance for fast comparison. Best used with rescoring.
405///
406/// # Example
407///
408/// ```
409/// use grafeo_core::index::vector::quantization::BinaryQuantizer;
410///
411/// let v1 = vec![0.5f32, -0.3, 0.0, 0.8, -0.1, 0.2, -0.4, 0.9];
412/// let v2 = vec![0.4f32, -0.2, 0.1, 0.7, -0.2, 0.3, -0.3, 0.8];
413///
414/// let bits1 = BinaryQuantizer::quantize(&v1);
415/// let bits2 = BinaryQuantizer::quantize(&v2);
416///
417/// let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
418/// // Vectors are similar, so hamming distance should be low
419/// assert!(dist < 4);
420/// ```
421pub struct BinaryQuantizer;
422
423impl BinaryQuantizer {
424    /// Quantizes f32 vector to binary (sign bits packed in u64).
425    ///
426    /// Each f32 becomes 1 bit: 1 if >= 0, 0 if < 0.
427    /// Bits are packed into u64 words (64 dimensions per word).
428    #[must_use]
429    pub fn quantize(vector: &[f32]) -> Vec<u64> {
430        let num_words = (vector.len() + 63) / 64;
431        let mut result = vec![0u64; num_words];
432
433        for (i, &v) in vector.iter().enumerate() {
434            if v >= 0.0 {
435                result[i / 64] |= 1u64 << (i % 64);
436            }
437        }
438
439        result
440    }
441
442    /// Quantizes multiple vectors in batch.
443    #[must_use]
444    pub fn quantize_batch(vectors: &[&[f32]]) -> Vec<Vec<u64>> {
445        vectors.iter().map(|v| Self::quantize(v)).collect()
446    }
447
448    /// Computes hamming distance between binary vectors.
449    ///
450    /// Counts the number of differing bits. Lower = more similar.
451    #[must_use]
452    pub fn hamming_distance(a: &[u64], b: &[u64]) -> u32 {
453        debug_assert_eq!(a.len(), b.len(), "Binary vectors must have same length");
454
455        a.iter().zip(b).map(|(&x, &y)| (x ^ y).count_ones()).sum()
456    }
457
458    /// Computes normalized hamming distance (0.0 to 1.0).
459    ///
460    /// Returns the fraction of bits that differ.
461    #[must_use]
462    pub fn hamming_distance_normalized(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
463        let hamming = Self::hamming_distance(a, b);
464        hamming as f32 / dimensions as f32
465    }
466
467    /// Estimates Euclidean distance from hamming distance.
468    ///
469    /// Uses an empirical approximation: d_euclidean ≈ sqrt(2 * hamming / dim).
470    /// This is a rough estimate suitable for initial filtering.
471    #[must_use]
472    pub fn approximate_euclidean(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
473        let hamming = Self::hamming_distance(a, b);
474        // Empirical approximation: assume values are roughly unit-normalized
475        (2.0 * hamming as f32 / dimensions as f32).sqrt()
476    }
477
478    /// Returns the number of u64 words needed for the given dimensions.
479    #[must_use]
480    pub const fn words_needed(dimensions: usize) -> usize {
481        (dimensions + 63) / 64
482    }
483
484    /// Returns the memory footprint in bytes for quantized storage.
485    #[must_use]
486    pub const fn bytes_needed(dimensions: usize) -> usize {
487        Self::words_needed(dimensions) * 8
488    }
489}
490
491// ============================================================================
492// Product Quantization
493// ============================================================================
494
495/// Product quantizer: splits vectors into M subvectors, quantizes each to K centroids.
496///
497/// Product Quantization (PQ) provides excellent compression (8-32x) with ~90% recall.
498/// It works by:
499/// 1. Dividing vectors into M subvectors
500/// 2. Learning K centroids (typically 256) for each subvector via k-means
501/// 3. Storing each vector as M u8 codes (indices into centroid tables)
502///
503/// Distance computation uses asymmetric distance tables (ADC) for efficiency.
504///
505/// # Example
506///
507/// ```
508/// use grafeo_core::index::vector::quantization::ProductQuantizer;
509///
510/// // Training vectors (16 dimensions, split into 4 subvectors)
511/// let vectors: Vec<Vec<f32>> = (0..50)
512///     .map(|i| (0..16).map(|j| (i + j) as f32 * 0.1).collect())
513///     .collect();
514/// let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
515///
516/// // Train quantizer with 4 subvectors, 8 centroids each
517/// let quantizer = ProductQuantizer::train(&refs, 4, 8, 5);
518///
519/// // Quantize a vector to 4 u8 codes
520/// let query = &vectors[0];
521/// let codes = quantizer.quantize(query);
522/// assert_eq!(codes.len(), 4);
523///
524/// // Each code is an index into the centroid table (0-7)
525/// assert!(codes.iter().all(|&c| c < 8));
526///
527/// // Reconstruct approximate vector from codes
528/// let reconstructed = quantizer.reconstruct(&codes);
529/// assert_eq!(reconstructed.len(), 16);
530/// ```
531#[derive(Debug, Clone, Serialize, Deserialize)]
532pub struct ProductQuantizer {
533    /// Number of subvectors (M).
534    num_subvectors: usize,
535    /// Number of centroids per subvector (K, typically 256 for u8 codes).
536    num_centroids: usize,
537    /// Dimensions per subvector.
538    subvector_dim: usize,
539    /// Total dimensions.
540    dimensions: usize,
541    /// Centroids: [M][K][subvector_dim] flattened to [M * K * subvector_dim].
542    centroids: Vec<f32>,
543}
544
545impl ProductQuantizer {
546    /// Trains a product quantizer from sample vectors using k-means clustering.
547    ///
548    /// # Arguments
549    ///
550    /// * `vectors` - Training vectors (should be representative of the dataset)
551    /// * `num_subvectors` - Number of subvectors (M), must divide dimensions evenly
552    /// * `num_centroids` - Number of centroids per subvector (K), typically 256
553    /// * `iterations` - Number of k-means iterations (10-20 is usually sufficient)
554    ///
555    /// # Panics
556    ///
557    /// Panics if vectors is empty, dimensions not divisible by num_subvectors,
558    /// or num_centroids > 256.
559    #[must_use]
560    pub fn train(
561        vectors: &[&[f32]],
562        num_subvectors: usize,
563        num_centroids: usize,
564        iterations: usize,
565    ) -> Self {
566        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
567        assert!(
568            num_centroids <= 256,
569            "num_centroids must be <= 256 for u8 codes"
570        );
571        assert!(num_subvectors > 0, "num_subvectors must be > 0");
572
573        let dimensions = vectors[0].len();
574        assert!(
575            dimensions.is_multiple_of(num_subvectors),
576            "dimensions ({dimensions}) must be divisible by num_subvectors ({num_subvectors})"
577        );
578        assert!(
579            vectors.iter().all(|v| v.len() == dimensions),
580            "All training vectors must have the same dimensions"
581        );
582
583        let subvector_dim = dimensions / num_subvectors;
584
585        // Train centroids for each subvector independently
586        let mut centroids = Vec::with_capacity(num_subvectors * num_centroids * subvector_dim);
587
588        for m in 0..num_subvectors {
589            // Extract subvectors for this partition
590            let subvectors: Vec<Vec<f32>> = vectors
591                .iter()
592                .map(|v| {
593                    let start = m * subvector_dim;
594                    let end = start + subvector_dim;
595                    v[start..end].to_vec()
596                })
597                .collect();
598
599            // Run k-means on this partition
600            let partition_centroids =
601                Self::kmeans(&subvectors, num_centroids, subvector_dim, iterations);
602
603            centroids.extend(partition_centroids);
604        }
605
606        Self {
607            num_subvectors,
608            num_centroids,
609            subvector_dim,
610            dimensions,
611            centroids,
612        }
613    }
614
615    /// Simple k-means clustering implementation.
616    fn kmeans(vectors: &[Vec<f32>], k: usize, dims: usize, iterations: usize) -> Vec<f32> {
617        let n = vectors.len();
618
619        // Initialize centroids using k-means++ style (first k vectors or random sampling)
620        let actual_k = k.min(n);
621        let mut centroids: Vec<f32> = if actual_k == n {
622            vectors.iter().flat_map(|v| v.iter().copied()).collect()
623        } else {
624            // Take evenly spaced samples
625            let step = n / actual_k;
626            (0..actual_k)
627                .flat_map(|i| vectors[i * step].iter().copied())
628                .collect()
629        };
630
631        // Pad with zeros if we don't have enough training vectors
632        if actual_k < k {
633            centroids.resize(k * dims, 0.0);
634        }
635
636        let mut assignments = vec![0usize; n];
637        let mut counts = vec![0usize; k];
638
639        for _ in 0..iterations {
640            // Assignment step: find nearest centroid for each vector
641            for (i, vec) in vectors.iter().enumerate() {
642                let mut best_dist = f32::INFINITY;
643                let mut best_k = 0;
644
645                for j in 0..k {
646                    let centroid_start = j * dims;
647                    let dist: f32 = vec
648                        .iter()
649                        .enumerate()
650                        .map(|(d, &v)| {
651                            let diff = v - centroids[centroid_start + d];
652                            diff * diff
653                        })
654                        .sum();
655
656                    if dist < best_dist {
657                        best_dist = dist;
658                        best_k = j;
659                    }
660                }
661
662                assignments[i] = best_k;
663            }
664
665            // Update step: recompute centroids as mean of assigned vectors
666            centroids.fill(0.0);
667            counts.fill(0);
668
669            for (i, vec) in vectors.iter().enumerate() {
670                let k_idx = assignments[i];
671                let centroid_start = k_idx * dims;
672                counts[k_idx] += 1;
673
674                for (d, &v) in vec.iter().enumerate() {
675                    centroids[centroid_start + d] += v;
676                }
677            }
678
679            // Divide by counts to get means
680            for j in 0..k {
681                if counts[j] > 0 {
682                    let centroid_start = j * dims;
683                    let count = counts[j] as f32;
684                    for d in 0..dims {
685                        centroids[centroid_start + d] /= count;
686                    }
687                }
688            }
689        }
690
691        centroids
692    }
693
694    /// Creates a product quantizer with explicit centroids (for testing/loading).
695    #[must_use]
696    pub fn with_centroids(
697        num_subvectors: usize,
698        num_centroids: usize,
699        dimensions: usize,
700        centroids: Vec<f32>,
701    ) -> Self {
702        let subvector_dim = dimensions / num_subvectors;
703        assert_eq!(
704            centroids.len(),
705            num_subvectors * num_centroids * subvector_dim,
706            "Invalid centroid count"
707        );
708
709        Self {
710            num_subvectors,
711            num_centroids,
712            subvector_dim,
713            dimensions,
714            centroids,
715        }
716    }
717
718    /// Returns the number of subvectors (M).
719    #[must_use]
720    pub fn num_subvectors(&self) -> usize {
721        self.num_subvectors
722    }
723
724    /// Returns the number of centroids per subvector (K).
725    #[must_use]
726    pub fn num_centroids(&self) -> usize {
727        self.num_centroids
728    }
729
730    /// Returns the total dimensions.
731    #[must_use]
732    pub fn dimensions(&self) -> usize {
733        self.dimensions
734    }
735
736    /// Returns the dimensions per subvector.
737    #[must_use]
738    pub fn subvector_dim(&self) -> usize {
739        self.subvector_dim
740    }
741
742    /// Returns the memory footprint in bytes for a quantized vector.
743    #[must_use]
744    pub fn code_size(&self) -> usize {
745        self.num_subvectors // M u8 codes
746    }
747
748    /// Returns the compression ratio compared to f32 storage.
749    #[must_use]
750    pub fn compression_ratio(&self) -> usize {
751        // Original: dimensions * 4 bytes
752        // Compressed: num_subvectors bytes
753        (self.dimensions * 4) / self.num_subvectors
754    }
755
756    /// Quantizes a vector to M u8 codes.
757    ///
758    /// Each code is the index of the nearest centroid for that subvector.
759    #[must_use]
760    pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
761        debug_assert_eq!(
762            vector.len(),
763            self.dimensions,
764            "Vector dimension mismatch: expected {}, got {}",
765            self.dimensions,
766            vector.len()
767        );
768
769        let mut codes = Vec::with_capacity(self.num_subvectors);
770
771        for m in 0..self.num_subvectors {
772            let subvec_start = m * self.subvector_dim;
773            let subvec = &vector[subvec_start..subvec_start + self.subvector_dim];
774
775            // Find nearest centroid for this subvector
776            let mut best_dist = f32::INFINITY;
777            let mut best_k = 0u8;
778
779            for k in 0..self.num_centroids {
780                let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
781                let dist: f32 = subvec
782                    .iter()
783                    .enumerate()
784                    .map(|(d, &v)| {
785                        let diff = v - self.centroids[centroid_start + d];
786                        diff * diff
787                    })
788                    .sum();
789
790                if dist < best_dist {
791                    best_dist = dist;
792                    best_k = k as u8;
793                }
794            }
795
796            codes.push(best_k);
797        }
798
799        codes
800    }
801
802    /// Quantizes multiple vectors in batch.
803    #[must_use]
804    pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
805        vectors.iter().map(|v| self.quantize(v)).collect()
806    }
807
808    /// Builds asymmetric distance table for a query vector.
809    ///
810    /// Returns a table of shape \[M\]\[K\] containing the squared distance
811    /// from each query subvector to each centroid. This allows O(M) distance
812    /// computation for quantized vectors via table lookups.
813    #[must_use]
814    pub fn build_distance_table(&self, query: &[f32]) -> Vec<f32> {
815        debug_assert_eq!(query.len(), self.dimensions);
816
817        let mut table = Vec::with_capacity(self.num_subvectors * self.num_centroids);
818
819        for m in 0..self.num_subvectors {
820            let query_start = m * self.subvector_dim;
821            let query_subvec = &query[query_start..query_start + self.subvector_dim];
822
823            for k in 0..self.num_centroids {
824                let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
825
826                let dist: f32 = query_subvec
827                    .iter()
828                    .enumerate()
829                    .map(|(d, &v)| {
830                        let diff = v - self.centroids[centroid_start + d];
831                        diff * diff
832                    })
833                    .sum();
834
835                table.push(dist);
836            }
837        }
838
839        table
840    }
841
842    /// Computes asymmetric squared distance using a precomputed distance table.
843    ///
844    /// This is O(M) - just M table lookups and additions.
845    #[must_use]
846    #[inline]
847    pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
848        debug_assert_eq!(codes.len(), self.num_subvectors);
849        debug_assert_eq!(table.len(), self.num_subvectors * self.num_centroids);
850
851        codes
852            .iter()
853            .enumerate()
854            .map(|(m, &code)| table[m * self.num_centroids + code as usize])
855            .sum()
856    }
857
858    /// Computes asymmetric squared distance from query to quantized vector.
859    ///
860    /// This builds the distance table on the fly - use `build_distance_table`
861    /// and `distance_with_table` for batch queries.
862    #[must_use]
863    pub fn asymmetric_distance_squared(&self, query: &[f32], codes: &[u8]) -> f32 {
864        let table = self.build_distance_table(query);
865        self.distance_with_table(&table, codes)
866    }
867
868    /// Computes asymmetric distance (Euclidean).
869    #[must_use]
870    #[inline]
871    pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
872        self.asymmetric_distance_squared(query, codes).sqrt()
873    }
874
875    /// Reconstructs an approximate vector from codes.
876    ///
877    /// Returns the concatenated centroids for the given codes.
878    #[must_use]
879    pub fn reconstruct(&self, codes: &[u8]) -> Vec<f32> {
880        debug_assert_eq!(codes.len(), self.num_subvectors);
881
882        let mut result = Vec::with_capacity(self.dimensions);
883
884        for (m, &code) in codes.iter().enumerate() {
885            let centroid_start = (m * self.num_centroids + code as usize) * self.subvector_dim;
886            result.extend_from_slice(
887                &self.centroids[centroid_start..centroid_start + self.subvector_dim],
888            );
889        }
890
891        result
892    }
893
894    /// Returns the centroid vectors for a specific subvector partition.
895    #[must_use]
896    pub fn get_partition_centroids(&self, partition: usize) -> Vec<&[f32]> {
897        assert!(partition < self.num_subvectors);
898
899        (0..self.num_centroids)
900            .map(|k| {
901                let start = (partition * self.num_centroids + k) * self.subvector_dim;
902                &self.centroids[start..start + self.subvector_dim]
903            })
904            .collect()
905    }
906}
907
908// ============================================================================
909// SIMD-Accelerated Hamming Distance
910// ============================================================================
911
912/// Computes hamming distance with SIMD acceleration (if available).
913///
914/// On x86_64 with popcnt instruction, this is significantly faster than
915/// the scalar implementation.
916#[cfg(target_arch = "x86_64")]
917#[must_use]
918pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
919    // Use popcnt instruction if available (almost all modern CPUs)
920    a.iter()
921        .zip(b)
922        .map(|(&x, &y)| {
923            let xor = x ^ y;
924            // Safety: popcnt is available on virtually all x86_64 CPUs since Nehalem (2008).
925            // This is a well-understood CPU intrinsic with no memory safety implications.
926            #[allow(unsafe_code)]
927            unsafe {
928                std::arch::x86_64::_popcnt64(xor as i64) as u32
929            }
930        })
931        .sum()
932}
933
934/// Fallback scalar implementation.
935#[cfg(not(target_arch = "x86_64"))]
936#[must_use]
937pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
938    BinaryQuantizer::hamming_distance(a, b)
939}
940
941// ============================================================================
942// Tests
943// ============================================================================
944
945#[cfg(test)]
946mod tests {
947    use super::*;
948
949    #[test]
950    fn test_quantization_type_compression_ratio() {
951        // Use 384 dimensions (common embedding size)
952        let dims = 384;
953        assert_eq!(QuantizationType::None.compression_ratio(dims), 1);
954        assert_eq!(QuantizationType::Scalar.compression_ratio(dims), 4);
955        assert_eq!(QuantizationType::Binary.compression_ratio(dims), 32);
956
957        // Product quantization: (384 * 4) / 8 = 192x compression
958        let pq8 = QuantizationType::Product { num_subvectors: 8 };
959        assert_eq!(pq8.compression_ratio(dims), 192);
960
961        // PQ16: (384 * 4) / 16 = 96x compression
962        let pq16 = QuantizationType::Product { num_subvectors: 16 };
963        assert_eq!(pq16.compression_ratio(dims), 96);
964    }
965
966    #[test]
967    fn test_quantization_type_from_str() {
968        assert_eq!(
969            QuantizationType::from_str("none"),
970            Some(QuantizationType::None)
971        );
972        assert_eq!(
973            QuantizationType::from_str("scalar"),
974            Some(QuantizationType::Scalar)
975        );
976        assert_eq!(
977            QuantizationType::from_str("SQ"),
978            Some(QuantizationType::Scalar)
979        );
980        assert_eq!(
981            QuantizationType::from_str("binary"),
982            Some(QuantizationType::Binary)
983        );
984        assert_eq!(
985            QuantizationType::from_str("bit"),
986            Some(QuantizationType::Binary)
987        );
988        assert_eq!(QuantizationType::from_str("invalid"), None);
989    }
990
991    // ========================================================================
992    // Scalar Quantization Tests
993    // ========================================================================
994
995    #[test]
996    fn test_scalar_quantizer_train() {
997        let vectors = [
998            vec![0.0f32, 0.5, 1.0],
999            vec![0.2, 0.3, 0.8],
1000            vec![0.1, 0.6, 0.9],
1001        ];
1002        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1003
1004        let quantizer = ScalarQuantizer::train(&refs);
1005
1006        assert_eq!(quantizer.dimensions(), 3);
1007        assert_eq!(quantizer.min_values()[0], 0.0);
1008        assert_eq!(quantizer.min_values()[1], 0.3);
1009        assert_eq!(quantizer.min_values()[2], 0.8);
1010    }
1011
1012    #[test]
1013    fn test_scalar_quantizer_quantize() {
1014        let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1015
1016        // Min value should quantize to 0
1017        let q_min = quantizer.quantize(&[0.0, 0.0]);
1018        assert_eq!(q_min, vec![0, 0]);
1019
1020        // Max value should quantize to 255
1021        let q_max = quantizer.quantize(&[1.0, 1.0]);
1022        assert_eq!(q_max, vec![255, 255]);
1023
1024        // Middle value should quantize to ~127
1025        let q_mid = quantizer.quantize(&[0.5, 0.5]);
1026        assert!(q_mid[0] >= 126 && q_mid[0] <= 128);
1027    }
1028
1029    #[test]
1030    fn test_scalar_quantizer_dequantize() {
1031        let quantizer = ScalarQuantizer::with_ranges(vec![0.0], vec![1.0]);
1032
1033        let original = [0.5f32];
1034        let quantized = quantizer.quantize(&original);
1035        let dequantized = quantizer.dequantize(&quantized);
1036
1037        // Should be close to original (within quantization error)
1038        assert!((original[0] - dequantized[0]).abs() < 0.01);
1039    }
1040
1041    #[test]
1042    fn test_scalar_quantizer_distance() {
1043        let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1044
1045        let a = quantizer.quantize(&[0.0, 0.0]);
1046        let b = quantizer.quantize(&[1.0, 0.0]);
1047
1048        let dist = quantizer.distance_u8(&a, &b);
1049        // Should be approximately 1.0 (the Euclidean distance in original space)
1050        assert!((dist - 1.0).abs() < 0.1);
1051    }
1052
1053    #[test]
1054    fn test_scalar_quantizer_asymmetric_distance() {
1055        let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1056
1057        let query = [0.0f32, 0.0];
1058        let stored = quantizer.quantize(&[1.0, 0.0]);
1059
1060        let dist = quantizer.asymmetric_distance(&query, &stored);
1061        assert!((dist - 1.0).abs() < 0.1);
1062    }
1063
1064    #[test]
1065    fn test_scalar_quantizer_cosine_distance() {
1066        let quantizer = ScalarQuantizer::with_ranges(vec![-1.0, -1.0], vec![1.0, 1.0]);
1067
1068        // Orthogonal vectors
1069        let a = quantizer.quantize(&[1.0, 0.0]);
1070        let b = quantizer.quantize(&[0.0, 1.0]);
1071
1072        let dist = quantizer.cosine_distance_u8(&a, &b);
1073        // Cosine distance of orthogonal vectors = 1.0
1074        assert!((dist - 1.0).abs() < 0.1);
1075    }
1076
1077    #[test]
1078    #[should_panic(expected = "Cannot train on empty vector set")]
1079    fn test_scalar_quantizer_empty_training() {
1080        let vectors: Vec<&[f32]> = vec![];
1081        let _ = ScalarQuantizer::train(&vectors);
1082    }
1083
1084    // ========================================================================
1085    // Binary Quantization Tests
1086    // ========================================================================
1087
1088    #[test]
1089    fn test_binary_quantizer_quantize() {
1090        let v = vec![0.5f32, -0.3, 0.0, 0.8];
1091        let bits = BinaryQuantizer::quantize(&v);
1092
1093        assert_eq!(bits.len(), 1); // 4 dims fit in 1 u64
1094
1095        // Check individual bits: 0.5 >= 0 (1), -0.3 < 0 (0), 0.0 >= 0 (1), 0.8 >= 0 (1)
1096        // Expected bits (LSB first): 1, 0, 1, 1 = 0b1101 = 13
1097        assert_eq!(bits[0] & 0xF, 0b1101);
1098    }
1099
1100    #[test]
1101    fn test_binary_quantizer_hamming_distance() {
1102        let v1 = vec![1.0f32, 1.0, 1.0, 1.0]; // All positive: 1111
1103        let v2 = vec![1.0f32, -1.0, 1.0, -1.0]; // Mixed: 1010
1104
1105        let bits1 = BinaryQuantizer::quantize(&v1);
1106        let bits2 = BinaryQuantizer::quantize(&v2);
1107
1108        let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
1109        assert_eq!(dist, 2); // Two bits differ
1110    }
1111
1112    #[test]
1113    fn test_binary_quantizer_identical_vectors() {
1114        let v = vec![0.1f32, -0.2, 0.3, -0.4, 0.5];
1115        let bits = BinaryQuantizer::quantize(&v);
1116
1117        let dist = BinaryQuantizer::hamming_distance(&bits, &bits);
1118        assert_eq!(dist, 0);
1119    }
1120
1121    #[test]
1122    fn test_binary_quantizer_opposite_vectors() {
1123        let v1 = vec![1.0f32; 64];
1124        let v2 = vec![-1.0f32; 64];
1125
1126        let bits1 = BinaryQuantizer::quantize(&v1);
1127        let bits2 = BinaryQuantizer::quantize(&v2);
1128
1129        let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
1130        assert_eq!(dist, 64); // All bits differ
1131    }
1132
1133    #[test]
1134    fn test_binary_quantizer_large_vector() {
1135        let v: Vec<f32> = (0..1000)
1136            .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
1137            .collect();
1138        let bits = BinaryQuantizer::quantize(&v);
1139
1140        // 1000 dims needs ceil(1000/64) = 16 words
1141        assert_eq!(bits.len(), 16);
1142    }
1143
1144    #[test]
1145    fn test_binary_quantizer_normalized_distance() {
1146        let v1 = vec![1.0f32; 100];
1147        let v2 = vec![-1.0f32; 100];
1148
1149        let bits1 = BinaryQuantizer::quantize(&v1);
1150        let bits2 = BinaryQuantizer::quantize(&v2);
1151
1152        let norm_dist = BinaryQuantizer::hamming_distance_normalized(&bits1, &bits2, 100);
1153        assert!((norm_dist - 1.0).abs() < 0.01); // All bits differ
1154    }
1155
1156    #[test]
1157    fn test_binary_quantizer_words_needed() {
1158        assert_eq!(BinaryQuantizer::words_needed(1), 1);
1159        assert_eq!(BinaryQuantizer::words_needed(64), 1);
1160        assert_eq!(BinaryQuantizer::words_needed(65), 2);
1161        assert_eq!(BinaryQuantizer::words_needed(128), 2);
1162        assert_eq!(BinaryQuantizer::words_needed(1536), 24); // OpenAI embedding size
1163    }
1164
1165    #[test]
1166    fn test_binary_quantizer_bytes_needed() {
1167        // Each u64 is 8 bytes
1168        assert_eq!(BinaryQuantizer::bytes_needed(64), 8);
1169        assert_eq!(BinaryQuantizer::bytes_needed(128), 16);
1170        assert_eq!(BinaryQuantizer::bytes_needed(1536), 192); // vs 6144 for f32
1171    }
1172
1173    // ========================================================================
1174    // SIMD Tests
1175    // ========================================================================
1176
1177    #[test]
1178    fn test_hamming_distance_simd() {
1179        let a = vec![0xFFFF_FFFF_FFFF_FFFFu64, 0x0000_0000_0000_0000];
1180        let b = vec![0x0000_0000_0000_0000u64, 0xFFFF_FFFF_FFFF_FFFF];
1181
1182        let dist = hamming_distance_simd(&a, &b);
1183        assert_eq!(dist, 128); // All 128 bits differ
1184    }
1185
1186    // ========================================================================
1187    // Product Quantization Tests
1188    // ========================================================================
1189
1190    #[test]
1191    fn test_quantization_type_product_from_str() {
1192        // Basic PQ parsing
1193        assert_eq!(
1194            QuantizationType::from_str("pq"),
1195            Some(QuantizationType::Product { num_subvectors: 8 })
1196        );
1197        assert_eq!(
1198            QuantizationType::from_str("product"),
1199            Some(QuantizationType::Product { num_subvectors: 8 })
1200        );
1201
1202        // PQ with specific subvector count
1203        assert_eq!(
1204            QuantizationType::from_str("pq8"),
1205            Some(QuantizationType::Product { num_subvectors: 8 })
1206        );
1207        assert_eq!(
1208            QuantizationType::from_str("pq16"),
1209            Some(QuantizationType::Product { num_subvectors: 16 })
1210        );
1211        assert_eq!(
1212            QuantizationType::from_str("pq32"),
1213            Some(QuantizationType::Product { num_subvectors: 32 })
1214        );
1215    }
1216
1217    #[test]
1218    fn test_quantization_type_requires_training() {
1219        assert!(!QuantizationType::None.requires_training());
1220        assert!(QuantizationType::Scalar.requires_training());
1221        assert!(!QuantizationType::Binary.requires_training());
1222        assert!(QuantizationType::Product { num_subvectors: 8 }.requires_training());
1223    }
1224
1225    #[test]
1226    fn test_product_quantizer_train() {
1227        // Create 100 training vectors with 16 dimensions
1228        let vectors: Vec<Vec<f32>> = (0..100)
1229            .map(|i| (0..16).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
1230            .collect();
1231        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1232
1233        // Train with 4 subvectors (16/4 = 4 dims each), 8 centroids
1234        let pq = ProductQuantizer::train(&refs, 4, 8, 5);
1235
1236        assert_eq!(pq.num_subvectors(), 4);
1237        assert_eq!(pq.num_centroids(), 8);
1238        assert_eq!(pq.dimensions(), 16);
1239        assert_eq!(pq.subvector_dim(), 4);
1240        assert_eq!(pq.code_size(), 4);
1241    }
1242
1243    #[test]
1244    fn test_product_quantizer_quantize() {
1245        let vectors: Vec<Vec<f32>> = (0..50)
1246            .map(|i| (0..8).map(|j| ((i * j) as f32 * 0.1).cos()).collect())
1247            .collect();
1248        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1249
1250        let pq = ProductQuantizer::train(&refs, 2, 16, 3);
1251
1252        // Quantize a vector
1253        let codes = pq.quantize(&vectors[0]);
1254        assert_eq!(codes.len(), 2);
1255
1256        // All codes should be < num_centroids
1257        for &code in &codes {
1258            assert!(code < 16);
1259        }
1260    }
1261
1262    #[test]
1263    fn test_product_quantizer_reconstruct() {
1264        let vectors: Vec<Vec<f32>> = (0..50)
1265            .map(|i| (0..12).map(|j| (i + j) as f32 * 0.05).collect())
1266            .collect();
1267        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1268
1269        let pq = ProductQuantizer::train(&refs, 3, 8, 5);
1270
1271        // Quantize and reconstruct
1272        let original = &vectors[10];
1273        let codes = pq.quantize(original);
1274        let reconstructed = pq.reconstruct(&codes);
1275
1276        assert_eq!(reconstructed.len(), 12);
1277
1278        // Reconstructed should be somewhat close to original (not exact due to quantization)
1279        let error: f32 = original
1280            .iter()
1281            .zip(&reconstructed)
1282            .map(|(a, b)| (a - b).powi(2))
1283            .sum::<f32>()
1284            .sqrt();
1285
1286        // Error should be bounded (not zero, but reasonable)
1287        assert!(error < 2.0, "Reconstruction error too high: {error}");
1288    }
1289
1290    #[test]
1291    fn test_product_quantizer_asymmetric_distance() {
1292        let vectors: Vec<Vec<f32>> = (0..100)
1293            .map(|i| (0..32).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
1294            .collect();
1295        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1296
1297        let pq = ProductQuantizer::train(&refs, 8, 32, 5);
1298
1299        // Distance to self should be small
1300        let query = &vectors[0];
1301        let codes = pq.quantize(query);
1302        let self_dist = pq.asymmetric_distance(query, &codes);
1303        assert!(self_dist < 1.0, "Self-distance too high: {self_dist}");
1304
1305        // Distance to different vector should be larger
1306        let other_codes = pq.quantize(&vectors[50]);
1307        let other_dist = pq.asymmetric_distance(query, &other_codes);
1308        assert!(other_dist > self_dist, "Other vector should be farther");
1309    }
1310
1311    #[test]
1312    fn test_product_quantizer_distance_table() {
1313        let vectors: Vec<Vec<f32>> = (0..50)
1314            .map(|i| (0..16).map(|j| (i + j) as f32 * 0.02).collect())
1315            .collect();
1316        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1317
1318        let pq = ProductQuantizer::train(&refs, 4, 8, 3);
1319
1320        let query = &vectors[0];
1321        let table = pq.build_distance_table(query);
1322
1323        // Table should have M * K entries
1324        assert_eq!(table.len(), 4 * 8);
1325
1326        // Distance via table should match direct computation
1327        let codes = pq.quantize(&vectors[5]);
1328        let dist_direct = pq.asymmetric_distance_squared(query, &codes);
1329        let dist_table = pq.distance_with_table(&table, &codes);
1330
1331        assert!((dist_direct - dist_table).abs() < 0.001);
1332    }
1333
1334    #[test]
1335    fn test_product_quantizer_batch() {
1336        let vectors: Vec<Vec<f32>> = (0..20)
1337            .map(|i| (0..8).map(|j| (i + j) as f32).collect())
1338            .collect();
1339        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1340
1341        let pq = ProductQuantizer::train(&refs, 2, 4, 2);
1342
1343        let batch_codes = pq.quantize_batch(&refs[0..5]);
1344        assert_eq!(batch_codes.len(), 5);
1345
1346        for codes in &batch_codes {
1347            assert_eq!(codes.len(), 2);
1348        }
1349    }
1350
1351    #[test]
1352    fn test_product_quantizer_compression_ratio() {
1353        let vectors: Vec<Vec<f32>> = (0..50)
1354            .map(|i| (0..384).map(|j| ((i * j) as f32).sin()).collect())
1355            .collect();
1356        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1357
1358        // PQ8: 384 dims split into 8 subvectors
1359        let pq8 = ProductQuantizer::train(&refs, 8, 256, 3);
1360        assert_eq!(pq8.compression_ratio(), 192); // (384 * 4) / 8 = 192
1361
1362        // PQ48: 384 dims split into 48 subvectors (8 dims each)
1363        let pq48 = ProductQuantizer::train(&refs, 48, 256, 3);
1364        assert_eq!(pq48.compression_ratio(), 32); // (384 * 4) / 48 = 32
1365    }
1366
1367    #[test]
1368    #[should_panic(expected = "dimensions (15) must be divisible by num_subvectors (4)")]
1369    fn test_product_quantizer_invalid_dimensions() {
1370        let vectors: Vec<Vec<f32>> = (0..10)
1371            .map(|i| (0..15).map(|j| (i + j) as f32).collect())
1372            .collect();
1373        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1374
1375        // 15 is not divisible by 4
1376        let _ = ProductQuantizer::train(&refs, 4, 8, 3);
1377    }
1378
1379    #[test]
1380    fn test_product_quantizer_get_partition_centroids() {
1381        let vectors: Vec<Vec<f32>> = (0..30)
1382            .map(|i| (0..8).map(|j| (i + j) as f32 * 0.1).collect())
1383            .collect();
1384        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1385
1386        let pq = ProductQuantizer::train(&refs, 2, 4, 3);
1387
1388        // Get centroids for first partition
1389        let centroids = pq.get_partition_centroids(0);
1390        assert_eq!(centroids.len(), 4); // 4 centroids
1391        assert_eq!(centroids[0].len(), 4); // 4 dims per subvector (8/2)
1392    }
1393}