use std::collections::HashMap;
use parking_lot::RwLock;
use rand::seq::SliceRandom;
use serde::{Deserialize, Serialize};
use common::{DistanceMetric, Vector, VectorId};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PQConfig {
pub num_subquantizers: usize,
pub num_centroids: usize,
pub kmeans_iterations: usize,
pub distance_metric: DistanceMetric,
}
impl Default for PQConfig {
fn default() -> Self {
Self {
num_subquantizers: 8,
num_centroids: 256,
kmeans_iterations: 20,
distance_metric: DistanceMetric::Euclidean,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProductQuantizer {
pub config: PQConfig,
pub codebooks: Vec<Vec<Vec<f32>>>,
pub dimension: usize,
pub subvector_dim: usize,
}
impl ProductQuantizer {
pub fn new(config: PQConfig, dimension: usize) -> Result<Self, String> {
if !dimension.is_multiple_of(config.num_subquantizers) {
return Err(format!(
"Dimension {} not divisible by num_subquantizers {}",
dimension, config.num_subquantizers
));
}
let subvector_dim = dimension / config.num_subquantizers;
Ok(Self {
config,
codebooks: Vec::new(),
dimension,
subvector_dim,
})
}
pub fn train(&mut self, vectors: &[Vector]) -> Result<(), String> {
if vectors.is_empty() {
return Err("Cannot train on empty vectors".to_string());
}
if vectors[0].values.len() != self.dimension {
return Err(format!(
"Vector dimension {} doesn't match expected {}",
vectors[0].values.len(),
self.dimension
));
}
let m = self.config.num_subquantizers;
let k = self.config.num_centroids;
let d = self.subvector_dim;
self.codebooks = Vec::with_capacity(m);
for subspace_idx in 0..m {
let start = subspace_idx * d;
let end = start + d;
let subvectors: Vec<Vec<f32>> = vectors
.iter()
.map(|v| v.values[start..end].to_vec())
.collect();
let codebook = self.train_kmeans(&subvectors, k)?;
self.codebooks.push(codebook);
}
Ok(())
}
fn train_kmeans(&self, subvectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>, String> {
if subvectors.is_empty() {
return Err("Cannot train k-means on empty subvectors".to_string());
}
let actual_k = k.min(subvectors.len());
let dim = subvectors[0].len();
let mut centroids = self.kmeans_plus_plus(subvectors, actual_k);
for _ in 0..self.config.kmeans_iterations {
let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); actual_k];
for (i, subvec) in subvectors.iter().enumerate() {
let nearest = self.find_nearest_centroid(subvec, ¢roids);
assignments[nearest].push(i);
}
for (c_idx, assigned) in assignments.iter().enumerate() {
if assigned.is_empty() {
continue;
}
let mut new_centroid = vec![0.0f32; dim];
for &vec_idx in assigned {
for (j, &val) in subvectors[vec_idx].iter().enumerate() {
new_centroid[j] += val;
}
}
let count = assigned.len() as f32;
for val in &mut new_centroid {
*val /= count;
}
centroids[c_idx] = new_centroid;
}
}
Ok(centroids)
}
fn kmeans_plus_plus(&self, subvectors: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
let mut rng = rand::thread_rng();
let mut centroids = Vec::with_capacity(k);
if let Some(first) = subvectors.choose(&mut rng) {
centroids.push(first.clone());
} else {
return centroids;
}
for _ in 1..k {
let distances: Vec<f32> = subvectors
.iter()
.map(|v| {
centroids
.iter()
.map(|c| self.squared_distance(v, c))
.fold(f32::MAX, f32::min)
})
.collect();
let total: f32 = distances.iter().sum();
if total == 0.0 {
break;
}
let threshold: f32 = rand::random::<f32>() * total;
let mut cumsum = 0.0;
for (i, &d) in distances.iter().enumerate() {
cumsum += d;
if cumsum >= threshold {
centroids.push(subvectors[i].clone());
break;
}
}
}
centroids
}
fn find_nearest_centroid(&self, subvec: &[f32], centroids: &[Vec<f32>]) -> usize {
let mut best_idx = 0;
let mut best_dist = f32::MAX;
for (i, centroid) in centroids.iter().enumerate() {
let dist = self.squared_distance(subvec, centroid);
if dist < best_dist {
best_dist = dist;
best_idx = i;
}
}
best_idx
}
#[inline]
fn squared_distance(&self, a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
}
pub fn is_trained(&self) -> bool {
!self.codebooks.is_empty()
}
pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>, String> {
if !self.is_trained() {
return Err("Quantizer not trained".to_string());
}
if vector.len() != self.dimension {
return Err(format!(
"Vector dimension {} doesn't match expected {}",
vector.len(),
self.dimension
));
}
let m = self.config.num_subquantizers;
let d = self.subvector_dim;
let mut codes = Vec::with_capacity(m);
for subspace_idx in 0..m {
let start = subspace_idx * d;
let end = start + d;
let subvec = &vector[start..end];
let nearest = self.find_nearest_centroid(subvec, &self.codebooks[subspace_idx]);
codes.push(nearest as u8);
}
Ok(codes)
}
pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>, String> {
if !self.is_trained() {
return Err("Quantizer not trained".to_string());
}
if codes.len() != self.config.num_subquantizers {
return Err(format!(
"Code length {} doesn't match num_subquantizers {}",
codes.len(),
self.config.num_subquantizers
));
}
let mut vector = Vec::with_capacity(self.dimension);
for (subspace_idx, &code) in codes.iter().enumerate() {
let centroid = &self.codebooks[subspace_idx][code as usize];
vector.extend_from_slice(centroid);
}
Ok(vector)
}
pub fn compute_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>, String> {
if !self.is_trained() {
return Err("Quantizer not trained".to_string());
}
if query.len() != self.dimension {
return Err(format!(
"Query dimension {} doesn't match expected {}",
query.len(),
self.dimension
));
}
let m = self.config.num_subquantizers;
let k = self.config.num_centroids;
let d = self.subvector_dim;
let mut table = Vec::with_capacity(m);
for subspace_idx in 0..m {
let start = subspace_idx * d;
let end = start + d;
let query_subvec = &query[start..end];
let mut distances = Vec::with_capacity(k);
for centroid in &self.codebooks[subspace_idx] {
let dist = match self.config.distance_metric {
DistanceMetric::Euclidean => {
-self.squared_distance(query_subvec, centroid).sqrt()
}
DistanceMetric::Cosine => self.cosine_sim(query_subvec, centroid),
DistanceMetric::DotProduct => self.dot_product(query_subvec, centroid),
};
distances.push(dist);
}
table.push(distances);
}
Ok(table)
}
#[inline]
pub fn compute_distance_adc(&self, table: &[Vec<f32>], codes: &[u8]) -> f32 {
let mut total = 0.0f32;
for (subspace_idx, &code) in codes.iter().enumerate() {
total += table[subspace_idx][code as usize];
}
total
}
#[inline]
fn cosine_sim(&self, a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let norm_a = norm_a.sqrt();
let norm_b = norm_b.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
#[inline]
fn dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
}
pub struct PQIndex {
quantizer: RwLock<ProductQuantizer>,
encoded_vectors: RwLock<HashMap<VectorId, Vec<u8>>>,
original_vectors: RwLock<HashMap<VectorId, Vector>>,
store_originals: bool,
}
#[derive(Debug, Clone)]
pub struct PQSearchResult {
pub id: VectorId,
pub score: f32,
pub vector: Option<Vector>,
}
impl PQIndex {
pub fn new(config: PQConfig, dimension: usize, store_originals: bool) -> Result<Self, String> {
let quantizer = ProductQuantizer::new(config, dimension)?;
Ok(Self {
quantizer: RwLock::new(quantizer),
encoded_vectors: RwLock::new(HashMap::new()),
original_vectors: RwLock::new(HashMap::new()),
store_originals,
})
}
pub fn train(&self, vectors: &[Vector]) -> Result<(), String> {
let mut quantizer = self.quantizer.write();
quantizer.train(vectors)
}
pub fn is_trained(&self) -> bool {
self.quantizer.read().is_trained()
}
pub fn add(&self, vectors: Vec<Vector>) -> Result<usize, String> {
let quantizer = self.quantizer.read();
if !quantizer.is_trained() {
return Err("Index not trained".to_string());
}
let mut encoded = self.encoded_vectors.write();
let mut originals = self.original_vectors.write();
let mut count = 0;
for vector in vectors {
let codes = quantizer.encode(&vector.values)?;
encoded.insert(vector.id.clone(), codes);
if self.store_originals {
originals.insert(vector.id.clone(), vector);
}
count += 1;
}
Ok(count)
}
pub fn remove(&self, ids: &[VectorId]) -> usize {
let mut encoded = self.encoded_vectors.write();
let mut originals = self.original_vectors.write();
let mut count = 0;
for id in ids {
if encoded.remove(id).is_some() {
count += 1;
}
originals.remove(id);
}
count
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<PQSearchResult>, String> {
let quantizer = self.quantizer.read();
if !quantizer.is_trained() {
return Err("Index not trained".to_string());
}
let table = quantizer.compute_distance_table(query)?;
let encoded = self.encoded_vectors.read();
let originals = self.original_vectors.read();
let mut results: Vec<PQSearchResult> = encoded
.iter()
.map(|(id, codes)| {
let score = quantizer.compute_distance_adc(&table, codes);
let vector = originals.get(id).cloned();
PQSearchResult {
id: id.clone(),
score,
vector,
}
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
Ok(results)
}
pub fn len(&self) -> usize {
self.encoded_vectors.read().len()
}
pub fn is_empty(&self) -> bool {
self.encoded_vectors.read().is_empty()
}
pub fn compression_ratio(&self) -> f32 {
let quantizer = self.quantizer.read();
let original_size = quantizer.dimension * 4; let compressed_size = quantizer.config.num_subquantizers; original_size as f32 / compressed_size as f32
}
pub fn decode(&self, id: &VectorId) -> Result<Vec<f32>, String> {
let quantizer = self.quantizer.read();
let encoded = self.encoded_vectors.read();
let codes = encoded
.get(id)
.ok_or_else(|| format!("Vector {} not found", id))?;
quantizer.decode(codes)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_vectors(n: usize, dim: usize) -> Vec<Vector> {
(0..n)
.map(|i| Vector {
id: format!("v{}", i),
values: (0..dim).map(|j| ((i + j) as f32 * 0.1).sin()).collect(),
metadata: None,
ttl_seconds: None,
expires_at: None,
})
.collect()
}
#[test]
fn test_pq_config_validation() {
let config = PQConfig {
num_subquantizers: 8,
..Default::default()
};
assert!(ProductQuantizer::new(config.clone(), 64).is_ok());
assert!(ProductQuantizer::new(config, 65).is_err());
}
#[test]
fn test_pq_train() {
let config = PQConfig {
num_subquantizers: 4,
num_centroids: 16,
kmeans_iterations: 10,
..Default::default()
};
let mut pq = ProductQuantizer::new(config, 32).unwrap();
let vectors = test_vectors(100, 32);
assert!(!pq.is_trained());
pq.train(&vectors).unwrap();
assert!(pq.is_trained());
assert_eq!(pq.codebooks.len(), 4);
assert_eq!(pq.codebooks[0].len(), 16);
assert_eq!(pq.codebooks[0][0].len(), 8); }
#[test]
fn test_pq_encode_decode() {
let config = PQConfig {
num_subquantizers: 4,
num_centroids: 16,
..Default::default()
};
let mut pq = ProductQuantizer::new(config, 32).unwrap();
let vectors = test_vectors(100, 32);
pq.train(&vectors).unwrap();
let original = &vectors[0].values;
let codes = pq.encode(original).unwrap();
assert_eq!(codes.len(), 4);
let decoded = pq.decode(&codes).unwrap();
assert_eq!(decoded.len(), 32);
let error: f32 = original
.iter()
.zip(decoded.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt();
assert!(error < 5.0, "Quantization error too high: {}", error);
}
#[test]
fn test_pq_distance_table() {
let config = PQConfig {
num_subquantizers: 4,
num_centroids: 16,
..Default::default()
};
let mut pq = ProductQuantizer::new(config, 32).unwrap();
let vectors = test_vectors(100, 32);
pq.train(&vectors).unwrap();
let query = &vectors[0].values;
let table = pq.compute_distance_table(query).unwrap();
assert_eq!(table.len(), 4);
assert_eq!(table[0].len(), 16);
}
#[test]
fn test_pq_adc() {
let config = PQConfig {
num_subquantizers: 4,
num_centroids: 16,
..Default::default()
};
let mut pq = ProductQuantizer::new(config, 32).unwrap();
let vectors = test_vectors(100, 32);
pq.train(&vectors).unwrap();
let query = &vectors[50].values;
let table = pq.compute_distance_table(query).unwrap();
let codes = pq.encode(query).unwrap();
let dist = pq.compute_distance_adc(&table, &codes);
assert!(
dist > -3.0,
"Self-distance should be relatively small, got {}",
dist
);
}
#[test]
fn test_pq_index_basic() {
let config = PQConfig {
num_subquantizers: 4,
num_centroids: 16,
..Default::default()
};
let index = PQIndex::new(config, 32, true).unwrap();
let vectors = test_vectors(100, 32);
index.train(&vectors).unwrap();
assert!(index.is_trained());
let added = index.add(vectors.clone()).unwrap();
assert_eq!(added, 100);
assert_eq!(index.len(), 100);
}
#[test]
fn test_pq_index_search() {
let config = PQConfig {
num_subquantizers: 4,
num_centroids: 32,
kmeans_iterations: 15,
distance_metric: DistanceMetric::Euclidean,
};
let index = PQIndex::new(config, 32, true).unwrap();
let vectors = test_vectors(200, 32);
index.train(&vectors).unwrap();
index.add(vectors.clone()).unwrap();
let query = &vectors[100].values;
let results = index.search(query, 10).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 10);
for i in 1..results.len() {
assert!(results[i - 1].score >= results[i].score);
}
let found = results.iter().any(|r| r.id == "v100");
assert!(found, "Query vector not found in top results");
}
#[test]
fn test_pq_index_remove() {
let config = PQConfig {
num_subquantizers: 4,
num_centroids: 16,
..Default::default()
};
let index = PQIndex::new(config, 32, false).unwrap();
let vectors = test_vectors(50, 32);
index.train(&vectors).unwrap();
index.add(vectors).unwrap();
assert_eq!(index.len(), 50);
let removed = index.remove(&["v0".to_string(), "v1".to_string()]);
assert_eq!(removed, 2);
assert_eq!(index.len(), 48);
}
#[test]
fn test_pq_compression_ratio() {
let config = PQConfig {
num_subquantizers: 8,
num_centroids: 256,
..Default::default()
};
let index = PQIndex::new(config, 128, false).unwrap();
let ratio = index.compression_ratio();
assert!((ratio - 64.0).abs() < 0.1);
}
#[test]
fn test_pq_decode_from_index() {
let config = PQConfig {
num_subquantizers: 4,
num_centroids: 16,
..Default::default()
};
let index = PQIndex::new(config, 32, false).unwrap();
let vectors = test_vectors(50, 32);
index.train(&vectors).unwrap();
index.add(vectors).unwrap();
let decoded = index.decode(&"v10".to_string()).unwrap();
assert_eq!(decoded.len(), 32);
assert!(index.decode(&"nonexistent".to_string()).is_err());
}
}