#[cfg(not(feature = "scann"))]
compile_error!("Online PQ requires 'scann' feature for k-means clustering");
use crate::simd;
use crate::RetrieveError;
pub struct OnlineProductQuantizer {
dimension: usize,
num_codebooks: usize,
codebook_size: usize,
subvector_dim: usize,
codebooks: Vec<Vec<Vec<f32>>>, learning_rate: f32,
forgetting_rate: f32,
codebook_counts: Vec<Vec<usize>>, }
impl OnlineProductQuantizer {
pub fn new(
dimension: usize,
num_codebooks: usize,
codebook_size: usize,
learning_rate: f32,
forgetting_rate: f32,
) -> Result<Self, RetrieveError> {
if dimension == 0 || num_codebooks == 0 || codebook_size == 0 {
return Err(RetrieveError::Other(
"All parameters must be greater than 0".to_string(),
));
}
if dimension % num_codebooks != 0 {
return Err(RetrieveError::Other(
"Dimension must be divisible by num_codebooks".to_string(),
));
}
if learning_rate <= 0.0 || learning_rate > 1.0 {
return Err(RetrieveError::Other(
"Learning rate must be in (0, 1]".to_string(),
));
}
if forgetting_rate < 0.0 || forgetting_rate > 1.0 {
return Err(RetrieveError::Other(
"Forgetting rate must be in [0, 1]".to_string(),
));
}
let subvector_dim = dimension / num_codebooks;
let codebook_counts = vec![vec![0; codebook_size]; num_codebooks];
Ok(Self {
dimension,
num_codebooks,
codebook_size,
subvector_dim,
codebooks: Vec::new(),
learning_rate,
forgetting_rate,
codebook_counts,
})
}
pub fn initialize(&mut self, vectors: &[f32], num_vectors: usize) -> Result<(), RetrieveError> {
self.codebooks = Vec::new();
for codebook_idx in 0..self.num_codebooks {
let start_dim = codebook_idx * self.subvector_dim;
let end_dim = (codebook_idx + 1) * self.subvector_dim;
let mut subvectors = Vec::new();
for i in 0..num_vectors {
let vec = get_vector(vectors, self.dimension, i);
subvectors.push(vec[start_dim..end_dim].to_vec());
}
let mut kmeans = crate::scann::partitioning::KMeans::new(
self.subvector_dim,
self.codebook_size,
)?;
let mut flat = Vec::with_capacity(num_vectors * self.subvector_dim);
for subvec in &subvectors {
flat.extend_from_slice(subvec);
}
kmeans.fit(&flat, num_vectors)?;
self.codebooks.push(kmeans.centroids().to_vec());
}
for codebook_idx in 0..self.num_codebooks {
self.codebook_counts[codebook_idx].fill(0);
}
for i in 0..num_vectors {
let vec = get_vector(vectors, self.dimension, i);
let codes = self.quantize_internal(vec);
for (codebook_idx, &code) in codes.iter().enumerate() {
self.codebook_counts[codebook_idx][code as usize] += 1;
}
}
Ok(())
}
pub fn update(&mut self, vector: &[f32]) -> Result<Vec<u8>, RetrieveError> {
if vector.len() != self.dimension {
return Err(RetrieveError::Other(format!(
"Vector dimension {} != {}",
vector.len(),
self.dimension
)));
}
let codes = self.quantize_internal(vector);
for codebook_idx in 0..self.num_codebooks {
let start_dim = codebook_idx * self.subvector_dim;
let end_dim = (codebook_idx + 1) * self.subvector_dim;
let subvector = &vector[start_dim..end_dim];
let code = codes[codebook_idx] as usize;
let centroid = &mut self.codebooks[codebook_idx][code];
for (i, &val) in subvector.iter().enumerate() {
centroid[i] = (1.0 - self.learning_rate) * centroid[i] + self.learning_rate * val;
}
self.codebook_counts[codebook_idx][code] += 1;
if self.forgetting_rate > 0.0 {
for (other_code, other_centroid) in
self.codebooks[codebook_idx].iter_mut().enumerate()
{
if other_code != code {
for val in other_centroid.iter_mut() {
*val *= 1.0 - self.forgetting_rate * 0.01;
}
}
}
}
}
Ok(codes)
}
pub fn update_batch(
&mut self,
vectors: &[f32],
num_vectors: usize,
) -> Result<Vec<Vec<u8>>, RetrieveError> {
let mut all_codes = Vec::new();
for i in 0..num_vectors {
let vec = get_vector(vectors, self.dimension, i);
let codes = self.update(vec)?;
all_codes.push(codes);
}
Ok(all_codes)
}
pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
self.quantize_internal(vector)
}
fn quantize_internal(&self, vector: &[f32]) -> Vec<u8> {
let mut codes = Vec::with_capacity(self.num_codebooks);
for codebook_idx in 0..self.num_codebooks {
let start_dim = codebook_idx * self.subvector_dim;
let end_dim = (codebook_idx + 1) * self.subvector_dim;
let subvector = &vector[start_dim..end_dim];
let mut best_code = 0u8;
let mut best_dist = f32::INFINITY;
for (code, codeword) in self.codebooks[codebook_idx].iter().enumerate() {
let dist = cosine_distance(subvector, codeword);
if dist < best_dist {
best_dist = dist;
best_code = code.min(255) as u8;
}
}
codes.push(best_code);
}
codes
}
pub fn approximate_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
let mut total_dist = 0.0;
for (codebook_idx, &code) in codes.iter().enumerate() {
let start_dim = codebook_idx * self.subvector_dim;
let end_dim = (codebook_idx + 1) * self.subvector_dim;
let query_subvector = &query[start_dim..end_dim];
let codeword = &self.codebooks[codebook_idx][code as usize];
total_dist += cosine_distance(query_subvector, codeword);
}
total_dist
}
pub fn codebook_counts(&self) -> &[Vec<usize>] {
&self.codebook_counts
}
pub fn codebooks(&self) -> &[Vec<Vec<f32>>] {
&self.codebooks
}
}
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
crate::distance::cosine_distance_normalized(a, b)
}
fn get_vector(vectors: &[f32], dimension: usize, idx: usize) -> &[f32] {
let start = idx * dimension;
let end = start + dimension;
&vectors[start..end]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_online_pq_basic() {
let mut opq = OnlineProductQuantizer::new(128, 8, 256, 0.1, 0.01).unwrap();
let num_init = 1000;
let mut init_vectors = Vec::new();
for i in 0..num_init {
let mut vec = vec![0.0; 128];
for j in 0..128 {
vec[j] = ((i * 128 + j) as f32) * 0.01;
}
init_vectors.extend_from_slice(&vec);
}
opq.initialize(&init_vectors, num_init).unwrap();
let new_vec = vec![1.0; 128];
let codes = opq.update(&new_vec).unwrap();
assert_eq!(codes.len(), 8);
}
}