use crate::simd;
use crate::RetrieveError;
#[derive(Debug)]
pub struct AnisotropicQuantizer {
dimension: usize,
num_codebooks: usize,
codebook_size: usize,
subvector_dim: usize,
seed: u64,
codebooks: Vec<f32>,
}
impl AnisotropicQuantizer {
pub fn new(
dimension: usize,
num_codebooks: usize,
codebook_size: usize,
seed: u64,
) -> Result<Self, RetrieveError> {
if dimension == 0 || num_codebooks == 0 || codebook_size == 0 {
return Err(RetrieveError::InvalidParameter(
"all parameters must be greater than 0".into(),
));
}
if dimension % num_codebooks != 0 {
return Err(RetrieveError::InvalidParameter(
"dimension must be divisible by num_codebooks".into(),
));
}
Ok(Self {
dimension,
num_codebooks,
codebook_size,
subvector_dim: dimension / num_codebooks,
seed,
codebooks: Vec::new(),
})
}
#[inline]
fn get_codeword(&self, codebook_idx: usize, code: usize) -> &[f32] {
let offset = (codebook_idx * self.codebook_size + code) * self.subvector_dim;
&self.codebooks[offset..offset + self.subvector_dim]
}
pub fn fit_residuals(
&mut self,
residuals: &[f32],
num_vectors: usize,
) -> Result<(), RetrieveError> {
if residuals.len() != num_vectors * self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: residuals.len() / num_vectors,
doc_dim: self.dimension,
});
}
let mut flat_codebooks =
Vec::with_capacity(self.num_codebooks * self.codebook_size * self.subvector_dim);
let mut actual_codebook_size = self.codebook_size;
for m in 0..self.num_codebooks {
let start_dim = m * self.subvector_dim;
let mut subvectors: Vec<f32> = Vec::with_capacity(num_vectors * self.subvector_dim);
for i in 0..num_vectors {
let vec_start = i * self.dimension + start_dim;
subvectors.extend_from_slice(&residuals[vec_start..vec_start + self.subvector_dim]);
}
let mut kmeans =
crate::scann::partitioning::KMeans::new(self.subvector_dim, self.codebook_size)?
.with_seed(self.seed.wrapping_add(m as u64));
kmeans.fit(&subvectors, num_vectors)?;
let centroids = kmeans.centroids();
if m == 0 {
actual_codebook_size = centroids.len();
}
for codeword in centroids {
flat_codebooks.extend_from_slice(codeword);
}
}
self.codebook_size = actual_codebook_size;
self.codebooks = flat_codebooks;
Ok(())
}
pub fn quantize(&self, residual: &[f32]) -> Vec<u8> {
let mut codes = Vec::with_capacity(self.num_codebooks);
for m in 0..self.num_codebooks {
let start_dim = m * self.subvector_dim;
let sub = &residual[start_dim..start_dim + self.subvector_dim];
let mut best_idx = 0;
let mut min_dist = f32::MAX;
for k in 0..self.codebook_size {
let codeword = self.get_codeword(m, k);
let dist = squared_euclidean(sub, codeword);
if dist < min_dist {
min_dist = dist;
best_idx = k;
}
}
codes.push(best_idx as u8);
}
codes
}
pub fn build_lut(&self, query: &[f32]) -> Vec<Vec<f32>> {
let mut lut = Vec::with_capacity(self.num_codebooks);
for m in 0..self.num_codebooks {
let start_dim = m * self.subvector_dim;
let query_sub = &query[start_dim..start_dim + self.subvector_dim];
let mut sub_lut = Vec::with_capacity(self.codebook_size);
for k in 0..self.codebook_size {
let codeword = self.get_codeword(m, k);
let score = simd::dot(query_sub, codeword);
sub_lut.push(score);
}
lut.push(sub_lut);
}
lut
}
}
fn squared_euclidean(a: &[f32], b: &[f32]) -> f32 {
crate::simd::l2_distance_squared(a, b)
}