Skip to main content

foxstash_core/vector/
quantize.rs

1//! Vector quantization for memory-efficient storage
2//!
3//! This module provides quantization methods to reduce memory footprint while
4//! maintaining acceptable search quality:
5//!
6//! - **Scalar Quantization (SQ8)**: f32 → i8 (4x compression, ~95% recall)
7//! - **Binary Quantization (BQ)**: f32 → bit (32x compression, ~85% recall)
8//!
9//! # Memory Comparison (1M vectors × 384 dims)
10//!
11//! | Format | Size | Compression |
12//! |--------|------|-------------|
13//! | f32    | 1.5 GB | 1x |
14//! | int8   | 384 MB | 4x |
15//! | binary | 48 MB  | 32x |
16//!
17//! # Usage
18//!
19//! ```
20//! use foxstash_core::vector::quantize::{ScalarQuantizer, BinaryQuantizer, Quantizer};
21//!
22//! let vector = vec![0.5, -0.3, 0.8, -0.1];
23//!
24//! // Scalar quantization (4x compression)
25//! let sq = ScalarQuantizer::fit(&[vector.clone()]);
26//! let quantized = sq.quantize(&vector);
27//! let reconstructed = sq.dequantize(&quantized);
28//!
29//! // Binary quantization (32x compression)
30//! let bq = BinaryQuantizer::new(4);
31//! let binary = bq.quantize(&vector);
32//! ```
33
34use pulp::Simd;
35use serde::{Deserialize, Serialize};
36
37/// Trait for vector quantization
38pub trait Quantizer: Send + Sync {
39    /// Quantized representation type
40    type Quantized: Clone + Send + Sync;
41
42    /// Quantize a vector
43    fn quantize(&self, vector: &[f32]) -> Self::Quantized;
44
45    /// Dequantize back to f32 (lossy)
46    fn dequantize(&self, quantized: &Self::Quantized) -> Vec<f32>;
47
48    /// Compute distance between quantized vectors (fast path)
49    fn distance_quantized(&self, a: &Self::Quantized, b: &Self::Quantized) -> f32;
50
51    /// Compute distance between f32 query and quantized vector (asymmetric)
52    fn distance_asymmetric(&self, query: &[f32], quantized: &Self::Quantized) -> f32;
53}
54
55// ============================================================================
56// Scalar Quantization (SQ8)
57// ============================================================================
58
59/// Scalar quantization parameters for a single dimension
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ScalarQuantizationParams {
62    /// Minimum value in training data
63    pub min: f32,
64    /// Maximum value in training data
65    pub max: f32,
66    /// Scale factor: (max - min) / 255
67    pub scale: f32,
68}
69
70impl ScalarQuantizationParams {
71    /// Create parameters from min/max values
72    pub fn new(min: f32, max: f32) -> Self {
73        let range = max - min;
74        let scale = if range > 0.0 { range / 255.0 } else { 1.0 };
75        Self { min, max, scale }
76    }
77
78    /// Quantize a single value to u8
79    #[inline]
80    pub fn quantize_value(&self, value: f32) -> u8 {
81        let normalized = (value - self.min) / self.scale;
82        normalized.clamp(0.0, 255.0) as u8
83    }
84
85    /// Dequantize a u8 back to f32
86    #[inline]
87    pub fn dequantize_value(&self, quantized: u8) -> f32 {
88        (quantized as f32) * self.scale + self.min
89    }
90}
91
92/// Scalar quantization (SQ8): f32 → u8
93///
94/// Maps each dimension independently to [0, 255] range based on min/max values
95/// from training data. Provides 4x memory reduction with minimal quality loss.
96///
97/// # Example
98///
99/// ```
100/// use foxstash_core::vector::quantize::{ScalarQuantizer, Quantizer};
101///
102/// // Fit quantizer on training data
103/// let training_vectors = vec![
104///     vec![0.1, 0.5, 0.9],
105///     vec![0.2, 0.4, 0.8],
106///     vec![0.3, 0.6, 0.7],
107/// ];
108/// let quantizer = ScalarQuantizer::fit(&training_vectors);
109///
110/// // Quantize new vectors
111/// let query = vec![0.15, 0.45, 0.85];
112/// let quantized = quantizer.quantize(&query);
113///
114/// // Compute distance efficiently
115/// let db_vec = quantizer.quantize(&training_vectors[0]);
116/// let distance = quantizer.distance_quantized(&quantized, &db_vec);
117/// ```
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct ScalarQuantizer {
120    /// Per-dimension quantization parameters
121    params: Vec<ScalarQuantizationParams>,
122    /// Dimensionality
123    dim: usize,
124}
125
126/// Quantized vector representation for SQ8
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ScalarQuantizedVector {
129    /// Quantized values (u8 per dimension)
130    pub data: Vec<u8>,
131}
132
133impl ScalarQuantizer {
134    /// Fit quantizer parameters from training vectors
135    ///
136    /// Computes min/max for each dimension across all training vectors.
137    ///
138    /// # Panics
139    ///
140    /// Panics if training vectors have inconsistent dimensions.
141    pub fn fit(training_vectors: &[Vec<f32>]) -> Self {
142        assert!(
143            !training_vectors.is_empty(),
144            "Need at least one training vector"
145        );
146
147        let dim = training_vectors[0].len();
148        let mut mins = vec![f32::INFINITY; dim];
149        let mut maxs = vec![f32::NEG_INFINITY; dim];
150
151        for vector in training_vectors {
152            assert_eq!(vector.len(), dim, "Inconsistent vector dimensions");
153            for (i, &val) in vector.iter().enumerate() {
154                mins[i] = mins[i].min(val);
155                maxs[i] = maxs[i].max(val);
156            }
157        }
158
159        let params: Vec<_> = mins
160            .iter()
161            .zip(maxs.iter())
162            .map(|(&min, &max)| ScalarQuantizationParams::new(min, max))
163            .collect();
164
165        Self { params, dim }
166    }
167
168    /// Create quantizer with known min/max bounds
169    ///
170    /// Use this when you know the expected value range (e.g., normalized embeddings).
171    pub fn with_bounds(dim: usize, min: f32, max: f32) -> Self {
172        let params = vec![ScalarQuantizationParams::new(min, max); dim];
173        Self { params, dim }
174    }
175
176    /// Create quantizer for normalized vectors ([-1, 1] range)
177    pub fn for_normalized(dim: usize) -> Self {
178        Self::with_bounds(dim, -1.0, 1.0)
179    }
180
181    /// Get the dimensionality
182    pub fn dim(&self) -> usize {
183        self.dim
184    }
185
186    /// Get quantization parameters for analysis
187    pub fn params(&self) -> &[ScalarQuantizationParams] {
188        &self.params
189    }
190}
191
192impl Quantizer for ScalarQuantizer {
193    type Quantized = ScalarQuantizedVector;
194
195    fn quantize(&self, vector: &[f32]) -> Self::Quantized {
196        debug_assert_eq!(vector.len(), self.dim);
197
198        let data: Vec<u8> = vector
199            .iter()
200            .zip(self.params.iter())
201            .map(|(&val, param)| param.quantize_value(val))
202            .collect();
203
204        ScalarQuantizedVector { data }
205    }
206
207    fn dequantize(&self, quantized: &Self::Quantized) -> Vec<f32> {
208        quantized
209            .data
210            .iter()
211            .zip(self.params.iter())
212            .map(|(&val, param)| param.dequantize_value(val))
213            .collect()
214    }
215
216    fn distance_quantized(&self, a: &Self::Quantized, b: &Self::Quantized) -> f32 {
217        // L2 distance in quantized space (scaled)
218        sq8_l2_distance_simd(&a.data, &b.data)
219    }
220
221    fn distance_asymmetric(&self, query: &[f32], quantized: &Self::Quantized) -> f32 {
222        // Asymmetric distance: full precision query vs quantized database
223        sq8_asymmetric_l2_distance_simd(query, &quantized.data, &self.params)
224    }
225}
226
227// ============================================================================
228// Binary Quantization (BQ)
229// ============================================================================
230
231/// Binary quantization: f32 → bit
232///
233/// Maps each dimension to a single bit based on sign (positive = 1, negative = 0).
234/// Provides 32x memory reduction. Best used for initial filtering with reranking.
235///
236/// Distance is computed using Hamming distance (number of differing bits).
237///
238/// # Example
239///
240/// ```
241/// use foxstash_core::vector::quantize::{BinaryQuantizer, Quantizer};
242///
243/// let quantizer = BinaryQuantizer::new(128);
244///
245/// let vec_a = vec![0.5; 128];  // All positive → all 1s
246/// let vec_b = vec![-0.5; 128]; // All negative → all 0s
247///
248/// let qa = quantizer.quantize(&vec_a);
249/// let qb = quantizer.quantize(&vec_b);
250///
251/// // Maximum Hamming distance (all bits differ)
252/// let distance = quantizer.distance_quantized(&qa, &qb);
253/// assert_eq!(distance, 128.0);
254/// ```
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct BinaryQuantizer {
257    dim: usize,
258    /// Number of bytes needed to store dim bits
259    byte_len: usize,
260}
261
262/// Quantized vector representation for binary quantization
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct BinaryQuantizedVector {
265    /// Packed bits (ceil(dim/8) bytes)
266    pub data: Vec<u8>,
267}
268
269impl BinaryQuantizer {
270    /// Create a binary quantizer for vectors of given dimension
271    pub fn new(dim: usize) -> Self {
272        let byte_len = (dim + 7) / 8; // Ceiling division
273        Self { dim, byte_len }
274    }
275
276    /// Get the dimensionality
277    pub fn dim(&self) -> usize {
278        self.dim
279    }
280
281    /// Get the byte length of quantized vectors
282    pub fn byte_len(&self) -> usize {
283        self.byte_len
284    }
285}
286
287impl Quantizer for BinaryQuantizer {
288    type Quantized = BinaryQuantizedVector;
289
290    fn quantize(&self, vector: &[f32]) -> Self::Quantized {
291        debug_assert_eq!(vector.len(), self.dim);
292
293        let mut data = vec![0u8; self.byte_len];
294
295        for (i, &val) in vector.iter().enumerate() {
296            if val >= 0.0 {
297                let byte_idx = i / 8;
298                let bit_idx = i % 8;
299                data[byte_idx] |= 1 << bit_idx;
300            }
301        }
302
303        BinaryQuantizedVector { data }
304    }
305
306    fn dequantize(&self, quantized: &Self::Quantized) -> Vec<f32> {
307        // Binary quantization is highly lossy - we can only recover sign
308        let mut result = vec![0.0f32; self.dim];
309
310        for i in 0..self.dim {
311            let byte_idx = i / 8;
312            let bit_idx = i % 8;
313            let bit = (quantized.data[byte_idx] >> bit_idx) & 1;
314            result[i] = if bit == 1 { 1.0 } else { -1.0 };
315        }
316
317        result
318    }
319
320    fn distance_quantized(&self, a: &Self::Quantized, b: &Self::Quantized) -> f32 {
321        // Hamming distance (number of differing bits)
322        hamming_distance_simd(&a.data, &b.data) as f32
323    }
324
325    fn distance_asymmetric(&self, query: &[f32], quantized: &Self::Quantized) -> f32 {
326        // Asymmetric: count mismatches between query sign and quantized bits
327        let mut mismatches = 0u32;
328
329        for (i, &val) in query.iter().enumerate() {
330            let byte_idx = i / 8;
331            let bit_idx = i % 8;
332            let quantized_bit = (quantized.data[byte_idx] >> bit_idx) & 1;
333            let query_bit = if val >= 0.0 { 1 } else { 0 };
334
335            if quantized_bit != query_bit {
336                mismatches += 1;
337            }
338        }
339
340        mismatches as f32
341    }
342}
343
344// ============================================================================
345// SIMD-Accelerated Distance Functions
346// ============================================================================
347
348/// SIMD-accelerated L2 distance for SQ8 vectors
349#[inline]
350pub fn sq8_l2_distance_simd(a: &[u8], b: &[u8]) -> f32 {
351    debug_assert_eq!(a.len(), b.len());
352
353    let simd = pulp::Arch::new();
354    simd.dispatch(|| sq8_l2_distance_impl(simd, a, b))
355}
356
357/// Internal SQ8 L2 distance implementation
358#[inline(always)]
359fn sq8_l2_distance_impl(simd: pulp::Arch, a: &[u8], b: &[u8]) -> f32 {
360    struct Sq8L2<'a> {
361        a: &'a [u8],
362        b: &'a [u8],
363    }
364
365    impl pulp::WithSimd for Sq8L2<'_> {
366        type Output = f32;
367
368        #[inline(always)]
369        fn with_simd<S: Simd>(self, _simd: S) -> Self::Output {
370            // pulp has no u8/i16 SIMD ops, so we use scalar unrolling.
371            // 4-at-a-time unrolling reduces loop overhead and gives the compiler
372            // headroom to auto-vectorize the inner body using integer SIMD.
373            let mut sum_sq: u32 = 0;
374
375            let mut chunks = self.a.chunks_exact(4).zip(self.b.chunks_exact(4));
376            for (a_chunk, b_chunk) in &mut chunks {
377                let d0 = (a_chunk[0] as i32) - (b_chunk[0] as i32);
378                let d1 = (a_chunk[1] as i32) - (b_chunk[1] as i32);
379                let d2 = (a_chunk[2] as i32) - (b_chunk[2] as i32);
380                let d3 = (a_chunk[3] as i32) - (b_chunk[3] as i32);
381                sum_sq += (d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3) as u32;
382            }
383
384            // Handle remainder (0..3 elements)
385            let rem_start = self.a.len() - self.a.len() % 4;
386            for i in rem_start..self.a.len() {
387                let diff = (self.a[i] as i32) - (self.b[i] as i32);
388                sum_sq += (diff * diff) as u32;
389            }
390
391            (sum_sq as f32).sqrt()
392        }
393    }
394
395    simd.dispatch(Sq8L2 { a, b })
396}
397
398/// SIMD-accelerated asymmetric L2 distance (f32 query vs SQ8 database)
399#[inline]
400pub fn sq8_asymmetric_l2_distance_simd(
401    query: &[f32],
402    quantized: &[u8],
403    params: &[ScalarQuantizationParams],
404) -> f32 {
405    debug_assert_eq!(query.len(), quantized.len());
406    debug_assert_eq!(query.len(), params.len());
407
408    let simd = pulp::Arch::new();
409    simd.dispatch(|| sq8_asymmetric_l2_impl(simd, query, quantized, params))
410}
411
412/// Internal asymmetric L2 implementation
413#[inline(always)]
414fn sq8_asymmetric_l2_impl(
415    simd: pulp::Arch,
416    query: &[f32],
417    quantized: &[u8],
418    params: &[ScalarQuantizationParams],
419) -> f32 {
420    struct AsymL2<'a> {
421        query: &'a [f32],
422        quantized: &'a [u8],
423        params: &'a [ScalarQuantizationParams],
424    }
425
426    impl pulp::WithSimd for AsymL2<'_> {
427        type Output = f32;
428
429        #[inline(always)]
430        fn with_simd<S: Simd>(self, _simd: S) -> Self::Output {
431            // pulp has no u8/i16 SIMD ops, so we use scalar unrolling.
432            // 4-at-a-time unrolling reduces loop overhead; the f32 accumulation
433            // gives the compiler room to auto-vectorize using f32 SIMD.
434            let mut sum_sq: f32 = 0.0;
435            let n = self.query.len();
436
437            let mut i = 0;
438            while i + 4 <= n {
439                let d0 = self.query[i] - self.params[i].dequantize_value(self.quantized[i]);
440                let d1 =
441                    self.query[i + 1] - self.params[i + 1].dequantize_value(self.quantized[i + 1]);
442                let d2 =
443                    self.query[i + 2] - self.params[i + 2].dequantize_value(self.quantized[i + 2]);
444                let d3 =
445                    self.query[i + 3] - self.params[i + 3].dequantize_value(self.quantized[i + 3]);
446                sum_sq += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
447                i += 4;
448            }
449
450            // Handle remainder (0..3 elements)
451            while i < n {
452                let dequantized = self.params[i].dequantize_value(self.quantized[i]);
453                let diff = self.query[i] - dequantized;
454                sum_sq += diff * diff;
455                i += 1;
456            }
457
458            sum_sq.sqrt()
459        }
460    }
461
462    simd.dispatch(AsymL2 {
463        query,
464        quantized,
465        params,
466    })
467}
468
469/// SIMD-accelerated Hamming distance for binary vectors
470#[inline]
471pub fn hamming_distance_simd(a: &[u8], b: &[u8]) -> u32 {
472    debug_assert_eq!(a.len(), b.len());
473
474    let simd = pulp::Arch::new();
475    simd.dispatch(|| hamming_distance_impl(simd, a, b))
476}
477
478/// Internal Hamming distance implementation
479#[inline(always)]
480fn hamming_distance_impl(simd: pulp::Arch, a: &[u8], b: &[u8]) -> u32 {
481    struct Hamming<'a> {
482        a: &'a [u8],
483        b: &'a [u8],
484    }
485
486    impl pulp::WithSimd for Hamming<'_> {
487        type Output = u32;
488
489        #[inline(always)]
490        fn with_simd<S: Simd>(self, _simd: S) -> Self::Output {
491            // Process 8 bytes at a time using u64 for efficient popcount
492            let mut distance = 0u32;
493
494            // Process full u64 chunks
495            let chunks = self.a.len() / 8;
496            for i in 0..chunks {
497                let offset = i * 8;
498                let a_u64 = u64::from_le_bytes([
499                    self.a[offset],
500                    self.a[offset + 1],
501                    self.a[offset + 2],
502                    self.a[offset + 3],
503                    self.a[offset + 4],
504                    self.a[offset + 5],
505                    self.a[offset + 6],
506                    self.a[offset + 7],
507                ]);
508                let b_u64 = u64::from_le_bytes([
509                    self.b[offset],
510                    self.b[offset + 1],
511                    self.b[offset + 2],
512                    self.b[offset + 3],
513                    self.b[offset + 4],
514                    self.b[offset + 5],
515                    self.b[offset + 6],
516                    self.b[offset + 7],
517                ]);
518                distance += (a_u64 ^ b_u64).count_ones();
519            }
520
521            // Handle remainder bytes
522            for i in (chunks * 8)..self.a.len() {
523                distance += (self.a[i] ^ self.b[i]).count_ones();
524            }
525
526            distance
527        }
528    }
529
530    simd.dispatch(Hamming { a, b })
531}
532
533/// Compute dot product between f32 query and binary quantized vector
534///
535/// This is useful for cosine similarity approximation with binary vectors.
536/// Returns the number of positive query components that match positive bits
537/// minus those that mismatch.
538#[inline]
539pub fn binary_dot_product(query: &[f32], quantized: &BinaryQuantizedVector, dim: usize) -> f32 {
540    let mut sum = 0.0f32;
541
542    for i in 0..dim {
543        let byte_idx = i / 8;
544        let bit_idx = i % 8;
545        let bit = ((quantized.data[byte_idx] >> bit_idx) & 1) as f32;
546        // Map bit: 0 → -1, 1 → +1
547        let sign = bit * 2.0 - 1.0;
548        sum += query[i] * sign;
549    }
550
551    sum
552}
553
554// ============================================================================
555// Product Quantization (PQ) - For very large datasets
556// ============================================================================
557
558/// Product Quantization configuration
559#[derive(Debug, Clone, Serialize, Deserialize)]
560pub struct ProductQuantizerConfig {
561    /// Original vector dimension
562    pub dim: usize,
563    /// Number of subvectors
564    pub num_subvectors: usize,
565    /// Bits per subvector (number of centroids = 2^bits)
566    pub bits_per_subvector: usize,
567}
568
569impl ProductQuantizerConfig {
570    /// Create a default PQ config for given dimension
571    ///
572    /// Uses 8 subvectors with 8 bits (256 centroids) each by default.
573    pub fn default_for_dim(dim: usize) -> Self {
574        let num_subvectors = 8.min(dim);
575        Self {
576            dim,
577            num_subvectors,
578            bits_per_subvector: 8,
579        }
580    }
581
582    /// Dimension of each subvector
583    pub fn subvector_dim(&self) -> usize {
584        self.dim / self.num_subvectors
585    }
586
587    /// Number of centroids per subvector
588    pub fn num_centroids(&self) -> usize {
589        1 << self.bits_per_subvector
590    }
591
592    /// Compressed size in bytes
593    pub fn compressed_size(&self) -> usize {
594        self.num_subvectors * ((self.bits_per_subvector + 7) / 8)
595    }
596}
597
598// Note: Full PQ implementation with k-means training would go here.
599// For now, we provide the config structure for future expansion.
600
601// ============================================================================
602// Tests
603// ============================================================================
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608
609    const EPSILON: f32 = 1e-5;
610
611    // ========================================================================
612    // Scalar Quantization Tests
613    // ========================================================================
614
615    #[test]
616    fn test_scalar_quantizer_fit() {
617        let vectors = vec![
618            vec![0.0, 0.5, 1.0],
619            vec![0.2, 0.3, 0.8],
620            vec![0.1, 0.6, 0.9],
621        ];
622
623        let sq = ScalarQuantizer::fit(&vectors);
624        assert_eq!(sq.dim(), 3);
625
626        // Check that params capture min/max
627        assert!((sq.params[0].min - 0.0).abs() < EPSILON);
628        assert!((sq.params[0].max - 0.2).abs() < EPSILON);
629        assert!((sq.params[2].min - 0.8).abs() < EPSILON);
630        assert!((sq.params[2].max - 1.0).abs() < EPSILON);
631    }
632
633    #[test]
634    fn test_scalar_quantizer_roundtrip() {
635        let vectors = vec![vec![-1.0, 0.0, 1.0], vec![-0.5, 0.5, 0.5]];
636
637        let sq = ScalarQuantizer::fit(&vectors);
638        let original = vec![-0.7, 0.3, 0.8];
639        let quantized = sq.quantize(&original);
640        let reconstructed = sq.dequantize(&quantized);
641
642        // Should be close but not exact (quantization is lossy)
643        for (o, r) in original.iter().zip(reconstructed.iter()) {
644            assert!((o - r).abs() < 0.02, "orig={}, recon={}", o, r);
645        }
646    }
647
648    #[test]
649    fn test_scalar_quantizer_for_normalized() {
650        let sq = ScalarQuantizer::for_normalized(384);
651        assert_eq!(sq.dim(), 384);
652
653        // Test with normalized vector
654        let vector: Vec<f32> = (0..384).map(|i| (i as f32 / 192.0) - 1.0).collect();
655        let quantized = sq.quantize(&vector);
656        let reconstructed = sq.dequantize(&quantized);
657
658        // Should be close
659        let max_error: f32 = vector
660            .iter()
661            .zip(reconstructed.iter())
662            .map(|(a, b)| (a - b).abs())
663            .fold(0.0f32, |a, b| a.max(b));
664
665        assert!(max_error < 0.01, "Max error: {}", max_error);
666    }
667
668    #[test]
669    fn test_sq8_distance_quantized() {
670        let sq = ScalarQuantizer::for_normalized(4);
671
672        let a = vec![1.0, 0.0, -1.0, 0.5];
673        let b = vec![1.0, 0.0, -1.0, 0.5];
674
675        let qa = sq.quantize(&a);
676        let qb = sq.quantize(&b);
677
678        let dist = sq.distance_quantized(&qa, &qb);
679        assert!(dist < 1.0, "Same vectors should have near-zero distance");
680    }
681
682    #[test]
683    fn test_sq8_distance_different() {
684        let sq = ScalarQuantizer::for_normalized(4);
685
686        let a = vec![1.0, 1.0, 1.0, 1.0];
687        let b = vec![-1.0, -1.0, -1.0, -1.0];
688
689        let qa = sq.quantize(&a);
690        let qb = sq.quantize(&b);
691
692        let dist = sq.distance_quantized(&qa, &qb);
693        assert!(dist > 100.0, "Opposite vectors should have large distance");
694    }
695
696    // ========================================================================
697    // Binary Quantization Tests
698    // ========================================================================
699
700    #[test]
701    fn test_binary_quantizer_basic() {
702        let bq = BinaryQuantizer::new(8);
703        assert_eq!(bq.dim(), 8);
704        assert_eq!(bq.byte_len(), 1);
705    }
706
707    #[test]
708    fn test_binary_quantizer_byte_len() {
709        // Test various dimensions
710        assert_eq!(BinaryQuantizer::new(1).byte_len(), 1);
711        assert_eq!(BinaryQuantizer::new(8).byte_len(), 1);
712        assert_eq!(BinaryQuantizer::new(9).byte_len(), 2);
713        assert_eq!(BinaryQuantizer::new(16).byte_len(), 2);
714        assert_eq!(BinaryQuantizer::new(384).byte_len(), 48);
715    }
716
717    #[test]
718    fn test_binary_quantizer_all_positive() {
719        let bq = BinaryQuantizer::new(8);
720        let vector = vec![0.5, 0.3, 0.1, 0.9, 0.2, 0.4, 0.6, 0.8];
721        let quantized = bq.quantize(&vector);
722
723        // All positive → all bits set
724        assert_eq!(quantized.data[0], 0xFF);
725    }
726
727    #[test]
728    fn test_binary_quantizer_all_negative() {
729        let bq = BinaryQuantizer::new(8);
730        let vector = vec![-0.5, -0.3, -0.1, -0.9, -0.2, -0.4, -0.6, -0.8];
731        let quantized = bq.quantize(&vector);
732
733        // All negative → no bits set
734        assert_eq!(quantized.data[0], 0x00);
735    }
736
737    #[test]
738    fn test_binary_quantizer_mixed() {
739        let bq = BinaryQuantizer::new(8);
740        let vector = vec![0.5, -0.3, 0.1, -0.9, 0.2, -0.4, 0.6, -0.8];
741        // Positive at indices: 0, 2, 4, 6 → bits 0, 2, 4, 6 → 0b01010101
742        let quantized = bq.quantize(&vector);
743        assert_eq!(quantized.data[0], 0b01010101);
744    }
745
746    #[test]
747    fn test_binary_hamming_distance() {
748        let bq = BinaryQuantizer::new(8);
749
750        let a = vec![1.0; 8]; // All positive
751        let b = vec![-1.0; 8]; // All negative
752
753        let qa = bq.quantize(&a);
754        let qb = bq.quantize(&b);
755
756        let dist = bq.distance_quantized(&qa, &qb);
757        assert_eq!(dist, 8.0); // All 8 bits differ
758    }
759
760    #[test]
761    fn test_binary_hamming_same() {
762        let bq = BinaryQuantizer::new(16);
763
764        let a = vec![
765            0.5, -0.3, 0.1, -0.9, 0.2, -0.4, 0.6, -0.8, 0.5, -0.3, 0.1, -0.9, 0.2, -0.4, 0.6, -0.8,
766        ];
767
768        let qa = bq.quantize(&a);
769        let qb = bq.quantize(&a);
770
771        let dist = bq.distance_quantized(&qa, &qb);
772        assert_eq!(dist, 0.0); // Same vector → zero distance
773    }
774
775    #[test]
776    fn test_binary_dequantize() {
777        let bq = BinaryQuantizer::new(4);
778        let vector = vec![0.5, -0.3, 0.1, -0.9];
779        let quantized = bq.quantize(&vector);
780        let dequantized = bq.dequantize(&quantized);
781
782        // Dequantize only recovers sign
783        assert_eq!(dequantized, vec![1.0, -1.0, 1.0, -1.0]);
784    }
785
786    #[test]
787    fn test_binary_large_dimension() {
788        let bq = BinaryQuantizer::new(384);
789        let vector: Vec<f32> = (0..384)
790            .map(|i| if i % 2 == 0 { 0.5 } else { -0.5 })
791            .collect();
792
793        let quantized = bq.quantize(&vector);
794        assert_eq!(quantized.data.len(), 48);
795
796        let dequantized = bq.dequantize(&quantized);
797        for (i, &val) in dequantized.iter().enumerate() {
798            let expected = if i % 2 == 0 { 1.0 } else { -1.0 };
799            assert_eq!(val, expected);
800        }
801    }
802
803    // ========================================================================
804    // SIMD Function Tests
805    // ========================================================================
806
807    #[test]
808    fn test_hamming_distance_simd_basic() {
809        let a = vec![0b11110000u8, 0b10101010];
810        let b = vec![0b00001111u8, 0b10101010];
811
812        let dist = hamming_distance_simd(&a, &b);
813        // First byte: all 8 bits differ
814        // Second byte: 0 bits differ
815        assert_eq!(dist, 8);
816    }
817
818    #[test]
819    fn test_hamming_distance_simd_same() {
820        let a = vec![0xFF, 0x00, 0xAB, 0xCD];
821        let b = a.clone();
822
823        let dist = hamming_distance_simd(&a, &b);
824        assert_eq!(dist, 0);
825    }
826
827    #[test]
828    fn test_sq8_l2_distance_simd_basic() {
829        let a = vec![0u8, 50, 100, 150, 200, 250];
830        let b = vec![0u8, 50, 100, 150, 200, 250];
831
832        let dist = sq8_l2_distance_simd(&a, &b);
833        assert!(dist < EPSILON);
834    }
835
836    #[test]
837    fn test_sq8_l2_distance_simd_different() {
838        let a = vec![0u8, 0, 0, 0];
839        let b = vec![255u8, 255, 255, 255];
840
841        let dist = sq8_l2_distance_simd(&a, &b);
842        // Expected: sqrt(255^2 * 4) = sqrt(260100) ≈ 510
843        assert!((dist - 510.0).abs() < 1.0);
844    }
845
846    // ========================================================================
847    // Product Quantization Config Tests
848    // ========================================================================
849
850    #[test]
851    fn test_pq_config_defaults() {
852        let config = ProductQuantizerConfig::default_for_dim(384);
853
854        assert_eq!(config.dim, 384);
855        assert_eq!(config.num_subvectors, 8);
856        assert_eq!(config.bits_per_subvector, 8);
857        assert_eq!(config.subvector_dim(), 48);
858        assert_eq!(config.num_centroids(), 256);
859        assert_eq!(config.compressed_size(), 8); // 8 bytes for 8 subvectors × 8 bits
860    }
861
862    // ========================================================================
863    // Accuracy Tests
864    // ========================================================================
865
866    #[test]
867    fn test_sq8_recall_approximation() {
868        // Generate random vectors and test that SQ8 preserves ordering
869        use rand::SeedableRng;
870        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
871
872        let dim = 128;
873        let num_vectors = 100;
874
875        let vectors: Vec<Vec<f32>> = (0..num_vectors)
876            .map(|_| {
877                (0..dim)
878                    .map(|_| rand::Rng::gen_range(&mut rng, -1.0..1.0))
879                    .collect()
880            })
881            .collect();
882
883        let sq = ScalarQuantizer::fit(&vectors);
884        let quantized: Vec<_> = vectors.iter().map(|v| sq.quantize(v)).collect();
885
886        // Pick a random query
887        let query_idx = 42;
888        let query = &vectors[query_idx];
889        let query_q = &quantized[query_idx];
890
891        // Find nearest neighbors using exact and quantized distance
892        let mut exact_distances: Vec<(usize, f32)> = vectors
893            .iter()
894            .enumerate()
895            .filter(|(i, _)| *i != query_idx)
896            .map(|(i, v)| {
897                let dist: f32 = query
898                    .iter()
899                    .zip(v.iter())
900                    .map(|(a, b)| (a - b).powi(2))
901                    .sum::<f32>()
902                    .sqrt();
903                (i, dist)
904            })
905            .collect();
906
907        let mut quantized_distances: Vec<(usize, f32)> = quantized
908            .iter()
909            .enumerate()
910            .filter(|(i, _)| *i != query_idx)
911            .map(|(i, q)| (i, sq.distance_quantized(query_q, q)))
912            .collect();
913
914        exact_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
915        quantized_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
916
917        // Check recall@10: how many of top-10 exact neighbors are in top-10 quantized
918        let exact_top10: std::collections::HashSet<_> =
919            exact_distances[..10].iter().map(|(i, _)| *i).collect();
920        let quantized_top10: std::collections::HashSet<_> =
921            quantized_distances[..10].iter().map(|(i, _)| *i).collect();
922
923        let recall = exact_top10.intersection(&quantized_top10).count();
924        // SQ8 should have at least 70% recall@10
925        assert!(recall >= 7, "Recall@10: {}/10", recall);
926    }
927
928    #[test]
929    fn test_binary_recall_approximation() {
930        use rand::SeedableRng;
931        let mut rng = rand::rngs::StdRng::seed_from_u64(123);
932
933        let dim = 128;
934        let num_vectors = 100;
935
936        let vectors: Vec<Vec<f32>> = (0..num_vectors)
937            .map(|_| {
938                (0..dim)
939                    .map(|_| rand::Rng::gen_range(&mut rng, -1.0..1.0))
940                    .collect()
941            })
942            .collect();
943
944        let bq = BinaryQuantizer::new(dim);
945        let quantized: Vec<_> = vectors.iter().map(|v| bq.quantize(v)).collect();
946
947        let query_idx = 42;
948        let query = &vectors[query_idx];
949        let query_q = &quantized[query_idx];
950
951        // Find nearest using exact cosine and quantized hamming
952        let mut exact_distances: Vec<(usize, f32)> = vectors
953            .iter()
954            .enumerate()
955            .filter(|(i, _)| *i != query_idx)
956            .map(|(i, v)| {
957                let dot: f32 = query.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
958                let norm_q: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
959                let norm_v: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
960                let cosine = dot / (norm_q * norm_v);
961                (i, 1.0 - cosine) // Convert to distance
962            })
963            .collect();
964
965        let mut quantized_distances: Vec<(usize, f32)> = quantized
966            .iter()
967            .enumerate()
968            .filter(|(i, _)| *i != query_idx)
969            .map(|(i, q)| (i, bq.distance_quantized(query_q, q)))
970            .collect();
971
972        exact_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
973        quantized_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
974
975        // Binary should have at least 50% recall@10 (it's more lossy)
976        let exact_top10: std::collections::HashSet<_> =
977            exact_distances[..10].iter().map(|(i, _)| *i).collect();
978        let quantized_top10: std::collections::HashSet<_> =
979            quantized_distances[..10].iter().map(|(i, _)| *i).collect();
980
981        let recall = exact_top10.intersection(&quantized_top10).count();
982        assert!(recall >= 5, "Binary recall@10: {}/10", recall);
983    }
984}