omendb_core/compression/
rabitq.rs

1//! `RaBitQ` Quantization for `OmenDB`
2//!
3//! `RaBitQ` (SIGMOD 2025) provides flexible vector compression with arbitrary
4//! bit rates (2-8 bits per dimension) using optimal rescaling per vector.
5//!
6//! # Tiered Compression Strategy
7//!
8//! - L0-L2 (hot): Full precision f32 (no compression)
9//! - L3-L4 (warm): `RaBitQ` 4-bit (8× compression, 98% recall)
10//! - L5-L6 (cold): `RaBitQ` 2-bit (16× compression, 95% recall)
11//!
12//! # Key Features
13//!
14//! - Flexible compression (2, 3, 4, 5, 7, 8 bits/dimension)
15//! - Optimal rescaling for each vector
16//! - SIMD-accelerated distance (AVX2/NEON)
17//! - Same query speed as scalar quantization
18//! - Better accuracy than binary quantization
19
20use serde::{Deserialize, Serialize};
21use smallvec::SmallVec;
22use std::fmt;
23
24#[cfg(target_arch = "aarch64")]
25use std::arch::aarch64::{vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vsubq_f32};
26#[cfg(target_arch = "x86_64")]
27#[allow(clippy::wildcard_imports)]
28use std::arch::x86_64::*;
29
30/// Maximum number of codes per subspace (16 for 4-bit quantization)
31const MAX_CODES: usize = 16;
32
33/// Number of bits per dimension for quantization
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35pub enum QuantizationBits {
36    /// 1 bit per dimension (32x compression) - Binary/BBQ
37    Bits1,
38    /// 2 bits per dimension (16x compression)
39    Bits2,
40    /// 3 bits per dimension (~10x compression)
41    Bits3,
42    /// 4 bits per dimension (8x compression)
43    Bits4,
44    /// 5 bits per dimension (~6x compression)
45    Bits5,
46    /// 7 bits per dimension (~4x compression)
47    Bits7,
48    /// 8 bits per dimension (4x compression)
49    Bits8,
50}
51
52impl QuantizationBits {
53    /// Convert to number of bits
54    #[must_use]
55    pub fn to_u8(self) -> u8 {
56        match self {
57            QuantizationBits::Bits1 => 1,
58            QuantizationBits::Bits2 => 2,
59            QuantizationBits::Bits3 => 3,
60            QuantizationBits::Bits4 => 4,
61            QuantizationBits::Bits5 => 5,
62            QuantizationBits::Bits7 => 7,
63            QuantizationBits::Bits8 => 8,
64        }
65    }
66
67    /// Get number of quantization levels (2^bits)
68    #[must_use]
69    pub fn levels(self) -> usize {
70        1 << self.to_u8()
71    }
72
73    /// Get compression ratio vs f32 (32 bits / `bits_per_dim`)
74    #[must_use]
75    pub fn compression_ratio(self) -> f32 {
76        32.0 / self.to_u8() as f32
77    }
78
79    /// Get number of values that fit in one byte
80    #[must_use]
81    pub fn values_per_byte(self) -> usize {
82        8 / self.to_u8() as usize
83    }
84}
85
86impl fmt::Display for QuantizationBits {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        write!(f, "{}-bit", self.to_u8())
89    }
90}
91
92/// Configuration for `RaBitQ` quantization
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct RaBitQParams {
95    /// Number of bits per dimension
96    pub bits_per_dim: QuantizationBits,
97
98    /// Number of rescaling factors to try (DEPRECATED: use trained quantizer)
99    ///
100    /// Higher values = better quantization quality but slower
101    /// Typical range: 8-16
102    pub num_rescale_factors: usize,
103
104    /// Range of rescaling factors to try (DEPRECATED: use trained quantizer)
105    ///
106    /// Typical range: (0.5, 2.0) means try scales from 0.5x to 2.0x
107    pub rescale_range: (f32, f32),
108}
109
110/// Trained quantization parameters computed from data
111///
112/// Stores per-dimension min/max values learned from a representative sample.
113/// This enables consistent quantization across all vectors and correct ADC distances.
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct TrainedParams {
116    /// Minimum value per dimension
117    pub mins: Vec<f32>,
118    /// Maximum value per dimension
119    pub maxs: Vec<f32>,
120    /// Number of dimensions
121    pub dimensions: usize,
122}
123
124impl TrainedParams {
125    /// Train quantization parameters from sample vectors
126    ///
127    /// Computes per-dimension min/max using percentiles to exclude outliers.
128    /// Uses 1st and 99th percentiles by default for robustness.
129    ///
130    /// # Arguments
131    /// * `vectors` - Sample vectors to train from (should be representative)
132    ///
133    /// # Errors
134    /// Returns error if vectors is empty or vectors have inconsistent dimensions.
135    pub fn train(vectors: &[&[f32]]) -> Result<Self, &'static str> {
136        Self::train_with_percentiles(vectors, 0.01, 0.99)
137    }
138
139    /// Train with custom percentile bounds
140    ///
141    /// # Arguments
142    /// * `vectors` - Sample vectors to train from
143    /// * `lower_percentile` - Lower bound percentile (e.g., 0.01 for 1st percentile)
144    /// * `upper_percentile` - Upper bound percentile (e.g., 0.99 for 99th percentile)
145    ///
146    /// # Errors
147    /// Returns error if vectors is empty or vectors have inconsistent dimensions.
148    pub fn train_with_percentiles(
149        vectors: &[&[f32]],
150        lower_percentile: f32,
151        upper_percentile: f32,
152    ) -> Result<Self, &'static str> {
153        if vectors.is_empty() {
154            return Err("Need at least one vector to train");
155        }
156        let dimensions = vectors[0].len();
157        if !vectors.iter().all(|v| v.len() == dimensions) {
158            return Err("All vectors must have same dimensions");
159        }
160
161        let n = vectors.len();
162        let lower_idx = ((n as f32 * lower_percentile) as usize).min(n - 1);
163        let upper_idx = ((n as f32 * upper_percentile) as usize).min(n - 1);
164
165        let mut mins = Vec::with_capacity(dimensions);
166        let mut maxs = Vec::with_capacity(dimensions);
167
168        // For each dimension, collect values and compute percentiles
169        let mut dim_values: Vec<f32> = Vec::with_capacity(n);
170        for d in 0..dimensions {
171            dim_values.clear();
172            for v in vectors {
173                dim_values.push(v[d]);
174            }
175            dim_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
176
177            let min_val = dim_values[lower_idx];
178            let max_val = dim_values[upper_idx];
179
180            // Ensure non-zero range (add small epsilon if needed)
181            let range = max_val - min_val;
182            if range < 1e-7 {
183                mins.push(min_val - 0.5);
184                maxs.push(max_val + 0.5);
185            } else {
186                mins.push(min_val);
187                maxs.push(max_val);
188            }
189        }
190
191        Ok(Self {
192            mins,
193            maxs,
194            dimensions,
195        })
196    }
197
198    /// Quantize a single value using trained parameters for given dimension
199    #[inline]
200    #[must_use]
201    pub fn quantize_value(&self, value: f32, dim: usize, levels: usize) -> u8 {
202        let min = self.mins[dim];
203        let max = self.maxs[dim];
204        let range = max - min;
205
206        // Map value to [0, 1] range, then to [0, levels-1]
207        let normalized = (value - min) / range;
208        let level = (normalized * (levels - 1) as f32).round();
209        level.clamp(0.0, (levels - 1) as f32) as u8
210    }
211
212    /// Dequantize a code to reconstructed value for given dimension
213    #[inline]
214    #[must_use]
215    pub fn dequantize_value(&self, code: u8, dim: usize, levels: usize) -> f32 {
216        let min = self.mins[dim];
217        let max = self.maxs[dim];
218        let range = max - min;
219
220        // Map code back to original range
221        (code as f32 / (levels - 1) as f32) * range + min
222    }
223}
224
225impl Default for RaBitQParams {
226    fn default() -> Self {
227        Self {
228            bits_per_dim: QuantizationBits::Bits4, // 8x compression
229            num_rescale_factors: 12,               // Good balance
230            rescale_range: (0.5, 2.0),             // Paper recommendation
231        }
232    }
233}
234
235impl RaBitQParams {
236    /// Create parameters for 2-bit quantization (16x compression)
237    #[must_use]
238    pub fn bits2() -> Self {
239        Self {
240            bits_per_dim: QuantizationBits::Bits2,
241            ..Default::default()
242        }
243    }
244
245    /// Create parameters for 4-bit quantization (8x compression, recommended)
246    #[must_use]
247    pub fn bits4() -> Self {
248        Self {
249            bits_per_dim: QuantizationBits::Bits4,
250            ..Default::default()
251        }
252    }
253
254    /// Create parameters for 8-bit quantization (4x compression, highest quality)
255    #[must_use]
256    pub fn bits8() -> Self {
257        Self {
258            bits_per_dim: QuantizationBits::Bits8,
259            num_rescale_factors: 16,   // More factors for higher precision
260            rescale_range: (0.7, 1.5), // Narrower range for 8-bit
261        }
262    }
263}
264
265/// A quantized vector with optimal rescaling
266///
267/// Storage format:
268/// - data: Packed quantized values (multiple values per byte)
269/// - scale: Optimal rescaling factor for this vector
270/// - bits: Number of bits per dimension
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct QuantizedVector {
273    /// Packed quantized values
274    ///
275    /// Format depends on `bits_per_dim`:
276    /// - 2-bit: 4 values per byte
277    /// - 3-bit: Not byte-aligned, needs special packing
278    /// - 4-bit: 2 values per byte
279    /// - 8-bit: 1 value per byte
280    pub data: Vec<u8>,
281
282    /// Optimal rescaling factor for this vector
283    ///
284    /// This is the scale factor that minimized quantization error
285    /// during the rescaling search.
286    pub scale: f32,
287
288    /// Number of bits per dimension
289    pub bits: u8,
290
291    /// Original vector dimensions (for unpacking)
292    pub dimensions: usize,
293}
294
295impl QuantizedVector {
296    /// Create a new quantized vector
297    #[must_use]
298    pub fn new(data: Vec<u8>, scale: f32, bits: u8, dimensions: usize) -> Self {
299        Self {
300            data,
301            scale,
302            bits,
303            dimensions,
304        }
305    }
306
307    /// Get memory usage in bytes
308    #[must_use]
309    pub fn memory_bytes(&self) -> usize {
310        std::mem::size_of::<Self>() + self.data.len()
311    }
312
313    /// Get compression ratio vs original f32 vector
314    #[must_use]
315    pub fn compression_ratio(&self) -> f32 {
316        let original_bytes = self.dimensions * 4; // f32 = 4 bytes
317        let compressed_bytes = self.data.len() + 4 + 1; // data + scale + bits
318        original_bytes as f32 / compressed_bytes as f32
319    }
320}
321
322/// `RaBitQ` quantizer
323///
324/// Implements scalar quantization with trained per-dimension ranges.
325/// Training computes min/max per dimension from sample data, enabling
326/// consistent quantization across all vectors and correct ADC distances.
327///
328/// # Usage
329///
330/// ```ignore
331/// // Create and train quantizer
332/// let mut quantizer = RaBitQ::new(RaBitQParams::bits4());
333/// quantizer.train(&sample_vectors);
334///
335/// // Quantize vectors
336/// let quantized = quantizer.quantize(&vector);
337///
338/// // ADC search (distances are mathematically correct)
339/// let adc = quantizer.build_adc_table(&query);
340/// let dist = adc.distance(&quantized.data);
341/// ```
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct RaBitQ {
344    params: RaBitQParams,
345    /// Trained parameters (per-dimension min/max)
346    /// When None, falls back to legacy per-vector scaling (deprecated)
347    trained: Option<TrainedParams>,
348}
349
350/// Asymmetric Distance Computation (ADC) lookup table for fast quantized search
351///
352/// Precomputes partial squared distances from a query vector to all possible
353/// quantized codes. This enables O(1) distance computation per dimension instead
354/// of O(dim) decompression + distance calculation.
355///
356/// # Performance
357///
358/// - **Memory**: For 4-bit: dim * 16 * 4 bytes (e.g., 1536D = 96KB per query)
359/// - **Speedup**: 5-10x faster distance computation vs full decompression
360/// - **Use case**: Scanning many candidate vectors during HNSW search
361///
362/// # Algorithm
363///
364/// Instead of:
365/// ```ignore
366/// for each candidate:
367///     decompress(candidate) -> O(dim)
368///     distance(query, decompressed) -> O(dim)
369/// ```
370///
371/// With ADC:
372/// ```ignore
373/// precompute table[code] = (query_value - dequantize(code))^2 for all codes
374/// for each candidate:
375///     sum(table[candidate[i]]) -> O(1) per dimension
376/// ```
377#[derive(Debug, Clone)]
378pub struct ADCTable {
379    /// Lookup table: `table[dim_idx][code] = partial squared distance`
380    /// For 4-bit: each inner array has 16 entries (codes 0-15)
381    /// For 2-bit: each inner array has 4 entries (codes 0-3)
382    table: Vec<SmallVec<[f32; MAX_CODES]>>,
383
384    /// Quantization parameters (bits per dimension)
385    bits: u8,
386
387    /// Number of dimensions
388    dimensions: usize,
389}
390
391impl ADCTable {
392    /// Build ADC lookup table using trained quantization parameters
393    ///
394    /// This is the production method that computes correct distances.
395    /// Uses per-dimension min/max ranges from training.
396    ///
397    /// # Arguments
398    ///
399    /// * `query` - The uncompressed query vector
400    /// * `trained` - Trained parameters with per-dimension min/max
401    /// * `params` - Quantization parameters (bits per dimension)
402    #[must_use]
403    pub fn new_trained(query: &[f32], trained: &TrainedParams, params: &RaBitQParams) -> Self {
404        let bits = params.bits_per_dim.to_u8();
405        let num_codes = params.bits_per_dim.levels();
406        let dimensions = query.len();
407
408        let mut table = Vec::with_capacity(dimensions);
409
410        // For each dimension, compute distances to all possible codes
411        for (d, &q_value) in query.iter().enumerate() {
412            let mut dim_table = SmallVec::new();
413
414            for code in 0..num_codes {
415                // Dequantize using trained min/max for this dimension
416                let reconstructed = trained.dequantize_value(code as u8, d, num_codes);
417
418                // Compute squared difference (partial L2 distance)
419                let diff = q_value - reconstructed;
420                dim_table.push(diff * diff);
421            }
422
423            table.push(dim_table);
424        }
425
426        Self {
427            table,
428            bits,
429            dimensions,
430        }
431    }
432
433    /// Build ADC lookup table for a query vector (DEPRECATED)
434    ///
435    /// For each dimension and each possible quantized code, precomputes the
436    /// squared distance contribution: (query[i] - dequantize(code, scale))^2
437    ///
438    /// WARNING: This method uses a fixed scale which produces incorrect distances
439    /// when vectors were quantized with different scales. Use `new_trained()` instead.
440    ///
441    /// # Arguments
442    ///
443    /// * `query` - The uncompressed query vector
444    /// * `scale` - The scale factor used for quantization (from training or default)
445    /// * `params` - Quantization parameters (bits per dimension)
446    ///
447    /// # Returns
448    ///
449    /// An `ADCTable` that can compute distances via `distance()` method
450    ///
451    /// # Note
452    /// Prefer `ADCTable::new_trained()` with `TrainedParams` for correct ADC distances.
453    /// This method uses per-vector scale which gives lower accuracy.
454    #[must_use]
455    pub fn new(query: &[f32], scale: f32, params: &RaBitQParams) -> Self {
456        let bits = params.bits_per_dim.to_u8();
457        let num_codes = params.bits_per_dim.levels();
458        let dimensions = query.len();
459
460        let mut table = Vec::with_capacity(dimensions);
461
462        // Dequantization factor: value = code / (levels - 1) / scale
463        let levels = num_codes as f32;
464        let dequant_factor = 1.0 / ((levels - 1.0) * scale);
465
466        // For each dimension, compute distances to all possible codes
467        for &q_value in query {
468            let mut dim_table = SmallVec::new();
469
470            for code in 0..num_codes {
471                // Dequantize the code to get the reconstructed value
472                let reconstructed = (code as f32) * dequant_factor;
473
474                // Compute squared difference (partial L2 distance)
475                let diff = q_value - reconstructed;
476                dim_table.push(diff * diff);
477            }
478
479            table.push(dim_table);
480        }
481
482        Self {
483            table,
484            bits,
485            dimensions,
486        }
487    }
488
489    /// Compute approximate L2 squared distance using lookup table
490    ///
491    /// This is the hot path for search! Instead of decompressing and computing
492    /// distance, we just sum up precomputed values from the table.
493    ///
494    /// # Performance
495    ///
496    /// - 4-bit: ~5-10x faster than decompression + distance
497    /// - Cache-friendly: sequential access patterns
498    /// - SIMD-friendly: can vectorize the summation
499    ///
500    /// # Arguments
501    ///
502    /// * `data` - Packed quantized bytes
503    ///
504    /// # Returns
505    ///
506    /// Approximate squared L2 distance (not square-rooted for efficiency)
507    #[inline]
508    #[must_use]
509    pub fn distance_squared(&self, data: &[u8]) -> f32 {
510        match self.bits {
511            4 => self.distance_squared_4bit(data),
512            2 => self.distance_squared_2bit(data),
513            8 => self.distance_squared_8bit(data),
514            _ => self.distance_squared_generic(data),
515        }
516    }
517
518    /// Compute distance and return square root (actual L2 distance)
519    ///
520    /// Uses SIMD-accelerated distance computation (AVX2 on `x86_64`, NEON on aarch64).
521    #[inline]
522    #[must_use]
523    pub fn distance(&self, data: &[u8]) -> f32 {
524        self.distance_squared_simd(data).sqrt()
525    }
526
527    /// Fast path for 4-bit quantization (most common case)
528    ///
529    /// # Safety invariants (maintained by `ADCTable::new`)
530    /// - `self.table.len() == self.dimensions`
531    /// - Each `table[i]` has exactly 16 entries (4-bit = 2^4 codes)
532    /// - Input `data` has `ceil(dimensions/2)` bytes (2 values per byte)
533    #[inline]
534    fn distance_squared_4bit(&self, data: &[u8]) -> f32 {
535        let mut sum = 0.0f32;
536        let num_pairs = self.dimensions / 2;
537
538        // Process pairs of dimensions (2 codes per byte)
539        for i in 0..num_pairs {
540            if i >= data.len() {
541                break;
542            }
543
544            // SAFETY: i < num_pairs <= data.len() (checked above)
545            let byte = unsafe { *data.get_unchecked(i) };
546            let code_hi = (byte >> 4) as usize; // 0..=15
547            let code_lo = (byte & 0x0F) as usize; // 0..=15
548
549            // SAFETY:
550            // - i*2 < dimensions (since i < num_pairs = dimensions/2)
551            // - i*2+1 < dimensions (same reasoning)
552            // - code_hi, code_lo in 0..16 (4-bit mask guarantees this)
553            // - table has 16 entries per dimension (4-bit quantization)
554            sum += unsafe {
555                *self.table.get_unchecked(i * 2).get_unchecked(code_hi)
556                    + *self.table.get_unchecked(i * 2 + 1).get_unchecked(code_lo)
557            };
558        }
559
560        // Handle odd dimension
561        if self.dimensions % 2 == 1 && num_pairs < data.len() {
562            // SAFETY: num_pairs < data.len() checked above
563            let byte = unsafe { *data.get_unchecked(num_pairs) };
564            let code_hi = (byte >> 4) as usize; // 0..=15
565                                                // SAFETY:
566                                                // - dimensions-1 is valid index (dimensions >= 1 when odd)
567                                                // - code_hi in 0..16 (4-bit mask)
568            sum += unsafe {
569                *self
570                    .table
571                    .get_unchecked(self.dimensions - 1)
572                    .get_unchecked(code_hi)
573            };
574        }
575
576        sum
577    }
578
579    /// Fast path for 2-bit quantization
580    ///
581    /// # Safety invariants (maintained by `ADCTable::new`)
582    /// - `self.table.len() == self.dimensions`
583    /// - Each `table[i]` has exactly 4 entries (2-bit = 2^2 codes)
584    /// - Input `data` has `ceil(dimensions/4)` bytes (4 values per byte)
585    #[inline]
586    fn distance_squared_2bit(&self, data: &[u8]) -> f32 {
587        let mut sum = 0.0f32;
588        let num_quads = self.dimensions / 4;
589
590        // Process quads of dimensions (4 codes per byte)
591        for i in 0..num_quads {
592            if i >= data.len() {
593                break;
594            }
595
596            // SAFETY: i < num_quads <= data.len() (checked above)
597            let byte = unsafe { *data.get_unchecked(i) };
598
599            // SAFETY:
600            // - i*4+k < dimensions for k in 0..4 (since i < num_quads = dimensions/4)
601            // - all codes in 0..4 (2-bit mask guarantees this)
602            // - table has 4 entries per dimension (2-bit quantization)
603            sum += unsafe {
604                *self
605                    .table
606                    .get_unchecked(i * 4)
607                    .get_unchecked((byte & 0b11) as usize)
608                    + *self
609                        .table
610                        .get_unchecked(i * 4 + 1)
611                        .get_unchecked(((byte >> 2) & 0b11) as usize)
612                    + *self
613                        .table
614                        .get_unchecked(i * 4 + 2)
615                        .get_unchecked(((byte >> 4) & 0b11) as usize)
616                    + *self
617                        .table
618                        .get_unchecked(i * 4 + 3)
619                        .get_unchecked(((byte >> 6) & 0b11) as usize)
620            };
621        }
622
623        // Handle remainder
624        let remaining = self.dimensions % 4;
625        if remaining > 0 && num_quads < data.len() {
626            // SAFETY: num_quads < data.len() checked above
627            let byte = unsafe { *data.get_unchecked(num_quads) };
628            for j in 0..remaining {
629                let code = ((byte >> (j * 2)) & 0b11) as usize; // 0..=3
630                                                                // SAFETY:
631                                                                // - num_quads*4+j < dimensions (j < remaining = dimensions%4)
632                                                                // - code in 0..4 (2-bit mask)
633                sum += unsafe {
634                    *self
635                        .table
636                        .get_unchecked(num_quads * 4 + j)
637                        .get_unchecked(code)
638                };
639            }
640        }
641
642        sum
643    }
644
645    /// Fast path for 8-bit quantization
646    ///
647    /// # Safety invariants (maintained by `ADCTable::new`)
648    /// - `self.table.len() == self.dimensions`
649    /// - Each `table[i]` has exactly 256 entries (8-bit = 2^8 codes)
650    /// - Input `data` has `dimensions` bytes (1 value per byte)
651    #[inline]
652    fn distance_squared_8bit(&self, data: &[u8]) -> f32 {
653        let mut sum = 0.0f32;
654
655        for (i, &byte) in data.iter().enumerate().take(self.dimensions) {
656            // SAFETY:
657            // - i < dimensions (take() ensures this)
658            // - byte as usize in 0..256 (u8 range)
659            // - table has 256 entries per dimension (8-bit quantization)
660            sum += unsafe { *self.table.get_unchecked(i).get_unchecked(byte as usize) };
661        }
662
663        sum
664    }
665
666    /// Generic fallback for other bit widths
667    #[inline]
668    fn distance_squared_generic(&self, data: &[u8]) -> f32 {
669        // For non-standard bit widths, fall back to bounds-checked access
670        let mut sum = 0.0f32;
671
672        for (i, dim_table) in self.table.iter().enumerate() {
673            if i >= data.len() {
674                break;
675            }
676            let code = data[i] as usize;
677            if let Some(&dist) = dim_table.get(code) {
678                sum += dist;
679            }
680        }
681
682        sum
683    }
684
685    /// SIMD-accelerated distance computation for 4-bit quantization
686    ///
687    /// Uses AVX2 on `x86_64` or NEON on `aarch64` to process multiple lookups in parallel.
688    /// Falls back to scalar implementation if SIMD not available.
689    #[inline]
690    #[must_use]
691    pub fn distance_squared_simd(&self, data: &[u8]) -> f32 {
692        match self.bits {
693            4 => {
694                #[cfg(target_arch = "x86_64")]
695                {
696                    if is_x86_feature_detected!("avx2") {
697                        unsafe { self.distance_squared_4bit_avx2(data) }
698                    } else {
699                        // x86_64 fallback to scalar
700                        self.distance_squared_4bit(data)
701                    }
702                }
703                #[cfg(target_arch = "aarch64")]
704                {
705                    // NEON is always available on aarch64
706                    unsafe { self.distance_squared_4bit_neon(data) }
707                }
708                #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
709                {
710                    // Other architectures fallback to scalar
711                    self.distance_squared_4bit(data)
712                }
713            }
714            2 => {
715                // For 2-bit, scalar is already quite fast
716                self.distance_squared_2bit(data)
717            }
718            8 => {
719                // For 8-bit, could use SIMD gather but scalar is reasonable
720                self.distance_squared_8bit(data)
721            }
722            _ => self.distance_squared_generic(data),
723        }
724    }
725
726    /// AVX2 implementation for 4-bit ADC distance
727    #[cfg(target_arch = "x86_64")]
728    #[target_feature(enable = "avx2")]
729    #[target_feature(enable = "fma")]
730    unsafe fn distance_squared_4bit_avx2(&self, data: &[u8]) -> f32 {
731        let mut sum = _mm256_setzero_ps();
732        let num_pairs = self.dimensions / 2;
733
734        // Process 8 pairs (16 dimensions) at a time using AVX2
735        let chunks = num_pairs / 8;
736        for chunk_idx in 0..chunks {
737            let byte_idx = chunk_idx * 8;
738            if byte_idx + 8 > data.len() {
739                break;
740            }
741
742            // Load 8 bytes (16 4-bit codes)
743            let mut values = [0.0f32; 8];
744            for (i, value) in values.iter_mut().enumerate() {
745                let byte = *data.get_unchecked(byte_idx + i);
746                let code_hi = (byte >> 4) as usize;
747                let code_lo = (byte & 0x0F) as usize;
748
749                let dist_hi = *self
750                    .table
751                    .get_unchecked((byte_idx + i) * 2)
752                    .get_unchecked(code_hi);
753                let dist_lo = *self
754                    .table
755                    .get_unchecked((byte_idx + i) * 2 + 1)
756                    .get_unchecked(code_lo);
757                *value = dist_hi + dist_lo;
758            }
759
760            let vec = _mm256_loadu_ps(values.as_ptr());
761            sum = _mm256_add_ps(sum, vec);
762        }
763
764        // Horizontal sum of AVX2 register
765        let mut result = horizontal_sum_avx2(sum);
766
767        // Handle remainder with scalar
768        for i in (chunks * 8)..num_pairs {
769            if i >= data.len() {
770                break;
771            }
772            let byte = *data.get_unchecked(i);
773            let code_hi = (byte >> 4) as usize;
774            let code_lo = (byte & 0x0F) as usize;
775
776            result += *self.table.get_unchecked(i * 2).get_unchecked(code_hi)
777                + *self.table.get_unchecked(i * 2 + 1).get_unchecked(code_lo);
778        }
779
780        // Handle odd dimension
781        if self.dimensions % 2 == 1 && num_pairs < data.len() {
782            let byte = *data.get_unchecked(num_pairs);
783            let code_hi = (byte >> 4) as usize;
784            result += *self
785                .table
786                .get_unchecked(self.dimensions - 1)
787                .get_unchecked(code_hi);
788        }
789
790        result
791    }
792
793    /// NEON implementation for 4-bit ADC distance
794    #[cfg(target_arch = "aarch64")]
795    unsafe fn distance_squared_4bit_neon(&self, data: &[u8]) -> f32 {
796        let mut sum = vdupq_n_f32(0.0);
797        let num_pairs = self.dimensions / 2;
798
799        // Process 4 pairs (8 dimensions) at a time using NEON
800        let chunks = num_pairs / 4;
801        for chunk_idx in 0..chunks {
802            let byte_idx = chunk_idx * 4;
803            if byte_idx + 4 > data.len() {
804                break;
805            }
806
807            let mut values = [0.0f32; 4];
808            for (i, value) in values.iter_mut().enumerate() {
809                let byte = *data.get_unchecked(byte_idx + i);
810                let code_hi = (byte >> 4) as usize;
811                let code_lo = (byte & 0x0F) as usize;
812
813                let dist_hi = *self
814                    .table
815                    .get_unchecked((byte_idx + i) * 2)
816                    .get_unchecked(code_hi);
817                let dist_lo = *self
818                    .table
819                    .get_unchecked((byte_idx + i) * 2 + 1)
820                    .get_unchecked(code_lo);
821                *value = dist_hi + dist_lo;
822            }
823
824            let vec = vld1q_f32(values.as_ptr());
825            sum = vaddq_f32(sum, vec);
826        }
827
828        let mut result = vaddvq_f32(sum);
829
830        // Handle remainder with scalar
831        for i in (chunks * 4)..num_pairs {
832            if i >= data.len() {
833                break;
834            }
835            let byte = *data.get_unchecked(i);
836            let code_hi = (byte >> 4) as usize;
837            let code_lo = (byte & 0x0F) as usize;
838
839            result += *self.table.get_unchecked(i * 2).get_unchecked(code_hi)
840                + *self.table.get_unchecked(i * 2 + 1).get_unchecked(code_lo);
841        }
842
843        // Handle odd dimension
844        if self.dimensions % 2 == 1 && num_pairs < data.len() {
845            let byte = *data.get_unchecked(num_pairs);
846            let code_hi = (byte >> 4) as usize;
847            result += *self
848                .table
849                .get_unchecked(self.dimensions - 1)
850                .get_unchecked(code_hi);
851        }
852
853        result
854    }
855
856    /// Get bits per dimension
857    #[must_use]
858    pub fn bits(&self) -> u8 {
859        self.bits
860    }
861
862    /// Get number of dimensions
863    #[must_use]
864    pub fn dimensions(&self) -> usize {
865        self.dimensions
866    }
867
868    /// Get partial distance for a dimension and code
869    ///
870    /// Returns 0.0 if indices are out of bounds.
871    #[must_use]
872    pub fn get(&self, dim: usize, code: usize) -> f32 {
873        self.table
874            .get(dim)
875            .and_then(|t| t.get(code))
876            .copied()
877            .unwrap_or(0.0)
878    }
879
880    /// Get memory usage in bytes
881    #[must_use]
882    pub fn memory_bytes(&self) -> usize {
883        std::mem::size_of::<Self>()
884            + self.table.len() * std::mem::size_of::<SmallVec<[f32; MAX_CODES]>>()
885            + self.table.iter().map(|t| t.len() * 4).sum::<usize>()
886    }
887}
888
889impl RaBitQ {
890    /// Create a new `RaBitQ` quantizer (untrained)
891    ///
892    /// Call `train()` before use to enable correct ADC distances.
893    #[must_use]
894    pub fn new(params: RaBitQParams) -> Self {
895        Self {
896            params,
897            trained: None,
898        }
899    }
900
901    /// Create a trained `RaBitQ` quantizer
902    #[must_use]
903    pub fn new_trained(params: RaBitQParams, trained: TrainedParams) -> Self {
904        Self {
905            params,
906            trained: Some(trained),
907        }
908    }
909
910    /// Create with default 4-bit quantization
911    #[must_use]
912    pub fn default_4bit() -> Self {
913        Self::new(RaBitQParams::bits4())
914    }
915
916    /// Get quantization parameters
917    #[must_use]
918    pub fn params(&self) -> &RaBitQParams {
919        &self.params
920    }
921
922    /// Check if quantizer has been trained
923    #[must_use]
924    pub fn is_trained(&self) -> bool {
925        self.trained.is_some()
926    }
927
928    /// Get trained parameters (if any)
929    #[must_use]
930    pub fn trained_params(&self) -> Option<&TrainedParams> {
931        self.trained.as_ref()
932    }
933
934    /// Train quantizer on sample vectors
935    ///
936    /// Computes per-dimension min/max ranges from the sample.
937    /// Must be called before quantization for correct ADC distances.
938    ///
939    /// # Arguments
940    /// * `vectors` - Representative sample of vectors to train from
941    ///
942    /// # Errors
943    /// Returns error if vectors is empty or have inconsistent dimensions.
944    pub fn train(&mut self, vectors: &[&[f32]]) -> Result<(), &'static str> {
945        self.trained = Some(TrainedParams::train(vectors)?);
946        Ok(())
947    }
948
949    /// Train with owned vectors (convenience method)
950    ///
951    /// # Errors
952    /// Returns error if vectors is empty or have inconsistent dimensions.
953    pub fn train_owned(&mut self, vectors: &[Vec<f32>]) -> Result<(), &'static str> {
954        let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
955        self.train(&refs)
956    }
957
958    /// Quantize a vector using trained parameters
959    ///
960    /// If trained: uses per-dimension min/max for consistent quantization
961    /// If untrained: falls back to legacy per-vector scaling (deprecated)
962    #[must_use]
963    pub fn quantize(&self, vector: &[f32]) -> QuantizedVector {
964        // Use trained quantization if available
965        if let Some(ref trained) = self.trained {
966            return self.quantize_trained(vector, trained);
967        }
968
969        // Legacy fallback: per-vector scale search (deprecated)
970        let mut best_error = f32::MAX;
971        let mut best_quantized = Vec::new();
972        let mut best_scale = 1.0;
973
974        // Generate rescaling factors to try
975        let scales = self.generate_scales();
976
977        // Try each scale and find the one with minimum error
978        for scale in scales {
979            let quantized = self.quantize_with_scale(vector, scale);
980            let error = self.compute_error(vector, &quantized, scale);
981
982            if error < best_error {
983                best_error = error;
984                best_quantized = quantized;
985                best_scale = scale;
986            }
987        }
988
989        QuantizedVector::new(
990            best_quantized,
991            best_scale,
992            self.params.bits_per_dim.to_u8(),
993            vector.len(),
994        )
995    }
996
997    /// Quantize using trained per-dimension min/max ranges
998    ///
999    /// This is the production path that enables correct ADC distances.
1000    fn quantize_trained(&self, vector: &[f32], trained: &TrainedParams) -> QuantizedVector {
1001        let bits = self.params.bits_per_dim.to_u8();
1002        let levels = self.params.bits_per_dim.levels();
1003
1004        // Quantize each dimension using trained min/max
1005        let quantized: Vec<u8> = vector
1006            .iter()
1007            .enumerate()
1008            .map(|(d, &v)| trained.quantize_value(v, d, levels))
1009            .collect();
1010
1011        // Pack into bytes
1012        let packed = self.pack_quantized(&quantized, bits);
1013
1014        // Scale=1.0 for trained quantization (min/max handles the range)
1015        QuantizedVector::new(packed, 1.0, bits, vector.len())
1016    }
1017
1018    /// Generate rescaling factors to try
1019    ///
1020    /// Returns a vector of scale factors evenly spaced between
1021    /// `rescale_range.0` and `rescale_range.1`
1022    fn generate_scales(&self) -> Vec<f32> {
1023        let (min_scale, max_scale) = self.params.rescale_range;
1024        let n = self.params.num_rescale_factors;
1025
1026        if n == 1 {
1027            return vec![f32::midpoint(min_scale, max_scale)];
1028        }
1029
1030        let step = (max_scale - min_scale) / (n - 1) as f32;
1031        (0..n).map(|i| min_scale + i as f32 * step).collect()
1032    }
1033
1034    /// Quantize a vector with a specific scale factor
1035    ///
1036    /// Algorithm (Extended RaBitQ):
1037    /// 1. Scale: v' = v * scale
1038    /// 2. Quantize to grid: q = round(v' * (2^bits - 1))
1039    /// 3. Clamp to valid range
1040    /// 4. Pack into bytes
1041    fn quantize_with_scale(&self, vector: &[f32], scale: f32) -> Vec<u8> {
1042        let bits = self.params.bits_per_dim.to_u8();
1043        let levels = self.params.bits_per_dim.levels() as f32;
1044        let max_level = (levels - 1.0) as u8;
1045
1046        // Scale and quantize directly (no normalization needed)
1047        let quantized: Vec<u8> = vector
1048            .iter()
1049            .map(|&v| {
1050                // Scale the value
1051                let scaled = v * scale;
1052                // Quantize to grid [0, levels-1]
1053                let level = (scaled * (levels - 1.0)).round();
1054                // Clamp to valid range
1055                level.clamp(0.0, max_level as f32) as u8
1056            })
1057            .collect();
1058
1059        // Pack into bytes
1060        self.pack_quantized(&quantized, bits)
1061    }
1062
1063    /// Pack quantized values into bytes
1064    ///
1065    /// Packing depends on bits per dimension:
1066    /// - 2-bit: 4 values per byte (00 00 00 00)
1067    /// - 4-bit: 2 values per byte (0000 0000)
1068    /// - 8-bit: 1 value per byte
1069    #[allow(clippy::unused_self)]
1070    fn pack_quantized(&self, values: &[u8], bits: u8) -> Vec<u8> {
1071        match bits {
1072            2 => {
1073                // 4 values per byte
1074                let mut packed = Vec::with_capacity(values.len().div_ceil(4));
1075                for chunk in values.chunks(4) {
1076                    let mut byte = 0u8;
1077                    for (i, &val) in chunk.iter().enumerate() {
1078                        byte |= (val & 0b11) << (i * 2);
1079                    }
1080                    packed.push(byte);
1081                }
1082                packed
1083            }
1084            4 => {
1085                // 2 values per byte
1086                let mut packed = Vec::with_capacity(values.len().div_ceil(2));
1087                for chunk in values.chunks(2) {
1088                    let byte = if chunk.len() == 2 {
1089                        (chunk[0] << 4) | (chunk[1] & 0x0F)
1090                    } else {
1091                        chunk[0] << 4
1092                    };
1093                    packed.push(byte);
1094                }
1095                packed
1096            }
1097            8 => {
1098                // 1 value per byte (no packing needed)
1099                values.to_vec()
1100            }
1101            _ => {
1102                // 3, 5, 7-bit: fall back to 8-bit storage
1103                // Not implementing proper bit-packing because:
1104                // - Public API only exposes 2, 4, 8-bit (see python/src/lib.rs)
1105                // - Cross-byte packing is complex with marginal compression benefit
1106                // - 4-bit (8x) vs 5-bit (~6x) isn't worth the code complexity
1107                values.to_vec()
1108            }
1109        }
1110    }
1111
1112    /// Unpack quantized bytes into individual values
1113    #[must_use]
1114    pub fn unpack_quantized(&self, packed: &[u8], bits: u8, dimensions: usize) -> Vec<u8> {
1115        match bits {
1116            2 => {
1117                // 4 values per byte
1118                let mut values = Vec::with_capacity(dimensions);
1119                for &byte in packed {
1120                    for i in 0..4 {
1121                        if values.len() < dimensions {
1122                            values.push((byte >> (i * 2)) & 0b11);
1123                        }
1124                    }
1125                }
1126                values
1127            }
1128            4 => {
1129                // 2 values per byte
1130                let mut values = Vec::with_capacity(dimensions);
1131                for &byte in packed {
1132                    values.push(byte >> 4);
1133                    if values.len() < dimensions {
1134                        values.push(byte & 0x0F);
1135                    }
1136                }
1137                values.truncate(dimensions);
1138                values
1139            }
1140            8 => {
1141                // 1 value per byte
1142                packed[..dimensions.min(packed.len())].to_vec()
1143            }
1144            _ => {
1145                // For other bit widths, assume 8-bit storage
1146                packed[..dimensions.min(packed.len())].to_vec()
1147            }
1148        }
1149    }
1150
1151    /// Compute quantization error (reconstruction error)
1152    ///
1153    /// Error = ||original - reconstructed||²
1154    fn compute_error(&self, original: &[f32], quantized: &[u8], scale: f32) -> f32 {
1155        let reconstructed = self.reconstruct(quantized, scale, original.len());
1156
1157        original
1158            .iter()
1159            .zip(reconstructed.iter())
1160            .map(|(o, r)| (o - r).powi(2))
1161            .sum()
1162    }
1163
1164    /// Reconstruct (dequantize) a quantized vector
1165    ///
1166    /// Algorithm (Extended RaBitQ):
1167    /// 1. Unpack bytes to quantized values [0, 2^bits-1]
1168    /// 2. Denormalize: v' = q / (2^bits - 1)
1169    /// 3. Unscale: v = v' / scale
1170    #[must_use]
1171    pub fn reconstruct(&self, quantized: &[u8], scale: f32, dimensions: usize) -> Vec<f32> {
1172        let bits = self.params.bits_per_dim.to_u8();
1173        let levels = self.params.bits_per_dim.levels() as f32;
1174
1175        // Unpack bytes
1176        let values = self.unpack_quantized(quantized, bits, dimensions);
1177
1178        // Dequantize: reverse the quantization process
1179        values
1180            .iter()
1181            .map(|&q| {
1182                // Denormalize from [0, levels-1] to [0, 1]
1183                let denorm = q as f32 / (levels - 1.0);
1184                // Unscale
1185                denorm / scale
1186            })
1187            .collect()
1188    }
1189
1190    /// Compute L2 (Euclidean) distance between two quantized vectors
1191    ///
1192    /// This reconstructs both vectors and computes standard L2 distance.
1193    /// For maximum accuracy, use this with original vectors for reranking.
1194    #[must_use]
1195    pub fn distance_l2(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1196        let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1197        let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1198
1199        v1.iter()
1200            .zip(v2.iter())
1201            .map(|(a, b)| (a - b).powi(2))
1202            .sum::<f32>()
1203            .sqrt()
1204    }
1205
1206    /// Compute cosine distance between two quantized vectors
1207    ///
1208    /// Cosine distance = 1 - cosine similarity
1209    #[must_use]
1210    pub fn distance_cosine(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1211        let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1212        let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1213
1214        let dot: f32 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
1215        let norm1: f32 = v1.iter().map(|a| a * a).sum::<f32>().sqrt();
1216        let norm2: f32 = v2.iter().map(|b| b * b).sum::<f32>().sqrt();
1217
1218        if norm1 < 1e-10 || norm2 < 1e-10 {
1219            return 1.0; // Maximum distance for zero vectors
1220        }
1221
1222        let cosine_sim = dot / (norm1 * norm2);
1223        1.0 - cosine_sim
1224    }
1225
1226    /// Compute dot product between two quantized vectors
1227    #[must_use]
1228    pub fn distance_dot(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1229        let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1230        let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1231
1232        // Return negative dot product (for nearest neighbor search)
1233        -v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum::<f32>()
1234    }
1235
1236    /// Compute approximate distance using quantized values directly (fast path)
1237    ///
1238    /// This computes distance in the quantized space without full reconstruction.
1239    /// Faster but less accurate than `distance_l2`.
1240    #[must_use]
1241    pub fn distance_approximate(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1242        // Unpack to quantized values (u8)
1243        let v1 = self.unpack_quantized(&qv1.data, qv1.bits, qv1.dimensions);
1244        let v2 = self.unpack_quantized(&qv2.data, qv2.bits, qv2.dimensions);
1245
1246        // Compute L2 distance in quantized space
1247        v1.iter()
1248            .zip(v2.iter())
1249            .map(|(a, b)| {
1250                let diff = (*a as i16 - *b as i16) as f32;
1251                diff * diff
1252            })
1253            .sum::<f32>()
1254            .sqrt()
1255    }
1256
1257    /// Compute asymmetric L2 distance (query vs quantized) without full reconstruction
1258    ///
1259    /// This is the hot path for search! It unpacks quantized values on the fly and
1260    /// computes distance against the uncompressed query vector.
1261    ///
1262    /// When trained: Uses per-dimension min/max for correct distance computation.
1263    /// When untrained: Falls back to per-vector scale (deprecated, lower accuracy).
1264    #[must_use]
1265    pub fn distance_asymmetric_l2(&self, query: &[f32], quantized: &QuantizedVector) -> f32 {
1266        // Use trained parameters when available for correct distance computation
1267        if let Some(trained) = &self.trained {
1268            return self.distance_asymmetric_l2_trained(query, &quantized.data, trained);
1269        }
1270        // Fallback to per-vector scale (deprecated, only for untrained quantizers)
1271        self.distance_asymmetric_l2_raw(query, &quantized.data, quantized.scale, quantized.bits)
1272    }
1273
1274    /// Asymmetric L2 distance from flat storage (no QuantizedVector wrapper)
1275    ///
1276    /// This is the preferred method for flat contiguous storage layouts.
1277    /// When trained: Uses per-dimension min/max (scale ignored).
1278    /// When untrained: Falls back to per-vector scale.
1279    #[must_use]
1280    #[inline]
1281    pub fn distance_asymmetric_l2_flat(&self, query: &[f32], data: &[u8], scale: f32) -> f32 {
1282        if let Some(trained) = &self.trained {
1283            return self.distance_asymmetric_l2_trained(query, data, trained);
1284        }
1285        // Fallback to per-vector scale (deprecated, only for untrained quantizers)
1286        self.distance_asymmetric_l2_raw(query, data, scale, self.params.bits_per_dim.to_u8())
1287    }
1288
1289    /// Asymmetric L2 distance using trained per-dimension parameters
1290    ///
1291    /// This is the correct distance computation for trained quantizers.
1292    /// Each dimension uses its own min/max range for accurate dequantization.
1293    #[must_use]
1294    fn distance_asymmetric_l2_trained(
1295        &self,
1296        query: &[f32],
1297        data: &[u8],
1298        trained: &TrainedParams,
1299    ) -> f32 {
1300        let levels = self.params.bits_per_dim.levels() as f32;
1301        let bits = self.params.bits_per_dim.to_u8();
1302
1303        // Use SmallVec for stack allocation when possible
1304        let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1305
1306        match bits {
1307            4 => {
1308                let num_pairs = query.len() / 2;
1309                if data.len() < query.len().div_ceil(2) {
1310                    return f32::MAX;
1311                }
1312
1313                for i in 0..num_pairs {
1314                    let byte = unsafe { *data.get_unchecked(i) };
1315                    let d0 = i * 2;
1316                    let d1 = i * 2 + 1;
1317
1318                    // Dequantize using per-dimension min/max
1319                    let code0 = (byte >> 4) as f32;
1320                    let code1 = (byte & 0x0F) as f32;
1321
1322                    let range0 = trained.maxs[d0] - trained.mins[d0];
1323                    let range1 = trained.maxs[d1] - trained.mins[d1];
1324
1325                    buffer.push((code0 / (levels - 1.0)) * range0 + trained.mins[d0]);
1326                    buffer.push((code1 / (levels - 1.0)) * range1 + trained.mins[d1]);
1327                }
1328
1329                if !query.len().is_multiple_of(2) {
1330                    let byte = unsafe { *data.get_unchecked(num_pairs) };
1331                    let d = num_pairs * 2;
1332                    let code = (byte >> 4) as f32;
1333                    let range = trained.maxs[d] - trained.mins[d];
1334                    buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1335                }
1336            }
1337            2 => {
1338                let num_quads = query.len() / 4;
1339                if data.len() < query.len().div_ceil(4) {
1340                    return f32::MAX;
1341                }
1342
1343                for i in 0..num_quads {
1344                    let byte = unsafe { *data.get_unchecked(i) };
1345                    for j in 0..4 {
1346                        let d = i * 4 + j;
1347                        let code = ((byte >> (j * 2)) & 0b11) as f32;
1348                        let range = trained.maxs[d] - trained.mins[d];
1349                        buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1350                    }
1351                }
1352
1353                let remaining = query.len() % 4;
1354                if remaining > 0 {
1355                    let byte = unsafe { *data.get_unchecked(num_quads) };
1356                    for j in 0..remaining {
1357                        let d = num_quads * 4 + j;
1358                        let code = ((byte >> (j * 2)) & 0b11) as f32;
1359                        let range = trained.maxs[d] - trained.mins[d];
1360                        buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1361                    }
1362                }
1363            }
1364            8 => {
1365                if data.len() < query.len() {
1366                    return f32::MAX;
1367                }
1368                for (d, &byte) in data.iter().enumerate().take(query.len()) {
1369                    let code = byte as f32;
1370                    let range = trained.maxs[d] - trained.mins[d];
1371                    buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1372                }
1373            }
1374            _ => {
1375                // Generic fallback for other bit widths
1376                let unpacked = self.unpack_quantized(data, bits, query.len());
1377                for (d, &code) in unpacked.iter().enumerate().take(query.len()) {
1378                    let range = trained.maxs[d] - trained.mins[d];
1379                    buffer.push((code as f32 / (levels - 1.0)) * range + trained.mins[d]);
1380                }
1381            }
1382        }
1383
1384        simd_l2_distance(query, &buffer)
1385    }
1386
1387    /// Build an ADC (Asymmetric Distance Computation) lookup table for a query
1388    ///
1389    /// If trained: uses per-dimension min/max for correct distances
1390    /// If untrained: uses provided scale (deprecated, incorrect distances)
1391    ///
1392    /// # Example
1393    ///
1394    /// ```ignore
1395    /// // Preferred: train first, then build ADC table
1396    /// quantizer.train(&sample_vectors);
1397    /// let adc_table = quantizer.build_adc_table(&query);
1398    /// for candidate in candidates {
1399    ///     let dist = adc_table.distance(&candidate.data);
1400    /// }
1401    /// ```
1402    #[must_use]
1403    pub fn build_adc_table(&self, query: &[f32]) -> Option<ADCTable> {
1404        self.trained
1405            .as_ref()
1406            .map(|trained| ADCTable::new_trained(query, trained, &self.params))
1407    }
1408
1409    /// Build ADC table with explicit scale
1410    ///
1411    /// # Note
1412    /// Prefer `build_adc_table()` on a trained quantizer for correct ADC distances.
1413    /// This method uses per-vector scale which gives lower accuracy.
1414    #[must_use]
1415    pub fn build_adc_table_with_scale(&self, query: &[f32], scale: f32) -> ADCTable {
1416        ADCTable::new(query, scale, &self.params)
1417    }
1418
1419    /// Compute distance using ADC table (convenience wrapper)
1420    ///
1421    /// Returns None if quantizer is not trained.
1422    #[must_use]
1423    pub fn distance_with_adc(&self, query: &[f32], quantized: &QuantizedVector) -> Option<f32> {
1424        let adc = self.build_adc_table(query)?;
1425        Some(adc.distance(&quantized.data))
1426    }
1427
1428    /// Low-level asymmetric distance computation on raw bytes
1429    ///
1430    /// Enables zero-copy access from mmap (no `QuantizedVector` allocation needed).
1431    /// Uses `SmallVec` to unpack on stack and SIMD for distance computation.
1432    #[must_use]
1433    pub fn distance_asymmetric_l2_raw(
1434        &self,
1435        query: &[f32],
1436        data: &[u8],
1437        scale: f32,
1438        bits: u8,
1439    ) -> f32 {
1440        let levels = self.params.bits_per_dim.levels() as f32;
1441
1442        // Dequantization factor: value = q * factor
1443        // derived from: q / (levels - 1) / scale
1444        // So factor = 1.0 / ((levels - 1.0) * scale)
1445        let factor = 1.0 / ((levels - 1.0) * scale);
1446
1447        match bits {
1448            4 => {
1449                // Unpack to stack buffer (up to 256 dims = 1KB). Falls back to heap for larger.
1450                let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1451
1452                let num_pairs = query.len() / 2;
1453
1454                // Check bounds once to avoid checks in loop
1455                if data.len() < query.len().div_ceil(2) {
1456                    // Fallback if data is truncated (shouldn't happen in valid storage)
1457                    return f32::MAX;
1458                }
1459
1460                for i in 0..num_pairs {
1461                    let byte = unsafe { *data.get_unchecked(i) };
1462                    buffer.push((byte >> 4) as f32 * factor);
1463                    buffer.push((byte & 0x0F) as f32 * factor);
1464                }
1465
1466                if !query.len().is_multiple_of(2) {
1467                    let byte = unsafe { *data.get_unchecked(num_pairs) };
1468                    buffer.push((byte >> 4) as f32 * factor);
1469                }
1470
1471                simd_l2_distance(query, &buffer)
1472            }
1473            2 => {
1474                let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1475                let num_quads = query.len() / 4;
1476
1477                if data.len() < query.len().div_ceil(4) {
1478                    return f32::MAX;
1479                }
1480
1481                for i in 0..num_quads {
1482                    let byte = unsafe { *data.get_unchecked(i) };
1483                    buffer.push((byte & 0b11) as f32 * factor);
1484                    buffer.push(((byte >> 2) & 0b11) as f32 * factor);
1485                    buffer.push(((byte >> 4) & 0b11) as f32 * factor);
1486                    buffer.push(((byte >> 6) & 0b11) as f32 * factor);
1487                }
1488
1489                // Handle remainder
1490                let remaining = query.len() % 4;
1491                if remaining > 0 {
1492                    let byte = unsafe { *data.get_unchecked(num_quads) };
1493                    for i in 0..remaining {
1494                        buffer.push(((byte >> (i * 2)) & 0b11) as f32 * factor);
1495                    }
1496                }
1497
1498                simd_l2_distance(query, &buffer)
1499            }
1500            _ => {
1501                // Generic fallback using existing unpack (allocates Vec if > 256, but correct)
1502                // Actually unpack_quantized returns Vec<u8>, so this path allocates.
1503                // That's fine for non-optimized bit widths.
1504                let unpacked = self.unpack_quantized(data, bits, query.len());
1505                let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1506
1507                for &q in &unpacked {
1508                    buffer.push(q as f32 * factor);
1509                }
1510
1511                simd_l2_distance(query, &buffer)
1512            }
1513        }
1514    }
1515
1516    // SIMD-optimized distance functions
1517
1518    /// Compute L2 distance using SIMD acceleration
1519    ///
1520    /// Uses runtime CPU detection to select the best SIMD implementation:
1521    /// - `x86_64`: AVX2 > SSE2 > scalar
1522    /// - aarch64: NEON > scalar
1523    #[inline]
1524    #[must_use]
1525    pub fn distance_l2_simd(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1526        // Reconstruct to f32 vectors
1527        let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1528        let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1529
1530        // Use SIMD distance computation
1531        simd_l2_distance(&v1, &v2)
1532    }
1533
1534    /// Compute cosine distance using SIMD acceleration
1535    #[inline]
1536    #[must_use]
1537    pub fn distance_cosine_simd(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1538        let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1539        let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1540
1541        simd_cosine_distance(&v1, &v2)
1542    }
1543}
1544
1545// SIMD distance computation functions
1546
1547/// Compute L2 distance using SIMD
1548#[inline]
1549fn simd_l2_distance(v1: &[f32], v2: &[f32]) -> f32 {
1550    #[cfg(target_arch = "x86_64")]
1551    {
1552        if is_x86_feature_detected!("avx2") {
1553            return unsafe { l2_distance_avx2(v1, v2) };
1554        } else if is_x86_feature_detected!("sse2") {
1555            return unsafe { l2_distance_sse2(v1, v2) };
1556        }
1557        // Scalar fallback for x86_64 without SIMD
1558        l2_distance_scalar(v1, v2)
1559    }
1560
1561    #[cfg(target_arch = "aarch64")]
1562    {
1563        // NEON always available on aarch64
1564        unsafe { l2_distance_neon(v1, v2) }
1565    }
1566
1567    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1568    {
1569        // Scalar fallback for other architectures
1570        l2_distance_scalar(v1, v2)
1571    }
1572}
1573
1574/// Compute cosine distance using SIMD
1575#[inline]
1576fn simd_cosine_distance(v1: &[f32], v2: &[f32]) -> f32 {
1577    #[cfg(target_arch = "x86_64")]
1578    {
1579        if is_x86_feature_detected!("avx2") {
1580            return unsafe { cosine_distance_avx2(v1, v2) };
1581        } else if is_x86_feature_detected!("sse2") {
1582            return unsafe { cosine_distance_sse2(v1, v2) };
1583        }
1584        // Scalar fallback for x86_64 without SIMD
1585        cosine_distance_scalar(v1, v2)
1586    }
1587
1588    #[cfg(target_arch = "aarch64")]
1589    {
1590        // NEON always available on aarch64
1591        unsafe { cosine_distance_neon(v1, v2) }
1592    }
1593
1594    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1595    {
1596        // Scalar fallback for other architectures
1597        cosine_distance_scalar(v1, v2)
1598    }
1599}
1600
1601// Scalar implementations
1602
1603#[inline]
1604#[allow(dead_code)] // Used as SIMD fallback on x86_64 without AVX2/SSE2
1605fn l2_distance_scalar(v1: &[f32], v2: &[f32]) -> f32 {
1606    v1.iter()
1607        .zip(v2.iter())
1608        .map(|(a, b)| (a - b).powi(2))
1609        .sum::<f32>()
1610        .sqrt()
1611}
1612
1613#[inline]
1614#[allow(dead_code)] // Used as SIMD fallback on x86_64 without AVX2/SSE2
1615fn cosine_distance_scalar(v1: &[f32], v2: &[f32]) -> f32 {
1616    let dot: f32 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
1617    let norm1: f32 = v1.iter().map(|a| a * a).sum::<f32>().sqrt();
1618    let norm2: f32 = v2.iter().map(|b| b * b).sum::<f32>().sqrt();
1619
1620    if norm1 < 1e-10 || norm2 < 1e-10 {
1621        return 1.0;
1622    }
1623
1624    let cosine_sim = dot / (norm1 * norm2);
1625    1.0 - cosine_sim
1626}
1627
1628// AVX2 implementations (x86_64)
1629
1630#[cfg(target_arch = "x86_64")]
1631#[target_feature(enable = "avx2")]
1632#[target_feature(enable = "fma")]
1633unsafe fn l2_distance_avx2(v1: &[f32], v2: &[f32]) -> f32 {
1634    unsafe {
1635        let len = v1.len().min(v2.len());
1636        let mut sum = _mm256_setzero_ps();
1637
1638        let chunks = len / 8;
1639        for i in 0..chunks {
1640            let a = _mm256_loadu_ps(v1.as_ptr().add(i * 8));
1641            let b = _mm256_loadu_ps(v2.as_ptr().add(i * 8));
1642            let diff = _mm256_sub_ps(a, b);
1643            sum = _mm256_fmadd_ps(diff, diff, sum);
1644        }
1645
1646        // Horizontal sum
1647        let sum_high = _mm256_extractf128_ps(sum, 1);
1648        let sum_low = _mm256_castps256_ps128(sum);
1649        let sum128 = _mm_add_ps(sum_low, sum_high);
1650        let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
1651        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1652        let mut result = _mm_cvtss_f32(sum32);
1653
1654        // Handle remainder
1655        for i in (chunks * 8)..len {
1656            let diff = v1[i] - v2[i];
1657            result += diff * diff;
1658        }
1659
1660        result.sqrt()
1661    }
1662}
1663
1664#[cfg(target_arch = "x86_64")]
1665#[target_feature(enable = "avx2")]
1666#[target_feature(enable = "fma")]
1667unsafe fn cosine_distance_avx2(v1: &[f32], v2: &[f32]) -> f32 {
1668    unsafe {
1669        let len = v1.len().min(v2.len());
1670        let mut dot_sum = _mm256_setzero_ps();
1671        let mut norm1_sum = _mm256_setzero_ps();
1672        let mut norm2_sum = _mm256_setzero_ps();
1673
1674        let chunks = len / 8;
1675        for i in 0..chunks {
1676            let a = _mm256_loadu_ps(v1.as_ptr().add(i * 8));
1677            let b = _mm256_loadu_ps(v2.as_ptr().add(i * 8));
1678            dot_sum = _mm256_fmadd_ps(a, b, dot_sum);
1679            norm1_sum = _mm256_fmadd_ps(a, a, norm1_sum);
1680            norm2_sum = _mm256_fmadd_ps(b, b, norm2_sum);
1681        }
1682
1683        // Horizontal sums
1684        let mut dot = horizontal_sum_avx2(dot_sum);
1685        let mut norm1 = horizontal_sum_avx2(norm1_sum);
1686        let mut norm2 = horizontal_sum_avx2(norm2_sum);
1687
1688        // Handle remainder
1689        for i in (chunks * 8)..len {
1690            dot += v1[i] * v2[i];
1691            norm1 += v1[i] * v1[i];
1692            norm2 += v2[i] * v2[i];
1693        }
1694
1695        if norm1 < 1e-10 || norm2 < 1e-10 {
1696            return 1.0;
1697        }
1698
1699        let cosine_sim = dot / (norm1.sqrt() * norm2.sqrt());
1700        1.0 - cosine_sim
1701    }
1702}
1703
1704#[cfg(target_arch = "x86_64")]
1705#[inline]
1706unsafe fn horizontal_sum_avx2(v: __m256) -> f32 {
1707    unsafe {
1708        let sum_high = _mm256_extractf128_ps(v, 1);
1709        let sum_low = _mm256_castps256_ps128(v);
1710        let sum128 = _mm_add_ps(sum_low, sum_high);
1711        let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
1712        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1713        _mm_cvtss_f32(sum32)
1714    }
1715}
1716
1717// SSE2 implementations (x86_64 fallback)
1718
1719#[cfg(target_arch = "x86_64")]
1720#[target_feature(enable = "sse2")]
1721unsafe fn l2_distance_sse2(v1: &[f32], v2: &[f32]) -> f32 {
1722    unsafe {
1723        let len = v1.len().min(v2.len());
1724        let mut sum = _mm_setzero_ps();
1725
1726        let chunks = len / 4;
1727        for i in 0..chunks {
1728            let a = _mm_loadu_ps(v1.as_ptr().add(i * 4));
1729            let b = _mm_loadu_ps(v2.as_ptr().add(i * 4));
1730            let diff = _mm_sub_ps(a, b);
1731            sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
1732        }
1733
1734        // Horizontal sum
1735        let sum64 = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));
1736        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1737        let mut result = _mm_cvtss_f32(sum32);
1738
1739        // Handle remainder
1740        for i in (chunks * 4)..len {
1741            let diff = v1[i] - v2[i];
1742            result += diff * diff;
1743        }
1744
1745        result.sqrt()
1746    }
1747}
1748
1749#[cfg(target_arch = "x86_64")]
1750#[target_feature(enable = "sse2")]
1751unsafe fn cosine_distance_sse2(v1: &[f32], v2: &[f32]) -> f32 {
1752    unsafe {
1753        let len = v1.len().min(v2.len());
1754        let mut dot_sum = _mm_setzero_ps();
1755        let mut norm1_sum = _mm_setzero_ps();
1756        let mut norm2_sum = _mm_setzero_ps();
1757
1758        let chunks = len / 4;
1759        for i in 0..chunks {
1760            let a = _mm_loadu_ps(v1.as_ptr().add(i * 4));
1761            let b = _mm_loadu_ps(v2.as_ptr().add(i * 4));
1762            dot_sum = _mm_add_ps(dot_sum, _mm_mul_ps(a, b));
1763            norm1_sum = _mm_add_ps(norm1_sum, _mm_mul_ps(a, a));
1764            norm2_sum = _mm_add_ps(norm2_sum, _mm_mul_ps(b, b));
1765        }
1766
1767        // Horizontal sums
1768        let mut dot = horizontal_sum_sse2(dot_sum);
1769        let mut norm1 = horizontal_sum_sse2(norm1_sum);
1770        let mut norm2 = horizontal_sum_sse2(norm2_sum);
1771
1772        // Handle remainder
1773        for i in (chunks * 4)..len {
1774            dot += v1[i] * v2[i];
1775            norm1 += v1[i] * v1[i];
1776            norm2 += v2[i] * v2[i];
1777        }
1778
1779        if norm1 < 1e-10 || norm2 < 1e-10 {
1780            return 1.0;
1781        }
1782
1783        let cosine_sim = dot / (norm1.sqrt() * norm2.sqrt());
1784        1.0 - cosine_sim
1785    }
1786}
1787
1788#[cfg(target_arch = "x86_64")]
1789#[inline]
1790unsafe fn horizontal_sum_sse2(v: __m128) -> f32 {
1791    unsafe {
1792        let sum64 = _mm_add_ps(v, _mm_movehl_ps(v, v));
1793        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1794        _mm_cvtss_f32(sum32)
1795    }
1796}
1797
1798// NEON implementations (aarch64)
1799
1800#[cfg(target_arch = "aarch64")]
1801unsafe fn l2_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
1802    let len = v1.len().min(v2.len());
1803
1804    // SAFETY: All SIMD operations wrapped in unsafe block for Rust 2024
1805    unsafe {
1806        let mut sum = vdupq_n_f32(0.0);
1807
1808        let chunks = len / 4;
1809        for i in 0..chunks {
1810            let a = vld1q_f32(v1.as_ptr().add(i * 4));
1811            let b = vld1q_f32(v2.as_ptr().add(i * 4));
1812            let diff = vsubq_f32(a, b);
1813            sum = vfmaq_f32(sum, diff, diff);
1814        }
1815
1816        // Horizontal sum
1817        let mut result = vaddvq_f32(sum);
1818
1819        // Handle remainder
1820        for i in (chunks * 4)..len {
1821            let diff = v1[i] - v2[i];
1822            result += diff * diff;
1823        }
1824
1825        result.sqrt()
1826    }
1827}
1828
1829#[cfg(target_arch = "aarch64")]
1830unsafe fn cosine_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
1831    let len = v1.len().min(v2.len());
1832
1833    // SAFETY: All SIMD operations wrapped in unsafe block for Rust 2024
1834    unsafe {
1835        let mut dot_sum = vdupq_n_f32(0.0);
1836        let mut norm1_sum = vdupq_n_f32(0.0);
1837        let mut norm2_sum = vdupq_n_f32(0.0);
1838
1839        let chunks = len / 4;
1840        for i in 0..chunks {
1841            let a = vld1q_f32(v1.as_ptr().add(i * 4));
1842            let b = vld1q_f32(v2.as_ptr().add(i * 4));
1843            dot_sum = vfmaq_f32(dot_sum, a, b);
1844            norm1_sum = vfmaq_f32(norm1_sum, a, a);
1845            norm2_sum = vfmaq_f32(norm2_sum, b, b);
1846        }
1847
1848        // Horizontal sums
1849        let mut dot = vaddvq_f32(dot_sum);
1850        let mut norm1 = vaddvq_f32(norm1_sum);
1851        let mut norm2 = vaddvq_f32(norm2_sum);
1852
1853        // Handle remainder
1854        for i in (chunks * 4)..len {
1855            dot += v1[i] * v2[i];
1856            norm1 += v1[i] * v1[i];
1857            norm2 += v2[i] * v2[i];
1858        }
1859
1860        if norm1 < 1e-10 || norm2 < 1e-10 {
1861            return 1.0;
1862        }
1863
1864        let cosine_sim = dot / (norm1.sqrt() * norm2.sqrt());
1865        1.0 - cosine_sim
1866    }
1867}
1868
1869#[cfg(test)]
1870#[allow(clippy::float_cmp)]
1871mod tests {
1872    use super::*;
1873
1874    #[test]
1875    fn test_quantization_bits_conversion() {
1876        assert_eq!(QuantizationBits::Bits2.to_u8(), 2);
1877        assert_eq!(QuantizationBits::Bits4.to_u8(), 4);
1878        assert_eq!(QuantizationBits::Bits8.to_u8(), 8);
1879    }
1880
1881    #[test]
1882    fn test_quantization_bits_levels() {
1883        assert_eq!(QuantizationBits::Bits2.levels(), 4); // 2^2
1884        assert_eq!(QuantizationBits::Bits4.levels(), 16); // 2^4
1885        assert_eq!(QuantizationBits::Bits8.levels(), 256); // 2^8
1886    }
1887
1888    #[test]
1889    fn test_quantization_bits_compression() {
1890        assert_eq!(QuantizationBits::Bits2.compression_ratio(), 16.0); // 32/2
1891        assert_eq!(QuantizationBits::Bits4.compression_ratio(), 8.0); // 32/4
1892        assert_eq!(QuantizationBits::Bits8.compression_ratio(), 4.0); // 32/8
1893    }
1894
1895    #[test]
1896    fn test_quantization_bits_values_per_byte() {
1897        assert_eq!(QuantizationBits::Bits2.values_per_byte(), 4); // 8/2
1898        assert_eq!(QuantizationBits::Bits4.values_per_byte(), 2); // 8/4
1899        assert_eq!(QuantizationBits::Bits8.values_per_byte(), 1); // 8/8
1900    }
1901
1902    #[test]
1903    fn test_default_params() {
1904        let params = RaBitQParams::default();
1905        assert_eq!(params.bits_per_dim, QuantizationBits::Bits4);
1906        assert_eq!(params.num_rescale_factors, 12);
1907        assert_eq!(params.rescale_range, (0.5, 2.0));
1908    }
1909
1910    #[test]
1911    fn test_preset_params() {
1912        let params2 = RaBitQParams::bits2();
1913        assert_eq!(params2.bits_per_dim, QuantizationBits::Bits2);
1914
1915        let params4 = RaBitQParams::bits4();
1916        assert_eq!(params4.bits_per_dim, QuantizationBits::Bits4);
1917
1918        let params8 = RaBitQParams::bits8();
1919        assert_eq!(params8.bits_per_dim, QuantizationBits::Bits8);
1920        assert_eq!(params8.num_rescale_factors, 16);
1921    }
1922
1923    #[test]
1924    fn test_quantized_vector_creation() {
1925        let data = vec![0u8, 128, 255];
1926        let qv = QuantizedVector::new(data.clone(), 1.5, 8, 3);
1927
1928        assert_eq!(qv.data, data);
1929        assert_eq!(qv.scale, 1.5);
1930        assert_eq!(qv.bits, 8);
1931        assert_eq!(qv.dimensions, 3);
1932    }
1933
1934    #[test]
1935    fn test_quantized_vector_memory() {
1936        let data = vec![0u8; 16]; // 16 bytes
1937        let qv = QuantizedVector::new(data, 1.0, 4, 32);
1938
1939        // Should be: struct overhead + data length
1940        let expected_min = 16; // At least the data
1941        assert!(qv.memory_bytes() >= expected_min);
1942    }
1943
1944    #[test]
1945    fn test_quantized_vector_compression_ratio() {
1946        // 128 dimensions, 4-bit = 64 bytes
1947        let data = vec![0u8; 64];
1948        let qv = QuantizedVector::new(data, 1.0, 4, 128);
1949
1950        // Original: 128 * 4 = 512 bytes
1951        // Compressed: 64 + 4 (scale) + 1 (bits) = 69 bytes
1952        // Ratio: 512 / 69 ≈ 7.4x
1953        let ratio = qv.compression_ratio();
1954        assert!(ratio > 7.0 && ratio < 8.0);
1955    }
1956
1957    #[test]
1958    fn test_create_quantizer() {
1959        let quantizer = RaBitQ::default_4bit();
1960        assert_eq!(quantizer.params().bits_per_dim, QuantizationBits::Bits4);
1961    }
1962
1963    // Phase 2 Tests: Core Algorithm
1964
1965    #[test]
1966    fn test_generate_scales() {
1967        let quantizer = RaBitQ::new(RaBitQParams {
1968            bits_per_dim: QuantizationBits::Bits4,
1969            num_rescale_factors: 5,
1970            rescale_range: (0.5, 1.5),
1971        });
1972
1973        let scales = quantizer.generate_scales();
1974        assert_eq!(scales.len(), 5);
1975        assert_eq!(scales[0], 0.5);
1976        assert_eq!(scales[4], 1.5);
1977        assert!((scales[2] - 1.0).abs() < 0.01); // Middle should be ~1.0
1978    }
1979
1980    #[test]
1981    fn test_generate_scales_single() {
1982        let quantizer = RaBitQ::new(RaBitQParams {
1983            bits_per_dim: QuantizationBits::Bits4,
1984            num_rescale_factors: 1,
1985            rescale_range: (0.5, 1.5),
1986        });
1987
1988        let scales = quantizer.generate_scales();
1989        assert_eq!(scales.len(), 1);
1990        assert_eq!(scales[0], 1.0); // Average of min and max
1991    }
1992
1993    #[test]
1994    fn test_pack_unpack_2bit() {
1995        let quantizer = RaBitQ::new(RaBitQParams {
1996            bits_per_dim: QuantizationBits::Bits2,
1997            ..Default::default()
1998        });
1999
2000        // 8 values (2 bits each) = 2 bytes
2001        let values = vec![0u8, 1, 2, 3, 0, 1, 2, 3];
2002        let packed = quantizer.pack_quantized(&values, 2);
2003        assert_eq!(packed.len(), 2); // 8 values / 4 per byte = 2 bytes
2004
2005        let unpacked = quantizer.unpack_quantized(&packed, 2, 8);
2006        assert_eq!(unpacked, values);
2007    }
2008
2009    #[test]
2010    fn test_pack_unpack_4bit() {
2011        let quantizer = RaBitQ::new(RaBitQParams {
2012            bits_per_dim: QuantizationBits::Bits4,
2013            ..Default::default()
2014        });
2015
2016        // 8 values (4 bits each) = 4 bytes
2017        let values = vec![0u8, 1, 2, 3, 4, 5, 6, 7];
2018        let packed = quantizer.pack_quantized(&values, 4);
2019        assert_eq!(packed.len(), 4); // 8 values / 2 per byte = 4 bytes
2020
2021        let unpacked = quantizer.unpack_quantized(&packed, 4, 8);
2022        assert_eq!(unpacked, values);
2023    }
2024
2025    #[test]
2026    fn test_pack_unpack_8bit() {
2027        let quantizer = RaBitQ::new(RaBitQParams {
2028            bits_per_dim: QuantizationBits::Bits8,
2029            ..Default::default()
2030        });
2031
2032        // 8 values (8 bits each) = 8 bytes
2033        let values = vec![0u8, 10, 20, 30, 40, 50, 60, 70];
2034        let packed = quantizer.pack_quantized(&values, 8);
2035        assert_eq!(packed.len(), 8); // 8 values = 8 bytes
2036
2037        let unpacked = quantizer.unpack_quantized(&packed, 8, 8);
2038        assert_eq!(unpacked, values);
2039    }
2040
2041    #[test]
2042    fn test_quantize_simple_vector() {
2043        let quantizer = RaBitQ::new(RaBitQParams {
2044            bits_per_dim: QuantizationBits::Bits4,
2045            num_rescale_factors: 4,
2046            rescale_range: (0.5, 1.5),
2047        });
2048
2049        // Simple vector: [0.0, 0.25, 0.5, 0.75, 1.0]
2050        let vector = vec![0.0, 0.25, 0.5, 0.75, 1.0];
2051        let quantized = quantizer.quantize(&vector);
2052
2053        // Check structure
2054        assert_eq!(quantized.dimensions, 5);
2055        assert_eq!(quantized.bits, 4);
2056        assert!(quantized.scale > 0.0);
2057
2058        // Check compression: 5 floats * 4 bytes = 20 bytes original
2059        // Quantized: 5 values * 4 bits = 20 bits = 3 bytes (rounded up)
2060        assert!(quantized.data.len() <= 4);
2061    }
2062
2063    #[test]
2064    fn test_quantize_reconstruct_accuracy() {
2065        let quantizer = RaBitQ::new(RaBitQParams {
2066            bits_per_dim: QuantizationBits::Bits8, // High precision
2067            num_rescale_factors: 8,
2068            rescale_range: (0.8, 1.2),
2069        });
2070
2071        // Test vector
2072        let vector = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2073        let quantized = quantizer.quantize(&vector);
2074
2075        // Reconstruct
2076        let reconstructed = quantizer.reconstruct(&quantized.data, quantized.scale, vector.len());
2077
2078        // Check reconstruction is close (8-bit should be accurate)
2079        for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
2080            let error = (orig - recon).abs();
2081            assert!(error < 0.1, "Error too large: {orig} vs {recon}");
2082        }
2083    }
2084
2085    #[test]
2086    fn test_quantize_uniform_vector() {
2087        let quantizer = RaBitQ::default_4bit();
2088
2089        // All values the same
2090        let vector = vec![0.5; 10];
2091        let quantized = quantizer.quantize(&vector);
2092
2093        // Reconstruct should also be uniform
2094        let reconstructed = quantizer.reconstruct(&quantized.data, quantized.scale, vector.len());
2095
2096        // All values should be similar
2097        let avg = reconstructed.iter().sum::<f32>() / reconstructed.len() as f32;
2098        for &val in &reconstructed {
2099            assert!((val - avg).abs() < 0.2);
2100        }
2101    }
2102
2103    #[test]
2104    fn test_compute_error() {
2105        let quantizer = RaBitQ::default_4bit();
2106
2107        let original = vec![0.1, 0.2, 0.3, 0.4];
2108        let quantized_vec = quantizer.quantize(&original);
2109
2110        // Compute error
2111        let error = quantizer.compute_error(&original, &quantized_vec.data, quantized_vec.scale);
2112
2113        // Error should be non-negative and finite
2114        assert!(error >= 0.0);
2115        assert!(error.is_finite());
2116    }
2117
2118    #[test]
2119    fn test_quantize_different_bit_widths() {
2120        let test_vector = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2121
2122        // Test 2-bit
2123        let q2 = RaBitQ::new(RaBitQParams::bits2());
2124        let qv2 = q2.quantize(&test_vector);
2125        assert_eq!(qv2.bits, 2);
2126
2127        // Test 4-bit
2128        let q4 = RaBitQ::default_4bit();
2129        let qv4 = q4.quantize(&test_vector);
2130        assert_eq!(qv4.bits, 4);
2131
2132        // Test 8-bit
2133        let q8 = RaBitQ::new(RaBitQParams::bits8());
2134        let qv8 = q8.quantize(&test_vector);
2135        assert_eq!(qv8.bits, 8);
2136
2137        // Higher bits = larger packed size (for same dimensions)
2138        assert!(qv2.data.len() <= qv4.data.len());
2139        assert!(qv4.data.len() <= qv8.data.len());
2140    }
2141
2142    #[test]
2143    fn test_quantize_high_dimensional() {
2144        let quantizer = RaBitQ::default_4bit();
2145
2146        // 128D vector (like small embeddings)
2147        let vector: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2148        let quantized = quantizer.quantize(&vector);
2149
2150        assert_eq!(quantized.dimensions, 128);
2151        assert_eq!(quantized.bits, 4);
2152
2153        // 128 dimensions * 4 bits = 512 bits = 64 bytes
2154        assert_eq!(quantized.data.len(), 64);
2155
2156        // Verify reconstruction
2157        let reconstructed = quantizer.reconstruct(&quantized.data, quantized.scale, 128);
2158        assert_eq!(reconstructed.len(), 128);
2159    }
2160
2161    // Phase 3 Tests: Distance Computation
2162
2163    #[test]
2164    fn test_distance_l2() {
2165        let quantizer = RaBitQ::new(RaBitQParams {
2166            bits_per_dim: QuantizationBits::Bits8, // High precision
2167            num_rescale_factors: 8,
2168            rescale_range: (0.8, 1.2),
2169        });
2170
2171        let v1 = vec![0.0, 0.0, 0.0];
2172        let v2 = vec![1.0, 0.0, 0.0];
2173
2174        let qv1 = quantizer.quantize(&v1);
2175        let qv2 = quantizer.quantize(&v2);
2176
2177        let dist = quantizer.distance_l2(&qv1, &qv2);
2178
2179        // Distance should be approximately 1.0
2180        assert!((dist - 1.0).abs() < 0.2, "Distance: {dist}");
2181    }
2182
2183    #[test]
2184    fn test_distance_l2_identical() {
2185        let quantizer = RaBitQ::default_4bit();
2186
2187        let v = vec![0.5, 0.3, 0.8, 0.2];
2188        let qv1 = quantizer.quantize(&v);
2189        let qv2 = quantizer.quantize(&v);
2190
2191        let dist = quantizer.distance_l2(&qv1, &qv2);
2192
2193        // Identical vectors should have near-zero distance
2194        assert!(dist < 0.3, "Distance should be near zero, got: {dist}");
2195    }
2196
2197    #[test]
2198    fn test_distance_cosine() {
2199        let quantizer = RaBitQ::new(RaBitQParams {
2200            bits_per_dim: QuantizationBits::Bits8,
2201            num_rescale_factors: 8,
2202            rescale_range: (0.8, 1.2),
2203        });
2204
2205        // Orthogonal vectors
2206        let v1 = vec![1.0, 0.0, 0.0];
2207        let v2 = vec![0.0, 1.0, 0.0];
2208
2209        let qv1 = quantizer.quantize(&v1);
2210        let qv2 = quantizer.quantize(&v2);
2211
2212        let dist = quantizer.distance_cosine(&qv1, &qv2);
2213
2214        // Orthogonal vectors: cosine = 0, distance = 1
2215        assert!((dist - 1.0).abs() < 0.3, "Distance: {dist}");
2216    }
2217
2218    #[test]
2219    fn test_distance_cosine_identical() {
2220        let quantizer = RaBitQ::default_4bit();
2221
2222        let v = vec![0.5, 0.3, 0.8];
2223        let qv1 = quantizer.quantize(&v);
2224        let qv2 = quantizer.quantize(&v);
2225
2226        let dist = quantizer.distance_cosine(&qv1, &qv2);
2227
2228        // Identical vectors: cosine = 1, distance = 0
2229        assert!(dist < 0.2, "Distance should be near zero, got: {dist}");
2230    }
2231
2232    #[test]
2233    fn test_distance_dot() {
2234        let quantizer = RaBitQ::new(RaBitQParams {
2235            bits_per_dim: QuantizationBits::Bits8,
2236            num_rescale_factors: 8,
2237            rescale_range: (0.8, 1.2),
2238        });
2239
2240        let v1 = vec![1.0, 0.0, 0.0];
2241        let v2 = vec![1.0, 0.0, 0.0];
2242
2243        let qv1 = quantizer.quantize(&v1);
2244        let qv2 = quantizer.quantize(&v2);
2245
2246        let dist = quantizer.distance_dot(&qv1, &qv2);
2247
2248        // Dot product of [1,0,0] with itself = 1, negated = -1
2249        assert!((dist + 1.0).abs() < 0.3, "Distance: {dist}");
2250    }
2251
2252    #[test]
2253    fn test_distance_approximate() {
2254        let quantizer = RaBitQ::default_4bit();
2255
2256        let v1 = vec![0.0, 0.0, 0.0];
2257        let v2 = vec![0.5, 0.5, 0.5];
2258
2259        let qv1 = quantizer.quantize(&v1);
2260        let qv2 = quantizer.quantize(&v2);
2261
2262        let dist_approx = quantizer.distance_approximate(&qv1, &qv2);
2263        let dist_exact = quantizer.distance_l2(&qv1, &qv2);
2264
2265        // Approximate should be non-negative and finite
2266        assert!(dist_approx >= 0.0);
2267        assert!(dist_approx.is_finite());
2268
2269        // Approximate and exact should be correlated (not exact match)
2270        // Just verify both increase/decrease together
2271        let v3 = vec![1.0, 1.0, 1.0];
2272        let qv3 = quantizer.quantize(&v3);
2273
2274        let dist_approx2 = quantizer.distance_approximate(&qv1, &qv3);
2275        let dist_exact2 = quantizer.distance_l2(&qv1, &qv3);
2276
2277        // If v3 is farther from v1 than v2, both metrics should reflect that
2278        if dist_exact2 > dist_exact {
2279            assert!(dist_approx2 > dist_approx * 0.5); // Allow some variance
2280        }
2281    }
2282
2283    #[test]
2284    fn test_distance_correlation() {
2285        let quantizer = RaBitQ::new(RaBitQParams {
2286            bits_per_dim: QuantizationBits::Bits8, // High precision for correlation
2287            num_rescale_factors: 12,
2288            rescale_range: (0.8, 1.2),
2289        });
2290
2291        // Create multiple vectors
2292        let vectors = [
2293            vec![0.1, 0.2, 0.3],
2294            vec![0.4, 0.5, 0.6],
2295            vec![0.7, 0.8, 0.9],
2296        ];
2297
2298        // Quantize all
2299        let quantized: Vec<QuantizedVector> =
2300            vectors.iter().map(|v| quantizer.quantize(v)).collect();
2301
2302        // Ground truth L2 distances
2303        let ground_truth_01 = vectors[0]
2304            .iter()
2305            .zip(vectors[1].iter())
2306            .map(|(a, b)| (a - b).powi(2))
2307            .sum::<f32>()
2308            .sqrt();
2309
2310        let ground_truth_02 = vectors[0]
2311            .iter()
2312            .zip(vectors[2].iter())
2313            .map(|(a, b)| (a - b).powi(2))
2314            .sum::<f32>()
2315            .sqrt();
2316
2317        // Quantized distances
2318        let quantized_01 = quantizer.distance_l2(&quantized[0], &quantized[1]);
2319        let quantized_02 = quantizer.distance_l2(&quantized[0], &quantized[2]);
2320
2321        // Check correlation: if ground truth says v2 > v1, quantized should too
2322        if ground_truth_02 > ground_truth_01 {
2323            assert!(
2324                quantized_02 > quantized_01 * 0.8,
2325                "Order not preserved: {quantized_01} vs {quantized_02}"
2326            );
2327        }
2328    }
2329
2330    #[test]
2331    fn test_distance_zero_vectors() {
2332        let quantizer = RaBitQ::default_4bit();
2333
2334        let v_zero = vec![0.0, 0.0, 0.0];
2335        let qv_zero = quantizer.quantize(&v_zero);
2336
2337        // Distance to itself should be zero
2338        let dist = quantizer.distance_l2(&qv_zero, &qv_zero);
2339        assert!(dist < 0.1);
2340
2341        // Cosine distance with zero vector should handle gracefully
2342        let dist_cosine = quantizer.distance_cosine(&qv_zero, &qv_zero);
2343        assert!(dist_cosine.is_finite());
2344    }
2345
2346    #[test]
2347    fn test_distance_high_dimensional() {
2348        let quantizer = RaBitQ::default_4bit();
2349
2350        // 128D vectors
2351        let v1: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2352        let v2: Vec<f32> = (0..128).map(|i| ((i + 10) as f32) / 128.0).collect();
2353
2354        let qv1 = quantizer.quantize(&v1);
2355        let qv2 = quantizer.quantize(&v2);
2356
2357        // All distance metrics should work on high-dimensional vectors
2358        let dist_l2 = quantizer.distance_l2(&qv1, &qv2);
2359        let dist_cosine = quantizer.distance_cosine(&qv1, &qv2);
2360        let dist_dot = quantizer.distance_dot(&qv1, &qv2);
2361        let dist_approx = quantizer.distance_approximate(&qv1, &qv2);
2362
2363        assert!(dist_l2 > 0.0 && dist_l2.is_finite());
2364        assert!(dist_cosine >= 0.0 && dist_cosine.is_finite());
2365        assert!(dist_dot.is_finite());
2366        assert!(dist_approx > 0.0 && dist_approx.is_finite());
2367    }
2368
2369    #[test]
2370    fn test_distance_asymmetric_l2() {
2371        let quantizer = RaBitQ::default_4bit();
2372
2373        let query = vec![0.1, 0.2, 0.3, 0.4];
2374        // Vector close to query
2375        let vector = vec![0.12, 0.22, 0.32, 0.42];
2376
2377        let quantized = quantizer.quantize(&vector);
2378
2379        // Symmetric distance (allocates)
2380        let dist_sym = quantizer.distance_l2_simd(&quantized, &quantizer.quantize(&query));
2381
2382        // Asymmetric distance (no allocation)
2383        let dist_asym = quantizer.distance_asymmetric_l2(&query, &quantized);
2384
2385        // Should be reasonably close (asymmetric is actually MORE accurate because query is exact)
2386        // But for this test, just ensure it's sane
2387        assert!(dist_asym >= 0.0);
2388        assert!((dist_asym - dist_sym).abs() < 0.2);
2389    }
2390
2391    // Phase 4 Tests: SIMD Optimizations
2392
2393    #[test]
2394    fn test_simd_l2_matches_scalar() {
2395        let quantizer = RaBitQ::new(RaBitQParams {
2396            bits_per_dim: QuantizationBits::Bits8, // High precision
2397            num_rescale_factors: 8,
2398            rescale_range: (0.8, 1.2),
2399        });
2400
2401        let v1 = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2402        let v2 = vec![0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
2403
2404        let qv1 = quantizer.quantize(&v1);
2405        let qv2 = quantizer.quantize(&v2);
2406
2407        let dist_scalar = quantizer.distance_l2(&qv1, &qv2);
2408        let dist_simd = quantizer.distance_l2_simd(&qv1, &qv2);
2409
2410        // SIMD should match scalar within floating point precision
2411        let diff = (dist_scalar - dist_simd).abs();
2412        assert!(diff < 0.01, "SIMD vs scalar: {dist_simd} vs {dist_scalar}");
2413    }
2414
2415    #[test]
2416    fn test_simd_cosine_matches_scalar() {
2417        let quantizer = RaBitQ::new(RaBitQParams {
2418            bits_per_dim: QuantizationBits::Bits8,
2419            num_rescale_factors: 8,
2420            rescale_range: (0.8, 1.2),
2421        });
2422
2423        let v1 = vec![1.0, 0.0, 0.0];
2424        let v2 = vec![0.0, 1.0, 0.0];
2425
2426        let qv1 = quantizer.quantize(&v1);
2427        let qv2 = quantizer.quantize(&v2);
2428
2429        let dist_scalar = quantizer.distance_cosine(&qv1, &qv2);
2430        let dist_simd = quantizer.distance_cosine_simd(&qv1, &qv2);
2431
2432        // SIMD should match scalar within floating point precision
2433        let diff = (dist_scalar - dist_simd).abs();
2434        assert!(diff < 0.01, "SIMD vs scalar: {dist_simd} vs {dist_scalar}");
2435    }
2436
2437    #[test]
2438    fn test_simd_high_dimensional() {
2439        let quantizer = RaBitQ::default_4bit();
2440
2441        // 128D vectors (realistic embeddings)
2442        let v1: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2443        let v2: Vec<f32> = (0..128).map(|i| ((i + 1) as f32) / 128.0).collect();
2444
2445        let qv1 = quantizer.quantize(&v1);
2446        let qv2 = quantizer.quantize(&v2);
2447
2448        let dist_scalar = quantizer.distance_l2(&qv1, &qv2);
2449        let dist_simd = quantizer.distance_l2_simd(&qv1, &qv2);
2450
2451        // Should be close (allow for quantization + FP variance)
2452        let diff = (dist_scalar - dist_simd).abs();
2453        assert!(
2454            diff < 0.1,
2455            "High-D SIMD vs scalar: {dist_simd} vs {dist_scalar}"
2456        );
2457    }
2458
2459    #[test]
2460    fn test_simd_scalar_fallback() {
2461        let quantizer = RaBitQ::default_4bit();
2462
2463        // Small vector (tests remainder handling)
2464        let v1 = vec![0.1, 0.2, 0.3];
2465        let v2 = vec![0.4, 0.5, 0.6];
2466
2467        let qv1 = quantizer.quantize(&v1);
2468        let qv2 = quantizer.quantize(&v2);
2469
2470        // Should not crash on small vectors
2471        let dist_l2 = quantizer.distance_l2_simd(&qv1, &qv2);
2472        let dist_cosine = quantizer.distance_cosine_simd(&qv1, &qv2);
2473
2474        assert!(dist_l2.is_finite());
2475        assert!(dist_cosine.is_finite());
2476    }
2477
2478    #[test]
2479    fn test_simd_performance_improvement() {
2480        let quantizer = RaBitQ::default_4bit();
2481
2482        // Large vectors (1536D like OpenAI embeddings)
2483        let v1: Vec<f32> = (0..1536).map(|i| (i as f32) / 1536.0).collect();
2484        let v2: Vec<f32> = (0..1536).map(|i| ((i + 10) as f32) / 1536.0).collect();
2485
2486        let qv1 = quantizer.quantize(&v1);
2487        let qv2 = quantizer.quantize(&v2);
2488
2489        // Just verify SIMD works on large vectors
2490        let dist_simd = quantizer.distance_l2_simd(&qv1, &qv2);
2491        assert!(dist_simd > 0.0 && dist_simd.is_finite());
2492
2493        // Note: Actual performance benchmarks in Phase 6
2494    }
2495
2496    #[test]
2497    fn test_scalar_distance_functions() {
2498        // Test the scalar fallback functions directly
2499        let v1 = vec![0.0, 0.0, 0.0];
2500        let v2 = vec![1.0, 0.0, 0.0];
2501
2502        let dist = l2_distance_scalar(&v1, &v2);
2503        assert!((dist - 1.0).abs() < 0.001);
2504
2505        let v1 = vec![1.0, 0.0, 0.0];
2506        let v2 = vec![0.0, 1.0, 0.0];
2507
2508        let dist = cosine_distance_scalar(&v1, &v2);
2509        assert!((dist - 1.0).abs() < 0.001);
2510    }
2511
2512    // ADC Tests
2513
2514    #[test]
2515    fn test_adc_table_creation() {
2516        let quantizer = RaBitQ::default_4bit();
2517        let query = vec![0.1, 0.2, 0.3, 0.4];
2518        let scale = 1.0;
2519
2520        let adc = quantizer.build_adc_table_with_scale(&query, scale);
2521
2522        // Check structure
2523        assert_eq!(adc.dimensions, 4);
2524        assert_eq!(adc.bits, 4);
2525        assert_eq!(adc.table.len(), 4);
2526
2527        // Each dimension should have 16 codes (4-bit)
2528        for dim_table in &adc.table {
2529            assert_eq!(dim_table.len(), 16);
2530        }
2531    }
2532
2533    #[test]
2534    fn test_adc_table_2bit() {
2535        let quantizer = RaBitQ::new(RaBitQParams::bits2());
2536        let query = vec![0.1, 0.2, 0.3, 0.4];
2537        let scale = 1.0;
2538
2539        let adc = quantizer.build_adc_table_with_scale(&query, scale);
2540
2541        // Each dimension should have 4 codes (2-bit)
2542        for dim_table in &adc.table {
2543            assert_eq!(dim_table.len(), 4);
2544        }
2545    }
2546
2547    #[test]
2548    fn test_adc_distance_matches_asymmetric() {
2549        let quantizer = RaBitQ::default_4bit();
2550
2551        // Create query and vector
2552        let query = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2553        let vector = vec![0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85];
2554
2555        // Quantize the vector
2556        let quantized = quantizer.quantize(&vector);
2557
2558        // Compute distance with asymmetric method
2559        let dist_asymmetric = quantizer.distance_asymmetric_l2(&query, &quantized);
2560
2561        // Compute distance with ADC (using build_adc_table_with_scale for untrained quantizer)
2562        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2563        let dist_adc = adc.distance(&quantized.data);
2564
2565        // ADC should give similar results to asymmetric distance
2566        // They use different computation paths but should be close
2567        let diff = (dist_asymmetric - dist_adc).abs();
2568        assert!(
2569            diff < 0.1,
2570            "ADC vs asymmetric: {dist_adc} vs {dist_asymmetric}, diff: {diff}"
2571        );
2572    }
2573
2574    #[test]
2575    fn test_adc_distance_accuracy() {
2576        let quantizer = RaBitQ::new(RaBitQParams {
2577            bits_per_dim: QuantizationBits::Bits8, // High precision
2578            num_rescale_factors: 16,
2579            rescale_range: (0.8, 1.2),
2580        });
2581
2582        let query = vec![0.1, 0.2, 0.3, 0.4];
2583        let vector = vec![0.1, 0.2, 0.3, 0.4]; // Same as query
2584
2585        let quantized = quantizer.quantize(&vector);
2586
2587        // Build ADC table
2588        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2589
2590        // Distance should be near zero (same vector)
2591        let dist = adc.distance(&quantized.data);
2592        assert!(dist < 0.2, "Distance should be near zero, got: {dist}");
2593    }
2594
2595    #[test]
2596    fn test_adc_distance_ordering() {
2597        let quantizer = RaBitQ::default_4bit();
2598
2599        let query = vec![0.5, 0.5, 0.5, 0.5];
2600        let v1 = vec![0.5, 0.5, 0.5, 0.5]; // Closest
2601        let v2 = vec![0.6, 0.6, 0.6, 0.6]; // Medium
2602        let v3 = vec![0.9, 0.9, 0.9, 0.9]; // Farthest
2603
2604        let qv1 = quantizer.quantize(&v1);
2605        let qv2 = quantizer.quantize(&v2);
2606        let qv3 = quantizer.quantize(&v3);
2607
2608        // Build ADC tables with respective scales
2609        let adc1 = quantizer.build_adc_table_with_scale(&query, qv1.scale);
2610        let adc2 = quantizer.build_adc_table_with_scale(&query, qv2.scale);
2611        let adc3 = quantizer.build_adc_table_with_scale(&query, qv3.scale);
2612
2613        let dist1 = adc1.distance(&qv1.data);
2614        let dist2 = adc2.distance(&qv2.data);
2615        let dist3 = adc3.distance(&qv3.data);
2616
2617        // Order should be preserved
2618        assert!(
2619            dist1 < dist2,
2620            "v1 should be closer than v2: {dist1} vs {dist2}"
2621        );
2622        assert!(
2623            dist2 < dist3,
2624            "v2 should be closer than v3: {dist2} vs {dist3}"
2625        );
2626    }
2627
2628    #[test]
2629    fn test_adc_high_dimensional() {
2630        let quantizer = RaBitQ::default_4bit();
2631
2632        // 128D vectors (realistic embedding size)
2633        let query: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2634        let vector: Vec<f32> = (0..128).map(|i| ((i + 5) as f32) / 128.0).collect();
2635
2636        let quantized = quantizer.quantize(&vector);
2637
2638        // Build ADC table
2639        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2640
2641        // Should handle high dimensions without panic
2642        let dist = adc.distance(&quantized.data);
2643        assert!(dist > 0.0 && dist.is_finite());
2644    }
2645
2646    #[test]
2647    fn test_adc_batch_search() {
2648        let quantizer = RaBitQ::default_4bit();
2649
2650        let query = vec![0.5, 0.5, 0.5, 0.5];
2651        let candidates = [
2652            vec![0.5, 0.5, 0.5, 0.5],
2653            vec![0.6, 0.6, 0.6, 0.6],
2654            vec![0.4, 0.4, 0.4, 0.4],
2655            vec![0.7, 0.7, 0.7, 0.7],
2656        ];
2657
2658        // Quantize all candidates
2659        let quantized: Vec<QuantizedVector> =
2660            candidates.iter().map(|v| quantizer.quantize(v)).collect();
2661
2662        // Scan all candidates using ADC tables
2663        let mut results: Vec<(usize, f32)> = quantized
2664            .iter()
2665            .enumerate()
2666            .map(|(i, qv)| {
2667                let adc = quantizer.build_adc_table_with_scale(&query, qv.scale);
2668                (i, adc.distance(&qv.data))
2669            })
2670            .collect();
2671
2672        // Sort by distance
2673        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
2674
2675        // First result should be index 0 (identical to query)
2676        assert_eq!(results[0].0, 0, "Results: {results:?}");
2677    }
2678
2679    #[test]
2680    fn test_adc_distance_squared() {
2681        let quantizer = RaBitQ::default_4bit();
2682
2683        let query = vec![0.0, 0.0, 0.0];
2684        let vector = vec![1.0, 0.0, 0.0];
2685
2686        let quantized = quantizer.quantize(&vector);
2687        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2688
2689        let dist_squared = adc.distance_squared(&quantized.data);
2690        let dist = adc.distance(&quantized.data);
2691
2692        // distance_squared should be dist^2 (approximately)
2693        let diff = (dist_squared - dist * dist).abs();
2694        assert!(
2695            diff < 0.01,
2696            "distance_squared != dist^2: {} vs {}",
2697            dist_squared,
2698            dist * dist
2699        );
2700    }
2701
2702    #[test]
2703    fn test_adc_simd_matches_scalar() {
2704        let quantizer = RaBitQ::default_4bit();
2705
2706        let query = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2707        let vector = vec![0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85];
2708
2709        let quantized = quantizer.quantize(&vector);
2710        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2711
2712        let dist_scalar = adc.distance_squared(&quantized.data);
2713        let dist_simd = adc.distance_squared_simd(&quantized.data);
2714
2715        // SIMD should match scalar within floating point precision
2716        let diff = (dist_scalar - dist_simd).abs();
2717        assert!(diff < 0.01, "SIMD vs scalar: {dist_simd} vs {dist_scalar}");
2718    }
2719
2720    #[test]
2721    fn test_adc_simd_high_dimensional() {
2722        let quantizer = RaBitQ::default_4bit();
2723
2724        // 1536D vectors (OpenAI embeddings)
2725        let query: Vec<f32> = (0..1536).map(|i| (i as f32) / 1536.0).collect();
2726        let vector: Vec<f32> = (0..1536).map(|i| ((i + 10) as f32) / 1536.0).collect();
2727
2728        let quantized = quantizer.quantize(&vector);
2729        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2730
2731        // Should handle large dimensions efficiently
2732        let dist_simd = adc.distance_squared_simd(&quantized.data);
2733        assert!(dist_simd > 0.0 && dist_simd.is_finite());
2734    }
2735
2736    #[test]
2737    fn test_adc_memory_usage() {
2738        let quantizer = RaBitQ::default_4bit();
2739
2740        let query: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2741        let adc = quantizer.build_adc_table_with_scale(&query, 1.0);
2742
2743        let memory = adc.memory_bytes();
2744
2745        // For 128D, 4-bit: 128 * 16 * 4 bytes = 8KB (plus overhead)
2746        let expected_min = 128 * 16 * 4;
2747        assert!(
2748            memory >= expected_min,
2749            "Memory {memory} should be at least {expected_min}"
2750        );
2751    }
2752
2753    #[test]
2754    fn test_adc_different_scales() {
2755        let quantizer = RaBitQ::default_4bit();
2756
2757        let query = vec![0.5, 0.5, 0.5, 0.5];
2758        let vector = vec![0.6, 0.6, 0.6, 0.6];
2759
2760        let quantized = quantizer.quantize(&vector);
2761
2762        // Build ADC tables with different scales
2763        let adc1 = quantizer.build_adc_table_with_scale(&query, 0.5);
2764        let adc2 = quantizer.build_adc_table_with_scale(&query, 1.0);
2765        let adc3 = quantizer.build_adc_table_with_scale(&query, 2.0);
2766
2767        // Distances should differ based on scale
2768        let dist1 = adc1.distance(&quantized.data);
2769        let dist2 = adc2.distance(&quantized.data);
2770        let dist3 = adc3.distance(&quantized.data);
2771
2772        // All should be valid finite numbers
2773        assert!(dist1.is_finite());
2774        assert!(dist2.is_finite());
2775        assert!(dist3.is_finite());
2776    }
2777
2778    #[test]
2779    fn test_adc_edge_cases() {
2780        let quantizer = RaBitQ::default_4bit();
2781
2782        // Test with very small vector
2783        let query = vec![0.5];
2784        let vector = vec![0.6];
2785        let quantized = quantizer.quantize(&vector);
2786        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2787        let dist = adc.distance(&quantized.data);
2788        assert!(dist.is_finite());
2789
2790        // Test with all zeros
2791        let query = vec![0.0, 0.0, 0.0, 0.0];
2792        let vector = vec![0.0, 0.0, 0.0, 0.0];
2793        let quantized = quantizer.quantize(&vector);
2794        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2795        let dist = adc.distance(&quantized.data);
2796        assert!(dist.is_finite());
2797    }
2798
2799    #[test]
2800    fn test_adc_2bit_accuracy() {
2801        let quantizer = RaBitQ::new(RaBitQParams::bits2());
2802
2803        let query = vec![0.1, 0.2, 0.3, 0.4];
2804        let vector = vec![0.12, 0.22, 0.32, 0.42];
2805
2806        let quantized = quantizer.quantize(&vector);
2807
2808        // Test ADC for 2-bit quantization
2809        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2810        let dist_adc = adc.distance(&quantized.data);
2811        let dist_asymmetric = quantizer.distance_asymmetric_l2(&query, &quantized);
2812
2813        // Should be reasonably close despite lower precision
2814        let diff = (dist_adc - dist_asymmetric).abs();
2815        assert!(diff < 0.2, "2-bit ADC diff too large: {diff}");
2816    }
2817
2818    #[test]
2819    fn test_adc_8bit_accuracy() {
2820        let quantizer = RaBitQ::new(RaBitQParams::bits8());
2821
2822        let query = vec![0.1, 0.2, 0.3, 0.4];
2823        let vector = vec![0.12, 0.22, 0.32, 0.42];
2824
2825        let quantized = quantizer.quantize(&vector);
2826
2827        // Test ADC for 8-bit quantization (highest precision)
2828        let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2829        let dist_adc = adc.distance(&quantized.data);
2830        let dist_asymmetric = quantizer.distance_asymmetric_l2(&query, &quantized);
2831
2832        // 8-bit should be very accurate
2833        let diff = (dist_adc - dist_asymmetric).abs();
2834        assert!(
2835            diff < 0.05,
2836            "8-bit ADC should be highly accurate, diff: {diff}"
2837        );
2838    }
2839}