use crate::partitioning::kmeans::KMeans;
use crate::RetrieveError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProductQuantizer {
dimension: usize,
num_codebooks: usize,
codebook_size: usize,
subvector_dim: usize,
codebooks: Vec<Vec<Vec<f32>>>, }
impl ProductQuantizer {
pub fn new(
dimension: usize,
num_codebooks: usize,
codebook_size: usize,
) -> Result<Self, RetrieveError> {
if dimension == 0 || num_codebooks == 0 || codebook_size == 0 {
return Err(RetrieveError::Other(
"All parameters must be greater than 0".to_string(),
));
}
if dimension % num_codebooks != 0 {
return Err(RetrieveError::Other(
"Dimension must be divisible by num_codebooks".to_string(),
));
}
Ok(Self {
dimension,
num_codebooks,
codebook_size,
subvector_dim: dimension / num_codebooks,
codebooks: Vec::new(),
})
}
pub fn fit(&mut self, vectors: &[f32], num_vectors: usize) -> Result<(), RetrieveError> {
self.codebooks = Vec::new();
for codebook_idx in 0..self.num_codebooks {
let start_dim = codebook_idx * self.subvector_dim;
let end_dim = (codebook_idx + 1) * self.subvector_dim;
let mut subvectors = Vec::new();
for i in 0..num_vectors {
let vec = get_vector(vectors, self.dimension, i);
subvectors.push(vec[start_dim..end_dim].to_vec());
}
let mut kmeans = KMeans::new(self.subvector_dim, self.codebook_size)?;
let mut flat = Vec::with_capacity(num_vectors * self.subvector_dim);
for subvec in &subvectors {
flat.extend_from_slice(subvec);
}
kmeans.fit(&flat, num_vectors)?;
self.codebooks.push(kmeans.centroids().to_vec());
}
Ok(())
}
pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
let mut codes = Vec::with_capacity(self.num_codebooks);
for codebook_idx in 0..self.num_codebooks {
let start_dim = codebook_idx * self.subvector_dim;
let end_dim = (codebook_idx + 1) * self.subvector_dim;
let subvector = &vector[start_dim..end_dim];
let mut best_code = 0u8;
let mut best_dist = f32::INFINITY;
for (code, codeword) in self.codebooks[codebook_idx].iter().enumerate() {
let dist = cosine_distance(subvector, codeword);
if dist < best_dist {
best_dist = dist;
best_code = code.min(255) as u8;
}
}
codes.push(best_code);
}
codes
}
pub fn approximate_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
let mut total_dist = 0.0;
for (codebook_idx, &code) in codes.iter().enumerate() {
let start_dim = codebook_idx * self.subvector_dim;
let end_dim = (codebook_idx + 1) * self.subvector_dim;
let query_subvector = &query[start_dim..end_dim];
let codeword = &self.codebooks[codebook_idx][code as usize];
total_dist += cosine_distance(query_subvector, codeword);
}
total_dist
}
pub fn compute_adc_table(&self, query: &[f32]) -> Result<Vec<f32>, RetrieveError> {
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: self.dimension,
doc_dim: query.len(),
});
}
let mut table = Vec::with_capacity(self.num_codebooks * self.codebook_size);
for codebook_idx in 0..self.num_codebooks {
let start_dim = codebook_idx * self.subvector_dim;
let end_dim = (codebook_idx + 1) * self.subvector_dim;
let query_subvector = &query[start_dim..end_dim];
for codeword in &self.codebooks[codebook_idx] {
let dist = cosine_distance(query_subvector, codeword);
table.push(dist);
}
}
Ok(table)
}
#[inline(always)]
pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
let mut total_dist = 0.0;
for (codebook_idx, &code) in codes.iter().enumerate() {
let idx = codebook_idx * self.codebook_size + code as usize;
total_dist += table[idx];
}
total_dist
}
pub fn codebooks(&self) -> &[Vec<Vec<f32>>] {
&self.codebooks
}
pub fn codebooks_mut(&mut self) -> &mut [Vec<Vec<f32>>] {
&mut self.codebooks
}
}
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
crate::distance::cosine_distance_normalized(a, b)
}
fn get_vector(vectors: &[f32], dimension: usize, idx: usize) -> &[f32] {
let start = idx * dimension;
let end = start + dimension;
&vectors[start..end]
}