use super::DistanceMetric;
pub struct QuantizedVectors {
pub dims: usize,
pub num_vectors: usize,
pub data: Vec<u8>,
pub mins: Vec<f32>,
pub scales: Vec<f32>,
pub norms: Vec<f32>,
pub metric: DistanceMetric,
}
impl QuantizedVectors {
pub fn quantize(vectors: &[Vec<f32>], metric: DistanceMetric) -> Self {
if vectors.is_empty() {
return Self {
dims: 0,
num_vectors: 0,
data: Vec::new(),
mins: Vec::new(),
scales: Vec::new(),
norms: Vec::new(),
metric,
};
}
let dims = vectors[0].len();
let num_vectors = vectors.len();
let mut mins = vec![f32::MAX; dims];
let mut maxs = vec![f32::MIN; dims];
for v in vectors {
for d in 0..dims {
if v[d] < mins[d] {
mins[d] = v[d];
}
if v[d] > maxs[d] {
maxs[d] = v[d];
}
}
}
let scales: Vec<f32> = (0..dims)
.map(|d| {
let range = maxs[d] - mins[d];
if range == 0.0 { 0.0 } else { range / 255.0 }
})
.collect();
let mut data = vec![0u8; num_vectors * dims];
let mut norms = vec![0.0f32; num_vectors];
for (i, v) in vectors.iter().enumerate() {
let offset = i * dims;
let mut norm_sq = 0.0f32;
for d in 0..dims {
let q = if scales[d] == 0.0 {
128u8 } else {
((v[d] - mins[d]) / scales[d]).round().clamp(0.0, 255.0) as u8
};
data[offset + d] = q;
let dequant = mins[d] + q as f32 * scales[d];
norm_sq += dequant * dequant;
}
norms[i] = norm_sq.sqrt();
}
Self {
dims,
num_vectors,
data,
mins,
scales,
norms,
metric,
}
}
#[inline]
pub fn get(&self, idx: usize) -> &[u8] {
let start = idx * self.dims;
&self.data[start..start + self.dims]
}
#[inline]
pub fn asymmetric_distance(&self, idx: usize, query: &[f32]) -> f32 {
match self.metric {
DistanceMetric::Cosine => self.asymmetric_cosine(idx, query),
DistanceMetric::DotProduct => self.asymmetric_dot(idx, query),
DistanceMetric::L2 => self.asymmetric_l2(idx, query),
}
}
fn asymmetric_cosine(&self, idx: usize, query: &[f32]) -> f32 {
let quantized = self.get(idx);
let mut dot = 0.0f32;
for d in 0..self.dims {
let dequant = self.mins[d] + quantized[d] as f32 * self.scales[d];
dot += dequant * query[d];
}
let stored_norm = self.norms[idx];
if stored_norm == 0.0 {
1.0
} else {
1.0 - dot / stored_norm
}
}
fn asymmetric_dot(&self, idx: usize, query: &[f32]) -> f32 {
let quantized = self.get(idx);
let mut dot = 0.0f32;
for d in 0..self.dims {
let dequant = self.mins[d] + quantized[d] as f32 * self.scales[d];
dot += dequant * query[d];
}
-dot
}
fn asymmetric_l2(&self, idx: usize, query: &[f32]) -> f32 {
let quantized = self.get(idx);
let mut sum_sq = 0.0f32;
for d in 0..self.dims {
let dequant = self.mins[d] + quantized[d] as f32 * self.scales[d];
let diff = dequant - query[d];
sum_sq += diff * diff;
}
sum_sq.sqrt()
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&(self.dims as u32).to_le_bytes());
buf.extend_from_slice(&(self.num_vectors as u32).to_le_bytes());
buf.push(self.metric as u8);
for &m in &self.mins {
buf.extend_from_slice(&m.to_le_bytes());
}
for &s in &self.scales {
buf.extend_from_slice(&s.to_le_bytes());
}
for &n in &self.norms {
buf.extend_from_slice(&n.to_le_bytes());
}
buf.extend_from_slice(&self.data);
buf
}
pub fn from_bytes(data: &[u8]) -> Self {
let dims = u32::from_le_bytes(data[0..4].try_into().unwrap()) as usize;
let num_vectors = u32::from_le_bytes(data[4..8].try_into().unwrap()) as usize;
let metric = DistanceMetric::from_byte(data[8]);
let mut pos = 9;
let mut mins = vec![0.0f32; dims];
for d in 0..dims {
mins[d] = f32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
pos += 4;
}
let mut scales = vec![0.0f32; dims];
for d in 0..dims {
scales[d] = f32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
pos += 4;
}
let mut norms = vec![0.0f32; num_vectors];
for i in 0..num_vectors {
norms[i] = f32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
pos += 4;
}
let qdata = data[pos..pos + num_vectors * dims].to_vec();
Self {
dims,
num_vectors,
data: qdata,
mins,
scales,
norms,
metric,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn quantize_round_trip() {
let vectors = vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
vec![7.0, 8.0, 9.0],
];
let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::Cosine);
assert_eq!(qv.dims, 3);
assert_eq!(qv.num_vectors, 3);
assert_eq!(qv.get(0), &[0, 0, 0]);
assert_eq!(qv.get(2), &[255, 255, 255]);
}
#[test]
fn asymmetric_cosine_close_to_exact() {
let vectors = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.707, 0.707, 0.0],
];
let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::Cosine);
let query = vec![1.0, 0.0, 0.0];
let d0 = qv.asymmetric_distance(0, &query);
let d1 = qv.asymmetric_distance(1, &query);
let d2 = qv.asymmetric_distance(2, &query);
assert!(d0 < d2, "d0={d0} should be < d2={d2}");
assert!(d2 < d1, "d2={d2} should be < d1={d1}");
}
#[test]
fn serialization_round_trip() {
let vectors = vec![vec![1.5, -2.3, 0.7, 4.1], vec![-0.5, 3.2, 1.1, -1.0]];
let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::L2);
let bytes = qv.to_bytes();
let qv2 = QuantizedVectors::from_bytes(&bytes);
assert_eq!(qv.dims, qv2.dims);
assert_eq!(qv.num_vectors, qv2.num_vectors);
assert_eq!(qv.data, qv2.data);
assert_eq!(qv.mins, qv2.mins);
assert_eq!(qv.scales, qv2.scales);
}
#[test]
fn empty_vectors() {
let qv = QuantizedVectors::quantize(&[], DistanceMetric::Cosine);
assert_eq!(qv.num_vectors, 0);
assert_eq!(qv.dims, 0);
}
#[test]
fn constant_dimension() {
let vectors = vec![vec![1.0, 5.0], vec![2.0, 5.0], vec![3.0, 5.0]];
let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::L2);
assert_eq!(qv.get(0)[1], 128);
assert_eq!(qv.get(1)[1], 128);
assert_eq!(qv.get(2)[1], 128);
}
}