use crate::error::{Result, RuvectorError};
use crate::types::DistanceMetric;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OPQConfig {
pub num_subspaces: usize,
pub codebook_size: usize,
pub num_iterations: usize,
pub num_opq_iterations: usize,
pub metric: DistanceMetric,
}
impl Default for OPQConfig {
fn default() -> Self {
Self {
num_subspaces: 8, codebook_size: 256, num_iterations: 20,
num_opq_iterations: 10, metric: DistanceMetric::Euclidean,
}
}
}
impl OPQConfig {
pub fn validate(&self) -> Result<()> {
if self.codebook_size > 256 {
return Err(RuvectorError::InvalidParameter(format!(
"Codebook size {} exceeds u8 max 256", self.codebook_size)));
}
if self.num_subspaces == 0 {
return Err(RuvectorError::InvalidParameter("num_subspaces must be > 0".into()));
}
if self.num_opq_iterations == 0 {
return Err(RuvectorError::InvalidParameter("num_opq_iterations must be > 0".into()));
}
Ok(())
}
}
#[derive(Debug, Clone)]
struct Mat { rows: usize, cols: usize, data: Vec<f32> }
impl Mat {
fn zeros(r: usize, c: usize) -> Self { Self { rows: r, cols: c, data: vec![0.0; r * c] } }
fn identity(n: usize) -> Self {
let mut m = Self::zeros(n, n);
for i in 0..n { m.data[i * n + i] = 1.0; }
m
}
#[inline] fn get(&self, r: usize, c: usize) -> f32 { self.data[r * self.cols + c] }
#[inline] fn set(&mut self, r: usize, c: usize, v: f32) { self.data[r * self.cols + c] = v; }
fn transpose(&self) -> Self {
let mut t = Self::zeros(self.cols, self.rows);
for r in 0..self.rows { for c in 0..self.cols { t.set(c, r, self.get(r, c)); } }
t
}
fn mul(&self, b: &Mat) -> Mat {
assert_eq!(self.cols, b.rows);
let mut out = Mat::zeros(self.rows, b.cols);
for i in 0..self.rows {
for k in 0..self.cols {
let a = self.get(i, k);
for j in 0..b.cols { let c = out.get(i, j); out.set(i, j, c + a * b.get(k, j)); }
}
}
out
}
fn from_rows(vecs: &[Vec<f32>]) -> Self {
let (rows, cols) = (vecs.len(), vecs[0].len());
let mut data = Vec::with_capacity(rows * cols);
for v in vecs { data.extend_from_slice(v); }
Self { rows, cols, data }
}
fn row(&self, i: usize) -> Vec<f32> { self.data[i * self.cols..(i + 1) * self.cols].to_vec() }
}
fn svd_rank1(a: &Mat, max_iters: usize) -> (Vec<f32>, f32, Vec<f32>) {
let ata = a.transpose().mul(a);
let n = ata.cols;
let mut v = vec![1.0 / (n as f32).sqrt(); n];
for _ in 0..max_iters {
let mut nv = vec![0.0; n];
for i in 0..n { for j in 0..n { nv[i] += ata.get(i, j) * v[j]; } }
let norm: f32 = nv.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-12 { break; }
for x in nv.iter_mut() { *x /= norm; }
v = nv;
}
let mut av = vec![0.0; a.rows];
for i in 0..a.rows { for j in 0..a.cols { av[i] += a.get(i, j) * v[j]; } }
let sigma: f32 = av.iter().map(|x| x * x).sum::<f32>().sqrt();
let u = if sigma > 1e-12 { av.iter().map(|x| x / sigma).collect() } else { vec![0.0; a.rows] };
(u, sigma, v)
}
fn svd_full(a: &Mat, iters: usize) -> (Mat, Vec<f32>, Mat) {
let n = a.rows;
let mut res = a.clone();
let (mut uc, mut sv, mut vc) = (Vec::new(), Vec::new(), Vec::new());
for _ in 0..n {
let (u, s, v) = svd_rank1(&res, iters);
if s > 1e-10 {
for i in 0..res.rows { for j in 0..res.cols {
let c = res.get(i, j); res.set(i, j, c - s * u[i] * v[j]);
}}
}
uc.push(u); sv.push(s); vc.push(v);
}
let (mut um, mut vm) = (Mat::zeros(n, n), Mat::zeros(n, n));
for j in 0..n { for i in 0..n { um.set(i, j, uc[j][i]); vm.set(i, j, vc[j][i]); } }
(um, sv, vm)
}
fn procrustes(x: &Mat, y: &Mat) -> Mat {
let m = x.transpose().mul(y);
let (u, _s, v) = svd_full(&m, 100);
v.mul(&u.transpose())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RotationMatrix { pub dim: usize, pub data: Vec<f32> }
impl RotationMatrix {
pub fn identity(dim: usize) -> Self {
let mut data = vec![0.0; dim * dim];
for i in 0..dim { data[i * dim + i] = 1.0; }
Self { dim, data }
}
pub fn rotate(&self, v: &[f32]) -> Vec<f32> {
let d = self.dim;
(0..d).map(|j| (0..d).map(|i| v[i] * self.data[i * d + j]).sum()).collect()
}
pub fn inverse_rotate(&self, v: &[f32]) -> Vec<f32> {
let d = self.dim;
(0..d).map(|j| (0..d).map(|i| v[i] * self.data[j * d + i]).sum()).collect()
}
fn from_mat(m: &Mat) -> Self { Self { dim: m.rows, data: m.data.clone() } }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OPQIndex {
pub config: OPQConfig,
pub rotation: RotationMatrix,
pub codebooks: Vec<Vec<Vec<f32>>>,
pub dimensions: usize,
}
impl OPQIndex {
pub fn train(vectors: &[Vec<f32>], config: OPQConfig) -> Result<Self> {
config.validate()?;
if vectors.is_empty() {
return Err(RuvectorError::InvalidParameter("Training set cannot be empty".into()));
}
let d = vectors[0].len();
if d % config.num_subspaces != 0 {
return Err(RuvectorError::InvalidParameter(format!(
"Dimensions {} not divisible by num_subspaces {}", d, config.num_subspaces)));
}
for v in vectors { if v.len() != d {
return Err(RuvectorError::DimensionMismatch { expected: d, actual: v.len() });
}}
let x_mat = Mat::from_rows(vectors);
let mut r = Mat::identity(d);
let mut codebooks: Vec<Vec<Vec<f32>>> = Vec::new();
let sub_dim = d / config.num_subspaces;
for _ in 0..config.num_opq_iterations {
let x_rot = x_mat.mul(&r);
let rotated: Vec<Vec<f32>> = (0..vectors.len()).map(|i| x_rot.row(i)).collect();
codebooks = train_pq_codebooks(&rotated, config.num_subspaces,
config.codebook_size, config.num_iterations, config.metric)?;
let mut x_hat = Mat::zeros(vectors.len(), d);
for (i, rv) in rotated.iter().enumerate() {
let codes = encode_vec(rv, &codebooks, sub_dim, config.metric)?;
let recon = decode_vec(&codes, &codebooks);
for (j, &val) in recon.iter().enumerate() { x_hat.set(i, j, val); }
}
r = procrustes(&x_mat, &x_hat);
}
Ok(Self { config, rotation: RotationMatrix::from_mat(&r), codebooks, dimensions: d })
}
pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
self.check_dim(vector.len())?;
let rotated = self.rotation.rotate(vector);
encode_vec(&rotated, &self.codebooks,
self.dimensions / self.config.num_subspaces, self.config.metric)
}
pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>> {
if codes.len() != self.config.num_subspaces {
return Err(RuvectorError::InvalidParameter(format!(
"Expected {} codes, got {}", self.config.num_subspaces, codes.len())));
}
Ok(self.rotation.inverse_rotate(&decode_vec(codes, &self.codebooks)))
}
pub fn search_adc(&self, query: &[f32], codes_db: &[Vec<u8>], top_k: usize,
) -> Result<Vec<(usize, f32)>> {
self.check_dim(query.len())?;
let rq = self.rotation.rotate(query);
let sub_dim = self.dimensions / self.config.num_subspaces;
let tables: Vec<Vec<f32>> = (0..self.config.num_subspaces).map(|s| {
let q_sub = &rq[s * sub_dim..(s + 1) * sub_dim];
self.codebooks[s].iter().map(|c| dist(q_sub, c, self.config.metric)).collect()
}).collect();
let mut dists: Vec<(usize, f32)> = codes_db.iter().enumerate().map(|(idx, codes)| {
let d: f32 = codes.iter().enumerate().map(|(s, &c)| tables[s][c as usize]).sum();
(idx, d)
}).collect();
dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
dists.truncate(top_k);
Ok(dists)
}
pub fn quantization_error(&self, vectors: &[Vec<f32>]) -> Result<f32> {
if vectors.is_empty() { return Ok(0.0); }
let mut total = 0.0f64;
for v in vectors {
let recon = self.decode(&self.encode(v)?)?;
total += v.iter().zip(&recon).map(|(a, b)| ((a - b) as f64).powi(2)).sum::<f64>();
}
Ok((total / vectors.len() as f64) as f32)
}
fn check_dim(&self, len: usize) -> Result<()> {
if len != self.dimensions {
Err(RuvectorError::DimensionMismatch { expected: self.dimensions, actual: len })
} else { Ok(()) }
}
}
fn dist(a: &[f32], b: &[f32], m: DistanceMetric) -> f32 {
match m {
DistanceMetric::Euclidean =>
a.iter().zip(b).map(|(x, y)| { let d = x - y; d * d }).sum::<f32>().sqrt(),
DistanceMetric::Cosine => {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let na = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if na == 0.0 || nb == 0.0 { 1.0 } else { 1.0 - dot / (na * nb) }
}
DistanceMetric::DotProduct => -a.iter().zip(b).map(|(x, y)| x * y).sum::<f32>(),
DistanceMetric::Manhattan => a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum(),
}
}
fn train_pq_codebooks(vecs: &[Vec<f32>], nsub: usize, k: usize, iters: usize,
metric: DistanceMetric) -> Result<Vec<Vec<Vec<f32>>>> {
let sub_dim = vecs[0].len() / nsub;
(0..nsub).map(|s| {
let sv: Vec<Vec<f32>> = vecs.iter().map(|v| v[s*sub_dim..(s+1)*sub_dim].to_vec()).collect();
kmeans(&sv, k.min(sv.len()), iters, metric)
}).collect()
}
fn encode_vec(v: &[f32], cbs: &[Vec<Vec<f32>>], sub_dim: usize, m: DistanceMetric,
) -> Result<Vec<u8>> {
cbs.iter().enumerate().map(|(s, cb)| {
let sub = &v[s * sub_dim..(s + 1) * sub_dim];
cb.iter().enumerate()
.min_by(|(_, a), (_, b)| dist(sub, a, m).partial_cmp(&dist(sub, b, m)).unwrap())
.map(|(i, _)| i as u8)
.ok_or_else(|| RuvectorError::Internal("Empty codebook".into()))
}).collect()
}
fn decode_vec(codes: &[u8], cbs: &[Vec<Vec<f32>>]) -> Vec<f32> {
codes.iter().enumerate().flat_map(|(s, &c)| cbs[s][c as usize].iter().copied()).collect()
}
fn kmeans(vecs: &[Vec<f32>], k: usize, iters: usize, metric: DistanceMetric,
) -> Result<Vec<Vec<f32>>> {
use rand::seq::SliceRandom;
if vecs.is_empty() || k == 0 {
return Err(RuvectorError::InvalidParameter("Cannot cluster empty set or k=0".into()));
}
let dim = vecs[0].len();
let mut rng = rand::thread_rng();
let mut cents: Vec<Vec<f32>> = vecs.choose_multiple(&mut rng, k).cloned().collect();
for _ in 0..iters {
let (mut sums, mut counts) = (vec![vec![0.0f32; dim]; k], vec![0usize; k]);
for v in vecs {
let b = cents.iter().enumerate()
.min_by(|(_, a), (_, b)| dist(v, a, metric).partial_cmp(&dist(v, b, metric)).unwrap())
.map(|(i, _)| i).unwrap_or(0);
counts[b] += 1;
for (j, &val) in v.iter().enumerate() { sums[b][j] += val; }
}
for (i, c) in cents.iter_mut().enumerate() {
if counts[i] > 0 { for j in 0..dim { c[j] = sums[i][j] / counts[i] as f32; } }
}
}
Ok(cents)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_data(n: usize, d: usize) -> Vec<Vec<f32>> {
let mut seed: u64 = 42;
(0..n).map(|_| (0..d).map(|_| {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
((seed >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
}).collect()).collect()
}
fn cfg() -> OPQConfig {
OPQConfig { num_subspaces: 2, codebook_size: 4, num_iterations: 5,
num_opq_iterations: 3, metric: DistanceMetric::Euclidean }
}
#[test]
fn test_rotation_orthogonality() {
let r = RotationMatrix::identity(4);
let v = vec![1.0, 2.0, 3.0, 4.0];
let back = r.inverse_rotate(&r.rotate(&v));
for i in 0..4 { assert!((v[i] - back[i]).abs() < 1e-6); }
}
#[test]
fn test_rotation_preserves_norm() {
let data = make_data(30, 4);
let idx = OPQIndex::train(&data, cfg()).unwrap();
let v = vec![1.0, 2.0, 3.0, 4.0];
let n1: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
let n2: f32 = idx.rotation.rotate(&v).iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((n1 - n2).abs() < 0.1, "norms: {} vs {}", n1, n2);
}
#[test]
fn test_pq_encoding_roundtrip() {
let data = make_data(30, 4);
let idx = OPQIndex::train(&data, cfg()).unwrap();
let codes = idx.encode(&data[0]).unwrap();
assert_eq!(codes.len(), 2);
assert_eq!(idx.decode(&codes).unwrap().len(), 4);
}
#[test]
fn test_opq_training_convergence() {
let data = make_data(100, 4);
let idx = OPQIndex::train(&data, cfg()).unwrap();
let err = idx.quantization_error(&data).unwrap();
assert!(err.is_finite() && err >= 0.0, "error must be finite non-negative: {}", err);
for v in &data {
let codes = idx.encode(v).unwrap();
let decoded = idx.decode(&codes).unwrap();
assert_eq!(decoded.len(), v.len());
for x in &decoded { assert!(x.is_finite()); }
}
}
#[test]
fn test_adc_correctness() {
let data = make_data(30, 4);
let idx = OPQIndex::train(&data, cfg()).unwrap();
let db: Vec<Vec<u8>> = data.iter().map(|v| idx.encode(v).unwrap()).collect();
let res = idx.search_adc(&[0.5, -0.5, 0.5, -0.5], &db, 3).unwrap();
assert_eq!(res.len(), 3);
for w in res.windows(2) { assert!(w[0].1 <= w[1].1 + 1e-6); }
}
#[test]
fn test_quantization_error_reduction() {
let data = make_data(50, 4);
let err = OPQIndex::train(&data, cfg()).unwrap().quantization_error(&data).unwrap();
assert!(err >= 0.0 && err.is_finite() && err < 10.0, "err={}", err);
}
#[test]
fn test_svd_correctness() {
let a = Mat { rows: 2, cols: 2, data: vec![3.0, 0.0, 0.0, 2.0] };
let (u, s, v) = svd_full(&a, 200);
for i in 0..2 { for j in 0..2 {
let r: f32 = (0..2).map(|k| u.get(i, k) * s[k] * v.get(j, k)).sum();
assert!((a.get(i, j) - r).abs() < 0.1, "SVD fail ({},{}): {} vs {}", i, j, a.get(i, j), r);
}}
}
#[test]
fn test_identity_rotation_baseline() {
let data = make_data(30, 4);
let idx = OPQIndex::train(&data, OPQConfig { num_opq_iterations: 1, ..cfg() }).unwrap();
let recon = idx.decode(&idx.encode(&data[0]).unwrap()).unwrap();
assert_eq!(recon.len(), data[0].len());
}
#[test]
fn test_search_accuracy() {
let data = make_data(40, 4);
let idx = OPQIndex::train(&data, cfg()).unwrap();
let db: Vec<Vec<u8>> = data.iter().map(|v| idx.encode(v).unwrap()).collect();
let ids: Vec<usize> = idx.search_adc(&data[0], &db, 5).unwrap().iter().map(|r| r.0).collect();
assert!(ids.contains(&0), "vector 0 should be in its own top-5");
}
#[test]
fn test_config_validation() {
assert!(OPQConfig { codebook_size: 300, ..cfg() }.validate().is_err());
assert!(OPQConfig { num_subspaces: 0, ..cfg() }.validate().is_err());
assert!(OPQConfig { num_opq_iterations: 0, ..cfg() }.validate().is_err());
}
#[test]
fn test_dimension_mismatch_errors() {
let idx = OPQIndex::train(&make_data(30, 4), cfg()).unwrap();
assert!(idx.encode(&[1.0, 2.0]).is_err());
assert!(idx.search_adc(&[1.0], &[], 1).is_err());
}
}