use super::types::Id;
use anyhow::{anyhow, Result};
use rand::seq::SliceRandom;
#[cfg(not(target_arch = "wasm32"))]
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PQConfig {
pub num_subvectors: usize,
pub num_centroids: usize,
pub training_iterations: usize,
}
impl Default for PQConfig {
fn default() -> Self {
Self {
num_subvectors: 16, num_centroids: 256, training_iterations: 20, }
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ProductQuantizer {
config: PQConfig,
dimension: usize,
subvector_dim: usize,
codebooks: Vec<Vec<Vec<f32>>>,
trained: bool,
}
impl ProductQuantizer {
pub fn new(dimension: usize, config: PQConfig) -> Result<Self> {
if !dimension.is_multiple_of(config.num_subvectors) {
return Err(anyhow!(
"Dimension {} must be divisible by num_subvectors {}",
dimension,
config.num_subvectors
));
}
let subvector_dim = dimension / config.num_subvectors;
Ok(Self {
config,
dimension,
subvector_dim,
codebooks: Vec::new(),
trained: false,
})
}
pub fn train(&mut self, training_vectors: &[Vec<f32>]) -> Result<()> {
if training_vectors.is_empty() {
return Err(anyhow!("Need training vectors"));
}
if training_vectors.len() < self.config.num_centroids {
return Err(anyhow!(
"Need at least {} training vectors (got {})",
self.config.num_centroids,
training_vectors.len()
));
}
#[cfg(not(target_arch = "wasm32"))]
println!(
"Training PQ: {} subvectors, {} centroids, {} training vectors",
self.config.num_subvectors,
self.config.num_centroids,
training_vectors.len()
);
#[cfg(target_arch = "wasm32")]
{
#[cfg(feature = "wasm")]
web_sys::console::log_1(
&format!(
"Training PQ: {} subvectors, {} centroids, {} training vectors",
self.config.num_subvectors,
self.config.num_centroids,
training_vectors.len()
)
.into(),
);
}
self.codebooks = Vec::with_capacity(self.config.num_subvectors);
#[cfg(not(target_arch = "wasm32"))]
let codebooks: Vec<Vec<Vec<f32>>> = {
(0..self.config.num_subvectors)
.into_par_iter()
.map(|m| {
let start_dim = m * self.subvector_dim;
let end_dim = start_dim + self.subvector_dim;
let subvectors: Vec<Vec<f32>> = training_vectors
.iter()
.map(|v| v[start_dim..end_dim].to_vec())
.collect();
self.kmeans(&subvectors, self.config.num_centroids).unwrap()
})
.collect()
};
#[cfg(target_arch = "wasm32")]
let codebooks: Vec<Vec<Vec<f32>>> = {
(0..self.config.num_subvectors)
.map(|m| {
let start_dim = m * self.subvector_dim;
let end_dim = start_dim + self.subvector_dim;
let subvectors: Vec<Vec<f32>> = training_vectors
.iter()
.map(|v| v[start_dim..end_dim].to_vec())
.collect();
self.kmeans(&subvectors, self.config.num_centroids).unwrap()
})
.collect()
};
self.codebooks = codebooks;
self.trained = true;
#[cfg(not(target_arch = "wasm32"))]
println!("✅ PQ training complete");
#[cfg(target_arch = "wasm32")]
{
#[cfg(feature = "wasm")]
web_sys::console::log_1(&"✅ PQ training complete".into());
}
Ok(())
}
fn kmeans(&self, vectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>> {
let dim = vectors[0].len();
let mut rng = rand::thread_rng();
let mut centroids: Vec<Vec<f32>> = vectors.choose_multiple(&mut rng, k).cloned().collect();
for _ in 0..self.config.training_iterations {
let assignments: Vec<usize> = vectors
.iter()
.map(|v| {
centroids
.iter()
.enumerate()
.map(|(i, c)| (i, euclidean_distance(v, c)))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap()
.0
})
.collect();
for (i, centroid) in centroids.iter_mut().enumerate().take(k) {
let cluster: Vec<&Vec<f32>> = vectors
.iter()
.enumerate()
.filter(|(idx, _)| assignments[*idx] == i)
.map(|(_, v)| v)
.collect();
if !cluster.is_empty() {
*centroid = compute_mean(&cluster, dim);
}
}
}
Ok(centroids)
}
pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
if !self.trained {
return Err(anyhow!("Quantizer not trained"));
}
if vector.len() != self.dimension {
return Err(anyhow!(
"Vector dimension mismatch: expected {}, got {}",
self.dimension,
vector.len()
));
}
let mut codes = Vec::with_capacity(self.config.num_subvectors);
for m in 0..self.config.num_subvectors {
let start_dim = m * self.subvector_dim;
let end_dim = start_dim + self.subvector_dim;
let subvector = &vector[start_dim..end_dim];
let code = self.codebooks[m]
.iter()
.enumerate()
.map(|(i, centroid)| (i, euclidean_distance(subvector, centroid)))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap()
.0;
codes.push(code as u8);
}
Ok(codes)
}
pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>> {
if !self.trained {
return Err(anyhow!("Quantizer not trained"));
}
if codes.len() != self.config.num_subvectors {
return Err(anyhow!("Invalid number of codes"));
}
let mut vector = Vec::with_capacity(self.dimension);
for (m, &code) in codes.iter().enumerate() {
let centroid = &self.codebooks[m][code as usize];
vector.extend_from_slice(centroid);
}
Ok(vector)
}
pub fn asymmetric_distance(&self, codes: &[u8], distance_table: &[Vec<f32>]) -> f32 {
codes
.iter()
.enumerate()
.map(|(m, &code)| distance_table[m][code as usize])
.sum()
}
pub fn compute_distance_table(&self, query: &[f32]) -> Vec<Vec<f32>> {
let mut table = Vec::with_capacity(self.config.num_subvectors);
for m in 0..self.config.num_subvectors {
let start_dim = m * self.subvector_dim;
let end_dim = start_dim + self.subvector_dim;
let query_subvector = &query[start_dim..end_dim];
let distances: Vec<f32> = self.codebooks[m]
.iter()
.map(|centroid| euclidean_distance(query_subvector, centroid))
.collect();
table.push(distances);
}
table
}
pub fn compression_ratio(&self) -> f32 {
let original_size = self.dimension * 4; let compressed_size = self.config.num_subvectors; original_size as f32 / compressed_size as f32
}
pub fn is_trained(&self) -> bool {
self.trained
}
pub fn config(&self) -> &PQConfig {
&self.config
}
}
pub struct PQVectorStore {
quantizer: ProductQuantizer,
codes: HashMap<Id, Vec<u8>>,
trained: bool,
}
impl PQVectorStore {
pub fn new(dimension: usize, config: PQConfig) -> Result<Self> {
Ok(Self {
quantizer: ProductQuantizer::new(dimension, config)?,
codes: HashMap::new(),
trained: false,
})
}
pub fn train(&mut self, training_vectors: &[Vec<f32>]) -> Result<()> {
self.quantizer.train(training_vectors)?;
self.trained = true;
Ok(())
}
pub fn add(&mut self, id: Id, vector: &[f32]) -> Result<()> {
if !self.trained {
return Err(anyhow!("Store not trained"));
}
let codes = self.quantizer.encode(vector)?;
self.codes.insert(id, codes);
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(Id, f32)>> {
if !self.trained {
return Err(anyhow!("Store not trained"));
}
let distance_table = self.quantizer.compute_distance_table(query);
#[cfg(not(target_arch = "wasm32"))]
let mut results: Vec<(Id, f32)> = self
.codes
.par_iter()
.map(|(id, codes)| {
let distance = self.quantizer.asymmetric_distance(codes, &distance_table);
(id.clone(), distance)
})
.collect();
#[cfg(target_arch = "wasm32")]
let mut results: Vec<(Id, f32)> = self
.codes
.iter()
.map(|(id, codes)| {
let distance = self.quantizer.asymmetric_distance(codes, &distance_table);
(id.clone(), distance)
})
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
results.truncate(k);
Ok(results)
}
pub fn len(&self) -> usize {
self.codes.len()
}
pub fn is_empty(&self) -> bool {
self.codes.is_empty()
}
pub fn compression_ratio(&self) -> f32 {
self.quantizer.compression_ratio()
}
pub fn memory_usage(&self) -> usize {
self.codes.len() * self.quantizer.config.num_subvectors
}
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
fn compute_mean(vectors: &[&Vec<f32>], dim: usize) -> Vec<f32> {
let mut mean = vec![0.0; dim];
let n = vectors.len() as f32;
for v in vectors {
for (i, &val) in v.iter().enumerate() {
mean[i] += val / n;
}
}
mean
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_random_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
(0..n)
.map(|_| (0..dim).map(|_| rand::random::<f32>()).collect())
.collect()
}
#[test]
fn test_pq_basic() {
let config = PQConfig {
num_subvectors: 8,
num_centroids: 16,
training_iterations: 5,
};
let mut pq = ProductQuantizer::new(64, config).unwrap();
let training_vectors = generate_random_vectors(100, 64);
pq.train(&training_vectors).unwrap();
assert!(pq.is_trained());
let vector = &training_vectors[0];
let codes = pq.encode(vector).unwrap();
assert_eq!(codes.len(), 8);
let decoded = pq.decode(&codes).unwrap();
assert_eq!(decoded.len(), 64);
}
#[test]
fn test_pq_store() {
let config = PQConfig {
num_subvectors: 8,
num_centroids: 256,
training_iterations: 10,
};
let mut store = PQVectorStore::new(64, config).unwrap();
let training_vectors = generate_random_vectors(500, 64);
store.train(&training_vectors).unwrap();
for (i, vec) in training_vectors.iter().take(100).enumerate() {
store.add(format!("vec_{}", i), vec).unwrap();
}
assert_eq!(store.len(), 100);
let query = &training_vectors[0];
let results = store.search(query, 10).unwrap();
assert_eq!(results.len(), 10);
assert_eq!(results[0].0, "vec_0"); }
#[test]
fn test_compression_ratio() {
let config = PQConfig {
num_subvectors: 16,
num_centroids: 256,
training_iterations: 5,
};
let pq = ProductQuantizer::new(128, config).unwrap();
assert_eq!(pq.compression_ratio(), 32.0);
}
}