use crate::partitioning::kmeans::KMeans;
use crate::simd::l2_distance_squared;
use crate::RetrieveError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProductQuantizer {
dimension: usize,
num_codebooks: usize,
codebook_size: usize,
subvector_dim: usize,
codebooks: Vec<f32>,
}
impl ProductQuantizer {
pub fn new(
dimension: usize,
num_codebooks: usize,
codebook_size: usize,
) -> Result<Self, RetrieveError> {
if dimension == 0 || num_codebooks == 0 || codebook_size == 0 {
return Err(RetrieveError::InvalidParameter(
"all parameters must be greater than 0".into(),
));
}
if dimension % num_codebooks != 0 {
return Err(RetrieveError::InvalidParameter(
"dimension must be divisible by num_codebooks".into(),
));
}
Ok(Self {
dimension,
num_codebooks,
codebook_size,
subvector_dim: dimension / num_codebooks,
codebooks: Vec::new(),
})
}
pub fn fit(&mut self, vectors: &[f32], num_vectors: usize) -> Result<(), RetrieveError> {
let mut flat_codebooks =
Vec::with_capacity(self.num_codebooks * self.codebook_size * self.subvector_dim);
let mut actual_codebook_size = self.codebook_size;
for codebook_idx in 0..self.num_codebooks {
let start_dim = codebook_idx * self.subvector_dim;
let end_dim = (codebook_idx + 1) * self.subvector_dim;
let mut flat = Vec::with_capacity(num_vectors * self.subvector_dim);
for i in 0..num_vectors {
let vec = get_vector(vectors, self.dimension, i);
flat.extend_from_slice(&vec[start_dim..end_dim]);
}
let mut kmeans = KMeans::new(self.subvector_dim, self.codebook_size)?;
kmeans.fit(&flat, num_vectors)?;
let centroids = kmeans.centroids();
if codebook_idx == 0 {
actual_codebook_size = centroids.len();
}
for codeword in centroids {
flat_codebooks.extend_from_slice(codeword);
}
}
self.codebook_size = actual_codebook_size;
self.codebooks = flat_codebooks;
Ok(())
}
#[inline]
fn get_codeword(&self, codebook_idx: usize, code: usize) -> &[f32] {
let offset = (codebook_idx * self.codebook_size + code) * self.subvector_dim;
&self.codebooks[offset..offset + self.subvector_dim]
}
pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
let mut codes = Vec::with_capacity(self.num_codebooks);
for codebook_idx in 0..self.num_codebooks {
let start_dim = codebook_idx * self.subvector_dim;
let end_dim = (codebook_idx + 1) * self.subvector_dim;
let subvector = &vector[start_dim..end_dim];
let mut best_code = 0u8;
let mut best_dist = f32::INFINITY;
for code in 0..self.codebook_size {
let codeword = self.get_codeword(codebook_idx, code);
let dist = l2_distance_squared(subvector, codeword);
if dist < best_dist {
best_dist = dist;
best_code = code.min(255) as u8;
}
}
codes.push(best_code);
}
codes
}
pub fn approximate_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
let mut total_dist = 0.0;
for (codebook_idx, &code) in codes.iter().enumerate() {
let start_dim = codebook_idx * self.subvector_dim;
let end_dim = (codebook_idx + 1) * self.subvector_dim;
let query_subvector = &query[start_dim..end_dim];
let codeword = self.get_codeword(codebook_idx, code as usize);
total_dist += l2_distance_squared(query_subvector, codeword);
}
total_dist
}
pub fn compute_adc_table(&self, query: &[f32]) -> Result<Vec<f32>, RetrieveError> {
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let mut table = Vec::with_capacity(self.num_codebooks * self.codebook_size);
for codebook_idx in 0..self.num_codebooks {
let start_dim = codebook_idx * self.subvector_dim;
let end_dim = (codebook_idx + 1) * self.subvector_dim;
let query_subvector = &query[start_dim..end_dim];
for code in 0..self.codebook_size {
let codeword = self.get_codeword(codebook_idx, code);
let dist = l2_distance_squared(query_subvector, codeword);
table.push(dist);
}
}
Ok(table)
}
#[inline(always)]
pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
let mut total_dist = 0.0;
for (codebook_idx, &code) in codes.iter().enumerate() {
let idx = codebook_idx * self.codebook_size + code as usize;
total_dist += table[idx];
}
total_dist
}
pub fn reconstruct(&self, codes: &[u8]) -> Vec<f32> {
let mut result = Vec::with_capacity(self.dimension);
for (m, &code) in codes.iter().enumerate() {
result.extend_from_slice(self.get_codeword(m, code as usize));
}
result
}
pub fn num_codebooks(&self) -> usize {
self.num_codebooks
}
pub fn subvector_dim(&self) -> usize {
self.subvector_dim
}
pub fn codebook_size(&self) -> usize {
self.codebook_size
}
}
#[inline]
fn get_vector(vectors: &[f32], dimension: usize, idx: usize) -> &[f32] {
let start = idx * dimension;
let end = start + dimension;
&vectors[start..end]
}