use crate::simd;
use crate::RetrieveError;
#[derive(Debug)]
pub struct AnisotropicQuantizer {
dimension: usize,
num_codebooks: usize,
codebook_size: usize,
pub(crate) codebooks: Vec<Vec<Vec<f32>>>,
}
impl AnisotropicQuantizer {
pub fn new(
dimension: usize,
num_codebooks: usize,
codebook_size: usize,
) -> Result<Self, RetrieveError> {
if dimension == 0 || num_codebooks == 0 || codebook_size == 0 {
return Err(RetrieveError::Other(
"All parameters must be greater than 0".to_string(),
));
}
if dimension % num_codebooks != 0 {
return Err(RetrieveError::Other(
"Dimension must be divisible by num_codebooks".to_string(),
));
}
Ok(Self {
dimension,
num_codebooks,
codebook_size,
codebooks: Vec::new(),
})
}
pub fn fit_residuals(
&mut self,
residuals: &[f32],
num_vectors: usize,
) -> Result<(), RetrieveError> {
if residuals.len() != num_vectors * self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: self.dimension,
doc_dim: residuals.len() / num_vectors,
});
}
let subvector_dim = self.dimension / self.num_codebooks;
self.codebooks = Vec::with_capacity(self.num_codebooks);
for m in 0..self.num_codebooks {
let start_dim = m * subvector_dim;
let _end_dim = (m + 1) * subvector_dim;
let mut subvectors: Vec<f32> = Vec::with_capacity(num_vectors * subvector_dim);
for i in 0..num_vectors {
let vec_start = i * self.dimension + start_dim;
subvectors.extend_from_slice(&residuals[vec_start..vec_start + subvector_dim]);
}
let mut kmeans =
crate::scann::partitioning::KMeans::new(subvector_dim, self.codebook_size)?;
kmeans.fit(&subvectors, num_vectors)?;
let centers = kmeans.centroids();
let codewords: Vec<Vec<f32>> = centers.to_vec();
self.codebooks.push(codewords);
}
Ok(())
}
pub fn quantize(&self, residual: &[f32]) -> Vec<u8> {
let subvector_dim = self.dimension / self.num_codebooks;
let mut codes = Vec::with_capacity(self.num_codebooks);
for m in 0..self.num_codebooks {
let start_dim = m * subvector_dim;
let sub = &residual[start_dim..start_dim + subvector_dim];
let mut best_idx = 0;
let mut min_dist = f32::MAX;
for (k, codeword) in self.codebooks[m].iter().enumerate() {
let dist = squared_euclidean(sub, codeword);
if dist < min_dist {
min_dist = dist;
best_idx = k;
}
}
codes.push(best_idx as u8);
}
codes
}
pub fn build_lut(&self, query: &[f32]) -> Vec<Vec<f32>> {
let subvector_dim = self.dimension / self.num_codebooks;
let mut lut = Vec::with_capacity(self.num_codebooks);
for m in 0..self.num_codebooks {
let start_dim = m * subvector_dim;
let query_sub = &query[start_dim..start_dim + subvector_dim];
let mut sub_lut = Vec::with_capacity(self.codebook_size);
for codeword in &self.codebooks[m] {
let score = simd::dot(query_sub, codeword);
sub_lut.push(score);
}
lut.push(sub_lut);
}
lut
}
}
fn squared_euclidean(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
}