1use error_forge::ForgeError;
17use iqdb_distance::compute;
18use iqdb_types::{DistanceMetric, IqdbError, Result};
19
20use crate::code::Sq8Code;
21use crate::traits::Quantizer;
22use crate::validate::{dim_eq, finite_non_empty, training_set};
23
24const LEVELS: f32 = 255.0;
26
27#[derive(Debug, Clone, PartialEq)]
29struct Sq8Calibration {
30 mins: Vec<f32>,
32 scales: Vec<f32>,
35}
36
37#[derive(Debug, Clone, Default, PartialEq)]
61pub struct ScalarQuantizer {
62 calibration: Option<Sq8Calibration>,
63}
64
65impl ScalarQuantizer {
66 #[must_use]
78 pub fn new() -> Self {
79 Self { calibration: None }
80 }
81
82 #[must_use]
95 pub fn dim(&self) -> Option<usize> {
96 self.calibration.as_ref().map(|c| c.mins.len())
97 }
98
99 fn calibration(&self) -> Result<&Sq8Calibration> {
100 self.calibration.as_ref().ok_or(IqdbError::InvalidConfig {
101 reason: "ScalarQuantizer has not been trained",
102 })
103 }
104}
105
106impl Quantizer for ScalarQuantizer {
107 type Quantized = Sq8Code;
108
109 #[tracing::instrument(
110 level = "info",
111 skip_all,
112 fields(quantizer = "sq8", training_size = vectors.len()),
113 )]
114 fn train(&mut self, vectors: &[&[f32]]) -> Result<()> {
115 let dim = training_set(vectors).inspect_err(|err: &IqdbError| {
116 tracing::error!(
117 error.kind = err.kind(),
118 error.reason = err.caption(),
119 "scalar quantizer training failed",
120 );
121 })?;
122 let mut mins = vec![f32::INFINITY; dim];
123 let mut maxs = vec![f32::NEG_INFINITY; dim];
124 for v in vectors {
125 for (i, &x) in v.iter().enumerate() {
126 if x < mins[i] {
127 mins[i] = x;
128 }
129 if x > maxs[i] {
130 maxs[i] = x;
131 }
132 }
133 }
134 let mut scales = vec![0.0_f32; dim];
135 for i in 0..dim {
136 let range = maxs[i] - mins[i];
137 scales[i] = if range > 0.0 { range / LEVELS } else { 0.0 };
138 }
139 self.calibration = Some(Sq8Calibration { mins, scales });
140 Ok(())
141 }
142
143 fn quantize(&self, vector: &[f32]) -> Result<Self::Quantized> {
144 let cal = self.calibration()?;
145 finite_non_empty(vector)?;
146 dim_eq(cal.mins.len(), vector.len())?;
147 let mut bytes = Vec::with_capacity(vector.len());
148 for (i, &x) in vector.iter().enumerate() {
149 bytes.push(encode_scalar(x, cal.mins[i], cal.scales[i]));
150 }
151 Ok(Sq8Code { bytes })
152 }
153
154 fn dequantize(&self, quantized: &Self::Quantized) -> Result<Vec<f32>> {
155 let cal = self.calibration()?;
156 dim_eq(cal.mins.len(), quantized.bytes.len())?;
157 let mut out = Vec::with_capacity(quantized.bytes.len());
158 for (i, &b) in quantized.bytes.iter().enumerate() {
159 out.push(decode_scalar(b, cal.mins[i], cal.scales[i]));
160 }
161 Ok(out)
162 }
163
164 fn distance(
165 &self,
166 query: &[f32],
167 quantized: &Self::Quantized,
168 metric: DistanceMetric,
169 ) -> Result<f32> {
170 let cal = self.calibration()?;
171 finite_non_empty(query)?;
172 dim_eq(cal.mins.len(), query.len())?;
173 dim_eq(cal.mins.len(), quantized.bytes.len())?;
174 let decoded = self.dequantize(quantized)?;
175 compute(metric, query, &decoded)
176 }
177}
178
179fn encode_scalar(value: f32, min: f32, scale: f32) -> u8 {
186 if scale <= 0.0 {
187 return 0;
188 }
189 let normalised = ((value - min) / scale).round();
190 if normalised <= 0.0 {
191 0
192 } else if normalised >= LEVELS {
193 u8::MAX
194 } else {
195 normalised as u8
196 }
197}
198
199fn decode_scalar(byte: u8, min: f32, scale: f32) -> f32 {
201 if scale <= 0.0 {
202 return min;
203 }
204 min + f32::from(byte) * scale
205}
206
207#[cfg(test)]
208mod tests {
209 #![allow(clippy::unwrap_used)]
210
211 use super::*;
212 use iqdb_types::{DistanceMetric, IqdbError};
213
214 fn trained_unit() -> ScalarQuantizer {
215 let mut sq = ScalarQuantizer::new();
216 sq.train(&[&[0.0_f32, 1.0, 2.0][..], &[1.0_f32, 0.0, 1.0][..]])
217 .unwrap();
218 sq
219 }
220
221 #[test]
222 fn quantize_before_train_returns_invalid_config() {
223 let sq = ScalarQuantizer::new();
224 let err = sq.quantize(&[0.5_f32, 0.5]).unwrap_err();
225 assert!(
226 matches!(err, IqdbError::InvalidConfig { .. }),
227 "expected InvalidConfig, got {err:?}",
228 );
229 }
230
231 #[test]
232 fn distance_before_train_returns_invalid_config() {
233 let sq = ScalarQuantizer::new();
234 let code = Sq8Code {
235 bytes: vec![0, 0, 0],
236 };
237 let err = sq
238 .distance(&[0.5_f32, 0.5, 0.5], &code, DistanceMetric::Euclidean)
239 .unwrap_err();
240 assert!(
241 matches!(err, IqdbError::InvalidConfig { .. }),
242 "expected InvalidConfig, got {err:?}",
243 );
244 }
245
246 #[test]
247 fn dequantize_before_train_returns_invalid_config() {
248 let sq = ScalarQuantizer::new();
249 let code = Sq8Code { bytes: vec![0, 0] };
250 let err = sq.dequantize(&code).unwrap_err();
251 assert!(
252 matches!(err, IqdbError::InvalidConfig { .. }),
253 "expected InvalidConfig, got {err:?}",
254 );
255 }
256
257 #[test]
258 fn train_empty_set_returns_invalid_config() {
259 let mut sq = ScalarQuantizer::new();
260 let empty: [&[f32]; 0] = [];
261 let err = sq.train(&empty).unwrap_err();
262 assert!(
263 matches!(err, IqdbError::InvalidConfig { .. }),
264 "expected InvalidConfig, got {err:?}",
265 );
266 }
267
268 #[test]
269 fn train_inconsistent_dim_returns_dimension_mismatch() {
270 let mut sq = ScalarQuantizer::new();
271 let a = [0.0_f32, 1.0, 2.0];
272 let b = [1.0_f32, 0.0];
273 let err = sq.train(&[&a[..], &b[..]]).unwrap_err();
274 assert_eq!(
275 err,
276 IqdbError::DimensionMismatch {
277 expected: 3,
278 found: 2,
279 },
280 );
281 }
282
283 #[test]
284 fn train_non_finite_returns_invalid_vector() {
285 let mut sq = ScalarQuantizer::new();
286 let v = [1.0_f32, f32::NAN];
287 assert_eq!(sq.train(&[&v[..]]).unwrap_err(), IqdbError::InvalidVector,);
288 }
289
290 #[test]
291 fn quantize_dim_mismatch_returns_dimension_mismatch() {
292 let sq = trained_unit();
293 let err = sq.quantize(&[0.5_f32, 0.5]).unwrap_err();
294 assert_eq!(
295 err,
296 IqdbError::DimensionMismatch {
297 expected: 3,
298 found: 2,
299 },
300 );
301 }
302
303 #[test]
304 fn quantize_non_finite_returns_invalid_vector() {
305 let sq = trained_unit();
306 let err = sq.quantize(&[0.5_f32, f32::INFINITY, 0.5]).unwrap_err();
307 assert_eq!(err, IqdbError::InvalidVector);
308 }
309
310 #[test]
311 fn round_trip_within_per_dim_bound() {
312 let sq = trained_unit();
313 let inputs = [0.1_f32, 0.5, 1.5];
316 let code = sq.quantize(&inputs).unwrap();
317 let decoded = sq.dequantize(&code).unwrap();
318 for (i, (&expected, &got)) in inputs.iter().zip(decoded.iter()).enumerate() {
319 let err = (expected - got).abs();
320 assert!(
322 err <= 1.0 / 255.0 + 1e-6,
323 "dim {i}: |{expected} - {got}| = {err}",
324 );
325 }
326 }
327
328 #[test]
329 fn zero_range_dimension_does_not_panic_and_round_trips_to_min() {
330 let mut sq = ScalarQuantizer::new();
332 sq.train(&[&[7.0_f32, 0.0][..], &[7.0_f32, 1.0][..]])
333 .unwrap();
334
335 let code = sq.quantize(&[7.0_f32, 0.5]).unwrap();
336 let decoded = sq.dequantize(&code).unwrap();
337 assert!((decoded[0] - 7.0).abs() < 1e-6);
338
339 let code = sq.quantize(&[42.0_f32, 0.5]).unwrap();
341 let decoded = sq.dequantize(&code).unwrap();
342 assert!((decoded[0] - 7.0).abs() < 1e-6);
343 }
344
345 #[test]
346 fn distance_smaller_is_nearer_for_euclidean() {
347 let sq = trained_unit();
348 let near = sq.quantize(&[0.5_f32, 0.5, 1.5]).unwrap();
349 let far = sq.quantize(&[1.0_f32, 0.0, 1.0]).unwrap();
350 let q = [0.5_f32, 0.5, 1.5];
351 let d_near = sq.distance(&q, &near, DistanceMetric::Euclidean).unwrap();
352 let d_far = sq.distance(&q, &far, DistanceMetric::Euclidean).unwrap();
353 assert!(d_near < d_far);
354 }
355
356 #[test]
357 fn distance_matches_iqdb_distance_on_dequantized() {
358 let sq = trained_unit();
359 let q = [0.5_f32, 0.5, 1.5];
360 let code = sq.quantize(&[0.4_f32, 0.6, 1.4]).unwrap();
361 let decoded = sq.dequantize(&code).unwrap();
362 let via_quant = sq.distance(&q, &code, DistanceMetric::Cosine).unwrap();
363 let direct = compute(DistanceMetric::Cosine, &q, &decoded).unwrap();
364 assert_eq!(via_quant.to_bits(), direct.to_bits());
365 }
366
367 #[test]
368 fn encode_clamps_below_range() {
369 assert_eq!(encode_scalar(-1e9, 0.0, 1.0 / 255.0), 0);
371 }
372
373 #[test]
374 fn encode_clamps_above_range() {
375 assert_eq!(encode_scalar(1e9, 0.0, 1.0 / 255.0), u8::MAX);
377 }
378}