nodedb_vector/rerank/codecs/
bbq.rs1use nodedb_codec::vector_quant::bbq::BbqCodec;
19use nodedb_codec::vector_quant::codec::VectorCodec as _;
20use nodedb_codec::vector_quant::layout::UnifiedQuantizedVectorRef;
21
22use crate::{
23 rerank::codec::{CodecName, PreparedQuery, RerankCodec},
24 rerank::types::RerankError,
25};
26
27fn encode_payload(query_norm: f32, centered: &[f32]) -> Vec<u8> {
30 let mut buf = Vec::with_capacity(4 + centered.len() * 4);
32 buf.extend_from_slice(&query_norm.to_le_bytes());
33 for &x in centered {
34 buf.extend_from_slice(&x.to_le_bytes());
35 }
36 buf
37}
38
39fn decode_payload(payload: &[u8], dim: usize) -> Result<(f32, Vec<f32>), RerankError> {
40 let expected = 4 + dim * 4;
41 if payload.len() != expected {
42 return Err(RerankError::BadInput(format!(
43 "bbq distance: payload len {} != expected {} for dim {}",
44 payload.len(),
45 expected,
46 dim
47 )));
48 }
49 let query_norm = f32::from_le_bytes(
50 payload[..4]
51 .try_into()
52 .expect("slice of 4 bytes always converts to [u8;4]"),
53 );
54 let centered: Vec<f32> = payload[4..]
55 .chunks_exact(4)
56 .map(|b| f32::from_le_bytes(b.try_into().expect("chunks_exact(4) always 4 bytes")))
57 .collect();
58 Ok((query_norm, centered))
59}
60
61#[inline]
69fn bbq_dequantize(packed: &[u8], residual_norm: f32, dim: usize) -> Vec<f32> {
70 let scale = if dim > 0 {
71 residual_norm / (dim as f32).sqrt()
72 } else {
73 0.0
74 };
75 (0..dim)
76 .map(|i| {
77 let bit = (packed[i / 8] >> (7 - (i % 8))) & 1;
78 if bit != 0 { scale } else { -scale }
79 })
80 .collect()
81}
82
83pub const DEFAULT_OVERSAMPLE: u8 = 4;
87
88pub struct BbqRerank {
97 codec: Option<BbqCodec>,
98 dim: usize,
99 oversample: u8,
100}
101
102impl BbqRerank {
103 pub fn new(dim: usize, oversample: u8) -> Self {
108 Self {
109 codec: None,
110 dim,
111 oversample,
112 }
113 }
114
115 pub fn from_codec(codec: BbqCodec) -> Self {
117 let dim = codec.dim;
118 Self {
119 codec: Some(codec),
120 dim,
121 oversample: DEFAULT_OVERSAMPLE,
122 }
123 }
124}
125
126impl RerankCodec for BbqRerank {
127 fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
133 if v.len() != self.dim {
134 return Err(RerankError::BadInput(format!(
135 "bbq encode: vector len {} != codec dim {}",
136 v.len(),
137 self.dim
138 )));
139 }
140 let codec = self.codec.as_ref().ok_or_else(|| {
141 RerankError::NotTrained(
142 "bbq: codec must be trained before encoding (call train() with a sample of vectors)"
143 .to_string(),
144 )
145 })?;
146 let quantized = codec.encode(v);
147 Ok(quantized.as_ref().as_bytes().to_vec())
148 }
149
150 fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
156 if q.len() != self.dim {
157 return Err(RerankError::BadInput(format!(
158 "bbq prepare_query: query len {} != codec dim {}",
159 q.len(),
160 self.dim
161 )));
162 }
163 let codec = self.codec.as_ref().ok_or_else(|| {
164 RerankError::NotTrained(
165 "bbq: codec must be trained before prepare_query (call train() with a sample of vectors)"
166 .to_string(),
167 )
168 })?;
169 let query = codec.prepare_query(q);
170 Ok(PreparedQuery::Bytes(encode_payload(
171 query.query_norm,
172 &query.centered,
173 )))
174 }
175
176 fn distance_prepared(
185 &self,
186 prepared: &PreparedQuery,
187 encoded: &[u8],
188 ) -> Result<f32, RerankError> {
189 let payload = match prepared {
190 PreparedQuery::Bytes(b) => b.as_slice(),
191 _ => {
192 return Err(RerankError::BadInput(
193 "bbq distance: prepared query is not Bytes".to_string(),
194 ));
195 }
196 };
197
198 let (_query_norm, centered) = decode_payload(payload, self.dim)?;
199
200 let packed_len = self.dim.div_ceil(8);
201 let uqv_ref = UnifiedQuantizedVectorRef::from_bytes(encoded, packed_len).map_err(|e| {
202 RerankError::BadInput(format!("bbq distance: failed to parse encoded bytes: {e}"))
203 })?;
204
205 let header = uqv_ref.header();
206 let recon = bbq_dequantize(uqv_ref.packed_bits(), header.residual_norm, self.dim);
207 let dist = centered
208 .iter()
209 .zip(recon.iter())
210 .map(|(&a, &b)| (a - b) * (a - b))
211 .sum::<f32>()
212 .sqrt();
213 Ok(dist)
214 }
215
216 fn name(&self) -> CodecName {
217 CodecName::Bbq
218 }
219
220 fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
221 let codec = self.codec.as_ref().ok_or_else(|| {
222 RerankError::NotTrained("bbq sidecar serialize: codec not trained".to_string())
223 })?;
224 codec
225 .to_bytes()
226 .map_err(|e| RerankError::BadInput(format!("bbq to_bytes: {e}")))
227 }
228
229 fn train(&mut self, samples: &[&[f32]]) -> Result<(), RerankError> {
238 if samples.is_empty() {
239 return Err(RerankError::BadInput(
240 "bbq train: empty sample set".to_string(),
241 ));
242 }
243 for s in samples {
244 if s.len() != self.dim {
245 return Err(RerankError::BadInput(format!(
246 "bbq train: sample has len {} but codec dim is {}",
247 s.len(),
248 self.dim
249 )));
250 }
251 }
252 let codec = BbqCodec::calibrate(samples, self.dim, self.oversample);
253 self.codec = Some(codec);
254 Ok(())
255 }
256}
257
258#[cfg(test)]
261mod tests {
262 use super::*;
263
264 const DIM: usize = 16;
265 const N: usize = 64;
266
267 fn det_vec(i: usize, dim: usize) -> Vec<f32> {
268 (0..dim)
269 .map(|j| ((i * 31 + j) % 100) as f32 / 100.0)
270 .collect()
271 }
272
273 fn trained() -> BbqRerank {
274 let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
275 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
276 let mut codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
277 codec.train(&refs).expect("train must succeed");
278 codec
279 }
280
281 #[test]
282 fn train_then_encode_roundtrip() {
283 let codec = trained();
284 let v = det_vec(0, DIM);
285 let enc = codec.encode(&v).expect("encode");
286 let prep = codec.prepare_query(&v).expect("prepare_query");
287 let dist = codec.distance_prepared(&prep, &enc).expect("distance");
288 assert!(dist.is_finite(), "distance must be finite, got {dist}");
289 assert!(dist >= 0.0, "distance must be non-negative, got {dist}");
290 }
291
292 #[test]
293 fn encode_before_train_returns_not_trained() {
294 let codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
295 let v = det_vec(0, DIM);
296 let err = codec.encode(&v).unwrap_err();
297 let msg = format!("{err}");
298 assert!(
299 msg.contains("not trained") || msg.contains("trained"),
300 "expected 'trained' in error, got: {msg}"
301 );
302 }
303
304 #[test]
305 fn train_with_empty_samples_fails() {
306 let mut codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
307 let err = codec.train(&[]).unwrap_err();
308 let msg = format!("{err}");
309 assert!(
310 msg.contains("bad input") || msg.contains("empty"),
311 "expected bad input error, got: {msg}"
312 );
313 }
314
315 #[test]
316 fn train_with_dim_mismatch_fails() {
317 let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
318 let mut refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
319 let bad = det_vec(0, DIM + 4);
320 refs.push(bad.as_slice());
321 let mut codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
322 let err = codec.train(&refs).unwrap_err();
323 let msg = format!("{err}");
324 assert!(
325 msg.contains("bad input") || msg.contains("dim"),
326 "expected bad input error, got: {msg}"
327 );
328 }
329
330 #[test]
331 fn prepare_query_wrong_dim_fails() {
332 let codec = trained();
333 let bad = det_vec(0, DIM + 2);
334 match codec.prepare_query(&bad) {
335 Err(e) => {
336 let msg = format!("{e}");
337 assert!(
338 msg.contains("bad input") || msg.contains("dim"),
339 "expected bad input error, got: {msg}"
340 );
341 }
342 Ok(_) => panic!("expected an error for wrong dim"),
343 }
344 }
345
346 #[test]
347 fn distance_prepared_wrong_variant_fails() {
348 let codec = trained();
349 let v = det_vec(0, DIM);
350 let enc = codec.encode(&v).expect("encode");
351 let bad_prepared = PreparedQuery::Raw(vec![0.0f32; DIM]);
352 let err = codec.distance_prepared(&bad_prepared, &enc).unwrap_err();
353 let msg = format!("{err}");
354 assert!(
355 msg.contains("Bytes") || msg.contains("not Bytes"),
356 "error message should mention Bytes variant, got: {msg}"
357 );
358 }
359
360 #[test]
361 fn name_is_expected() {
362 let codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
363 assert_eq!(codec.name(), CodecName::Bbq);
364 }
365}