Skip to main content

nodedb_vector/quantize/
sq8.rs

1//! Scalar Quantization (SQ8): FP32 → INT8 per-dimension.
2//!
3//! Each dimension is independently quantized to `[0, 255]` using per-dimension
4//! min/max calibration. This is the **default production quantization** for
5//! HNSW traversal: 4x RAM reduction with <1% recall loss.
6//!
7//! Distance computation uses asymmetric mode: query stays in FP32,
8//! candidates are in INT8. This avoids quantizing the query and
9//! preserves accuracy at the cost of a dequantize-per-dimension
10//! during distance computation.
11//!
12//! Storage: D bytes per vector (vs 4D bytes for FP32).
13
14use serde::{Deserialize, Serialize};
15
16/// SQ8 calibration parameters: per-dimension min/max.
17#[derive(Clone, Serialize, Deserialize)]
18pub struct Sq8Codec {
19    pub dim: usize,
20    /// Per-dimension minimum observed during calibration.
21    mins: Vec<f32>,
22    /// Per-dimension maximum observed during calibration.
23    maxs: Vec<f32>,
24    /// Pre-computed per-dimension scale: `(max - min) / 255.0`.
25    /// Zero if max == min (constant dimension → all quantize to 0).
26    scales: Vec<f32>,
27    /// Pre-computed per-dimension inverse scale: `255.0 / (max - min)`.
28    inv_scales: Vec<f32>,
29}
30
31impl Sq8Codec {
32    /// Calibrate min/max from a set of training vectors.
33    ///
34    /// Scans all vectors to find per-dimension min/max bounds.
35    /// At least 1000 vectors recommended for stable calibration;
36    /// for fewer vectors the bounds may be tight, causing clipping
37    /// on future inserts outside the calibration range.
38    pub fn calibrate(vectors: &[&[f32]], dim: usize) -> Self {
39        assert!(!vectors.is_empty(), "cannot calibrate on empty set");
40        assert!(dim > 0);
41
42        let mut mins = vec![f32::MAX; dim];
43        let mut maxs = vec![f32::MIN; dim];
44
45        for v in vectors {
46            debug_assert_eq!(v.len(), dim);
47            for d in 0..dim {
48                if v[d] < mins[d] {
49                    mins[d] = v[d];
50                }
51                if v[d] > maxs[d] {
52                    maxs[d] = v[d];
53                }
54            }
55        }
56
57        let mut scales = vec![0.0f32; dim];
58        let mut inv_scales = vec![0.0f32; dim];
59        for d in 0..dim {
60            let range = maxs[d] - mins[d];
61            if range > f32::EPSILON {
62                scales[d] = range / 255.0;
63                inv_scales[d] = 255.0 / range;
64            }
65        }
66
67        Self {
68            dim,
69            mins,
70            maxs,
71            scales,
72            inv_scales,
73        }
74    }
75
76    /// Quantize a single FP32 vector to INT8.
77    pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
78        debug_assert_eq!(vector.len(), self.dim);
79        let mut out = Vec::with_capacity(self.dim);
80        for ((&v, &min), (&max, &inv_scale)) in vector
81            .iter()
82            .zip(self.mins.iter())
83            .zip(self.maxs.iter().zip(self.inv_scales.iter()))
84        {
85            let clamped = v.clamp(min, max);
86            let q = ((clamped - min) * inv_scale).round() as u8;
87            out.push(q);
88        }
89        out
90    }
91
92    /// Batch quantize: quantize all vectors into a contiguous byte array.
93    ///
94    /// Returns `dim * N` bytes laid out as `[v0_d0, v0_d1, ..., v1_d0, ...]`.
95    pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<u8> {
96        let mut out = Vec::with_capacity(self.dim * vectors.len());
97        for v in vectors {
98            out.extend(self.quantize(v));
99        }
100        out
101    }
102
103    /// Dequantize INT8 back to FP32 (lossy reconstruction).
104    pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
105        debug_assert_eq!(quantized.len(), self.dim);
106        let mut out = Vec::with_capacity(self.dim);
107        for ((&q, &min), &scale) in quantized
108            .iter()
109            .zip(self.mins.iter())
110            .zip(self.scales.iter())
111        {
112            out.push(min + q as f32 * scale);
113        }
114        out
115    }
116
117    /// Asymmetric L2 squared distance: query (FP32) vs candidate (INT8).
118    ///
119    /// This is the hot-path function used during HNSW traversal.
120    /// The query stays in full precision; only the candidate is quantized.
121    #[inline]
122    pub fn asymmetric_l2(&self, query: &[f32], candidate: &[u8]) -> f32 {
123        debug_assert_eq!(query.len(), self.dim);
124        debug_assert_eq!(candidate.len(), self.dim);
125        let mut sum = 0.0f32;
126        for d in 0..self.dim {
127            let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
128            let diff = query[d] - dequant;
129            sum += diff * diff;
130        }
131        sum
132    }
133
134    /// Asymmetric cosine distance: query (FP32) vs candidate (INT8).
135    #[inline]
136    pub fn asymmetric_cosine(&self, query: &[f32], candidate: &[u8]) -> f32 {
137        debug_assert_eq!(query.len(), self.dim);
138        debug_assert_eq!(candidate.len(), self.dim);
139        let mut dot = 0.0f32;
140        let mut norm_q = 0.0f32;
141        let mut norm_c = 0.0f32;
142        for d in 0..self.dim {
143            let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
144            dot += query[d] * dequant;
145            norm_q += query[d] * query[d];
146            norm_c += dequant * dequant;
147        }
148        let denom = (norm_q * norm_c).sqrt();
149        if denom < f32::EPSILON {
150            return 1.0;
151        }
152        (1.0 - dot / denom).max(0.0)
153    }
154
155    /// Asymmetric negative inner product: query (FP32) vs candidate (INT8).
156    #[inline]
157    pub fn asymmetric_ip(&self, query: &[f32], candidate: &[u8]) -> f32 {
158        debug_assert_eq!(query.len(), self.dim);
159        debug_assert_eq!(candidate.len(), self.dim);
160        let mut dot = 0.0f32;
161        for d in 0..self.dim {
162            let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
163            dot += query[d] * dequant;
164        }
165        -dot
166    }
167
168    /// Dimension count.
169    pub fn dim(&self) -> usize {
170        self.dim
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    fn make_vectors() -> Vec<Vec<f32>> {
179        (0..100)
180            .map(|i| vec![i as f32 * 0.1, (i as f32).sin(), (i as f32).cos()])
181            .collect()
182    }
183
184    #[test]
185    fn quantize_dequantize_roundtrip() {
186        let vecs = make_vectors();
187        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
188        let codec = Sq8Codec::calibrate(&refs, 3);
189
190        for v in &vecs {
191            let q = codec.quantize(v);
192            let dq = codec.dequantize(&q);
193            for d in 0..3 {
194                let error = (v[d] - dq[d]).abs();
195                let range = codec.maxs[d] - codec.mins[d];
196                // Error should be at most half a quantization step.
197                assert!(
198                    error <= range / 255.0 + 1e-6,
199                    "d={d}: error={error}, max_step={}",
200                    range / 255.0
201                );
202            }
203        }
204    }
205
206    #[test]
207    fn asymmetric_l2_close_to_exact() {
208        let vecs = make_vectors();
209        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
210        let codec = Sq8Codec::calibrate(&refs, 3);
211
212        let query = &[5.0, 0.5, -0.5];
213        for v in &vecs {
214            let q = codec.quantize(v);
215            let exact = crate::distance::l2_squared(query, v);
216            let approx = codec.asymmetric_l2(query, &q);
217            // Allow up to 5% relative error.
218            let rel_error = if exact > 0.01 {
219                (exact - approx).abs() / exact
220            } else {
221                (exact - approx).abs()
222            };
223            assert!(
224                rel_error < 0.05 || (exact - approx).abs() < 0.1,
225                "exact={exact}, approx={approx}, rel_error={rel_error}"
226            );
227        }
228    }
229
230    #[test]
231    fn batch_quantize() {
232        let vecs = make_vectors();
233        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
234        let codec = Sq8Codec::calibrate(&refs, 3);
235
236        let batch = codec.quantize_batch(&refs);
237        assert_eq!(batch.len(), 3 * 100);
238
239        // First vector should match individual quantize.
240        let single = codec.quantize(&vecs[0]);
241        assert_eq!(&batch[0..3], &single[..]);
242    }
243
244    #[test]
245    fn constant_dimension_handled() {
246        // All vectors have the same value in dimension 0.
247        let vecs: Vec<Vec<f32>> = (0..10).map(|i| vec![5.0, i as f32]).collect();
248        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
249        let codec = Sq8Codec::calibrate(&refs, 2);
250
251        // Constant dimension should quantize to 0 without NaN/inf.
252        let q = codec.quantize(&[5.0, 3.0]);
253        assert_eq!(q[0], 0); // constant dim
254    }
255}