use crate::distance::l2_squared;
use crate::error::{DiskAnnError, Result};
use bincode::{Decode, Encode};
use rand::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize, Encode, Decode)]
pub struct ProductQuantizer {
pub m: usize,
pub dsub: usize,
pub dim: usize,
pub centroids: Vec<Vec<Vec<f32>>>,
pub trained: bool,
}
impl ProductQuantizer {
pub fn new(dim: usize, m: usize) -> Result<Self> {
if dim % m != 0 {
return Err(DiskAnnError::InvalidConfig(format!(
"dim ({dim}) must be divisible by m ({m})"
)));
}
let dsub = dim / m;
Ok(Self {
m,
dsub,
dim,
centroids: Vec::new(),
trained: false,
})
}
pub fn train(&mut self, vectors: &[Vec<f32>], iterations: usize) -> Result<()> {
if vectors.is_empty() {
return Err(DiskAnnError::Empty);
}
if vectors[0].len() != self.dim {
return Err(DiskAnnError::DimensionMismatch {
expected: self.dim,
actual: vectors[0].len(),
});
}
let k = 256usize; let n = vectors.len();
let mut rng = rand::thread_rng();
self.centroids = Vec::with_capacity(self.m);
for sub in 0..self.m {
let offset = sub * self.dsub;
let subvectors: Vec<&[f32]> = vectors
.iter()
.map(|v| &v[offset..offset + self.dsub])
.collect();
let mut centers = Vec::with_capacity(k);
centers.push(subvectors[rng.gen_range(0..n)].to_vec());
for _ in 1..k.min(n) {
let dists: Vec<f32> = subvectors
.iter()
.map(|sv| {
centers
.iter()
.map(|c| l2_squared(sv, c))
.fold(f32::MAX, f32::min)
})
.collect();
let total: f32 = dists.iter().sum();
if total < 1e-10 {
while centers.len() < k {
centers.push(centers[0].clone());
}
break;
}
let mut r = rng.gen::<f32>() * total;
for (i, &d) in dists.iter().enumerate() {
r -= d;
if r <= 0.0 {
centers.push(subvectors[i].to_vec());
break;
}
}
if centers.len() < k && r > 0.0 {
centers.push(subvectors[rng.gen_range(0..n)].to_vec());
}
}
while centers.len() < k {
centers.push(centers[rng.gen_range(0..centers.len())].clone());
}
let mut assignments = vec![0u8; n];
for _ in 0..iterations {
for (i, sv) in subvectors.iter().enumerate() {
let mut best_dist = f32::MAX;
let mut best_c = 0u8;
for (c, center) in centers.iter().enumerate() {
let d = l2_squared(sv, center);
if d < best_dist {
best_dist = d;
best_c = c as u8;
}
}
assignments[i] = best_c;
}
let mut counts = vec![0usize; k];
let mut sums = vec![vec![0.0f32; self.dsub]; k];
for (i, &a) in assignments.iter().enumerate() {
let ci = a as usize;
counts[ci] += 1;
for d in 0..self.dsub {
sums[ci][d] += subvectors[i][d];
}
}
for c in 0..k {
if counts[c] > 0 {
for d in 0..self.dsub {
centers[c][d] = sums[c][d] / counts[c] as f32;
}
}
}
}
self.centroids.push(centers);
}
self.trained = true;
Ok(())
}
pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
if !self.trained {
return Err(DiskAnnError::PqNotTrained);
}
if vector.len() != self.dim {
return Err(DiskAnnError::DimensionMismatch {
expected: self.dim,
actual: vector.len(),
});
}
let mut codes = Vec::with_capacity(self.m);
for sub in 0..self.m {
let offset = sub * self.dsub;
let subvec = &vector[offset..offset + self.dsub];
let mut best_dist = f32::MAX;
let mut best_c = 0u8;
for (c, center) in self.centroids[sub].iter().enumerate() {
let d = l2_squared(subvec, center);
if d < best_dist {
best_dist = d;
best_c = c as u8;
}
}
codes.push(best_c);
}
Ok(codes)
}
pub fn build_distance_table(&self, query: &[f32]) -> Result<Vec<f32>> {
if !self.trained {
return Err(DiskAnnError::PqNotTrained);
}
if query.len() != self.dim {
return Err(DiskAnnError::DimensionMismatch {
expected: self.dim,
actual: query.len(),
});
}
let k = 256;
let mut table = vec![0.0f32; self.m * k];
for sub in 0..self.m {
let offset = sub * self.dsub;
let subquery = &query[offset..offset + self.dsub];
for (c, center) in self.centroids[sub].iter().enumerate() {
table[sub * k + c] = l2_squared(subquery, center);
}
}
Ok(table)
}
#[inline]
pub fn distance_with_table(&self, codes: &[u8], table: &[f32]) -> f32 {
crate::distance::pq_asymmetric_distance(codes, table, 256)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pq_train_encode() {
let mut pq = ProductQuantizer::new(8, 2).unwrap();
let vectors: Vec<Vec<f32>> = (0..100)
.map(|i| (0..8).map(|d| (i * 7 + d) as f32 / 100.0).collect())
.collect();
pq.train(&vectors, 5).unwrap();
let codes = pq.encode(&vectors[0]).unwrap();
assert_eq!(codes.len(), 2);
let table = pq.build_distance_table(&vectors[0]).unwrap();
let dist = pq.distance_with_table(&codes, &table);
assert!(dist < 0.1, "self-distance was {dist}");
}
}