use crate::{
pq::{PQConfig, PQIndex},
Vector,
};
use anyhow::{anyhow, Result};
use nalgebra::{DMatrix, DVector, SVD};
#[derive(Debug, Clone)]
pub struct OPQConfig {
pub pq_config: PQConfig,
pub n_iterations: usize,
pub center_data: bool,
pub regularization: f32,
}
impl Default for OPQConfig {
fn default() -> Self {
Self {
pq_config: PQConfig::default(),
n_iterations: 10,
center_data: true,
regularization: 0.0,
}
}
}
pub struct OPQIndex {
config: OPQConfig,
rotation_matrix: Option<DMatrix<f32>>,
data_mean: Option<DVector<f32>>,
pq_index: PQIndex,
is_trained: bool,
}
impl OPQIndex {
pub fn new(config: OPQConfig) -> Self {
Self {
pq_index: PQIndex::new(config.pq_config.clone()),
config,
rotation_matrix: None,
data_mean: None,
is_trained: false,
}
}
pub fn train(&mut self, vectors: &[Vector]) -> Result<()> {
if vectors.is_empty() {
return Err(anyhow!("Cannot train OPQ with empty data"));
}
let n_samples = vectors.len();
let dimensions = vectors[0].dimensions;
let mut data_matrix = DMatrix::zeros(n_samples, dimensions);
for (i, vector) in vectors.iter().enumerate() {
let vec_f32 = vector.as_f32();
for (j, &val) in vec_f32.iter().enumerate() {
data_matrix[(i, j)] = val;
}
}
if self.config.center_data {
let mean = self.compute_mean(&data_matrix);
self.center_data_matrix(&mut data_matrix, &mean);
self.data_mean = Some(mean);
}
let mut rotation = DMatrix::identity(dimensions, dimensions);
for iteration in 0..self.config.n_iterations {
println!(
"OPQ iteration {}/{}",
iteration + 1,
self.config.n_iterations
);
let rotated_data = self.apply_rotation(&data_matrix, &rotation);
let rotated_vectors = self.matrix_to_vectors(&rotated_data);
self.pq_index.train(&rotated_vectors)?;
rotation = self.optimize_rotation(&data_matrix, &rotated_vectors)?;
let error = self.compute_reconstruction_error(&data_matrix, &rotation)?;
println!("Reconstruction error: {error}");
}
self.rotation_matrix = Some(rotation);
self.is_trained = true;
Ok(())
}
fn compute_mean(&self, data: &DMatrix<f32>) -> DVector<f32> {
let n_samples = data.nrows() as f32;
let mut mean = DVector::zeros(data.ncols());
for i in 0..data.ncols() {
mean[i] = data.column(i).sum() / n_samples;
}
mean
}
fn center_data_matrix(&self, data: &mut DMatrix<f32>, mean: &DVector<f32>) {
for i in 0..data.nrows() {
for j in 0..data.ncols() {
data[(i, j)] -= mean[j];
}
}
}
fn apply_rotation(&self, data: &DMatrix<f32>, rotation: &DMatrix<f32>) -> DMatrix<f32> {
data * rotation.transpose()
}
fn matrix_to_vectors(&self, matrix: &DMatrix<f32>) -> Vec<Vector> {
let mut vectors = Vec::with_capacity(matrix.nrows());
for i in 0..matrix.nrows() {
let row: Vec<f32> = matrix.row(i).iter().cloned().collect();
vectors.push(Vector::new(row));
}
vectors
}
fn optimize_rotation(
&self,
data: &DMatrix<f32>,
rotated_vectors: &[Vector],
) -> Result<DMatrix<f32>> {
let mut reconstructed = DMatrix::zeros(data.nrows(), data.ncols());
for (i, vector) in rotated_vectors.iter().enumerate() {
if let Ok(reconstructed_vec) = self.pq_index.reconstruct(vector) {
let rec_f32 = reconstructed_vec.as_f32();
for (j, &val) in rec_f32.iter().enumerate() {
reconstructed[(i, j)] = val;
}
}
}
let correlation = data.transpose() * &reconstructed;
let mut reg_correlation = correlation.clone();
if self.config.regularization > 0.0 {
for i in 0..reg_correlation.ncols().min(reg_correlation.nrows()) {
reg_correlation[(i, i)] += self.config.regularization;
}
}
let svd = SVD::new(reg_correlation, true, true);
let u = svd.u.ok_or_else(|| anyhow!("SVD failed to compute U"))?;
let v_t = svd
.v_t
.ok_or_else(|| anyhow!("SVD failed to compute V^T"))?;
Ok(u * v_t)
}
fn compute_reconstruction_error(
&self,
data: &DMatrix<f32>,
rotation: &DMatrix<f32>,
) -> Result<f32> {
let rotated = self.apply_rotation(data, rotation);
let rotated_vecs = self.matrix_to_vectors(&rotated);
let mut total_error = 0.0;
for (i, vec) in rotated_vecs.iter().enumerate() {
if let Ok(reconstructed) = self.pq_index.reconstruct(vec) {
let rec_f32 = reconstructed.as_f32();
for (j, &val) in rec_f32.iter().enumerate() {
let diff = rotated[(i, j)] - val;
total_error += diff * diff;
}
}
}
Ok((total_error / (data.nrows() * data.ncols()) as f32).sqrt())
}
pub fn encode(&self, vector: &Vector) -> Result<Vec<u8>> {
if !self.is_trained {
return Err(anyhow!("OPQ index must be trained before encoding"));
}
let transformed = self.transform_vector(vector)?;
self.pq_index.encode(&transformed)
}
pub fn decode(&self, codes: &[u8]) -> Result<Vector> {
if !self.is_trained {
return Err(anyhow!("OPQ index must be trained before decoding"));
}
let rotated = self.pq_index.decode(codes)?;
self.inverse_transform_vector(&rotated)
}
fn transform_vector(&self, vector: &Vector) -> Result<Vector> {
let rotation = self
.rotation_matrix
.as_ref()
.ok_or_else(|| anyhow!("Rotation matrix not initialized"))?;
let vec_f32 = vector.as_f32();
let mut vec_dv = DVector::from_vec(vec_f32.to_vec());
if let Some(ref mean) = self.data_mean {
vec_dv -= mean;
}
let rotated = rotation.transpose() * vec_dv;
Ok(Vector::new(rotated.iter().cloned().collect()))
}
fn inverse_transform_vector(&self, vector: &Vector) -> Result<Vector> {
let rotation = self
.rotation_matrix
.as_ref()
.ok_or_else(|| anyhow!("Rotation matrix not initialized"))?;
let vec_f32 = vector.as_f32();
let vec_dv = DVector::from_vec(vec_f32.to_vec());
let unrotated = rotation * vec_dv;
let mut result = unrotated;
if let Some(ref mean) = self.data_mean {
result += mean;
}
Ok(Vector::new(result.iter().cloned().collect()))
}
pub fn search(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
if !self.is_trained {
return Err(anyhow!("OPQ index must be trained before searching"));
}
let transformed_query = self.transform_vector(query)?;
self.pq_index.search(&transformed_query, k)
}
pub fn stats(&self) -> OPQStats {
let pq_stats = self.pq_index.stats();
OPQStats {
pq_stats,
is_trained: self.is_trained,
has_rotation: self.rotation_matrix.is_some(),
rotation_rank: self
.rotation_matrix
.as_ref()
.map(|r| r.rank(1e-6))
.unwrap_or(0),
}
}
}
#[derive(Debug, Clone)]
pub struct OPQStats {
pub pq_stats: crate::pq::PQStats,
pub is_trained: bool,
pub has_rotation: bool,
pub rotation_rank: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::VectorIndex;
#[test]
fn test_opq_basic() -> Result<()> {
let config = OPQConfig {
pq_config: PQConfig {
n_subquantizers: 4,
n_centroids: 16,
..Default::default()
},
n_iterations: 3,
..Default::default()
};
let mut opq = OPQIndex::new(config);
let vectors: Vec<Vector> = (0..100)
.map(|i| {
let values: Vec<f32> = (0..16)
.map(|j| (i as f32 * 0.1 + j as f32) % 10.0)
.collect();
Vector::new(values)
})
.collect();
opq.train(&vectors)?;
let test_vec = Vector::new(vec![1.0; 16]);
let codes = opq.encode(&test_vec)?;
let reconstructed = opq.decode(&codes)?;
assert_eq!(reconstructed.dimensions, 16);
Ok(())
}
#[test]
fn test_opq_search() -> Result<()> {
let pq_config = PQConfig {
n_subquantizers: 2,
n_centroids: 4,
n_bits: 2,
max_iterations: 5,
convergence_threshold: 1e-3,
seed: None,
enable_residual_quantization: false,
residual_levels: 2,
enable_multi_codebook: false,
num_codebooks: 2,
enable_symmetric_distance: false,
};
let config = OPQConfig {
pq_config,
n_iterations: 2,
center_data: true,
regularization: 0.0,
};
let mut opq = OPQIndex::new(config);
let vectors: Vec<Vector> = (0..20)
.map(|i| {
let values: Vec<f32> = (0..4).map(|j| ((i * j) as f32).sin()).collect();
Vector::new(values)
})
.collect();
opq.train(&vectors)?;
for (i, vec) in vectors.iter().enumerate() {
opq.pq_index.insert(format!("vec_{i}"), vec.clone())?;
}
let query = Vector::new(vec![0.5; 4]);
let results = opq.search(&query, 5)?;
assert_eq!(results.len(), 5);
Ok(())
}
}