use nalgebra::{DMatrix, SVD};
use crate::vector_quant::codec::{AdcLut, VectorCodec};
use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
use crate::vector_quant::opq_kmeans::l2_sq;
use crate::vector_quant::opq_kmeans::lloyd;
pub struct OpqCodec {
pub dim: usize,
pub m: usize,
pub k: usize,
pub sub_dim: usize,
rotation: Vec<f32>,
codebooks: Vec<Vec<Vec<f32>>>,
}
impl OpqCodec {
pub fn train(
vectors: &[&[f32]],
dim: usize,
m: usize,
k: usize,
opq_iters: usize,
kmeans_iters: usize,
) -> Self {
assert!(!vectors.is_empty(), "training set must be non-empty");
assert!(dim > 0 && m > 0 && k > 0, "dim/m/k must be positive");
assert!(
dim.is_multiple_of(m),
"dim ({dim}) must be divisible by m ({m})"
);
let sub_dim = dim / m;
let seed = dim as u64 ^ ((m as u64) << 16) ^ ((k as u64) << 32);
let mut rotation = identity(dim);
let mut codebooks: Vec<Vec<Vec<f32>>> = Vec::new();
let iters = opq_iters.max(1);
for iter in 0..iters {
let rotated: Vec<Vec<f32>> =
vectors.iter().map(|v| matvec(&rotation, v, dim)).collect();
codebooks = train_codebooks(&rotated, m, k, sub_dim, kmeans_iters, seed ^ iter as u64);
if iter + 1 < iters {
let n = vectors.len();
let x_mat = DMatrix::from_fn(dim, n, |row, col| vectors[col][row]);
let y_mat = {
let recon: Vec<Vec<f32>> = rotated
.iter()
.map(|rv| {
let codes = pq_encode(rv, &codebooks, m, sub_dim);
dequantize_codes(&codes, &codebooks)
})
.collect();
DMatrix::from_fn(dim, n, |row, col| recon[col][row])
};
let m_mat = &x_mat * y_mat.transpose();
let has_nan = m_mat.iter().any(|x| x.is_nan());
if !has_nan {
let svd = SVD::new(m_mat, true, true);
if let (Some(u), Some(v_t)) = (svd.u, svd.v_t) {
let r_new = v_t.transpose() * u.transpose();
let mut buf = Vec::with_capacity(dim * dim);
for i in 0..dim {
for j in 0..dim {
buf.push(r_new[(i, j)]);
}
}
rotation = buf;
}
}
}
}
Self {
dim,
m,
k,
sub_dim,
rotation,
codebooks,
}
}
pub fn apply_rotation(&self, v: &[f32]) -> Vec<f32> {
matvec(&self.rotation, v, self.dim)
}
fn encode_inner(&self, v: &[f32]) -> (Vec<u8>, UnifiedQuantizedVector) {
let rotated = self.apply_rotation(v);
let codes = pq_encode(&rotated, &self.codebooks, self.m, self.sub_dim);
let uqv = make_uqv(&codes, self.dim as u16);
(codes, uqv)
}
fn dequantize(&self, codes: &[u8]) -> Vec<f32> {
dequantize_codes(codes, &self.codebooks)
}
}
fn identity(dim: usize) -> Vec<f32> {
let mut mat = vec![0.0f32; dim * dim];
for i in 0..dim {
mat[i * dim + i] = 1.0;
}
mat
}
fn dequantize_codes(codes: &[u8], codebooks: &[Vec<Vec<f32>>]) -> Vec<f32> {
let mut out = Vec::with_capacity(codebooks.len() * codebooks[0][0].len());
for (s, &c) in codes.iter().enumerate() {
out.extend_from_slice(&codebooks[s][c as usize]);
}
out
}
#[inline]
fn matvec(r: &[f32], v: &[f32], dim: usize) -> Vec<f32> {
let mut out = vec![0.0f32; dim];
for i in 0..dim {
let row = &r[i * dim..(i + 1) * dim];
out[i] = row.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
}
out
}
fn pq_encode(v: &[f32], codebooks: &[Vec<Vec<f32>>], m: usize, sub_dim: usize) -> Vec<u8> {
let mut codes = Vec::with_capacity(m);
#[allow(clippy::needless_range_loop)]
for s in 0..m {
let offset = s * sub_dim;
let sub = &v[offset..offset + sub_dim];
let best = codebooks[s]
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
l2_sq(sub, a)
.partial_cmp(&l2_sq(sub, b))
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0);
codes.push(best as u8);
}
codes
}
fn train_codebooks(
rotated: &[Vec<f32>],
m: usize,
k: usize,
sub_dim: usize,
kmeans_iters: usize,
seed: u64,
) -> Vec<Vec<Vec<f32>>> {
let mut codebooks = Vec::with_capacity(m);
for s in 0..m {
let offset = s * sub_dim;
let sub_vecs: Vec<Vec<f32>> = rotated
.iter()
.map(|v| v[offset..offset + sub_dim].to_vec())
.collect();
let centroids = lloyd(
&sub_vecs,
sub_dim,
k,
kmeans_iters,
seed ^ (s as u64 * 0x1234567),
);
codebooks.push(centroids);
}
codebooks
}
fn make_uqv(codes: &[u8], dim: u16) -> UnifiedQuantizedVector {
let header = QuantHeader {
quant_mode: QuantMode::Pq as u16,
dim,
global_scale: 1.0,
residual_norm: 0.0,
dot_quantized: 0.0,
outlier_bitmask: 0,
reserved: [0; 8],
};
UnifiedQuantizedVector::new(header, codes, &[])
.expect("make_uqv: layout construction must not fail for valid inputs")
}
pub struct OpqQuantized {
codes: Vec<u8>,
uqv: UnifiedQuantizedVector,
}
impl AsRef<UnifiedQuantizedVector> for OpqQuantized {
fn as_ref(&self) -> &UnifiedQuantizedVector {
&self.uqv
}
}
pub struct OpqQuery {
pub distance_table: Vec<f32>,
#[allow(dead_code)]
rotated: Vec<f32>,
}
impl VectorCodec for OpqCodec {
type Quantized = OpqQuantized;
type Query = OpqQuery;
fn encode(&self, v: &[f32]) -> Self::Quantized {
let (codes, uqv) = self.encode_inner(v);
OpqQuantized { codes, uqv }
}
fn prepare_query(&self, q: &[f32]) -> Self::Query {
let rotated = self.apply_rotation(q);
let mut table = vec![0.0f32; self.m * self.k];
for s in 0..self.m {
let offset = s * self.sub_dim;
let sub_q = &rotated[offset..offset + self.sub_dim];
for c in 0..self.k {
table[s * self.k + c] = l2_sq(sub_q, &self.codebooks[s][c]);
}
}
OpqQuery {
distance_table: table,
rotated,
}
}
fn adc_lut(&self, q: &Self::Query) -> Option<AdcLut> {
let mut lut = AdcLut::new(self.m as u16, self.k as u16);
lut.table.copy_from_slice(&q.distance_table);
Some(lut)
}
fn fast_symmetric_distance(&self, q: &Self::Quantized, v: &Self::Quantized) -> f32 {
let qv = self.dequantize(&q.codes);
let vv = self.dequantize(&v.codes);
l2_sq(&qv, &vv)
}
fn exact_asymmetric_distance(&self, q: &Self::Query, v: &Self::Quantized) -> f32 {
v.codes
.iter()
.enumerate()
.map(|(s, &code)| q.distance_table[s * self.k + code as usize])
.sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tiny_dataset() -> Vec<Vec<f32>> {
(0..10)
.map(|i| {
let base = i as f32 * 2.0;
vec![
base,
base + 0.1,
base - 0.1,
base + 0.2,
base * 0.5,
base * 0.5 + 0.1,
base * 0.5 - 0.1,
base * 0.5 + 0.05,
]
})
.collect()
}
fn train_tiny() -> OpqCodec {
let vecs = tiny_dataset();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
OpqCodec::train(&refs, 8, 2, 4, 10, 30)
}
#[test]
fn encode_produces_m_bytes() {
let codec = train_tiny();
let vecs = tiny_dataset();
for v in &vecs {
let q = codec.encode(v);
assert_eq!(q.codes.len(), codec.m);
}
}
#[test]
fn distance_is_non_negative() {
let codec = train_tiny();
let vecs = tiny_dataset();
for v in &vecs {
let qv = codec.encode(v);
let qq = codec.prepare_query(v);
let asym = codec.exact_asymmetric_distance(&qq, &qv);
let sym = codec.fast_symmetric_distance(&qv, &qv);
assert!(
asym >= 0.0,
"asymmetric distance must be non-negative, got {asym}"
);
assert!(
sym >= 0.0,
"symmetric distance must be non-negative, got {sym}"
);
}
}
#[test]
fn top1_recall_on_training_set() {
let vecs = tiny_dataset();
let codec = train_tiny();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let encoded: Vec<_> = refs.iter().map(|v| codec.encode(v)).collect();
let mut correct = 0usize;
for (i, v) in refs.iter().enumerate() {
let query = codec.prepare_query(v);
let best = encoded
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
codec
.exact_asymmetric_distance(&query, a)
.partial_cmp(&codec.exact_asymmetric_distance(&query, b))
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx)
.unwrap_or(usize::MAX);
if best == i {
correct += 1;
}
}
let recall = correct as f64 / vecs.len() as f64;
assert!(
recall >= 0.70,
"top-1 recall on training set too low: {correct}/{} = {recall:.2}",
vecs.len()
);
}
#[test]
fn more_iterations_reduce_reconstruction_error() {
let vecs = tiny_dataset();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec_1 = OpqCodec::train(&refs, 8, 2, 4, 1, 10);
let codec_5 = OpqCodec::train(&refs, 8, 2, 4, 5, 10);
let mean_recon_error = |codec: &OpqCodec| -> f32 {
refs.iter()
.map(|v| {
let rotated = codec.apply_rotation(v);
let codes = pq_encode(&rotated, &codec.codebooks, codec.m, codec.sub_dim);
let recon = dequantize_codes(&codes, &codec.codebooks);
l2_sq(&rotated, &recon)
})
.sum::<f32>()
/ refs.len() as f32
};
let err_1 = mean_recon_error(&codec_1);
let err_5 = mean_recon_error(&codec_5);
assert!(
err_5 <= err_1 * 1.05,
"5-iter OPQ (err={err_5:.4}) should have ≤ reconstruction error than 1-iter (err={err_1:.4})"
);
}
}