omendb_core/compression/
rabitq.rs

1//! Multi-bit Scalar Quantization for `OmenDB`
2//!
3//! Provides flexible vector compression with arbitrary bit rates (2-8 bits per
4//! dimension) using per-dimension min/max quantization with trained parameters.
5//!
6//! **Note:** This module is named `rabitq` for historical reasons but implements
7//! standard scalar quantization, NOT the RaBitQ algorithm from arXiv:2405.12497.
8//! True RaBitQ requires random orthogonal rotation + binary quantization.
9//!
10//! # Tiered Compression Strategy
11//!
12//! - L0-L2 (hot): Full precision f32 (no compression)
13//! - L3-L4 (warm): 4-bit (8× compression)
14//! - L5-L6 (cold): 2-bit (16× compression)
15//!
16//! # Key Features
17//!
18//! - Flexible compression (2, 3, 4, 5, 7, 8 bits/dimension)
19//! - Per-dimension min/max training (percentile-based for outlier robustness)
20//! - SIMD-accelerated distance (AVX2/NEON)
21//! - Same query speed as 8-bit scalar quantization
22//! - Better accuracy than binary quantization
23
24use serde::{Deserialize, Serialize};
25use smallvec::SmallVec;
26use std::fmt;
27
28#[cfg(target_arch = "aarch64")]
29use std::arch::aarch64::{vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vsubq_f32};
30#[cfg(target_arch = "x86_64")]
31#[allow(clippy::wildcard_imports)]
32use std::arch::x86_64::*;
33
34/// Maximum number of codes per subspace (16 for 4-bit quantization)
35const MAX_CODES: usize = 16;
36
37/// Number of bits per dimension for quantization
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum QuantizationBits {
40    /// 1 bit per dimension (32x compression) - Binary/BBQ
41    Bits1,
42    /// 2 bits per dimension (16x compression)
43    Bits2,
44    /// 3 bits per dimension (~10x compression)
45    Bits3,
46    /// 4 bits per dimension (8x compression)
47    Bits4,
48    /// 5 bits per dimension (~6x compression)
49    Bits5,
50    /// 7 bits per dimension (~4x compression)
51    Bits7,
52    /// 8 bits per dimension (4x compression)
53    Bits8,
54}
55
56impl QuantizationBits {
57    /// Convert to number of bits
58    #[must_use]
59    pub fn to_u8(self) -> u8 {
60        match self {
61            QuantizationBits::Bits1 => 1,
62            QuantizationBits::Bits2 => 2,
63            QuantizationBits::Bits3 => 3,
64            QuantizationBits::Bits4 => 4,
65            QuantizationBits::Bits5 => 5,
66            QuantizationBits::Bits7 => 7,
67            QuantizationBits::Bits8 => 8,
68        }
69    }
70
71    /// Get number of quantization levels (2^bits)
72    #[must_use]
73    pub fn levels(self) -> usize {
74        1 << self.to_u8()
75    }
76
77    /// Get compression ratio vs f32 (32 bits / `bits_per_dim`)
78    #[must_use]
79    pub fn compression_ratio(self) -> f32 {
80        32.0 / self.to_u8() as f32
81    }
82
83    /// Get number of values that fit in one byte
84    #[must_use]
85    pub fn values_per_byte(self) -> usize {
86        8 / self.to_u8() as usize
87    }
88}
89
90impl fmt::Display for QuantizationBits {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        write!(f, "{}-bit", self.to_u8())
93    }
94}
95
96/// Configuration for `RaBitQ` quantization
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct RaBitQParams {
99    /// Number of bits per dimension
100    pub bits_per_dim: QuantizationBits,
101
102    /// Number of rescaling factors to try (DEPRECATED: use trained quantizer)
103    ///
104    /// Higher values = better quantization quality but slower
105    /// Typical range: 8-16
106    pub num_rescale_factors: usize,
107
108    /// Range of rescaling factors to try (DEPRECATED: use trained quantizer)
109    ///
110    /// Typical range: (0.5, 2.0) means try scales from 0.5x to 2.0x
111    pub rescale_range: (f32, f32),
112}
113
114/// Trained quantization parameters computed from data
115///
116/// Stores per-dimension min/max values learned from a representative sample.
117/// This enables consistent quantization across all vectors and correct ADC distances.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct TrainedParams {
120    /// Minimum value per dimension
121    pub mins: Vec<f32>,
122    /// Maximum value per dimension
123    pub maxs: Vec<f32>,
124    /// Number of dimensions
125    pub dimensions: usize,
126}
127
128impl TrainedParams {
129    /// Train quantization parameters from sample vectors
130    ///
131    /// Computes per-dimension min/max using percentiles to exclude outliers.
132    /// Uses 1st and 99th percentiles by default for robustness.
133    ///
134    /// # Arguments
135    /// * `vectors` - Sample vectors to train from (should be representative)
136    ///
137    /// # Errors
138    /// Returns error if vectors is empty or vectors have inconsistent dimensions.
139    pub fn train(vectors: &[&[f32]]) -> Result<Self, &'static str> {
140        Self::train_with_percentiles(vectors, 0.01, 0.99)
141    }
142
143    /// Train with custom percentile bounds
144    ///
145    /// # Arguments
146    /// * `vectors` - Sample vectors to train from
147    /// * `lower_percentile` - Lower bound percentile (e.g., 0.01 for 1st percentile)
148    /// * `upper_percentile` - Upper bound percentile (e.g., 0.99 for 99th percentile)
149    ///
150    /// # Errors
151    /// Returns error if vectors is empty or vectors have inconsistent dimensions.
152    pub fn train_with_percentiles(
153        vectors: &[&[f32]],
154        lower_percentile: f32,
155        upper_percentile: f32,
156    ) -> Result<Self, &'static str> {
157        if vectors.is_empty() {
158            return Err("Need at least one vector to train");
159        }
160        let dimensions = vectors[0].len();
161        if !vectors.iter().all(|v| v.len() == dimensions) {
162            return Err("All vectors must have same dimensions");
163        }
164
165        let n = vectors.len();
166        let lower_idx = ((n as f32 * lower_percentile) as usize).min(n - 1);
167        let upper_idx = ((n as f32 * upper_percentile) as usize).min(n - 1);
168
169        let mut mins = Vec::with_capacity(dimensions);
170        let mut maxs = Vec::with_capacity(dimensions);
171
172        // For each dimension, collect values and compute percentiles
173        let mut dim_values: Vec<f32> = Vec::with_capacity(n);
174        for d in 0..dimensions {
175            dim_values.clear();
176            for v in vectors {
177                dim_values.push(v[d]);
178            }
179            dim_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
180
181            let min_val = dim_values[lower_idx];
182            let max_val = dim_values[upper_idx];
183
184            // Ensure non-zero range (add small epsilon if needed)
185            let range = max_val - min_val;
186            if range < 1e-7 {
187                mins.push(min_val - 0.5);
188                maxs.push(max_val + 0.5);
189            } else {
190                mins.push(min_val);
191                maxs.push(max_val);
192            }
193        }
194
195        Ok(Self {
196            mins,
197            maxs,
198            dimensions,
199        })
200    }
201
202    /// Quantize a single value using trained parameters for given dimension
203    #[inline]
204    #[must_use]
205    pub fn quantize_value(&self, value: f32, dim: usize, levels: usize) -> u8 {
206        let min = self.mins[dim];
207        let max = self.maxs[dim];
208        let range = max - min;
209
210        // Map value to [0, 1] range, then to [0, levels-1]
211        let normalized = (value - min) / range;
212        let level = (normalized * (levels - 1) as f32).round();
213        level.clamp(0.0, (levels - 1) as f32) as u8
214    }
215
216    /// Dequantize a code to reconstructed value for given dimension
217    #[inline]
218    #[must_use]
219    pub fn dequantize_value(&self, code: u8, dim: usize, levels: usize) -> f32 {
220        let min = self.mins[dim];
221        let max = self.maxs[dim];
222        let range = max - min;
223
224        // Map code back to original range
225        (code as f32 / (levels - 1) as f32) * range + min
226    }
227}
228
229impl Default for RaBitQParams {
230    fn default() -> Self {
231        Self {
232            bits_per_dim: QuantizationBits::Bits4, // 8x compression
233            num_rescale_factors: 12,               // Good balance
234            rescale_range: (0.5, 2.0),             // Paper recommendation
235        }
236    }
237}
238
239impl RaBitQParams {
240    /// Create parameters for 2-bit quantization (16x compression)
241    #[must_use]
242    pub fn bits2() -> Self {
243        Self {
244            bits_per_dim: QuantizationBits::Bits2,
245            ..Default::default()
246        }
247    }
248
249    /// Create parameters for 4-bit quantization (8x compression, recommended)
250    #[must_use]
251    pub fn bits4() -> Self {
252        Self {
253            bits_per_dim: QuantizationBits::Bits4,
254            ..Default::default()
255        }
256    }
257
258    /// Create parameters for 8-bit quantization (4x compression, highest quality)
259    #[must_use]
260    pub fn bits8() -> Self {
261        Self {
262            bits_per_dim: QuantizationBits::Bits8,
263            num_rescale_factors: 16,   // More factors for higher precision
264            rescale_range: (0.7, 1.5), // Narrower range for 8-bit
265        }
266    }
267}
268
269/// A quantized vector with optimal rescaling
270///
271/// Storage format:
272/// - data: Packed quantized values (multiple values per byte)
273/// - scale: Optimal rescaling factor for this vector
274/// - bits: Number of bits per dimension
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct QuantizedVector {
277    /// Packed quantized values
278    ///
279    /// Format depends on `bits_per_dim`:
280    /// - 2-bit: 4 values per byte
281    /// - 3-bit: Not byte-aligned, needs special packing
282    /// - 4-bit: 2 values per byte
283    /// - 8-bit: 1 value per byte
284    pub data: Vec<u8>,
285
286    /// Optimal rescaling factor for this vector
287    ///
288    /// This is the scale factor that minimized quantization error
289    /// during the rescaling search.
290    pub scale: f32,
291
292    /// Number of bits per dimension
293    pub bits: u8,
294
295    /// Original vector dimensions (for unpacking)
296    pub dimensions: usize,
297}
298
299impl QuantizedVector {
300    /// Create a new quantized vector
301    #[must_use]
302    pub fn new(data: Vec<u8>, scale: f32, bits: u8, dimensions: usize) -> Self {
303        Self {
304            data,
305            scale,
306            bits,
307            dimensions,
308        }
309    }
310
311    /// Get memory usage in bytes
312    #[must_use]
313    pub fn memory_bytes(&self) -> usize {
314        std::mem::size_of::<Self>() + self.data.len()
315    }
316
317    /// Get compression ratio vs original f32 vector
318    #[must_use]
319    pub fn compression_ratio(&self) -> f32 {
320        let original_bytes = self.dimensions * 4; // f32 = 4 bytes
321        let compressed_bytes = self.data.len() + 4 + 1; // data + scale + bits
322        original_bytes as f32 / compressed_bytes as f32
323    }
324}
325
326/// `RaBitQ` quantizer
327///
328/// Implements scalar quantization with trained per-dimension ranges.
329/// Training computes min/max per dimension from sample data, enabling
330/// consistent quantization across all vectors and correct ADC distances.
331///
332/// # Usage
333///
334/// ```ignore
335/// // Create and train quantizer
336/// let mut quantizer = RaBitQ::new(RaBitQParams::bits4());
337/// quantizer.train(&sample_vectors);
338///
339/// // Quantize vectors
340/// let quantized = quantizer.quantize(&vector);
341///
342/// // ADC search (distances are mathematically correct)
343/// let adc = quantizer.build_adc_table(&query);
344/// let dist = adc.distance(&quantized.data);
345/// ```
346#[derive(Debug, Clone, Serialize, Deserialize)]
347pub struct RaBitQ {
348    params: RaBitQParams,
349    /// Trained parameters (per-dimension min/max)
350    /// When None, falls back to legacy per-vector scaling (deprecated)
351    trained: Option<TrainedParams>,
352}
353
354/// Asymmetric Distance Computation (ADC) lookup table for fast quantized search
355///
356/// Precomputes partial squared distances from a query vector to all possible
357/// quantized codes. This enables O(1) distance computation per dimension instead
358/// of O(dim) decompression + distance calculation.
359///
360/// # Performance
361///
362/// - **Memory**: For 4-bit: dim * 16 * 4 bytes (e.g., 1536D = 96KB per query)
363/// - **Speedup**: 5-10x faster distance computation vs full decompression
364/// - **Use case**: Scanning many candidate vectors during HNSW search
365///
366/// # Algorithm
367///
368/// Instead of:
369/// ```ignore
370/// for each candidate:
371///     decompress(candidate) -> O(dim)
372///     distance(query, decompressed) -> O(dim)
373/// ```
374///
375/// With ADC:
376/// ```ignore
377/// precompute table[code] = (query_value - dequantize(code))^2 for all codes
378/// for each candidate:
379///     sum(table[candidate[i]]) -> O(1) per dimension
380/// ```
381#[derive(Debug, Clone)]
382pub struct ADCTable {
383    /// Lookup table: `table[dim_idx][code] = partial squared distance`
384    /// For 4-bit: each inner array has 16 entries (codes 0-15)
385    /// For 2-bit: each inner array has 4 entries (codes 0-3)
386    table: Vec<SmallVec<[f32; MAX_CODES]>>,
387
388    /// Quantization parameters (bits per dimension)
389    bits: u8,
390
391    /// Number of dimensions
392    dimensions: usize,
393}
394
395impl ADCTable {
396    /// Build ADC lookup table using trained quantization parameters
397    ///
398    /// This is the production method that computes correct distances.
399    /// Uses per-dimension min/max ranges from training.
400    ///
401    /// # Arguments
402    ///
403    /// * `query` - The uncompressed query vector
404    /// * `trained` - Trained parameters with per-dimension min/max
405    /// * `params` - Quantization parameters (bits per dimension)
406    #[must_use]
407    pub fn new_trained(query: &[f32], trained: &TrainedParams, params: &RaBitQParams) -> Self {
408        let bits = params.bits_per_dim.to_u8();
409        let num_codes = params.bits_per_dim.levels();
410        let dimensions = query.len();
411
412        let mut table = Vec::with_capacity(dimensions);
413
414        // For each dimension, compute distances to all possible codes
415        for (d, &q_value) in query.iter().enumerate() {
416            let mut dim_table = SmallVec::new();
417
418            for code in 0..num_codes {
419                // Dequantize using trained min/max for this dimension
420                let reconstructed = trained.dequantize_value(code as u8, d, num_codes);
421
422                // Compute squared difference (partial L2 distance)
423                let diff = q_value - reconstructed;
424                dim_table.push(diff * diff);
425            }
426
427            table.push(dim_table);
428        }
429
430        Self {
431            table,
432            bits,
433            dimensions,
434        }
435    }
436
437    /// Build ADC lookup table for a query vector (DEPRECATED)
438    ///
439    /// For each dimension and each possible quantized code, precomputes the
440    /// squared distance contribution: (query[i] - dequantize(code, scale))^2
441    ///
442    /// WARNING: This method uses a fixed scale which produces incorrect distances
443    /// when vectors were quantized with different scales. Use `new_trained()` instead.
444    ///
445    /// # Arguments
446    ///
447    /// * `query` - The uncompressed query vector
448    /// * `scale` - The scale factor used for quantization (from training or default)
449    /// * `params` - Quantization parameters (bits per dimension)
450    ///
451    /// # Returns
452    ///
453    /// An `ADCTable` that can compute distances via `distance()` method
454    ///
455    /// # Note
456    /// Prefer `ADCTable::new_trained()` with `TrainedParams` for correct ADC distances.
457    /// This method uses per-vector scale which gives lower accuracy.
458    #[must_use]
459    pub fn new(query: &[f32], scale: f32, params: &RaBitQParams) -> Self {
460        let bits = params.bits_per_dim.to_u8();
461        let num_codes = params.bits_per_dim.levels();
462        let dimensions = query.len();
463
464        let mut table = Vec::with_capacity(dimensions);
465
466        // Dequantization factor: value = code / (levels - 1) / scale
467        let levels = num_codes as f32;
468        let dequant_factor = 1.0 / ((levels - 1.0) * scale);
469
470        // For each dimension, compute distances to all possible codes
471        for &q_value in query {
472            let mut dim_table = SmallVec::new();
473
474            for code in 0..num_codes {
475                // Dequantize the code to get the reconstructed value
476                let reconstructed = (code as f32) * dequant_factor;
477
478                // Compute squared difference (partial L2 distance)
479                let diff = q_value - reconstructed;
480                dim_table.push(diff * diff);
481            }
482
483            table.push(dim_table);
484        }
485
486        Self {
487            table,
488            bits,
489            dimensions,
490        }
491    }
492
493    /// Compute approximate L2 squared distance using lookup table
494    ///
495    /// This is the hot path for search! Instead of decompressing and computing
496    /// distance, we just sum up precomputed values from the table.
497    ///
498    /// # Performance
499    ///
500    /// - 4-bit: ~5-10x faster than decompression + distance
501    /// - Cache-friendly: sequential access patterns
502    /// - SIMD-friendly: can vectorize the summation
503    ///
504    /// # Arguments
505    ///
506    /// * `data` - Packed quantized bytes
507    ///
508    /// # Returns
509    ///
510    /// Approximate squared L2 distance (not square-rooted for efficiency)
511    #[inline]
512    #[must_use]
513    pub fn distance_squared(&self, data: &[u8]) -> f32 {
514        match self.bits {
515            4 => self.distance_squared_4bit(data),
516            2 => self.distance_squared_2bit(data),
517            8 => self.distance_squared_8bit(data),
518            _ => self.distance_squared_generic(data),
519        }
520    }
521
522    /// Compute distance and return square root (actual L2 distance)
523    ///
524    /// Uses SIMD-accelerated distance computation (AVX2 on `x86_64`, NEON on aarch64).
525    #[inline]
526    #[must_use]
527    pub fn distance(&self, data: &[u8]) -> f32 {
528        self.distance_squared_simd(data).sqrt()
529    }
530
531    /// Fast path for 4-bit quantization (most common case)
532    ///
533    /// # Safety invariants (maintained by `ADCTable::new`)
534    /// - `self.table.len() == self.dimensions`
535    /// - Each `table[i]` has exactly 16 entries (4-bit = 2^4 codes)
536    /// - Input `data` has `ceil(dimensions/2)` bytes (2 values per byte)
537    #[inline]
538    fn distance_squared_4bit(&self, data: &[u8]) -> f32 {
539        let mut sum = 0.0f32;
540        let num_pairs = self.dimensions / 2;
541
542        // Process pairs of dimensions (2 codes per byte)
543        for i in 0..num_pairs {
544            if i >= data.len() {
545                break;
546            }
547
548            // SAFETY: i < num_pairs <= data.len() (checked above)
549            let byte = unsafe { *data.get_unchecked(i) };
550            let code_hi = (byte >> 4) as usize; // 0..=15
551            let code_lo = (byte & 0x0F) as usize; // 0..=15
552
553            // SAFETY:
554            // - i*2 < dimensions (since i < num_pairs = dimensions/2)
555            // - i*2+1 < dimensions (same reasoning)
556            // - code_hi, code_lo in 0..16 (4-bit mask guarantees this)
557            // - table has 16 entries per dimension (4-bit quantization)
558            sum += unsafe {
559                *self.table.get_unchecked(i * 2).get_unchecked(code_hi)
560                    + *self.table.get_unchecked(i * 2 + 1).get_unchecked(code_lo)
561            };
562        }
563
564        // Handle odd dimension
565        if self.dimensions % 2 == 1 && num_pairs < data.len() {
566            // SAFETY: num_pairs < data.len() checked above
567            let byte = unsafe { *data.get_unchecked(num_pairs) };
568            let code_hi = (byte >> 4) as usize; // 0..=15
569                                                // SAFETY:
570                                                // - dimensions-1 is valid index (dimensions >= 1 when odd)
571                                                // - code_hi in 0..16 (4-bit mask)
572            sum += unsafe {
573                *self
574                    .table
575                    .get_unchecked(self.dimensions - 1)
576                    .get_unchecked(code_hi)
577            };
578        }
579
580        sum
581    }
582
583    /// Fast path for 2-bit quantization
584    ///
585    /// # Safety invariants (maintained by `ADCTable::new`)
586    /// - `self.table.len() == self.dimensions`
587    /// - Each `table[i]` has exactly 4 entries (2-bit = 2^2 codes)
588    /// - Input `data` has `ceil(dimensions/4)` bytes (4 values per byte)
589    #[inline]
590    fn distance_squared_2bit(&self, data: &[u8]) -> f32 {
591        let mut sum = 0.0f32;
592        let num_quads = self.dimensions / 4;
593
594        // Process quads of dimensions (4 codes per byte)
595        for i in 0..num_quads {
596            if i >= data.len() {
597                break;
598            }
599
600            // SAFETY: i < num_quads <= data.len() (checked above)
601            let byte = unsafe { *data.get_unchecked(i) };
602
603            // SAFETY:
604            // - i*4+k < dimensions for k in 0..4 (since i < num_quads = dimensions/4)
605            // - all codes in 0..4 (2-bit mask guarantees this)
606            // - table has 4 entries per dimension (2-bit quantization)
607            sum += unsafe {
608                *self
609                    .table
610                    .get_unchecked(i * 4)
611                    .get_unchecked((byte & 0b11) as usize)
612                    + *self
613                        .table
614                        .get_unchecked(i * 4 + 1)
615                        .get_unchecked(((byte >> 2) & 0b11) as usize)
616                    + *self
617                        .table
618                        .get_unchecked(i * 4 + 2)
619                        .get_unchecked(((byte >> 4) & 0b11) as usize)
620                    + *self
621                        .table
622                        .get_unchecked(i * 4 + 3)
623                        .get_unchecked(((byte >> 6) & 0b11) as usize)
624            };
625        }
626
627        // Handle remainder
628        let remaining = self.dimensions % 4;
629        if remaining > 0 && num_quads < data.len() {
630            // SAFETY: num_quads < data.len() checked above
631            let byte = unsafe { *data.get_unchecked(num_quads) };
632            for j in 0..remaining {
633                let code = ((byte >> (j * 2)) & 0b11) as usize; // 0..=3
634                                                                // SAFETY:
635                                                                // - num_quads*4+j < dimensions (j < remaining = dimensions%4)
636                                                                // - code in 0..4 (2-bit mask)
637                sum += unsafe {
638                    *self
639                        .table
640                        .get_unchecked(num_quads * 4 + j)
641                        .get_unchecked(code)
642                };
643            }
644        }
645
646        sum
647    }
648
649    /// Fast path for 8-bit quantization
650    ///
651    /// # Safety invariants (maintained by `ADCTable::new`)
652    /// - `self.table.len() == self.dimensions`
653    /// - Each `table[i]` has exactly 256 entries (8-bit = 2^8 codes)
654    /// - Input `data` has `dimensions` bytes (1 value per byte)
655    #[inline]
656    fn distance_squared_8bit(&self, data: &[u8]) -> f32 {
657        let mut sum = 0.0f32;
658
659        for (i, &byte) in data.iter().enumerate().take(self.dimensions) {
660            // SAFETY:
661            // - i < dimensions (take() ensures this)
662            // - byte as usize in 0..256 (u8 range)
663            // - table has 256 entries per dimension (8-bit quantization)
664            sum += unsafe { *self.table.get_unchecked(i).get_unchecked(byte as usize) };
665        }
666
667        sum
668    }
669
670    /// Generic fallback for other bit widths
671    #[inline]
672    fn distance_squared_generic(&self, data: &[u8]) -> f32 {
673        // For non-standard bit widths, fall back to bounds-checked access
674        let mut sum = 0.0f32;
675
676        for (i, dim_table) in self.table.iter().enumerate() {
677            if i >= data.len() {
678                break;
679            }
680            let code = data[i] as usize;
681            if let Some(&dist) = dim_table.get(code) {
682                sum += dist;
683            }
684        }
685
686        sum
687    }
688
689    /// SIMD-accelerated distance computation for 4-bit quantization
690    ///
691    /// Uses AVX2 on `x86_64` or NEON on `aarch64` to process multiple lookups in parallel.
692    /// Falls back to scalar implementation if SIMD not available.
693    #[inline]
694    #[must_use]
695    pub fn distance_squared_simd(&self, data: &[u8]) -> f32 {
696        match self.bits {
697            4 => {
698                #[cfg(target_arch = "x86_64")]
699                {
700                    if is_x86_feature_detected!("avx2") {
701                        unsafe { self.distance_squared_4bit_avx2(data) }
702                    } else {
703                        // x86_64 fallback to scalar
704                        self.distance_squared_4bit(data)
705                    }
706                }
707                #[cfg(target_arch = "aarch64")]
708                {
709                    // NEON is always available on aarch64
710                    unsafe { self.distance_squared_4bit_neon(data) }
711                }
712                #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
713                {
714                    // Other architectures fallback to scalar
715                    self.distance_squared_4bit(data)
716                }
717            }
718            2 => {
719                // For 2-bit, scalar is already quite fast
720                self.distance_squared_2bit(data)
721            }
722            8 => {
723                // For 8-bit, could use SIMD gather but scalar is reasonable
724                self.distance_squared_8bit(data)
725            }
726            _ => self.distance_squared_generic(data),
727        }
728    }
729
730    /// AVX2 implementation for 4-bit ADC distance
731    #[cfg(target_arch = "x86_64")]
732    #[target_feature(enable = "avx2")]
733    #[target_feature(enable = "fma")]
734    unsafe fn distance_squared_4bit_avx2(&self, data: &[u8]) -> f32 {
735        let mut sum = _mm256_setzero_ps();
736        let num_pairs = self.dimensions / 2;
737
738        // Process 8 pairs (16 dimensions) at a time using AVX2
739        let chunks = num_pairs / 8;
740        for chunk_idx in 0..chunks {
741            let byte_idx = chunk_idx * 8;
742            if byte_idx + 8 > data.len() {
743                break;
744            }
745
746            // Load 8 bytes (16 4-bit codes)
747            let mut values = [0.0f32; 8];
748            for (i, value) in values.iter_mut().enumerate() {
749                let byte = *data.get_unchecked(byte_idx + i);
750                let code_hi = (byte >> 4) as usize;
751                let code_lo = (byte & 0x0F) as usize;
752
753                let dist_hi = *self
754                    .table
755                    .get_unchecked((byte_idx + i) * 2)
756                    .get_unchecked(code_hi);
757                let dist_lo = *self
758                    .table
759                    .get_unchecked((byte_idx + i) * 2 + 1)
760                    .get_unchecked(code_lo);
761                *value = dist_hi + dist_lo;
762            }
763
764            let vec = _mm256_loadu_ps(values.as_ptr());
765            sum = _mm256_add_ps(sum, vec);
766        }
767
768        // Horizontal sum of AVX2 register
769        let mut result = horizontal_sum_avx2(sum);
770
771        // Handle remainder with scalar
772        for i in (chunks * 8)..num_pairs {
773            if i >= data.len() {
774                break;
775            }
776            let byte = *data.get_unchecked(i);
777            let code_hi = (byte >> 4) as usize;
778            let code_lo = (byte & 0x0F) as usize;
779
780            result += *self.table.get_unchecked(i * 2).get_unchecked(code_hi)
781                + *self.table.get_unchecked(i * 2 + 1).get_unchecked(code_lo);
782        }
783
784        // Handle odd dimension
785        if self.dimensions % 2 == 1 && num_pairs < data.len() {
786            let byte = *data.get_unchecked(num_pairs);
787            let code_hi = (byte >> 4) as usize;
788            result += *self
789                .table
790                .get_unchecked(self.dimensions - 1)
791                .get_unchecked(code_hi);
792        }
793
794        result
795    }
796
797    /// NEON implementation for 4-bit ADC distance
798    #[cfg(target_arch = "aarch64")]
799    unsafe fn distance_squared_4bit_neon(&self, data: &[u8]) -> f32 {
800        let mut sum = vdupq_n_f32(0.0);
801        let num_pairs = self.dimensions / 2;
802
803        // Process 4 pairs (8 dimensions) at a time using NEON
804        let chunks = num_pairs / 4;
805        for chunk_idx in 0..chunks {
806            let byte_idx = chunk_idx * 4;
807            if byte_idx + 4 > data.len() {
808                break;
809            }
810
811            let mut values = [0.0f32; 4];
812            for (i, value) in values.iter_mut().enumerate() {
813                let byte = *data.get_unchecked(byte_idx + i);
814                let code_hi = (byte >> 4) as usize;
815                let code_lo = (byte & 0x0F) as usize;
816
817                let dist_hi = *self
818                    .table
819                    .get_unchecked((byte_idx + i) * 2)
820                    .get_unchecked(code_hi);
821                let dist_lo = *self
822                    .table
823                    .get_unchecked((byte_idx + i) * 2 + 1)
824                    .get_unchecked(code_lo);
825                *value = dist_hi + dist_lo;
826            }
827
828            let vec = vld1q_f32(values.as_ptr());
829            sum = vaddq_f32(sum, vec);
830        }
831
832        let mut result = vaddvq_f32(sum);
833
834        // Handle remainder with scalar
835        for i in (chunks * 4)..num_pairs {
836            if i >= data.len() {
837                break;
838            }
839            let byte = *data.get_unchecked(i);
840            let code_hi = (byte >> 4) as usize;
841            let code_lo = (byte & 0x0F) as usize;
842
843            result += *self.table.get_unchecked(i * 2).get_unchecked(code_hi)
844                + *self.table.get_unchecked(i * 2 + 1).get_unchecked(code_lo);
845        }
846
847        // Handle odd dimension
848        if self.dimensions % 2 == 1 && num_pairs < data.len() {
849            let byte = *data.get_unchecked(num_pairs);
850            let code_hi = (byte >> 4) as usize;
851            result += *self
852                .table
853                .get_unchecked(self.dimensions - 1)
854                .get_unchecked(code_hi);
855        }
856
857        result
858    }
859
860    /// Get bits per dimension
861    #[must_use]
862    pub fn bits(&self) -> u8 {
863        self.bits
864    }
865
866    /// Get number of dimensions
867    #[must_use]
868    pub fn dimensions(&self) -> usize {
869        self.dimensions
870    }
871
872    /// Get partial distance for a dimension and code
873    ///
874    /// Returns 0.0 if indices are out of bounds.
875    #[must_use]
876    pub fn get(&self, dim: usize, code: usize) -> f32 {
877        self.table
878            .get(dim)
879            .and_then(|t| t.get(code))
880            .copied()
881            .unwrap_or(0.0)
882    }
883
884    /// Get memory usage in bytes
885    #[must_use]
886    pub fn memory_bytes(&self) -> usize {
887        std::mem::size_of::<Self>()
888            + self.table.len() * std::mem::size_of::<SmallVec<[f32; MAX_CODES]>>()
889            + self.table.iter().map(|t| t.len() * 4).sum::<usize>()
890    }
891}
892
893impl RaBitQ {
894    /// Create a new `RaBitQ` quantizer (untrained)
895    ///
896    /// Call `train()` before use to enable correct ADC distances.
897    #[must_use]
898    pub fn new(params: RaBitQParams) -> Self {
899        Self {
900            params,
901            trained: None,
902        }
903    }
904
905    /// Create a trained `RaBitQ` quantizer
906    #[must_use]
907    pub fn new_trained(params: RaBitQParams, trained: TrainedParams) -> Self {
908        Self {
909            params,
910            trained: Some(trained),
911        }
912    }
913
914    /// Create with default 4-bit quantization
915    #[must_use]
916    pub fn default_4bit() -> Self {
917        Self::new(RaBitQParams::bits4())
918    }
919
920    /// Get quantization parameters
921    #[must_use]
922    pub fn params(&self) -> &RaBitQParams {
923        &self.params
924    }
925
926    /// Check if quantizer has been trained
927    #[must_use]
928    pub fn is_trained(&self) -> bool {
929        self.trained.is_some()
930    }
931
932    /// Get trained parameters (if any)
933    #[must_use]
934    pub fn trained_params(&self) -> Option<&TrainedParams> {
935        self.trained.as_ref()
936    }
937
938    /// Train quantizer on sample vectors
939    ///
940    /// Computes per-dimension min/max ranges from the sample.
941    /// Must be called before quantization for correct ADC distances.
942    ///
943    /// # Arguments
944    /// * `vectors` - Representative sample of vectors to train from
945    ///
946    /// # Errors
947    /// Returns error if vectors is empty or have inconsistent dimensions.
948    pub fn train(&mut self, vectors: &[&[f32]]) -> Result<(), &'static str> {
949        self.trained = Some(TrainedParams::train(vectors)?);
950        Ok(())
951    }
952
953    /// Train with owned vectors (convenience method)
954    ///
955    /// # Errors
956    /// Returns error if vectors is empty or have inconsistent dimensions.
957    pub fn train_owned(&mut self, vectors: &[Vec<f32>]) -> Result<(), &'static str> {
958        let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
959        self.train(&refs)
960    }
961
962    /// Quantize a vector using trained parameters
963    ///
964    /// If trained: uses per-dimension min/max for consistent quantization
965    /// If untrained: falls back to legacy per-vector scaling (deprecated)
966    #[must_use]
967    pub fn quantize(&self, vector: &[f32]) -> QuantizedVector {
968        // Use trained quantization if available
969        if let Some(ref trained) = self.trained {
970            return self.quantize_trained(vector, trained);
971        }
972
973        // Legacy fallback: per-vector scale search (deprecated)
974        let mut best_error = f32::MAX;
975        let mut best_quantized = Vec::new();
976        let mut best_scale = 1.0;
977
978        // Generate rescaling factors to try
979        let scales = self.generate_scales();
980
981        // Try each scale and find the one with minimum error
982        for scale in scales {
983            let quantized = self.quantize_with_scale(vector, scale);
984            let error = self.compute_error(vector, &quantized, scale);
985
986            if error < best_error {
987                best_error = error;
988                best_quantized = quantized;
989                best_scale = scale;
990            }
991        }
992
993        QuantizedVector::new(
994            best_quantized,
995            best_scale,
996            self.params.bits_per_dim.to_u8(),
997            vector.len(),
998        )
999    }
1000
1001    /// Quantize using trained per-dimension min/max ranges
1002    ///
1003    /// This is the production path that enables correct ADC distances.
1004    fn quantize_trained(&self, vector: &[f32], trained: &TrainedParams) -> QuantizedVector {
1005        let bits = self.params.bits_per_dim.to_u8();
1006        let levels = self.params.bits_per_dim.levels();
1007
1008        // Quantize each dimension using trained min/max
1009        let quantized: Vec<u8> = vector
1010            .iter()
1011            .enumerate()
1012            .map(|(d, &v)| trained.quantize_value(v, d, levels))
1013            .collect();
1014
1015        // Pack into bytes
1016        let packed = self.pack_quantized(&quantized, bits);
1017
1018        // Scale=1.0 for trained quantization (min/max handles the range)
1019        QuantizedVector::new(packed, 1.0, bits, vector.len())
1020    }
1021
1022    /// Generate rescaling factors to try
1023    ///
1024    /// Returns a vector of scale factors evenly spaced between
1025    /// `rescale_range.0` and `rescale_range.1`
1026    fn generate_scales(&self) -> Vec<f32> {
1027        let (min_scale, max_scale) = self.params.rescale_range;
1028        let n = self.params.num_rescale_factors;
1029
1030        if n == 1 {
1031            return vec![f32::midpoint(min_scale, max_scale)];
1032        }
1033
1034        let step = (max_scale - min_scale) / (n - 1) as f32;
1035        (0..n).map(|i| min_scale + i as f32 * step).collect()
1036    }
1037
1038    /// Quantize a vector with a specific scale factor
1039    ///
1040    /// Algorithm (Extended RaBitQ):
1041    /// 1. Scale: v' = v * scale
1042    /// 2. Quantize to grid: q = round(v' * (2^bits - 1))
1043    /// 3. Clamp to valid range
1044    /// 4. Pack into bytes
1045    fn quantize_with_scale(&self, vector: &[f32], scale: f32) -> Vec<u8> {
1046        let bits = self.params.bits_per_dim.to_u8();
1047        let levels = self.params.bits_per_dim.levels() as f32;
1048        let max_level = (levels - 1.0) as u8;
1049
1050        // Scale and quantize directly (no normalization needed)
1051        let quantized: Vec<u8> = vector
1052            .iter()
1053            .map(|&v| {
1054                // Scale the value
1055                let scaled = v * scale;
1056                // Quantize to grid [0, levels-1]
1057                let level = (scaled * (levels - 1.0)).round();
1058                // Clamp to valid range
1059                level.clamp(0.0, max_level as f32) as u8
1060            })
1061            .collect();
1062
1063        // Pack into bytes
1064        self.pack_quantized(&quantized, bits)
1065    }
1066
1067    /// Pack quantized values into bytes
1068    ///
1069    /// Packing depends on bits per dimension:
1070    /// - 2-bit: 4 values per byte (00 00 00 00)
1071    /// - 4-bit: 2 values per byte (0000 0000)
1072    /// - 8-bit: 1 value per byte
1073    #[allow(clippy::unused_self)]
1074    fn pack_quantized(&self, values: &[u8], bits: u8) -> Vec<u8> {
1075        match bits {
1076            2 => {
1077                // 4 values per byte
1078                let mut packed = Vec::with_capacity(values.len().div_ceil(4));
1079                for chunk in values.chunks(4) {
1080                    let mut byte = 0u8;
1081                    for (i, &val) in chunk.iter().enumerate() {
1082                        byte |= (val & 0b11) << (i * 2);
1083                    }
1084                    packed.push(byte);
1085                }
1086                packed
1087            }
1088            4 => {
1089                // 2 values per byte
1090                let mut packed = Vec::with_capacity(values.len().div_ceil(2));
1091                for chunk in values.chunks(2) {
1092                    let byte = if chunk.len() == 2 {
1093                        (chunk[0] << 4) | (chunk[1] & 0x0F)
1094                    } else {
1095                        chunk[0] << 4
1096                    };
1097                    packed.push(byte);
1098                }
1099                packed
1100            }
1101            8 => {
1102                // 1 value per byte (no packing needed)
1103                values.to_vec()
1104            }
1105            _ => {
1106                // 3, 5, 7-bit: fall back to 8-bit storage
1107                // Not implementing proper bit-packing because:
1108                // - Public API only exposes 2, 4, 8-bit (see python/src/lib.rs)
1109                // - Cross-byte packing is complex with marginal compression benefit
1110                // - 4-bit (8x) vs 5-bit (~6x) isn't worth the code complexity
1111                values.to_vec()
1112            }
1113        }
1114    }
1115
1116    /// Unpack quantized bytes into individual values
1117    #[must_use]
1118    pub fn unpack_quantized(&self, packed: &[u8], bits: u8, dimensions: usize) -> Vec<u8> {
1119        match bits {
1120            2 => {
1121                // 4 values per byte
1122                let mut values = Vec::with_capacity(dimensions);
1123                for &byte in packed {
1124                    for i in 0..4 {
1125                        if values.len() < dimensions {
1126                            values.push((byte >> (i * 2)) & 0b11);
1127                        }
1128                    }
1129                }
1130                values
1131            }
1132            4 => {
1133                // 2 values per byte
1134                let mut values = Vec::with_capacity(dimensions);
1135                for &byte in packed {
1136                    values.push(byte >> 4);
1137                    if values.len() < dimensions {
1138                        values.push(byte & 0x0F);
1139                    }
1140                }
1141                values.truncate(dimensions);
1142                values
1143            }
1144            8 => {
1145                // 1 value per byte
1146                packed[..dimensions.min(packed.len())].to_vec()
1147            }
1148            _ => {
1149                // For other bit widths, assume 8-bit storage
1150                packed[..dimensions.min(packed.len())].to_vec()
1151            }
1152        }
1153    }
1154
1155    /// Compute quantization error (reconstruction error)
1156    ///
1157    /// Error = ||original - reconstructed||²
1158    fn compute_error(&self, original: &[f32], quantized: &[u8], scale: f32) -> f32 {
1159        let reconstructed = self.reconstruct(quantized, scale, original.len());
1160
1161        original
1162            .iter()
1163            .zip(reconstructed.iter())
1164            .map(|(o, r)| (o - r).powi(2))
1165            .sum()
1166    }
1167
1168    /// Reconstruct (dequantize) a quantized vector
1169    ///
1170    /// Algorithm (Extended RaBitQ):
1171    /// 1. Unpack bytes to quantized values [0, 2^bits-1]
1172    /// 2. Denormalize: v' = q / (2^bits - 1)
1173    /// 3. Unscale: v = v' / scale
1174    #[must_use]
1175    pub fn reconstruct(&self, quantized: &[u8], scale: f32, dimensions: usize) -> Vec<f32> {
1176        let bits = self.params.bits_per_dim.to_u8();
1177        let levels = self.params.bits_per_dim.levels() as f32;
1178
1179        // Unpack bytes
1180        let values = self.unpack_quantized(quantized, bits, dimensions);
1181
1182        // Dequantize: reverse the quantization process
1183        values
1184            .iter()
1185            .map(|&q| {
1186                // Denormalize from [0, levels-1] to [0, 1]
1187                let denorm = q as f32 / (levels - 1.0);
1188                // Unscale
1189                denorm / scale
1190            })
1191            .collect()
1192    }
1193
1194    /// Compute L2 (Euclidean) distance between two quantized vectors
1195    ///
1196    /// This reconstructs both vectors and computes standard L2 distance.
1197    /// For maximum accuracy, use this with original vectors for reranking.
1198    #[must_use]
1199    pub fn distance_l2(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1200        let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1201        let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1202
1203        v1.iter()
1204            .zip(v2.iter())
1205            .map(|(a, b)| (a - b).powi(2))
1206            .sum::<f32>()
1207            .sqrt()
1208    }
1209
1210    /// Compute cosine distance between two quantized vectors
1211    ///
1212    /// Cosine distance = 1 - cosine similarity
1213    #[must_use]
1214    pub fn distance_cosine(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1215        let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1216        let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1217
1218        let dot: f32 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
1219        let norm1: f32 = v1.iter().map(|a| a * a).sum::<f32>().sqrt();
1220        let norm2: f32 = v2.iter().map(|b| b * b).sum::<f32>().sqrt();
1221
1222        if norm1 < 1e-10 || norm2 < 1e-10 {
1223            return 1.0; // Maximum distance for zero vectors
1224        }
1225
1226        let cosine_sim = dot / (norm1 * norm2);
1227        1.0 - cosine_sim
1228    }
1229
1230    /// Compute dot product between two quantized vectors
1231    #[must_use]
1232    pub fn distance_dot(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1233        let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1234        let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1235
1236        // Return negative dot product (for nearest neighbor search)
1237        -v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum::<f32>()
1238    }
1239
1240    /// Compute approximate distance using quantized values directly (fast path)
1241    ///
1242    /// This computes distance in the quantized space without full reconstruction.
1243    /// Faster but less accurate than `distance_l2`.
1244    #[must_use]
1245    pub fn distance_approximate(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1246        // Unpack to quantized values (u8)
1247        let v1 = self.unpack_quantized(&qv1.data, qv1.bits, qv1.dimensions);
1248        let v2 = self.unpack_quantized(&qv2.data, qv2.bits, qv2.dimensions);
1249
1250        // Compute L2 distance in quantized space
1251        v1.iter()
1252            .zip(v2.iter())
1253            .map(|(a, b)| {
1254                let diff = (*a as i16 - *b as i16) as f32;
1255                diff * diff
1256            })
1257            .sum::<f32>()
1258            .sqrt()
1259    }
1260
1261    /// Compute asymmetric L2 distance (query vs quantized) without full reconstruction
1262    ///
1263    /// This is the hot path for search! It unpacks quantized values on the fly and
1264    /// computes distance against the uncompressed query vector.
1265    ///
1266    /// When trained: Uses per-dimension min/max for correct distance computation.
1267    /// When untrained: Falls back to per-vector scale (deprecated, lower accuracy).
1268    #[must_use]
1269    pub fn distance_asymmetric_l2(&self, query: &[f32], quantized: &QuantizedVector) -> f32 {
1270        // Use trained parameters when available for correct distance computation
1271        if let Some(trained) = &self.trained {
1272            return self.distance_asymmetric_l2_trained(query, &quantized.data, trained);
1273        }
1274        // Fallback to per-vector scale (deprecated, only for untrained quantizers)
1275        self.distance_asymmetric_l2_raw(query, &quantized.data, quantized.scale, quantized.bits)
1276    }
1277
1278    /// Asymmetric L2 distance from flat storage (no QuantizedVector wrapper)
1279    ///
1280    /// This is the preferred method for flat contiguous storage layouts.
1281    /// When trained: Uses per-dimension min/max (scale ignored).
1282    /// When untrained: Falls back to per-vector scale.
1283    #[must_use]
1284    #[inline]
1285    pub fn distance_asymmetric_l2_flat(&self, query: &[f32], data: &[u8], scale: f32) -> f32 {
1286        if let Some(trained) = &self.trained {
1287            return self.distance_asymmetric_l2_trained(query, data, trained);
1288        }
1289        // Fallback to per-vector scale (deprecated, only for untrained quantizers)
1290        self.distance_asymmetric_l2_raw(query, data, scale, self.params.bits_per_dim.to_u8())
1291    }
1292
1293    /// Asymmetric L2 distance using trained per-dimension parameters
1294    ///
1295    /// This is the correct distance computation for trained quantizers.
1296    /// Each dimension uses its own min/max range for accurate dequantization.
1297    #[must_use]
1298    fn distance_asymmetric_l2_trained(
1299        &self,
1300        query: &[f32],
1301        data: &[u8],
1302        trained: &TrainedParams,
1303    ) -> f32 {
1304        let levels = self.params.bits_per_dim.levels() as f32;
1305        let bits = self.params.bits_per_dim.to_u8();
1306
1307        // Use SmallVec for stack allocation when possible
1308        let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1309
1310        match bits {
1311            4 => {
1312                let num_pairs = query.len() / 2;
1313                if data.len() < query.len().div_ceil(2) {
1314                    return f32::MAX;
1315                }
1316
1317                for i in 0..num_pairs {
1318                    let byte = unsafe { *data.get_unchecked(i) };
1319                    let d0 = i * 2;
1320                    let d1 = i * 2 + 1;
1321
1322                    // Dequantize using per-dimension min/max
1323                    let code0 = (byte >> 4) as f32;
1324                    let code1 = (byte & 0x0F) as f32;
1325
1326                    let range0 = trained.maxs[d0] - trained.mins[d0];
1327                    let range1 = trained.maxs[d1] - trained.mins[d1];
1328
1329                    buffer.push((code0 / (levels - 1.0)) * range0 + trained.mins[d0]);
1330                    buffer.push((code1 / (levels - 1.0)) * range1 + trained.mins[d1]);
1331                }
1332
1333                if !query.len().is_multiple_of(2) {
1334                    let byte = unsafe { *data.get_unchecked(num_pairs) };
1335                    let d = num_pairs * 2;
1336                    let code = (byte >> 4) as f32;
1337                    let range = trained.maxs[d] - trained.mins[d];
1338                    buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1339                }
1340            }
1341            2 => {
1342                let num_quads = query.len() / 4;
1343                if data.len() < query.len().div_ceil(4) {
1344                    return f32::MAX;
1345                }
1346
1347                for i in 0..num_quads {
1348                    let byte = unsafe { *data.get_unchecked(i) };
1349                    for j in 0..4 {
1350                        let d = i * 4 + j;
1351                        let code = ((byte >> (j * 2)) & 0b11) as f32;
1352                        let range = trained.maxs[d] - trained.mins[d];
1353                        buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1354                    }
1355                }
1356
1357                let remaining = query.len() % 4;
1358                if remaining > 0 {
1359                    let byte = unsafe { *data.get_unchecked(num_quads) };
1360                    for j in 0..remaining {
1361                        let d = num_quads * 4 + j;
1362                        let code = ((byte >> (j * 2)) & 0b11) as f32;
1363                        let range = trained.maxs[d] - trained.mins[d];
1364                        buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1365                    }
1366                }
1367            }
1368            8 => {
1369                if data.len() < query.len() {
1370                    return f32::MAX;
1371                }
1372                for (d, &byte) in data.iter().enumerate().take(query.len()) {
1373                    let code = byte as f32;
1374                    let range = trained.maxs[d] - trained.mins[d];
1375                    buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1376                }
1377            }
1378            _ => {
1379                // Generic fallback for other bit widths
1380                let unpacked = self.unpack_quantized(data, bits, query.len());
1381                for (d, &code) in unpacked.iter().enumerate().take(query.len()) {
1382                    let range = trained.maxs[d] - trained.mins[d];
1383                    buffer.push((code as f32 / (levels - 1.0)) * range + trained.mins[d]);
1384                }
1385            }
1386        }
1387
1388        simd_l2_distance(query, &buffer)
1389    }
1390
1391    /// Build an ADC (Asymmetric Distance Computation) lookup table for a query
1392    ///
1393    /// If trained: uses per-dimension min/max for correct distances
1394    /// If untrained: uses provided scale (deprecated, incorrect distances)
1395    ///
1396    /// # Example
1397    ///
1398    /// ```ignore
1399    /// // Preferred: train first, then build ADC table
1400    /// quantizer.train(&sample_vectors);
1401    /// let adc_table = quantizer.build_adc_table(&query);
1402    /// for candidate in candidates {
1403    ///     let dist = adc_table.distance(&candidate.data);
1404    /// }
1405    /// ```
1406    #[must_use]
1407    pub fn build_adc_table(&self, query: &[f32]) -> Option<ADCTable> {
1408        self.trained
1409            .as_ref()
1410            .map(|trained| ADCTable::new_trained(query, trained, &self.params))
1411    }
1412
1413    /// Build ADC table with explicit scale
1414    ///
1415    /// # Note
1416    /// Prefer `build_adc_table()` on a trained quantizer for correct ADC distances.
1417    /// This method uses per-vector scale which gives lower accuracy.
1418    #[must_use]
1419    pub fn build_adc_table_with_scale(&self, query: &[f32], scale: f32) -> ADCTable {
1420        ADCTable::new(query, scale, &self.params)
1421    }
1422
1423    /// Compute distance using ADC table (convenience wrapper)
1424    ///
1425    /// Returns None if quantizer is not trained.
1426    #[must_use]
1427    pub fn distance_with_adc(&self, query: &[f32], quantized: &QuantizedVector) -> Option<f32> {
1428        let adc = self.build_adc_table(query)?;
1429        Some(adc.distance(&quantized.data))
1430    }
1431
1432    /// Low-level asymmetric distance computation on raw bytes
1433    ///
1434    /// Enables zero-copy access from mmap (no `QuantizedVector` allocation needed).
1435    /// Uses `SmallVec` to unpack on stack and SIMD for distance computation.
1436    #[must_use]
1437    pub fn distance_asymmetric_l2_raw(
1438        &self,
1439        query: &[f32],
1440        data: &[u8],
1441        scale: f32,
1442        bits: u8,
1443    ) -> f32 {
1444        let levels = self.params.bits_per_dim.levels() as f32;
1445
1446        // Dequantization factor: value = q * factor
1447        // derived from: q / (levels - 1) / scale
1448        // So factor = 1.0 / ((levels - 1.0) * scale)
1449        let factor = 1.0 / ((levels - 1.0) * scale);
1450
1451        match bits {
1452            4 => {
1453                // Unpack to stack buffer (up to 256 dims = 1KB). Falls back to heap for larger.
1454                let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1455
1456                let num_pairs = query.len() / 2;
1457
1458                // Check bounds once to avoid checks in loop
1459                if data.len() < query.len().div_ceil(2) {
1460                    // Fallback if data is truncated (shouldn't happen in valid storage)
1461                    return f32::MAX;
1462                }
1463
1464                for i in 0..num_pairs {
1465                    let byte = unsafe { *data.get_unchecked(i) };
1466                    buffer.push((byte >> 4) as f32 * factor);
1467                    buffer.push((byte & 0x0F) as f32 * factor);
1468                }
1469
1470                if !query.len().is_multiple_of(2) {
1471                    let byte = unsafe { *data.get_unchecked(num_pairs) };
1472                    buffer.push((byte >> 4) as f32 * factor);
1473                }
1474
1475                simd_l2_distance(query, &buffer)
1476            }
1477            2 => {
1478                let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1479                let num_quads = query.len() / 4;
1480
1481                if data.len() < query.len().div_ceil(4) {
1482                    return f32::MAX;
1483                }
1484
1485                for i in 0..num_quads {
1486                    let byte = unsafe { *data.get_unchecked(i) };
1487                    buffer.push((byte & 0b11) as f32 * factor);
1488                    buffer.push(((byte >> 2) & 0b11) as f32 * factor);
1489                    buffer.push(((byte >> 4) & 0b11) as f32 * factor);
1490                    buffer.push(((byte >> 6) & 0b11) as f32 * factor);
1491                }
1492
1493                // Handle remainder
1494                let remaining = query.len() % 4;
1495                if remaining > 0 {
1496                    let byte = unsafe { *data.get_unchecked(num_quads) };
1497                    for i in 0..remaining {
1498                        buffer.push(((byte >> (i * 2)) & 0b11) as f32 * factor);
1499                    }
1500                }
1501
1502                simd_l2_distance(query, &buffer)
1503            }
1504            _ => {
1505                // Generic fallback using existing unpack (allocates Vec if > 256, but correct)
1506                // Actually unpack_quantized returns Vec<u8>, so this path allocates.
1507                // That's fine for non-optimized bit widths.
1508                let unpacked = self.unpack_quantized(data, bits, query.len());
1509                let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1510
1511                for &q in &unpacked {
1512                    buffer.push(q as f32 * factor);
1513                }
1514
1515                simd_l2_distance(query, &buffer)
1516            }
1517        }
1518    }
1519
1520    // SIMD-optimized distance functions
1521
1522    /// Compute L2 distance using SIMD acceleration
1523    ///
1524    /// Uses runtime CPU detection to select the best SIMD implementation:
1525    /// - `x86_64`: AVX2 > SSE2 > scalar
1526    /// - aarch64: NEON > scalar
1527    #[inline]
1528    #[must_use]
1529    pub fn distance_l2_simd(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1530        // Reconstruct to f32 vectors
1531        let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1532        let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1533
1534        // Use SIMD distance computation
1535        simd_l2_distance(&v1, &v2)
1536    }
1537
1538    /// Compute cosine distance using SIMD acceleration
1539    #[inline]
1540    #[must_use]
1541    pub fn distance_cosine_simd(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1542        let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1543        let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1544
1545        simd_cosine_distance(&v1, &v2)
1546    }
1547}
1548
1549// SIMD distance computation functions
1550
1551/// Compute L2 distance using SIMD
1552#[inline]
1553fn simd_l2_distance(v1: &[f32], v2: &[f32]) -> f32 {
1554    #[cfg(target_arch = "x86_64")]
1555    {
1556        if is_x86_feature_detected!("avx2") {
1557            return unsafe { l2_distance_avx2(v1, v2) };
1558        } else if is_x86_feature_detected!("sse2") {
1559            return unsafe { l2_distance_sse2(v1, v2) };
1560        }
1561        // Scalar fallback for x86_64 without SIMD
1562        l2_distance_scalar(v1, v2)
1563    }
1564
1565    #[cfg(target_arch = "aarch64")]
1566    {
1567        // NEON always available on aarch64
1568        unsafe { l2_distance_neon(v1, v2) }
1569    }
1570
1571    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1572    {
1573        // Scalar fallback for other architectures
1574        l2_distance_scalar(v1, v2)
1575    }
1576}
1577
1578/// Compute cosine distance using SIMD
1579#[inline]
1580fn simd_cosine_distance(v1: &[f32], v2: &[f32]) -> f32 {
1581    #[cfg(target_arch = "x86_64")]
1582    {
1583        if is_x86_feature_detected!("avx2") {
1584            return unsafe { cosine_distance_avx2(v1, v2) };
1585        } else if is_x86_feature_detected!("sse2") {
1586            return unsafe { cosine_distance_sse2(v1, v2) };
1587        }
1588        // Scalar fallback for x86_64 without SIMD
1589        cosine_distance_scalar(v1, v2)
1590    }
1591
1592    #[cfg(target_arch = "aarch64")]
1593    {
1594        // NEON always available on aarch64
1595        unsafe { cosine_distance_neon(v1, v2) }
1596    }
1597
1598    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1599    {
1600        // Scalar fallback for other architectures
1601        cosine_distance_scalar(v1, v2)
1602    }
1603}
1604
1605// Scalar implementations
1606
1607#[inline]
1608#[allow(dead_code)] // Used as SIMD fallback on x86_64 without AVX2/SSE2
1609fn l2_distance_scalar(v1: &[f32], v2: &[f32]) -> f32 {
1610    v1.iter()
1611        .zip(v2.iter())
1612        .map(|(a, b)| (a - b).powi(2))
1613        .sum::<f32>()
1614        .sqrt()
1615}
1616
1617#[inline]
1618#[allow(dead_code)] // Used as SIMD fallback on x86_64 without AVX2/SSE2
1619fn cosine_distance_scalar(v1: &[f32], v2: &[f32]) -> f32 {
1620    let dot: f32 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
1621    let norm1: f32 = v1.iter().map(|a| a * a).sum::<f32>().sqrt();
1622    let norm2: f32 = v2.iter().map(|b| b * b).sum::<f32>().sqrt();
1623
1624    if norm1 < 1e-10 || norm2 < 1e-10 {
1625        return 1.0;
1626    }
1627
1628    let cosine_sim = dot / (norm1 * norm2);
1629    1.0 - cosine_sim
1630}
1631
1632// AVX2 implementations (x86_64)
1633
1634#[cfg(target_arch = "x86_64")]
1635#[target_feature(enable = "avx2")]
1636#[target_feature(enable = "fma")]
1637unsafe fn l2_distance_avx2(v1: &[f32], v2: &[f32]) -> f32 {
1638    unsafe {
1639        let len = v1.len().min(v2.len());
1640        let mut sum = _mm256_setzero_ps();
1641
1642        let chunks = len / 8;
1643        for i in 0..chunks {
1644            let a = _mm256_loadu_ps(v1.as_ptr().add(i * 8));
1645            let b = _mm256_loadu_ps(v2.as_ptr().add(i * 8));
1646            let diff = _mm256_sub_ps(a, b);
1647            sum = _mm256_fmadd_ps(diff, diff, sum);
1648        }
1649
1650        // Horizontal sum
1651        let sum_high = _mm256_extractf128_ps(sum, 1);
1652        let sum_low = _mm256_castps256_ps128(sum);
1653        let sum128 = _mm_add_ps(sum_low, sum_high);
1654        let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
1655        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1656        let mut result = _mm_cvtss_f32(sum32);
1657
1658        // Handle remainder
1659        for i in (chunks * 8)..len {
1660            let diff = v1[i] - v2[i];
1661            result += diff * diff;
1662        }
1663
1664        result.sqrt()
1665    }
1666}
1667
1668#[cfg(target_arch = "x86_64")]
1669#[target_feature(enable = "avx2")]
1670#[target_feature(enable = "fma")]
1671unsafe fn cosine_distance_avx2(v1: &[f32], v2: &[f32]) -> f32 {
1672    unsafe {
1673        let len = v1.len().min(v2.len());
1674        let mut dot_sum = _mm256_setzero_ps();
1675        let mut norm1_sum = _mm256_setzero_ps();
1676        let mut norm2_sum = _mm256_setzero_ps();
1677
1678        let chunks = len / 8;
1679        for i in 0..chunks {
1680            let a = _mm256_loadu_ps(v1.as_ptr().add(i * 8));
1681            let b = _mm256_loadu_ps(v2.as_ptr().add(i * 8));
1682            dot_sum = _mm256_fmadd_ps(a, b, dot_sum);
1683            norm1_sum = _mm256_fmadd_ps(a, a, norm1_sum);
1684            norm2_sum = _mm256_fmadd_ps(b, b, norm2_sum);
1685        }
1686
1687        // Horizontal sums
1688        let mut dot = horizontal_sum_avx2(dot_sum);
1689        let mut norm1 = horizontal_sum_avx2(norm1_sum);
1690        let mut norm2 = horizontal_sum_avx2(norm2_sum);
1691
1692        // Handle remainder
1693        for i in (chunks * 8)..len {
1694            dot += v1[i] * v2[i];
1695            norm1 += v1[i] * v1[i];
1696            norm2 += v2[i] * v2[i];
1697        }
1698
1699        if norm1 < 1e-10 || norm2 < 1e-10 {
1700            return 1.0;
1701        }
1702
1703        let cosine_sim = dot / (norm1.sqrt() * norm2.sqrt());
1704        1.0 - cosine_sim
1705    }
1706}
1707
1708#[cfg(target_arch = "x86_64")]
1709#[inline]
1710unsafe fn horizontal_sum_avx2(v: __m256) -> f32 {
1711    unsafe {
1712        let sum_high = _mm256_extractf128_ps(v, 1);
1713        let sum_low = _mm256_castps256_ps128(v);
1714        let sum128 = _mm_add_ps(sum_low, sum_high);
1715        let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
1716        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1717        _mm_cvtss_f32(sum32)
1718    }
1719}
1720
1721// SSE2 implementations (x86_64 fallback)
1722
1723#[cfg(target_arch = "x86_64")]
1724#[target_feature(enable = "sse2")]
1725unsafe fn l2_distance_sse2(v1: &[f32], v2: &[f32]) -> f32 {
1726    unsafe {
1727        let len = v1.len().min(v2.len());
1728        let mut sum = _mm_setzero_ps();
1729
1730        let chunks = len / 4;
1731        for i in 0..chunks {
1732            let a = _mm_loadu_ps(v1.as_ptr().add(i * 4));
1733            let b = _mm_loadu_ps(v2.as_ptr().add(i * 4));
1734            let diff = _mm_sub_ps(a, b);
1735            sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
1736        }
1737
1738        // Horizontal sum
1739        let sum64 = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));
1740        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1741        let mut result = _mm_cvtss_f32(sum32);
1742
1743        // Handle remainder
1744        for i in (chunks * 4)..len {
1745            let diff = v1[i] - v2[i];
1746            result += diff * diff;
1747        }
1748
1749        result.sqrt()
1750    }
1751}
1752
1753#[cfg(target_arch = "x86_64")]
1754#[target_feature(enable = "sse2")]
1755unsafe fn cosine_distance_sse2(v1: &[f32], v2: &[f32]) -> f32 {
1756    unsafe {
1757        let len = v1.len().min(v2.len());
1758        let mut dot_sum = _mm_setzero_ps();
1759        let mut norm1_sum = _mm_setzero_ps();
1760        let mut norm2_sum = _mm_setzero_ps();
1761
1762        let chunks = len / 4;
1763        for i in 0..chunks {
1764            let a = _mm_loadu_ps(v1.as_ptr().add(i * 4));
1765            let b = _mm_loadu_ps(v2.as_ptr().add(i * 4));
1766            dot_sum = _mm_add_ps(dot_sum, _mm_mul_ps(a, b));
1767            norm1_sum = _mm_add_ps(norm1_sum, _mm_mul_ps(a, a));
1768            norm2_sum = _mm_add_ps(norm2_sum, _mm_mul_ps(b, b));
1769        }
1770
1771        // Horizontal sums
1772        let mut dot = horizontal_sum_sse2(dot_sum);
1773        let mut norm1 = horizontal_sum_sse2(norm1_sum);
1774        let mut norm2 = horizontal_sum_sse2(norm2_sum);
1775
1776        // Handle remainder
1777        for i in (chunks * 4)..len {
1778            dot += v1[i] * v2[i];
1779            norm1 += v1[i] * v1[i];
1780            norm2 += v2[i] * v2[i];
1781        }
1782
1783        if norm1 < 1e-10 || norm2 < 1e-10 {
1784            return 1.0;
1785        }
1786
1787        let cosine_sim = dot / (norm1.sqrt() * norm2.sqrt());
1788        1.0 - cosine_sim
1789    }
1790}
1791
1792#[cfg(target_arch = "x86_64")]
1793#[inline]
1794unsafe fn horizontal_sum_sse2(v: __m128) -> f32 {
1795    unsafe {
1796        let sum64 = _mm_add_ps(v, _mm_movehl_ps(v, v));
1797        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1798        _mm_cvtss_f32(sum32)
1799    }
1800}
1801
1802// NEON implementations (aarch64)
1803
1804#[cfg(target_arch = "aarch64")]
1805unsafe fn l2_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
1806    let len = v1.len().min(v2.len());
1807
1808    // SAFETY: All SIMD operations wrapped in unsafe block for Rust 2024
1809    unsafe {
1810        let mut sum = vdupq_n_f32(0.0);
1811
1812        let chunks = len / 4;
1813        for i in 0..chunks {
1814            let a = vld1q_f32(v1.as_ptr().add(i * 4));
1815            let b = vld1q_f32(v2.as_ptr().add(i * 4));
1816            let diff = vsubq_f32(a, b);
1817            sum = vfmaq_f32(sum, diff, diff);
1818        }
1819
1820        // Horizontal sum
1821        let mut result = vaddvq_f32(sum);
1822
1823        // Handle remainder
1824        for i in (chunks * 4)..len {
1825            let diff = v1[i] - v2[i];
1826            result += diff * diff;
1827        }
1828
1829        result.sqrt()
1830    }
1831}
1832
1833#[cfg(target_arch = "aarch64")]
1834unsafe fn cosine_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
1835    let len = v1.len().min(v2.len());
1836
1837    // SAFETY: All SIMD operations wrapped in unsafe block for Rust 2024
1838    unsafe {
1839        let mut dot_sum = vdupq_n_f32(0.0);
1840        let mut norm1_sum = vdupq_n_f32(0.0);
1841        let mut norm2_sum = vdupq_n_f32(0.0);
1842
1843        let chunks = len / 4;
1844        for i in 0..chunks {
1845            let a = vld1q_f32(v1.as_ptr().add(i * 4));
1846            let b = vld1q_f32(v2.as_ptr().add(i * 4));
1847            dot_sum = vfmaq_f32(dot_sum, a, b);
1848            norm1_sum = vfmaq_f32(norm1_sum, a, a);
1849            norm2_sum = vfmaq_f32(norm2_sum, b, b);
1850        }
1851
1852        // Horizontal sums
1853        let mut dot = vaddvq_f32(dot_sum);
1854        let mut norm1 = vaddvq_f32(norm1_sum);
1855        let mut norm2 = vaddvq_f32(norm2_sum);
1856
1857        // Handle remainder
1858        for i in (chunks * 4)..len {
1859            dot += v1[i] * v2[i];
1860            norm1 += v1[i] * v1[i];
1861            norm2 += v2[i] * v2[i];
1862        }
1863
1864        if norm1 < 1e-10 || norm2 < 1e-10 {
1865            return 1.0;
1866        }
1867
1868        let cosine_sim = dot / (norm1.sqrt() * norm2.sqrt());
1869        1.0 - cosine_sim
1870    }
1871}
1872
1873#[cfg(test)]
1874#[allow(clippy::float_cmp)]
1875mod tests {
1876    use super::*;
1877
1878    #[test]
1879    fn test_quantization_bits_conversion() {
1880        assert_eq!(QuantizationBits::Bits2.to_u8(), 2);
1881        assert_eq!(QuantizationBits::Bits4.to_u8(), 4);
1882        assert_eq!(QuantizationBits::Bits8.to_u8(), 8);
1883    }
1884
1885    #[test]
1886    fn test_quantization_bits_levels() {
1887        assert_eq!(QuantizationBits::Bits2.levels(), 4); // 2^2
1888        assert_eq!(QuantizationBits::Bits4.levels(), 16); // 2^4
1889        assert_eq!(QuantizationBits::Bits8.levels(), 256); // 2^8
1890    }
1891
1892    #[test]
1893    fn test_quantization_bits_compression() {
1894        assert_eq!(QuantizationBits::Bits2.compression_ratio(), 16.0); // 32/2
1895        assert_eq!(QuantizationBits::Bits4.compression_ratio(), 8.0); // 32/4
1896        assert_eq!(QuantizationBits::Bits8.compression_ratio(), 4.0); // 32/8
1897    }
1898
1899    #[test]
1900    fn test_quantization_bits_values_per_byte() {
1901        assert_eq!(QuantizationBits::Bits2.values_per_byte(), 4); // 8/2
1902        assert_eq!(QuantizationBits::Bits4.values_per_byte(), 2); // 8/4
1903        assert_eq!(QuantizationBits::Bits8.values_per_byte(), 1); // 8/8
1904    }
1905
1906    #[test]
1907    fn test_default_params() {
1908        let params = RaBitQParams::default();
1909        assert_eq!(params.bits_per_dim, QuantizationBits::Bits4);
1910        assert_eq!(params.num_rescale_factors, 12);
1911        assert_eq!(params.rescale_range, (0.5, 2.0));
1912    }
1913
1914    #[test]
1915    fn test_preset_params() {
1916        let params2 = RaBitQParams::bits2();
1917        assert_eq!(params2.bits_per_dim, QuantizationBits::Bits2);
1918
1919        let params4 = RaBitQParams::bits4();
1920        assert_eq!(params4.bits_per_dim, QuantizationBits::Bits4);
1921
1922        let params8 = RaBitQParams::bits8();
1923        assert_eq!(params8.bits_per_dim, QuantizationBits::Bits8);
1924        assert_eq!(params8.num_rescale_factors, 16);
1925    }
1926
1927    #[test]
1928    fn test_quantized_vector_creation() {
1929        let data = vec![0u8, 128, 255];
1930        let qv = QuantizedVector::new(data.clone(), 1.5, 8, 3);
1931
1932        assert_eq!(qv.data, data);
1933        assert_eq!(qv.scale, 1.5);
1934        assert_eq!(qv.bits, 8);
1935        assert_eq!(qv.dimensions, 3);
1936    }
1937
1938    #[test]
1939    fn test_quantized_vector_memory() {
1940        let data = vec![0u8; 16]; // 16 bytes
1941        let qv = QuantizedVector::new(data, 1.0, 4, 32);
1942
1943        // Should be: struct overhead + data length
1944        let expected_min = 16; // At least the data
1945        assert!(qv.memory_bytes() >= expected_min);
1946    }
1947
1948    #[test]
1949    fn test_quantized_vector_compression_ratio() {
1950        // 128 dimensions, 4-bit = 64 bytes
1951        let data = vec![0u8; 64];
1952        let qv = QuantizedVector::new(data, 1.0, 4, 128);
1953
1954        // Original: 128 * 4 = 512 bytes
1955        // Compressed: 64 + 4 (scale) + 1 (bits) = 69 bytes
1956        // Ratio: 512 / 69 ≈ 7.4x
1957        let ratio = qv.compression_ratio();
1958        assert!(ratio > 7.0 && ratio < 8.0);
1959    }
1960
1961    #[test]
1962    fn test_create_quantizer() {
1963        let quantizer = RaBitQ::default_4bit();
1964        assert_eq!(quantizer.params().bits_per_dim, QuantizationBits::Bits4);
1965    }
1966
1967    // Phase 2 Tests: Core Algorithm
1968
1969    #[test]
1970    fn test_generate_scales() {
1971        let quantizer = RaBitQ::new(RaBitQParams {
1972            bits_per_dim: QuantizationBits::Bits4,
1973            num_rescale_factors: 5,
1974            rescale_range: (0.5, 1.5),
1975        });
1976
1977        let scales = quantizer.generate_scales();
1978        assert_eq!(scales.len(), 5);
1979        assert_eq!(scales[0], 0.5);
1980        assert_eq!(scales[4], 1.5);
1981        assert!((scales[2] - 1.0).abs() < 0.01); // Middle should be ~1.0
1982    }
1983
1984    #[test]
1985    fn test_generate_scales_single() {
1986        let quantizer = RaBitQ::new(RaBitQParams {
1987            bits_per_dim: QuantizationBits::Bits4,
1988            num_rescale_factors: 1,
1989            rescale_range: (0.5, 1.5),
1990        });
1991
1992        let scales = quantizer.generate_scales();
1993        assert_eq!(scales.len(), 1);
1994        assert_eq!(scales[0], 1.0); // Average of min and max
1995    }
1996
1997    #[test]
1998    fn test_pack_unpack_2bit() {
1999        let quantizer = RaBitQ::new(RaBitQParams {
2000            bits_per_dim: QuantizationBits::Bits2,
2001            ..Default::default()
2002        });
2003
2004        // 8 values (2 bits each) = 2 bytes
2005        let values = vec![0u8, 1, 2, 3, 0, 1, 2, 3];
2006        let packed = quantizer.pack_quantized(&values, 2);
2007        assert_eq!(packed.len(), 2); // 8 values / 4 per byte = 2 bytes
2008
2009        let unpacked = quantizer.unpack_quantized(&packed, 2, 8);
2010        assert_eq!(unpacked, values);
2011    }
2012
2013    #[test]
2014    fn test_pack_unpack_4bit() {
2015        let quantizer = RaBitQ::new(RaBitQParams {
2016            bits_per_dim: QuantizationBits::Bits4,
2017            ..Default::default()
2018        });
2019
2020        // 8 values (4 bits each) = 4 bytes
2021        let values = vec![0u8, 1, 2, 3, 4, 5, 6, 7];
2022        let packed = quantizer.pack_quantized(&values, 4);
2023        assert_eq!(packed.len(), 4); // 8 values / 2 per byte = 4 bytes
2024
2025        let unpacked = quantizer.unpack_quantized(&packed, 4, 8);
2026        assert_eq!(unpacked, values);
2027    }
2028
2029    #[test]
2030    fn test_pack_unpack_8bit() {
2031        let quantizer = RaBitQ::new(RaBitQParams {
2032            bits_per_dim: QuantizationBits::Bits8,
2033            ..Default::default()
2034        });
2035
2036        // 8 values (8 bits each) = 8 bytes
2037        let values = vec![0u8, 10, 20, 30, 40, 50, 60, 70];
2038        let packed = quantizer.pack_quantized(&values, 8);
2039        assert_eq!(packed.len(), 8); // 8 values = 8 bytes
2040
2041        let unpacked = quantizer.unpack_quantized(&packed, 8, 8);
2042        assert_eq!(unpacked, values);
2043    }
2044
2045    #[test]
2046    fn test_quantize_simple_vector() {
2047        let quantizer = RaBitQ::new(RaBitQParams {
2048            bits_per_dim: QuantizationBits::Bits4,
2049            num_rescale_factors: 4,
2050            rescale_range: (0.5, 1.5),
2051        });
2052
2053        // Simple vector: [0.0, 0.25, 0.5, 0.75, 1.0]
2054        let vector = vec![0.0, 0.25, 0.5, 0.75, 1.0];
2055        let quantized = quantizer.quantize(&vector);
2056
2057        // Check structure
2058        assert_eq!(quantized.dimensions, 5);
2059        assert_eq!(quantized.bits, 4);
2060        assert!(quantized.scale > 0.0);
2061
2062        // Check compression: 5 floats * 4 bytes = 20 bytes original
2063        // Quantized: 5 values * 4 bits = 20 bits = 3 bytes (rounded up)
2064        assert!(quantized.data.len() <= 4);
2065    }
2066
2067    #[test]
2068    fn test_quantize_reconstruct_accuracy() {
2069        let quantizer = RaBitQ::new(RaBitQParams {
2070            bits_per_dim: QuantizationBits::Bits8, // High precision
2071            num_rescale_factors: 8,
2072            rescale_range: (0.8, 1.2),
2073        });
2074
2075        // Test vector
2076        let vector = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2077        let quantized = quantizer.quantize(&vector);
2078
2079        // Reconstruct
2080        let reconstructed = quantizer.reconstruct(&quantized.data, quantized.scale, vector.len());
2081
2082        // Check reconstruction is close (8-bit should be accurate)
2083        for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
2084            let error = (orig - recon).abs();
2085            assert!(error < 0.1, "Error too large: {orig} vs {recon}");
2086        }
2087    }
2088
2089    #[test]
2090    fn test_quantize_uniform_vector() {
2091        let quantizer = RaBitQ::default_4bit();
2092
2093        // All values the same
2094        let vector = vec![0.5; 10];
2095        let quantized = quantizer.quantize(&vector);
2096
2097        // Reconstruct should also be uniform
2098        let reconstructed = quantizer.reconstruct(&quantized.data, quantized.scale, vector.len());
2099
2100        // All values should be similar
2101        let avg = reconstructed.iter().sum::<f32>() / reconstructed.len() as f32;
2102        for &val in &reconstructed {
2103            assert!((val - avg).abs() < 0.2);
2104        }
2105    }
2106
2107    #[test]
2108    fn test_compute_error() {
2109        let quantizer = RaBitQ::default_4bit();
2110
2111        let original = vec![0.1, 0.2, 0.3, 0.4];
2112        let quantized_vec = quantizer.quantize(&original);
2113
2114        // Compute error
2115        let error = quantizer.compute_error(&original, &quantized_vec.data, quantized_vec.scale);
2116
2117        // Error should be non-negative and finite
2118        assert!(error >= 0.0);
2119        assert!(error.is_finite());
2120    }
2121
2122    #[test]
2123    fn test_quantize_different_bit_widths() {
2124        let test_vector = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2125
2126        // Test 2-bit
2127        let q2 = RaBitQ::new(RaBitQParams::bits2());
2128        let qv2 = q2.quantize(&test_vector);
2129        assert_eq!(qv2.bits, 2);
2130
2131        // Test 4-bit
2132        let q4 = RaBitQ::default_4bit();
2133        let qv4 = q4.quantize(&test_vector);
2134        assert_eq!(qv4.bits, 4);
2135
2136        // Test 8-bit
2137        let q8 = RaBitQ::new(RaBitQParams::bits8());
2138        let qv8 = q8.quantize(&test_vector);
2139        assert_eq!(qv8.bits, 8);
2140
2141        // Higher bits = larger packed size (for same dimensions)
2142        assert!(qv2.data.len() <= qv4.data.len());
2143        assert!(qv4.data.len() <= qv8.data.len());
2144    }
2145
2146    #[test]
2147    fn test_quantize_high_dimensional() {
2148        let quantizer = RaBitQ::default_4bit();
2149
2150        // 128D vector (like small embeddings)
2151        let vector: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2152        let quantized = quantizer.quantize(&vector);
2153
2154        assert_eq!(quantized.dimensions, 128);
2155        assert_eq!(quantized.bits, 4);
2156
2157        // 128 dimensions * 4 bits = 512 bits = 64 bytes
2158        assert_eq!(quantized.data.len(), 64);
2159
2160        // Verify reconstruction
2161        let reconstructed = quantizer.reconstruct(&quantized.data, quantized.scale, 128);
2162        assert_eq!(reconstructed.len(), 128);
2163    }
2164
2165    // Phase 3 Tests: Distance Computation
2166
2167    #[test]
2168    fn test_distance_l2() {
2169        let quantizer = RaBitQ::new(RaBitQParams {
2170            bits_per_dim: QuantizationBits::Bits8, // High precision
2171            num_rescale_factors: 8,
2172            rescale_range: (0.8, 1.2),
2173        });
2174
2175        let v1 = vec![0.0, 0.0, 0.0];
2176        let v2 = vec![1.0, 0.0, 0.0];
2177
2178        let qv1 = quantizer.quantize(&v1);
2179        let qv2 = quantizer.quantize(&v2);
2180
2181        let dist = quantizer.distance_l2(&qv1, &qv2);
2182
2183        // Distance should be approximately 1.0
2184        assert!((dist - 1.0).abs() < 0.2, "Distance: {dist}");
2185    }
2186
2187    #[test]
2188    fn test_distance_l2_identical() {
2189        let quantizer = RaBitQ::default_4bit();
2190
2191        let v = vec![0.5, 0.3, 0.8, 0.2];
2192        let qv1 = quantizer.quantize(&v);
2193        let qv2 = quantizer.quantize(&v);
2194
2195        let dist = quantizer.distance_l2(&qv1, &qv2);
2196
2197        // Identical vectors should have near-zero distance
2198        assert!(dist < 0.3, "Distance should be near zero, got: {dist}");
2199    }
2200
2201    #[test]
2202    fn test_distance_cosine() {
2203        let quantizer = RaBitQ::new(RaBitQParams {
2204            bits_per_dim: QuantizationBits::Bits8,
2205            num_rescale_factors: 8,
2206            rescale_range: (0.8, 1.2),
2207        });
2208
2209        // Orthogonal vectors
2210        let v1 = vec![1.0, 0.0, 0.0];
2211        let v2 = vec![0.0, 1.0, 0.0];
2212
2213        let qv1 = quantizer.quantize(&v1);
2214        let qv2 = quantizer.quantize(&v2);
2215
2216        let dist = quantizer.distance_cosine(&qv1, &qv2);
2217
2218        // Orthogonal vectors: cosine = 0, distance = 1
2219        assert!((dist - 1.0).abs() < 0.3, "Distance: {dist}");
2220    }
2221
2222    #[test]
2223    fn test_distance_cosine_identical() {
2224        let quantizer = RaBitQ::default_4bit();
2225
2226        let v = vec![0.5, 0.3, 0.8];
2227        let qv1 = quantizer.quantize(&v);
2228        let qv2 = quantizer.quantize(&v);
2229
2230        let dist = quantizer.distance_cosine(&qv1, &qv2);
2231
2232        // Identical vectors: cosine = 1, distance = 0
2233        assert!(dist < 0.2, "Distance should be near zero, got: {dist}");
2234    }
2235
2236    #[test]
2237    fn test_distance_dot() {
2238        let quantizer = RaBitQ::new(RaBitQParams {
2239            bits_per_dim: QuantizationBits::Bits8,
2240            num_rescale_factors: 8,
2241            rescale_range: (0.8, 1.2),
2242        });
2243
2244        let v1 = vec![1.0, 0.0, 0.0];
2245        let v2 = vec![1.0, 0.0, 0.0];
2246
2247        let qv1 = quantizer.quantize(&v1);
2248        let qv2 = quantizer.quantize(&v2);
2249
2250        let dist = quantizer.distance_dot(&qv1, &qv2);
2251
2252        // Dot product of [1,0,0] with itself = 1, negated = -1
2253        assert!((dist + 1.0).abs() < 0.3, "Distance: {dist}");
2254    }
2255
2256    #[test]
2257    fn test_distance_approximate() {
2258        let quantizer = RaBitQ::default_4bit();
2259
2260        let v1 = vec![0.0, 0.0, 0.0];
2261        let v2 = vec![0.5, 0.5, 0.5];
2262
2263        let qv1 = quantizer.quantize(&v1);
2264        let qv2 = quantizer.quantize(&v2);
2265
2266        let dist_approx = quantizer.distance_approximate(&qv1, &qv2);
2267        let dist_exact = quantizer.distance_l2(&qv1, &qv2);
2268
2269        // Approximate should be non-negative and finite
2270        assert!(dist_approx >= 0.0);
2271        assert!(dist_approx.is_finite());
2272
2273        // Approximate and exact should be correlated (not exact match)
2274        // Just verify both increase/decrease together
2275        let v3 = vec![1.0, 1.0, 1.0];
2276        let qv3 = quantizer.quantize(&v3);
2277
2278        let dist_approx2 = quantizer.distance_approximate(&qv1, &qv3);
2279        let dist_exact2 = quantizer.distance_l2(&qv1, &qv3);
2280
2281        // If v3 is farther from v1 than v2, both metrics should reflect that
2282        if dist_exact2 > dist_exact {
2283            assert!(dist_approx2 > dist_approx * 0.5); // Allow some variance
2284        }
2285    }
2286
2287    #[test]
2288    fn test_distance_correlation() {
2289        let quantizer = RaBitQ::new(RaBitQParams {
2290            bits_per_dim: QuantizationBits::Bits8, // High precision for correlation
2291            num_rescale_factors: 12,
2292            rescale_range: (0.8, 1.2),
2293        });
2294
2295        // Create multiple vectors
2296        let vectors = [
2297            vec![0.1, 0.2, 0.3],
2298            vec![0.4, 0.5, 0.6],
2299            vec![0.7, 0.8, 0.9],
2300        ];
2301
2302        // Quantize all
2303        let quantized: Vec<QuantizedVector> =
2304            vectors.iter().map(|v| quantizer.quantize(v)).collect();
2305
2306        // Ground truth L2 distances
2307        let ground_truth_01 = vectors[0]
2308            .iter()
2309            .zip(vectors[1].iter())
2310            .map(|(a, b)| (a - b).powi(2))
2311            .sum::<f32>()
2312            .sqrt();
2313
2314        let ground_truth_02 = vectors[0]
2315            .iter()
2316            .zip(vectors[2].iter())
2317            .map(|(a, b)| (a - b).powi(2))
2318            .sum::<f32>()
2319            .sqrt();
2320
2321        // Quantized distances
2322        let quantized_01 = quantizer.distance_l2(&quantized[0], &quantized[1]);
2323        let quantized_02 = quantizer.distance_l2(&quantized[0], &quantized[2]);
2324
2325        // Check correlation: if ground truth says v2 > v1, quantized should too
2326        if ground_truth_02 > ground_truth_01 {
2327            assert!(
2328                quantized_02 > quantized_01 * 0.8,
2329                "Order not preserved: {quantized_01} vs {quantized_02}"
2330            );
2331        }
2332    }
2333
2334    #[test]
2335    fn test_distance_zero_vectors() {
2336        let quantizer = RaBitQ::default_4bit();
2337
2338        let v_zero = vec![0.0, 0.0, 0.0];
2339        let qv_zero = quantizer.quantize(&v_zero);
2340
2341        // Distance to itself should be zero
2342        let dist = quantizer.distance_l2(&qv_zero, &qv_zero);
2343        assert!(dist < 0.1);
2344
2345        // Cosine distance with zero vector should handle gracefully
2346        let dist_cosine = quantizer.distance_cosine(&qv_zero, &qv_zero);
2347        assert!(dist_cosine.is_finite());
2348    }
2349
2350    #[test]
2351    fn test_distance_high_dimensional() {
2352        let quantizer = RaBitQ::default_4bit();
2353
2354        // 128D vectors
2355        let v1: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2356        let v2: Vec<f32> = (0..128).map(|i| ((i + 10) as f32) / 128.0).collect();
2357
2358        let qv1 = quantizer.quantize(&v1);
2359        let qv2 = quantizer.quantize(&v2);
2360
2361        // All distance metrics should work on high-dimensional vectors
2362        let dist_l2 = quantizer.distance_l2(&qv1, &qv2);
2363        let dist_cosine = quantizer.distance_cosine(&qv1, &qv2);
2364        let dist_dot = quantizer.distance_dot(&qv1, &qv2);
2365        let dist_approx = quantizer.distance_approximate(&qv1, &qv2);
2366
2367        assert!(dist_l2 > 0.0 && dist_l2.is_finite());
2368        assert!(dist_cosine >= 0.0 && dist_cosine.is_finite());
2369        assert!(dist_dot.is_finite());
2370        assert!(dist_approx > 0.0 && dist_approx.is_finite());
2371    }
2372
2373    #[test]
2374    fn test_distance_asymmetric_l2() {
2375        let quantizer = RaBitQ::default_4bit();
2376
2377        let query = vec![0.1, 0.2, 0.3, 0.4];
2378        // Vector close to query
2379        let vector = vec![0.12, 0.22, 0.32, 0.42];
2380
2381        let quantized = quantizer.quantize(&vector);
2382
2383        // Symmetric distance (allocates)
2384        let dist_sym = quantizer.distance_l2_simd(&quantized, &quantizer.quantize(&query));
2385
2386        // Asymmetric distance (no allocation)
2387        let dist_asym = quantizer.distance_asymmetric_l2(&query, &quantized);
2388
2389        // Should be reasonably close (asymmetric is actually MORE accurate because query is exact)
2390        // But for this test, just ensure it's sane
2391        assert!(dist_asym >= 0.0);
2392        assert!((dist_asym - dist_sym).abs() < 0.2);
2393    }
2394
2395    // Phase 4 Tests: SIMD Optimizations
2396
2397    #[test]
2398    fn test_simd_l2_matches_scalar() {
2399        let quantizer = RaBitQ::new(RaBitQParams {
2400            bits_per_dim: QuantizationBits::Bits8, // High precision
2401            num_rescale_factors: 8,
2402            rescale_range: (0.8, 1.2),
2403        });
2404
2405        let v1 = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2406        let v2 = vec![0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
2407
2408        let qv1 = quantizer.quantize(&v1);
2409        let qv2 = quantizer.quantize(&v2);
2410
2411        let dist_scalar = quantizer.distance_l2(&qv1, &qv2);
2412        let dist_simd = quantizer.distance_l2_simd(&qv1, &qv2);
2413
2414        // SIMD should match scalar within floating point precision
2415        let diff = (dist_scalar - dist_simd).abs();
2416        assert!(diff < 0.01, "SIMD vs scalar: {dist_simd} vs {dist_scalar}");
2417    }
2418
2419    #[test]
2420    fn test_simd_cosine_matches_scalar() {
2421        let quantizer = RaBitQ::new(RaBitQParams {
2422            bits_per_dim: QuantizationBits::Bits8,
2423            num_rescale_factors: 8,
2424            rescale_range: (0.8, 1.2),
2425        });
2426
2427        let v1 = vec![1.0, 0.0, 0.0];
2428        let v2 = vec![0.0, 1.0, 0.0];
2429
2430        let qv1 = quantizer.quantize(&v1);
2431        let qv2 = quantizer.quantize(&v2);
2432
2433        let dist_scalar = quantizer.distance_cosine(&qv1, &qv2);
2434        let dist_simd = quantizer.distance_cosine_simd(&qv1, &qv2);
2435
2436        // SIMD should match scalar within floating point precision
2437        let diff = (dist_scalar - dist_simd).abs();
2438        assert!(diff < 0.01, "SIMD vs scalar: {dist_simd} vs {dist_scalar}");
2439    }
2440
2441    #[test]
2442    fn test_simd_high_dimensional() {
2443        let quantizer = RaBitQ::default_4bit();
2444
2445        // 128D vectors (realistic embeddings)
2446        let v1: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2447        let v2: Vec<f32> = (0..128).map(|i| ((i + 1) as f32) / 128.0).collect();
2448
2449        let qv1 = quantizer.quantize(&v1);
2450        let qv2 = quantizer.quantize(&v2);
2451
2452        let dist_scalar = quantizer.distance_l2(&qv1, &qv2);
2453        let dist_simd = quantizer.distance_l2_simd(&qv1, &qv2);
2454
2455        // Should be close (allow for quantization + FP variance)
2456        let diff = (dist_scalar - dist_simd).abs();
2457        assert!(
2458            diff < 0.1,
2459            "High-D SIMD vs scalar: {dist_simd} vs {dist_scalar}"
2460        );
2461    }
2462
2463    #[test]
2464    fn test_simd_scalar_fallback() {
2465        let quantizer = RaBitQ::default_4bit();
2466
2467        // Small vector (tests remainder handling)
2468        let v1 = vec![0.1, 0.2, 0.3];
2469        let v2 = vec![0.4, 0.5, 0.6];
2470
2471        let qv1 = quantizer.quantize(&v1);
2472        let qv2 = quantizer.quantize(&v2);
2473
2474        // Should not crash on small vectors
2475        let dist_l2 = quantizer.distance_l2_simd(&qv1, &qv2);
2476        let dist_cosine = quantizer.distance_cosine_simd(&qv1, &qv2);
2477
2478        assert!(dist_l2.is_finite());
2479        assert!(dist_cosine.is_finite());
2480    }
2481
2482    #[test]
2483    fn test_simd_performance_improvement() {
2484        let quantizer = RaBitQ::default_4bit();
2485
2486        // Large vectors (1536D like OpenAI embeddings)
2487        let v1: Vec<f32> = (0..1536).map(|i| (i as f32) / 1536.0).collect();
2488        let v2: Vec<f32> = (0..1536).map(|i| ((i + 10) as f32) / 1536.0).collect();
2489
2490        let qv1 = quantizer.quantize(&v1);
2491        let qv2 = quantizer.quantize(&v2);
2492
2493        // Just verify SIMD works on large vectors
2494        let dist_simd = quantizer.distance_l2_simd(&qv1, &qv2);
2495        assert!(dist_simd > 0.0 && dist_simd.is_finite());
2496
2497        // Note: Actual performance benchmarks in Phase 6
2498    }
2499
2500    #[test]
2501    fn test_scalar_distance_functions() {
2502        // Test the scalar fallback functions directly
2503        let v1 = vec![0.0, 0.0, 0.0];
2504        let v2 = vec![1.0, 0.0, 0.0];
2505
2506        let dist = l2_distance_scalar(&v1, &v2);
2507        assert!((dist - 1.0).abs() < 0.001);
2508
2509        let v1 = vec![1.0, 0.0, 0.0];
2510        let v2 = vec![0.0, 1.0, 0.0];
2511
2512        let dist = cosine_distance_scalar(&v1, &v2);
2513        assert!((dist - 1.0).abs() < 0.001);
2514    }
2515
2516    // ADC Tests
2517
2518    #[test]
2519    fn test_adc_table_creation() {
2520        let quantizer = RaBitQ::default_4bit();
2521        let query = vec![0.1, 0.2, 0.3, 0.4];
2522        let scale = 1.0;
2523
2524        let adc = quantizer.build_adc_table_with_scale(&query, scale);
2525
2526        // Check structure
2527        assert_eq!(adc.dimensions, 4);
2528        assert_eq!(adc.bits, 4);
2529        assert_eq!(adc.table.len(), 4);
2530
2531        // Each dimension should have 16 codes (4-bit)
2532        for dim_table in &adc.table {
2533            assert_eq!(dim_table.len(), 16);
2534        }
2535    }
2536
2537    #[test]
2538    fn test_adc_table_2bit() {
2539        let quantizer = RaBitQ::new(RaBitQParams::bits2());
2540        let query = vec![0.1, 0.2, 0.3, 0.4];
2541        let scale = 1.0;
2542
2543        let adc = quantizer.build_adc_table_with_scale(&query, scale);
2544
2545        // Each dimension should have 4 codes (2-bit)
2546        for dim_table in &adc.table {
2547            assert_eq!(dim_table.len(), 4);
2548        }
2549    }
2550
2551    #[test]
2552    fn test_adc_distance_matches_asymmetric() {
2553        let quantizer = RaBitQ::default_4bit();
2554
2555        // Create query and vector
2556        let query = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2557        let vector = vec![0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85];
2558
2559        // Quantize the vector
2560        let quantized = quantizer.quantize(&vector);
2561
2562        // Compute distance with asymmetric method
2563        let dist_asymmetric = quantizer.distance_asymmetric_l2(&query, &quantized);
2564
2565        // Compute distance with ADC (using build_adc_table_with_scale for untrained quantizer)
2566        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2567        let dist_adc = adc.distance(&quantized.data);
2568
2569        // ADC should give similar results to asymmetric distance
2570        // They use different computation paths but should be close
2571        let diff = (dist_asymmetric - dist_adc).abs();
2572        assert!(
2573            diff < 0.1,
2574            "ADC vs asymmetric: {dist_adc} vs {dist_asymmetric}, diff: {diff}"
2575        );
2576    }
2577
2578    #[test]
2579    fn test_adc_distance_accuracy() {
2580        let quantizer = RaBitQ::new(RaBitQParams {
2581            bits_per_dim: QuantizationBits::Bits8, // High precision
2582            num_rescale_factors: 16,
2583            rescale_range: (0.8, 1.2),
2584        });
2585
2586        let query = vec![0.1, 0.2, 0.3, 0.4];
2587        let vector = vec![0.1, 0.2, 0.3, 0.4]; // Same as query
2588
2589        let quantized = quantizer.quantize(&vector);
2590
2591        // Build ADC table
2592        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2593
2594        // Distance should be near zero (same vector)
2595        let dist = adc.distance(&quantized.data);
2596        assert!(dist < 0.2, "Distance should be near zero, got: {dist}");
2597    }
2598
2599    #[test]
2600    fn test_adc_distance_ordering() {
2601        let quantizer = RaBitQ::default_4bit();
2602
2603        let query = vec![0.5, 0.5, 0.5, 0.5];
2604        let v1 = vec![0.5, 0.5, 0.5, 0.5]; // Closest
2605        let v2 = vec![0.6, 0.6, 0.6, 0.6]; // Medium
2606        let v3 = vec![0.9, 0.9, 0.9, 0.9]; // Farthest
2607
2608        let qv1 = quantizer.quantize(&v1);
2609        let qv2 = quantizer.quantize(&v2);
2610        let qv3 = quantizer.quantize(&v3);
2611
2612        // Build ADC tables with respective scales
2613        let adc1 = quantizer.build_adc_table_with_scale(&query, qv1.scale);
2614        let adc2 = quantizer.build_adc_table_with_scale(&query, qv2.scale);
2615        let adc3 = quantizer.build_adc_table_with_scale(&query, qv3.scale);
2616
2617        let dist1 = adc1.distance(&qv1.data);
2618        let dist2 = adc2.distance(&qv2.data);
2619        let dist3 = adc3.distance(&qv3.data);
2620
2621        // Order should be preserved
2622        assert!(
2623            dist1 < dist2,
2624            "v1 should be closer than v2: {dist1} vs {dist2}"
2625        );
2626        assert!(
2627            dist2 < dist3,
2628            "v2 should be closer than v3: {dist2} vs {dist3}"
2629        );
2630    }
2631
2632    #[test]
2633    fn test_adc_high_dimensional() {
2634        let quantizer = RaBitQ::default_4bit();
2635
2636        // 128D vectors (realistic embedding size)
2637        let query: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2638        let vector: Vec<f32> = (0..128).map(|i| ((i + 5) as f32) / 128.0).collect();
2639
2640        let quantized = quantizer.quantize(&vector);
2641
2642        // Build ADC table
2643        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2644
2645        // Should handle high dimensions without panic
2646        let dist = adc.distance(&quantized.data);
2647        assert!(dist > 0.0 && dist.is_finite());
2648    }
2649
2650    #[test]
2651    fn test_adc_batch_search() {
2652        let quantizer = RaBitQ::default_4bit();
2653
2654        let query = vec![0.5, 0.5, 0.5, 0.5];
2655        let candidates = [
2656            vec![0.5, 0.5, 0.5, 0.5],
2657            vec![0.6, 0.6, 0.6, 0.6],
2658            vec![0.4, 0.4, 0.4, 0.4],
2659            vec![0.7, 0.7, 0.7, 0.7],
2660        ];
2661
2662        // Quantize all candidates
2663        let quantized: Vec<QuantizedVector> =
2664            candidates.iter().map(|v| quantizer.quantize(v)).collect();
2665
2666        // Scan all candidates using ADC tables
2667        let mut results: Vec<(usize, f32)> = quantized
2668            .iter()
2669            .enumerate()
2670            .map(|(i, qv)| {
2671                let adc = quantizer.build_adc_table_with_scale(&query, qv.scale);
2672                (i, adc.distance(&qv.data))
2673            })
2674            .collect();
2675
2676        // Sort by distance
2677        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
2678
2679        // First result should be index 0 (identical to query)
2680        assert_eq!(results[0].0, 0, "Results: {results:?}");
2681    }
2682
2683    #[test]
2684    fn test_adc_distance_squared() {
2685        let quantizer = RaBitQ::default_4bit();
2686
2687        let query = vec![0.0, 0.0, 0.0];
2688        let vector = vec![1.0, 0.0, 0.0];
2689
2690        let quantized = quantizer.quantize(&vector);
2691        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2692
2693        let dist_squared = adc.distance_squared(&quantized.data);
2694        let dist = adc.distance(&quantized.data);
2695
2696        // distance_squared should be dist^2 (approximately)
2697        let diff = (dist_squared - dist * dist).abs();
2698        assert!(
2699            diff < 0.01,
2700            "distance_squared != dist^2: {} vs {}",
2701            dist_squared,
2702            dist * dist
2703        );
2704    }
2705
2706    #[test]
2707    fn test_adc_simd_matches_scalar() {
2708        let quantizer = RaBitQ::default_4bit();
2709
2710        let query = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2711        let vector = vec![0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85];
2712
2713        let quantized = quantizer.quantize(&vector);
2714        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2715
2716        let dist_scalar = adc.distance_squared(&quantized.data);
2717        let dist_simd = adc.distance_squared_simd(&quantized.data);
2718
2719        // SIMD should match scalar within floating point precision
2720        let diff = (dist_scalar - dist_simd).abs();
2721        assert!(diff < 0.01, "SIMD vs scalar: {dist_simd} vs {dist_scalar}");
2722    }
2723
2724    #[test]
2725    fn test_adc_simd_high_dimensional() {
2726        let quantizer = RaBitQ::default_4bit();
2727
2728        // 1536D vectors (OpenAI embeddings)
2729        let query: Vec<f32> = (0..1536).map(|i| (i as f32) / 1536.0).collect();
2730        let vector: Vec<f32> = (0..1536).map(|i| ((i + 10) as f32) / 1536.0).collect();
2731
2732        let quantized = quantizer.quantize(&vector);
2733        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2734
2735        // Should handle large dimensions efficiently
2736        let dist_simd = adc.distance_squared_simd(&quantized.data);
2737        assert!(dist_simd > 0.0 && dist_simd.is_finite());
2738    }
2739
2740    #[test]
2741    fn test_adc_memory_usage() {
2742        let quantizer = RaBitQ::default_4bit();
2743
2744        let query: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2745        let adc = quantizer.build_adc_table_with_scale(&query, 1.0);
2746
2747        let memory = adc.memory_bytes();
2748
2749        // For 128D, 4-bit: 128 * 16 * 4 bytes = 8KB (plus overhead)
2750        let expected_min = 128 * 16 * 4;
2751        assert!(
2752            memory >= expected_min,
2753            "Memory {memory} should be at least {expected_min}"
2754        );
2755    }
2756
2757    #[test]
2758    fn test_adc_different_scales() {
2759        let quantizer = RaBitQ::default_4bit();
2760
2761        let query = vec![0.5, 0.5, 0.5, 0.5];
2762        let vector = vec![0.6, 0.6, 0.6, 0.6];
2763
2764        let quantized = quantizer.quantize(&vector);
2765
2766        // Build ADC tables with different scales
2767        let adc1 = quantizer.build_adc_table_with_scale(&query, 0.5);
2768        let adc2 = quantizer.build_adc_table_with_scale(&query, 1.0);
2769        let adc3 = quantizer.build_adc_table_with_scale(&query, 2.0);
2770
2771        // Distances should differ based on scale
2772        let dist1 = adc1.distance(&quantized.data);
2773        let dist2 = adc2.distance(&quantized.data);
2774        let dist3 = adc3.distance(&quantized.data);
2775
2776        // All should be valid finite numbers
2777        assert!(dist1.is_finite());
2778        assert!(dist2.is_finite());
2779        assert!(dist3.is_finite());
2780    }
2781
2782    #[test]
2783    fn test_adc_edge_cases() {
2784        let quantizer = RaBitQ::default_4bit();
2785
2786        // Test with very small vector
2787        let query = vec![0.5];
2788        let vector = vec![0.6];
2789        let quantized = quantizer.quantize(&vector);
2790        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2791        let dist = adc.distance(&quantized.data);
2792        assert!(dist.is_finite());
2793
2794        // Test with all zeros
2795        let query = vec![0.0, 0.0, 0.0, 0.0];
2796        let vector = vec![0.0, 0.0, 0.0, 0.0];
2797        let quantized = quantizer.quantize(&vector);
2798        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2799        let dist = adc.distance(&quantized.data);
2800        assert!(dist.is_finite());
2801    }
2802
2803    #[test]
2804    fn test_adc_2bit_accuracy() {
2805        let quantizer = RaBitQ::new(RaBitQParams::bits2());
2806
2807        let query = vec![0.1, 0.2, 0.3, 0.4];
2808        let vector = vec![0.12, 0.22, 0.32, 0.42];
2809
2810        let quantized = quantizer.quantize(&vector);
2811
2812        // Test ADC for 2-bit quantization
2813        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2814        let dist_adc = adc.distance(&quantized.data);
2815        let dist_asymmetric = quantizer.distance_asymmetric_l2(&query, &quantized);
2816
2817        // Should be reasonably close despite lower precision
2818        let diff = (dist_adc - dist_asymmetric).abs();
2819        assert!(diff < 0.2, "2-bit ADC diff too large: {diff}");
2820    }
2821
2822    #[test]
2823    fn test_adc_8bit_accuracy() {
2824        let quantizer = RaBitQ::new(RaBitQParams::bits8());
2825
2826        let query = vec![0.1, 0.2, 0.3, 0.4];
2827        let vector = vec![0.12, 0.22, 0.32, 0.42];
2828
2829        let quantized = quantizer.quantize(&vector);
2830
2831        // Test ADC for 8-bit quantization (highest precision)
2832        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2833        let dist_adc = adc.distance(&quantized.data);
2834        let dist_asymmetric = quantizer.distance_asymmetric_l2(&query, &quantized);
2835
2836        // 8-bit should be very accurate
2837        let diff = (dist_adc - dist_asymmetric).abs();
2838        assert!(
2839            diff < 0.05,
2840            "8-bit ADC should be highly accurate, diff: {diff}"
2841        );
2842    }
2843}