Skip to main content

grafeo_core/index/vector/
quantization.rs

1//! Vector quantization algorithms for memory-efficient storage.
2//!
3//! Quantization reduces vector precision for memory savings:
4//!
5//! | Method  | Compression | Accuracy | Speed    | Use Case                    |
6//! |---------|-------------|----------|----------|----------------------------|
7//! | Scalar  | 4x          | ~97%     | Fast     | Default for most datasets   |
8//! | Binary  | 32x         | ~80%     | Fastest  | Very large datasets         |
9//!
10//! # Scalar Quantization
11//!
12//! Converts f32 values to u8 by learning min/max ranges per dimension:
13//!
14//! ```ignore
15//! use grafeo_core::index::vector::quantization::ScalarQuantizer;
16//!
17//! let vectors: Vec<Vec<f32>> = get_training_vectors();
18//! let quantizer = ScalarQuantizer::train(&vectors);
19//!
20//! // Quantize: f32 -> u8 (4x compression)
21//! let original = vec![0.1f32, 0.5, 0.9];
22//! let quantized = quantizer.quantize(&original);
23//!
24//! // Compute distance in quantized space (approximate)
25//! let dist = quantizer.distance_u8(&quantized, &other_quantized);
26//! ```
27//!
28//! # Binary Quantization
29//!
30//! Converts f32 values to bits (sign only), enabling hamming distance:
31//!
32//! ```ignore
33//! use grafeo_core::index::vector::quantization::BinaryQuantizer;
34//!
35//! let v = vec![0.1f32, -0.5, 0.0, 0.9];
36//! let bits = BinaryQuantizer::quantize(&v);
37//!
38//! // Hamming distance (count differing bits)
39//! let dist = BinaryQuantizer::hamming_distance(&bits, &other_bits);
40//! ```
41
42use serde::{Deserialize, Serialize};
43
44// ============================================================================
45// Quantization Type
46// ============================================================================
47
48/// Quantization strategy for vector storage.
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
50pub enum QuantizationType {
51    /// No quantization - full f32 precision.
52    #[default]
53    None,
54    /// Scalar quantization: f32 -> u8 (4x compression, ~97% accuracy).
55    Scalar,
56    /// Binary quantization: f32 -> 1 bit (32x compression, ~80% accuracy).
57    Binary,
58    /// Product quantization: f32 -> M u8 codes (8-32x compression, ~90% accuracy).
59    Product {
60        /// Number of subvectors (typically 8, 16, 32, 64).
61        num_subvectors: usize,
62    },
63}
64
65impl QuantizationType {
66    /// Returns the compression ratio (memory reduction factor).
67    ///
68    /// For Product quantization, this depends on dimensions and num_subvectors.
69    /// The ratio is approximate: dimensions * 4 / num_subvectors.
70    #[must_use]
71    pub fn compression_ratio(&self, dimensions: usize) -> usize {
72        match self {
73            Self::None => 1,
74            Self::Scalar => 4,  // f32 (4 bytes) -> u8 (1 byte)
75            Self::Binary => 32, // f32 (4 bytes) -> 1 bit (0.125 bytes)
76            Self::Product { num_subvectors } => {
77                // Original: dimensions * 4 bytes (f32)
78                // Compressed: num_subvectors bytes (u8 codes)
79                // Ratio: (dimensions * 4) / num_subvectors
80                let m = (*num_subvectors).max(1);
81                (dimensions * 4) / m
82            }
83        }
84    }
85
86    /// Returns the name of the quantization type.
87    #[must_use]
88    pub fn name(&self) -> &'static str {
89        match self {
90            Self::None => "none",
91            Self::Scalar => "scalar",
92            Self::Binary => "binary",
93            Self::Product { .. } => "product",
94        }
95    }
96
97    /// Parses from string (case-insensitive).
98    #[must_use]
99    pub fn from_str(s: &str) -> Option<Self> {
100        match s.to_lowercase().as_str() {
101            "none" | "full" | "f32" => Some(Self::None),
102            "scalar" | "sq" | "u8" | "int8" => Some(Self::Scalar),
103            "binary" | "bin" | "bit" | "1bit" => Some(Self::Binary),
104            "product" | "pq" => Some(Self::Product { num_subvectors: 8 }),
105            s if s.starts_with("pq") => {
106                // Parse "pq8", "pq16", etc.
107                s[2..]
108                    .parse()
109                    .ok()
110                    .map(|n| Self::Product { num_subvectors: n })
111            }
112            _ => None,
113        }
114    }
115
116    /// Returns true if this quantization type requires training.
117    #[must_use]
118    pub const fn requires_training(&self) -> bool {
119        matches!(self, Self::Scalar | Self::Product { .. })
120    }
121}
122
123// ============================================================================
124// Scalar Quantization
125// ============================================================================
126
127/// Scalar quantizer: f32 -> u8 with per-dimension min/max scaling.
128///
129/// Training learns the min/max value for each dimension, then quantizes
130/// values to [0, 255] range. This achieves 4x compression with typically
131/// >97% recall retention.
132///
133/// # Example
134///
135/// ```
136/// use grafeo_core::index::vector::quantization::ScalarQuantizer;
137///
138/// // Training vectors
139/// let vectors = vec![
140///     vec![0.0f32, 0.5, 1.0],
141///     vec![0.2, 0.3, 0.8],
142///     vec![0.1, 0.6, 0.9],
143/// ];
144///
145/// // Train quantizer
146/// let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
147/// let quantizer = ScalarQuantizer::train(&refs);
148///
149/// // Quantize a vector
150/// let quantized = quantizer.quantize(&[0.1, 0.4, 0.85]);
151/// assert_eq!(quantized.len(), 3);
152///
153/// // Compute approximate distance
154/// let q2 = quantizer.quantize(&[0.15, 0.45, 0.9]);
155/// let dist = quantizer.distance_squared_u8(&quantized, &q2);
156/// assert!(dist < 1000.0);
157/// ```
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ScalarQuantizer {
160    /// Minimum value per dimension.
161    min: Vec<f32>,
162    /// Scale factor per dimension: 255 / (max - min).
163    scale: Vec<f32>,
164    /// Inverse scale for distance computation: (max - min) / 255.
165    inv_scale: Vec<f32>,
166    /// Number of dimensions.
167    dimensions: usize,
168}
169
170impl ScalarQuantizer {
171    /// Trains a scalar quantizer from sample vectors.
172    ///
173    /// Learns the min/max value per dimension from the training data.
174    /// The more representative the training data, the better the quantization.
175    ///
176    /// # Arguments
177    ///
178    /// * `vectors` - Training vectors (should be representative of the dataset)
179    ///
180    /// # Panics
181    ///
182    /// Panics if `vectors` is empty or if vectors have different dimensions.
183    #[must_use]
184    pub fn train(vectors: &[&[f32]]) -> Self {
185        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
186
187        let dimensions = vectors[0].len();
188        assert!(
189            vectors.iter().all(|v| v.len() == dimensions),
190            "All training vectors must have the same dimensions"
191        );
192
193        // Find min/max per dimension
194        let mut min = vec![f32::INFINITY; dimensions];
195        let mut max = vec![f32::NEG_INFINITY; dimensions];
196
197        for vec in vectors {
198            for (i, &v) in vec.iter().enumerate() {
199                min[i] = min[i].min(v);
200                max[i] = max[i].max(v);
201            }
202        }
203
204        // Compute scale factors (avoid division by zero)
205        let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
206            .iter()
207            .zip(&max)
208            .map(|(&mn, &mx)| {
209                let range = mx - mn;
210                if range.abs() < f32::EPSILON {
211                    // All values are the same, use 1.0 as scale
212                    (1.0, 1.0)
213                } else {
214                    (255.0 / range, range / 255.0)
215                }
216            })
217            .unzip();
218
219        Self {
220            min,
221            scale,
222            inv_scale,
223            dimensions,
224        }
225    }
226
227    /// Creates a quantizer with explicit ranges (useful for testing).
228    #[must_use]
229    pub fn with_ranges(min: Vec<f32>, max: Vec<f32>) -> Self {
230        let dimensions = min.len();
231        assert_eq!(min.len(), max.len(), "Min and max must have same length");
232
233        let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
234            .iter()
235            .zip(&max)
236            .map(|(&mn, &mx)| {
237                let range = mx - mn;
238                if range.abs() < f32::EPSILON {
239                    (1.0, 1.0)
240                } else {
241                    (255.0 / range, range / 255.0)
242                }
243            })
244            .unzip();
245
246        Self {
247            min,
248            scale,
249            inv_scale,
250            dimensions,
251        }
252    }
253
254    /// Returns the number of dimensions.
255    #[must_use]
256    pub fn dimensions(&self) -> usize {
257        self.dimensions
258    }
259
260    /// Returns the min values per dimension.
261    #[must_use]
262    pub fn min_values(&self) -> &[f32] {
263        &self.min
264    }
265
266    /// Quantizes an f32 vector to u8.
267    ///
268    /// Values are clamped to the learned [min, max] range.
269    #[must_use]
270    pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
271        debug_assert_eq!(
272            vector.len(),
273            self.dimensions,
274            "Vector dimension mismatch: expected {}, got {}",
275            self.dimensions,
276            vector.len()
277        );
278
279        vector
280            .iter()
281            .enumerate()
282            .map(|(i, &v)| {
283                let normalized = (v - self.min[i]) * self.scale[i];
284                normalized.clamp(0.0, 255.0) as u8
285            })
286            .collect()
287    }
288
289    /// Quantizes multiple vectors in batch.
290    #[must_use]
291    pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
292        vectors.iter().map(|v| self.quantize(v)).collect()
293    }
294
295    /// Dequantizes a u8 vector back to f32 (approximate).
296    #[must_use]
297    pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
298        debug_assert_eq!(quantized.len(), self.dimensions);
299
300        quantized
301            .iter()
302            .enumerate()
303            .map(|(i, &q)| self.min[i] + (q as f32) * self.inv_scale[i])
304            .collect()
305    }
306
307    /// Computes squared Euclidean distance between quantized vectors.
308    ///
309    /// This is an approximation that works well for ranking nearest neighbors.
310    /// The returned distance is scaled back to the original space.
311    #[must_use]
312    pub fn distance_squared_u8(&self, a: &[u8], b: &[u8]) -> f32 {
313        debug_assert_eq!(a.len(), self.dimensions);
314        debug_assert_eq!(b.len(), self.dimensions);
315
316        // Compute in quantized space, then scale
317        let mut sum = 0.0f32;
318        for i in 0..a.len() {
319            let diff = (a[i] as f32) - (b[i] as f32);
320            sum += diff * diff * self.inv_scale[i] * self.inv_scale[i];
321        }
322        sum
323    }
324
325    /// Computes Euclidean distance between quantized vectors.
326    #[must_use]
327    #[inline]
328    pub fn distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
329        self.distance_squared_u8(a, b).sqrt()
330    }
331
332    /// Computes approximate cosine distance using quantized vectors.
333    ///
334    /// This is less accurate than exact computation but much faster.
335    #[must_use]
336    pub fn cosine_distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
337        debug_assert_eq!(a.len(), self.dimensions);
338        debug_assert_eq!(b.len(), self.dimensions);
339
340        let mut dot = 0.0f32;
341        let mut norm_a = 0.0f32;
342        let mut norm_b = 0.0f32;
343
344        for i in 0..a.len() {
345            // Dequantize on the fly
346            let va = self.min[i] + (a[i] as f32) * self.inv_scale[i];
347            let vb = self.min[i] + (b[i] as f32) * self.inv_scale[i];
348
349            dot += va * vb;
350            norm_a += va * va;
351            norm_b += vb * vb;
352        }
353
354        let denom = (norm_a * norm_b).sqrt();
355        if denom < f32::EPSILON {
356            1.0 // Maximum distance for zero vectors
357        } else {
358            1.0 - (dot / denom)
359        }
360    }
361
362    /// Computes distance between a f32 query and a quantized vector.
363    ///
364    /// This is useful for search where we keep the query in full precision.
365    #[must_use]
366    pub fn asymmetric_distance_squared(&self, query: &[f32], quantized: &[u8]) -> f32 {
367        debug_assert_eq!(query.len(), self.dimensions);
368        debug_assert_eq!(quantized.len(), self.dimensions);
369
370        let mut sum = 0.0f32;
371        for i in 0..query.len() {
372            // Dequantize the stored vector
373            let dequant = self.min[i] + (quantized[i] as f32) * self.inv_scale[i];
374            let diff = query[i] - dequant;
375            sum += diff * diff;
376        }
377        sum
378    }
379
380    /// Computes asymmetric Euclidean distance.
381    #[must_use]
382    #[inline]
383    pub fn asymmetric_distance(&self, query: &[f32], quantized: &[u8]) -> f32 {
384        self.asymmetric_distance_squared(query, quantized).sqrt()
385    }
386}
387
388// ============================================================================
389// Binary Quantization
390// ============================================================================
391
392/// Binary quantizer: f32 -> 1 bit (sign only).
393///
394/// Provides extreme compression (32x) at the cost of accuracy (~80% recall).
395/// Uses hamming distance for fast comparison. Best used with rescoring.
396///
397/// # Example
398///
399/// ```
400/// use grafeo_core::index::vector::quantization::BinaryQuantizer;
401///
402/// let v1 = vec![0.5f32, -0.3, 0.0, 0.8, -0.1, 0.2, -0.4, 0.9];
403/// let v2 = vec![0.4f32, -0.2, 0.1, 0.7, -0.2, 0.3, -0.3, 0.8];
404///
405/// let bits1 = BinaryQuantizer::quantize(&v1);
406/// let bits2 = BinaryQuantizer::quantize(&v2);
407///
408/// let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
409/// // Vectors are similar, so hamming distance should be low
410/// assert!(dist < 4);
411/// ```
412pub struct BinaryQuantizer;
413
414impl BinaryQuantizer {
415    /// Quantizes f32 vector to binary (sign bits packed in u64).
416    ///
417    /// Each f32 becomes 1 bit: 1 if >= 0, 0 if < 0.
418    /// Bits are packed into u64 words (64 dimensions per word).
419    #[must_use]
420    pub fn quantize(vector: &[f32]) -> Vec<u64> {
421        let num_words = (vector.len() + 63) / 64;
422        let mut result = vec![0u64; num_words];
423
424        for (i, &v) in vector.iter().enumerate() {
425            if v >= 0.0 {
426                result[i / 64] |= 1u64 << (i % 64);
427            }
428        }
429
430        result
431    }
432
433    /// Quantizes multiple vectors in batch.
434    #[must_use]
435    pub fn quantize_batch(vectors: &[&[f32]]) -> Vec<Vec<u64>> {
436        vectors.iter().map(|v| Self::quantize(v)).collect()
437    }
438
439    /// Computes hamming distance between binary vectors.
440    ///
441    /// Counts the number of differing bits. Lower = more similar.
442    #[must_use]
443    pub fn hamming_distance(a: &[u64], b: &[u64]) -> u32 {
444        debug_assert_eq!(a.len(), b.len(), "Binary vectors must have same length");
445
446        a.iter().zip(b).map(|(&x, &y)| (x ^ y).count_ones()).sum()
447    }
448
449    /// Computes normalized hamming distance (0.0 to 1.0).
450    ///
451    /// Returns the fraction of bits that differ.
452    #[must_use]
453    pub fn hamming_distance_normalized(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
454        let hamming = Self::hamming_distance(a, b);
455        hamming as f32 / dimensions as f32
456    }
457
458    /// Estimates Euclidean distance from hamming distance.
459    ///
460    /// Uses an empirical approximation: d_euclidean ≈ sqrt(2 * hamming / dim).
461    /// This is a rough estimate suitable for initial filtering.
462    #[must_use]
463    pub fn approximate_euclidean(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
464        let hamming = Self::hamming_distance(a, b);
465        // Empirical approximation: assume values are roughly unit-normalized
466        (2.0 * hamming as f32 / dimensions as f32).sqrt()
467    }
468
469    /// Returns the number of u64 words needed for the given dimensions.
470    #[must_use]
471    pub const fn words_needed(dimensions: usize) -> usize {
472        (dimensions + 63) / 64
473    }
474
475    /// Returns the memory footprint in bytes for quantized storage.
476    #[must_use]
477    pub const fn bytes_needed(dimensions: usize) -> usize {
478        Self::words_needed(dimensions) * 8
479    }
480}
481
482// ============================================================================
483// Product Quantization
484// ============================================================================
485
486/// Product quantizer: splits vectors into M subvectors, quantizes each to K centroids.
487///
488/// Product Quantization (PQ) provides excellent compression (8-32x) with ~90% recall.
489/// It works by:
490/// 1. Dividing vectors into M subvectors
491/// 2. Learning K centroids (typically 256) for each subvector via k-means
492/// 3. Storing each vector as M u8 codes (indices into centroid tables)
493///
494/// Distance computation uses asymmetric distance tables (ADC) for efficiency.
495///
496/// # Example
497///
498/// ```
499/// use grafeo_core::index::vector::quantization::ProductQuantizer;
500///
501/// // Training vectors (16 dimensions, split into 4 subvectors)
502/// let vectors: Vec<Vec<f32>> = (0..50)
503///     .map(|i| (0..16).map(|j| (i + j) as f32 * 0.1).collect())
504///     .collect();
505/// let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
506///
507/// // Train quantizer with 4 subvectors, 8 centroids each
508/// let quantizer = ProductQuantizer::train(&refs, 4, 8, 5);
509///
510/// // Quantize a vector to 4 u8 codes
511/// let query = &vectors[0];
512/// let codes = quantizer.quantize(query);
513/// assert_eq!(codes.len(), 4);
514///
515/// // Each code is an index into the centroid table (0-7)
516/// assert!(codes.iter().all(|&c| c < 8));
517///
518/// // Reconstruct approximate vector from codes
519/// let reconstructed = quantizer.reconstruct(&codes);
520/// assert_eq!(reconstructed.len(), 16);
521/// ```
522#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct ProductQuantizer {
524    /// Number of subvectors (M).
525    num_subvectors: usize,
526    /// Number of centroids per subvector (K, typically 256 for u8 codes).
527    num_centroids: usize,
528    /// Dimensions per subvector.
529    subvector_dim: usize,
530    /// Total dimensions.
531    dimensions: usize,
532    /// Centroids: [M][K][subvector_dim] flattened to [M * K * subvector_dim].
533    centroids: Vec<f32>,
534}
535
536impl ProductQuantizer {
537    /// Trains a product quantizer from sample vectors using k-means clustering.
538    ///
539    /// # Arguments
540    ///
541    /// * `vectors` - Training vectors (should be representative of the dataset)
542    /// * `num_subvectors` - Number of subvectors (M), must divide dimensions evenly
543    /// * `num_centroids` - Number of centroids per subvector (K), typically 256
544    /// * `iterations` - Number of k-means iterations (10-20 is usually sufficient)
545    ///
546    /// # Panics
547    ///
548    /// Panics if vectors is empty, dimensions not divisible by num_subvectors,
549    /// or num_centroids > 256.
550    #[must_use]
551    pub fn train(
552        vectors: &[&[f32]],
553        num_subvectors: usize,
554        num_centroids: usize,
555        iterations: usize,
556    ) -> Self {
557        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
558        assert!(
559            num_centroids <= 256,
560            "num_centroids must be <= 256 for u8 codes"
561        );
562        assert!(num_subvectors > 0, "num_subvectors must be > 0");
563
564        let dimensions = vectors[0].len();
565        assert!(
566            dimensions.is_multiple_of(num_subvectors),
567            "dimensions ({dimensions}) must be divisible by num_subvectors ({num_subvectors})"
568        );
569        assert!(
570            vectors.iter().all(|v| v.len() == dimensions),
571            "All training vectors must have the same dimensions"
572        );
573
574        let subvector_dim = dimensions / num_subvectors;
575
576        // Train centroids for each subvector independently
577        let mut centroids = Vec::with_capacity(num_subvectors * num_centroids * subvector_dim);
578
579        for m in 0..num_subvectors {
580            // Extract subvectors for this partition
581            let subvectors: Vec<Vec<f32>> = vectors
582                .iter()
583                .map(|v| {
584                    let start = m * subvector_dim;
585                    let end = start + subvector_dim;
586                    v[start..end].to_vec()
587                })
588                .collect();
589
590            // Run k-means on this partition
591            let partition_centroids =
592                Self::kmeans(&subvectors, num_centroids, subvector_dim, iterations);
593
594            centroids.extend(partition_centroids);
595        }
596
597        Self {
598            num_subvectors,
599            num_centroids,
600            subvector_dim,
601            dimensions,
602            centroids,
603        }
604    }
605
606    /// Simple k-means clustering implementation.
607    fn kmeans(vectors: &[Vec<f32>], k: usize, dims: usize, iterations: usize) -> Vec<f32> {
608        let n = vectors.len();
609
610        // Initialize centroids using k-means++ style (first k vectors or random sampling)
611        let actual_k = k.min(n);
612        let mut centroids: Vec<f32> = if actual_k == n {
613            vectors.iter().flat_map(|v| v.iter().copied()).collect()
614        } else {
615            // Take evenly spaced samples
616            let step = n / actual_k;
617            (0..actual_k)
618                .flat_map(|i| vectors[i * step].iter().copied())
619                .collect()
620        };
621
622        // Pad with zeros if we don't have enough training vectors
623        if actual_k < k {
624            centroids.resize(k * dims, 0.0);
625        }
626
627        let mut assignments = vec![0usize; n];
628        let mut counts = vec![0usize; k];
629
630        for _ in 0..iterations {
631            // Assignment step: find nearest centroid for each vector
632            for (i, vec) in vectors.iter().enumerate() {
633                let mut best_dist = f32::INFINITY;
634                let mut best_k = 0;
635
636                for j in 0..k {
637                    let centroid_start = j * dims;
638                    let dist: f32 = vec
639                        .iter()
640                        .enumerate()
641                        .map(|(d, &v)| {
642                            let diff = v - centroids[centroid_start + d];
643                            diff * diff
644                        })
645                        .sum();
646
647                    if dist < best_dist {
648                        best_dist = dist;
649                        best_k = j;
650                    }
651                }
652
653                assignments[i] = best_k;
654            }
655
656            // Update step: recompute centroids as mean of assigned vectors
657            centroids.fill(0.0);
658            counts.fill(0);
659
660            for (i, vec) in vectors.iter().enumerate() {
661                let k_idx = assignments[i];
662                let centroid_start = k_idx * dims;
663                counts[k_idx] += 1;
664
665                for (d, &v) in vec.iter().enumerate() {
666                    centroids[centroid_start + d] += v;
667                }
668            }
669
670            // Divide by counts to get means
671            for j in 0..k {
672                if counts[j] > 0 {
673                    let centroid_start = j * dims;
674                    let count = counts[j] as f32;
675                    for d in 0..dims {
676                        centroids[centroid_start + d] /= count;
677                    }
678                }
679            }
680        }
681
682        centroids
683    }
684
685    /// Creates a product quantizer with explicit centroids (for testing/loading).
686    #[must_use]
687    pub fn with_centroids(
688        num_subvectors: usize,
689        num_centroids: usize,
690        dimensions: usize,
691        centroids: Vec<f32>,
692    ) -> Self {
693        let subvector_dim = dimensions / num_subvectors;
694        assert_eq!(
695            centroids.len(),
696            num_subvectors * num_centroids * subvector_dim,
697            "Invalid centroid count"
698        );
699
700        Self {
701            num_subvectors,
702            num_centroids,
703            subvector_dim,
704            dimensions,
705            centroids,
706        }
707    }
708
709    /// Returns the number of subvectors (M).
710    #[must_use]
711    pub fn num_subvectors(&self) -> usize {
712        self.num_subvectors
713    }
714
715    /// Returns the number of centroids per subvector (K).
716    #[must_use]
717    pub fn num_centroids(&self) -> usize {
718        self.num_centroids
719    }
720
721    /// Returns the total dimensions.
722    #[must_use]
723    pub fn dimensions(&self) -> usize {
724        self.dimensions
725    }
726
727    /// Returns the dimensions per subvector.
728    #[must_use]
729    pub fn subvector_dim(&self) -> usize {
730        self.subvector_dim
731    }
732
733    /// Returns the memory footprint in bytes for a quantized vector.
734    #[must_use]
735    pub fn code_size(&self) -> usize {
736        self.num_subvectors // M u8 codes
737    }
738
739    /// Returns the compression ratio compared to f32 storage.
740    #[must_use]
741    pub fn compression_ratio(&self) -> usize {
742        // Original: dimensions * 4 bytes
743        // Compressed: num_subvectors bytes
744        (self.dimensions * 4) / self.num_subvectors
745    }
746
747    /// Quantizes a vector to M u8 codes.
748    ///
749    /// Each code is the index of the nearest centroid for that subvector.
750    #[must_use]
751    pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
752        debug_assert_eq!(
753            vector.len(),
754            self.dimensions,
755            "Vector dimension mismatch: expected {}, got {}",
756            self.dimensions,
757            vector.len()
758        );
759
760        let mut codes = Vec::with_capacity(self.num_subvectors);
761
762        for m in 0..self.num_subvectors {
763            let subvec_start = m * self.subvector_dim;
764            let subvec = &vector[subvec_start..subvec_start + self.subvector_dim];
765
766            // Find nearest centroid for this subvector
767            let mut best_dist = f32::INFINITY;
768            let mut best_k = 0u8;
769
770            for k in 0..self.num_centroids {
771                let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
772                let dist: f32 = subvec
773                    .iter()
774                    .enumerate()
775                    .map(|(d, &v)| {
776                        let diff = v - self.centroids[centroid_start + d];
777                        diff * diff
778                    })
779                    .sum();
780
781                if dist < best_dist {
782                    best_dist = dist;
783                    best_k = k as u8;
784                }
785            }
786
787            codes.push(best_k);
788        }
789
790        codes
791    }
792
793    /// Quantizes multiple vectors in batch.
794    #[must_use]
795    pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
796        vectors.iter().map(|v| self.quantize(v)).collect()
797    }
798
799    /// Builds asymmetric distance table for a query vector.
800    ///
801    /// Returns a table of shape \[M\]\[K\] containing the squared distance
802    /// from each query subvector to each centroid. This allows O(M) distance
803    /// computation for quantized vectors via table lookups.
804    #[must_use]
805    pub fn build_distance_table(&self, query: &[f32]) -> Vec<f32> {
806        debug_assert_eq!(query.len(), self.dimensions);
807
808        let mut table = Vec::with_capacity(self.num_subvectors * self.num_centroids);
809
810        for m in 0..self.num_subvectors {
811            let query_start = m * self.subvector_dim;
812            let query_subvec = &query[query_start..query_start + self.subvector_dim];
813
814            for k in 0..self.num_centroids {
815                let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
816
817                let dist: f32 = query_subvec
818                    .iter()
819                    .enumerate()
820                    .map(|(d, &v)| {
821                        let diff = v - self.centroids[centroid_start + d];
822                        diff * diff
823                    })
824                    .sum();
825
826                table.push(dist);
827            }
828        }
829
830        table
831    }
832
833    /// Computes asymmetric squared distance using a precomputed distance table.
834    ///
835    /// This is O(M) - just M table lookups and additions.
836    #[must_use]
837    #[inline]
838    pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
839        debug_assert_eq!(codes.len(), self.num_subvectors);
840        debug_assert_eq!(table.len(), self.num_subvectors * self.num_centroids);
841
842        codes
843            .iter()
844            .enumerate()
845            .map(|(m, &code)| table[m * self.num_centroids + code as usize])
846            .sum()
847    }
848
849    /// Computes asymmetric squared distance from query to quantized vector.
850    ///
851    /// This builds the distance table on the fly - use `build_distance_table`
852    /// and `distance_with_table` for batch queries.
853    #[must_use]
854    pub fn asymmetric_distance_squared(&self, query: &[f32], codes: &[u8]) -> f32 {
855        let table = self.build_distance_table(query);
856        self.distance_with_table(&table, codes)
857    }
858
859    /// Computes asymmetric distance (Euclidean).
860    #[must_use]
861    #[inline]
862    pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
863        self.asymmetric_distance_squared(query, codes).sqrt()
864    }
865
866    /// Reconstructs an approximate vector from codes.
867    ///
868    /// Returns the concatenated centroids for the given codes.
869    #[must_use]
870    pub fn reconstruct(&self, codes: &[u8]) -> Vec<f32> {
871        debug_assert_eq!(codes.len(), self.num_subvectors);
872
873        let mut result = Vec::with_capacity(self.dimensions);
874
875        for (m, &code) in codes.iter().enumerate() {
876            let centroid_start = (m * self.num_centroids + code as usize) * self.subvector_dim;
877            result.extend_from_slice(
878                &self.centroids[centroid_start..centroid_start + self.subvector_dim],
879            );
880        }
881
882        result
883    }
884
885    /// Returns the centroid vectors for a specific subvector partition.
886    #[must_use]
887    pub fn get_partition_centroids(&self, partition: usize) -> Vec<&[f32]> {
888        assert!(partition < self.num_subvectors);
889
890        (0..self.num_centroids)
891            .map(|k| {
892                let start = (partition * self.num_centroids + k) * self.subvector_dim;
893                &self.centroids[start..start + self.subvector_dim]
894            })
895            .collect()
896    }
897}
898
899// ============================================================================
900// SIMD-Accelerated Hamming Distance
901// ============================================================================
902
903/// Computes hamming distance with SIMD acceleration (if available).
904///
905/// On x86_64 with popcnt instruction, this is significantly faster than
906/// the scalar implementation.
907#[cfg(target_arch = "x86_64")]
908#[must_use]
909pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
910    // Use popcnt instruction if available (almost all modern CPUs)
911    a.iter()
912        .zip(b)
913        .map(|(&x, &y)| {
914            let xor = x ^ y;
915            // Safety: popcnt is available on virtually all x86_64 CPUs since Nehalem (2008).
916            // This is a well-understood CPU intrinsic with no memory safety implications.
917            #[allow(unsafe_code)]
918            unsafe {
919                std::arch::x86_64::_popcnt64(xor as i64) as u32
920            }
921        })
922        .sum()
923}
924
925/// Fallback scalar implementation.
926#[cfg(not(target_arch = "x86_64"))]
927#[must_use]
928pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
929    BinaryQuantizer::hamming_distance(a, b)
930}
931
932// ============================================================================
933// Tests
934// ============================================================================
935
936#[cfg(test)]
937mod tests {
938    use super::*;
939
940    #[test]
941    fn test_quantization_type_compression_ratio() {
942        // Use 384 dimensions (common embedding size)
943        let dims = 384;
944        assert_eq!(QuantizationType::None.compression_ratio(dims), 1);
945        assert_eq!(QuantizationType::Scalar.compression_ratio(dims), 4);
946        assert_eq!(QuantizationType::Binary.compression_ratio(dims), 32);
947
948        // Product quantization: (384 * 4) / 8 = 192x compression
949        let pq8 = QuantizationType::Product { num_subvectors: 8 };
950        assert_eq!(pq8.compression_ratio(dims), 192);
951
952        // PQ16: (384 * 4) / 16 = 96x compression
953        let pq16 = QuantizationType::Product { num_subvectors: 16 };
954        assert_eq!(pq16.compression_ratio(dims), 96);
955    }
956
957    #[test]
958    fn test_quantization_type_from_str() {
959        assert_eq!(
960            QuantizationType::from_str("none"),
961            Some(QuantizationType::None)
962        );
963        assert_eq!(
964            QuantizationType::from_str("scalar"),
965            Some(QuantizationType::Scalar)
966        );
967        assert_eq!(
968            QuantizationType::from_str("SQ"),
969            Some(QuantizationType::Scalar)
970        );
971        assert_eq!(
972            QuantizationType::from_str("binary"),
973            Some(QuantizationType::Binary)
974        );
975        assert_eq!(
976            QuantizationType::from_str("bit"),
977            Some(QuantizationType::Binary)
978        );
979        assert_eq!(QuantizationType::from_str("invalid"), None);
980    }
981
982    // ========================================================================
983    // Scalar Quantization Tests
984    // ========================================================================
985
986    #[test]
987    fn test_scalar_quantizer_train() {
988        let vectors = [
989            vec![0.0f32, 0.5, 1.0],
990            vec![0.2, 0.3, 0.8],
991            vec![0.1, 0.6, 0.9],
992        ];
993        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
994
995        let quantizer = ScalarQuantizer::train(&refs);
996
997        assert_eq!(quantizer.dimensions(), 3);
998        assert_eq!(quantizer.min_values()[0], 0.0);
999        assert_eq!(quantizer.min_values()[1], 0.3);
1000        assert_eq!(quantizer.min_values()[2], 0.8);
1001    }
1002
1003    #[test]
1004    fn test_scalar_quantizer_quantize() {
1005        let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1006
1007        // Min value should quantize to 0
1008        let q_min = quantizer.quantize(&[0.0, 0.0]);
1009        assert_eq!(q_min, vec![0, 0]);
1010
1011        // Max value should quantize to 255
1012        let q_max = quantizer.quantize(&[1.0, 1.0]);
1013        assert_eq!(q_max, vec![255, 255]);
1014
1015        // Middle value should quantize to ~127
1016        let q_mid = quantizer.quantize(&[0.5, 0.5]);
1017        assert!(q_mid[0] >= 126 && q_mid[0] <= 128);
1018    }
1019
1020    #[test]
1021    fn test_scalar_quantizer_dequantize() {
1022        let quantizer = ScalarQuantizer::with_ranges(vec![0.0], vec![1.0]);
1023
1024        let original = [0.5f32];
1025        let quantized = quantizer.quantize(&original);
1026        let dequantized = quantizer.dequantize(&quantized);
1027
1028        // Should be close to original (within quantization error)
1029        assert!((original[0] - dequantized[0]).abs() < 0.01);
1030    }
1031
1032    #[test]
1033    fn test_scalar_quantizer_distance() {
1034        let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1035
1036        let a = quantizer.quantize(&[0.0, 0.0]);
1037        let b = quantizer.quantize(&[1.0, 0.0]);
1038
1039        let dist = quantizer.distance_u8(&a, &b);
1040        // Should be approximately 1.0 (the Euclidean distance in original space)
1041        assert!((dist - 1.0).abs() < 0.1);
1042    }
1043
1044    #[test]
1045    fn test_scalar_quantizer_asymmetric_distance() {
1046        let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1047
1048        let query = [0.0f32, 0.0];
1049        let stored = quantizer.quantize(&[1.0, 0.0]);
1050
1051        let dist = quantizer.asymmetric_distance(&query, &stored);
1052        assert!((dist - 1.0).abs() < 0.1);
1053    }
1054
1055    #[test]
1056    fn test_scalar_quantizer_cosine_distance() {
1057        let quantizer = ScalarQuantizer::with_ranges(vec![-1.0, -1.0], vec![1.0, 1.0]);
1058
1059        // Orthogonal vectors
1060        let a = quantizer.quantize(&[1.0, 0.0]);
1061        let b = quantizer.quantize(&[0.0, 1.0]);
1062
1063        let dist = quantizer.cosine_distance_u8(&a, &b);
1064        // Cosine distance of orthogonal vectors = 1.0
1065        assert!((dist - 1.0).abs() < 0.1);
1066    }
1067
1068    #[test]
1069    #[should_panic(expected = "Cannot train on empty vector set")]
1070    fn test_scalar_quantizer_empty_training() {
1071        let vectors: Vec<&[f32]> = vec![];
1072        let _ = ScalarQuantizer::train(&vectors);
1073    }
1074
1075    // ========================================================================
1076    // Binary Quantization Tests
1077    // ========================================================================
1078
1079    #[test]
1080    fn test_binary_quantizer_quantize() {
1081        let v = vec![0.5f32, -0.3, 0.0, 0.8];
1082        let bits = BinaryQuantizer::quantize(&v);
1083
1084        assert_eq!(bits.len(), 1); // 4 dims fit in 1 u64
1085
1086        // Check individual bits: 0.5 >= 0 (1), -0.3 < 0 (0), 0.0 >= 0 (1), 0.8 >= 0 (1)
1087        // Expected bits (LSB first): 1, 0, 1, 1 = 0b1101 = 13
1088        assert_eq!(bits[0] & 0xF, 0b1101);
1089    }
1090
1091    #[test]
1092    fn test_binary_quantizer_hamming_distance() {
1093        let v1 = vec![1.0f32, 1.0, 1.0, 1.0]; // All positive: 1111
1094        let v2 = vec![1.0f32, -1.0, 1.0, -1.0]; // Mixed: 1010
1095
1096        let bits1 = BinaryQuantizer::quantize(&v1);
1097        let bits2 = BinaryQuantizer::quantize(&v2);
1098
1099        let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
1100        assert_eq!(dist, 2); // Two bits differ
1101    }
1102
1103    #[test]
1104    fn test_binary_quantizer_identical_vectors() {
1105        let v = vec![0.1f32, -0.2, 0.3, -0.4, 0.5];
1106        let bits = BinaryQuantizer::quantize(&v);
1107
1108        let dist = BinaryQuantizer::hamming_distance(&bits, &bits);
1109        assert_eq!(dist, 0);
1110    }
1111
1112    #[test]
1113    fn test_binary_quantizer_opposite_vectors() {
1114        let v1 = vec![1.0f32; 64];
1115        let v2 = vec![-1.0f32; 64];
1116
1117        let bits1 = BinaryQuantizer::quantize(&v1);
1118        let bits2 = BinaryQuantizer::quantize(&v2);
1119
1120        let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
1121        assert_eq!(dist, 64); // All bits differ
1122    }
1123
1124    #[test]
1125    fn test_binary_quantizer_large_vector() {
1126        let v: Vec<f32> = (0..1000)
1127            .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
1128            .collect();
1129        let bits = BinaryQuantizer::quantize(&v);
1130
1131        // 1000 dims needs ceil(1000/64) = 16 words
1132        assert_eq!(bits.len(), 16);
1133    }
1134
1135    #[test]
1136    fn test_binary_quantizer_normalized_distance() {
1137        let v1 = vec![1.0f32; 100];
1138        let v2 = vec![-1.0f32; 100];
1139
1140        let bits1 = BinaryQuantizer::quantize(&v1);
1141        let bits2 = BinaryQuantizer::quantize(&v2);
1142
1143        let norm_dist = BinaryQuantizer::hamming_distance_normalized(&bits1, &bits2, 100);
1144        assert!((norm_dist - 1.0).abs() < 0.01); // All bits differ
1145    }
1146
1147    #[test]
1148    fn test_binary_quantizer_words_needed() {
1149        assert_eq!(BinaryQuantizer::words_needed(1), 1);
1150        assert_eq!(BinaryQuantizer::words_needed(64), 1);
1151        assert_eq!(BinaryQuantizer::words_needed(65), 2);
1152        assert_eq!(BinaryQuantizer::words_needed(128), 2);
1153        assert_eq!(BinaryQuantizer::words_needed(1536), 24); // OpenAI embedding size
1154    }
1155
1156    #[test]
1157    fn test_binary_quantizer_bytes_needed() {
1158        // Each u64 is 8 bytes
1159        assert_eq!(BinaryQuantizer::bytes_needed(64), 8);
1160        assert_eq!(BinaryQuantizer::bytes_needed(128), 16);
1161        assert_eq!(BinaryQuantizer::bytes_needed(1536), 192); // vs 6144 for f32
1162    }
1163
1164    // ========================================================================
1165    // SIMD Tests
1166    // ========================================================================
1167
1168    #[test]
1169    fn test_hamming_distance_simd() {
1170        let a = vec![0xFFFF_FFFF_FFFF_FFFFu64, 0x0000_0000_0000_0000];
1171        let b = vec![0x0000_0000_0000_0000u64, 0xFFFF_FFFF_FFFF_FFFF];
1172
1173        let dist = hamming_distance_simd(&a, &b);
1174        assert_eq!(dist, 128); // All 128 bits differ
1175    }
1176
1177    // ========================================================================
1178    // Product Quantization Tests
1179    // ========================================================================
1180
1181    #[test]
1182    fn test_quantization_type_product_from_str() {
1183        // Basic PQ parsing
1184        assert_eq!(
1185            QuantizationType::from_str("pq"),
1186            Some(QuantizationType::Product { num_subvectors: 8 })
1187        );
1188        assert_eq!(
1189            QuantizationType::from_str("product"),
1190            Some(QuantizationType::Product { num_subvectors: 8 })
1191        );
1192
1193        // PQ with specific subvector count
1194        assert_eq!(
1195            QuantizationType::from_str("pq8"),
1196            Some(QuantizationType::Product { num_subvectors: 8 })
1197        );
1198        assert_eq!(
1199            QuantizationType::from_str("pq16"),
1200            Some(QuantizationType::Product { num_subvectors: 16 })
1201        );
1202        assert_eq!(
1203            QuantizationType::from_str("pq32"),
1204            Some(QuantizationType::Product { num_subvectors: 32 })
1205        );
1206    }
1207
1208    #[test]
1209    fn test_quantization_type_requires_training() {
1210        assert!(!QuantizationType::None.requires_training());
1211        assert!(QuantizationType::Scalar.requires_training());
1212        assert!(!QuantizationType::Binary.requires_training());
1213        assert!(QuantizationType::Product { num_subvectors: 8 }.requires_training());
1214    }
1215
1216    #[test]
1217    fn test_product_quantizer_train() {
1218        // Create 100 training vectors with 16 dimensions
1219        let vectors: Vec<Vec<f32>> = (0..100)
1220            .map(|i| (0..16).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
1221            .collect();
1222        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1223
1224        // Train with 4 subvectors (16/4 = 4 dims each), 8 centroids
1225        let pq = ProductQuantizer::train(&refs, 4, 8, 5);
1226
1227        assert_eq!(pq.num_subvectors(), 4);
1228        assert_eq!(pq.num_centroids(), 8);
1229        assert_eq!(pq.dimensions(), 16);
1230        assert_eq!(pq.subvector_dim(), 4);
1231        assert_eq!(pq.code_size(), 4);
1232    }
1233
1234    #[test]
1235    fn test_product_quantizer_quantize() {
1236        let vectors: Vec<Vec<f32>> = (0..50)
1237            .map(|i| (0..8).map(|j| ((i * j) as f32 * 0.1).cos()).collect())
1238            .collect();
1239        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1240
1241        let pq = ProductQuantizer::train(&refs, 2, 16, 3);
1242
1243        // Quantize a vector
1244        let codes = pq.quantize(&vectors[0]);
1245        assert_eq!(codes.len(), 2);
1246
1247        // All codes should be < num_centroids
1248        for &code in &codes {
1249            assert!(code < 16);
1250        }
1251    }
1252
1253    #[test]
1254    fn test_product_quantizer_reconstruct() {
1255        let vectors: Vec<Vec<f32>> = (0..50)
1256            .map(|i| (0..12).map(|j| (i + j) as f32 * 0.05).collect())
1257            .collect();
1258        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1259
1260        let pq = ProductQuantizer::train(&refs, 3, 8, 5);
1261
1262        // Quantize and reconstruct
1263        let original = &vectors[10];
1264        let codes = pq.quantize(original);
1265        let reconstructed = pq.reconstruct(&codes);
1266
1267        assert_eq!(reconstructed.len(), 12);
1268
1269        // Reconstructed should be somewhat close to original (not exact due to quantization)
1270        let error: f32 = original
1271            .iter()
1272            .zip(&reconstructed)
1273            .map(|(a, b)| (a - b).powi(2))
1274            .sum::<f32>()
1275            .sqrt();
1276
1277        // Error should be bounded (not zero, but reasonable)
1278        assert!(error < 2.0, "Reconstruction error too high: {error}");
1279    }
1280
1281    #[test]
1282    fn test_product_quantizer_asymmetric_distance() {
1283        let vectors: Vec<Vec<f32>> = (0..100)
1284            .map(|i| (0..32).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
1285            .collect();
1286        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1287
1288        let pq = ProductQuantizer::train(&refs, 8, 32, 5);
1289
1290        // Distance to self should be small
1291        let query = &vectors[0];
1292        let codes = pq.quantize(query);
1293        let self_dist = pq.asymmetric_distance(query, &codes);
1294        assert!(self_dist < 1.0, "Self-distance too high: {self_dist}");
1295
1296        // Distance to different vector should be larger
1297        let other_codes = pq.quantize(&vectors[50]);
1298        let other_dist = pq.asymmetric_distance(query, &other_codes);
1299        assert!(other_dist > self_dist, "Other vector should be farther");
1300    }
1301
1302    #[test]
1303    fn test_product_quantizer_distance_table() {
1304        let vectors: Vec<Vec<f32>> = (0..50)
1305            .map(|i| (0..16).map(|j| (i + j) as f32 * 0.02).collect())
1306            .collect();
1307        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1308
1309        let pq = ProductQuantizer::train(&refs, 4, 8, 3);
1310
1311        let query = &vectors[0];
1312        let table = pq.build_distance_table(query);
1313
1314        // Table should have M * K entries
1315        assert_eq!(table.len(), 4 * 8);
1316
1317        // Distance via table should match direct computation
1318        let codes = pq.quantize(&vectors[5]);
1319        let dist_direct = pq.asymmetric_distance_squared(query, &codes);
1320        let dist_table = pq.distance_with_table(&table, &codes);
1321
1322        assert!((dist_direct - dist_table).abs() < 0.001);
1323    }
1324
1325    #[test]
1326    fn test_product_quantizer_batch() {
1327        let vectors: Vec<Vec<f32>> = (0..20)
1328            .map(|i| (0..8).map(|j| (i + j) as f32).collect())
1329            .collect();
1330        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1331
1332        let pq = ProductQuantizer::train(&refs, 2, 4, 2);
1333
1334        let batch_codes = pq.quantize_batch(&refs[0..5]);
1335        assert_eq!(batch_codes.len(), 5);
1336
1337        for codes in &batch_codes {
1338            assert_eq!(codes.len(), 2);
1339        }
1340    }
1341
1342    #[test]
1343    fn test_product_quantizer_compression_ratio() {
1344        let vectors: Vec<Vec<f32>> = (0..50)
1345            .map(|i| (0..384).map(|j| ((i * j) as f32).sin()).collect())
1346            .collect();
1347        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1348
1349        // PQ8: 384 dims split into 8 subvectors
1350        let pq8 = ProductQuantizer::train(&refs, 8, 256, 3);
1351        assert_eq!(pq8.compression_ratio(), 192); // (384 * 4) / 8 = 192
1352
1353        // PQ48: 384 dims split into 48 subvectors (8 dims each)
1354        let pq48 = ProductQuantizer::train(&refs, 48, 256, 3);
1355        assert_eq!(pq48.compression_ratio(), 32); // (384 * 4) / 48 = 32
1356    }
1357
1358    #[test]
1359    #[should_panic(expected = "dimensions (15) must be divisible by num_subvectors (4)")]
1360    fn test_product_quantizer_invalid_dimensions() {
1361        let vectors: Vec<Vec<f32>> = (0..10)
1362            .map(|i| (0..15).map(|j| (i + j) as f32).collect())
1363            .collect();
1364        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1365
1366        // 15 is not divisible by 4
1367        let _ = ProductQuantizer::train(&refs, 4, 8, 3);
1368    }
1369
1370    #[test]
1371    fn test_product_quantizer_get_partition_centroids() {
1372        let vectors: Vec<Vec<f32>> = (0..30)
1373            .map(|i| (0..8).map(|j| (i + j) as f32 * 0.1).collect())
1374            .collect();
1375        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1376
1377        let pq = ProductQuantizer::train(&refs, 2, 4, 3);
1378
1379        // Get centroids for first partition
1380        let centroids = pq.get_partition_centroids(0);
1381        assert_eq!(centroids.len(), 4); // 4 centroids
1382        assert_eq!(centroids[0].len(), 4); // 4 dims per subvector (8/2)
1383    }
1384}