Skip to main content

grafeo_core/index/vector/
quantization.rs

1//! Vector quantization algorithms for memory-efficient storage.
2//!
3//! Quantization reduces vector precision for memory savings:
4//!
5//! | Method  | Compression | Accuracy | Speed    | Use Case                    |
6//! |---------|-------------|----------|----------|----------------------------|
7//! | Scalar  | 4x          | ~97%     | Fast     | Default for most datasets   |
8//! | Binary  | 32x         | ~80%     | Fastest  | Very large datasets         |
9//!
10//! # Scalar Quantization
11//!
12//! Converts f32 values to u8 by learning min/max ranges per dimension:
13//!
14//! ```
15//! use grafeo_core::index::vector::quantization::ScalarQuantizer;
16//!
17//! // Training vectors to learn min/max ranges
18//! let vectors = vec![
19//!     vec![0.0f32, 0.3, 0.7],
20//!     vec![0.2, 0.5, 1.0],
21//!     vec![0.1, 0.6, 0.9],
22//! ];
23//! let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
24//! let quantizer = ScalarQuantizer::train(&refs);
25//!
26//! // Quantize: f32 -> u8 (4x compression)
27//! let original = vec![0.1f32, 0.5, 0.9];
28//! let quantized = quantizer.quantize(&original);
29//!
30//! // Compute distance in quantized space (approximate)
31//! let other_quantized = quantizer.quantize(&[0.15, 0.45, 0.85]);
32//! let dist = quantizer.distance_u8(&quantized, &other_quantized);
33//! ```
34//!
35//! # Binary Quantization
36//!
37//! Converts f32 values to bits (sign only), enabling hamming distance:
38//!
39//! ```
40//! use grafeo_core::index::vector::quantization::BinaryQuantizer;
41//!
42//! let v1 = vec![0.1f32, -0.5, 0.0, 0.9];
43//! let v2 = vec![0.2f32, -0.3, 0.1, 0.8];
44//! let bits1 = BinaryQuantizer::quantize(&v1);
45//! let bits2 = BinaryQuantizer::quantize(&v2);
46//!
47//! // Hamming distance (count differing bits)
48//! let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
49//! ```
50
51use serde::{Deserialize, Serialize};
52
53// ============================================================================
54// Quantization Type
55// ============================================================================
56
57/// Quantization strategy for vector storage.
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
59#[non_exhaustive]
60pub enum QuantizationType {
61    /// No quantization - full f32 precision.
62    #[default]
63    None,
64    /// Scalar quantization: f32 -> u8 (4x compression, ~97% accuracy).
65    Scalar,
66    /// Binary quantization: f32 -> 1 bit (32x compression, ~80% accuracy).
67    Binary,
68    /// Product quantization: f32 -> M u8 codes (8-32x compression, ~90% accuracy).
69    Product {
70        /// Number of subvectors (typically 8, 16, 32, 64).
71        num_subvectors: usize,
72    },
73}
74
75impl QuantizationType {
76    /// Returns the compression ratio (memory reduction factor).
77    ///
78    /// For Product quantization, this depends on dimensions and num_subvectors.
79    /// The ratio is approximate: dimensions * 4 / num_subvectors.
80    #[must_use]
81    pub fn compression_ratio(&self, dimensions: usize) -> usize {
82        match self {
83            Self::None => 1,
84            Self::Scalar => 4,  // f32 (4 bytes) -> u8 (1 byte)
85            Self::Binary => 32, // f32 (4 bytes) -> 1 bit (0.125 bytes)
86            Self::Product { num_subvectors } => {
87                // Original: dimensions * 4 bytes (f32)
88                // Compressed: num_subvectors bytes (u8 codes)
89                // Ratio: (dimensions * 4) / num_subvectors
90                let m = (*num_subvectors).max(1);
91                (dimensions * 4) / m
92            }
93        }
94    }
95
96    /// Returns the name of the quantization type.
97    #[must_use]
98    pub fn name(&self) -> &'static str {
99        match self {
100            Self::None => "none",
101            Self::Scalar => "scalar",
102            Self::Binary => "binary",
103            Self::Product { .. } => "product",
104        }
105    }
106
107    /// Parses from string (case-insensitive).
108    #[must_use]
109    pub fn from_str(s: &str) -> Option<Self> {
110        match s.to_lowercase().as_str() {
111            "none" | "full" | "f32" => Some(Self::None),
112            "scalar" | "sq" | "u8" | "int8" => Some(Self::Scalar),
113            "binary" | "bin" | "bit" | "1bit" => Some(Self::Binary),
114            "product" | "pq" => Some(Self::Product { num_subvectors: 8 }),
115            s if s.starts_with("pq") => {
116                // Parse "pq8", "pq16", etc.
117                s[2..]
118                    .parse()
119                    .ok()
120                    .map(|n| Self::Product { num_subvectors: n })
121            }
122            _ => None,
123        }
124    }
125
126    /// Returns true if this quantization type requires training.
127    #[must_use]
128    pub const fn requires_training(&self) -> bool {
129        matches!(self, Self::Scalar | Self::Product { .. })
130    }
131}
132
133// ============================================================================
134// Scalar Quantization
135// ============================================================================
136
137/// Scalar quantizer: f32 -> u8 with per-dimension min/max scaling.
138///
139/// Training learns the min/max value for each dimension, then quantizes
140/// values to [0, 255] range. This achieves 4x compression with typically
141/// >97% recall retention.
142///
143/// # Example
144///
145/// ```
146/// use grafeo_core::index::vector::quantization::ScalarQuantizer;
147///
148/// // Training vectors
149/// let vectors = vec![
150///     vec![0.0f32, 0.5, 1.0],
151///     vec![0.2, 0.3, 0.8],
152///     vec![0.1, 0.6, 0.9],
153/// ];
154///
155/// // Train quantizer
156/// let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
157/// let quantizer = ScalarQuantizer::train(&refs);
158///
159/// // Quantize a vector
160/// let quantized = quantizer.quantize(&[0.1, 0.4, 0.85]);
161/// assert_eq!(quantized.len(), 3);
162///
163/// // Compute approximate distance
164/// let q2 = quantizer.quantize(&[0.15, 0.45, 0.9]);
165/// let dist = quantizer.distance_squared_u8(&quantized, &q2);
166/// assert!(dist < 1000.0);
167/// ```
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct ScalarQuantizer {
170    /// Minimum value per dimension.
171    min: Vec<f32>,
172    /// Scale factor per dimension: 255 / (max - min).
173    scale: Vec<f32>,
174    /// Inverse scale for distance computation: (max - min) / 255.
175    inv_scale: Vec<f32>,
176    /// Number of dimensions.
177    dimensions: usize,
178}
179
180impl ScalarQuantizer {
181    /// Trains a scalar quantizer from sample vectors.
182    ///
183    /// Learns the min/max value per dimension from the training data.
184    /// The more representative the training data, the better the quantization.
185    ///
186    /// # Arguments
187    ///
188    /// * `vectors` - Training vectors (should be representative of the dataset)
189    ///
190    /// # Panics
191    ///
192    /// Panics if `vectors` is empty or if vectors have different dimensions.
193    #[must_use]
194    pub fn train(vectors: &[&[f32]]) -> Self {
195        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
196
197        let dimensions = vectors[0].len();
198        assert!(
199            vectors.iter().all(|v| v.len() == dimensions),
200            "All training vectors must have the same dimensions"
201        );
202
203        // Find min/max per dimension
204        let mut min = vec![f32::INFINITY; dimensions];
205        let mut max = vec![f32::NEG_INFINITY; dimensions];
206
207        for vec in vectors {
208            for (i, &v) in vec.iter().enumerate() {
209                min[i] = min[i].min(v);
210                max[i] = max[i].max(v);
211            }
212        }
213
214        // Compute scale factors (avoid division by zero)
215        let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
216            .iter()
217            .zip(&max)
218            .map(|(&mn, &mx)| {
219                let range = mx - mn;
220                if range.abs() < f32::EPSILON {
221                    // All values are the same, use 1.0 as scale
222                    (1.0, 1.0)
223                } else {
224                    (255.0 / range, range / 255.0)
225                }
226            })
227            .unzip();
228
229        Self {
230            min,
231            scale,
232            inv_scale,
233            dimensions,
234        }
235    }
236
237    /// Creates a quantizer with explicit ranges (useful for testing).
238    ///
239    /// # Panics
240    ///
241    /// Panics if `min` and `max` have different lengths.
242    #[must_use]
243    pub fn with_ranges(min: Vec<f32>, max: Vec<f32>) -> Self {
244        let dimensions = min.len();
245        assert_eq!(min.len(), max.len(), "Min and max must have same length");
246
247        let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
248            .iter()
249            .zip(&max)
250            .map(|(&mn, &mx)| {
251                let range = mx - mn;
252                if range.abs() < f32::EPSILON {
253                    (1.0, 1.0)
254                } else {
255                    (255.0 / range, range / 255.0)
256                }
257            })
258            .unzip();
259
260        Self {
261            min,
262            scale,
263            inv_scale,
264            dimensions,
265        }
266    }
267
268    /// Returns the number of dimensions.
269    #[must_use]
270    pub fn dimensions(&self) -> usize {
271        self.dimensions
272    }
273
274    /// Returns the min values per dimension.
275    #[must_use]
276    pub fn min_values(&self) -> &[f32] {
277        &self.min
278    }
279
280    /// Quantizes an f32 vector to u8.
281    ///
282    /// Values are clamped to the learned [min, max] range.
283    #[must_use]
284    pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
285        debug_assert_eq!(
286            vector.len(),
287            self.dimensions,
288            "Vector dimension mismatch: expected {}, got {}",
289            self.dimensions,
290            vector.len()
291        );
292
293        vector
294            .iter()
295            .enumerate()
296            .map(|(i, &v)| {
297                let normalized = (v - self.min[i]) * self.scale[i];
298                // reason: clamped to [0, 255] by clamp(), safe to cast to u8
299                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
300                {
301                    normalized.clamp(0.0, 255.0) as u8
302                }
303            })
304            .collect()
305    }
306
307    /// Quantizes multiple vectors in batch.
308    #[must_use]
309    pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
310        vectors.iter().map(|v| self.quantize(v)).collect()
311    }
312
313    /// Dequantizes a u8 vector back to f32 (approximate).
314    #[must_use]
315    pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
316        debug_assert_eq!(quantized.len(), self.dimensions);
317
318        quantized
319            .iter()
320            .enumerate()
321            .map(|(i, &q)| self.min[i] + (q as f32) * self.inv_scale[i])
322            .collect()
323    }
324
325    /// Computes squared Euclidean distance between quantized vectors.
326    ///
327    /// This is an approximation that works well for ranking nearest neighbors.
328    /// The returned distance is scaled back to the original space.
329    #[must_use]
330    pub fn distance_squared_u8(&self, a: &[u8], b: &[u8]) -> f32 {
331        debug_assert_eq!(a.len(), self.dimensions);
332        debug_assert_eq!(b.len(), self.dimensions);
333
334        // Compute in quantized space, then scale
335        let mut sum = 0.0f32;
336        for i in 0..a.len() {
337            let diff = (a[i] as f32) - (b[i] as f32);
338            sum += diff * diff * self.inv_scale[i] * self.inv_scale[i];
339        }
340        sum
341    }
342
343    /// Computes Euclidean distance between quantized vectors.
344    #[must_use]
345    #[inline]
346    pub fn distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
347        self.distance_squared_u8(a, b).sqrt()
348    }
349
350    /// Computes approximate cosine distance using quantized vectors.
351    ///
352    /// This is less accurate than exact computation but much faster.
353    #[must_use]
354    pub fn cosine_distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
355        debug_assert_eq!(a.len(), self.dimensions);
356        debug_assert_eq!(b.len(), self.dimensions);
357
358        let mut dot = 0.0f32;
359        let mut norm_a = 0.0f32;
360        let mut norm_b = 0.0f32;
361
362        for i in 0..a.len() {
363            // Dequantize on the fly
364            let va = self.min[i] + (a[i] as f32) * self.inv_scale[i];
365            let vb = self.min[i] + (b[i] as f32) * self.inv_scale[i];
366
367            dot += va * vb;
368            norm_a += va * va;
369            norm_b += vb * vb;
370        }
371
372        let denom = (norm_a * norm_b).sqrt();
373        if denom < f32::EPSILON {
374            1.0 // Maximum distance for zero vectors
375        } else {
376            1.0 - (dot / denom)
377        }
378    }
379
380    /// Computes distance between a f32 query and a quantized vector.
381    ///
382    /// This is useful for search where we keep the query in full precision.
383    #[must_use]
384    pub fn asymmetric_distance_squared(&self, query: &[f32], quantized: &[u8]) -> f32 {
385        debug_assert_eq!(query.len(), self.dimensions);
386        debug_assert_eq!(quantized.len(), self.dimensions);
387
388        let mut sum = 0.0f32;
389        for i in 0..query.len() {
390            // Dequantize the stored vector
391            let dequant = self.min[i] + (quantized[i] as f32) * self.inv_scale[i];
392            let diff = query[i] - dequant;
393            sum += diff * diff;
394        }
395        sum
396    }
397
398    /// Computes asymmetric Euclidean distance.
399    #[must_use]
400    #[inline]
401    pub fn asymmetric_distance(&self, query: &[f32], quantized: &[u8]) -> f32 {
402        self.asymmetric_distance_squared(query, quantized).sqrt()
403    }
404}
405
406// ============================================================================
407// Binary Quantization
408// ============================================================================
409
410/// Binary quantizer: f32 -> 1 bit (sign only).
411///
412/// Provides extreme compression (32x) at the cost of accuracy (~80% recall).
413/// Uses hamming distance for fast comparison. Best used with rescoring.
414///
415/// # Example
416///
417/// ```
418/// use grafeo_core::index::vector::quantization::BinaryQuantizer;
419///
420/// let v1 = vec![0.5f32, -0.3, 0.0, 0.8, -0.1, 0.2, -0.4, 0.9];
421/// let v2 = vec![0.4f32, -0.2, 0.1, 0.7, -0.2, 0.3, -0.3, 0.8];
422///
423/// let bits1 = BinaryQuantizer::quantize(&v1);
424/// let bits2 = BinaryQuantizer::quantize(&v2);
425///
426/// let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
427/// // Vectors are similar, so hamming distance should be low
428/// assert!(dist < 4);
429/// ```
430pub struct BinaryQuantizer;
431
432impl BinaryQuantizer {
433    /// Quantizes f32 vector to binary (sign bits packed in u64).
434    ///
435    /// Each f32 becomes 1 bit: 1 if >= 0, 0 if < 0.
436    /// Bits are packed into u64 words (64 dimensions per word).
437    #[must_use]
438    pub fn quantize(vector: &[f32]) -> Vec<u64> {
439        let num_words = (vector.len() + 63) / 64;
440        let mut result = vec![0u64; num_words];
441
442        for (i, &v) in vector.iter().enumerate() {
443            if v >= 0.0 {
444                result[i / 64] |= 1u64 << (i % 64);
445            }
446        }
447
448        result
449    }
450
451    /// Quantizes multiple vectors in batch.
452    #[must_use]
453    pub fn quantize_batch(vectors: &[&[f32]]) -> Vec<Vec<u64>> {
454        vectors.iter().map(|v| Self::quantize(v)).collect()
455    }
456
457    /// Computes hamming distance between binary vectors.
458    ///
459    /// Counts the number of differing bits. Lower = more similar.
460    #[must_use]
461    pub fn hamming_distance(a: &[u64], b: &[u64]) -> u32 {
462        debug_assert_eq!(a.len(), b.len(), "Binary vectors must have same length");
463
464        a.iter().zip(b).map(|(&x, &y)| (x ^ y).count_ones()).sum()
465    }
466
467    /// Computes normalized hamming distance (0.0 to 1.0).
468    ///
469    /// Returns the fraction of bits that differ.
470    #[must_use]
471    pub fn hamming_distance_normalized(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
472        let hamming = Self::hamming_distance(a, b);
473        hamming as f32 / dimensions as f32
474    }
475
476    /// Estimates Euclidean distance from hamming distance.
477    ///
478    /// Uses an empirical approximation: d_euclidean ≈ sqrt(2 * hamming / dim).
479    /// This is a rough estimate suitable for initial filtering.
480    #[must_use]
481    pub fn approximate_euclidean(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
482        let hamming = Self::hamming_distance(a, b);
483        // Empirical approximation: assume values are roughly unit-normalized
484        (2.0 * hamming as f32 / dimensions as f32).sqrt()
485    }
486
487    /// Returns the number of u64 words needed for the given dimensions.
488    #[must_use]
489    pub const fn words_needed(dimensions: usize) -> usize {
490        (dimensions + 63) / 64
491    }
492
493    /// Returns the memory footprint in bytes for quantized storage.
494    #[must_use]
495    pub const fn bytes_needed(dimensions: usize) -> usize {
496        Self::words_needed(dimensions) * 8
497    }
498}
499
500// ============================================================================
501// Product Quantization
502// ============================================================================
503
504/// Product quantizer: splits vectors into M subvectors, quantizes each to K centroids.
505///
506/// Product Quantization (PQ) provides excellent compression (8-32x) with ~90% recall.
507/// It works by:
508/// 1. Dividing vectors into M subvectors
509/// 2. Learning K centroids (typically 256) for each subvector via k-means
510/// 3. Storing each vector as M u8 codes (indices into centroid tables)
511///
512/// Distance computation uses asymmetric distance tables (ADC) for efficiency.
513///
514/// # Example
515///
516/// ```
517/// use grafeo_core::index::vector::quantization::ProductQuantizer;
518///
519/// // Training vectors (16 dimensions, split into 4 subvectors)
520/// let vectors: Vec<Vec<f32>> = (0..50)
521///     .map(|i| (0..16).map(|j| (i + j) as f32 * 0.1).collect())
522///     .collect();
523/// let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
524///
525/// // Train quantizer with 4 subvectors, 8 centroids each
526/// let quantizer = ProductQuantizer::train(&refs, 4, 8, 5);
527///
528/// // Quantize a vector to 4 u8 codes
529/// let query = &vectors[0];
530/// let codes = quantizer.quantize(query);
531/// assert_eq!(codes.len(), 4);
532///
533/// // Each code is an index into the centroid table (0-7)
534/// assert!(codes.iter().all(|&c| c < 8));
535///
536/// // Reconstruct approximate vector from codes
537/// let reconstructed = quantizer.reconstruct(&codes);
538/// assert_eq!(reconstructed.len(), 16);
539/// ```
540#[derive(Debug, Clone, Serialize, Deserialize)]
541pub struct ProductQuantizer {
542    /// Number of subvectors (M).
543    num_subvectors: usize,
544    /// Number of centroids per subvector (K, typically 256 for u8 codes).
545    num_centroids: usize,
546    /// Dimensions per subvector.
547    subvector_dim: usize,
548    /// Total dimensions.
549    dimensions: usize,
550    /// Centroids: [M][K][subvector_dim] flattened to [M * K * subvector_dim].
551    centroids: Vec<f32>,
552}
553
554impl ProductQuantizer {
555    /// Trains a product quantizer from sample vectors using k-means clustering.
556    ///
557    /// # Arguments
558    ///
559    /// * `vectors` - Training vectors (should be representative of the dataset)
560    /// * `num_subvectors` - Number of subvectors (M), must divide dimensions evenly
561    /// * `num_centroids` - Number of centroids per subvector (K), typically 256
562    /// * `iterations` - Number of k-means iterations (10-20 is usually sufficient)
563    ///
564    /// # Panics
565    ///
566    /// Panics if vectors is empty, dimensions not divisible by num_subvectors,
567    /// or num_centroids > 256.
568    #[must_use]
569    pub fn train(
570        vectors: &[&[f32]],
571        num_subvectors: usize,
572        num_centroids: usize,
573        iterations: usize,
574    ) -> Self {
575        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
576        assert!(
577            num_centroids <= 256,
578            "num_centroids must be <= 256 for u8 codes"
579        );
580        assert!(num_subvectors > 0, "num_subvectors must be > 0");
581
582        let dimensions = vectors[0].len();
583        assert!(
584            dimensions.is_multiple_of(num_subvectors),
585            "dimensions ({dimensions}) must be divisible by num_subvectors ({num_subvectors})"
586        );
587        assert!(
588            vectors.iter().all(|v| v.len() == dimensions),
589            "All training vectors must have the same dimensions"
590        );
591
592        let subvector_dim = dimensions / num_subvectors;
593
594        // Train centroids for each subvector independently
595        let mut centroids = Vec::with_capacity(num_subvectors * num_centroids * subvector_dim);
596
597        for m in 0..num_subvectors {
598            // Extract subvectors for this partition
599            let subvectors: Vec<Vec<f32>> = vectors
600                .iter()
601                .map(|v| {
602                    let start = m * subvector_dim;
603                    let end = start + subvector_dim;
604                    v[start..end].to_vec()
605                })
606                .collect();
607
608            // Run k-means on this partition
609            let partition_centroids =
610                Self::kmeans(&subvectors, num_centroids, subvector_dim, iterations);
611
612            centroids.extend(partition_centroids);
613        }
614
615        Self {
616            num_subvectors,
617            num_centroids,
618            subvector_dim,
619            dimensions,
620            centroids,
621        }
622    }
623
624    /// Simple k-means clustering implementation.
625    fn kmeans(vectors: &[Vec<f32>], k: usize, dims: usize, iterations: usize) -> Vec<f32> {
626        let n = vectors.len();
627
628        // Initialize centroids using k-means++ style (first k vectors or random sampling)
629        let actual_k = k.min(n);
630        let mut centroids: Vec<f32> = if actual_k == n {
631            vectors.iter().flat_map(|v| v.iter().copied()).collect()
632        } else {
633            // Take evenly spaced samples
634            let step = n / actual_k;
635            (0..actual_k)
636                .flat_map(|i| vectors[i * step].iter().copied())
637                .collect()
638        };
639
640        // Pad with zeros if we don't have enough training vectors
641        if actual_k < k {
642            centroids.resize(k * dims, 0.0);
643        }
644
645        let mut assignments = vec![0usize; n];
646        let mut counts = vec![0usize; k];
647
648        for _ in 0..iterations {
649            // Assignment step: find nearest centroid for each vector
650            for (i, vec) in vectors.iter().enumerate() {
651                let mut best_dist = f32::INFINITY;
652                let mut best_k = 0;
653
654                for j in 0..k {
655                    let centroid_start = j * dims;
656                    let dist: f32 = vec
657                        .iter()
658                        .enumerate()
659                        .map(|(d, &v)| {
660                            let diff = v - centroids[centroid_start + d];
661                            diff * diff
662                        })
663                        .sum();
664
665                    if dist < best_dist {
666                        best_dist = dist;
667                        best_k = j;
668                    }
669                }
670
671                assignments[i] = best_k;
672            }
673
674            // Update step: recompute centroids as mean of assigned vectors
675            centroids.fill(0.0);
676            counts.fill(0);
677
678            for (i, vec) in vectors.iter().enumerate() {
679                let k_idx = assignments[i];
680                let centroid_start = k_idx * dims;
681                counts[k_idx] += 1;
682
683                for (d, &v) in vec.iter().enumerate() {
684                    centroids[centroid_start + d] += v;
685                }
686            }
687
688            // Divide by counts to get means
689            for j in 0..k {
690                if counts[j] > 0 {
691                    let centroid_start = j * dims;
692                    let count = counts[j] as f32;
693                    for d in 0..dims {
694                        centroids[centroid_start + d] /= count;
695                    }
696                }
697            }
698        }
699
700        centroids
701    }
702
703    /// Creates a product quantizer with explicit centroids (for testing/loading).
704    ///
705    /// # Panics
706    ///
707    /// Panics if `num_subvectors` is zero (division by zero), or if `centroids.len()`
708    /// does not equal `num_subvectors * num_centroids * (dimensions / num_subvectors)`.
709    #[must_use]
710    pub fn with_centroids(
711        num_subvectors: usize,
712        num_centroids: usize,
713        dimensions: usize,
714        centroids: Vec<f32>,
715    ) -> Self {
716        let subvector_dim = dimensions / num_subvectors;
717        assert_eq!(
718            centroids.len(),
719            num_subvectors * num_centroids * subvector_dim,
720            "Invalid centroid count"
721        );
722
723        Self {
724            num_subvectors,
725            num_centroids,
726            subvector_dim,
727            dimensions,
728            centroids,
729        }
730    }
731
732    /// Returns the number of subvectors (M).
733    #[must_use]
734    pub fn num_subvectors(&self) -> usize {
735        self.num_subvectors
736    }
737
738    /// Returns the number of centroids per subvector (K).
739    #[must_use]
740    pub fn num_centroids(&self) -> usize {
741        self.num_centroids
742    }
743
744    /// Returns the total dimensions.
745    #[must_use]
746    pub fn dimensions(&self) -> usize {
747        self.dimensions
748    }
749
750    /// Returns the dimensions per subvector.
751    #[must_use]
752    pub fn subvector_dim(&self) -> usize {
753        self.subvector_dim
754    }
755
756    /// Returns the memory footprint in bytes for a quantized vector.
757    #[must_use]
758    pub fn code_size(&self) -> usize {
759        self.num_subvectors // M u8 codes
760    }
761
762    /// Returns the compression ratio compared to f32 storage.
763    #[must_use]
764    pub fn compression_ratio(&self) -> usize {
765        // Original: dimensions * 4 bytes
766        // Compressed: num_subvectors bytes
767        (self.dimensions * 4) / self.num_subvectors
768    }
769
770    /// Quantizes a vector to M u8 codes.
771    ///
772    /// Each code is the index of the nearest centroid for that subvector.
773    #[must_use]
774    pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
775        debug_assert_eq!(
776            vector.len(),
777            self.dimensions,
778            "Vector dimension mismatch: expected {}, got {}",
779            self.dimensions,
780            vector.len()
781        );
782
783        let mut codes = Vec::with_capacity(self.num_subvectors);
784
785        for m in 0..self.num_subvectors {
786            let subvec_start = m * self.subvector_dim;
787            let subvec = &vector[subvec_start..subvec_start + self.subvector_dim];
788
789            // Find nearest centroid for this subvector
790            let mut best_dist = f32::INFINITY;
791            let mut best_k = 0u8;
792
793            for k in 0..self.num_centroids {
794                let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
795                let dist: f32 = subvec
796                    .iter()
797                    .enumerate()
798                    .map(|(d, &v)| {
799                        let diff = v - self.centroids[centroid_start + d];
800                        diff * diff
801                    })
802                    .sum();
803
804                if dist < best_dist {
805                    best_dist = dist;
806                    // reason: k < num_centroids which is typically <= 256
807                    #[allow(clippy::cast_possible_truncation)]
808                    {
809                        best_k = k as u8;
810                    }
811                }
812            }
813
814            codes.push(best_k);
815        }
816
817        codes
818    }
819
820    /// Quantizes multiple vectors in batch.
821    #[must_use]
822    pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
823        vectors.iter().map(|v| self.quantize(v)).collect()
824    }
825
826    /// Builds asymmetric distance table for a query vector.
827    ///
828    /// Returns a table of shape \[M\]\[K\] containing the squared distance
829    /// from each query subvector to each centroid. This allows O(M) distance
830    /// computation for quantized vectors via table lookups.
831    #[must_use]
832    pub fn build_distance_table(&self, query: &[f32]) -> Vec<f32> {
833        debug_assert_eq!(query.len(), self.dimensions);
834
835        let mut table = Vec::with_capacity(self.num_subvectors * self.num_centroids);
836
837        for m in 0..self.num_subvectors {
838            let query_start = m * self.subvector_dim;
839            let query_subvec = &query[query_start..query_start + self.subvector_dim];
840
841            for k in 0..self.num_centroids {
842                let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
843
844                let dist: f32 = query_subvec
845                    .iter()
846                    .enumerate()
847                    .map(|(d, &v)| {
848                        let diff = v - self.centroids[centroid_start + d];
849                        diff * diff
850                    })
851                    .sum();
852
853                table.push(dist);
854            }
855        }
856
857        table
858    }
859
860    /// Computes asymmetric squared distance using a precomputed distance table.
861    ///
862    /// This is O(M) - just M table lookups and additions.
863    #[must_use]
864    #[inline]
865    pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
866        debug_assert_eq!(codes.len(), self.num_subvectors);
867        debug_assert_eq!(table.len(), self.num_subvectors * self.num_centroids);
868
869        codes
870            .iter()
871            .enumerate()
872            .map(|(m, &code)| table[m * self.num_centroids + code as usize])
873            .sum()
874    }
875
876    /// Computes asymmetric squared distance from query to quantized vector.
877    ///
878    /// This builds the distance table on the fly - use `build_distance_table`
879    /// and `distance_with_table` for batch queries.
880    #[must_use]
881    pub fn asymmetric_distance_squared(&self, query: &[f32], codes: &[u8]) -> f32 {
882        let table = self.build_distance_table(query);
883        self.distance_with_table(&table, codes)
884    }
885
886    /// Computes asymmetric distance (Euclidean).
887    #[must_use]
888    #[inline]
889    pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
890        self.asymmetric_distance_squared(query, codes).sqrt()
891    }
892
893    /// Reconstructs an approximate vector from codes.
894    ///
895    /// Returns the concatenated centroids for the given codes.
896    #[must_use]
897    pub fn reconstruct(&self, codes: &[u8]) -> Vec<f32> {
898        debug_assert_eq!(codes.len(), self.num_subvectors);
899
900        let mut result = Vec::with_capacity(self.dimensions);
901
902        for (m, &code) in codes.iter().enumerate() {
903            let centroid_start = (m * self.num_centroids + code as usize) * self.subvector_dim;
904            result.extend_from_slice(
905                &self.centroids[centroid_start..centroid_start + self.subvector_dim],
906            );
907        }
908
909        result
910    }
911
912    /// Returns the centroid vectors for a specific subvector partition.
913    ///
914    /// # Panics
915    ///
916    /// Panics if `partition` is greater than or equal to `num_subvectors`.
917    #[must_use]
918    pub fn get_partition_centroids(&self, partition: usize) -> Vec<&[f32]> {
919        assert!(partition < self.num_subvectors);
920
921        (0..self.num_centroids)
922            .map(|k| {
923                let start = (partition * self.num_centroids + k) * self.subvector_dim;
924                &self.centroids[start..start + self.subvector_dim]
925            })
926            .collect()
927    }
928}
929
930// ============================================================================
931// SIMD-Accelerated Hamming Distance
932// ============================================================================
933
934/// Computes hamming distance with SIMD acceleration (if available).
935///
936/// On x86_64 with popcnt instruction, this is significantly faster than
937/// the scalar implementation.
938#[cfg(target_arch = "x86_64")]
939#[must_use]
940pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
941    // Use popcnt instruction if available (almost all modern CPUs)
942    a.iter()
943        .zip(b)
944        .map(|(&x, &y)| {
945            let xor = x ^ y;
946            // Safety: popcnt is available on virtually all x86_64 CPUs since Nehalem (2008).
947            // This is a well-understood CPU intrinsic with no memory safety implications.
948            // reason: bit-level reinterpretation for popcount, and popcount result fits u32
949            #[allow(unsafe_code, clippy::cast_possible_wrap, clippy::cast_sign_loss)]
950            unsafe {
951                std::arch::x86_64::_popcnt64(xor as i64) as u32
952            }
953        })
954        .sum()
955}
956
957/// Fallback scalar implementation.
958#[cfg(not(target_arch = "x86_64"))]
959#[must_use]
960pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
961    BinaryQuantizer::hamming_distance(a, b)
962}
963
964// ============================================================================
965// Tests
966// ============================================================================
967
968#[cfg(test)]
969mod tests {
970    use super::*;
971
972    #[test]
973    fn test_quantization_type_compression_ratio() {
974        // Use 384 dimensions (common embedding size)
975        let dims = 384;
976        assert_eq!(QuantizationType::None.compression_ratio(dims), 1);
977        assert_eq!(QuantizationType::Scalar.compression_ratio(dims), 4);
978        assert_eq!(QuantizationType::Binary.compression_ratio(dims), 32);
979
980        // Product quantization: (384 * 4) / 8 = 192x compression
981        let pq8 = QuantizationType::Product { num_subvectors: 8 };
982        assert_eq!(pq8.compression_ratio(dims), 192);
983
984        // PQ16: (384 * 4) / 16 = 96x compression
985        let pq16 = QuantizationType::Product { num_subvectors: 16 };
986        assert_eq!(pq16.compression_ratio(dims), 96);
987    }
988
989    #[test]
990    fn test_quantization_type_from_str() {
991        assert_eq!(
992            QuantizationType::from_str("none"),
993            Some(QuantizationType::None)
994        );
995        assert_eq!(
996            QuantizationType::from_str("scalar"),
997            Some(QuantizationType::Scalar)
998        );
999        assert_eq!(
1000            QuantizationType::from_str("SQ"),
1001            Some(QuantizationType::Scalar)
1002        );
1003        assert_eq!(
1004            QuantizationType::from_str("binary"),
1005            Some(QuantizationType::Binary)
1006        );
1007        assert_eq!(
1008            QuantizationType::from_str("bit"),
1009            Some(QuantizationType::Binary)
1010        );
1011        assert_eq!(QuantizationType::from_str("invalid"), None);
1012    }
1013
1014    // ========================================================================
1015    // Scalar Quantization Tests
1016    // ========================================================================
1017
1018    #[test]
1019    fn test_scalar_quantizer_train() {
1020        let vectors = [
1021            vec![0.0f32, 0.5, 1.0],
1022            vec![0.2, 0.3, 0.8],
1023            vec![0.1, 0.6, 0.9],
1024        ];
1025        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1026
1027        let quantizer = ScalarQuantizer::train(&refs);
1028
1029        assert_eq!(quantizer.dimensions(), 3);
1030        assert_eq!(quantizer.min_values()[0], 0.0);
1031        assert_eq!(quantizer.min_values()[1], 0.3);
1032        assert_eq!(quantizer.min_values()[2], 0.8);
1033    }
1034
1035    #[test]
1036    fn test_scalar_quantizer_quantize() {
1037        let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1038
1039        // Min value should quantize to 0
1040        let q_min = quantizer.quantize(&[0.0, 0.0]);
1041        assert_eq!(q_min, vec![0, 0]);
1042
1043        // Max value should quantize to 255
1044        let q_max = quantizer.quantize(&[1.0, 1.0]);
1045        assert_eq!(q_max, vec![255, 255]);
1046
1047        // Middle value should quantize to ~127
1048        let q_mid = quantizer.quantize(&[0.5, 0.5]);
1049        assert!(q_mid[0] >= 126 && q_mid[0] <= 128);
1050    }
1051
1052    #[test]
1053    fn test_scalar_quantizer_dequantize() {
1054        let quantizer = ScalarQuantizer::with_ranges(vec![0.0], vec![1.0]);
1055
1056        let original = [0.5f32];
1057        let quantized = quantizer.quantize(&original);
1058        let dequantized = quantizer.dequantize(&quantized);
1059
1060        // Should be close to original (within quantization error)
1061        assert!((original[0] - dequantized[0]).abs() < 0.01);
1062    }
1063
1064    #[test]
1065    fn test_scalar_quantizer_distance() {
1066        let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1067
1068        let a = quantizer.quantize(&[0.0, 0.0]);
1069        let b = quantizer.quantize(&[1.0, 0.0]);
1070
1071        let dist = quantizer.distance_u8(&a, &b);
1072        // Should be approximately 1.0 (the Euclidean distance in original space)
1073        assert!((dist - 1.0).abs() < 0.1);
1074    }
1075
1076    #[test]
1077    fn test_scalar_quantizer_asymmetric_distance() {
1078        let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1079
1080        let query = [0.0f32, 0.0];
1081        let stored = quantizer.quantize(&[1.0, 0.0]);
1082
1083        let dist = quantizer.asymmetric_distance(&query, &stored);
1084        assert!((dist - 1.0).abs() < 0.1);
1085    }
1086
1087    #[test]
1088    fn test_scalar_quantizer_cosine_distance() {
1089        let quantizer = ScalarQuantizer::with_ranges(vec![-1.0, -1.0], vec![1.0, 1.0]);
1090
1091        // Orthogonal vectors
1092        let a = quantizer.quantize(&[1.0, 0.0]);
1093        let b = quantizer.quantize(&[0.0, 1.0]);
1094
1095        let dist = quantizer.cosine_distance_u8(&a, &b);
1096        // Cosine distance of orthogonal vectors = 1.0
1097        assert!((dist - 1.0).abs() < 0.1);
1098    }
1099
1100    #[test]
1101    #[should_panic(expected = "Cannot train on empty vector set")]
1102    fn test_scalar_quantizer_empty_training() {
1103        let vectors: Vec<&[f32]> = vec![];
1104        let _ = ScalarQuantizer::train(&vectors);
1105    }
1106
1107    // ========================================================================
1108    // Binary Quantization Tests
1109    // ========================================================================
1110
1111    #[test]
1112    fn test_binary_quantizer_quantize() {
1113        let v = vec![0.5f32, -0.3, 0.0, 0.8];
1114        let bits = BinaryQuantizer::quantize(&v);
1115
1116        assert_eq!(bits.len(), 1); // 4 dims fit in 1 u64
1117
1118        // Check individual bits: 0.5 >= 0 (1), -0.3 < 0 (0), 0.0 >= 0 (1), 0.8 >= 0 (1)
1119        // Expected bits (LSB first): 1, 0, 1, 1 = 0b1101 = 13
1120        assert_eq!(bits[0] & 0xF, 0b1101);
1121    }
1122
1123    #[test]
1124    fn test_binary_quantizer_hamming_distance() {
1125        let v1 = vec![1.0f32, 1.0, 1.0, 1.0]; // All positive: 1111
1126        let v2 = vec![1.0f32, -1.0, 1.0, -1.0]; // Mixed: 1010
1127
1128        let bits1 = BinaryQuantizer::quantize(&v1);
1129        let bits2 = BinaryQuantizer::quantize(&v2);
1130
1131        let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
1132        assert_eq!(dist, 2); // Two bits differ
1133    }
1134
1135    #[test]
1136    fn test_binary_quantizer_identical_vectors() {
1137        let v = vec![0.1f32, -0.2, 0.3, -0.4, 0.5];
1138        let bits = BinaryQuantizer::quantize(&v);
1139
1140        let dist = BinaryQuantizer::hamming_distance(&bits, &bits);
1141        assert_eq!(dist, 0);
1142    }
1143
1144    #[test]
1145    fn test_binary_quantizer_opposite_vectors() {
1146        let v1 = vec![1.0f32; 64];
1147        let v2 = vec![-1.0f32; 64];
1148
1149        let bits1 = BinaryQuantizer::quantize(&v1);
1150        let bits2 = BinaryQuantizer::quantize(&v2);
1151
1152        let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
1153        assert_eq!(dist, 64); // All bits differ
1154    }
1155
1156    #[test]
1157    fn test_binary_quantizer_large_vector() {
1158        let v: Vec<f32> = (0..1000)
1159            .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
1160            .collect();
1161        let bits = BinaryQuantizer::quantize(&v);
1162
1163        // 1000 dims needs ceil(1000/64) = 16 words
1164        assert_eq!(bits.len(), 16);
1165    }
1166
1167    #[test]
1168    fn test_binary_quantizer_normalized_distance() {
1169        let v1 = vec![1.0f32; 100];
1170        let v2 = vec![-1.0f32; 100];
1171
1172        let bits1 = BinaryQuantizer::quantize(&v1);
1173        let bits2 = BinaryQuantizer::quantize(&v2);
1174
1175        let norm_dist = BinaryQuantizer::hamming_distance_normalized(&bits1, &bits2, 100);
1176        assert!((norm_dist - 1.0).abs() < 0.01); // All bits differ
1177    }
1178
1179    #[test]
1180    fn test_binary_quantizer_words_needed() {
1181        assert_eq!(BinaryQuantizer::words_needed(1), 1);
1182        assert_eq!(BinaryQuantizer::words_needed(64), 1);
1183        assert_eq!(BinaryQuantizer::words_needed(65), 2);
1184        assert_eq!(BinaryQuantizer::words_needed(128), 2);
1185        assert_eq!(BinaryQuantizer::words_needed(1536), 24); // OpenAI embedding size
1186    }
1187
1188    #[test]
1189    fn test_binary_quantizer_bytes_needed() {
1190        // Each u64 is 8 bytes
1191        assert_eq!(BinaryQuantizer::bytes_needed(64), 8);
1192        assert_eq!(BinaryQuantizer::bytes_needed(128), 16);
1193        assert_eq!(BinaryQuantizer::bytes_needed(1536), 192); // vs 6144 for f32
1194    }
1195
1196    // ========================================================================
1197    // SIMD Tests
1198    // ========================================================================
1199
1200    #[test]
1201    fn test_hamming_distance_simd() {
1202        let a = vec![0xFFFF_FFFF_FFFF_FFFFu64, 0x0000_0000_0000_0000];
1203        let b = vec![0x0000_0000_0000_0000u64, 0xFFFF_FFFF_FFFF_FFFF];
1204
1205        let dist = hamming_distance_simd(&a, &b);
1206        assert_eq!(dist, 128); // All 128 bits differ
1207    }
1208
1209    // ========================================================================
1210    // Product Quantization Tests
1211    // ========================================================================
1212
1213    #[test]
1214    fn test_quantization_type_product_from_str() {
1215        // Basic PQ parsing
1216        assert_eq!(
1217            QuantizationType::from_str("pq"),
1218            Some(QuantizationType::Product { num_subvectors: 8 })
1219        );
1220        assert_eq!(
1221            QuantizationType::from_str("product"),
1222            Some(QuantizationType::Product { num_subvectors: 8 })
1223        );
1224
1225        // PQ with specific subvector count
1226        assert_eq!(
1227            QuantizationType::from_str("pq8"),
1228            Some(QuantizationType::Product { num_subvectors: 8 })
1229        );
1230        assert_eq!(
1231            QuantizationType::from_str("pq16"),
1232            Some(QuantizationType::Product { num_subvectors: 16 })
1233        );
1234        assert_eq!(
1235            QuantizationType::from_str("pq32"),
1236            Some(QuantizationType::Product { num_subvectors: 32 })
1237        );
1238    }
1239
1240    #[test]
1241    fn test_quantization_type_requires_training() {
1242        assert!(!QuantizationType::None.requires_training());
1243        assert!(QuantizationType::Scalar.requires_training());
1244        assert!(!QuantizationType::Binary.requires_training());
1245        assert!(QuantizationType::Product { num_subvectors: 8 }.requires_training());
1246    }
1247
1248    #[test]
1249    fn test_product_quantizer_train() {
1250        // Create 100 training vectors with 16 dimensions
1251        let vectors: Vec<Vec<f32>> = (0..100)
1252            .map(|i| (0..16).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
1253            .collect();
1254        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1255
1256        // Train with 4 subvectors (16/4 = 4 dims each), 8 centroids
1257        let pq = ProductQuantizer::train(&refs, 4, 8, 5);
1258
1259        assert_eq!(pq.num_subvectors(), 4);
1260        assert_eq!(pq.num_centroids(), 8);
1261        assert_eq!(pq.dimensions(), 16);
1262        assert_eq!(pq.subvector_dim(), 4);
1263        assert_eq!(pq.code_size(), 4);
1264    }
1265
1266    #[test]
1267    fn test_product_quantizer_quantize() {
1268        let vectors: Vec<Vec<f32>> = (0..50)
1269            .map(|i| (0..8).map(|j| ((i * j) as f32 * 0.1).cos()).collect())
1270            .collect();
1271        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1272
1273        let pq = ProductQuantizer::train(&refs, 2, 16, 3);
1274
1275        // Quantize a vector
1276        let codes = pq.quantize(&vectors[0]);
1277        assert_eq!(codes.len(), 2);
1278
1279        // All codes should be < num_centroids
1280        for &code in &codes {
1281            assert!(code < 16);
1282        }
1283    }
1284
1285    #[test]
1286    fn test_product_quantizer_reconstruct() {
1287        let vectors: Vec<Vec<f32>> = (0..50)
1288            .map(|i| (0..12).map(|j| (i + j) as f32 * 0.05).collect())
1289            .collect();
1290        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1291
1292        let pq = ProductQuantizer::train(&refs, 3, 8, 5);
1293
1294        // Quantize and reconstruct
1295        let original = &vectors[10];
1296        let codes = pq.quantize(original);
1297        let reconstructed = pq.reconstruct(&codes);
1298
1299        assert_eq!(reconstructed.len(), 12);
1300
1301        // Reconstructed should be somewhat close to original (not exact due to quantization)
1302        let error: f32 = original
1303            .iter()
1304            .zip(&reconstructed)
1305            .map(|(a, b)| (a - b).powi(2))
1306            .sum::<f32>()
1307            .sqrt();
1308
1309        // Error should be bounded (not zero, but reasonable)
1310        assert!(error < 2.0, "Reconstruction error too high: {error}");
1311    }
1312
1313    #[test]
1314    fn test_product_quantizer_asymmetric_distance() {
1315        let vectors: Vec<Vec<f32>> = (0..100)
1316            .map(|i| (0..32).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
1317            .collect();
1318        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1319
1320        let pq = ProductQuantizer::train(&refs, 8, 32, 5);
1321
1322        // Distance to self should be small
1323        let query = &vectors[0];
1324        let codes = pq.quantize(query);
1325        let self_dist = pq.asymmetric_distance(query, &codes);
1326        assert!(self_dist < 1.0, "Self-distance too high: {self_dist}");
1327
1328        // Distance to different vector should be larger
1329        let other_codes = pq.quantize(&vectors[50]);
1330        let other_dist = pq.asymmetric_distance(query, &other_codes);
1331        assert!(other_dist > self_dist, "Other vector should be farther");
1332    }
1333
1334    #[test]
1335    fn test_product_quantizer_distance_table() {
1336        let vectors: Vec<Vec<f32>> = (0..50)
1337            .map(|i| (0..16).map(|j| (i + j) as f32 * 0.02).collect())
1338            .collect();
1339        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1340
1341        let pq = ProductQuantizer::train(&refs, 4, 8, 3);
1342
1343        let query = &vectors[0];
1344        let table = pq.build_distance_table(query);
1345
1346        // Table should have M * K entries
1347        assert_eq!(table.len(), 4 * 8);
1348
1349        // Distance via table should match direct computation
1350        let codes = pq.quantize(&vectors[5]);
1351        let dist_direct = pq.asymmetric_distance_squared(query, &codes);
1352        let dist_table = pq.distance_with_table(&table, &codes);
1353
1354        assert!((dist_direct - dist_table).abs() < 0.001);
1355    }
1356
1357    #[test]
1358    fn test_product_quantizer_batch() {
1359        let vectors: Vec<Vec<f32>> = (0..20)
1360            .map(|i| (0..8).map(|j| (i + j) as f32).collect())
1361            .collect();
1362        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1363
1364        let pq = ProductQuantizer::train(&refs, 2, 4, 2);
1365
1366        let batch_codes = pq.quantize_batch(&refs[0..5]);
1367        assert_eq!(batch_codes.len(), 5);
1368
1369        for codes in &batch_codes {
1370            assert_eq!(codes.len(), 2);
1371        }
1372    }
1373
1374    #[test]
1375    fn test_product_quantizer_compression_ratio() {
1376        let vectors: Vec<Vec<f32>> = (0..50)
1377            .map(|i| (0..384).map(|j| ((i * j) as f32).sin()).collect())
1378            .collect();
1379        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1380
1381        // PQ8: 384 dims split into 8 subvectors
1382        let pq8 = ProductQuantizer::train(&refs, 8, 256, 3);
1383        assert_eq!(pq8.compression_ratio(), 192); // (384 * 4) / 8 = 192
1384
1385        // PQ48: 384 dims split into 48 subvectors (8 dims each)
1386        let pq48 = ProductQuantizer::train(&refs, 48, 256, 3);
1387        assert_eq!(pq48.compression_ratio(), 32); // (384 * 4) / 48 = 32
1388    }
1389
1390    #[test]
1391    #[should_panic(expected = "dimensions (15) must be divisible by num_subvectors (4)")]
1392    fn test_product_quantizer_invalid_dimensions() {
1393        let vectors: Vec<Vec<f32>> = (0..10)
1394            .map(|i| (0..15).map(|j| (i + j) as f32).collect())
1395            .collect();
1396        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1397
1398        // 15 is not divisible by 4
1399        let _ = ProductQuantizer::train(&refs, 4, 8, 3);
1400    }
1401
1402    #[test]
1403    fn test_product_quantizer_get_partition_centroids() {
1404        let vectors: Vec<Vec<f32>> = (0..30)
1405            .map(|i| (0..8).map(|j| (i + j) as f32 * 0.1).collect())
1406            .collect();
1407        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1408
1409        let pq = ProductQuantizer::train(&refs, 2, 4, 3);
1410
1411        // Get centroids for first partition
1412        let centroids = pq.get_partition_centroids(0);
1413        assert_eq!(centroids.len(), 4); // 4 centroids
1414        assert_eq!(centroids[0].len(), 4); // 4 dims per subvector (8/2)
1415    }
1416}