use nodedb_codec::vector_quant::bbq::BbqCodec;
use nodedb_codec::vector_quant::codec::VectorCodec as _;
use nodedb_codec::vector_quant::layout::UnifiedQuantizedVectorRef;
use crate::{
rerank::codec::{CodecName, PreparedQuery, RerankCodec},
rerank::types::RerankError,
};
fn encode_payload(query_norm: f32, centered: &[f32]) -> Vec<u8> {
let mut buf = Vec::with_capacity(4 + centered.len() * 4);
buf.extend_from_slice(&query_norm.to_le_bytes());
for &x in centered {
buf.extend_from_slice(&x.to_le_bytes());
}
buf
}
fn decode_payload(payload: &[u8], dim: usize) -> Result<(f32, Vec<f32>), RerankError> {
let expected = 4 + dim * 4;
if payload.len() != expected {
return Err(RerankError::BadInput(format!(
"bbq distance: payload len {} != expected {} for dim {}",
payload.len(),
expected,
dim
)));
}
let query_norm = f32::from_le_bytes(
payload[..4]
.try_into()
.expect("slice of 4 bytes always converts to [u8;4]"),
);
let centered: Vec<f32> = payload[4..]
.chunks_exact(4)
.map(|b| f32::from_le_bytes(b.try_into().expect("chunks_exact(4) always 4 bytes")))
.collect();
Ok((query_norm, centered))
}
#[inline]
fn bbq_dequantize(packed: &[u8], residual_norm: f32, dim: usize) -> Vec<f32> {
let scale = if dim > 0 {
residual_norm / (dim as f32).sqrt()
} else {
0.0
};
(0..dim)
.map(|i| {
let bit = (packed[i / 8] >> (7 - (i % 8))) & 1;
if bit != 0 { scale } else { -scale }
})
.collect()
}
pub const DEFAULT_OVERSAMPLE: u8 = 4;
pub struct BbqRerank {
codec: Option<BbqCodec>,
dim: usize,
oversample: u8,
}
impl BbqRerank {
pub fn new(dim: usize, oversample: u8) -> Self {
Self {
codec: None,
dim,
oversample,
}
}
pub fn from_codec(codec: BbqCodec) -> Self {
let dim = codec.dim;
Self {
codec: Some(codec),
dim,
oversample: DEFAULT_OVERSAMPLE,
}
}
}
impl RerankCodec for BbqRerank {
fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
if v.len() != self.dim {
return Err(RerankError::BadInput(format!(
"bbq encode: vector len {} != codec dim {}",
v.len(),
self.dim
)));
}
let codec = self.codec.as_ref().ok_or_else(|| {
RerankError::NotTrained(
"bbq: codec must be trained before encoding (call train() with a sample of vectors)"
.to_string(),
)
})?;
let quantized = codec.encode(v);
Ok(quantized.as_ref().as_bytes().to_vec())
}
fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
if q.len() != self.dim {
return Err(RerankError::BadInput(format!(
"bbq prepare_query: query len {} != codec dim {}",
q.len(),
self.dim
)));
}
let codec = self.codec.as_ref().ok_or_else(|| {
RerankError::NotTrained(
"bbq: codec must be trained before prepare_query (call train() with a sample of vectors)"
.to_string(),
)
})?;
let query = codec.prepare_query(q);
Ok(PreparedQuery::Bytes(encode_payload(
query.query_norm,
&query.centered,
)))
}
fn distance_prepared(
&self,
prepared: &PreparedQuery,
encoded: &[u8],
) -> Result<f32, RerankError> {
let payload = match prepared {
PreparedQuery::Bytes(b) => b.as_slice(),
_ => {
return Err(RerankError::BadInput(
"bbq distance: prepared query is not Bytes".to_string(),
));
}
};
let (_query_norm, centered) = decode_payload(payload, self.dim)?;
let packed_len = self.dim.div_ceil(8);
let uqv_ref = UnifiedQuantizedVectorRef::from_bytes(encoded, packed_len).map_err(|e| {
RerankError::BadInput(format!("bbq distance: failed to parse encoded bytes: {e}"))
})?;
let header = uqv_ref.header();
let recon = bbq_dequantize(uqv_ref.packed_bits(), header.residual_norm, self.dim);
let dist = centered
.iter()
.zip(recon.iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum::<f32>()
.sqrt();
Ok(dist)
}
fn name(&self) -> CodecName {
CodecName::Bbq
}
fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
let codec = self.codec.as_ref().ok_or_else(|| {
RerankError::NotTrained("bbq sidecar serialize: codec not trained".to_string())
})?;
codec
.to_bytes()
.map_err(|e| RerankError::BadInput(format!("bbq to_bytes: {e}")))
}
fn train(&mut self, samples: &[&[f32]]) -> Result<(), RerankError> {
if samples.is_empty() {
return Err(RerankError::BadInput(
"bbq train: empty sample set".to_string(),
));
}
for s in samples {
if s.len() != self.dim {
return Err(RerankError::BadInput(format!(
"bbq train: sample has len {} but codec dim is {}",
s.len(),
self.dim
)));
}
}
let codec = BbqCodec::calibrate(samples, self.dim, self.oversample);
self.codec = Some(codec);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
const DIM: usize = 16;
const N: usize = 64;
fn det_vec(i: usize, dim: usize) -> Vec<f32> {
(0..dim)
.map(|j| ((i * 31 + j) % 100) as f32 / 100.0)
.collect()
}
fn trained() -> BbqRerank {
let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let mut codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
codec.train(&refs).expect("train must succeed");
codec
}
#[test]
fn train_then_encode_roundtrip() {
let codec = trained();
let v = det_vec(0, DIM);
let enc = codec.encode(&v).expect("encode");
let prep = codec.prepare_query(&v).expect("prepare_query");
let dist = codec.distance_prepared(&prep, &enc).expect("distance");
assert!(dist.is_finite(), "distance must be finite, got {dist}");
assert!(dist >= 0.0, "distance must be non-negative, got {dist}");
}
#[test]
fn encode_before_train_returns_not_trained() {
let codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
let v = det_vec(0, DIM);
let err = codec.encode(&v).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("not trained") || msg.contains("trained"),
"expected 'trained' in error, got: {msg}"
);
}
#[test]
fn train_with_empty_samples_fails() {
let mut codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
let err = codec.train(&[]).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("bad input") || msg.contains("empty"),
"expected bad input error, got: {msg}"
);
}
#[test]
fn train_with_dim_mismatch_fails() {
let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
let mut refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let bad = det_vec(0, DIM + 4);
refs.push(bad.as_slice());
let mut codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
let err = codec.train(&refs).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("bad input") || msg.contains("dim"),
"expected bad input error, got: {msg}"
);
}
#[test]
fn prepare_query_wrong_dim_fails() {
let codec = trained();
let bad = det_vec(0, DIM + 2);
match codec.prepare_query(&bad) {
Err(e) => {
let msg = format!("{e}");
assert!(
msg.contains("bad input") || msg.contains("dim"),
"expected bad input error, got: {msg}"
);
}
Ok(_) => panic!("expected an error for wrong dim"),
}
}
#[test]
fn distance_prepared_wrong_variant_fails() {
let codec = trained();
let v = det_vec(0, DIM);
let enc = codec.encode(&v).expect("encode");
let bad_prepared = PreparedQuery::Raw(vec![0.0f32; DIM]);
let err = codec.distance_prepared(&bad_prepared, &enc).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("Bytes") || msg.contains("not Bytes"),
"error message should mention Bytes variant, got: {msg}"
);
}
#[test]
fn name_is_expected() {
let codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
assert_eq!(codec.name(), CodecName::Bbq);
}
}