use error_forge::ForgeError;
use iqdb_distance::compute;
use iqdb_types::{DistanceMetric, IqdbError, Result};
use crate::code::Sq8Code;
use crate::traits::Quantizer;
use crate::validate::{dim_eq, finite_non_empty, training_set};
const LEVELS: f32 = 255.0;
#[derive(Debug, Clone, PartialEq)]
struct Sq8Calibration {
mins: Vec<f32>,
scales: Vec<f32>,
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct ScalarQuantizer {
calibration: Option<Sq8Calibration>,
}
impl ScalarQuantizer {
#[must_use]
pub fn new() -> Self {
Self { calibration: None }
}
#[must_use]
pub fn dim(&self) -> Option<usize> {
self.calibration.as_ref().map(|c| c.mins.len())
}
fn calibration(&self) -> Result<&Sq8Calibration> {
self.calibration.as_ref().ok_or(IqdbError::InvalidConfig {
reason: "ScalarQuantizer has not been trained",
})
}
}
impl Quantizer for ScalarQuantizer {
type Quantized = Sq8Code;
#[tracing::instrument(
level = "info",
skip_all,
fields(quantizer = "sq8", training_size = vectors.len()),
)]
fn train(&mut self, vectors: &[&[f32]]) -> Result<()> {
let dim = training_set(vectors).inspect_err(|err: &IqdbError| {
tracing::error!(
error.kind = err.kind(),
error.reason = err.caption(),
"scalar quantizer training failed",
);
})?;
let mut mins = vec![f32::INFINITY; dim];
let mut maxs = vec![f32::NEG_INFINITY; dim];
for v in vectors {
for (i, &x) in v.iter().enumerate() {
if x < mins[i] {
mins[i] = x;
}
if x > maxs[i] {
maxs[i] = x;
}
}
}
let mut scales = vec![0.0_f32; dim];
for i in 0..dim {
let range = maxs[i] - mins[i];
scales[i] = if range > 0.0 { range / LEVELS } else { 0.0 };
}
self.calibration = Some(Sq8Calibration { mins, scales });
Ok(())
}
fn quantize(&self, vector: &[f32]) -> Result<Self::Quantized> {
let cal = self.calibration()?;
finite_non_empty(vector)?;
dim_eq(cal.mins.len(), vector.len())?;
let mut bytes = Vec::with_capacity(vector.len());
for (i, &x) in vector.iter().enumerate() {
bytes.push(encode_scalar(x, cal.mins[i], cal.scales[i]));
}
Ok(Sq8Code { bytes })
}
fn dequantize(&self, quantized: &Self::Quantized) -> Result<Vec<f32>> {
let cal = self.calibration()?;
dim_eq(cal.mins.len(), quantized.bytes.len())?;
let mut out = Vec::with_capacity(quantized.bytes.len());
for (i, &b) in quantized.bytes.iter().enumerate() {
out.push(decode_scalar(b, cal.mins[i], cal.scales[i]));
}
Ok(out)
}
fn distance(
&self,
query: &[f32],
quantized: &Self::Quantized,
metric: DistanceMetric,
) -> Result<f32> {
let cal = self.calibration()?;
finite_non_empty(query)?;
dim_eq(cal.mins.len(), query.len())?;
dim_eq(cal.mins.len(), quantized.bytes.len())?;
let decoded = self.dequantize(quantized)?;
compute(metric, query, &decoded)
}
}
fn encode_scalar(value: f32, min: f32, scale: f32) -> u8 {
if scale <= 0.0 {
return 0;
}
let normalised = ((value - min) / scale).round();
if normalised <= 0.0 {
0
} else if normalised >= LEVELS {
u8::MAX
} else {
normalised as u8
}
}
fn decode_scalar(byte: u8, min: f32, scale: f32) -> f32 {
if scale <= 0.0 {
return min;
}
min + f32::from(byte) * scale
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use iqdb_types::{DistanceMetric, IqdbError};
fn trained_unit() -> ScalarQuantizer {
let mut sq = ScalarQuantizer::new();
sq.train(&[&[0.0_f32, 1.0, 2.0][..], &[1.0_f32, 0.0, 1.0][..]])
.unwrap();
sq
}
#[test]
fn quantize_before_train_returns_invalid_config() {
let sq = ScalarQuantizer::new();
let err = sq.quantize(&[0.5_f32, 0.5]).unwrap_err();
assert!(
matches!(err, IqdbError::InvalidConfig { .. }),
"expected InvalidConfig, got {err:?}",
);
}
#[test]
fn distance_before_train_returns_invalid_config() {
let sq = ScalarQuantizer::new();
let code = Sq8Code {
bytes: vec![0, 0, 0],
};
let err = sq
.distance(&[0.5_f32, 0.5, 0.5], &code, DistanceMetric::Euclidean)
.unwrap_err();
assert!(
matches!(err, IqdbError::InvalidConfig { .. }),
"expected InvalidConfig, got {err:?}",
);
}
#[test]
fn dequantize_before_train_returns_invalid_config() {
let sq = ScalarQuantizer::new();
let code = Sq8Code { bytes: vec![0, 0] };
let err = sq.dequantize(&code).unwrap_err();
assert!(
matches!(err, IqdbError::InvalidConfig { .. }),
"expected InvalidConfig, got {err:?}",
);
}
#[test]
fn train_empty_set_returns_invalid_config() {
let mut sq = ScalarQuantizer::new();
let empty: [&[f32]; 0] = [];
let err = sq.train(&empty).unwrap_err();
assert!(
matches!(err, IqdbError::InvalidConfig { .. }),
"expected InvalidConfig, got {err:?}",
);
}
#[test]
fn train_inconsistent_dim_returns_dimension_mismatch() {
let mut sq = ScalarQuantizer::new();
let a = [0.0_f32, 1.0, 2.0];
let b = [1.0_f32, 0.0];
let err = sq.train(&[&a[..], &b[..]]).unwrap_err();
assert_eq!(
err,
IqdbError::DimensionMismatch {
expected: 3,
found: 2,
},
);
}
#[test]
fn train_non_finite_returns_invalid_vector() {
let mut sq = ScalarQuantizer::new();
let v = [1.0_f32, f32::NAN];
assert_eq!(sq.train(&[&v[..]]).unwrap_err(), IqdbError::InvalidVector,);
}
#[test]
fn quantize_dim_mismatch_returns_dimension_mismatch() {
let sq = trained_unit();
let err = sq.quantize(&[0.5_f32, 0.5]).unwrap_err();
assert_eq!(
err,
IqdbError::DimensionMismatch {
expected: 3,
found: 2,
},
);
}
#[test]
fn quantize_non_finite_returns_invalid_vector() {
let sq = trained_unit();
let err = sq.quantize(&[0.5_f32, f32::INFINITY, 0.5]).unwrap_err();
assert_eq!(err, IqdbError::InvalidVector);
}
#[test]
fn round_trip_within_per_dim_bound() {
let sq = trained_unit();
let inputs = [0.1_f32, 0.5, 1.5];
let code = sq.quantize(&inputs).unwrap();
let decoded = sq.dequantize(&code).unwrap();
for (i, (&expected, &got)) in inputs.iter().zip(decoded.iter()).enumerate() {
let err = (expected - got).abs();
assert!(
err <= 1.0 / 255.0 + 1e-6,
"dim {i}: |{expected} - {got}| = {err}",
);
}
}
#[test]
fn zero_range_dimension_does_not_panic_and_round_trips_to_min() {
let mut sq = ScalarQuantizer::new();
sq.train(&[&[7.0_f32, 0.0][..], &[7.0_f32, 1.0][..]])
.unwrap();
let code = sq.quantize(&[7.0_f32, 0.5]).unwrap();
let decoded = sq.dequantize(&code).unwrap();
assert!((decoded[0] - 7.0).abs() < 1e-6);
let code = sq.quantize(&[42.0_f32, 0.5]).unwrap();
let decoded = sq.dequantize(&code).unwrap();
assert!((decoded[0] - 7.0).abs() < 1e-6);
}
#[test]
fn distance_smaller_is_nearer_for_euclidean() {
let sq = trained_unit();
let near = sq.quantize(&[0.5_f32, 0.5, 1.5]).unwrap();
let far = sq.quantize(&[1.0_f32, 0.0, 1.0]).unwrap();
let q = [0.5_f32, 0.5, 1.5];
let d_near = sq.distance(&q, &near, DistanceMetric::Euclidean).unwrap();
let d_far = sq.distance(&q, &far, DistanceMetric::Euclidean).unwrap();
assert!(d_near < d_far);
}
#[test]
fn distance_matches_iqdb_distance_on_dequantized() {
let sq = trained_unit();
let q = [0.5_f32, 0.5, 1.5];
let code = sq.quantize(&[0.4_f32, 0.6, 1.4]).unwrap();
let decoded = sq.dequantize(&code).unwrap();
let via_quant = sq.distance(&q, &code, DistanceMetric::Cosine).unwrap();
let direct = compute(DistanceMetric::Cosine, &q, &decoded).unwrap();
assert_eq!(via_quant.to_bits(), direct.to_bits());
}
#[test]
fn encode_clamps_below_range() {
assert_eq!(encode_scalar(-1e9, 0.0, 1.0 / 255.0), 0);
}
#[test]
fn encode_clamps_above_range() {
assert_eq!(encode_scalar(1e9, 0.0, 1.0 / 255.0), u8::MAX);
}
}