use std::mem::size_of;
use std::sync::Arc;
use nodedb_mem::{EngineId, MemoryGovernor};
use serde::{Deserialize, Serialize};
use crate::error::VectorError;
#[inline]
fn try_reserve_or_skip(
governor: &Option<Arc<MemoryGovernor>>,
bytes: usize,
) -> Result<Option<nodedb_mem::BudgetGuard>, VectorError> {
match governor {
Some(g) => Ok(Some(g.reserve(EngineId::Vector, bytes)?)),
None => Ok(None),
}
}
#[derive(
Clone, Debug, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
)]
pub struct PqCodec {
pub dim: usize,
pub m: usize,
pub k: usize,
pub sub_dim: usize,
codebooks: Vec<Vec<Vec<f32>>>,
#[serde(skip, default)]
#[msgpack(ignore)]
governor: Option<Arc<MemoryGovernor>>,
}
impl PqCodec {
pub fn with_governor(mut self, governor: Arc<MemoryGovernor>) -> Self {
self.governor = Some(governor);
self
}
pub fn train(vectors: &[&[f32]], dim: usize, m: usize, k: usize, max_iter: usize) -> Self {
assert!(!vectors.is_empty());
assert!(dim > 0 && m > 0 && k > 0);
assert!(
dim.is_multiple_of(m),
"dim ({dim}) must be divisible by m ({m})"
);
let sub_dim = dim / m;
let mut codebooks = Vec::with_capacity(m);
for sub in 0..m {
let offset = sub * sub_dim;
let sub_vectors: Vec<&[f32]> = vectors
.iter()
.map(|v| &v[offset..offset + sub_dim])
.collect();
let centroids = kmeans(&sub_vectors, sub_dim, k, max_iter);
codebooks.push(centroids);
}
Self {
dim,
m,
k,
sub_dim,
codebooks,
governor: None,
}
}
pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
debug_assert_eq!(vector.len(), self.dim);
let mut code = Vec::with_capacity(self.m);
for sub in 0..self.m {
let offset = sub * self.sub_dim;
let sub_vec = &vector[offset..offset + self.sub_dim];
let nearest = self.nearest_centroid(sub, sub_vec);
code.push(nearest as u8);
}
code
}
pub fn encode_batch(&self, vectors: &[&[f32]]) -> Result<Vec<u8>, VectorError> {
let capacity = self.m * vectors.len();
let _g = try_reserve_or_skip(&self.governor, capacity * size_of::<u8>())?;
let mut out = Vec::with_capacity(capacity);
for v in vectors {
out.extend(self.encode(v));
}
Ok(out)
}
pub fn build_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>, VectorError> {
debug_assert_eq!(query.len(), self.dim);
let total_bytes = self.m * self.k * size_of::<f32>();
let _g = try_reserve_or_skip(&self.governor, total_bytes)?;
let mut table = Vec::with_capacity(self.m);
for sub in 0..self.m {
let offset = sub * self.sub_dim;
let sub_query = &query[offset..offset + self.sub_dim];
let mut dists = Vec::with_capacity(self.k);
for centroid in &self.codebooks[sub] {
let d = l2_sub(sub_query, centroid);
dists.push(d);
}
table.push(dists);
}
Ok(table)
}
#[inline]
pub fn asymmetric_distance(&self, table: &[Vec<f32>], code: &[u8]) -> f32 {
debug_assert_eq!(code.len(), self.m);
let mut dist = 0.0f32;
for (sub, &c) in code.iter().enumerate() {
dist += table[sub][c as usize];
}
dist
}
pub fn decode(&self, code: &[u8]) -> Result<Vec<f32>, VectorError> {
debug_assert_eq!(code.len(), self.m);
let _g = try_reserve_or_skip(&self.governor, self.dim * size_of::<f32>())?;
let mut out = Vec::with_capacity(self.dim);
for (sub, &c) in code.iter().enumerate() {
out.extend_from_slice(&self.codebooks[sub][c as usize]);
}
Ok(out)
}
pub fn to_bytes(&self) -> Result<Vec<u8>, VectorError> {
const MAGIC: &[u8; 6] = b"NDPQ\0\0";
const VERSION: u8 = 1;
let estimated = self.m * self.k * self.sub_dim * size_of::<f32>() + 64;
let _g = try_reserve_or_skip(&self.governor, estimated)?;
let payload = zerompk::to_msgpack_vec(self).unwrap_or_default();
let mut out = Vec::with_capacity(7 + payload.len());
out.extend_from_slice(MAGIC);
out.push(VERSION);
out.extend_from_slice(&payload);
Ok(out)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
const MAGIC: &[u8; 6] = b"NDPQ\0\0";
const PQ_FORMAT_VERSION: u8 = 1;
if bytes.len() < 7 || &bytes[0..6] != MAGIC {
return Err(VectorError::InvalidMagic);
}
let version = bytes[6];
if version != PQ_FORMAT_VERSION {
return Err(VectorError::UnsupportedVersion {
found: version,
expected: PQ_FORMAT_VERSION,
});
}
zerompk::from_msgpack::<Self>(&bytes[7..])
.map_err(|e| VectorError::DeserializationFailed(e.to_string()))
}
fn nearest_centroid(&self, subspace: usize, sub_vec: &[f32]) -> usize {
let mut best_idx = 0;
let mut best_dist = f32::MAX;
for (i, centroid) in self.codebooks[subspace].iter().enumerate() {
let d = l2_sub(sub_vec, centroid);
if d < best_dist {
best_dist = d;
best_idx = i;
}
}
best_idx
}
}
#[inline]
fn l2_sub(a: &[f32], b: &[f32]) -> f32 {
let mut sum = 0.0f32;
for i in 0..a.len() {
let d = a[i] - b[i];
sum += d * d;
}
sum
}
fn kmeans(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec<Vec<f32>> {
let n = data.len();
if n == 0 || k == 0 {
return Vec::new();
}
let k = k.min(n);
let mut rng = crate::hnsw::Xorshift64::new(0xC0FF_EEDE_ADBE_EF42);
let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
centroids.push(data[0].to_vec());
let mut min_dists = vec![f32::MAX; n];
for (i, point) in data.iter().enumerate() {
let d = l2_sub(point, ¢roids[0]);
if d < min_dists[i] {
min_dists[i] = d;
}
}
for _ in 1..k {
let total: f64 = min_dists.iter().map(|&d| d as f64).sum();
let next_idx = if total < f64::EPSILON {
0
} else {
let target = rng.next_f64() * total;
let mut acc = 0.0f64;
let mut chosen = n - 1;
for (i, &d) in min_dists.iter().enumerate() {
acc += d as f64;
if acc >= target {
chosen = i;
break;
}
}
chosen
};
centroids.push(data[next_idx].to_vec());
let last = centroids.last().expect("just pushed");
for (i, point) in data.iter().enumerate() {
let d = l2_sub(point, last);
if d < min_dists[i] {
min_dists[i] = d;
}
}
}
let mut assignments = vec![0usize; n];
for _ in 0..max_iter {
let mut changed = false;
for (i, point) in data.iter().enumerate() {
let mut best = 0;
let mut best_d = f32::MAX;
for (c, centroid) in centroids.iter().enumerate() {
let d = l2_sub(point, centroid);
if d < best_d {
best_d = d;
best = c;
}
}
if assignments[i] != best {
assignments[i] = best;
changed = true;
}
}
if !changed {
break;
}
let mut sums = vec![vec![0.0f32; dim]; k];
let mut counts = vec![0usize; k];
for (i, point) in data.iter().enumerate() {
let c = assignments[i];
counts[c] += 1;
for d in 0..dim {
sums[c][d] += point[d];
}
}
for c in 0..k {
if counts[c] > 0 {
for d in 0..dim {
centroids[c][d] = sums[c][d] / counts[c] as f32;
}
}
}
}
centroids
}
#[cfg(test)]
mod tests {
use super::*;
fn make_clustered_data() -> Vec<Vec<f32>> {
let mut vecs = Vec::new();
for cluster in 0..4 {
let center = cluster as f32 * 10.0;
for i in 0..50 {
vecs.push(vec![
center + (i as f32) * 0.1,
center + (i as f32) * 0.05,
center - (i as f32) * 0.1,
center + (i as f32) * 0.02,
]);
}
}
vecs
}
#[test]
fn encode_decode_roundtrip() {
let vecs = make_clustered_data();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = PqCodec::train(&refs, 4, 2, 16, 10);
for v in &vecs {
let code = codec.encode(v);
assert_eq!(code.len(), 2); let decoded = codec.decode(&code).unwrap();
assert_eq!(decoded.len(), 4);
}
}
#[test]
fn distance_table_gives_correct_ordering() {
let vecs = make_clustered_data();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = PqCodec::train(&refs, 4, 2, 16, 10);
let codes: Vec<Vec<u8>> = vecs.iter().map(|v| codec.encode(v)).collect();
let query = &[5.0, 5.0, 5.0, 5.0];
let table = codec.build_distance_table(query).unwrap();
let mut pq_dists: Vec<(usize, f32)> = codes
.iter()
.enumerate()
.map(|(i, c)| (i, codec.asymmetric_distance(&table, c)))
.collect();
pq_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let mut exact_dists: Vec<(usize, f32)> = vecs
.iter()
.enumerate()
.map(|(i, v)| (i, l2_sub(query, v)))
.collect();
exact_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let pq_top: std::collections::HashSet<usize> = pq_dists[..5].iter().map(|x| x.0).collect();
let exact_top: std::collections::HashSet<usize> =
exact_dists[..10].iter().map(|x| x.0).collect();
let overlap = pq_top.intersection(&exact_top).count();
assert!(overlap >= 3, "PQ recall too low: {overlap}/5 in top-10");
}
#[test]
fn batch_encode() {
let vecs = make_clustered_data();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = PqCodec::train(&refs, 4, 2, 16, 10);
let batch = codec.encode_batch(&refs).unwrap();
assert_eq!(batch.len(), 2 * 200); }
#[test]
fn pq_codec_golden_format() {
let vecs = make_clustered_data();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = PqCodec::train(&refs, 4, 2, 16, 10);
let bytes = codec.to_bytes().unwrap();
assert_eq!(&bytes[0..6], b"NDPQ\0\0", "magic mismatch");
assert_eq!(bytes[6], 1u8, "version must be 1");
let restored = zerompk::from_msgpack::<PqCodec>(&bytes[7..])
.expect("msgpack payload at offset 7 must decode");
assert_eq!(restored.dim, codec.dim);
assert_eq!(restored.m, codec.m);
}
#[test]
fn pq_version_mismatch_returns_error() {
let mut crafted = b"NDPQ\0\0".to_vec();
crafted.push(0u8); crafted.extend_from_slice(b"\x80");
let err = PqCodec::from_bytes(&crafted).unwrap_err();
assert!(
matches!(
err,
VectorError::UnsupportedVersion {
found: 0,
expected: 1
}
),
"expected UnsupportedVersion, got: {err:?}"
);
}
#[test]
fn pq_invalid_magic_returns_error() {
let bad: &[u8] = b"JUNK\0\0\x01some-payload";
let err = PqCodec::from_bytes(bad).unwrap_err();
assert!(
matches!(err, VectorError::InvalidMagic),
"expected InvalidMagic, got: {err:?}"
);
}
}