Skip to main content

reddb_server/storage/engine/
int8_quantize.rs

1//! int8 Quantization for Vector Embeddings
2//!
3//! Compresses fp32 vectors to int8 (8 bits per dimension) for efficient
4//! storage and fast approximate distance computation.
5//!
6//! # Compression Ratio
7//!
8//! - fp32: 4 bytes per dimension
9//! - int8: 1 byte per dimension = 4x compression
10//!
11//! Example: 1024-dim vector
12//! - fp32: 4096 bytes
13//! - int8: 1024 bytes
14//!
15//! # Quantization Methods
16//!
17//! ## Symmetric Quantization
18//! Maps [-max_abs, +max_abs] → [-127, +127]
19//! - scale = max(|v|) / 127
20//! - quantized = round(v / scale)
21//!
22//! ## Asymmetric Quantization
23//! Maps [min, max] → [0, 255]
24//! - scale = (max - min) / 255
25//! - zero_point = round(-min / scale)
26//! - quantized = round(v / scale) + zero_point
27//!
28//! # Usage
29//!
30//! ```ignore
31//! // Quantize a vector (symmetric)
32//! let int8 = Int8Vector::from_f32(&embedding);
33//!
34//! // Compute dot product (SIMD accelerated)
35//! let dot = int8.dot_product(&other);
36//!
37//! // Rescore binary search candidates
38//! let rescored = int8.rescore_candidates(&binary_results, &query);
39//! ```
40
41use std::cmp::Ordering;
42
43/// int8 quantized vector with scale factor
44#[derive(Clone, Debug)]
45pub struct Int8Vector {
46    /// Quantized values (-127 to +127 for symmetric)
47    data: Vec<i8>,
48    /// Scale factor for dequantization
49    scale: f32,
50    /// Original L2 norm (for normalized dot product)
51    norm: f32,
52}
53
54impl Int8Vector {
55    /// Create int8 vector from fp32 using symmetric quantization
56    ///
57    /// Best for normalized embeddings centered around 0.
58    pub fn from_f32(values: &[f32]) -> Self {
59        if values.is_empty() {
60            return Self {
61                data: Vec::new(),
62                scale: 1.0,
63                norm: 0.0,
64            };
65        }
66
67        // Find maximum absolute value
68        let max_abs = values
69            .iter()
70            .map(|v| v.abs())
71            .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
72            .unwrap_or(1.0);
73
74        // Compute scale (avoid division by zero)
75        let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
76
77        // Quantize
78        let data: Vec<i8> = values
79            .iter()
80            .map(|&v| {
81                let quantized = (v / scale).round();
82                quantized.clamp(-127.0, 127.0) as i8
83            })
84            .collect();
85
86        // Compute original norm
87        let norm = values.iter().map(|v| v * v).sum::<f32>().sqrt();
88
89        Self { data, scale, norm }
90    }
91
92    /// Create int8 vector with pre-computed scale
93    pub fn from_f32_with_scale(values: &[f32], scale: f32) -> Self {
94        let data: Vec<i8> = values
95            .iter()
96            .map(|&v| {
97                let quantized = (v / scale).round();
98                quantized.clamp(-127.0, 127.0) as i8
99            })
100            .collect();
101
102        let norm = values.iter().map(|v| v * v).sum::<f32>().sqrt();
103
104        Self { data, scale, norm }
105    }
106
107    /// Create from raw quantized data
108    pub fn from_raw(data: Vec<i8>, scale: f32, norm: f32) -> Self {
109        Self { data, scale, norm }
110    }
111
112    /// Get the dimensionality
113    #[inline]
114    pub fn dim(&self) -> usize {
115        self.data.len()
116    }
117
118    /// Get the quantized data
119    #[inline]
120    pub fn data(&self) -> &[i8] {
121        &self.data
122    }
123
124    /// Get the scale factor
125    #[inline]
126    pub fn scale(&self) -> f32 {
127        self.scale
128    }
129
130    /// Get size in bytes
131    #[inline]
132    pub fn size_bytes(&self) -> usize {
133        self.data.len() + 8 // data + scale (f32) + norm (f32)
134    }
135
136    /// Dequantize to fp32
137    pub fn to_f32(&self) -> Vec<f32> {
138        self.data.iter().map(|&v| v as f32 * self.scale).collect()
139    }
140
141    /// Compute dot product with another int8 vector
142    ///
143    /// Returns scaled dot product in fp32.
144    #[inline]
145    pub fn dot_product(&self, other: &Self) -> f32 {
146        debug_assert_eq!(self.data.len(), other.data.len(), "Dimensions must match");
147
148        let raw_dot = dot_product_i8_simd(&self.data, &other.data);
149        raw_dot as f32 * self.scale * other.scale
150    }
151
152    /// Compute dot product with fp32 query (asymmetric)
153    ///
154    /// Query stays in fp32 for better precision.
155    /// This is the recommended approach for rescoring.
156    #[inline]
157    pub fn dot_product_f32(&self, query: &[f32]) -> f32 {
158        debug_assert_eq!(self.data.len(), query.len(), "Dimensions must match");
159
160        dot_product_i8_f32_simd(&self.data, query) * self.scale
161    }
162
163    /// Compute L2 squared distance to another int8 vector
164    #[inline]
165    pub fn l2_squared(&self, other: &Self) -> f32 {
166        debug_assert_eq!(self.data.len(), other.data.len(), "Dimensions must match");
167
168        let raw_dist = l2_squared_i8_simd(&self.data, &other.data);
169        raw_dist as f32 * self.scale * other.scale
170    }
171
172    /// Compute cosine distance using normalized dot product
173    ///
174    /// Assumes vectors were normalized before quantization.
175    #[inline]
176    pub fn cosine_distance(&self, other: &Self) -> f32 {
177        let dot = self.dot_product(other);
178        let denom = self.norm * other.norm;
179        if denom > 0.0 {
180            1.0 - (dot / denom)
181        } else {
182            1.0
183        }
184    }
185}
186
187// ============================================================================
188// SIMD Operations
189// ============================================================================
190
191/// Compute dot product of two i8 vectors using SIMD
192#[inline]
193pub fn dot_product_i8_simd(a: &[i8], b: &[i8]) -> i32 {
194    debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
195
196    #[cfg(target_arch = "x86_64")]
197    {
198        if is_x86_feature_detected!("avx2") {
199            return unsafe { dot_product_i8_avx2(a, b) };
200        }
201        if is_x86_feature_detected!("sse4.1") {
202            return unsafe { dot_product_i8_sse4(a, b) };
203        }
204    }
205
206    dot_product_i8_scalar(a, b)
207}
208
209/// Scalar fallback for i8 dot product
210#[inline]
211fn dot_product_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
212    let mut sum = 0i32;
213    for (x, y) in a.iter().zip(b.iter()) {
214        sum += (*x as i32) * (*y as i32);
215    }
216    sum
217}
218
219/// AVX2 implementation of i8 dot product
220#[cfg(target_arch = "x86_64")]
221#[target_feature(enable = "avx2")]
222#[inline]
223unsafe fn dot_product_i8_avx2(a: &[i8], b: &[i8]) -> i32 {
224    use std::arch::x86_64::*;
225
226    let len = a.len();
227    let mut sum = _mm256_setzero_si256();
228
229    // Process 32 elements at a time
230    let chunks = len / 32;
231    for i in 0..chunks {
232        let idx = i * 32;
233        let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
234        let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
235
236        // Split into low and high 128-bit lanes and convert to i16
237        let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
238        let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
239        let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
240        let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
241
242        // Multiply and accumulate
243        let prod_lo = _mm256_madd_epi16(va_lo, vb_lo);
244        let prod_hi = _mm256_madd_epi16(va_hi, vb_hi);
245
246        sum = _mm256_add_epi32(sum, prod_lo);
247        sum = _mm256_add_epi32(sum, prod_hi);
248    }
249
250    // Horizontal sum
251    let sum128 = _mm_add_epi32(
252        _mm256_castsi256_si128(sum),
253        _mm256_extracti128_si256(sum, 1),
254    );
255    let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
256    let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
257    let mut result = _mm_cvtsi128_si32(sum32);
258
259    // Handle remaining elements
260    for i in (chunks * 32)..len {
261        result += (a[i] as i32) * (b[i] as i32);
262    }
263
264    result
265}
266
267/// SSE4.1 implementation of i8 dot product
268#[cfg(target_arch = "x86_64")]
269#[target_feature(enable = "sse4.1")]
270#[inline]
271unsafe fn dot_product_i8_sse4(a: &[i8], b: &[i8]) -> i32 {
272    use std::arch::x86_64::*;
273
274    let len = a.len();
275    let mut sum = _mm_setzero_si128();
276
277    // Process 16 elements at a time
278    let chunks = len / 16;
279    for i in 0..chunks {
280        let idx = i * 16;
281        let va = _mm_loadu_si128(a.as_ptr().add(idx) as *const __m128i);
282        let vb = _mm_loadu_si128(b.as_ptr().add(idx) as *const __m128i);
283
284        // Convert to i16 (low and high)
285        let va_lo = _mm_cvtepi8_epi16(va);
286        let va_hi = _mm_cvtepi8_epi16(_mm_srli_si128(va, 8));
287        let vb_lo = _mm_cvtepi8_epi16(vb);
288        let vb_hi = _mm_cvtepi8_epi16(_mm_srli_si128(vb, 8));
289
290        // Multiply and accumulate
291        let prod_lo = _mm_madd_epi16(va_lo, vb_lo);
292        let prod_hi = _mm_madd_epi16(va_hi, vb_hi);
293
294        sum = _mm_add_epi32(sum, prod_lo);
295        sum = _mm_add_epi32(sum, prod_hi);
296    }
297
298    // Horizontal sum
299    let sum64 = _mm_add_epi32(sum, _mm_srli_si128(sum, 8));
300    let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
301    let mut result = _mm_cvtsi128_si32(sum32);
302
303    // Handle remaining elements
304    for i in (chunks * 16)..len {
305        result += (a[i] as i32) * (b[i] as i32);
306    }
307
308    result
309}
310
311/// Compute dot product of i8 vector with f32 query (asymmetric)
312#[inline]
313pub fn dot_product_i8_f32_simd(a: &[i8], b: &[f32]) -> f32 {
314    debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
315
316    #[cfg(target_arch = "x86_64")]
317    {
318        if is_x86_feature_detected!("avx2") {
319            return unsafe { dot_product_i8_f32_avx2(a, b) };
320        }
321    }
322
323    dot_product_i8_f32_scalar(a, b)
324}
325
326/// Scalar fallback for i8-f32 dot product
327#[inline]
328fn dot_product_i8_f32_scalar(a: &[i8], b: &[f32]) -> f32 {
329    let mut sum = 0.0f32;
330    for (x, y) in a.iter().zip(b.iter()) {
331        sum += (*x as f32) * y;
332    }
333    sum
334}
335
336/// AVX2 implementation of i8-f32 dot product
337#[cfg(target_arch = "x86_64")]
338#[target_feature(enable = "avx2")]
339#[inline]
340unsafe fn dot_product_i8_f32_avx2(a: &[i8], b: &[f32]) -> f32 {
341    use std::arch::x86_64::*;
342
343    let len = a.len();
344    let mut sum = _mm256_setzero_ps();
345
346    // Process 8 elements at a time
347    let chunks = len / 8;
348    for i in 0..chunks {
349        let idx = i * 8;
350
351        // Load 8 i8 values and convert to f32
352        let va_i8 = _mm_loadl_epi64(a.as_ptr().add(idx) as *const __m128i);
353        let va_i16 = _mm_cvtepi8_epi16(va_i8);
354        let va_i32 = _mm256_cvtepi16_epi32(va_i16);
355        let va_f32 = _mm256_cvtepi32_ps(va_i32);
356
357        // Load 8 f32 values
358        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
359
360        // Multiply and accumulate
361        sum = _mm256_fmadd_ps(va_f32, vb, sum);
362    }
363
364    // Horizontal sum
365    let sum128 = _mm_add_ps(_mm256_castps256_ps128(sum), _mm256_extractf128_ps(sum, 1));
366    let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
367    let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
368    let mut result = _mm_cvtss_f32(sum32);
369
370    // Handle remaining elements
371    for i in (chunks * 8)..len {
372        result += (a[i] as f32) * b[i];
373    }
374
375    result
376}
377
378/// Compute L2 squared distance between two i8 vectors
379#[inline]
380pub fn l2_squared_i8_simd(a: &[i8], b: &[i8]) -> i32 {
381    debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
382
383    // For now, use scalar. Can be SIMD optimized later.
384    let mut sum = 0i32;
385    for (x, y) in a.iter().zip(b.iter()) {
386        let d = (*x as i32) - (*y as i32);
387        sum += d * d;
388    }
389    sum
390}
391
392// ============================================================================
393// Storage Index
394// ============================================================================
395
396/// Index of int8 vectors for batch operations
397#[derive(Clone)]
398pub struct Int8Index {
399    /// All quantized vectors (flattened)
400    vectors: Vec<i8>,
401    /// Scale factors for each vector
402    scales: Vec<f32>,
403    /// L2 norms for each vector
404    norms: Vec<f32>,
405    /// Dimensionality
406    dim: usize,
407    /// Number of vectors
408    n_vectors: usize,
409}
410
411impl Int8Index {
412    /// Create a new int8 index
413    pub fn new(dim: usize) -> Self {
414        Self {
415            vectors: Vec::new(),
416            scales: Vec::new(),
417            norms: Vec::new(),
418            dim,
419            n_vectors: 0,
420        }
421    }
422
423    /// Create with pre-allocated capacity
424    pub fn with_capacity(dim: usize, capacity: usize) -> Self {
425        Self {
426            vectors: Vec::with_capacity(capacity * dim),
427            scales: Vec::with_capacity(capacity),
428            norms: Vec::with_capacity(capacity),
429            dim,
430            n_vectors: 0,
431        }
432    }
433
434    /// Add a vector to the index
435    pub fn add(&mut self, vector: &Int8Vector) {
436        debug_assert_eq!(vector.dim(), self.dim, "Dimension mismatch");
437        self.vectors.extend_from_slice(&vector.data);
438        self.scales.push(vector.scale);
439        self.norms.push(vector.norm);
440        self.n_vectors += 1;
441    }
442
443    /// Add a fp32 vector (will be quantized)
444    pub fn add_f32(&mut self, vector: &[f32]) {
445        let int8 = Int8Vector::from_f32(vector);
446        self.add(&int8);
447    }
448
449    /// Get number of vectors
450    #[inline]
451    pub fn len(&self) -> usize {
452        self.n_vectors
453    }
454
455    /// Check if empty
456    #[inline]
457    pub fn is_empty(&self) -> bool {
458        self.n_vectors == 0
459    }
460
461    /// Get memory usage in bytes
462    pub fn memory_bytes(&self) -> usize {
463        self.vectors.len() + self.scales.len() * 4 + self.norms.len() * 4
464    }
465
466    /// Get a vector by index
467    pub fn get(&self, idx: usize) -> Option<Int8Vector> {
468        if idx >= self.n_vectors {
469            return None;
470        }
471        let start = idx * self.dim;
472        let end = start + self.dim;
473        Some(Int8Vector::from_raw(
474            self.vectors[start..end].to_vec(),
475            self.scales[idx],
476            self.norms[idx],
477        ))
478    }
479
480    /// Get raw data slice for a vector
481    #[inline]
482    pub fn get_data(&self, idx: usize) -> &[i8] {
483        let start = idx * self.dim;
484        let end = start + self.dim;
485        &self.vectors[start..end]
486    }
487
488    /// Compute dot product with fp32 query for a specific vector
489    #[inline]
490    pub fn dot_product_f32(&self, idx: usize, query: &[f32]) -> f32 {
491        let data = self.get_data(idx);
492        let scale = self.scales[idx];
493        dot_product_i8_f32_simd(data, query) * scale
494    }
495
496    /// Rescore candidates from binary search using int8 dot product
497    ///
498    /// Takes (index, hamming_distance) pairs and returns (index, rescored_distance).
499    pub fn rescore_candidates(
500        &self,
501        candidates: &[(usize, u32)],
502        query: &[f32],
503    ) -> Vec<(usize, f32)> {
504        let mut results: Vec<(usize, f32)> = candidates
505            .iter()
506            .filter_map(|&(idx, _)| {
507                if idx < self.n_vectors {
508                    // Use negative dot product as distance (higher dot = lower distance)
509                    let dot = self.dot_product_f32(idx, query);
510                    Some((idx, -dot))
511                } else {
512                    None
513                }
514            })
515            .collect();
516
517        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
518        results
519    }
520}
521
522// ============================================================================
523// Tests
524// ============================================================================
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529
530    #[test]
531    fn test_int8_quantization() {
532        let values = vec![1.0, -1.0, 0.5, -0.5, 0.0];
533        let int8 = Int8Vector::from_f32(&values);
534
535        // Max abs is 1.0, scale = 1.0/127 ≈ 0.00787
536        assert_eq!(int8.data[0], 127); // 1.0 -> 127
537        assert_eq!(int8.data[1], -127); // -1.0 -> -127
538        assert_eq!(int8.data[4], 0); // 0.0 -> 0
539    }
540
541    #[test]
542    fn test_dot_product_identical() {
543        let v1 = Int8Vector::from_f32(&[1.0, 2.0, 3.0, 4.0]);
544        let v2 = Int8Vector::from_f32(&[1.0, 2.0, 3.0, 4.0]);
545
546        let dot = v1.dot_product(&v2);
547        let expected = 1.0 + 4.0 + 9.0 + 16.0; // 30.0
548        assert!((dot - expected).abs() < 1.0); // Allow quantization error
549    }
550
551    #[test]
552    fn test_dot_product_f32() {
553        let int8 = Int8Vector::from_f32(&[1.0, 0.0, -1.0, 0.5]);
554        let query = vec![1.0, 1.0, 1.0, 1.0];
555
556        let dot = int8.dot_product_f32(&query);
557        // 1*1 + 0*1 + (-1)*1 + 0.5*1 = 0.5
558        assert!((dot - 0.5).abs() < 0.1);
559    }
560
561    #[test]
562    fn test_compression_ratio() {
563        // fp32: 1024 * 4 = 4096 bytes
564        // int8: 1024 * 1 + 8 = 1032 bytes
565        // ratio: ~4x
566
567        let fp32_size = 1024 * 4;
568        let int8 = Int8Vector::from_f32(&vec![1.0; 1024]);
569        let int8_size = int8.size_bytes();
570
571        assert_eq!(int8_size, 1032);
572        assert!(fp32_size / int8_size >= 3); // At least 3x compression
573    }
574
575    #[test]
576    fn test_index_rescore() {
577        let mut index = Int8Index::new(4);
578
579        index.add_f32(&[1.0, 0.0, 0.0, 0.0]);
580        index.add_f32(&[0.0, 1.0, 0.0, 0.0]);
581        index.add_f32(&[0.0, 0.0, 1.0, 0.0]);
582
583        let query = vec![1.0, 0.0, 0.0, 0.0];
584
585        // Simulate binary search results
586        let binary_candidates = vec![(0, 10), (1, 20), (2, 30)];
587
588        let rescored = index.rescore_candidates(&binary_candidates, &query);
589
590        // Vector 0 should be closest (highest dot product = lowest distance)
591        assert_eq!(rescored[0].0, 0);
592    }
593
594    #[test]
595    fn test_simd_vs_scalar() {
596        let a: Vec<i8> = (0..128).map(|i| (i % 127) as i8).collect();
597        let b: Vec<i8> = (0..128).map(|i| ((127 - i) % 127) as i8).collect();
598
599        let scalar = dot_product_i8_scalar(&a, &b);
600
601        #[cfg(target_arch = "x86_64")]
602        {
603            let simd = dot_product_i8_simd(&a, &b);
604            assert_eq!(scalar, simd);
605        }
606    }
607}