use nodedb_codec::vector_quant::{
codec::{AdcLut, VectorCodec},
layout::{QuantHeader, QuantMode, UnifiedQuantizedVector},
};
use crate::quantize::pq::PqCodec;
pub struct PqQuantized(pub UnifiedQuantizedVector);
impl AsRef<UnifiedQuantizedVector> for PqQuantized {
#[inline]
fn as_ref(&self) -> &UnifiedQuantizedVector {
&self.0
}
}
pub struct PqQuery {
pub distance_table: Vec<Vec<f32>>,
pub raw: Vec<f32>,
}
#[inline]
fn packed_bits_of(q: &PqQuantized) -> &[u8] {
q.0.packed_bits()
}
impl VectorCodec for PqCodec {
type Quantized = PqQuantized;
type Query = PqQuery;
fn encode(&self, v: &[f32]) -> Self::Quantized {
let codes = self.encode(v);
let header = QuantHeader {
quant_mode: QuantMode::Pq as u16,
dim: self.dim as u16,
global_scale: 0.0,
residual_norm: 0.0,
dot_quantized: 0.0,
outlier_bitmask: 0,
reserved: [0; 8],
};
let uqv = UnifiedQuantizedVector::new(header, &codes, &[])
.expect("PqCodec::encode: layout construction is infallible (no outliers)");
PqQuantized(uqv)
}
fn prepare_query(&self, q: &[f32]) -> Self::Query {
let distance_table = self.build_distance_table(q).expect(
"PqCodec::prepare_query: build_distance_table failed; \
if a governor is attached ensure it has sufficient budget",
);
PqQuery {
distance_table,
raw: q.to_vec(),
}
}
fn adc_lut(&self, q: &Self::Query) -> Option<AdcLut> {
let m = self.m as u16;
let k = self.k as u16;
let mut lut = AdcLut::new(m, k);
for (sub, sub_table) in q.distance_table.iter().enumerate() {
for (centroid, &dist) in sub_table.iter().enumerate() {
let idx = sub * self.k + centroid;
lut.table[idx] = dist;
}
}
Some(lut)
}
#[inline]
fn fast_symmetric_distance(&self, q: &Self::Quantized, v: &Self::Quantized) -> f32 {
let dq_a = self
.decode(packed_bits_of(q))
.expect("PqCodec::fast_symmetric_distance: decode failed");
let dq_b = self
.decode(packed_bits_of(v))
.expect("PqCodec::fast_symmetric_distance: decode failed");
dq_a.iter()
.zip(dq_b.iter())
.map(|(&a, &b)| {
let d = a - b;
d * d
})
.sum()
}
#[inline]
fn exact_asymmetric_distance(&self, q: &Self::Query, v: &Self::Quantized) -> f32 {
self.asymmetric_distance(&q.distance_table, packed_bits_of(v))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_codec() -> PqCodec {
let vecs: Vec<Vec<f32>> = (0..80)
.map(|i| {
let c = (i / 20) as f32 * 5.0;
vec![
c + (i as f32) * 0.1,
c - (i as f32) * 0.05,
c + 1.0,
c - 1.0,
]
})
.collect();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
PqCodec::train(&refs, 4, 2, 8, 10)
}
#[test]
fn encode_packed_bits_matches_raw_encode() {
let codec = make_codec();
let v = vec![2.0f32, 1.0, 3.0, -1.0];
let raw = codec.encode(&v);
let quantized = <PqCodec as VectorCodec>::encode(&codec, &v);
assert_eq!(quantized.as_ref().packed_bits(), raw.as_slice());
}
#[test]
fn fast_symmetric_distance_is_non_negative_finite() {
let codec = make_codec();
let a = <PqCodec as VectorCodec>::encode(&codec, &[0.5, 0.1, 1.0, -0.5]);
let b = <PqCodec as VectorCodec>::encode(&codec, &[5.0, 4.0, 6.0, 4.5]);
let d = codec.fast_symmetric_distance(&a, &b);
assert!(d.is_finite(), "expected finite distance, got {d}");
assert!(d >= 0.0, "expected non-negative distance, got {d}");
}
#[test]
fn exact_asymmetric_distance_is_non_negative_finite() {
let codec = make_codec();
let q = codec.prepare_query(&[0.5, 0.1, 1.0, -0.5]);
let v = <PqCodec as VectorCodec>::encode(&codec, &[5.0, 4.0, 6.0, 4.5]);
let d = codec.exact_asymmetric_distance(&q, &v);
assert!(d.is_finite(), "expected finite distance, got {d}");
assert!(d >= 0.0, "expected non-negative distance, got {d}");
}
#[test]
fn adc_lut_has_correct_shape() {
let codec = make_codec();
let q = codec.prepare_query(&[0.5, 0.1, 1.0, -0.5]);
let lut =
<PqCodec as VectorCodec>::adc_lut(&codec, &q).expect("PqCodec must produce an AdcLut");
assert_eq!(lut.subspace_count, codec.m as u16);
assert_eq!(lut.centroids_per_subspace, codec.k as u16);
assert_eq!(lut.table.len(), codec.m * codec.k);
assert!(lut.table.iter().all(|v| v.is_finite()));
}
fn use_vector_codec<C: VectorCodec>(c: &C, q: &[f32], v: &[f32]) -> f32 {
let qv = c.encode(v);
let qq = c.prepare_query(q);
c.fast_symmetric_distance(&qv, &qv) + c.exact_asymmetric_distance(&qq, &qv)
}
#[test]
fn trait_bounds_compile() {
let codec = make_codec();
let result = use_vector_codec(&codec, &[0.5, 0.1, 1.0, -0.5], &[5.0, 4.0, 6.0, 4.5]);
assert!(result.is_finite());
}
}