use std::ops::Deref;
use serde::{Deserialize, Serialize};
pub const PQ_BYTES: usize = 32;
pub const NUM_SUBSPACES: usize = PQ_BYTES;
pub const CENTROIDS_PER_SUBSPACE: usize = 256;
pub type PqCode = [u8; PQ_BYTES];
#[derive(Clone, Serialize, Deserialize)]
pub struct PqCodebook {
pub full_dim: usize,
pub sub_dim: usize,
pub centroids: Vec<Vec<Vec<f32>>>,
}
impl PqCodebook {
pub fn train(vectors: &[Vec<f32>], full_dim: usize, max_iters: usize) -> Self {
assert!(!vectors.is_empty(), "need at least one training vector");
assert_eq!(
full_dim % NUM_SUBSPACES,
0,
"full_dim must be divisible by {NUM_SUBSPACES}"
);
let sub_dim = full_dim / NUM_SUBSPACES;
let centroids: Vec<Vec<Vec<f32>>> = (0..NUM_SUBSPACES)
.map(|m| {
let start = m * sub_dim;
let slices: Vec<&[f32]> = vectors.iter().map(|v| &v[start..start + sub_dim]).collect();
kmeans(&slices, CENTROIDS_PER_SUBSPACE, sub_dim, max_iters)
})
.collect();
Self {
full_dim,
sub_dim,
centroids,
}
}
pub fn encode(&self, vector: &[f32]) -> PqCode {
debug_assert_eq!(vector.len(), self.full_dim);
let mut code = [0u8; PQ_BYTES];
for (m, byte) in code.iter_mut().enumerate() {
let start = m * self.sub_dim;
let sub = &vector[start..start + self.sub_dim];
*byte = nearest_centroid(sub, &self.centroids[m]) as u8;
}
code
}
pub fn build_distance_table(&self, query: &[f32]) -> DistanceTable {
debug_assert_eq!(query.len(), self.full_dim);
let mut table = [[0.0f32; CENTROIDS_PER_SUBSPACE]; NUM_SUBSPACES];
for (m, row) in table.iter_mut().enumerate() {
let start = m * self.sub_dim;
let q_sub = &query[start..start + self.sub_dim];
let n_centroids = self.centroids[m].len();
for (c, cell) in row.iter_mut().take(n_centroids).enumerate() {
*cell = dot(q_sub, &self.centroids[m][c]);
}
}
DistanceTable(table)
}
}
pub struct DistanceTable(pub [[f32; CENTROIDS_PER_SUBSPACE]; NUM_SUBSPACES]);
impl DistanceTable {
#[inline]
pub fn approximate_dot(&self, code: &PqCode) -> f32 {
let mut sum = 0.0f32;
for (m, &byte) in code.iter().enumerate() {
sum += self.0[m][byte as usize];
}
sum
}
#[inline]
pub fn approximate_cosine(&self, code: &PqCode) -> f32 {
self.approximate_dot(code)
}
}
fn kmeans(data: &[&[f32]], k: usize, dim: usize, max_iters: usize) -> Vec<Vec<f32>> {
let k = k.min(data.len());
let mut centroids: Vec<Vec<f32>> = data.iter().take(k).map(|s| s.to_vec()).collect();
let mut assignments = vec![0usize; data.len()];
for _ in 0..max_iters {
let mut changed = false;
for (i, point) in data.iter().enumerate() {
let best = nearest_centroid(point, ¢roids);
if best != assignments[i] {
assignments[i] = best;
changed = true;
}
}
if !changed {
break;
}
let mut sums = vec![vec![0.0f32; dim]; k];
let mut counts = vec![0usize; k];
for (i, point) in data.iter().enumerate() {
let c = assignments[i];
counts[c] += 1;
for (j, &val) in point.iter().enumerate() {
sums[c][j] += val;
}
}
for c in 0..k {
if counts[c] > 0 {
let cnt = counts[c] as f32;
for j in 0..dim {
centroids[c][j] = sums[c][j] / cnt;
}
}
}
}
centroids
}
fn nearest_centroid(point: &[f32], centroids: &[Vec<f32>]) -> usize {
centroids
.iter()
.enumerate()
.map(|(i, c)| (i, sq_dist(point, c.deref())))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap()
.0
}
#[inline]
fn dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
fn sq_dist(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
}
pub fn l2_normalize(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
v.iter_mut().for_each(|x| *x /= norm);
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
fn random_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
let mut rng = rand::thread_rng();
(0..n)
.map(|_| {
let mut v: Vec<f32> = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
l2_normalize(&mut v);
v
})
.collect()
}
#[test]
fn encode_decode_roundtrip_preserves_similarity() {
let dim = 128;
let vecs = random_vectors(500, dim);
let cb = PqCodebook::train(&vecs, dim, 20);
let a = &vecs[0];
let b = &vecs[1];
let exact = dot(a, b);
let dt = cb.build_distance_table(a);
let code_b = cb.encode(b);
let approx = dt.approximate_cosine(&code_b);
let error = (exact - approx).abs();
assert!(
error < 0.35,
"PQ approximation error {error:.4} too large (exact={exact:.4}, approx={approx:.4})"
);
}
}