use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize)]
pub struct PqCodec {
pub dim: usize,
pub m: usize,
pub k: usize,
pub sub_dim: usize,
codebooks: Vec<Vec<Vec<f32>>>,
}
impl PqCodec {
pub fn train(vectors: &[&[f32]], dim: usize, m: usize, k: usize, max_iter: usize) -> Self {
assert!(!vectors.is_empty());
assert!(dim > 0 && m > 0 && k > 0);
assert!(
dim.is_multiple_of(m),
"dim ({dim}) must be divisible by m ({m})"
);
let sub_dim = dim / m;
let mut codebooks = Vec::with_capacity(m);
for sub in 0..m {
let offset = sub * sub_dim;
let sub_vectors: Vec<&[f32]> = vectors
.iter()
.map(|v| &v[offset..offset + sub_dim])
.collect();
let centroids = kmeans(&sub_vectors, sub_dim, k, max_iter);
codebooks.push(centroids);
}
Self {
dim,
m,
k,
sub_dim,
codebooks,
}
}
pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
debug_assert_eq!(vector.len(), self.dim);
let mut code = Vec::with_capacity(self.m);
for sub in 0..self.m {
let offset = sub * self.sub_dim;
let sub_vec = &vector[offset..offset + self.sub_dim];
let nearest = self.nearest_centroid(sub, sub_vec);
code.push(nearest as u8);
}
code
}
pub fn encode_batch(&self, vectors: &[&[f32]]) -> Vec<u8> {
let mut out = Vec::with_capacity(self.m * vectors.len());
for v in vectors {
out.extend(self.encode(v));
}
out
}
pub fn build_distance_table(&self, query: &[f32]) -> Vec<Vec<f32>> {
debug_assert_eq!(query.len(), self.dim);
let mut table = Vec::with_capacity(self.m);
for sub in 0..self.m {
let offset = sub * self.sub_dim;
let sub_query = &query[offset..offset + self.sub_dim];
let mut dists = Vec::with_capacity(self.k);
for centroid in &self.codebooks[sub] {
let d = l2_sub(sub_query, centroid);
dists.push(d);
}
table.push(dists);
}
table
}
#[inline]
pub fn asymmetric_distance(&self, table: &[Vec<f32>], code: &[u8]) -> f32 {
debug_assert_eq!(code.len(), self.m);
let mut dist = 0.0f32;
for (sub, &c) in code.iter().enumerate() {
dist += table[sub][c as usize];
}
dist
}
pub fn decode(&self, code: &[u8]) -> Vec<f32> {
debug_assert_eq!(code.len(), self.m);
let mut out = Vec::with_capacity(self.dim);
for (sub, &c) in code.iter().enumerate() {
out.extend_from_slice(&self.codebooks[sub][c as usize]);
}
out
}
fn nearest_centroid(&self, subspace: usize, sub_vec: &[f32]) -> usize {
let mut best_idx = 0;
let mut best_dist = f32::MAX;
for (i, centroid) in self.codebooks[subspace].iter().enumerate() {
let d = l2_sub(sub_vec, centroid);
if d < best_dist {
best_dist = d;
best_idx = i;
}
}
best_idx
}
}
#[inline]
fn l2_sub(a: &[f32], b: &[f32]) -> f32 {
let mut sum = 0.0f32;
for i in 0..a.len() {
let d = a[i] - b[i];
sum += d * d;
}
sum
}
fn kmeans(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec<Vec<f32>> {
let n = data.len();
if n == 0 || k == 0 {
return Vec::new();
}
let k = k.min(n);
let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
centroids.push(data[0].to_vec());
let mut min_dists = vec![f32::MAX; n];
for c in 1..k {
for (i, point) in data.iter().enumerate() {
let d = l2_sub(point, ¢roids[c - 1]);
if d < min_dists[i] {
min_dists[i] = d;
}
}
let total: f64 = min_dists.iter().map(|&d| d as f64).sum();
if total < f64::EPSILON {
centroids.push(data[0].to_vec());
continue;
}
let best_idx = min_dists
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
centroids.push(data[best_idx].to_vec());
}
let mut assignments = vec![0usize; n];
for _ in 0..max_iter {
let mut changed = false;
for (i, point) in data.iter().enumerate() {
let mut best = 0;
let mut best_d = f32::MAX;
for (c, centroid) in centroids.iter().enumerate() {
let d = l2_sub(point, centroid);
if d < best_d {
best_d = d;
best = c;
}
}
if assignments[i] != best {
assignments[i] = best;
changed = true;
}
}
if !changed {
break;
}
let mut sums = vec![vec![0.0f32; dim]; k];
let mut counts = vec![0usize; k];
for (i, point) in data.iter().enumerate() {
let c = assignments[i];
counts[c] += 1;
for d in 0..dim {
sums[c][d] += point[d];
}
}
for c in 0..k {
if counts[c] > 0 {
for d in 0..dim {
centroids[c][d] = sums[c][d] / counts[c] as f32;
}
}
}
}
centroids
}
#[cfg(test)]
mod tests {
use super::*;
fn make_clustered_data() -> Vec<Vec<f32>> {
let mut vecs = Vec::new();
for cluster in 0..4 {
let center = cluster as f32 * 10.0;
for i in 0..50 {
vecs.push(vec![
center + (i as f32) * 0.1,
center + (i as f32) * 0.05,
center - (i as f32) * 0.1,
center + (i as f32) * 0.02,
]);
}
}
vecs
}
#[test]
fn encode_decode_roundtrip() {
let vecs = make_clustered_data();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = PqCodec::train(&refs, 4, 2, 16, 10);
for v in &vecs {
let code = codec.encode(v);
assert_eq!(code.len(), 2); let decoded = codec.decode(&code);
assert_eq!(decoded.len(), 4);
}
}
#[test]
fn distance_table_gives_correct_ordering() {
let vecs = make_clustered_data();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = PqCodec::train(&refs, 4, 2, 16, 10);
let codes: Vec<Vec<u8>> = vecs.iter().map(|v| codec.encode(v)).collect();
let query = &[5.0, 5.0, 5.0, 5.0];
let table = codec.build_distance_table(query);
let mut pq_dists: Vec<(usize, f32)> = codes
.iter()
.enumerate()
.map(|(i, c)| (i, codec.asymmetric_distance(&table, c)))
.collect();
pq_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let mut exact_dists: Vec<(usize, f32)> = vecs
.iter()
.enumerate()
.map(|(i, v)| (i, l2_sub(query, v)))
.collect();
exact_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let pq_top: std::collections::HashSet<usize> = pq_dists[..5].iter().map(|x| x.0).collect();
let exact_top: std::collections::HashSet<usize> =
exact_dists[..10].iter().map(|x| x.0).collect();
let overlap = pq_top.intersection(&exact_top).count();
assert!(overlap >= 3, "PQ recall too low: {overlap}/5 in top-10");
}
#[test]
fn batch_encode() {
let vecs = make_clustered_data();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = PqCodec::train(&refs, 4, 2, 16, 10);
let batch = codec.encode_batch(&refs);
assert_eq!(batch.len(), 2 * 200); }
}