use pulp::Simd;
use serde::{Deserialize, Serialize};
pub trait Quantizer: Send + Sync {
type Quantized: Clone + Send + Sync;
fn quantize(&self, vector: &[f32]) -> Self::Quantized;
fn dequantize(&self, quantized: &Self::Quantized) -> Vec<f32>;
fn distance_quantized(&self, a: &Self::Quantized, b: &Self::Quantized) -> f32;
fn distance_asymmetric(&self, query: &[f32], quantized: &Self::Quantized) -> f32;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScalarQuantizationParams {
pub min: f32,
pub max: f32,
pub scale: f32,
}
impl ScalarQuantizationParams {
pub fn new(min: f32, max: f32) -> Self {
let range = max - min;
let scale = if range > 0.0 { range / 255.0 } else { 1.0 };
Self { min, max, scale }
}
#[inline]
pub fn quantize_value(&self, value: f32) -> u8 {
let normalized = (value - self.min) / self.scale;
normalized.clamp(0.0, 255.0) as u8
}
#[inline]
pub fn dequantize_value(&self, quantized: u8) -> f32 {
(quantized as f32) * self.scale + self.min
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScalarQuantizer {
params: Vec<ScalarQuantizationParams>,
dim: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScalarQuantizedVector {
pub data: Vec<u8>,
}
impl ScalarQuantizer {
pub fn fit(training_vectors: &[Vec<f32>]) -> Self {
assert!(
!training_vectors.is_empty(),
"Need at least one training vector"
);
let dim = training_vectors[0].len();
let mut mins = vec![f32::INFINITY; dim];
let mut maxs = vec![f32::NEG_INFINITY; dim];
for vector in training_vectors {
assert_eq!(vector.len(), dim, "Inconsistent vector dimensions");
for (i, &val) in vector.iter().enumerate() {
mins[i] = mins[i].min(val);
maxs[i] = maxs[i].max(val);
}
}
let params: Vec<_> = mins
.iter()
.zip(maxs.iter())
.map(|(&min, &max)| ScalarQuantizationParams::new(min, max))
.collect();
Self { params, dim }
}
pub fn with_bounds(dim: usize, min: f32, max: f32) -> Self {
let params = vec![ScalarQuantizationParams::new(min, max); dim];
Self { params, dim }
}
pub fn for_normalized(dim: usize) -> Self {
Self::with_bounds(dim, -1.0, 1.0)
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn params(&self) -> &[ScalarQuantizationParams] {
&self.params
}
}
impl Quantizer for ScalarQuantizer {
type Quantized = ScalarQuantizedVector;
fn quantize(&self, vector: &[f32]) -> Self::Quantized {
debug_assert_eq!(vector.len(), self.dim);
let data: Vec<u8> = vector
.iter()
.zip(self.params.iter())
.map(|(&val, param)| param.quantize_value(val))
.collect();
ScalarQuantizedVector { data }
}
fn dequantize(&self, quantized: &Self::Quantized) -> Vec<f32> {
quantized
.data
.iter()
.zip(self.params.iter())
.map(|(&val, param)| param.dequantize_value(val))
.collect()
}
fn distance_quantized(&self, a: &Self::Quantized, b: &Self::Quantized) -> f32 {
sq8_l2_distance_simd(&a.data, &b.data)
}
fn distance_asymmetric(&self, query: &[f32], quantized: &Self::Quantized) -> f32 {
sq8_asymmetric_l2_distance_simd(query, &quantized.data, &self.params)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BinaryQuantizer {
dim: usize,
byte_len: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BinaryQuantizedVector {
pub data: Vec<u8>,
}
impl BinaryQuantizer {
pub fn new(dim: usize) -> Self {
let byte_len = (dim + 7) / 8; Self { dim, byte_len }
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn byte_len(&self) -> usize {
self.byte_len
}
}
impl Quantizer for BinaryQuantizer {
type Quantized = BinaryQuantizedVector;
fn quantize(&self, vector: &[f32]) -> Self::Quantized {
debug_assert_eq!(vector.len(), self.dim);
let mut data = vec![0u8; self.byte_len];
for (i, &val) in vector.iter().enumerate() {
if val >= 0.0 {
let byte_idx = i / 8;
let bit_idx = i % 8;
data[byte_idx] |= 1 << bit_idx;
}
}
BinaryQuantizedVector { data }
}
fn dequantize(&self, quantized: &Self::Quantized) -> Vec<f32> {
let mut result = vec![0.0f32; self.dim];
for i in 0..self.dim {
let byte_idx = i / 8;
let bit_idx = i % 8;
let bit = (quantized.data[byte_idx] >> bit_idx) & 1;
result[i] = if bit == 1 { 1.0 } else { -1.0 };
}
result
}
fn distance_quantized(&self, a: &Self::Quantized, b: &Self::Quantized) -> f32 {
hamming_distance_simd(&a.data, &b.data) as f32
}
fn distance_asymmetric(&self, query: &[f32], quantized: &Self::Quantized) -> f32 {
let mut mismatches = 0u32;
for (i, &val) in query.iter().enumerate() {
let byte_idx = i / 8;
let bit_idx = i % 8;
let quantized_bit = (quantized.data[byte_idx] >> bit_idx) & 1;
let query_bit = if val >= 0.0 { 1 } else { 0 };
if quantized_bit != query_bit {
mismatches += 1;
}
}
mismatches as f32
}
}
#[inline]
pub fn sq8_l2_distance_simd(a: &[u8], b: &[u8]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let simd = pulp::Arch::new();
simd.dispatch(|| sq8_l2_distance_impl(simd, a, b))
}
#[inline(always)]
fn sq8_l2_distance_impl(simd: pulp::Arch, a: &[u8], b: &[u8]) -> f32 {
struct Sq8L2<'a> {
a: &'a [u8],
b: &'a [u8],
}
impl pulp::WithSimd for Sq8L2<'_> {
type Output = f32;
#[inline(always)]
fn with_simd<S: Simd>(self, _simd: S) -> Self::Output {
let mut sum_sq: u32 = 0;
let mut chunks = self.a.chunks_exact(4).zip(self.b.chunks_exact(4));
for (a_chunk, b_chunk) in &mut chunks {
let d0 = (a_chunk[0] as i32) - (b_chunk[0] as i32);
let d1 = (a_chunk[1] as i32) - (b_chunk[1] as i32);
let d2 = (a_chunk[2] as i32) - (b_chunk[2] as i32);
let d3 = (a_chunk[3] as i32) - (b_chunk[3] as i32);
sum_sq += (d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3) as u32;
}
let rem_start = self.a.len() - self.a.len() % 4;
for i in rem_start..self.a.len() {
let diff = (self.a[i] as i32) - (self.b[i] as i32);
sum_sq += (diff * diff) as u32;
}
(sum_sq as f32).sqrt()
}
}
simd.dispatch(Sq8L2 { a, b })
}
#[inline]
pub fn sq8_asymmetric_l2_distance_simd(
query: &[f32],
quantized: &[u8],
params: &[ScalarQuantizationParams],
) -> f32 {
debug_assert_eq!(query.len(), quantized.len());
debug_assert_eq!(query.len(), params.len());
let simd = pulp::Arch::new();
simd.dispatch(|| sq8_asymmetric_l2_impl(simd, query, quantized, params))
}
#[inline(always)]
fn sq8_asymmetric_l2_impl(
simd: pulp::Arch,
query: &[f32],
quantized: &[u8],
params: &[ScalarQuantizationParams],
) -> f32 {
struct AsymL2<'a> {
query: &'a [f32],
quantized: &'a [u8],
params: &'a [ScalarQuantizationParams],
}
impl pulp::WithSimd for AsymL2<'_> {
type Output = f32;
#[inline(always)]
fn with_simd<S: Simd>(self, _simd: S) -> Self::Output {
let mut sum_sq: f32 = 0.0;
let n = self.query.len();
let mut i = 0;
while i + 4 <= n {
let d0 = self.query[i] - self.params[i].dequantize_value(self.quantized[i]);
let d1 =
self.query[i + 1] - self.params[i + 1].dequantize_value(self.quantized[i + 1]);
let d2 =
self.query[i + 2] - self.params[i + 2].dequantize_value(self.quantized[i + 2]);
let d3 =
self.query[i + 3] - self.params[i + 3].dequantize_value(self.quantized[i + 3]);
sum_sq += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
i += 4;
}
while i < n {
let dequantized = self.params[i].dequantize_value(self.quantized[i]);
let diff = self.query[i] - dequantized;
sum_sq += diff * diff;
i += 1;
}
sum_sq.sqrt()
}
}
simd.dispatch(AsymL2 {
query,
quantized,
params,
})
}
#[inline]
pub fn hamming_distance_simd(a: &[u8], b: &[u8]) -> u32 {
debug_assert_eq!(a.len(), b.len());
let simd = pulp::Arch::new();
simd.dispatch(|| hamming_distance_impl(simd, a, b))
}
#[inline(always)]
fn hamming_distance_impl(simd: pulp::Arch, a: &[u8], b: &[u8]) -> u32 {
struct Hamming<'a> {
a: &'a [u8],
b: &'a [u8],
}
impl pulp::WithSimd for Hamming<'_> {
type Output = u32;
#[inline(always)]
fn with_simd<S: Simd>(self, _simd: S) -> Self::Output {
let mut distance = 0u32;
let chunks = self.a.len() / 8;
for i in 0..chunks {
let offset = i * 8;
let a_u64 = u64::from_le_bytes([
self.a[offset],
self.a[offset + 1],
self.a[offset + 2],
self.a[offset + 3],
self.a[offset + 4],
self.a[offset + 5],
self.a[offset + 6],
self.a[offset + 7],
]);
let b_u64 = u64::from_le_bytes([
self.b[offset],
self.b[offset + 1],
self.b[offset + 2],
self.b[offset + 3],
self.b[offset + 4],
self.b[offset + 5],
self.b[offset + 6],
self.b[offset + 7],
]);
distance += (a_u64 ^ b_u64).count_ones();
}
for i in (chunks * 8)..self.a.len() {
distance += (self.a[i] ^ self.b[i]).count_ones();
}
distance
}
}
simd.dispatch(Hamming { a, b })
}
#[inline]
pub fn binary_dot_product(query: &[f32], quantized: &BinaryQuantizedVector, dim: usize) -> f32 {
let mut sum = 0.0f32;
for i in 0..dim {
let byte_idx = i / 8;
let bit_idx = i % 8;
let bit = ((quantized.data[byte_idx] >> bit_idx) & 1) as f32;
let sign = bit * 2.0 - 1.0;
sum += query[i] * sign;
}
sum
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProductQuantizerConfig {
pub dim: usize,
pub num_subvectors: usize,
pub bits_per_subvector: usize,
}
impl ProductQuantizerConfig {
pub fn default_for_dim(dim: usize) -> Self {
let num_subvectors = 8.min(dim);
Self {
dim,
num_subvectors,
bits_per_subvector: 8,
}
}
pub fn subvector_dim(&self) -> usize {
self.dim / self.num_subvectors
}
pub fn num_centroids(&self) -> usize {
1 << self.bits_per_subvector
}
pub fn compressed_size(&self) -> usize {
self.num_subvectors * ((self.bits_per_subvector + 7) / 8)
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 1e-5;
#[test]
fn test_scalar_quantizer_fit() {
let vectors = vec![
vec![0.0, 0.5, 1.0],
vec![0.2, 0.3, 0.8],
vec![0.1, 0.6, 0.9],
];
let sq = ScalarQuantizer::fit(&vectors);
assert_eq!(sq.dim(), 3);
assert!((sq.params[0].min - 0.0).abs() < EPSILON);
assert!((sq.params[0].max - 0.2).abs() < EPSILON);
assert!((sq.params[2].min - 0.8).abs() < EPSILON);
assert!((sq.params[2].max - 1.0).abs() < EPSILON);
}
#[test]
fn test_scalar_quantizer_roundtrip() {
let vectors = vec![vec![-1.0, 0.0, 1.0], vec![-0.5, 0.5, 0.5]];
let sq = ScalarQuantizer::fit(&vectors);
let original = vec![-0.7, 0.3, 0.8];
let quantized = sq.quantize(&original);
let reconstructed = sq.dequantize(&quantized);
for (o, r) in original.iter().zip(reconstructed.iter()) {
assert!((o - r).abs() < 0.02, "orig={}, recon={}", o, r);
}
}
#[test]
fn test_scalar_quantizer_for_normalized() {
let sq = ScalarQuantizer::for_normalized(384);
assert_eq!(sq.dim(), 384);
let vector: Vec<f32> = (0..384).map(|i| (i as f32 / 192.0) - 1.0).collect();
let quantized = sq.quantize(&vector);
let reconstructed = sq.dequantize(&quantized);
let max_error: f32 = vector
.iter()
.zip(reconstructed.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, |a, b| a.max(b));
assert!(max_error < 0.01, "Max error: {}", max_error);
}
#[test]
fn test_sq8_distance_quantized() {
let sq = ScalarQuantizer::for_normalized(4);
let a = vec![1.0, 0.0, -1.0, 0.5];
let b = vec![1.0, 0.0, -1.0, 0.5];
let qa = sq.quantize(&a);
let qb = sq.quantize(&b);
let dist = sq.distance_quantized(&qa, &qb);
assert!(dist < 1.0, "Same vectors should have near-zero distance");
}
#[test]
fn test_sq8_distance_different() {
let sq = ScalarQuantizer::for_normalized(4);
let a = vec![1.0, 1.0, 1.0, 1.0];
let b = vec![-1.0, -1.0, -1.0, -1.0];
let qa = sq.quantize(&a);
let qb = sq.quantize(&b);
let dist = sq.distance_quantized(&qa, &qb);
assert!(dist > 100.0, "Opposite vectors should have large distance");
}
#[test]
fn test_binary_quantizer_basic() {
let bq = BinaryQuantizer::new(8);
assert_eq!(bq.dim(), 8);
assert_eq!(bq.byte_len(), 1);
}
#[test]
fn test_binary_quantizer_byte_len() {
assert_eq!(BinaryQuantizer::new(1).byte_len(), 1);
assert_eq!(BinaryQuantizer::new(8).byte_len(), 1);
assert_eq!(BinaryQuantizer::new(9).byte_len(), 2);
assert_eq!(BinaryQuantizer::new(16).byte_len(), 2);
assert_eq!(BinaryQuantizer::new(384).byte_len(), 48);
}
#[test]
fn test_binary_quantizer_all_positive() {
let bq = BinaryQuantizer::new(8);
let vector = vec![0.5, 0.3, 0.1, 0.9, 0.2, 0.4, 0.6, 0.8];
let quantized = bq.quantize(&vector);
assert_eq!(quantized.data[0], 0xFF);
}
#[test]
fn test_binary_quantizer_all_negative() {
let bq = BinaryQuantizer::new(8);
let vector = vec![-0.5, -0.3, -0.1, -0.9, -0.2, -0.4, -0.6, -0.8];
let quantized = bq.quantize(&vector);
assert_eq!(quantized.data[0], 0x00);
}
#[test]
fn test_binary_quantizer_mixed() {
let bq = BinaryQuantizer::new(8);
let vector = vec![0.5, -0.3, 0.1, -0.9, 0.2, -0.4, 0.6, -0.8];
let quantized = bq.quantize(&vector);
assert_eq!(quantized.data[0], 0b01010101);
}
#[test]
fn test_binary_hamming_distance() {
let bq = BinaryQuantizer::new(8);
let a = vec![1.0; 8]; let b = vec![-1.0; 8];
let qa = bq.quantize(&a);
let qb = bq.quantize(&b);
let dist = bq.distance_quantized(&qa, &qb);
assert_eq!(dist, 8.0); }
#[test]
fn test_binary_hamming_same() {
let bq = BinaryQuantizer::new(16);
let a = vec![
0.5, -0.3, 0.1, -0.9, 0.2, -0.4, 0.6, -0.8, 0.5, -0.3, 0.1, -0.9, 0.2, -0.4, 0.6, -0.8,
];
let qa = bq.quantize(&a);
let qb = bq.quantize(&a);
let dist = bq.distance_quantized(&qa, &qb);
assert_eq!(dist, 0.0); }
#[test]
fn test_binary_dequantize() {
let bq = BinaryQuantizer::new(4);
let vector = vec![0.5, -0.3, 0.1, -0.9];
let quantized = bq.quantize(&vector);
let dequantized = bq.dequantize(&quantized);
assert_eq!(dequantized, vec![1.0, -1.0, 1.0, -1.0]);
}
#[test]
fn test_binary_large_dimension() {
let bq = BinaryQuantizer::new(384);
let vector: Vec<f32> = (0..384)
.map(|i| if i % 2 == 0 { 0.5 } else { -0.5 })
.collect();
let quantized = bq.quantize(&vector);
assert_eq!(quantized.data.len(), 48);
let dequantized = bq.dequantize(&quantized);
for (i, &val) in dequantized.iter().enumerate() {
let expected = if i % 2 == 0 { 1.0 } else { -1.0 };
assert_eq!(val, expected);
}
}
#[test]
fn test_hamming_distance_simd_basic() {
let a = vec![0b11110000u8, 0b10101010];
let b = vec![0b00001111u8, 0b10101010];
let dist = hamming_distance_simd(&a, &b);
assert_eq!(dist, 8);
}
#[test]
fn test_hamming_distance_simd_same() {
let a = vec![0xFF, 0x00, 0xAB, 0xCD];
let b = a.clone();
let dist = hamming_distance_simd(&a, &b);
assert_eq!(dist, 0);
}
#[test]
fn test_sq8_l2_distance_simd_basic() {
let a = vec![0u8, 50, 100, 150, 200, 250];
let b = vec![0u8, 50, 100, 150, 200, 250];
let dist = sq8_l2_distance_simd(&a, &b);
assert!(dist < EPSILON);
}
#[test]
fn test_sq8_l2_distance_simd_different() {
let a = vec![0u8, 0, 0, 0];
let b = vec![255u8, 255, 255, 255];
let dist = sq8_l2_distance_simd(&a, &b);
assert!((dist - 510.0).abs() < 1.0);
}
#[test]
fn test_pq_config_defaults() {
let config = ProductQuantizerConfig::default_for_dim(384);
assert_eq!(config.dim, 384);
assert_eq!(config.num_subvectors, 8);
assert_eq!(config.bits_per_subvector, 8);
assert_eq!(config.subvector_dim(), 48);
assert_eq!(config.num_centroids(), 256);
assert_eq!(config.compressed_size(), 8); }
#[test]
fn test_sq8_recall_approximation() {
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let dim = 128;
let num_vectors = 100;
let vectors: Vec<Vec<f32>> = (0..num_vectors)
.map(|_| {
(0..dim)
.map(|_| rand::Rng::gen_range(&mut rng, -1.0..1.0))
.collect()
})
.collect();
let sq = ScalarQuantizer::fit(&vectors);
let quantized: Vec<_> = vectors.iter().map(|v| sq.quantize(v)).collect();
let query_idx = 42;
let query = &vectors[query_idx];
let query_q = &quantized[query_idx];
let mut exact_distances: Vec<(usize, f32)> = vectors
.iter()
.enumerate()
.filter(|(i, _)| *i != query_idx)
.map(|(i, v)| {
let dist: f32 = query
.iter()
.zip(v.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt();
(i, dist)
})
.collect();
let mut quantized_distances: Vec<(usize, f32)> = quantized
.iter()
.enumerate()
.filter(|(i, _)| *i != query_idx)
.map(|(i, q)| (i, sq.distance_quantized(query_q, q)))
.collect();
exact_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
quantized_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let exact_top10: std::collections::HashSet<_> =
exact_distances[..10].iter().map(|(i, _)| *i).collect();
let quantized_top10: std::collections::HashSet<_> =
quantized_distances[..10].iter().map(|(i, _)| *i).collect();
let recall = exact_top10.intersection(&quantized_top10).count();
assert!(recall >= 7, "Recall@10: {}/10", recall);
}
#[test]
fn test_binary_recall_approximation() {
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(123);
let dim = 128;
let num_vectors = 100;
let vectors: Vec<Vec<f32>> = (0..num_vectors)
.map(|_| {
(0..dim)
.map(|_| rand::Rng::gen_range(&mut rng, -1.0..1.0))
.collect()
})
.collect();
let bq = BinaryQuantizer::new(dim);
let quantized: Vec<_> = vectors.iter().map(|v| bq.quantize(v)).collect();
let query_idx = 42;
let query = &vectors[query_idx];
let query_q = &quantized[query_idx];
let mut exact_distances: Vec<(usize, f32)> = vectors
.iter()
.enumerate()
.filter(|(i, _)| *i != query_idx)
.map(|(i, v)| {
let dot: f32 = query.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
let norm_q: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_v: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
let cosine = dot / (norm_q * norm_v);
(i, 1.0 - cosine) })
.collect();
let mut quantized_distances: Vec<(usize, f32)> = quantized
.iter()
.enumerate()
.filter(|(i, _)| *i != query_idx)
.map(|(i, q)| (i, bq.distance_quantized(query_q, q)))
.collect();
exact_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
quantized_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let exact_top10: std::collections::HashSet<_> =
exact_distances[..10].iter().map(|(i, _)| *i).collect();
let quantized_top10: std::collections::HashSet<_> =
quantized_distances[..10].iter().map(|(i, _)| *i).collect();
let recall = exact_top10.intersection(&quantized_top10).count();
assert!(recall >= 5, "Binary recall@10: {}/10", recall);
}
}