use super::pq::ProductQuantizer;
use crate::RetrieveError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizedProductQuantizer {
dimension: usize,
num_codebooks: usize,
codebook_size: usize,
rotation: Option<Vec<f32>>,
pq: ProductQuantizer,
}
impl OptimizedProductQuantizer {
pub fn new(
dimension: usize,
num_codebooks: usize,
codebook_size: usize,
) -> Result<Self, RetrieveError> {
let pq = ProductQuantizer::new(dimension, num_codebooks, codebook_size)?;
Ok(Self {
dimension,
num_codebooks,
codebook_size,
rotation: None, pq,
})
}
pub fn fit(
&mut self,
data: &[f32],
num_vectors: usize,
iterations: usize,
) -> Result<(), RetrieveError> {
if num_vectors < self.codebook_size {
return Err(RetrieveError::InvalidParameter(format!(
"need at least {} training vectors, got {}",
self.codebook_size, num_vectors
)));
}
let mut rotation = identity_matrix(self.dimension);
let mut rotated_data = data.to_vec();
for iter in 0..iterations {
if iter > 0 {
apply_rotation_batch(
data,
num_vectors,
self.dimension,
&rotation,
&mut rotated_data,
);
}
self.pq.fit(&rotated_data, num_vectors)?;
rotation =
compute_optimal_rotation(data, num_vectors, self.dimension, &self.pq, &rotation);
}
self.rotation = Some(rotation);
Ok(())
}
pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
let rotated = self.rotate_vector(vector);
self.pq.quantize(&rotated)
}
pub fn approximate_distance_table(&self, query: &[f32]) -> Result<Vec<f32>, RetrieveError> {
let rotated = self.rotate_vector(query);
self.pq.compute_adc_table(&rotated)
}
#[inline(always)]
pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
self.pq.distance_with_table(table, codes)
}
fn rotate_vector(&self, vector: &[f32]) -> Vec<f32> {
match &self.rotation {
Some(r) => matrix_vector_multiply(r, vector, self.dimension),
None => vector.to_vec(),
}
}
pub fn rotation_matrix(&self) -> Option<&[f32]> {
self.rotation.as_deref()
}
}
fn identity_matrix(d: usize) -> Vec<f32> {
let mut m = vec![0.0f32; d * d];
for i in 0..d {
m[i * d + i] = 1.0;
}
m
}
fn matrix_vector_multiply(m: &[f32], v: &[f32], d: usize) -> Vec<f32> {
let mut result = vec![0.0f32; d];
for (i, val) in result.iter_mut().enumerate() {
let mut sum = 0.0;
let row_start = i * d;
for j in 0..d {
sum += m[row_start + j] * v[j];
}
*val = sum;
}
result
}
fn apply_rotation_batch(
data: &[f32],
num_vectors: usize,
dimension: usize,
rotation: &[f32],
output: &mut [f32],
) {
for i in 0..num_vectors {
let src_start = i * dimension;
let src = &data[src_start..src_start + dimension];
let dst_start = i * dimension;
let rotated = matrix_vector_multiply(rotation, src, dimension);
output[dst_start..dst_start + dimension].copy_from_slice(&rotated);
}
}
fn compute_optimal_rotation(
data: &[f32],
num_vectors: usize,
dimension: usize,
pq: &ProductQuantizer,
current_rotation: &[f32],
) -> Vec<f32> {
if dimension == 0 {
return Vec::new();
}
let sample_size = num_vectors.min(5000);
if sample_size == 0 {
return identity_matrix(dimension);
}
let mut m = vec![0.0f32; dimension * dimension];
for k in 0..sample_size {
let src_start = k * dimension;
let original = &data[src_start..src_start + dimension];
let rotated = matrix_vector_multiply(current_rotation, original, dimension);
let codes = pq.quantize(&rotated);
let reconstruction = reconstruct_vector(pq, &codes, dimension);
for i in 0..dimension {
for j in 0..dimension {
m[i * dimension + j] += reconstruction[i] * original[j];
}
}
}
let na_m = nalgebra::DMatrix::from_row_slice(dimension, dimension, &m);
let svd = nalgebra::linalg::SVD::new(na_m, true, true);
let u = match svd.u {
Some(ref u) => u,
None => return identity_matrix(dimension),
};
let vt = match svd.v_t {
Some(ref vt) => vt,
None => return identity_matrix(dimension),
};
let rotation = u * vt;
let mut result = vec![0.0f32; dimension * dimension];
for i in 0..dimension {
for j in 0..dimension {
result[i * dimension + j] = rotation[(i, j)];
}
}
result
}
fn reconstruct_vector(pq: &ProductQuantizer, codes: &[u8], _dimension: usize) -> Vec<f32> {
pq.reconstruct(codes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_identity_rotation() {
let opq = OptimizedProductQuantizer::new(8, 2, 256).unwrap();
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let rotated = opq.rotate_vector(&v);
assert_eq!(v, rotated);
}
#[test]
fn test_matrix_vector_multiply() {
let m = vec![1.0, 0.0, 0.0, 1.0];
let v = vec![3.0, 4.0];
let result = matrix_vector_multiply(&m, &v, 2);
assert_eq!(result, vec![3.0, 4.0]);
let m = vec![0.0, -1.0, 1.0, 0.0];
let v = vec![1.0, 0.0];
let result = matrix_vector_multiply(&m, &v, 2);
assert!((result[0] - 0.0).abs() < 1e-6);
assert!((result[1] - 1.0).abs() < 1e-6);
}
#[test]
fn test_opq_training() {
let dimension = 16;
let num_codebooks = 4;
let codebook_size = 8; let num_vectors = 100;
let mut data = Vec::with_capacity(num_vectors * dimension);
for i in 0..num_vectors {
for j in 0..dimension {
data.push(((i * 7 + j * 3) % 100) as f32 / 100.0);
}
}
let mut opq =
OptimizedProductQuantizer::new(dimension, num_codebooks, codebook_size).unwrap();
let result = opq.fit(&data, num_vectors, 2);
assert!(result.is_ok());
assert!(opq.rotation.is_some());
let r = opq.rotation.as_ref().unwrap();
for i in 0..dimension {
let row_start = i * dimension;
let mut norm_sq = 0.0;
for j in 0..dimension {
norm_sq += r[row_start + j] * r[row_start + j];
}
assert!(
(norm_sq - 1.0).abs() < 0.1,
"Row {} norm = {}",
i,
norm_sq.sqrt()
);
}
}
#[test]
fn test_opq_quantize() {
let dimension = 8;
let num_codebooks = 2;
let codebook_size = 4;
let num_vectors = 50;
let mut data = Vec::with_capacity(num_vectors * dimension);
for i in 0..num_vectors {
for j in 0..dimension {
data.push(((i + j) % 10) as f32 / 10.0);
}
}
let mut opq =
OptimizedProductQuantizer::new(dimension, num_codebooks, codebook_size).unwrap();
opq.fit(&data, num_vectors, 2).unwrap();
let v = vec![0.5; dimension];
let codes = opq.quantize(&v);
assert_eq!(codes.len(), num_codebooks);
for &c in &codes {
assert!((c as usize) < codebook_size);
}
}
}