use crate::vector_quant::codec::VectorCodec;
use crate::vector_quant::hamming::hamming_distance;
use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
#[inline]
fn xorshift64(state: &mut u64) -> u64 {
let mut x = *state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
*state = x;
x
}
#[inline]
fn next_pow2(n: usize) -> usize {
if n.is_power_of_two() {
n
} else {
n.next_power_of_two()
}
}
fn wht_inplace(buf: &mut [f32]) {
let n = buf.len();
debug_assert!(n.is_power_of_two());
let mut step = 1usize;
while step < n {
let mut i = 0usize;
while i < n {
for j in i..i + step {
let a = buf[j];
let b = buf[j + step];
buf[j] = a + b;
buf[j + step] = a - b;
}
i += step * 2;
}
step *= 2;
}
}
fn sign_pack(rotated: &[f32], dim: usize) -> Vec<u8> {
let nbytes = dim.div_ceil(8);
let mut out = vec![0u8; nbytes];
for (i, &v) in rotated.iter().take(dim).enumerate() {
if v < 0.0 {
out[i / 8] |= 1 << (i % 8);
}
}
out
}
fn sign_unpack(packed: &[u8], dim: usize) -> Vec<f32> {
(0..dim)
.map(|i| {
if packed[i / 8] & (1 << (i % 8)) != 0 {
-1.0f32
} else {
1.0f32
}
})
.collect()
}
pub struct RaBitQCodec {
pub dim: usize,
centroid: Vec<f32>,
rotation_seed: u64,
pub bias_correct: bool,
}
impl RaBitQCodec {
pub fn calibrate(vectors: &[&[f32]], dim: usize, rotation_seed: u64) -> Self {
let centroid = if vectors.is_empty() {
vec![0.0f32; dim]
} else {
let n = vectors.len() as f32;
let mut c = vec![0.0f32; dim];
for v in vectors {
for (ci, &vi) in c.iter_mut().zip(v.iter()) {
*ci += vi;
}
}
c.iter_mut().for_each(|x| *x /= n);
c
};
Self {
dim,
centroid,
rotation_seed,
bias_correct: false,
}
}
pub fn apply_rotation(&self, v: &[f32]) -> Vec<f32> {
let dim = self.dim;
let pow2 = next_pow2(dim);
let mut seed = self.rotation_seed;
let mut buf = vec![0.0f32; pow2];
for (i, &vi) in v.iter().take(dim).enumerate() {
let flip = if xorshift64(&mut seed) & 1 == 0 {
1.0f32
} else {
-1.0f32
};
buf[i] = vi * flip;
}
wht_inplace(&mut buf);
buf.truncate(dim);
buf
}
fn encode_inner(&self, v: &[f32]) -> UnifiedQuantizedVector {
let dim = self.dim;
let residual: Vec<f32> = v
.iter()
.zip(self.centroid.iter())
.map(|(&vi, &ci)| vi - ci)
.collect();
let residual_norm = residual.iter().map(|x| x * x).sum::<f32>().sqrt();
let rotated = self.apply_rotation(&residual);
let packed = sign_pack(&rotated, dim);
let signs_fp = sign_unpack(&packed, dim);
let pow2 = next_pow2(dim);
let mut sign_buf = vec![0.0f32; pow2];
for (i, &s) in signs_fp.iter().enumerate() {
sign_buf[i] = s;
}
wht_inplace(&mut sign_buf);
let mut seed = self.rotation_seed;
#[allow(clippy::needless_range_loop)]
for i in 0..dim {
let flip = if xorshift64(&mut seed) & 1 == 0 {
1.0f32
} else {
-1.0f32
};
sign_buf[i] *= flip;
}
let dot_raw: f32 = residual
.iter()
.zip(sign_buf.iter().take(dim))
.map(|(&r, &s)| r * s)
.sum();
let dot_quantized = if residual_norm > 0.0 {
dot_raw / residual_norm
} else {
0.0
};
let header = QuantHeader {
quant_mode: QuantMode::RaBitQ as u16,
dim: dim as u16,
global_scale: residual_norm,
residual_norm,
dot_quantized,
outlier_bitmask: 0,
reserved: [0u8; 8],
};
UnifiedQuantizedVector::new(header, &packed, &[])
.expect("RaBitQ encode: layout construction must succeed")
}
}
pub struct RaBitQQuantized(UnifiedQuantizedVector);
impl AsRef<UnifiedQuantizedVector> for RaBitQQuantized {
#[inline]
fn as_ref(&self) -> &UnifiedQuantizedVector {
&self.0
}
}
pub struct RaBitQQuery {
pub rotated_signs: Vec<u8>,
pub query_norm: f32,
}
impl VectorCodec for RaBitQCodec {
type Quantized = RaBitQQuantized;
type Query = RaBitQQuery;
fn encode(&self, v: &[f32]) -> Self::Quantized {
RaBitQQuantized(self.encode_inner(v))
}
fn prepare_query(&self, q: &[f32]) -> Self::Query {
let dim = self.dim;
let residual: Vec<f32> = q
.iter()
.zip(self.centroid.iter())
.map(|(&qi, &ci)| qi - ci)
.collect();
let query_norm = residual.iter().map(|x| x * x).sum::<f32>().sqrt();
let rotated = self.apply_rotation(&residual);
let rotated_signs = sign_pack(&rotated, dim);
RaBitQQuery {
rotated_signs,
query_norm,
}
}
fn fast_symmetric_distance(&self, q: &Self::Quantized, v: &Self::Quantized) -> f32 {
let qh = q.0.header();
let vh = v.0.header();
let qb = q.0.packed_bits();
let vb = v.0.packed_bits();
let h = hamming_distance(qb, vb);
let dim = self.dim as f32;
let dot_estimate = 1.0 - 2.0 * h as f32 / dim;
let approx = qh.residual_norm * qh.residual_norm + vh.residual_norm * vh.residual_norm
- 2.0 * qh.residual_norm * vh.residual_norm * dot_estimate;
approx.max(0.0)
}
fn exact_asymmetric_distance(&self, q: &Self::Query, v: &Self::Quantized) -> f32 {
let vh = v.0.header();
let vb = v.0.packed_bits();
let h = hamming_distance(&q.rotated_signs, vb);
let dim = self.dim as f32;
let dot_estimate = 1.0 - 2.0 * h as f32 / dim;
let mut approx = q.query_norm * q.query_norm + vh.residual_norm * vh.residual_norm
- 2.0 * q.query_norm * vh.residual_norm * dot_estimate;
if self.bias_correct {
approx -= vh.dot_quantized;
}
approx.max(0.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn random_vec(seed: u64, dim: usize) -> Vec<f32> {
let mut s = seed | 1;
(0..dim)
.map(|_| {
let v = xorshift64(&mut s);
(v as f32 / u64::MAX as f32) * 2.0 - 1.0
})
.collect()
}
#[test]
fn apply_rotation_different_seeds_differ() {
let dim = 64;
let v: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
let codec_a = RaBitQCodec::calibrate(&[], dim, 0xDEAD_BEEF_1234_5678);
let codec_b = RaBitQCodec::calibrate(&[], dim, 0xCAFE_BABE_0000_0001);
let rot_a = codec_a.apply_rotation(&v);
let rot_b = codec_b.apply_rotation(&v);
let differ = rot_a
.iter()
.zip(rot_b.iter())
.any(|(a, b)| (a - b).abs() > 1e-6);
assert!(differ, "different seeds must produce different rotations");
}
#[test]
fn encode_roundtrip_preserves_residual_norm() {
let dim = 128;
let vecs: Vec<Vec<f32>> = (0..16).map(|i| random_vec(i as u64, dim)).collect();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = RaBitQCodec::calibrate(&refs, dim, 42);
let v = random_vec(99, dim);
let q = codec.encode(&v);
let h = q.0.header();
assert!(h.residual_norm.is_finite() && h.residual_norm >= 0.0);
assert!((h.global_scale - h.residual_norm).abs() < 1e-6);
}
#[test]
fn distance_non_negative_finite() {
let dim = 64;
let vecs: Vec<Vec<f32>> = (0..8).map(|i| random_vec(i as u64, dim)).collect();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = RaBitQCodec::calibrate(&refs, dim, 7);
let v1 = codec.encode(&random_vec(100, dim));
let v2 = codec.encode(&random_vec(200, dim));
let sym = codec.fast_symmetric_distance(&v1, &v2);
assert!(sym.is_finite() && sym >= 0.0, "sym distance: {sym}");
let q = codec.prepare_query(&random_vec(300, dim));
let asym = codec.exact_asymmetric_distance(&q, &v2);
assert!(asym.is_finite() && asym >= 0.0, "asym distance: {asym}");
}
#[test]
fn calibrate_identical_vectors_zero_residual() {
let dim = 32;
let v: Vec<f32> = (0..dim).map(|i| i as f32).collect();
let refs = vec![v.as_slice(); 16];
let codec = RaBitQCodec::calibrate(&refs, dim, 1);
let q = codec.encode(&v);
assert!(
q.0.header().residual_norm < 1e-5,
"residual_norm should be ~0 for vector equal to centroid"
);
}
}