nodedb_vector/rerank/codecs/
rabitq.rs1use nodedb_codec::vector_quant::codec::VectorCodec as _;
19use nodedb_codec::vector_quant::hamming::hamming_distance;
20use nodedb_codec::vector_quant::layout::UnifiedQuantizedVectorRef;
21use nodedb_codec::vector_quant::rabitq::{RaBitQCodec, RaBitQQuery};
22
23use crate::{
24 rerank::codec::{CodecName, PreparedQuery, RerankCodec},
25 rerank::types::RerankError,
26};
27
28fn encode_payload(query: &RaBitQQuery) -> Vec<u8> {
31 let mut buf = Vec::with_capacity(4 + query.rotated_signs.len());
33 buf.extend_from_slice(&query.query_norm.to_le_bytes());
34 buf.extend_from_slice(&query.rotated_signs);
35 buf
36}
37
38fn decode_payload(payload: &[u8], dim: usize) -> Result<(f32, Vec<u8>), RerankError> {
39 let sign_len = dim.div_ceil(8);
40 let expected = 4 + sign_len;
41 if payload.len() != expected {
42 return Err(RerankError::BadInput(format!(
43 "rabitq distance: payload len {} != expected {} for dim {}",
44 payload.len(),
45 expected,
46 dim
47 )));
48 }
49 let query_norm = f32::from_le_bytes(
50 payload[..4]
51 .try_into()
52 .expect("slice of 4 bytes always converts to [u8;4]"),
53 );
54 Ok((query_norm, payload[4..].to_vec()))
55}
56
57pub const DEFAULT_ROTATION_SEED: u64 = 0x00C0_FFEE_00C0_FFEE;
61
62pub struct RaBitQRerank {
71 codec: Option<RaBitQCodec>,
72 dim: usize,
73 rotation_seed: u64,
74}
75
76impl RaBitQRerank {
77 pub fn new(dim: usize, rotation_seed: u64) -> Self {
82 Self {
83 codec: None,
84 dim,
85 rotation_seed,
86 }
87 }
88
89 pub fn from_codec(codec: RaBitQCodec) -> Self {
91 let dim = codec.dim;
92 Self {
93 codec: Some(codec),
94 dim,
95 rotation_seed: DEFAULT_ROTATION_SEED,
96 }
97 }
98}
99
100impl RerankCodec for RaBitQRerank {
101 fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
107 if v.len() != self.dim {
108 return Err(RerankError::BadInput(format!(
109 "rabitq encode: vector len {} != codec dim {}",
110 v.len(),
111 self.dim
112 )));
113 }
114 let codec = self.codec.as_ref().ok_or_else(|| {
115 RerankError::NotTrained(
116 "rabitq: codec must be trained before encoding (call train() with a sample of vectors)"
117 .to_string(),
118 )
119 })?;
120 let quantized = codec.encode(v);
121 Ok(quantized.as_ref().as_bytes().to_vec())
122 }
123
124 fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
130 if q.len() != self.dim {
131 return Err(RerankError::BadInput(format!(
132 "rabitq prepare_query: query len {} != codec dim {}",
133 q.len(),
134 self.dim
135 )));
136 }
137 let codec = self.codec.as_ref().ok_or_else(|| {
138 RerankError::NotTrained(
139 "rabitq: codec must be trained before prepare_query (call train() with a sample of vectors)"
140 .to_string(),
141 )
142 })?;
143 let query = codec.prepare_query(q);
144 Ok(PreparedQuery::Bytes(encode_payload(&query)))
145 }
146
147 fn distance_prepared(
157 &self,
158 prepared: &PreparedQuery,
159 encoded: &[u8],
160 ) -> Result<f32, RerankError> {
161 let payload = match prepared {
162 PreparedQuery::Bytes(b) => b.as_slice(),
163 _ => {
164 return Err(RerankError::BadInput(
165 "rabitq distance: prepared query is not Bytes".to_string(),
166 ));
167 }
168 };
169
170 let (query_norm, rotated_signs) = decode_payload(payload, self.dim)?;
171
172 let packed_len = self.dim.div_ceil(8);
173 let uqv_ref = UnifiedQuantizedVectorRef::from_bytes(encoded, packed_len).map_err(|e| {
174 RerankError::BadInput(format!(
175 "rabitq distance: failed to parse encoded bytes: {e}"
176 ))
177 })?;
178
179 let vh = uqv_ref.header();
180 let vb = uqv_ref.packed_bits();
181 let h = hamming_distance(&rotated_signs, vb);
182 let dim = self.dim as f32;
183 let dot_estimate = 1.0 - 2.0 * h as f32 / dim;
184 let approx = query_norm * query_norm + vh.residual_norm * vh.residual_norm
185 - 2.0 * query_norm * vh.residual_norm * dot_estimate;
186 Ok(approx.max(0.0))
187 }
188
189 fn name(&self) -> CodecName {
190 CodecName::RaBitQ
191 }
192
193 fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
194 let codec = self.codec.as_ref().ok_or_else(|| {
195 RerankError::NotTrained("rabitq sidecar serialize: codec not trained".to_string())
196 })?;
197 codec
198 .to_bytes()
199 .map_err(|e| RerankError::BadInput(format!("rabitq to_bytes: {e}")))
200 }
201
202 fn train(&mut self, samples: &[&[f32]]) -> Result<(), RerankError> {
211 if samples.is_empty() {
212 return Err(RerankError::BadInput(
213 "rabitq train: empty sample set".to_string(),
214 ));
215 }
216 for s in samples {
217 if s.len() != self.dim {
218 return Err(RerankError::BadInput(format!(
219 "rabitq train: sample has len {} but codec dim is {}",
220 s.len(),
221 self.dim
222 )));
223 }
224 }
225 let codec = RaBitQCodec::calibrate(samples, self.dim, self.rotation_seed);
226 self.codec = Some(codec);
227 Ok(())
228 }
229}
230
231#[cfg(test)]
234mod tests {
235 use super::*;
236
237 const DIM: usize = 16;
238 const N: usize = 64;
239
240 fn det_vec(i: usize, dim: usize) -> Vec<f32> {
241 (0..dim)
242 .map(|j| ((i * 31 + j) % 100) as f32 / 100.0)
243 .collect()
244 }
245
246 fn trained() -> RaBitQRerank {
247 let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
248 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
249 let mut codec = RaBitQRerank::new(DIM, DEFAULT_ROTATION_SEED);
250 codec.train(&refs).expect("train must succeed");
251 codec
252 }
253
254 #[test]
255 fn train_then_encode_roundtrip() {
256 let codec = trained();
257 let v = det_vec(0, DIM);
258 let enc = codec.encode(&v).expect("encode");
259 let prep = codec.prepare_query(&v).expect("prepare_query");
260 let dist = codec.distance_prepared(&prep, &enc).expect("distance");
261 assert!(dist.is_finite(), "distance must be finite, got {dist}");
262 assert!(dist >= 0.0, "distance must be non-negative, got {dist}");
263 }
264
265 #[test]
266 fn encode_before_train_returns_not_trained() {
267 let codec = RaBitQRerank::new(DIM, DEFAULT_ROTATION_SEED);
268 let v = det_vec(0, DIM);
269 let err = codec.encode(&v).unwrap_err();
270 let msg = format!("{err}");
271 assert!(
272 msg.contains("not trained") || msg.contains("trained"),
273 "expected 'trained' in error, got: {msg}"
274 );
275 }
276
277 #[test]
278 fn train_with_empty_samples_fails() {
279 let mut codec = RaBitQRerank::new(DIM, DEFAULT_ROTATION_SEED);
280 let err = codec.train(&[]).unwrap_err();
281 let msg = format!("{err}");
282 assert!(
283 msg.contains("bad input") || msg.contains("empty"),
284 "expected bad input error, got: {msg}"
285 );
286 }
287
288 #[test]
289 fn train_with_dim_mismatch_fails() {
290 let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
291 let mut refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
292 let bad = det_vec(0, DIM + 4);
293 refs.push(bad.as_slice());
294 let mut codec = RaBitQRerank::new(DIM, DEFAULT_ROTATION_SEED);
295 let err = codec.train(&refs).unwrap_err();
296 let msg = format!("{err}");
297 assert!(
298 msg.contains("bad input") || msg.contains("dim"),
299 "expected bad input error, got: {msg}"
300 );
301 }
302
303 #[test]
304 fn prepare_query_wrong_dim_fails() {
305 let codec = trained();
306 let bad = det_vec(0, DIM + 2);
307 match codec.prepare_query(&bad) {
308 Err(e) => {
309 let msg = format!("{e}");
310 assert!(
311 msg.contains("bad input") || msg.contains("dim"),
312 "expected bad input error, got: {msg}"
313 );
314 }
315 Ok(_) => panic!("expected an error for wrong dim"),
316 }
317 }
318
319 #[test]
320 fn distance_prepared_wrong_variant_fails() {
321 let codec = trained();
322 let v = det_vec(0, DIM);
323 let enc = codec.encode(&v).expect("encode");
324 let bad_prepared = PreparedQuery::Raw(vec![0.0f32; DIM]);
325 let err = codec.distance_prepared(&bad_prepared, &enc).unwrap_err();
326 let msg = format!("{err}");
327 assert!(
328 msg.contains("Bytes") || msg.contains("not Bytes"),
329 "error message should mention Bytes variant, got: {msg}"
330 );
331 }
332
333 #[test]
334 fn name_is_expected() {
335 let codec = RaBitQRerank::new(DIM, DEFAULT_ROTATION_SEED);
336 assert_eq!(codec.name(), CodecName::RaBitQ);
337 }
338}