nodedb_codec/vector_quant/
rabitq.rs1use crate::error::CodecError;
33use crate::vector_quant::codec::VectorCodec;
34use crate::vector_quant::codec_envelope;
35use crate::vector_quant::hamming::hamming_distance;
36use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
37use serde::{Deserialize, Serialize};
38
39#[inline]
43fn xorshift64(state: &mut u64) -> u64 {
44 let mut x = *state;
45 x ^= x << 13;
46 x ^= x >> 7;
47 x ^= x << 17;
48 *state = x;
49 x
50}
51
52#[inline]
56fn next_pow2(n: usize) -> usize {
57 if n.is_power_of_two() {
58 n
59 } else {
60 n.next_power_of_two()
61 }
62}
63
64fn wht_inplace(buf: &mut [f32]) {
68 let n = buf.len();
69 debug_assert!(n.is_power_of_two());
70 let mut step = 1usize;
71 while step < n {
72 let mut i = 0usize;
73 while i < n {
74 for j in i..i + step {
75 let a = buf[j];
76 let b = buf[j + step];
77 buf[j] = a + b;
78 buf[j + step] = a - b;
79 }
80 i += step * 2;
81 }
82 step *= 2;
83 }
84}
85
86fn sign_pack(rotated: &[f32], dim: usize) -> Vec<u8> {
91 let nbytes = dim.div_ceil(8);
92 let mut out = vec![0u8; nbytes];
93 for (i, &v) in rotated.iter().take(dim).enumerate() {
94 if v < 0.0 {
95 out[i / 8] |= 1 << (i % 8);
96 }
97 }
98 out
99}
100
101fn sign_unpack(packed: &[u8], dim: usize) -> Vec<f32> {
103 (0..dim)
104 .map(|i| {
105 if packed[i / 8] & (1 << (i % 8)) != 0 {
106 -1.0f32
107 } else {
108 1.0f32
109 }
110 })
111 .collect()
112}
113
114#[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
120pub struct RaBitQCodec {
121 pub dim: usize,
122 centroid: Vec<f32>,
124 rotation_seed: u64,
126 pub bias_correct: bool,
129}
130
131impl RaBitQCodec {
132 pub fn calibrate(vectors: &[&[f32]], dim: usize, rotation_seed: u64) -> Self {
141 let centroid = if vectors.is_empty() {
142 vec![0.0f32; dim]
143 } else {
144 let n = vectors.len() as f32;
145 let mut c = vec![0.0f32; dim];
146 for v in vectors {
147 for (ci, &vi) in c.iter_mut().zip(v.iter()) {
148 *ci += vi;
149 }
150 }
151 c.iter_mut().for_each(|x| *x /= n);
152 c
153 };
154 Self {
155 dim,
156 centroid,
157 rotation_seed,
158 bias_correct: false,
159 }
160 }
161
162 pub const ENVELOPE_MAGIC: &'static [u8; codec_envelope::MAGIC_LEN] = b"NDRBQ";
164
165 pub const ENVELOPE_VERSION: u8 = 1;
167
168 pub fn to_bytes(&self) -> Result<Vec<u8>, CodecError> {
170 codec_envelope::encode(Self::ENVELOPE_MAGIC, Self::ENVELOPE_VERSION, self)
171 }
172
173 pub fn from_bytes(buf: &[u8]) -> Result<Self, CodecError> {
175 codec_envelope::decode(Self::ENVELOPE_MAGIC, Self::ENVELOPE_VERSION, buf)
176 }
177
178 pub fn apply_rotation(&self, v: &[f32]) -> Vec<f32> {
186 let dim = self.dim;
187 let pow2 = next_pow2(dim);
188
189 let mut seed = self.rotation_seed;
191 let mut buf = vec![0.0f32; pow2];
192 for (i, &vi) in v.iter().take(dim).enumerate() {
193 let flip = if xorshift64(&mut seed) & 1 == 0 {
194 1.0f32
195 } else {
196 -1.0f32
197 };
198 buf[i] = vi * flip;
199 }
200 wht_inplace(&mut buf);
203 buf.truncate(dim);
204 buf
205 }
206
207 fn encode_inner(&self, v: &[f32]) -> UnifiedQuantizedVector {
218 let dim = self.dim;
219
220 let residual: Vec<f32> = v
222 .iter()
223 .zip(self.centroid.iter())
224 .map(|(&vi, &ci)| vi - ci)
225 .collect();
226
227 let residual_norm = residual.iter().map(|x| x * x).sum::<f32>().sqrt();
229
230 let rotated = self.apply_rotation(&residual);
232
233 let packed = sign_pack(&rotated, dim);
235
236 let signs_fp = sign_unpack(&packed, dim);
239 let pow2 = next_pow2(dim);
241 let mut sign_buf = vec![0.0f32; pow2];
242 for (i, &s) in signs_fp.iter().enumerate() {
243 sign_buf[i] = s;
244 }
245 wht_inplace(&mut sign_buf);
246 let mut seed = self.rotation_seed;
248 #[allow(clippy::needless_range_loop)]
249 for i in 0..dim {
250 let flip = if xorshift64(&mut seed) & 1 == 0 {
251 1.0f32
252 } else {
253 -1.0f32
254 };
255 sign_buf[i] *= flip;
256 }
257 let dot_raw: f32 = residual
258 .iter()
259 .zip(sign_buf.iter().take(dim))
260 .map(|(&r, &s)| r * s)
261 .sum();
262 let dot_quantized = if residual_norm > 0.0 {
263 dot_raw / residual_norm
264 } else {
265 0.0
266 };
267
268 let header = QuantHeader {
269 quant_mode: QuantMode::RaBitQ as u16,
270 dim: dim as u16,
271 global_scale: residual_norm,
272 residual_norm,
273 dot_quantized,
274 outlier_bitmask: 0,
275 reserved: [0u8; 8],
276 };
277
278 UnifiedQuantizedVector::new(header, &packed, &[])
279 .expect("RaBitQ encode: layout construction must succeed")
280 }
281}
282
283pub struct RaBitQQuantized(UnifiedQuantizedVector);
287
288impl AsRef<UnifiedQuantizedVector> for RaBitQQuantized {
289 #[inline]
290 fn as_ref(&self) -> &UnifiedQuantizedVector {
291 &self.0
292 }
293}
294
295pub struct RaBitQQuery {
297 pub rotated_signs: Vec<u8>,
299 pub query_norm: f32,
301}
302
303impl VectorCodec for RaBitQCodec {
306 type Quantized = RaBitQQuantized;
307 type Query = RaBitQQuery;
308
309 fn encode(&self, v: &[f32]) -> Self::Quantized {
310 RaBitQQuantized(self.encode_inner(v))
311 }
312
313 fn prepare_query(&self, q: &[f32]) -> Self::Query {
314 let dim = self.dim;
315 let residual: Vec<f32> = q
316 .iter()
317 .zip(self.centroid.iter())
318 .map(|(&qi, &ci)| qi - ci)
319 .collect();
320 let query_norm = residual.iter().map(|x| x * x).sum::<f32>().sqrt();
321 let rotated = self.apply_rotation(&residual);
322 let rotated_signs = sign_pack(&rotated, dim);
323 RaBitQQuery {
324 rotated_signs,
325 query_norm,
326 }
327 }
328
329 fn fast_symmetric_distance(&self, q: &Self::Quantized, v: &Self::Quantized) -> f32 {
336 let qh = q.0.header();
337 let vh = v.0.header();
338 let qb = q.0.packed_bits();
339 let vb = v.0.packed_bits();
340 let h = hamming_distance(qb, vb);
341 let dim = self.dim as f32;
342 let dot_estimate = 1.0 - 2.0 * h as f32 / dim;
343 let approx = qh.residual_norm * qh.residual_norm + vh.residual_norm * vh.residual_norm
344 - 2.0 * qh.residual_norm * vh.residual_norm * dot_estimate;
345 approx.max(0.0)
346 }
347
348 fn exact_asymmetric_distance(&self, q: &Self::Query, v: &Self::Quantized) -> f32 {
357 let vh = v.0.header();
358 let vb = v.0.packed_bits();
359 let h = hamming_distance(&q.rotated_signs, vb);
360 let dim = self.dim as f32;
361 let dot_estimate = 1.0 - 2.0 * h as f32 / dim;
362 let mut approx = q.query_norm * q.query_norm + vh.residual_norm * vh.residual_norm
363 - 2.0 * q.query_norm * vh.residual_norm * dot_estimate;
364 if self.bias_correct {
365 approx -= vh.dot_quantized;
366 }
367 approx.max(0.0)
368 }
369}
370
371#[cfg(test)]
374mod tests {
375 use super::*;
376
377 fn random_vec(seed: u64, dim: usize) -> Vec<f32> {
378 let mut s = seed | 1;
379 (0..dim)
380 .map(|_| {
381 let v = xorshift64(&mut s);
382 (v as f32 / u64::MAX as f32) * 2.0 - 1.0
384 })
385 .collect()
386 }
387
388 #[test]
389 fn to_bytes_from_bytes_roundtrip() {
390 let dim = 64;
391 let vecs: Vec<Vec<f32>> = (0..4).map(|i| random_vec(i as u64, dim)).collect();
392 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
393 let codec = RaBitQCodec::calibrate(&refs, dim, 0xABCD_1234_5678_EF01);
394 let bytes = codec.to_bytes().expect("to_bytes should succeed");
395 let restored = RaBitQCodec::from_bytes(&bytes).expect("from_bytes should succeed");
396 assert_eq!(restored.dim, codec.dim);
397 assert_eq!(restored.rotation_seed, codec.rotation_seed);
398 assert_eq!(restored.bias_correct, codec.bias_correct);
399 assert_eq!(restored.centroid.len(), codec.centroid.len());
400 for (a, b) in restored.centroid.iter().zip(codec.centroid.iter()) {
401 assert!((a - b).abs() < 1e-6, "centroid mismatch: {a} vs {b}");
402 }
403 }
404
405 #[test]
406 fn from_bytes_rejects_bad_magic() {
407 let mut bytes = b"WRONG".to_vec();
408 bytes.push(1);
409 bytes.extend_from_slice(&[0u8; 4]);
410 assert!(RaBitQCodec::from_bytes(&bytes).is_err());
411 }
412
413 #[test]
414 fn from_bytes_rejects_bad_version() {
415 let codec = RaBitQCodec::calibrate(&[], 4, 1);
416 let mut bytes = codec.to_bytes().unwrap();
417 bytes[5] = 99;
418 assert!(RaBitQCodec::from_bytes(&bytes).is_err());
419 }
420
421 #[test]
422 fn apply_rotation_different_seeds_differ() {
423 let dim = 64;
424 let v: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
425 let codec_a = RaBitQCodec::calibrate(&[], dim, 0xDEAD_BEEF_1234_5678);
426 let codec_b = RaBitQCodec::calibrate(&[], dim, 0xCAFE_BABE_0000_0001);
427 let rot_a = codec_a.apply_rotation(&v);
428 let rot_b = codec_b.apply_rotation(&v);
429 let differ = rot_a
431 .iter()
432 .zip(rot_b.iter())
433 .any(|(a, b)| (a - b).abs() > 1e-6);
434 assert!(differ, "different seeds must produce different rotations");
435 }
436
437 #[test]
438 fn encode_roundtrip_preserves_residual_norm() {
439 let dim = 128;
440 let vecs: Vec<Vec<f32>> = (0..16).map(|i| random_vec(i as u64, dim)).collect();
441 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
442 let codec = RaBitQCodec::calibrate(&refs, dim, 42);
443 let v = random_vec(99, dim);
444 let q = codec.encode(&v);
445 let h = q.0.header();
446 assert!(h.residual_norm.is_finite() && h.residual_norm >= 0.0);
448 assert!((h.global_scale - h.residual_norm).abs() < 1e-6);
449 }
450
451 #[test]
452 fn distance_non_negative_finite() {
453 let dim = 64;
454 let vecs: Vec<Vec<f32>> = (0..8).map(|i| random_vec(i as u64, dim)).collect();
455 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
456 let codec = RaBitQCodec::calibrate(&refs, dim, 7);
457 let v1 = codec.encode(&random_vec(100, dim));
458 let v2 = codec.encode(&random_vec(200, dim));
459 let sym = codec.fast_symmetric_distance(&v1, &v2);
460 assert!(sym.is_finite() && sym >= 0.0, "sym distance: {sym}");
461 let q = codec.prepare_query(&random_vec(300, dim));
462 let asym = codec.exact_asymmetric_distance(&q, &v2);
463 assert!(asym.is_finite() && asym >= 0.0, "asym distance: {asym}");
464 }
465
466 #[test]
467 fn calibrate_identical_vectors_zero_residual() {
468 let dim = 32;
469 let v: Vec<f32> = (0..dim).map(|i| i as f32).collect();
470 let refs = vec![v.as_slice(); 16];
471 let codec = RaBitQCodec::calibrate(&refs, dim, 1);
472 let q = codec.encode(&v);
474 assert!(
475 q.0.header().residual_norm < 1e-5,
476 "residual_norm should be ~0 for vector equal to centroid"
477 );
478 }
479}