use std::collections::HashMap;
use super::distance::{cmp_f32, l2_squared_simd};
use super::hnsw::NodeId;
#[derive(Clone, Debug)]
pub struct PQConfig {
pub dimension: usize,
pub n_subvectors: usize,
pub n_centroids: usize,
pub max_iterations: usize,
}
impl Default for PQConfig {
fn default() -> Self {
Self {
dimension: 128,
n_subvectors: 8,
n_centroids: 256,
max_iterations: 25,
}
}
}
impl PQConfig {
pub fn new(dimension: usize, n_subvectors: usize) -> Self {
assert!(
dimension.is_multiple_of(n_subvectors),
"dimension must be divisible by n_subvectors"
);
Self {
dimension,
n_subvectors,
n_centroids: 256,
max_iterations: 25,
}
}
pub fn subvector_dim(&self) -> usize {
self.dimension / self.n_subvectors
}
}
#[derive(Clone)]
struct Codebook {
centroids: Vec<Vec<f32>>,
dim: usize,
}
impl Codebook {
fn new(dim: usize, n_centroids: usize) -> Self {
Self {
centroids: vec![vec![0.0; dim]; n_centroids],
dim,
}
}
fn train(&mut self, subvectors: &[Vec<f32>], max_iterations: usize) {
if subvectors.is_empty() {
return;
}
let k = self.centroids.len();
let step = subvectors.len().max(1) / k.max(1);
for (i, centroid) in self.centroids.iter_mut().enumerate() {
let idx = (i * step).min(subvectors.len() - 1);
*centroid = subvectors[idx].clone();
}
for _ in 0..max_iterations {
let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); k];
for (i, sv) in subvectors.iter().enumerate() {
let nearest = self.find_nearest(sv);
assignments[nearest].push(i);
}
let mut converged = true;
for (ci, indices) in assignments.iter().enumerate() {
if indices.is_empty() {
continue;
}
let mut new_centroid = vec![0.0f32; self.dim];
for &idx in indices {
for (j, &val) in subvectors[idx].iter().enumerate() {
new_centroid[j] += val;
}
}
for val in &mut new_centroid {
*val /= indices.len() as f32;
}
let shift = l2_squared_simd(&new_centroid, &self.centroids[ci]).sqrt();
if shift > 1e-4 {
converged = false;
}
self.centroids[ci] = new_centroid;
}
if converged {
break;
}
}
}
fn find_nearest(&self, subvector: &[f32]) -> usize {
self.centroids
.iter()
.enumerate()
.map(|(i, c)| (i, l2_squared_simd(subvector, c)))
.min_by(|(li, la), (ri, lb)| cmp_f32(*la, *lb).then_with(|| li.cmp(ri)))
.map(|(i, _)| i)
.unwrap_or(0)
}
fn compute_distance_table(&self, query_subvector: &[f32]) -> Vec<f32> {
self.centroids
.iter()
.map(|c| l2_squared_simd(query_subvector, c))
.collect()
}
}
pub type PQCode = Vec<u8>;
pub struct ProductQuantizer {
config: PQConfig,
codebooks: Vec<Codebook>,
trained: bool,
}
impl ProductQuantizer {
pub fn new(config: PQConfig) -> Self {
let subdim = config.subvector_dim();
let codebooks = (0..config.n_subvectors)
.map(|_| Codebook::new(subdim, config.n_centroids))
.collect();
Self {
config,
codebooks,
trained: false,
}
}
pub fn with_dimension(dimension: usize) -> Self {
let n_subvectors = if dimension >= 64 { 8 } else { 4 };
Self::new(PQConfig::new(dimension, n_subvectors))
}
pub fn train(&mut self, vectors: &[Vec<f32>]) {
if vectors.is_empty() {
return;
}
let subdim = self.config.subvector_dim();
for (m, codebook) in self.codebooks.iter_mut().enumerate() {
let subvectors: Vec<Vec<f32>> = vectors
.iter()
.map(|v| v[m * subdim..(m + 1) * subdim].to_vec())
.collect();
codebook.train(&subvectors, self.config.max_iterations);
}
self.trained = true;
}
pub fn encode(&self, vector: &[f32]) -> PQCode {
let subdim = self.config.subvector_dim();
self.codebooks
.iter()
.enumerate()
.map(|(m, codebook)| {
let subvector = &vector[m * subdim..(m + 1) * subdim];
codebook.find_nearest(subvector) as u8
})
.collect()
}
pub fn encode_batch(&self, vectors: &[Vec<f32>]) -> Vec<PQCode> {
vectors.iter().map(|v| self.encode(v)).collect()
}
pub fn decode(&self, code: &PQCode) -> Vec<f32> {
let subdim = self.config.subvector_dim();
let mut vector = Vec::with_capacity(self.config.dimension);
for (m, &c) in code.iter().enumerate() {
let centroid = &self.codebooks[m].centroids[c as usize];
vector.extend_from_slice(centroid);
}
vector
}
pub fn compute_distances(&self, query: &[f32], codes: &[PQCode]) -> Vec<f32> {
let subdim = self.config.subvector_dim();
let tables: Vec<Vec<f32>> = self
.codebooks
.iter()
.enumerate()
.map(|(m, codebook)| {
let subquery = &query[m * subdim..(m + 1) * subdim];
codebook.compute_distance_table(subquery)
})
.collect();
codes
.iter()
.map(|code| {
code.iter()
.enumerate()
.map(|(m, &c)| tables[m][c as usize])
.sum::<f32>()
.sqrt()
})
.collect()
}
pub fn compression_ratio(&self) -> f32 {
let original_bytes = self.config.dimension * 4; let compressed_bytes = self.config.n_subvectors; original_bytes as f32 / compressed_bytes as f32
}
pub fn config(&self) -> &PQConfig {
&self.config
}
pub fn is_trained(&self) -> bool {
self.trained
}
}
pub struct PQIndex {
pq: ProductQuantizer,
codes: Vec<PQCode>,
ids: Vec<NodeId>,
id_to_idx: HashMap<NodeId, usize>,
originals: Option<Vec<Vec<f32>>>,
next_id: NodeId,
}
impl PQIndex {
pub fn new(config: PQConfig) -> Self {
Self {
pq: ProductQuantizer::new(config),
codes: Vec::new(),
ids: Vec::new(),
id_to_idx: HashMap::new(),
originals: None,
next_id: 0,
}
}
pub fn with_originals(mut self) -> Self {
self.originals = Some(Vec::new());
self
}
pub fn train(&mut self, vectors: &[Vec<f32>]) {
self.pq.train(vectors);
}
pub fn add(&mut self, vector: Vec<f32>) -> NodeId {
let id = self.next_id;
self.next_id += 1;
self.add_with_id(id, vector);
id
}
pub fn add_with_id(&mut self, id: NodeId, vector: Vec<f32>) {
let code = self.pq.encode(&vector);
let idx = self.codes.len();
self.codes.push(code);
self.ids.push(id);
self.id_to_idx.insert(id, idx);
if let Some(ref mut originals) = self.originals {
originals.push(vector);
}
}
pub fn add_batch(&mut self, vectors: Vec<Vec<f32>>) -> Vec<NodeId> {
vectors.into_iter().map(|v| self.add(v)).collect()
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(NodeId, f32)> {
if self.codes.is_empty() {
return Vec::new();
}
let distances = self.pq.compute_distances(query, &self.codes);
let mut results: Vec<(usize, f32)> = distances.into_iter().enumerate().collect();
results.sort_by(|(li, la), (ri, lb)| cmp_f32(*la, *lb).then_with(|| li.cmp(ri)));
results.truncate(k);
results
.into_iter()
.map(|(idx, dist)| (self.ids[idx], dist))
.collect()
}
pub fn search_rerank(&self, query: &[f32], k: usize, rerank_k: usize) -> Vec<(NodeId, f32)> {
let originals = match &self.originals {
Some(o) => o,
None => return self.search(query, k),
};
let candidates = self.search(query, rerank_k);
let mut reranked: Vec<(NodeId, f32)> = candidates
.into_iter()
.map(|(id, _)| {
let idx = self.id_to_idx[&id];
let dist = l2_squared_simd(query, &originals[idx]).sqrt();
(id, dist)
})
.collect();
reranked.sort_by(|(li, la), (ri, lb)| cmp_f32(*la, *lb).then_with(|| li.cmp(ri)));
reranked.truncate(k);
reranked
}
pub fn len(&self) -> usize {
self.codes.len()
}
pub fn is_empty(&self) -> bool {
self.codes.is_empty()
}
pub fn compression_ratio(&self) -> f32 {
self.pq.compression_ratio()
}
pub fn memory_usage(&self) -> usize {
let code_bytes = self.codes.len() * self.pq.config.n_subvectors;
let original_bytes = self
.originals
.as_ref()
.map(|o| o.len() * self.pq.config.dimension * 4)
.unwrap_or(0);
code_bytes + original_bytes
}
}
#[cfg(test)]
mod tests {
use super::*;
fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
(0..dim)
.map(|i| ((seed * 1103515245 + i as u64 * 12345) % 1000) as f32 / 1000.0)
.collect()
}
#[test]
fn test_pq_encode_decode() {
let config = PQConfig::new(16, 4);
let mut pq = ProductQuantizer::new(config);
let training: Vec<Vec<f32>> = (0..100).map(|i| random_vector(16, i)).collect();
pq.train(&training);
assert!(pq.is_trained());
let original = random_vector(16, 999);
let code = pq.encode(&original);
let decoded = pq.decode(&code);
assert_eq!(code.len(), 4); assert_eq!(decoded.len(), 16);
let reconstruction_error: f32 = original
.iter()
.zip(decoded.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
assert!(reconstruction_error < 1.0); }
#[test]
fn test_pq_compression_ratio() {
let pq = ProductQuantizer::new(PQConfig::new(128, 8));
assert_eq!(pq.compression_ratio(), 64.0);
}
#[test]
fn test_pq_index_search() {
let mut index = PQIndex::new(PQConfig::new(8, 4));
let training: Vec<Vec<f32>> = (0..50).map(|i| random_vector(8, i)).collect();
index.train(&training);
for (i, v) in training.iter().enumerate() {
index.add_with_id(i as u64, v.clone());
}
let query = random_vector(8, 0);
let results = index.search(&query, 5);
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, 0);
}
#[test]
fn test_pq_distance_tables() {
let config = PQConfig::new(8, 2);
let mut pq = ProductQuantizer::new(config);
let training: Vec<Vec<f32>> = vec![
vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
];
pq.train(&training);
let query = vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5];
let codes = pq.encode_batch(&training);
let distances = pq.compute_distances(&query, &codes);
assert_eq!(distances.len(), 2);
assert!((distances[0] - distances[1]).abs() < 0.1);
}
}