use std::sync::Arc;
#[inline]
fn distance_l2_quantized_simd(a: &[u8], b: &[u8]) -> u32 {
debug_assert_eq!(a.len(), b.len());
let chunks = a.len() / 8;
let remainder = a.len() % 8;
let mut sum0: u32 = 0;
let mut sum1: u32 = 0;
let mut sum2: u32 = 0;
let mut sum3: u32 = 0;
for i in 0..chunks {
let base = i * 8;
let d0 = i32::from(a[base]) - i32::from(b[base]);
let d1 = i32::from(a[base + 1]) - i32::from(b[base + 1]);
let d2 = i32::from(a[base + 2]) - i32::from(b[base + 2]);
let d3 = i32::from(a[base + 3]) - i32::from(b[base + 3]);
let d4 = i32::from(a[base + 4]) - i32::from(b[base + 4]);
let d5 = i32::from(a[base + 5]) - i32::from(b[base + 5]);
let d6 = i32::from(a[base + 6]) - i32::from(b[base + 6]);
let d7 = i32::from(a[base + 7]) - i32::from(b[base + 7]);
#[allow(clippy::cast_sign_loss)] {
sum0 += (d0 * d0) as u32 + (d4 * d4) as u32;
sum1 += (d1 * d1) as u32 + (d5 * d5) as u32;
sum2 += (d2 * d2) as u32 + (d6 * d6) as u32;
sum3 += (d3 * d3) as u32 + (d7 * d7) as u32;
}
}
let base = chunks * 8;
for i in 0..remainder {
let diff = i32::from(a[base + i]) - i32::from(b[base + i]);
#[allow(clippy::cast_sign_loss)]
{
sum0 += (diff * diff) as u32;
}
}
sum0 + sum1 + sum2 + sum3
}
#[inline]
fn distance_l2_asymmetric_simd(
query: &[f32],
quantized: &[u8],
min_vals: &[f32],
inv_scales: &[f32],
) -> f32 {
debug_assert_eq!(query.len(), quantized.len());
debug_assert_eq!(query.len(), min_vals.len());
debug_assert_eq!(query.len(), inv_scales.len());
let chunks = query.len() / 4;
let remainder = query.len() % 4;
let (sum0, sum1, sum2, sum3) =
asymmetric_chunked_sum(query, quantized, min_vals, inv_scales, chunks);
let remainder_sum = asymmetric_remainder_sum(
query,
quantized,
min_vals,
inv_scales,
chunks * 4,
remainder,
);
(sum0 + sum1 + sum2 + sum3 + remainder_sum).sqrt()
}
#[inline]
fn asymmetric_chunked_sum(
query: &[f32],
quantized: &[u8],
min_vals: &[f32],
inv_scales: &[f32],
chunks: usize,
) -> (f32, f32, f32, f32) {
let mut sum0: f32 = 0.0;
let mut sum1: f32 = 0.0;
let mut sum2: f32 = 0.0;
let mut sum3: f32 = 0.0;
for i in 0..chunks {
let base = i * 4;
let dq0 = f32::from(quantized[base]) * inv_scales[base] + min_vals[base];
let dq1 = f32::from(quantized[base + 1]) * inv_scales[base + 1] + min_vals[base + 1];
let dq2 = f32::from(quantized[base + 2]) * inv_scales[base + 2] + min_vals[base + 2];
let dq3 = f32::from(quantized[base + 3]) * inv_scales[base + 3] + min_vals[base + 3];
let d0 = query[base] - dq0;
let d1 = query[base + 1] - dq1;
let d2 = query[base + 2] - dq2;
let d3 = query[base + 3] - dq3;
sum0 += d0 * d0;
sum1 += d1 * d1;
sum2 += d2 * d2;
sum3 += d3 * d3;
}
(sum0, sum1, sum2, sum3)
}
#[inline]
fn asymmetric_remainder_sum(
query: &[f32],
quantized: &[u8],
min_vals: &[f32],
inv_scales: &[f32],
base: usize,
remainder: usize,
) -> f32 {
let mut sum = 0.0_f32;
for i in 0..remainder {
let idx = base + i;
let dq = f32::from(quantized[idx]) * inv_scales[idx] + min_vals[idx];
let diff = query[idx] - dq;
sum += diff * diff;
}
sum
}
#[derive(Debug, Clone)]
pub struct ScalarQuantizer {
pub min_vals: Vec<f32>,
pub scales: Vec<f32>,
pub inv_scales: Vec<f32>,
pub dimension: usize,
}
#[derive(Debug, Clone)]
pub struct QuantizedVector {
pub data: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct QuantizedVectorStore {
quantizer: Arc<ScalarQuantizer>,
data: Vec<u8>,
count: usize,
}
impl ScalarQuantizer {
pub fn train(vectors: &[&[f32]]) -> crate::error::Result<Self> {
if vectors.is_empty() {
return Err(crate::error::Error::InvalidQuantizerConfig(
"cannot train on empty vectors".to_string(),
));
}
let dimension = vectors[0].len();
if !vectors.iter().all(|v| v.len() == dimension) {
return Err(crate::error::Error::InvalidQuantizerConfig(
"all vectors must have same dimension".to_string(),
));
}
let mut min_vals = vec![f32::MAX; dimension];
let mut max_vals = vec![f32::MIN; dimension];
for vec in vectors {
for (i, &val) in vec.iter().enumerate() {
min_vals[i] = min_vals[i].min(val);
max_vals[i] = max_vals[i].max(val);
}
}
let scales: Vec<f32> = min_vals
.iter()
.zip(max_vals.iter())
.map(|(&min, &max)| {
let range = max - min;
if range.abs() < 1e-10 {
1.0 } else {
255.0 / range
}
})
.collect();
let inv_scales: Vec<f32> = scales.iter().map(|&s| 1.0 / s).collect();
Ok(Self {
min_vals,
scales,
inv_scales,
dimension,
})
}
#[must_use]
pub fn quantize(&self, vector: &[f32]) -> QuantizedVector {
debug_assert_eq!(vector.len(), self.dimension);
let data: Vec<u8> = vector
.iter()
.zip(self.min_vals.iter())
.zip(self.scales.iter())
.map(|((&val, &min), &scale)| {
let q = ((val - min) * scale).round();
q.clamp(0.0, 255.0) as u8
})
.collect();
QuantizedVector { data }
}
#[must_use]
pub fn dequantize(&self, quantized: &QuantizedVector) -> Vec<f32> {
debug_assert_eq!(quantized.data.len(), self.dimension);
quantized
.data
.iter()
.zip(self.min_vals.iter())
.zip(self.inv_scales.iter())
.map(|((&q, &min), &inv_scale)| {
f32::from(q) * inv_scale + min
})
.collect()
}
#[inline]
#[must_use]
pub fn distance_l2_quantized(&self, a: &QuantizedVector, b: &QuantizedVector) -> u32 {
debug_assert_eq!(a.data.len(), b.data.len());
distance_l2_quantized_simd(&a.data, &b.data)
}
#[inline]
#[must_use]
pub fn distance_l2_quantized_slice(&self, a: &[u8], b: &[u8]) -> u32 {
debug_assert_eq!(a.len(), b.len());
distance_l2_quantized_simd(a, b)
}
#[inline]
#[must_use]
pub fn distance_l2_asymmetric(&self, query: &[f32], quantized: &QuantizedVector) -> f32 {
debug_assert_eq!(query.len(), self.dimension);
debug_assert_eq!(quantized.data.len(), self.dimension);
distance_l2_asymmetric_simd(query, &quantized.data, &self.min_vals, &self.inv_scales)
}
#[inline]
#[must_use]
pub fn distance_l2_asymmetric_slice(&self, query: &[f32], quantized: &[u8]) -> f32 {
debug_assert_eq!(query.len(), self.dimension);
debug_assert_eq!(quantized.len(), self.dimension);
distance_l2_asymmetric_simd(query, quantized, &self.min_vals, &self.inv_scales)
}
}
impl QuantizedVectorStore {
#[must_use]
pub fn new(quantizer: Arc<ScalarQuantizer>, capacity: usize) -> Self {
let dimension = quantizer.dimension;
Self {
quantizer,
data: Vec::with_capacity(capacity * dimension),
count: 0,
}
}
pub fn push(&mut self, vector: &[f32]) {
let quantized = self.quantizer.quantize(vector);
self.data.extend(quantized.data);
self.count += 1;
}
#[must_use]
pub fn get(&self, index: usize) -> Option<QuantizedVector> {
if index >= self.count {
return None;
}
let start = index * self.quantizer.dimension;
let end = start + self.quantizer.dimension;
Some(QuantizedVector {
data: self.data[start..end].to_vec(),
})
}
#[must_use]
pub fn get_slice(&self, index: usize) -> Option<&[u8]> {
if index >= self.count {
return None;
}
let start = index * self.quantizer.dimension;
let end = start + self.quantizer.dimension;
Some(&self.data[start..end])
}
#[must_use]
pub fn len(&self) -> usize {
self.count
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.count == 0
}
#[must_use]
pub fn quantizer(&self) -> &ScalarQuantizer {
&self.quantizer
}
}