use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScalarQuantizer {
mins: Vec<f32>,
scales: Vec<f32>,
dims: usize,
}
impl ScalarQuantizer {
#[must_use]
pub fn fit(vectors: &[&[f32]]) -> Self {
assert!(!vectors.is_empty(), "need at least one vector to fit");
let dims = vectors[0].len();
let mut mins = vec![f32::INFINITY; dims];
let mut maxs = vec![f32::NEG_INFINITY; dims];
for vec in vectors {
assert_eq!(vec.len(), dims, "all vectors must have the same dimension");
for (i, &v) in vec.iter().enumerate() {
if !v.is_finite() {
continue;
}
if v < mins[i] {
mins[i] = v;
}
if v > maxs[i] {
maxs[i] = v;
}
}
}
let scales: Vec<f32> = mins
.iter()
.zip(maxs.iter())
.map(|(&mn, &mx)| {
let range = mx - mn;
if range < f32::EPSILON {
0.0 } else {
range / 255.0
}
})
.collect();
Self { mins, scales, dims }
}
#[must_use]
pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
assert_eq!(
vector.len(),
self.dims,
"vector dimension mismatch: expected {}, got {}",
self.dims,
vector.len()
);
vector
.iter()
.enumerate()
.map(|(i, &v)| {
if self.scales[i] < f32::EPSILON {
0
} else {
let normalized = (v - self.mins[i]) / self.scales[i];
let clamped = normalized.round().clamp(0.0, 255.0);
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let q = if clamped.is_nan() { 0 } else { clamped as u8 };
q
}
})
.collect()
}
#[must_use]
pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
assert_eq!(
quantized.len(),
self.dims,
"quantized vector dimension mismatch"
);
quantized
.iter()
.enumerate()
.map(|(i, &q)| f32::from(q).mul_add(self.scales[i], self.mins[i]))
.collect()
}
#[must_use]
pub fn dot_product_quantized(&self, stored: &[u8], query: &[f32]) -> f32 {
assert_eq!(stored.len(), self.dims);
assert_eq!(query.len(), self.dims);
let mut sum = 0.0_f32;
for i in 0..self.dims {
let dequantized = f32::from(stored[i]).mul_add(self.scales[i], self.mins[i]);
sum = dequantized.mul_add(query[i], sum);
}
sum
}
#[must_use]
pub fn cosine_similarity_quantized(&self, stored: &[u8], query: &[f32]) -> f32 {
assert_eq!(stored.len(), self.dims);
assert_eq!(query.len(), self.dims);
let mut dot = 0.0_f32;
let mut norm_s = 0.0_f32;
let mut norm_q = 0.0_f32;
for i in 0..self.dims {
let s = f32::from(stored[i]).mul_add(self.scales[i], self.mins[i]);
dot = s.mul_add(query[i], dot);
norm_s = s.mul_add(s, norm_s);
norm_q = query[i].mul_add(query[i], norm_q);
}
let denom = norm_s.sqrt() * norm_q.sqrt();
if denom < f32::EPSILON {
return 0.0;
}
dot / denom
}
#[must_use]
pub const fn dims(&self) -> usize {
self.dims
}
#[must_use]
pub fn mins(&self) -> &[f32] {
&self.mins
}
#[must_use]
pub fn scales(&self) -> &[f32] {
&self.scales
}
#[must_use]
pub const fn quantized_vector_bytes(&self) -> usize {
self.dims }
#[must_use]
pub const fn parameter_bytes(&self) -> usize {
self.dims * 4 * 2 }
#[must_use]
pub fn max_error_per_dim(&self) -> Vec<f32> {
self.scales.iter().map(|s| s / 2.0).collect()
}
#[must_use]
pub fn cosine_error_bound(&self) -> f32 {
let max_scale = self.scales.iter().copied().fold(0.0_f32, f32::max);
#[allow(clippy::cast_precision_loss)]
let sqrt_d = (self.dims as f32).sqrt();
max_scale * sqrt_d
}
}
#[cfg(test)]
mod tests {
use super::*;
fn l2_normalize(v: &[f32]) -> Vec<f32> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < f32::EPSILON {
return vec![0.0; v.len()];
}
v.iter().map(|x| x / norm).collect()
}
fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
let denom = na * nb;
if denom < f32::EPSILON {
0.0
} else {
dot / denom
}
}
#[test]
fn fit_and_roundtrip_small() {
let vectors = [
vec![0.1_f32, 0.5, -0.3],
vec![0.2, 0.8, -0.1],
vec![-0.1, 0.6, 0.4],
];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
assert_eq!(q.dims(), 3);
for vec in &vectors {
let quantized = q.quantize(vec);
let restored = q.dequantize(&quantized);
for (orig, rest) in vec.iter().zip(restored.iter()) {
assert!(
(orig - rest).abs() < 0.01,
"roundtrip error too large: {orig} vs {rest}"
);
}
}
}
#[test]
fn fit_single_vector() {
let vectors = [vec![1.0_f32, 2.0, 3.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
assert!(q.scales().iter().all(|s| *s < f32::EPSILON));
}
#[test]
fn cosine_similarity_preserved_after_quantization() {
let vectors: Vec<Vec<f32>> = (0..20)
.map(|i| {
#[allow(clippy::cast_precision_loss)]
let i_f = i as f32;
l2_normalize(&[
(i_f * 0.1).sin(),
(i_f * 0.2).cos(),
(i_f * 0.3).sin(),
(i_f * 0.4).cos(),
])
})
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
for i in 0..vectors.len() {
for j in (i + 1)..vectors.len() {
let original_sim = cosine_sim(&vectors[i], &vectors[j]);
let qi = q.quantize(&vectors[i]);
let ri = q.dequantize(&qi);
let qj = q.quantize(&vectors[j]);
let rj = q.dequantize(&qj);
let quantized_sim = cosine_sim(&ri, &rj);
assert!(
(original_sim - quantized_sim).abs() < 0.05,
"cosine sim diverged: {original_sim} vs {quantized_sim} for pair ({i}, {j})"
);
}
}
}
#[test]
fn dot_product_quantized_matches_dequantized() {
let vectors = [
vec![0.1_f32, 0.5, -0.3, 0.8],
vec![0.2, 0.8, -0.1, 0.3],
vec![-0.1, 0.6, 0.4, -0.2],
];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let query = vec![0.3_f32, 0.4, -0.2, 0.5];
let stored = q.quantize(&vectors[0]);
let adc_dot = q.dot_product_quantized(&stored, &query);
let deq = q.dequantize(&stored);
let full_dot: f32 = deq.iter().zip(query.iter()).map(|(a, b)| a * b).sum();
assert!(
(adc_dot - full_dot).abs() < 1e-5,
"ADC dot mismatch: {adc_dot} vs {full_dot}"
);
}
#[test]
fn cosine_similarity_quantized_matches() {
let vectors: Vec<Vec<f32>> = (0..10)
.map(|i| {
#[allow(clippy::cast_precision_loss)]
let i_f = i as f32;
l2_normalize(&[(i_f * 0.5).sin(), (i_f * 0.7).cos(), (i_f * 0.3).sin()])
})
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let query = l2_normalize(&[0.5, 0.3, -0.1]);
for vec in &vectors {
let stored = q.quantize(vec);
let adc_cos = q.cosine_similarity_quantized(&stored, &query);
let deq = q.dequantize(&stored);
let full_cos = cosine_sim(&deq, &query);
assert!(
(adc_cos - full_cos).abs() < 1e-4,
"ADC cosine mismatch: {adc_cos} vs {full_cos}"
);
}
}
#[test]
fn error_bound_is_valid() {
let vectors: Vec<Vec<f32>> = (0..50)
.map(|i| {
#[allow(clippy::cast_precision_loss)]
let i_f = i as f32;
l2_normalize(&[
(i_f * 0.1).sin(),
(i_f * 0.2).cos(),
(i_f * 0.3).sin(),
(i_f * 0.4).cos(),
(i_f * 0.5).sin(),
(i_f * 0.6).cos(),
(i_f * 0.7).sin(),
(i_f * 0.8).cos(),
])
})
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let bound = q.cosine_error_bound();
assert!(bound > 0.0);
assert!(bound < 1.0, "error bound too large: {bound}");
for i in 0..vectors.len() {
for j in (i + 1)..vectors.len() {
let orig = cosine_sim(&vectors[i], &vectors[j]);
let qi = q.quantize(&vectors[i]);
let qj = q.quantize(&vectors[j]);
let ri = q.dequantize(&qi);
let rj = q.dequantize(&qj);
let quant = cosine_sim(&ri, &rj);
assert!(
(orig - quant).abs() <= bound + 0.01, "error {:.4} exceeds bound {:.4} for pair ({i}, {j})",
(orig - quant).abs(),
bound,
);
}
}
}
#[test]
fn memory_accounting() {
let vectors = [vec![0.0_f32; 384]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
assert_eq!(q.quantized_vector_bytes(), 384);
assert_eq!(q.parameter_bytes(), 384 * 8); }
#[test]
fn constant_dimension_handled() {
let vectors = [
vec![0.5_f32, 0.0, 0.3], vec![0.1, 0.0, 0.8],
vec![0.9, 0.0, -0.2],
];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
assert!(q.scales()[1] < f32::EPSILON);
let quantized = q.quantize(&vectors[0]);
let restored = q.dequantize(&quantized);
assert!((restored[1] - 0.0).abs() < f32::EPSILON);
}
#[test]
fn values_outside_training_range_clamp() {
let vectors = [vec![0.0_f32, 1.0], vec![1.0, 0.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let outlier = vec![2.0_f32, -1.0];
let quantized = q.quantize(&outlier);
assert_eq!(quantized[0], 255);
assert_eq!(quantized[1], 0);
}
#[test]
fn serde_roundtrip() {
let vectors = [vec![0.1_f32, 0.5, -0.3], vec![0.2, 0.8, -0.1]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let json = serde_json::to_string(&q).unwrap();
let decoded: ScalarQuantizer = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.dims(), q.dims());
let original_q = q.quantize(&vectors[0]);
let decoded_q = decoded.quantize(&vectors[0]);
assert_eq!(original_q, decoded_q);
}
#[test]
fn scalar_quantizer_implements_debug() {
let vectors = [vec![1.0_f32, 2.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let debug_str = format!("{q:?}");
assert!(debug_str.contains("ScalarQuantizer"));
}
#[test]
fn scalar_quantizer_implements_clone() {
let vectors = [vec![0.1_f32, 0.5], vec![0.9, 0.2]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let cloned = q.clone();
assert_eq!(cloned.dims(), q.dims());
assert_eq!(cloned.mins(), q.mins());
assert_eq!(cloned.scales(), q.scales());
let input = vec![0.5_f32, 0.3];
assert_eq!(q.quantize(&input), cloned.quantize(&input));
}
#[test]
#[should_panic(expected = "need at least one vector")]
fn fit_panics_on_empty_vectors() {
let refs: Vec<&[f32]> = vec![];
let _ = ScalarQuantizer::fit(&refs);
}
#[test]
#[should_panic(expected = "same dimension")]
fn fit_panics_on_inconsistent_dimensions() {
let v1 = vec![1.0_f32, 2.0];
let v2 = vec![1.0_f32, 2.0, 3.0];
let refs: Vec<&[f32]> = vec![v1.as_slice(), v2.as_slice()];
let _ = ScalarQuantizer::fit(&refs);
}
#[test]
#[should_panic(expected = "dimension mismatch")]
fn quantize_panics_on_wrong_dimension() {
let vectors = [vec![1.0_f32, 2.0, 3.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let _ = q.quantize(&[1.0, 2.0]);
}
#[test]
#[should_panic(expected = "dimension mismatch")]
fn dequantize_panics_on_wrong_dimension() {
let vectors = [vec![1.0_f32, 2.0, 3.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let _ = q.dequantize(&[0, 128]);
}
#[test]
#[should_panic(expected = "assertion `left == right` failed")]
fn dot_product_quantized_panics_on_stored_dim_mismatch() {
let vectors = [vec![1.0_f32, 2.0], vec![3.0, 4.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let _ = q.dot_product_quantized(&[0, 128, 255], &[1.0, 2.0]);
}
#[test]
#[should_panic(expected = "assertion `left == right` failed")]
fn dot_product_quantized_panics_on_query_dim_mismatch() {
let vectors = [vec![1.0_f32, 2.0], vec![3.0, 4.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let stored = q.quantize(&vectors[0]);
let _ = q.dot_product_quantized(&stored, &[1.0, 2.0, 3.0]);
}
#[test]
#[should_panic(expected = "assertion `left == right` failed")]
fn cosine_similarity_quantized_panics_on_dim_mismatch() {
let vectors = [vec![1.0_f32, 2.0], vec![3.0, 4.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let stored = q.quantize(&vectors[0]);
let _ = q.cosine_similarity_quantized(&stored, &[1.0]);
}
#[test]
fn cosine_similarity_zero_stored_vector_returns_zero() {
let vectors = [vec![0.0_f32, 0.0, 0.0], vec![0.0, 0.0, 0.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let stored = q.quantize(&[0.0, 0.0, 0.0]);
let sim = q.cosine_similarity_quantized(&stored, &[1.0, 0.0, 0.0]);
assert!(sim.abs() < f32::EPSILON);
}
#[test]
fn cosine_similarity_self_is_near_one() {
let vectors = [vec![0.3_f32, 0.7, -0.2], vec![-0.1, 0.5, 0.9]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let query = &[0.3_f32, 0.7, -0.2];
let stored = q.quantize(query);
let sim = q.cosine_similarity_quantized(&stored, query);
assert!(
(sim - 1.0).abs() < 0.02,
"self-similarity should be ~1.0, got {sim}"
);
}
#[test]
fn dot_product_with_zero_query_returns_zero() {
let vectors = [vec![1.0_f32, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let stored = q.quantize(&vectors[0]);
let dot = q.dot_product_quantized(&stored, &[0.0, 0.0, 0.0]);
assert!(
dot.abs() < 1e-5,
"dot with zero query should be ~0, got {dot}"
);
}
#[test]
fn max_error_per_dim_values() {
let vectors = [vec![0.0_f32, 10.0], vec![1.0, 20.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let errors = q.max_error_per_dim();
assert_eq!(errors.len(), 2);
assert!((errors[0] - (1.0 / 255.0 / 2.0)).abs() < 1e-6);
assert!((errors[1] - (10.0 / 255.0 / 2.0)).abs() < 1e-6);
}
#[test]
fn max_error_per_dim_constant_dimension_is_zero() {
let vectors = [vec![5.0_f32, 0.0], vec![5.0, 1.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let errors = q.max_error_per_dim();
assert!(errors[0].abs() < f32::EPSILON);
assert!(errors[1] > 0.0);
}
#[test]
fn cosine_error_bound_positive_for_nonconstant() {
let vectors = [vec![0.1_f32, 0.9], vec![0.9, 0.1]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let bound = q.cosine_error_bound();
assert!(bound > 0.0);
assert!(bound < 1.0);
}
#[test]
fn cosine_error_bound_zero_for_all_constant() {
let vectors = [vec![3.0_f32, 3.0], vec![3.0, 3.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let bound = q.cosine_error_bound();
assert!(bound.abs() < f32::EPSILON);
}
#[test]
fn mins_accessor_returns_correct_values() {
let vectors = [vec![-5.0_f32, 10.0, 0.0], vec![5.0, 20.0, 0.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let mins = q.mins();
assert_eq!(mins.len(), 3);
assert!((mins[0] - (-5.0)).abs() < f32::EPSILON);
assert!((mins[1] - 10.0).abs() < f32::EPSILON);
assert!((mins[2] - 0.0).abs() < f32::EPSILON);
}
#[test]
fn quantized_vector_bytes_equals_dims() {
let vectors = [vec![0.0_f32; 128]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
assert_eq!(q.quantized_vector_bytes(), 128);
assert_eq!(q.parameter_bytes(), 128 * 8);
}
#[test]
fn fit_two_identical_vectors_all_scales_zero() {
let vectors = [vec![0.5_f32, -0.3, 0.7], vec![0.5, -0.3, 0.7]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
assert!(q.scales().iter().all(|s| *s < f32::EPSILON));
}
#[test]
fn quantize_identical_inputs_yields_identical_output() {
let vectors = [vec![0.1_f32, 0.5, -0.3], vec![0.9, 0.2, 0.8]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let input = vec![0.4_f32, 0.3, 0.1];
let q1 = q.quantize(&input);
let q2 = q.quantize(&input);
assert_eq!(q1, q2);
}
#[test]
fn quantize_constant_dimension_maps_to_zero() {
let vectors = [vec![7.0_f32, 1.0], vec![7.0, 5.0]];
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let q = ScalarQuantizer::fit(&refs);
let quantized = q.quantize(&[7.0, 3.0]);
assert_eq!(quantized[0], 0);
assert!(quantized[1] > 0);
assert!(quantized[1] < 255);
}
}