Skip to main content

lean_ctx/core/
embedding_quant.rs

1//! int8 scalar quantization + SIMD-friendly scoring for embedding vectors.
2//!
3//! Adapted from the TurboQuant approach (RyanCodrai/turbovec, ICLR 2026): a
4//! data-oblivious, training-free, single-pass quantizer. At lean-ctx's scale
5//! (hundreds of facts × 384-dim MiniLM vectors) the win is twofold:
6//!   1. **4× smaller** on-disk knowledge index (`i8` codes vs `f32`).
7//!   2. **Faster scoring** — the query is rotated once into the codebook domain
8//!      (the per-vector `scale`) and accumulated directly over `i8` codes, so we
9//!      never reconstruct the full `f32` document vector (turbovec's core idea).
10//!
11//! No heavy SIMD crate is pulled in: the chunked-lane accumulators below are
12//! shaped so the autovectorizer emits NEON/AVX automatically, with a scalar tail
13//! that is always correct on every target.
14
15use serde::{Deserialize, Serialize};
16
17/// Largest magnitude representable by a symmetric `i8` code (−127..=127; −128 is
18/// excluded to keep the mapping symmetric and avoid an asymmetric overflow edge).
19const I8_ABS_MAX: f32 = 127.0;
20
21/// Lane width for the chunked accumulators. 8 maps cleanly onto a 256-bit AVX2
22/// f32 register and two 128-bit NEON registers; the scalar tail handles any
23/// remainder (e.g. 384 % 8 == 0, but odd dimensions stay correct).
24const LANES: usize = 8;
25
26/// A vector stored as int8 codes plus the per-vector scale needed to reconstruct
27/// approximate values: `value[i] ≈ code[i] · scale`.
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
29pub struct QuantizedVector {
30    pub code: Vec<i8>,
31    pub scale: f32,
32}
33
34impl QuantizedVector {
35    #[must_use]
36    pub fn dim(&self) -> usize {
37        self.code.len()
38    }
39
40    /// Reconstruct the approximate `f32` vector. Only needed for diagnostics /
41    /// migration; the hot path scores against the codes directly via [`dot_quant`].
42    #[must_use]
43    pub fn dequantize(&self) -> Vec<f32> {
44        self.code
45            .iter()
46            .map(|&c| f32::from(c) * self.scale)
47            .collect()
48    }
49}
50
51/// Symmetric, per-vector quantization: `scale = max|x| / 127`, `code = round(x / scale)`.
52///
53/// Data-oblivious (no codebook training) and single-pass. A zero vector maps to
54/// all-zero codes with `scale = 0.0`, which [`dot_quant`] treats as a zero result.
55#[must_use]
56pub fn quantize(v: &[f32]) -> QuantizedVector {
57    let max_abs = v.iter().fold(0.0f32, |m, &x| m.max(x.abs()));
58    if max_abs == 0.0 {
59        return QuantizedVector {
60            code: vec![0; v.len()],
61            scale: 0.0,
62        };
63    }
64    let scale = max_abs / I8_ABS_MAX;
65    let inv = 1.0 / scale;
66    let code = v
67        .iter()
68        .map(|&x| {
69            // round-half-away then clamp into the symmetric range before the cast.
70            let q = (x * inv).round().clamp(-I8_ABS_MAX, I8_ABS_MAX);
71            q as i8
72        })
73        .collect();
74    QuantizedVector { code, scale }
75}
76
77/// Asymmetric dot product: full-precision `query` · quantized `doc`.
78///
79/// Computes `Σ query[i] · code[i] · scale` without ever reconstructing the doc
80/// vector. For L2-normalized inputs this approximates cosine similarity; the
81/// quantization error is well within embedding-ranking tolerance.
82#[must_use]
83pub fn dot_quant(query: &[f32], doc: &QuantizedVector) -> f32 {
84    debug_assert_eq!(query.len(), doc.code.len(), "dim mismatch");
85    if doc.scale == 0.0 {
86        return 0.0;
87    }
88
89    let mut lanes = [0.0f32; LANES];
90    let mut q_chunks = query.chunks_exact(LANES);
91    let mut c_chunks = doc.code.chunks_exact(LANES);
92
93    for (q, c) in q_chunks.by_ref().zip(c_chunks.by_ref()) {
94        for i in 0..LANES {
95            lanes[i] += q[i] * f32::from(c[i]);
96        }
97    }
98
99    let mut tail = 0.0f32;
100    for (q, c) in q_chunks.remainder().iter().zip(c_chunks.remainder()) {
101        tail += q * f32::from(*c);
102    }
103
104    (lanes.iter().sum::<f32>() + tail) * doc.scale
105}
106
107/// SIMD-friendly `f32` dot product with chunked lane accumulators.
108///
109/// Numerically a hair different from a naïve left-fold (float add is
110/// non-associative) but far within similarity tolerance, and materially faster
111/// on the 384-dim vectors used for semantic recall.
112#[must_use]
113pub fn dot_f32(a: &[f32], b: &[f32]) -> f32 {
114    debug_assert_eq!(a.len(), b.len(), "dim mismatch");
115
116    let mut lanes = [0.0f32; LANES];
117    let mut a_chunks = a.chunks_exact(LANES);
118    let mut b_chunks = b.chunks_exact(LANES);
119
120    for (x, y) in a_chunks.by_ref().zip(b_chunks.by_ref()) {
121        for i in 0..LANES {
122            lanes[i] += x[i] * y[i];
123        }
124    }
125
126    let mut tail = 0.0f32;
127    for (x, y) in a_chunks.remainder().iter().zip(b_chunks.remainder()) {
128        tail += x * y;
129    }
130
131    lanes.iter().sum::<f32>() + tail
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    fn naive_dot(a: &[f32], b: &[f32]) -> f32 {
139        a.iter().zip(b).map(|(x, y)| x * y).sum()
140    }
141
142    #[test]
143    fn dot_f32_matches_naive_within_tolerance() {
144        let a: Vec<f32> = (0..384).map(|i| (i as f32 * 0.013).sin()).collect();
145        let b: Vec<f32> = (0..384).map(|i| (i as f32 * 0.017).cos()).collect();
146        let chunked = dot_f32(&a, &b);
147        let naive = naive_dot(&a, &b);
148        assert!(
149            (chunked - naive).abs() < 1e-3,
150            "chunked={chunked} naive={naive}"
151        );
152    }
153
154    #[test]
155    fn dot_f32_handles_non_multiple_of_lane_width() {
156        // 13 is not a multiple of LANES (8) → exercises the scalar tail.
157        let a: Vec<f32> = (0..13).map(|i| i as f32).collect();
158        let b: Vec<f32> = (0..13).map(|i| (i * 2) as f32).collect();
159        assert!((dot_f32(&a, &b) - naive_dot(&a, &b)).abs() < 1e-4);
160    }
161
162    #[test]
163    fn quantize_zero_vector_is_zero_scale() {
164        let q = quantize(&[0.0, 0.0, 0.0]);
165        assert_eq!(q.scale, 0.0);
166        assert_eq!(q.code, vec![0, 0, 0]);
167        assert_eq!(dot_quant(&[1.0, 1.0, 1.0], &q), 0.0);
168    }
169
170    #[test]
171    fn quantize_preserves_max_magnitude_at_full_scale() {
172        let q = quantize(&[1.0, 0.0, 0.0]);
173        assert_eq!(q.code[0], 127);
174        assert_eq!(q.code[1], 0);
175        // Reconstructed peak is the original max (scale = max/127, code = 127).
176        let recon = q.dequantize();
177        assert!((recon[0] - 1.0).abs() < 1e-6);
178    }
179
180    #[test]
181    fn dot_quant_approximates_cosine_for_normalized_vectors() {
182        // Two similar L2-normalized vectors: quantized dot must track the true dot.
183        let mut a: Vec<f32> = (0..384).map(|i| (i as f32 * 0.011).sin() + 0.3).collect();
184        let mut b: Vec<f32> = a.iter().map(|x| x + 0.02).collect();
185        l2_normalize(&mut a);
186        l2_normalize(&mut b);
187
188        let exact = naive_dot(&a, &b);
189        let approx = dot_quant(&a, &quantize(&b));
190        assert!(
191            (exact - approx).abs() < 5e-3,
192            "exact={exact} approx={approx}"
193        );
194        // Self-similarity stays ~1.0 after quantization.
195        let self_sim = dot_quant(&a, &quantize(&a));
196        assert!(self_sim > 0.99, "self_sim={self_sim}");
197    }
198
199    #[test]
200    fn dot_quant_preserves_ranking() {
201        // The most similar doc must still score highest after quantization.
202        let query = normalized(vec![1.0, 0.2, 0.0, 0.1]);
203        let near = quantize(&normalized(vec![0.9, 0.3, 0.0, 0.1]));
204        let far = quantize(&normalized(vec![-0.5, 0.8, 0.2, 0.0]));
205        assert!(dot_quant(&query, &near) > dot_quant(&query, &far));
206    }
207
208    fn l2_normalize(v: &mut [f32]) {
209        let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
210        if norm > 0.0 {
211            for x in v.iter_mut() {
212                *x /= norm;
213            }
214        }
215    }
216
217    fn normalized(mut v: Vec<f32>) -> Vec<f32> {
218        l2_normalize(&mut v);
219        v
220    }
221}