omendb_core/compression/
scalar.rs

1//! Scalar Quantization (SQ8) for OmenDB
2//!
3//! Compresses f32 vectors to u8 (4x compression, ~97% recall, 2-3x faster than FP32).
4//!
5//! # Algorithm
6//!
7//! Uniform min/max scaling (single scale/offset for all dimensions):
8//! - Train: Compute global min, max from sample vectors
9//! - Quantize: u8[d] = round((f32[d] - offset) / scale)
10//! - Distance: Integer SIMD dot product with float reconstruction
11//!
12//! # Performance (768D, Apple M3 Max)
13//!
14//! - 4x compression (f32 → u8)
15//! - 2-3x faster than FP32 (integer SIMD)
16//! - ~97% recall (vs 99%+ for per-dimension quantization)
17
18use serde::{Deserialize, Serialize};
19
20#[cfg(target_arch = "x86_64")]
21#[allow(clippy::wildcard_imports)]
22use std::arch::x86_64::*;
23
24/// Trained scalar quantization parameters (uniform quantization)
25///
26/// Uses a single scale/offset for all dimensions, enabling integer SIMD
27/// for 2-3x speedup over FP32.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ScalarParams {
30    /// Global scale factor: (max - min) / 255
31    pub scale: f32,
32    /// Global offset (minimum value)
33    pub offset: f32,
34    /// Number of dimensions
35    pub dimensions: usize,
36}
37
38/// Precomputed data for a quantized vector
39#[derive(Debug, Clone)]
40pub struct QuantizedVector {
41    /// Quantized values (u8)
42    pub data: Vec<u8>,
43    /// Precomputed: sum of quantized values (Σ data[i])
44    pub sum: i32,
45    /// Precomputed: squared norm of dequantized vector
46    pub norm_sq: f32,
47}
48
49/// Precomputed query data for fast integer SIMD distance
50#[derive(Debug, Clone)]
51pub struct QueryPrep {
52    /// Quantized query values (u8 for SIMD dot product)
53    pub quantized: Vec<u8>,
54    /// Query squared norm: ||q||²
55    pub norm_sq: f32,
56    /// Sum of quantized query values
57    pub sum: i32,
58}
59
60impl ScalarParams {
61    /// Create uninitialized params (for lazy training)
62    ///
63    /// Uses identity mapping until trained.
64    #[must_use]
65    pub fn uninitialized(dimensions: usize) -> Self {
66        Self {
67            scale: 1.0 / 255.0,
68            offset: 0.0,
69            dimensions,
70        }
71    }
72
73    /// Train scalar quantization from sample vectors
74    ///
75    /// Uses 1st and 99th percentiles to handle outliers.
76    ///
77    /// # Errors
78    /// Returns error if vectors is empty or vectors have inconsistent dimensions.
79    pub fn train(vectors: &[&[f32]]) -> Result<Self, &'static str> {
80        Self::train_with_percentiles(vectors, 0.01, 0.99)
81    }
82
83    /// Train with custom percentile bounds
84    pub fn train_with_percentiles(
85        vectors: &[&[f32]],
86        lower_percentile: f32,
87        upper_percentile: f32,
88    ) -> Result<Self, &'static str> {
89        if vectors.is_empty() {
90            return Err("Need at least one vector to train");
91        }
92        let dimensions = vectors[0].len();
93        if !vectors.iter().all(|v| v.len() == dimensions) {
94            return Err("All vectors must have same dimensions");
95        }
96
97        // Collect ALL values across all vectors and dimensions
98        let mut all_values: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
99        all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
100
101        let n = all_values.len();
102        let lower_idx = ((n as f32 * lower_percentile) as usize).min(n - 1);
103        let upper_idx = ((n as f32 * upper_percentile) as usize).min(n - 1);
104
105        let min_val = all_values[lower_idx];
106        let max_val = all_values[upper_idx];
107
108        let range = max_val - min_val;
109        let (offset, scale) = if range < 1e-7 {
110            (min_val - 0.5, 1.0 / 255.0)
111        } else {
112            (min_val, range / 255.0)
113        };
114
115        Ok(Self {
116            scale,
117            offset,
118            dimensions,
119        })
120    }
121
122    /// Quantize a vector to u8 with precomputed metadata
123    #[must_use]
124    pub fn quantize(&self, vector: &[f32]) -> QuantizedVector {
125        debug_assert_eq!(vector.len(), self.dimensions);
126
127        let inv_scale = 1.0 / self.scale;
128        let data: Vec<u8> = vector
129            .iter()
130            .map(|&v| ((v - self.offset) * inv_scale).clamp(0.0, 255.0).round() as u8)
131            .collect();
132
133        let sum: i32 = data.iter().map(|&x| x as i32).sum();
134
135        // Compute dequantized norm
136        let norm_sq: f32 = data
137            .iter()
138            .map(|&x| {
139                let dequant = x as f32 * self.scale + self.offset;
140                dequant * dequant
141            })
142            .sum();
143
144        QuantizedVector { data, sum, norm_sq }
145    }
146
147    /// Quantize a vector, returning only the u8 data (for storage)
148    #[must_use]
149    pub fn quantize_to_bytes(&self, vector: &[f32]) -> Vec<u8> {
150        debug_assert_eq!(vector.len(), self.dimensions);
151
152        let inv_scale = 1.0 / self.scale;
153        vector
154            .iter()
155            .map(|&v| ((v - self.offset) * inv_scale).clamp(0.0, 255.0).round() as u8)
156            .collect()
157    }
158
159    /// Dequantize a u8 vector back to f32 (approximate)
160    #[must_use]
161    pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
162        quantized
163            .iter()
164            .map(|&q| q as f32 * self.scale + self.offset)
165            .collect()
166    }
167
168    /// Compute squared norm of dequantized vector: ||dequant(q)||²
169    #[must_use]
170    pub fn dequantized_norm_squared(&self, quantized: &[u8]) -> f32 {
171        quantized
172            .iter()
173            .map(|&q| {
174                let dequant = q as f32 * self.scale + self.offset;
175                dequant * dequant
176            })
177            .sum()
178    }
179
180    /// Compute sum of quantized values (for distance computation)
181    #[must_use]
182    pub fn quantized_sum(&self, quantized: &[u8]) -> i32 {
183        quantized.iter().map(|&x| x as i32).sum()
184    }
185
186    /// Prepare query for fast integer SIMD distance computation
187    #[must_use]
188    pub fn prepare_query(&self, query: &[f32]) -> QueryPrep {
189        debug_assert_eq!(query.len(), self.dimensions);
190
191        let inv_scale = 1.0 / self.scale;
192        let quantized: Vec<u8> = query
193            .iter()
194            .map(|&v| ((v - self.offset) * inv_scale).clamp(0.0, 255.0).round() as u8)
195            .collect();
196
197        let norm_sq: f32 = query.iter().map(|x| x * x).sum();
198        let sum: i32 = quantized.iter().map(|&x| x as i32).sum();
199
200        QueryPrep {
201            quantized,
202            norm_sq,
203            sum,
204        }
205    }
206
207    /// Compute L2² distance using integer SIMD
208    ///
209    /// Uses the identity: ||q - v||² = ||q||² + ||v||² - 2⟨q,v⟩
210    /// The dot product is computed in integer domain for speed.
211    #[inline(always)]
212    #[must_use]
213    pub fn distance_l2_squared(&self, query_prep: &QueryPrep, vec: &QuantizedVector) -> f32 {
214        // Integer dot product (SIMD accelerated) - uses u8×u8→u32
215        let int_dot = self.int_dot_product(&query_prep.quantized, &vec.data);
216
217        // Reconstruct actual dot product: scale² × int_dot + corrections
218        // dot(q, v) = scale² × Σ q_int[i] × v_int[i]
219        //           + scale × offset × (Σ q_int[i] + Σ v_int[i])
220        //           + offset² × dim
221        let scale_sq = self.scale * self.scale;
222        let dot = scale_sq * int_dot as f32
223            + self.scale * self.offset * (query_prep.sum + vec.sum) as f32
224            + self.offset * self.offset * self.dimensions as f32;
225
226        // L2² = ||q||² + ||v||² - 2⟨q,v⟩
227        query_prep.norm_sq + vec.norm_sq - 2.0 * dot
228    }
229
230    /// Compute L2² distance from raw bytes (for storage integration)
231    ///
232    /// Slightly slower than `distance_l2_squared` since it computes sum and norm on the fly.
233    #[inline(always)]
234    #[must_use]
235    pub fn distance_l2_squared_raw(
236        &self,
237        query_prep: &QueryPrep,
238        vec_data: &[u8],
239        vec_sum: i32,
240        vec_norm_sq: f32,
241    ) -> f32 {
242        let int_dot = self.int_dot_product(&query_prep.quantized, vec_data);
243
244        let scale_sq = self.scale * self.scale;
245        let dot = scale_sq * int_dot as f32
246            + self.scale * self.offset * (query_prep.sum + vec_sum) as f32
247            + self.offset * self.offset * self.dimensions as f32;
248
249        query_prep.norm_sq + vec_norm_sq - 2.0 * dot
250    }
251
252    /// Integer dot product with SIMD acceleration (u8 × u8 → u32)
253    #[inline(always)]
254    fn int_dot_product(&self, query: &[u8], vec: &[u8]) -> u32 {
255        debug_assert_eq!(query.len(), vec.len());
256
257        #[cfg(target_arch = "x86_64")]
258        {
259            if is_x86_feature_detected!("avx2") {
260                return unsafe { self.int_dot_product_avx2(query, vec) };
261            }
262            Self::int_dot_product_scalar(query, vec)
263        }
264
265        #[cfg(target_arch = "aarch64")]
266        {
267            unsafe { self.int_dot_product_neon(query, vec) }
268        }
269
270        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
271        {
272            Self::int_dot_product_scalar(query, vec)
273        }
274    }
275
276    #[inline]
277    #[allow(dead_code)]
278    fn int_dot_product_scalar(query: &[u8], vec: &[u8]) -> u32 {
279        query
280            .iter()
281            .zip(vec.iter())
282            .map(|(&q, &v)| q as u32 * v as u32)
283            .sum()
284    }
285
286    #[cfg(target_arch = "x86_64")]
287    #[target_feature(enable = "avx2")]
288    #[allow(clippy::unused_self)]
289    unsafe fn int_dot_product_avx2(&self, query: &[u8], vec: &[u8]) -> u32 {
290        let mut sum = _mm256_setzero_si256();
291        let mut i = 0;
292
293        while i + 32 <= query.len() {
294            let q = _mm256_loadu_si256(query.as_ptr().add(i).cast());
295            let v = _mm256_loadu_si256(vec.as_ptr().add(i).cast());
296
297            let q_lo = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(q, 0));
298            let q_hi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(q, 1));
299            let v_lo = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v, 0));
300            let v_hi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v, 1));
301
302            let prod_lo = _mm256_madd_epi16(q_lo, v_lo);
303            let prod_hi = _mm256_madd_epi16(q_hi, v_hi);
304            sum = _mm256_add_epi32(sum, prod_lo);
305            sum = _mm256_add_epi32(sum, prod_hi);
306
307            i += 32;
308        }
309
310        while i + 16 <= query.len() {
311            let q = _mm256_cvtepu8_epi16(_mm_loadu_si128(query.as_ptr().add(i).cast()));
312            let v = _mm256_cvtepu8_epi16(_mm_loadu_si128(vec.as_ptr().add(i).cast()));
313            let prod = _mm256_madd_epi16(q, v);
314            sum = _mm256_add_epi32(sum, prod);
315            i += 16;
316        }
317
318        let sum128 = _mm_add_epi32(
319            _mm256_extracti128_si256(sum, 0),
320            _mm256_extracti128_si256(sum, 1),
321        );
322        let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
323        let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
324        let mut result = _mm_cvtsi128_si32(sum32) as u32;
325
326        for j in i..query.len() {
327            result += query[j] as u32 * vec[j] as u32;
328        }
329
330        result
331    }
332
333    #[cfg(target_arch = "aarch64")]
334    #[inline(always)]
335    #[allow(clippy::unused_self)]
336    unsafe fn int_dot_product_neon(&self, query: &[u8], vec: &[u8]) -> u32 {
337        use std::arch::aarch64::{
338            vaddq_u32, vaddvq_u32, vdupq_n_u32, vget_low_u8, vld1q_u8, vmull_high_u8, vmull_u8,
339            vpadalq_u16,
340        };
341
342        // Use 4 accumulators to hide latency and increase ILP
343        let mut sum0 = vdupq_n_u32(0);
344        let mut sum1 = vdupq_n_u32(0);
345        let mut sum2 = vdupq_n_u32(0);
346        let mut sum3 = vdupq_n_u32(0);
347        let mut i = 0;
348
349        // Process 64 elements per iteration (4x unrolling)
350        while i + 64 <= query.len() {
351            let q0 = vld1q_u8(query.as_ptr().add(i));
352            let v0 = vld1q_u8(vec.as_ptr().add(i));
353            let prod0_lo = vmull_u8(vget_low_u8(q0), vget_low_u8(v0));
354            let prod0_hi = vmull_high_u8(q0, v0);
355            sum0 = vpadalq_u16(sum0, prod0_lo);
356            sum0 = vpadalq_u16(sum0, prod0_hi);
357
358            let q1 = vld1q_u8(query.as_ptr().add(i + 16));
359            let v1 = vld1q_u8(vec.as_ptr().add(i + 16));
360            let prod1_lo = vmull_u8(vget_low_u8(q1), vget_low_u8(v1));
361            let prod1_hi = vmull_high_u8(q1, v1);
362            sum1 = vpadalq_u16(sum1, prod1_lo);
363            sum1 = vpadalq_u16(sum1, prod1_hi);
364
365            let q2 = vld1q_u8(query.as_ptr().add(i + 32));
366            let v2 = vld1q_u8(vec.as_ptr().add(i + 32));
367            let prod2_lo = vmull_u8(vget_low_u8(q2), vget_low_u8(v2));
368            let prod2_hi = vmull_high_u8(q2, v2);
369            sum2 = vpadalq_u16(sum2, prod2_lo);
370            sum2 = vpadalq_u16(sum2, prod2_hi);
371
372            let q3 = vld1q_u8(query.as_ptr().add(i + 48));
373            let v3 = vld1q_u8(vec.as_ptr().add(i + 48));
374            let prod3_lo = vmull_u8(vget_low_u8(q3), vget_low_u8(v3));
375            let prod3_hi = vmull_high_u8(q3, v3);
376            sum3 = vpadalq_u16(sum3, prod3_lo);
377            sum3 = vpadalq_u16(sum3, prod3_hi);
378
379            i += 64;
380        }
381
382        while i + 16 <= query.len() {
383            let q = vld1q_u8(query.as_ptr().add(i));
384            let v = vld1q_u8(vec.as_ptr().add(i));
385            let prod_lo = vmull_u8(vget_low_u8(q), vget_low_u8(v));
386            let prod_hi = vmull_high_u8(q, v);
387            sum0 = vpadalq_u16(sum0, prod_lo);
388            sum0 = vpadalq_u16(sum0, prod_hi);
389            i += 16;
390        }
391
392        let sum01 = vaddq_u32(sum0, sum1);
393        let sum23 = vaddq_u32(sum2, sum3);
394        let sum_all = vaddq_u32(sum01, sum23);
395        let mut result = vaddvq_u32(sum_all);
396
397        for j in i..query.len() {
398            result += query[j] as u32 * vec[j] as u32;
399        }
400
401        result
402    }
403}
404
405/// Compute symmetric L2² distance between two quantized vectors
406#[inline]
407#[must_use]
408pub fn symmetric_l2_squared_u8(a: &[u8], b: &[u8]) -> u32 {
409    a.iter()
410        .zip(b.iter())
411        .map(|(&x, &y)| {
412            let diff = (i16::from(x) - i16::from(y)) as i32;
413            (diff * diff) as u32
414        })
415        .sum()
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[test]
423    fn test_train_and_quantize() {
424        let vectors: Vec<Vec<f32>> = vec![
425            vec![0.0, 0.5, 1.0, 0.3],
426            vec![0.1, 0.6, 0.9, 0.4],
427            vec![0.2, 0.4, 0.8, 0.5],
428        ];
429        let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
430
431        let params = ScalarParams::train(&refs).unwrap();
432
433        let quantized = params.quantize(&vectors[0]);
434        assert_eq!(quantized.data.len(), 4);
435        assert!(quantized.sum > 0);
436        assert!(quantized.norm_sq > 0.0);
437    }
438
439    #[test]
440    fn test_distance_accuracy() {
441        use rand::Rng;
442
443        let dim = 128;
444        let n_vectors = 100;
445        let mut rng = rand::thread_rng();
446
447        // Generate normalized vectors (common in embeddings)
448        let vectors: Vec<Vec<f32>> = (0..n_vectors)
449            .map(|_| {
450                let v: Vec<f32> = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
451                let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
452                v.iter().map(|x| x / norm).collect()
453            })
454            .collect();
455
456        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
457        let params = ScalarParams::train(&refs).unwrap();
458
459        let quantized: Vec<_> = vectors.iter().map(|v| params.quantize(v)).collect();
460
461        let query = &vectors[0];
462        let query_prep = params.prepare_query(query);
463
464        let mut max_rel_error = 0.0f32;
465
466        for (i, (orig, quant)) in vectors.iter().zip(quantized.iter()).enumerate() {
467            if i == 0 {
468                continue;
469            }
470
471            let true_dist: f32 = query
472                .iter()
473                .zip(orig.iter())
474                .map(|(a, b)| (a - b).powi(2))
475                .sum();
476
477            let quant_dist = params.distance_l2_squared(&query_prep, quant);
478
479            let rel_error = (true_dist - quant_dist).abs() / true_dist.max(1e-6);
480            max_rel_error = max_rel_error.max(rel_error);
481        }
482
483        println!(
484            "SQ8 max relative distance error: {:.2}%",
485            max_rel_error * 100.0
486        );
487        assert!(
488            max_rel_error < 0.15,
489            "Distance error too large: {max_rel_error:.4}"
490        );
491    }
492
493    #[test]
494    fn test_int_dot_product() {
495        let vectors: Vec<Vec<f32>> = vec![vec![0.5; 768], vec![0.3; 768]];
496        let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
497
498        let params = ScalarParams::train(&refs).unwrap();
499        let query_prep = params.prepare_query(&vectors[0]);
500        let quantized = params.quantize(&vectors[1]);
501
502        let dist = params.distance_l2_squared(&query_prep, &quantized);
503        assert!(dist >= 0.0);
504        assert!(!dist.is_nan());
505    }
506
507    #[test]
508    fn test_dequantize_roundtrip() {
509        let vectors: Vec<Vec<f32>> = vec![
510            vec![0.0, 0.5, 1.0],
511            vec![0.1, 0.6, 0.9],
512            vec![0.2, 0.4, 0.8],
513        ];
514        let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
515
516        let params = ScalarParams::train(&refs).unwrap();
517        let quantized = params.quantize(&vectors[0]);
518        let dequantized = params.dequantize(&quantized.data);
519
520        for (orig, deq) in vectors[0].iter().zip(dequantized.iter()) {
521            assert!(
522                (orig - deq).abs() < 0.05,
523                "Roundtrip error too large: {} vs {}",
524                orig,
525                deq
526            );
527        }
528    }
529
530    #[test]
531    fn test_symmetric_distance() {
532        let a: Vec<u8> = vec![0, 100, 200, 255];
533        let b: Vec<u8> = vec![0, 100, 200, 255];
534        let dist = symmetric_l2_squared_u8(&a, &b);
535        assert_eq!(dist, 0);
536
537        let c: Vec<u8> = vec![10, 110, 210, 245];
538        let dist2 = symmetric_l2_squared_u8(&a, &c);
539        assert!(dist2 > 0);
540    }
541
542    #[test]
543    fn test_compression_ratio() {
544        let dims = 768;
545        let original_size = dims * 4; // f32 = 4 bytes
546        let quantized_size = dims; // u8 = 1 byte
547
548        assert_eq!(original_size / quantized_size, 4);
549    }
550}