use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub enum QuantizationType {
#[default]
None,
Scalar,
Binary,
Product {
num_subvectors: usize,
},
}
impl QuantizationType {
#[must_use]
pub fn compression_ratio(&self, dimensions: usize) -> usize {
match self {
Self::None => 1,
Self::Scalar => 4, Self::Binary => 32, Self::Product { num_subvectors } => {
let m = (*num_subvectors).max(1);
(dimensions * 4) / m
}
}
}
#[must_use]
pub fn name(&self) -> &'static str {
match self {
Self::None => "none",
Self::Scalar => "scalar",
Self::Binary => "binary",
Self::Product { .. } => "product",
}
}
#[must_use]
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"none" | "full" | "f32" => Some(Self::None),
"scalar" | "sq" | "u8" | "int8" => Some(Self::Scalar),
"binary" | "bin" | "bit" | "1bit" => Some(Self::Binary),
"product" | "pq" => Some(Self::Product { num_subvectors: 8 }),
s if s.starts_with("pq") => {
s[2..]
.parse()
.ok()
.map(|n| Self::Product { num_subvectors: n })
}
_ => None,
}
}
#[must_use]
pub const fn requires_training(&self) -> bool {
matches!(self, Self::Scalar | Self::Product { .. })
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScalarQuantizer {
min: Vec<f32>,
scale: Vec<f32>,
inv_scale: Vec<f32>,
dimensions: usize,
}
impl ScalarQuantizer {
#[must_use]
pub fn train(vectors: &[&[f32]]) -> Self {
assert!(!vectors.is_empty(), "Cannot train on empty vector set");
let dimensions = vectors[0].len();
assert!(
vectors.iter().all(|v| v.len() == dimensions),
"All training vectors must have the same dimensions"
);
let mut min = vec![f32::INFINITY; dimensions];
let mut max = vec![f32::NEG_INFINITY; dimensions];
for vec in vectors {
for (i, &v) in vec.iter().enumerate() {
min[i] = min[i].min(v);
max[i] = max[i].max(v);
}
}
let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
.iter()
.zip(&max)
.map(|(&mn, &mx)| {
let range = mx - mn;
if range.abs() < f32::EPSILON {
(1.0, 1.0)
} else {
(255.0 / range, range / 255.0)
}
})
.unzip();
Self {
min,
scale,
inv_scale,
dimensions,
}
}
#[must_use]
pub fn with_ranges(min: Vec<f32>, max: Vec<f32>) -> Self {
let dimensions = min.len();
assert_eq!(min.len(), max.len(), "Min and max must have same length");
let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
.iter()
.zip(&max)
.map(|(&mn, &mx)| {
let range = mx - mn;
if range.abs() < f32::EPSILON {
(1.0, 1.0)
} else {
(255.0 / range, range / 255.0)
}
})
.unzip();
Self {
min,
scale,
inv_scale,
dimensions,
}
}
#[must_use]
pub fn dimensions(&self) -> usize {
self.dimensions
}
#[must_use]
pub fn min_values(&self) -> &[f32] {
&self.min
}
#[must_use]
pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
debug_assert_eq!(
vector.len(),
self.dimensions,
"Vector dimension mismatch: expected {}, got {}",
self.dimensions,
vector.len()
);
vector
.iter()
.enumerate()
.map(|(i, &v)| {
let normalized = (v - self.min[i]) * self.scale[i];
normalized.clamp(0.0, 255.0) as u8
})
.collect()
}
#[must_use]
pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
vectors.iter().map(|v| self.quantize(v)).collect()
}
#[must_use]
pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
debug_assert_eq!(quantized.len(), self.dimensions);
quantized
.iter()
.enumerate()
.map(|(i, &q)| self.min[i] + (q as f32) * self.inv_scale[i])
.collect()
}
#[must_use]
pub fn distance_squared_u8(&self, a: &[u8], b: &[u8]) -> f32 {
debug_assert_eq!(a.len(), self.dimensions);
debug_assert_eq!(b.len(), self.dimensions);
let mut sum = 0.0f32;
for i in 0..a.len() {
let diff = (a[i] as f32) - (b[i] as f32);
sum += diff * diff * self.inv_scale[i] * self.inv_scale[i];
}
sum
}
#[must_use]
#[inline]
pub fn distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
self.distance_squared_u8(a, b).sqrt()
}
#[must_use]
pub fn cosine_distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
debug_assert_eq!(a.len(), self.dimensions);
debug_assert_eq!(b.len(), self.dimensions);
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..a.len() {
let va = self.min[i] + (a[i] as f32) * self.inv_scale[i];
let vb = self.min[i] + (b[i] as f32) * self.inv_scale[i];
dot += va * vb;
norm_a += va * va;
norm_b += vb * vb;
}
let denom = (norm_a * norm_b).sqrt();
if denom < f32::EPSILON {
1.0 } else {
1.0 - (dot / denom)
}
}
#[must_use]
pub fn asymmetric_distance_squared(&self, query: &[f32], quantized: &[u8]) -> f32 {
debug_assert_eq!(query.len(), self.dimensions);
debug_assert_eq!(quantized.len(), self.dimensions);
let mut sum = 0.0f32;
for i in 0..query.len() {
let dequant = self.min[i] + (quantized[i] as f32) * self.inv_scale[i];
let diff = query[i] - dequant;
sum += diff * diff;
}
sum
}
#[must_use]
#[inline]
pub fn asymmetric_distance(&self, query: &[f32], quantized: &[u8]) -> f32 {
self.asymmetric_distance_squared(query, quantized).sqrt()
}
}
pub struct BinaryQuantizer;
impl BinaryQuantizer {
#[must_use]
pub fn quantize(vector: &[f32]) -> Vec<u64> {
let num_words = (vector.len() + 63) / 64;
let mut result = vec![0u64; num_words];
for (i, &v) in vector.iter().enumerate() {
if v >= 0.0 {
result[i / 64] |= 1u64 << (i % 64);
}
}
result
}
#[must_use]
pub fn quantize_batch(vectors: &[&[f32]]) -> Vec<Vec<u64>> {
vectors.iter().map(|v| Self::quantize(v)).collect()
}
#[must_use]
pub fn hamming_distance(a: &[u64], b: &[u64]) -> u32 {
debug_assert_eq!(a.len(), b.len(), "Binary vectors must have same length");
a.iter().zip(b).map(|(&x, &y)| (x ^ y).count_ones()).sum()
}
#[must_use]
pub fn hamming_distance_normalized(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
let hamming = Self::hamming_distance(a, b);
hamming as f32 / dimensions as f32
}
#[must_use]
pub fn approximate_euclidean(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
let hamming = Self::hamming_distance(a, b);
(2.0 * hamming as f32 / dimensions as f32).sqrt()
}
#[must_use]
pub const fn words_needed(dimensions: usize) -> usize {
(dimensions + 63) / 64
}
#[must_use]
pub const fn bytes_needed(dimensions: usize) -> usize {
Self::words_needed(dimensions) * 8
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProductQuantizer {
num_subvectors: usize,
num_centroids: usize,
subvector_dim: usize,
dimensions: usize,
centroids: Vec<f32>,
}
impl ProductQuantizer {
#[must_use]
pub fn train(
vectors: &[&[f32]],
num_subvectors: usize,
num_centroids: usize,
iterations: usize,
) -> Self {
assert!(!vectors.is_empty(), "Cannot train on empty vector set");
assert!(
num_centroids <= 256,
"num_centroids must be <= 256 for u8 codes"
);
assert!(num_subvectors > 0, "num_subvectors must be > 0");
let dimensions = vectors[0].len();
assert!(
dimensions.is_multiple_of(num_subvectors),
"dimensions ({dimensions}) must be divisible by num_subvectors ({num_subvectors})"
);
assert!(
vectors.iter().all(|v| v.len() == dimensions),
"All training vectors must have the same dimensions"
);
let subvector_dim = dimensions / num_subvectors;
let mut centroids = Vec::with_capacity(num_subvectors * num_centroids * subvector_dim);
for m in 0..num_subvectors {
let subvectors: Vec<Vec<f32>> = vectors
.iter()
.map(|v| {
let start = m * subvector_dim;
let end = start + subvector_dim;
v[start..end].to_vec()
})
.collect();
let partition_centroids =
Self::kmeans(&subvectors, num_centroids, subvector_dim, iterations);
centroids.extend(partition_centroids);
}
Self {
num_subvectors,
num_centroids,
subvector_dim,
dimensions,
centroids,
}
}
fn kmeans(vectors: &[Vec<f32>], k: usize, dims: usize, iterations: usize) -> Vec<f32> {
let n = vectors.len();
let actual_k = k.min(n);
let mut centroids: Vec<f32> = if actual_k == n {
vectors.iter().flat_map(|v| v.iter().copied()).collect()
} else {
let step = n / actual_k;
(0..actual_k)
.flat_map(|i| vectors[i * step].iter().copied())
.collect()
};
if actual_k < k {
centroids.resize(k * dims, 0.0);
}
let mut assignments = vec![0usize; n];
let mut counts = vec![0usize; k];
for _ in 0..iterations {
for (i, vec) in vectors.iter().enumerate() {
let mut best_dist = f32::INFINITY;
let mut best_k = 0;
for j in 0..k {
let centroid_start = j * dims;
let dist: f32 = vec
.iter()
.enumerate()
.map(|(d, &v)| {
let diff = v - centroids[centroid_start + d];
diff * diff
})
.sum();
if dist < best_dist {
best_dist = dist;
best_k = j;
}
}
assignments[i] = best_k;
}
centroids.fill(0.0);
counts.fill(0);
for (i, vec) in vectors.iter().enumerate() {
let k_idx = assignments[i];
let centroid_start = k_idx * dims;
counts[k_idx] += 1;
for (d, &v) in vec.iter().enumerate() {
centroids[centroid_start + d] += v;
}
}
for j in 0..k {
if counts[j] > 0 {
let centroid_start = j * dims;
let count = counts[j] as f32;
for d in 0..dims {
centroids[centroid_start + d] /= count;
}
}
}
}
centroids
}
#[must_use]
pub fn with_centroids(
num_subvectors: usize,
num_centroids: usize,
dimensions: usize,
centroids: Vec<f32>,
) -> Self {
let subvector_dim = dimensions / num_subvectors;
assert_eq!(
centroids.len(),
num_subvectors * num_centroids * subvector_dim,
"Invalid centroid count"
);
Self {
num_subvectors,
num_centroids,
subvector_dim,
dimensions,
centroids,
}
}
#[must_use]
pub fn num_subvectors(&self) -> usize {
self.num_subvectors
}
#[must_use]
pub fn num_centroids(&self) -> usize {
self.num_centroids
}
#[must_use]
pub fn dimensions(&self) -> usize {
self.dimensions
}
#[must_use]
pub fn subvector_dim(&self) -> usize {
self.subvector_dim
}
#[must_use]
pub fn code_size(&self) -> usize {
self.num_subvectors }
#[must_use]
pub fn compression_ratio(&self) -> usize {
(self.dimensions * 4) / self.num_subvectors
}
#[must_use]
pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
debug_assert_eq!(
vector.len(),
self.dimensions,
"Vector dimension mismatch: expected {}, got {}",
self.dimensions,
vector.len()
);
let mut codes = Vec::with_capacity(self.num_subvectors);
for m in 0..self.num_subvectors {
let subvec_start = m * self.subvector_dim;
let subvec = &vector[subvec_start..subvec_start + self.subvector_dim];
let mut best_dist = f32::INFINITY;
let mut best_k = 0u8;
for k in 0..self.num_centroids {
let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
let dist: f32 = subvec
.iter()
.enumerate()
.map(|(d, &v)| {
let diff = v - self.centroids[centroid_start + d];
diff * diff
})
.sum();
if dist < best_dist {
best_dist = dist;
best_k = k as u8;
}
}
codes.push(best_k);
}
codes
}
#[must_use]
pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
vectors.iter().map(|v| self.quantize(v)).collect()
}
#[must_use]
pub fn build_distance_table(&self, query: &[f32]) -> Vec<f32> {
debug_assert_eq!(query.len(), self.dimensions);
let mut table = Vec::with_capacity(self.num_subvectors * self.num_centroids);
for m in 0..self.num_subvectors {
let query_start = m * self.subvector_dim;
let query_subvec = &query[query_start..query_start + self.subvector_dim];
for k in 0..self.num_centroids {
let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
let dist: f32 = query_subvec
.iter()
.enumerate()
.map(|(d, &v)| {
let diff = v - self.centroids[centroid_start + d];
diff * diff
})
.sum();
table.push(dist);
}
}
table
}
#[must_use]
#[inline]
pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
debug_assert_eq!(codes.len(), self.num_subvectors);
debug_assert_eq!(table.len(), self.num_subvectors * self.num_centroids);
codes
.iter()
.enumerate()
.map(|(m, &code)| table[m * self.num_centroids + code as usize])
.sum()
}
#[must_use]
pub fn asymmetric_distance_squared(&self, query: &[f32], codes: &[u8]) -> f32 {
let table = self.build_distance_table(query);
self.distance_with_table(&table, codes)
}
#[must_use]
#[inline]
pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
self.asymmetric_distance_squared(query, codes).sqrt()
}
#[must_use]
pub fn reconstruct(&self, codes: &[u8]) -> Vec<f32> {
debug_assert_eq!(codes.len(), self.num_subvectors);
let mut result = Vec::with_capacity(self.dimensions);
for (m, &code) in codes.iter().enumerate() {
let centroid_start = (m * self.num_centroids + code as usize) * self.subvector_dim;
result.extend_from_slice(
&self.centroids[centroid_start..centroid_start + self.subvector_dim],
);
}
result
}
#[must_use]
pub fn get_partition_centroids(&self, partition: usize) -> Vec<&[f32]> {
assert!(partition < self.num_subvectors);
(0..self.num_centroids)
.map(|k| {
let start = (partition * self.num_centroids + k) * self.subvector_dim;
&self.centroids[start..start + self.subvector_dim]
})
.collect()
}
}
#[cfg(target_arch = "x86_64")]
#[must_use]
pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
a.iter()
.zip(b)
.map(|(&x, &y)| {
let xor = x ^ y;
#[allow(unsafe_code)]
unsafe {
std::arch::x86_64::_popcnt64(xor as i64) as u32
}
})
.sum()
}
#[cfg(not(target_arch = "x86_64"))]
#[must_use]
pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
BinaryQuantizer::hamming_distance(a, b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantization_type_compression_ratio() {
let dims = 384;
assert_eq!(QuantizationType::None.compression_ratio(dims), 1);
assert_eq!(QuantizationType::Scalar.compression_ratio(dims), 4);
assert_eq!(QuantizationType::Binary.compression_ratio(dims), 32);
let pq8 = QuantizationType::Product { num_subvectors: 8 };
assert_eq!(pq8.compression_ratio(dims), 192);
let pq16 = QuantizationType::Product { num_subvectors: 16 };
assert_eq!(pq16.compression_ratio(dims), 96);
}
#[test]
fn test_quantization_type_from_str() {
assert_eq!(
QuantizationType::from_str("none"),
Some(QuantizationType::None)
);
assert_eq!(
QuantizationType::from_str("scalar"),
Some(QuantizationType::Scalar)
);
assert_eq!(
QuantizationType::from_str("SQ"),
Some(QuantizationType::Scalar)
);
assert_eq!(
QuantizationType::from_str("binary"),
Some(QuantizationType::Binary)
);
assert_eq!(
QuantizationType::from_str("bit"),
Some(QuantizationType::Binary)
);
assert_eq!(QuantizationType::from_str("invalid"), None);
}
#[test]
fn test_scalar_quantizer_train() {
let vectors = [
vec![0.0f32, 0.5, 1.0],
vec![0.2, 0.3, 0.8],
vec![0.1, 0.6, 0.9],
];
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let quantizer = ScalarQuantizer::train(&refs);
assert_eq!(quantizer.dimensions(), 3);
assert_eq!(quantizer.min_values()[0], 0.0);
assert_eq!(quantizer.min_values()[1], 0.3);
assert_eq!(quantizer.min_values()[2], 0.8);
}
#[test]
fn test_scalar_quantizer_quantize() {
let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
let q_min = quantizer.quantize(&[0.0, 0.0]);
assert_eq!(q_min, vec![0, 0]);
let q_max = quantizer.quantize(&[1.0, 1.0]);
assert_eq!(q_max, vec![255, 255]);
let q_mid = quantizer.quantize(&[0.5, 0.5]);
assert!(q_mid[0] >= 126 && q_mid[0] <= 128);
}
#[test]
fn test_scalar_quantizer_dequantize() {
let quantizer = ScalarQuantizer::with_ranges(vec![0.0], vec![1.0]);
let original = [0.5f32];
let quantized = quantizer.quantize(&original);
let dequantized = quantizer.dequantize(&quantized);
assert!((original[0] - dequantized[0]).abs() < 0.01);
}
#[test]
fn test_scalar_quantizer_distance() {
let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
let a = quantizer.quantize(&[0.0, 0.0]);
let b = quantizer.quantize(&[1.0, 0.0]);
let dist = quantizer.distance_u8(&a, &b);
assert!((dist - 1.0).abs() < 0.1);
}
#[test]
fn test_scalar_quantizer_asymmetric_distance() {
let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
let query = [0.0f32, 0.0];
let stored = quantizer.quantize(&[1.0, 0.0]);
let dist = quantizer.asymmetric_distance(&query, &stored);
assert!((dist - 1.0).abs() < 0.1);
}
#[test]
fn test_scalar_quantizer_cosine_distance() {
let quantizer = ScalarQuantizer::with_ranges(vec![-1.0, -1.0], vec![1.0, 1.0]);
let a = quantizer.quantize(&[1.0, 0.0]);
let b = quantizer.quantize(&[0.0, 1.0]);
let dist = quantizer.cosine_distance_u8(&a, &b);
assert!((dist - 1.0).abs() < 0.1);
}
#[test]
#[should_panic(expected = "Cannot train on empty vector set")]
fn test_scalar_quantizer_empty_training() {
let vectors: Vec<&[f32]> = vec![];
let _ = ScalarQuantizer::train(&vectors);
}
#[test]
fn test_binary_quantizer_quantize() {
let v = vec![0.5f32, -0.3, 0.0, 0.8];
let bits = BinaryQuantizer::quantize(&v);
assert_eq!(bits.len(), 1);
assert_eq!(bits[0] & 0xF, 0b1101);
}
#[test]
fn test_binary_quantizer_hamming_distance() {
let v1 = vec![1.0f32, 1.0, 1.0, 1.0]; let v2 = vec![1.0f32, -1.0, 1.0, -1.0];
let bits1 = BinaryQuantizer::quantize(&v1);
let bits2 = BinaryQuantizer::quantize(&v2);
let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
assert_eq!(dist, 2); }
#[test]
fn test_binary_quantizer_identical_vectors() {
let v = vec![0.1f32, -0.2, 0.3, -0.4, 0.5];
let bits = BinaryQuantizer::quantize(&v);
let dist = BinaryQuantizer::hamming_distance(&bits, &bits);
assert_eq!(dist, 0);
}
#[test]
fn test_binary_quantizer_opposite_vectors() {
let v1 = vec![1.0f32; 64];
let v2 = vec![-1.0f32; 64];
let bits1 = BinaryQuantizer::quantize(&v1);
let bits2 = BinaryQuantizer::quantize(&v2);
let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
assert_eq!(dist, 64); }
#[test]
fn test_binary_quantizer_large_vector() {
let v: Vec<f32> = (0..1000)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let bits = BinaryQuantizer::quantize(&v);
assert_eq!(bits.len(), 16);
}
#[test]
fn test_binary_quantizer_normalized_distance() {
let v1 = vec![1.0f32; 100];
let v2 = vec![-1.0f32; 100];
let bits1 = BinaryQuantizer::quantize(&v1);
let bits2 = BinaryQuantizer::quantize(&v2);
let norm_dist = BinaryQuantizer::hamming_distance_normalized(&bits1, &bits2, 100);
assert!((norm_dist - 1.0).abs() < 0.01); }
#[test]
fn test_binary_quantizer_words_needed() {
assert_eq!(BinaryQuantizer::words_needed(1), 1);
assert_eq!(BinaryQuantizer::words_needed(64), 1);
assert_eq!(BinaryQuantizer::words_needed(65), 2);
assert_eq!(BinaryQuantizer::words_needed(128), 2);
assert_eq!(BinaryQuantizer::words_needed(1536), 24); }
#[test]
fn test_binary_quantizer_bytes_needed() {
assert_eq!(BinaryQuantizer::bytes_needed(64), 8);
assert_eq!(BinaryQuantizer::bytes_needed(128), 16);
assert_eq!(BinaryQuantizer::bytes_needed(1536), 192); }
#[test]
fn test_hamming_distance_simd() {
let a = vec![0xFFFF_FFFF_FFFF_FFFFu64, 0x0000_0000_0000_0000];
let b = vec![0x0000_0000_0000_0000u64, 0xFFFF_FFFF_FFFF_FFFF];
let dist = hamming_distance_simd(&a, &b);
assert_eq!(dist, 128); }
#[test]
fn test_quantization_type_product_from_str() {
assert_eq!(
QuantizationType::from_str("pq"),
Some(QuantizationType::Product { num_subvectors: 8 })
);
assert_eq!(
QuantizationType::from_str("product"),
Some(QuantizationType::Product { num_subvectors: 8 })
);
assert_eq!(
QuantizationType::from_str("pq8"),
Some(QuantizationType::Product { num_subvectors: 8 })
);
assert_eq!(
QuantizationType::from_str("pq16"),
Some(QuantizationType::Product { num_subvectors: 16 })
);
assert_eq!(
QuantizationType::from_str("pq32"),
Some(QuantizationType::Product { num_subvectors: 32 })
);
}
#[test]
fn test_quantization_type_requires_training() {
assert!(!QuantizationType::None.requires_training());
assert!(QuantizationType::Scalar.requires_training());
assert!(!QuantizationType::Binary.requires_training());
assert!(QuantizationType::Product { num_subvectors: 8 }.requires_training());
}
#[test]
fn test_product_quantizer_train() {
let vectors: Vec<Vec<f32>> = (0..100)
.map(|i| (0..16).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let pq = ProductQuantizer::train(&refs, 4, 8, 5);
assert_eq!(pq.num_subvectors(), 4);
assert_eq!(pq.num_centroids(), 8);
assert_eq!(pq.dimensions(), 16);
assert_eq!(pq.subvector_dim(), 4);
assert_eq!(pq.code_size(), 4);
}
#[test]
fn test_product_quantizer_quantize() {
let vectors: Vec<Vec<f32>> = (0..50)
.map(|i| (0..8).map(|j| ((i * j) as f32 * 0.1).cos()).collect())
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let pq = ProductQuantizer::train(&refs, 2, 16, 3);
let codes = pq.quantize(&vectors[0]);
assert_eq!(codes.len(), 2);
for &code in &codes {
assert!(code < 16);
}
}
#[test]
fn test_product_quantizer_reconstruct() {
let vectors: Vec<Vec<f32>> = (0..50)
.map(|i| (0..12).map(|j| (i + j) as f32 * 0.05).collect())
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let pq = ProductQuantizer::train(&refs, 3, 8, 5);
let original = &vectors[10];
let codes = pq.quantize(original);
let reconstructed = pq.reconstruct(&codes);
assert_eq!(reconstructed.len(), 12);
let error: f32 = original
.iter()
.zip(&reconstructed)
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt();
assert!(error < 2.0, "Reconstruction error too high: {error}");
}
#[test]
fn test_product_quantizer_asymmetric_distance() {
let vectors: Vec<Vec<f32>> = (0..100)
.map(|i| (0..32).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let pq = ProductQuantizer::train(&refs, 8, 32, 5);
let query = &vectors[0];
let codes = pq.quantize(query);
let self_dist = pq.asymmetric_distance(query, &codes);
assert!(self_dist < 1.0, "Self-distance too high: {self_dist}");
let other_codes = pq.quantize(&vectors[50]);
let other_dist = pq.asymmetric_distance(query, &other_codes);
assert!(other_dist > self_dist, "Other vector should be farther");
}
#[test]
fn test_product_quantizer_distance_table() {
let vectors: Vec<Vec<f32>> = (0..50)
.map(|i| (0..16).map(|j| (i + j) as f32 * 0.02).collect())
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let pq = ProductQuantizer::train(&refs, 4, 8, 3);
let query = &vectors[0];
let table = pq.build_distance_table(query);
assert_eq!(table.len(), 4 * 8);
let codes = pq.quantize(&vectors[5]);
let dist_direct = pq.asymmetric_distance_squared(query, &codes);
let dist_table = pq.distance_with_table(&table, &codes);
assert!((dist_direct - dist_table).abs() < 0.001);
}
#[test]
fn test_product_quantizer_batch() {
let vectors: Vec<Vec<f32>> = (0..20)
.map(|i| (0..8).map(|j| (i + j) as f32).collect())
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let pq = ProductQuantizer::train(&refs, 2, 4, 2);
let batch_codes = pq.quantize_batch(&refs[0..5]);
assert_eq!(batch_codes.len(), 5);
for codes in &batch_codes {
assert_eq!(codes.len(), 2);
}
}
#[test]
fn test_product_quantizer_compression_ratio() {
let vectors: Vec<Vec<f32>> = (0..50)
.map(|i| (0..384).map(|j| ((i * j) as f32).sin()).collect())
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let pq8 = ProductQuantizer::train(&refs, 8, 256, 3);
assert_eq!(pq8.compression_ratio(), 192);
let pq48 = ProductQuantizer::train(&refs, 48, 256, 3);
assert_eq!(pq48.compression_ratio(), 32); }
#[test]
#[should_panic(expected = "dimensions (15) must be divisible by num_subvectors (4)")]
fn test_product_quantizer_invalid_dimensions() {
let vectors: Vec<Vec<f32>> = (0..10)
.map(|i| (0..15).map(|j| (i + j) as f32).collect())
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let _ = ProductQuantizer::train(&refs, 4, 8, 3);
}
#[test]
fn test_product_quantizer_get_partition_centroids() {
let vectors: Vec<Vec<f32>> = (0..30)
.map(|i| (0..8).map(|j| (i + j) as f32 * 0.1).collect())
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let pq = ProductQuantizer::train(&refs, 2, 4, 3);
let centroids = pq.get_partition_centroids(0);
assert_eq!(centroids.len(), 4); assert_eq!(centroids[0].len(), 4); }
}