Skip to main content

nodedb_vector/quantize/
sq8.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Scalar Quantization (SQ8): FP32 → INT8 per-dimension.
4//!
5//! Each dimension is independently quantized to `[0, 255]` using per-dimension
6//! min/max calibration. This is the **default production quantization** for
7//! HNSW traversal: 4x RAM reduction with <1% recall loss.
8//!
9//! Distance computation uses asymmetric mode: query stays in FP32,
10//! candidates are in INT8. This avoids quantizing the query and
11//! preserves accuracy at the cost of a dequantize-per-dimension
12//! during distance computation.
13//!
14//! Storage: D bytes per vector (vs 4D bytes for FP32).
15
16use serde::{Deserialize, Serialize};
17
18use crate::error::VectorError;
19
20/// Magic bytes identifying a serialized [`Sq8Codec`] blob.
21///
22/// Format: `[NDSQ\0\0 (6 bytes)][version: u8 = 1][msgpack payload]`
23pub const MAGIC: &[u8; 6] = b"NDSQ\0\0";
24
25/// Wire format version for [`Sq8Codec`] serialization.
26pub const SQ8_FORMAT_VERSION: u8 = 1;
27
28/// SQ8 calibration parameters: per-dimension min/max.
29#[derive(Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
30pub struct Sq8Codec {
31    pub dim: usize,
32    /// Per-dimension minimum observed during calibration.
33    mins: Vec<f32>,
34    /// Per-dimension maximum observed during calibration.
35    maxs: Vec<f32>,
36    /// Pre-computed per-dimension scale: `(max - min) / 255.0`.
37    /// Zero if max == min (constant dimension → all quantize to 0).
38    scales: Vec<f32>,
39    /// Pre-computed per-dimension inverse scale: `255.0 / (max - min)`.
40    inv_scales: Vec<f32>,
41}
42
43impl Sq8Codec {
44    /// Calibrate min/max from a set of training vectors.
45    ///
46    /// Scans all vectors to find per-dimension min/max bounds.
47    /// At least 1000 vectors recommended for stable calibration;
48    /// for fewer vectors the bounds may be tight, causing clipping
49    /// on future inserts outside the calibration range.
50    pub fn calibrate(vectors: &[&[f32]], dim: usize) -> Self {
51        assert!(!vectors.is_empty(), "cannot calibrate on empty set");
52        assert!(dim > 0);
53
54        let mut mins = vec![f32::MAX; dim];
55        let mut maxs = vec![f32::MIN; dim];
56
57        for v in vectors {
58            debug_assert_eq!(v.len(), dim);
59            for d in 0..dim {
60                if v[d] < mins[d] {
61                    mins[d] = v[d];
62                }
63                if v[d] > maxs[d] {
64                    maxs[d] = v[d];
65                }
66            }
67        }
68
69        let mut scales = vec![0.0f32; dim];
70        let mut inv_scales = vec![0.0f32; dim];
71        for d in 0..dim {
72            let range = maxs[d] - mins[d];
73            if range > f32::EPSILON {
74                scales[d] = range / 255.0;
75                inv_scales[d] = 255.0 / range;
76            }
77        }
78
79        Self {
80            dim,
81            mins,
82            maxs,
83            scales,
84            inv_scales,
85        }
86    }
87
88    /// Quantize a single FP32 vector to INT8.
89    pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
90        debug_assert_eq!(vector.len(), self.dim);
91        let mut out = Vec::with_capacity(self.dim);
92        for ((&v, &min), (&max, &inv_scale)) in vector
93            .iter()
94            .zip(self.mins.iter())
95            .zip(self.maxs.iter().zip(self.inv_scales.iter()))
96        {
97            let clamped = v.clamp(min, max);
98            let q = ((clamped - min) * inv_scale).round() as u8;
99            out.push(q);
100        }
101        out
102    }
103
104    /// Batch quantize: quantize all vectors into a contiguous byte array.
105    ///
106    /// Returns `dim * N` bytes laid out as `[v0_d0, v0_d1, ..., v1_d0, ...]`.
107    pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<u8> {
108        let mut out = Vec::with_capacity(self.dim * vectors.len());
109        for v in vectors {
110            out.extend(self.quantize(v));
111        }
112        out
113    }
114
115    /// Dequantize INT8 back to FP32 (lossy reconstruction).
116    pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
117        debug_assert_eq!(quantized.len(), self.dim);
118        let mut out = Vec::with_capacity(self.dim);
119        for ((&q, &min), &scale) in quantized
120            .iter()
121            .zip(self.mins.iter())
122            .zip(self.scales.iter())
123        {
124            out.push(min + q as f32 * scale);
125        }
126        out
127    }
128
129    /// Asymmetric L2 squared distance: query (FP32) vs candidate (INT8).
130    ///
131    /// This is the hot-path function used during HNSW traversal.
132    /// The query stays in full precision; only the candidate is quantized.
133    #[inline]
134    pub fn asymmetric_l2(&self, query: &[f32], candidate: &[u8]) -> f32 {
135        debug_assert_eq!(query.len(), self.dim);
136        debug_assert_eq!(candidate.len(), self.dim);
137        let mut sum = 0.0f32;
138        for d in 0..self.dim {
139            let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
140            let diff = query[d] - dequant;
141            sum += diff * diff;
142        }
143        sum
144    }
145
146    /// Asymmetric cosine distance: query (FP32) vs candidate (INT8).
147    #[inline]
148    pub fn asymmetric_cosine(&self, query: &[f32], candidate: &[u8]) -> f32 {
149        debug_assert_eq!(query.len(), self.dim);
150        debug_assert_eq!(candidate.len(), self.dim);
151        let mut dot = 0.0f32;
152        let mut norm_q = 0.0f32;
153        let mut norm_c = 0.0f32;
154        for d in 0..self.dim {
155            let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
156            dot += query[d] * dequant;
157            norm_q += query[d] * query[d];
158            norm_c += dequant * dequant;
159        }
160        let denom = (norm_q * norm_c).sqrt();
161        if denom < f32::EPSILON {
162            return 1.0;
163        }
164        (1.0 - dot / denom).max(0.0)
165    }
166
167    /// Asymmetric negative inner product: query (FP32) vs candidate (INT8).
168    #[inline]
169    pub fn asymmetric_ip(&self, query: &[f32], candidate: &[u8]) -> f32 {
170        debug_assert_eq!(query.len(), self.dim);
171        debug_assert_eq!(candidate.len(), self.dim);
172        let mut dot = 0.0f32;
173        for d in 0..self.dim {
174            let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
175            dot += query[d] * dequant;
176        }
177        -dot
178    }
179
180    /// Dimension count.
181    pub fn dim(&self) -> usize {
182        self.dim
183    }
184
185    /// Serialize the codec to bytes with a versioned magic header.
186    ///
187    /// Format: `[NDSQ\0\0 (6 bytes)][version: u8 = 1][msgpack payload]`
188    pub fn to_bytes(&self) -> Vec<u8> {
189        let payload = zerompk::to_msgpack_vec(self).unwrap_or_default();
190        let mut out = Vec::with_capacity(7 + payload.len());
191        out.extend_from_slice(MAGIC);
192        out.push(SQ8_FORMAT_VERSION);
193        out.extend_from_slice(&payload);
194        out
195    }
196
197    /// Deserialize the codec from bytes produced by [`Self::to_bytes`].
198    ///
199    /// Returns `VectorError::InvalidMagic` if the header does not match
200    /// `NDSQ\0\0`, and `VectorError::UnsupportedVersion` for unknown versions.
201    pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
202        if bytes.len() < 7 || &bytes[0..6] != MAGIC {
203            return Err(VectorError::InvalidMagic);
204        }
205        let version = bytes[6];
206        if version != SQ8_FORMAT_VERSION {
207            return Err(VectorError::UnsupportedVersion {
208                found: version,
209                expected: SQ8_FORMAT_VERSION,
210            });
211        }
212        zerompk::from_msgpack::<Self>(&bytes[7..])
213            .map_err(|e| VectorError::DeserializationFailed(e.to_string()))
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    fn make_codec() -> Sq8Codec {
222        let vecs: Vec<Vec<f32>> = (0..100)
223            .map(|i| vec![i as f32 * 0.1, (i as f32).sin(), (i as f32).cos()])
224            .collect();
225        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
226        Sq8Codec::calibrate(&refs, 3)
227    }
228
229    #[test]
230    fn sq8_codec_golden_format() {
231        let codec = make_codec();
232        let bytes = codec.to_bytes();
233        // First 6 bytes are magic.
234        assert_eq!(&bytes[0..6], MAGIC);
235        // Byte 6 is the version.
236        assert_eq!(bytes[6], SQ8_FORMAT_VERSION);
237        // Bytes 7+ must decode back to a valid Sq8Codec.
238        let decoded = zerompk::from_msgpack::<Sq8Codec>(&bytes[7..]).unwrap();
239        assert_eq!(decoded.dim, 3);
240    }
241
242    #[test]
243    fn sq8_roundtrip() {
244        let codec = make_codec();
245        let bytes = codec.to_bytes();
246        let restored = Sq8Codec::from_bytes(&bytes).unwrap();
247        assert_eq!(restored.dim, codec.dim);
248        assert_eq!(restored.inv_scales.len(), codec.inv_scales.len());
249        for (a, b) in restored.inv_scales.iter().zip(codec.inv_scales.iter()) {
250            assert!((a - b).abs() < 1e-6, "inv_scales mismatch: {a} vs {b}");
251        }
252    }
253
254    #[test]
255    fn sq8_invalid_magic_returns_error() {
256        let mut bytes = make_codec().to_bytes();
257        bytes[0] = b'X'; // corrupt magic
258        assert!(matches!(
259            Sq8Codec::from_bytes(&bytes),
260            Err(VectorError::InvalidMagic)
261        ));
262    }
263
264    #[test]
265    fn sq8_version_mismatch_returns_error() {
266        let mut bytes = make_codec().to_bytes();
267        bytes[6] = 0; // wrong version
268        assert!(matches!(
269            Sq8Codec::from_bytes(&bytes),
270            Err(VectorError::UnsupportedVersion {
271                found: 0,
272                expected: 1
273            })
274        ));
275    }
276
277    fn make_vectors() -> Vec<Vec<f32>> {
278        (0..100)
279            .map(|i| vec![i as f32 * 0.1, (i as f32).sin(), (i as f32).cos()])
280            .collect()
281    }
282
283    #[test]
284    fn quantize_dequantize_roundtrip() {
285        let vecs = make_vectors();
286        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
287        let codec = Sq8Codec::calibrate(&refs, 3);
288
289        for v in &vecs {
290            let q = codec.quantize(v);
291            let dq = codec.dequantize(&q);
292            for d in 0..3 {
293                let error = (v[d] - dq[d]).abs();
294                let range = codec.maxs[d] - codec.mins[d];
295                // Error should be at most half a quantization step.
296                assert!(
297                    error <= range / 255.0 + 1e-6,
298                    "d={d}: error={error}, max_step={}",
299                    range / 255.0
300                );
301            }
302        }
303    }
304
305    #[test]
306    fn asymmetric_l2_close_to_exact() {
307        let vecs = make_vectors();
308        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
309        let codec = Sq8Codec::calibrate(&refs, 3);
310
311        let query = &[5.0, 0.5, -0.5];
312        for v in &vecs {
313            let q = codec.quantize(v);
314            let exact = crate::distance::l2_squared(query, v);
315            let approx = codec.asymmetric_l2(query, &q);
316            // Allow up to 5% relative error.
317            let rel_error = if exact > 0.01 {
318                (exact - approx).abs() / exact
319            } else {
320                (exact - approx).abs()
321            };
322            assert!(
323                rel_error < 0.05 || (exact - approx).abs() < 0.1,
324                "exact={exact}, approx={approx}, rel_error={rel_error}"
325            );
326        }
327    }
328
329    #[test]
330    fn batch_quantize() {
331        let vecs = make_vectors();
332        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
333        let codec = Sq8Codec::calibrate(&refs, 3);
334
335        let batch = codec.quantize_batch(&refs);
336        assert_eq!(batch.len(), 3 * 100);
337
338        // First vector should match individual quantize.
339        let single = codec.quantize(&vecs[0]);
340        assert_eq!(&batch[0..3], &single[..]);
341    }
342
343    #[test]
344    fn constant_dimension_handled() {
345        // All vectors have the same value in dimension 0.
346        let vecs: Vec<Vec<f32>> = (0..10).map(|i| vec![5.0, i as f32]).collect();
347        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
348        let codec = Sq8Codec::calibrate(&refs, 2);
349
350        // Constant dimension should quantize to 0 without NaN/inf.
351        let q = codec.quantize(&[5.0, 3.0]);
352        assert_eq!(q[0], 0); // constant dim
353    }
354}