Skip to main content

iqdb_quantize/
scalar.rs

1//! [`ScalarQuantizer`] — scalar quantization (SQ8, 4× compression).
2//!
3//! Codes are `u8` per dimension. The calibration is a per-dimension affine
4//! map: each dimension stores its trained `min` and a `scale` defined as
5//! `(max - min) / 255`. Encoding clamps the input into `[min, max]`, scales
6//! it onto `[0, 255]`, and rounds to the nearest integer. Decoding reverses
7//! the affine map. A dimension with `max == min` collapses to a `scale = 0`
8//! lane: every code byte there is `0` and `dequantize` returns `min` —
9//! there is no division by zero.
10//!
11//! Asymmetric distance keeps the query as `f32`, dequantizes the candidate
12//! code into a temporary `Vec<f32>`, and delegates to
13//! [`iqdb_distance::compute`] for every metric. The result honours the
14//! "smaller is nearer" convention used by the rest of the iqdb spine.
15
16use error_forge::ForgeError;
17use iqdb_distance::compute;
18use iqdb_types::{DistanceMetric, IqdbError, Result};
19
20use crate::code::Sq8Code;
21use crate::traits::Quantizer;
22use crate::validate::{dim_eq, finite_non_empty, training_set};
23
24/// Number of `u8` code levels above zero.
25const LEVELS: f32 = 255.0;
26
27/// Calibration learned during [`ScalarQuantizer::train`].
28#[derive(Debug, Clone, PartialEq)]
29struct Sq8Calibration {
30    /// Per-dimension minimum from the training sample.
31    mins: Vec<f32>,
32    /// Per-dimension `(max - min) / 255`. Zero for any zero-range dimension
33    /// (`max == min`); see [`ScalarQuantizer::quantize`] for the guard.
34    scales: Vec<f32>,
35}
36
37/// Scalar quantizer (SQ8): one `u8` per dimension, 4× compression.
38///
39/// Build one with [`ScalarQuantizer::new`] (or [`Default`]), train it once
40/// with a representative sample, then quantize and compare. The trained
41/// quantizer is callable from multiple threads — it owns its calibration
42/// by value and exposes no interior mutability.
43///
44/// # Examples
45///
46/// ```
47/// use iqdb_quantize::{Quantizer, ScalarQuantizer};
48/// use iqdb_types::DistanceMetric;
49///
50/// let mut sq = ScalarQuantizer::new();
51/// sq.train(&[&[0.0_f32, 1.0, 2.0][..], &[1.0_f32, 0.0, 1.0][..]])
52///     .expect("two non-empty, finite vectors of equal dim");
53///
54/// let code = sq.quantize(&[0.5_f32, 0.5, 1.5]).expect("dim matches");
55/// let d = sq
56///     .distance(&[0.5_f32, 0.5, 1.5], &code, DistanceMetric::Euclidean)
57///     .expect("dim matches");
58/// assert!(d.is_finite());
59/// ```
60#[derive(Debug, Clone, Default, PartialEq)]
61pub struct ScalarQuantizer {
62    calibration: Option<Sq8Calibration>,
63}
64
65impl ScalarQuantizer {
66    /// Build an untrained scalar quantizer.
67    ///
68    /// Every hot method returns [`IqdbError::InvalidConfig`] until
69    /// [`Quantizer::train`] succeeds.
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// use iqdb_quantize::ScalarQuantizer;
75    /// let _sq = ScalarQuantizer::new();
76    /// ```
77    #[must_use]
78    pub fn new() -> Self {
79        Self { calibration: None }
80    }
81
82    /// The trained dimension, if any.
83    ///
84    /// # Examples
85    ///
86    /// ```
87    /// use iqdb_quantize::{Quantizer, ScalarQuantizer};
88    ///
89    /// let mut sq = ScalarQuantizer::new();
90    /// assert_eq!(sq.dim(), None);
91    /// sq.train(&[&[0.0_f32, 1.0][..]]).expect("ok");
92    /// assert_eq!(sq.dim(), Some(2));
93    /// ```
94    #[must_use]
95    pub fn dim(&self) -> Option<usize> {
96        self.calibration.as_ref().map(|c| c.mins.len())
97    }
98
99    fn calibration(&self) -> Result<&Sq8Calibration> {
100        self.calibration.as_ref().ok_or(IqdbError::InvalidConfig {
101            reason: "ScalarQuantizer has not been trained",
102        })
103    }
104}
105
106impl Quantizer for ScalarQuantizer {
107    type Quantized = Sq8Code;
108
109    #[tracing::instrument(
110        level = "info",
111        skip_all,
112        fields(quantizer = "sq8", training_size = vectors.len()),
113    )]
114    fn train(&mut self, vectors: &[&[f32]]) -> Result<()> {
115        let dim = training_set(vectors).inspect_err(|err: &IqdbError| {
116            tracing::error!(
117                error.kind = err.kind(),
118                error.reason = err.caption(),
119                "scalar quantizer training failed",
120            );
121        })?;
122        let mut mins = vec![f32::INFINITY; dim];
123        let mut maxs = vec![f32::NEG_INFINITY; dim];
124        for v in vectors {
125            for (i, &x) in v.iter().enumerate() {
126                if x < mins[i] {
127                    mins[i] = x;
128                }
129                if x > maxs[i] {
130                    maxs[i] = x;
131                }
132            }
133        }
134        let mut scales = vec![0.0_f32; dim];
135        for i in 0..dim {
136            let range = maxs[i] - mins[i];
137            scales[i] = if range > 0.0 { range / LEVELS } else { 0.0 };
138        }
139        self.calibration = Some(Sq8Calibration { mins, scales });
140        Ok(())
141    }
142
143    fn quantize(&self, vector: &[f32]) -> Result<Self::Quantized> {
144        let cal = self.calibration()?;
145        finite_non_empty(vector)?;
146        dim_eq(cal.mins.len(), vector.len())?;
147        let mut bytes = Vec::with_capacity(vector.len());
148        for (i, &x) in vector.iter().enumerate() {
149            bytes.push(encode_scalar(x, cal.mins[i], cal.scales[i]));
150        }
151        Ok(Sq8Code { bytes })
152    }
153
154    fn dequantize(&self, quantized: &Self::Quantized) -> Result<Vec<f32>> {
155        let cal = self.calibration()?;
156        dim_eq(cal.mins.len(), quantized.bytes.len())?;
157        let mut out = Vec::with_capacity(quantized.bytes.len());
158        for (i, &b) in quantized.bytes.iter().enumerate() {
159            out.push(decode_scalar(b, cal.mins[i], cal.scales[i]));
160        }
161        Ok(out)
162    }
163
164    fn distance(
165        &self,
166        query: &[f32],
167        quantized: &Self::Quantized,
168        metric: DistanceMetric,
169    ) -> Result<f32> {
170        let cal = self.calibration()?;
171        finite_non_empty(query)?;
172        dim_eq(cal.mins.len(), query.len())?;
173        dim_eq(cal.mins.len(), quantized.bytes.len())?;
174        let decoded = self.dequantize(quantized)?;
175        compute(metric, query, &decoded)
176    }
177}
178
179/// Encode one `f32` component into a `u8` code under an affine calibration.
180///
181/// Clamps the input into the trained range before the cast, so the `as u8`
182/// step cannot trip release-mode out-of-range UB even on inputs well
183/// outside `[min, max]`. A zero-`scale` lane (the `max == min` case in
184/// training) always encodes to `0`.
185fn encode_scalar(value: f32, min: f32, scale: f32) -> u8 {
186    if scale <= 0.0 {
187        return 0;
188    }
189    let normalised = ((value - min) / scale).round();
190    if normalised <= 0.0 {
191        0
192    } else if normalised >= LEVELS {
193        u8::MAX
194    } else {
195        normalised as u8
196    }
197}
198
199/// Decode one `u8` code byte back to `f32` under an affine calibration.
200fn decode_scalar(byte: u8, min: f32, scale: f32) -> f32 {
201    if scale <= 0.0 {
202        return min;
203    }
204    min + f32::from(byte) * scale
205}
206
207#[cfg(test)]
208mod tests {
209    #![allow(clippy::unwrap_used)]
210
211    use super::*;
212    use iqdb_types::{DistanceMetric, IqdbError};
213
214    fn trained_unit() -> ScalarQuantizer {
215        let mut sq = ScalarQuantizer::new();
216        sq.train(&[&[0.0_f32, 1.0, 2.0][..], &[1.0_f32, 0.0, 1.0][..]])
217            .unwrap();
218        sq
219    }
220
221    #[test]
222    fn quantize_before_train_returns_invalid_config() {
223        let sq = ScalarQuantizer::new();
224        let err = sq.quantize(&[0.5_f32, 0.5]).unwrap_err();
225        assert!(
226            matches!(err, IqdbError::InvalidConfig { .. }),
227            "expected InvalidConfig, got {err:?}",
228        );
229    }
230
231    #[test]
232    fn distance_before_train_returns_invalid_config() {
233        let sq = ScalarQuantizer::new();
234        let code = Sq8Code {
235            bytes: vec![0, 0, 0],
236        };
237        let err = sq
238            .distance(&[0.5_f32, 0.5, 0.5], &code, DistanceMetric::Euclidean)
239            .unwrap_err();
240        assert!(
241            matches!(err, IqdbError::InvalidConfig { .. }),
242            "expected InvalidConfig, got {err:?}",
243        );
244    }
245
246    #[test]
247    fn dequantize_before_train_returns_invalid_config() {
248        let sq = ScalarQuantizer::new();
249        let code = Sq8Code { bytes: vec![0, 0] };
250        let err = sq.dequantize(&code).unwrap_err();
251        assert!(
252            matches!(err, IqdbError::InvalidConfig { .. }),
253            "expected InvalidConfig, got {err:?}",
254        );
255    }
256
257    #[test]
258    fn train_empty_set_returns_invalid_config() {
259        let mut sq = ScalarQuantizer::new();
260        let empty: [&[f32]; 0] = [];
261        let err = sq.train(&empty).unwrap_err();
262        assert!(
263            matches!(err, IqdbError::InvalidConfig { .. }),
264            "expected InvalidConfig, got {err:?}",
265        );
266    }
267
268    #[test]
269    fn train_inconsistent_dim_returns_dimension_mismatch() {
270        let mut sq = ScalarQuantizer::new();
271        let a = [0.0_f32, 1.0, 2.0];
272        let b = [1.0_f32, 0.0];
273        let err = sq.train(&[&a[..], &b[..]]).unwrap_err();
274        assert_eq!(
275            err,
276            IqdbError::DimensionMismatch {
277                expected: 3,
278                found: 2,
279            },
280        );
281    }
282
283    #[test]
284    fn train_non_finite_returns_invalid_vector() {
285        let mut sq = ScalarQuantizer::new();
286        let v = [1.0_f32, f32::NAN];
287        assert_eq!(sq.train(&[&v[..]]).unwrap_err(), IqdbError::InvalidVector,);
288    }
289
290    #[test]
291    fn quantize_dim_mismatch_returns_dimension_mismatch() {
292        let sq = trained_unit();
293        let err = sq.quantize(&[0.5_f32, 0.5]).unwrap_err();
294        assert_eq!(
295            err,
296            IqdbError::DimensionMismatch {
297                expected: 3,
298                found: 2,
299            },
300        );
301    }
302
303    #[test]
304    fn quantize_non_finite_returns_invalid_vector() {
305        let sq = trained_unit();
306        let err = sq.quantize(&[0.5_f32, f32::INFINITY, 0.5]).unwrap_err();
307        assert_eq!(err, IqdbError::InvalidVector);
308    }
309
310    #[test]
311    fn round_trip_within_per_dim_bound() {
312        let sq = trained_unit();
313        // Per-dim trained ranges: dim0=[0,1], dim1=[0,1], dim2=[1,2].
314        // Per-dim max round-trip error <= scale = range/255 <= 1/255 here.
315        let inputs = [0.1_f32, 0.5, 1.5];
316        let code = sq.quantize(&inputs).unwrap();
317        let decoded = sq.dequantize(&code).unwrap();
318        for (i, (&expected, &got)) in inputs.iter().zip(decoded.iter()).enumerate() {
319            let err = (expected - got).abs();
320            // 1.0 / 255.0 plus a tiny rounding cushion.
321            assert!(
322                err <= 1.0 / 255.0 + 1e-6,
323                "dim {i}: |{expected} - {got}| = {err}",
324            );
325        }
326    }
327
328    #[test]
329    fn zero_range_dimension_does_not_panic_and_round_trips_to_min() {
330        // All training vectors share dim0 = 7.0 -> max == min on that lane.
331        let mut sq = ScalarQuantizer::new();
332        sq.train(&[&[7.0_f32, 0.0][..], &[7.0_f32, 1.0][..]])
333            .unwrap();
334
335        let code = sq.quantize(&[7.0_f32, 0.5]).unwrap();
336        let decoded = sq.dequantize(&code).unwrap();
337        assert!((decoded[0] - 7.0).abs() < 1e-6);
338
339        // Even an out-of-range value on dim0 cannot escape the lane.
340        let code = sq.quantize(&[42.0_f32, 0.5]).unwrap();
341        let decoded = sq.dequantize(&code).unwrap();
342        assert!((decoded[0] - 7.0).abs() < 1e-6);
343    }
344
345    #[test]
346    fn distance_smaller_is_nearer_for_euclidean() {
347        let sq = trained_unit();
348        let near = sq.quantize(&[0.5_f32, 0.5, 1.5]).unwrap();
349        let far = sq.quantize(&[1.0_f32, 0.0, 1.0]).unwrap();
350        let q = [0.5_f32, 0.5, 1.5];
351        let d_near = sq.distance(&q, &near, DistanceMetric::Euclidean).unwrap();
352        let d_far = sq.distance(&q, &far, DistanceMetric::Euclidean).unwrap();
353        assert!(d_near < d_far);
354    }
355
356    #[test]
357    fn distance_matches_iqdb_distance_on_dequantized() {
358        let sq = trained_unit();
359        let q = [0.5_f32, 0.5, 1.5];
360        let code = sq.quantize(&[0.4_f32, 0.6, 1.4]).unwrap();
361        let decoded = sq.dequantize(&code).unwrap();
362        let via_quant = sq.distance(&q, &code, DistanceMetric::Cosine).unwrap();
363        let direct = compute(DistanceMetric::Cosine, &q, &decoded).unwrap();
364        assert_eq!(via_quant.to_bits(), direct.to_bits());
365    }
366
367    #[test]
368    fn encode_clamps_below_range() {
369        // scale > 0, value far below min: byte is 0, no UB on the cast.
370        assert_eq!(encode_scalar(-1e9, 0.0, 1.0 / 255.0), 0);
371    }
372
373    #[test]
374    fn encode_clamps_above_range() {
375        // scale > 0, value far above max: byte is 255, no UB on the cast.
376        assert_eq!(encode_scalar(1e9, 0.0, 1.0 / 255.0), u8::MAX);
377    }
378}