use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize)]
pub struct Sq8Codec {
pub dim: usize,
mins: Vec<f32>,
maxs: Vec<f32>,
scales: Vec<f32>,
inv_scales: Vec<f32>,
}
impl Sq8Codec {
pub fn calibrate(vectors: &[&[f32]], dim: usize) -> Self {
assert!(!vectors.is_empty(), "cannot calibrate on empty set");
assert!(dim > 0);
let mut mins = vec![f32::MAX; dim];
let mut maxs = vec![f32::MIN; dim];
for v in vectors {
debug_assert_eq!(v.len(), dim);
for d in 0..dim {
if v[d] < mins[d] {
mins[d] = v[d];
}
if v[d] > maxs[d] {
maxs[d] = v[d];
}
}
}
let mut scales = vec![0.0f32; dim];
let mut inv_scales = vec![0.0f32; dim];
for d in 0..dim {
let range = maxs[d] - mins[d];
if range > f32::EPSILON {
scales[d] = range / 255.0;
inv_scales[d] = 255.0 / range;
}
}
Self {
dim,
mins,
maxs,
scales,
inv_scales,
}
}
pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
debug_assert_eq!(vector.len(), self.dim);
let mut out = Vec::with_capacity(self.dim);
for ((&v, &min), (&max, &inv_scale)) in vector
.iter()
.zip(self.mins.iter())
.zip(self.maxs.iter().zip(self.inv_scales.iter()))
{
let clamped = v.clamp(min, max);
let q = ((clamped - min) * inv_scale).round() as u8;
out.push(q);
}
out
}
pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<u8> {
let mut out = Vec::with_capacity(self.dim * vectors.len());
for v in vectors {
out.extend(self.quantize(v));
}
out
}
pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
debug_assert_eq!(quantized.len(), self.dim);
let mut out = Vec::with_capacity(self.dim);
for ((&q, &min), &scale) in quantized
.iter()
.zip(self.mins.iter())
.zip(self.scales.iter())
{
out.push(min + q as f32 * scale);
}
out
}
#[inline]
pub fn asymmetric_l2(&self, query: &[f32], candidate: &[u8]) -> f32 {
debug_assert_eq!(query.len(), self.dim);
debug_assert_eq!(candidate.len(), self.dim);
let mut sum = 0.0f32;
for d in 0..self.dim {
let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
let diff = query[d] - dequant;
sum += diff * diff;
}
sum
}
#[inline]
pub fn asymmetric_cosine(&self, query: &[f32], candidate: &[u8]) -> f32 {
debug_assert_eq!(query.len(), self.dim);
debug_assert_eq!(candidate.len(), self.dim);
let mut dot = 0.0f32;
let mut norm_q = 0.0f32;
let mut norm_c = 0.0f32;
for d in 0..self.dim {
let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
dot += query[d] * dequant;
norm_q += query[d] * query[d];
norm_c += dequant * dequant;
}
let denom = (norm_q * norm_c).sqrt();
if denom < f32::EPSILON {
return 1.0;
}
(1.0 - dot / denom).max(0.0)
}
#[inline]
pub fn asymmetric_ip(&self, query: &[f32], candidate: &[u8]) -> f32 {
debug_assert_eq!(query.len(), self.dim);
debug_assert_eq!(candidate.len(), self.dim);
let mut dot = 0.0f32;
for d in 0..self.dim {
let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
dot += query[d] * dequant;
}
-dot
}
pub fn dim(&self) -> usize {
self.dim
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_vectors() -> Vec<Vec<f32>> {
(0..100)
.map(|i| vec![i as f32 * 0.1, (i as f32).sin(), (i as f32).cos()])
.collect()
}
#[test]
fn quantize_dequantize_roundtrip() {
let vecs = make_vectors();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = Sq8Codec::calibrate(&refs, 3);
for v in &vecs {
let q = codec.quantize(v);
let dq = codec.dequantize(&q);
for d in 0..3 {
let error = (v[d] - dq[d]).abs();
let range = codec.maxs[d] - codec.mins[d];
assert!(
error <= range / 255.0 + 1e-6,
"d={d}: error={error}, max_step={}",
range / 255.0
);
}
}
}
#[test]
fn asymmetric_l2_close_to_exact() {
let vecs = make_vectors();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = Sq8Codec::calibrate(&refs, 3);
let query = &[5.0, 0.5, -0.5];
for v in &vecs {
let q = codec.quantize(v);
let exact = crate::engine::vector::distance::l2_squared(query, v);
let approx = codec.asymmetric_l2(query, &q);
let rel_error = if exact > 0.01 {
(exact - approx).abs() / exact
} else {
(exact - approx).abs()
};
assert!(
rel_error < 0.05 || (exact - approx).abs() < 0.1,
"exact={exact}, approx={approx}, rel_error={rel_error}"
);
}
}
#[test]
fn batch_quantize() {
let vecs = make_vectors();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = Sq8Codec::calibrate(&refs, 3);
let batch = codec.quantize_batch(&refs);
assert_eq!(batch.len(), 3 * 100);
let single = codec.quantize(&vecs[0]);
assert_eq!(&batch[0..3], &single[..]);
}
#[test]
fn constant_dimension_handled() {
let vecs: Vec<Vec<f32>> = (0..10).map(|i| vec![5.0, i as f32]).collect();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = Sq8Codec::calibrate(&refs, 2);
let q = codec.quantize(&[5.0, 3.0]);
assert_eq!(q[0], 0); }
}