Skip to main content

nodedb_vector/quantize/
sq8.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Scalar Quantization (SQ8): FP32 → INT8 per-dimension.
4//!
5//! Each dimension is independently quantized to `[0, 255]` using per-dimension
6//! min/max calibration. This is the **default production quantization** for
7//! HNSW traversal: 4x RAM reduction with <1% recall loss.
8//!
9//! Distance computation uses asymmetric mode: query stays in FP32,
10//! candidates are in INT8. This avoids quantizing the query and
11//! preserves accuracy at the cost of a dequantize-per-dimension
12//! during distance computation.
13//!
14//! Storage: D bytes per vector (vs 4D bytes for FP32).
15
16use serde::{Deserialize, Serialize};
17
18use crate::error::VectorError;
19
20/// Magic bytes identifying a serialized [`Sq8Codec`] blob.
21///
22/// Format: `[NDSQ\0\0 (6 bytes)][version: u8 = 1][msgpack payload]`
23pub const MAGIC: &[u8; 6] = b"NDSQ\0\0";
24
25/// Wire format version for [`Sq8Codec`] serialization.
26pub const SQ8_FORMAT_VERSION: u8 = 1;
27
28/// SQ8 calibration parameters: per-dimension min/max.
29#[derive(Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
30pub struct Sq8Codec {
31    pub dim: usize,
32    /// Per-dimension minimum observed during calibration.
33    mins: Vec<f32>,
34    /// Per-dimension maximum observed during calibration.
35    maxs: Vec<f32>,
36    /// Pre-computed per-dimension scale: `(max - min) / 255.0`.
37    /// Zero if max == min (constant dimension → all quantize to 0).
38    scales: Vec<f32>,
39    /// Pre-computed per-dimension inverse scale: `255.0 / (max - min)`.
40    inv_scales: Vec<f32>,
41}
42
43impl Sq8Codec {
44    /// Calibrate min/max from a set of training vectors.
45    ///
46    /// Scans all vectors to find per-dimension min/max bounds.
47    /// At least 1000 vectors recommended for stable calibration;
48    /// for fewer vectors the bounds may be tight, causing clipping
49    /// on future inserts outside the calibration range.
50    pub fn calibrate(vectors: &[&[f32]], dim: usize) -> Self {
51        assert!(!vectors.is_empty(), "cannot calibrate on empty set");
52        assert!(dim > 0);
53
54        let mut mins = vec![f32::MAX; dim];
55        let mut maxs = vec![f32::MIN; dim];
56
57        for v in vectors {
58            debug_assert_eq!(v.len(), dim);
59            for d in 0..dim {
60                if v[d] < mins[d] {
61                    mins[d] = v[d];
62                }
63                if v[d] > maxs[d] {
64                    maxs[d] = v[d];
65                }
66            }
67        }
68
69        let mut scales = vec![0.0f32; dim];
70        let mut inv_scales = vec![0.0f32; dim];
71        for d in 0..dim {
72            let range = maxs[d] - mins[d];
73            if range > f32::EPSILON {
74                scales[d] = range / 255.0;
75                inv_scales[d] = 255.0 / range;
76            }
77        }
78
79        Self {
80            dim,
81            mins,
82            maxs,
83            scales,
84            inv_scales,
85        }
86    }
87
88    /// Quantize a single FP32 vector to INT8.
89    pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
90        debug_assert_eq!(vector.len(), self.dim);
91        // no-governor: hot-path per-vector quantize; dim-bounded, instrument cost exceeds benefit
92        let mut out = Vec::with_capacity(self.dim);
93        for ((&v, &min), (&max, &inv_scale)) in vector
94            .iter()
95            .zip(self.mins.iter())
96            .zip(self.maxs.iter().zip(self.inv_scales.iter()))
97        {
98            let clamped = v.clamp(min, max);
99            let q = ((clamped - min) * inv_scale).round() as u8;
100            out.push(q);
101        }
102        out
103    }
104
105    /// Batch quantize: quantize all vectors into a contiguous byte array.
106    ///
107    /// Returns `dim * N` bytes laid out as `[v0_d0, v0_d1, ..., v1_d0, ...]`.
108    pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<u8> {
109        // no-governor: cold batch quantize; governed at segment build call site
110        let mut out = Vec::with_capacity(self.dim * vectors.len());
111        for v in vectors {
112            out.extend(self.quantize(v));
113        }
114        out
115    }
116
117    /// Dequantize INT8 back to FP32 (lossy reconstruction).
118    pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
119        debug_assert_eq!(quantized.len(), self.dim);
120        // no-governor: hot-path per-vector dequantize; dim-bounded, instrument cost exceeds benefit
121        let mut out = Vec::with_capacity(self.dim);
122        for ((&q, &min), &scale) in quantized
123            .iter()
124            .zip(self.mins.iter())
125            .zip(self.scales.iter())
126        {
127            out.push(min + q as f32 * scale);
128        }
129        out
130    }
131
132    /// Asymmetric L2 squared distance: query (FP32) vs candidate (INT8).
133    ///
134    /// This is the hot-path function used during HNSW traversal.
135    /// The query stays in full precision; only the candidate is quantized.
136    #[inline]
137    pub fn asymmetric_l2(&self, query: &[f32], candidate: &[u8]) -> f32 {
138        debug_assert_eq!(query.len(), self.dim);
139        debug_assert_eq!(candidate.len(), self.dim);
140        let mut sum = 0.0f32;
141        for d in 0..self.dim {
142            let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
143            let diff = query[d] - dequant;
144            sum += diff * diff;
145        }
146        sum
147    }
148
149    /// Asymmetric cosine distance: query (FP32) vs candidate (INT8).
150    #[inline]
151    pub fn asymmetric_cosine(&self, query: &[f32], candidate: &[u8]) -> f32 {
152        debug_assert_eq!(query.len(), self.dim);
153        debug_assert_eq!(candidate.len(), self.dim);
154        let mut dot = 0.0f32;
155        let mut norm_q = 0.0f32;
156        let mut norm_c = 0.0f32;
157        for d in 0..self.dim {
158            let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
159            dot += query[d] * dequant;
160            norm_q += query[d] * query[d];
161            norm_c += dequant * dequant;
162        }
163        let denom = (norm_q * norm_c).sqrt();
164        if denom < f32::EPSILON {
165            return 1.0;
166        }
167        (1.0 - dot / denom).max(0.0)
168    }
169
170    /// Asymmetric negative inner product: query (FP32) vs candidate (INT8).
171    #[inline]
172    pub fn asymmetric_ip(&self, query: &[f32], candidate: &[u8]) -> f32 {
173        debug_assert_eq!(query.len(), self.dim);
174        debug_assert_eq!(candidate.len(), self.dim);
175        let mut dot = 0.0f32;
176        for d in 0..self.dim {
177            let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
178            dot += query[d] * dequant;
179        }
180        -dot
181    }
182
183    /// Dimension count.
184    pub fn dim(&self) -> usize {
185        self.dim
186    }
187
188    /// Serialize the codec to bytes with a versioned magic header.
189    ///
190    /// Format: `[NDSQ\0\0 (6 bytes)][version: u8 = 1][msgpack payload]`
191    pub fn to_bytes(&self) -> Vec<u8> {
192        let payload = zerompk::to_msgpack_vec(self).unwrap_or_default();
193        // no-governor: cold serialization; fixed header + msgpack payload, governed at checkpoint call site
194        let mut out = Vec::with_capacity(7 + payload.len());
195        out.extend_from_slice(MAGIC);
196        out.push(SQ8_FORMAT_VERSION);
197        out.extend_from_slice(&payload);
198        out
199    }
200
201    /// Deserialize the codec from bytes produced by [`Self::to_bytes`].
202    ///
203    /// Returns `VectorError::InvalidMagic` if the header does not match
204    /// `NDSQ\0\0`, and `VectorError::UnsupportedVersion` for unknown versions.
205    pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
206        if bytes.len() < 7 || &bytes[0..6] != MAGIC {
207            return Err(VectorError::InvalidMagic);
208        }
209        let version = bytes[6];
210        if version != SQ8_FORMAT_VERSION {
211            return Err(VectorError::UnsupportedVersion {
212                found: version,
213                expected: SQ8_FORMAT_VERSION,
214            });
215        }
216        zerompk::from_msgpack::<Self>(&bytes[7..])
217            .map_err(|e| VectorError::DeserializationFailed(e.to_string()))
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    fn make_codec() -> Sq8Codec {
226        let vecs: Vec<Vec<f32>> = (0..100)
227            .map(|i| vec![i as f32 * 0.1, (i as f32).sin(), (i as f32).cos()])
228            .collect();
229        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
230        Sq8Codec::calibrate(&refs, 3)
231    }
232
233    #[test]
234    fn sq8_codec_golden_format() {
235        let codec = make_codec();
236        let bytes = codec.to_bytes();
237        // First 6 bytes are magic.
238        assert_eq!(&bytes[0..6], MAGIC);
239        // Byte 6 is the version.
240        assert_eq!(bytes[6], SQ8_FORMAT_VERSION);
241        // Bytes 7+ must decode back to a valid Sq8Codec.
242        let decoded = zerompk::from_msgpack::<Sq8Codec>(&bytes[7..]).unwrap();
243        assert_eq!(decoded.dim, 3);
244    }
245
246    #[test]
247    fn sq8_roundtrip() {
248        let codec = make_codec();
249        let bytes = codec.to_bytes();
250        let restored = Sq8Codec::from_bytes(&bytes).unwrap();
251        assert_eq!(restored.dim, codec.dim);
252        assert_eq!(restored.inv_scales.len(), codec.inv_scales.len());
253        for (a, b) in restored.inv_scales.iter().zip(codec.inv_scales.iter()) {
254            assert!((a - b).abs() < 1e-6, "inv_scales mismatch: {a} vs {b}");
255        }
256    }
257
258    #[test]
259    fn sq8_invalid_magic_returns_error() {
260        let mut bytes = make_codec().to_bytes();
261        bytes[0] = b'X'; // corrupt magic
262        assert!(matches!(
263            Sq8Codec::from_bytes(&bytes),
264            Err(VectorError::InvalidMagic)
265        ));
266    }
267
268    #[test]
269    fn sq8_version_mismatch_returns_error() {
270        let mut bytes = make_codec().to_bytes();
271        bytes[6] = 0; // wrong version
272        assert!(matches!(
273            Sq8Codec::from_bytes(&bytes),
274            Err(VectorError::UnsupportedVersion {
275                found: 0,
276                expected: 1
277            })
278        ));
279    }
280
281    fn make_vectors() -> Vec<Vec<f32>> {
282        (0..100)
283            .map(|i| vec![i as f32 * 0.1, (i as f32).sin(), (i as f32).cos()])
284            .collect()
285    }
286
287    #[test]
288    fn quantize_dequantize_roundtrip() {
289        let vecs = make_vectors();
290        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
291        let codec = Sq8Codec::calibrate(&refs, 3);
292
293        for v in &vecs {
294            let q = codec.quantize(v);
295            let dq = codec.dequantize(&q);
296            for d in 0..3 {
297                let error = (v[d] - dq[d]).abs();
298                let range = codec.maxs[d] - codec.mins[d];
299                // Error should be at most half a quantization step.
300                assert!(
301                    error <= range / 255.0 + 1e-6,
302                    "d={d}: error={error}, max_step={}",
303                    range / 255.0
304                );
305            }
306        }
307    }
308
309    #[test]
310    fn asymmetric_l2_close_to_exact() {
311        let vecs = make_vectors();
312        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
313        let codec = Sq8Codec::calibrate(&refs, 3);
314
315        let query = &[5.0, 0.5, -0.5];
316        for v in &vecs {
317            let q = codec.quantize(v);
318            let exact = crate::distance::l2_squared(query, v);
319            let approx = codec.asymmetric_l2(query, &q);
320            // Allow up to 5% relative error.
321            let rel_error = if exact > 0.01 {
322                (exact - approx).abs() / exact
323            } else {
324                (exact - approx).abs()
325            };
326            assert!(
327                rel_error < 0.05 || (exact - approx).abs() < 0.1,
328                "exact={exact}, approx={approx}, rel_error={rel_error}"
329            );
330        }
331    }
332
333    #[test]
334    fn batch_quantize() {
335        let vecs = make_vectors();
336        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
337        let codec = Sq8Codec::calibrate(&refs, 3);
338
339        let batch = codec.quantize_batch(&refs);
340        assert_eq!(batch.len(), 3 * 100);
341
342        // First vector should match individual quantize.
343        let single = codec.quantize(&vecs[0]);
344        assert_eq!(&batch[0..3], &single[..]);
345    }
346
347    #[test]
348    fn constant_dimension_handled() {
349        // All vectors have the same value in dimension 0.
350        let vecs: Vec<Vec<f32>> = (0..10).map(|i| vec![5.0, i as f32]).collect();
351        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
352        let codec = Sq8Codec::calibrate(&refs, 2);
353
354        // Constant dimension should quantize to 0 without NaN/inf.
355        let q = codec.quantize(&[5.0, 3.0]);
356        assert_eq!(q[0], 0); // constant dim
357    }
358}