use alloc::vec;
use alloc::vec::Vec;
use crate::tier::TemperatureTier;
use crate::traits::Quantizer;
pub const DEFAULT_ROUNDS: u8 = 3;
pub const CORRECTION_BYTES: usize = 8;
#[derive(Clone, Debug)]
pub struct RabitqQuantizer {
pub dim: usize,
pub padded_dim: usize,
pub seed: u64,
pub rounds: u8,
pub centroid: Vec<f32>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct RabitqCode {
pub bits: Vec<u8>,
pub norm: f32,
pub dot_corr: f32,
}
impl RabitqCode {
#[inline]
pub fn stored_bytes(&self) -> usize {
self.bits.len() + CORRECTION_BYTES
}
}
#[derive(Clone, Debug)]
pub struct RabitqQuery {
pub rotated: Vec<f32>,
pub norm_sq: f32,
}
#[inline]
fn splitmix64(x: u64) -> u64 {
let mut z = x.wrapping_add(0x9E37_79B9_7F4A_7C15);
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
#[inline]
fn next_pow2(n: usize) -> usize {
n.max(1).next_power_of_two()
}
fn fwht(v: &mut [f32]) {
let n = v.len();
let mut h = 1;
while h < n {
let mut i = 0;
while i < n {
for j in i..i + h {
let x = v[j];
let y = v[j + h];
v[j] = x + y;
v[j + h] = x - y;
}
i += h * 2;
}
h *= 2;
}
}
impl RabitqQuantizer {
pub fn train(vectors: &[&[f32]], seed: u64) -> Self {
assert!(!vectors.is_empty(), "need at least one training vector");
let dim = vectors[0].len();
assert!(dim > 0, "vector dimensionality must be > 0");
let mut centroid = vec![0.0f64; dim];
for v in vectors {
assert_eq!(v.len(), dim, "dimension mismatch in training data");
for (acc, &x) in centroid.iter_mut().zip(v.iter()) {
*acc += x as f64;
}
}
let inv_n = 1.0 / vectors.len() as f64;
let centroid: Vec<f32> = centroid.iter().map(|&s| (s * inv_n) as f32).collect();
Self::with_centroid(dim, centroid, seed, DEFAULT_ROUNDS)
}
pub fn with_centroid(dim: usize, centroid: Vec<f32>, seed: u64, rounds: u8) -> Self {
assert_eq!(centroid.len(), dim, "centroid length must equal dim");
Self {
dim,
padded_dim: next_pow2(dim),
seed,
rounds: rounds.max(1),
centroid,
}
}
#[inline]
fn sign_flip(&self, round: u8, i: usize) -> bool {
let word = splitmix64(
self.seed
^ (round as u64).wrapping_mul(0xA076_1D64_78BD_642F)
^ ((i as u64) / 64).wrapping_mul(0xE703_7ED1_A0B4_28DB),
);
(word >> (i % 64)) & 1 == 1
}
pub fn rotate(&self, v: &[f32]) -> Vec<f32> {
debug_assert!(v.len() <= self.padded_dim);
let mut buf = vec![0.0f32; self.padded_dim];
buf[..v.len()].copy_from_slice(v);
let scale = 1.0 / (self.padded_dim as f32).sqrt();
for round in 0..self.rounds {
for (i, x) in buf.iter_mut().enumerate() {
if self.sign_flip(round, i) {
*x = -*x;
}
}
fwht(&mut buf);
for x in buf.iter_mut() {
*x *= scale;
}
}
buf
}
pub fn rotate_inverse(&self, v: &[f32]) -> Vec<f32> {
debug_assert_eq!(v.len(), self.padded_dim);
let mut buf = v.to_vec();
let scale = 1.0 / (self.padded_dim as f32).sqrt();
for round in (0..self.rounds).rev() {
fwht(&mut buf);
for x in buf.iter_mut() {
*x *= scale;
}
for (i, x) in buf.iter_mut().enumerate() {
if self.sign_flip(round, i) {
*x = -*x;
}
}
}
buf
}
pub fn encode_code(&self, vector: &[f32]) -> RabitqCode {
assert_eq!(vector.len(), self.dim, "vector dimension mismatch");
let centered: Vec<f32> = vector
.iter()
.zip(self.centroid.iter())
.map(|(&x, &c)| x - c)
.collect();
let rotated = self.rotate(¢ered);
let mut norm_sq = 0.0f32;
let mut abs_sum = 0.0f32;
let mut bits = vec![0u8; self.padded_dim.div_ceil(8)];
for (d, &x) in rotated.iter().enumerate() {
norm_sq += x * x;
abs_sum += x.abs();
if x >= 0.0 {
bits[d / 8] |= 1 << (d % 8);
}
}
let norm = norm_sq.sqrt();
let dot_corr = if norm > f32::EPSILON {
(abs_sum / (norm * (self.padded_dim as f32).sqrt())).max(f32::EPSILON)
} else {
1.0
};
RabitqCode {
bits,
norm,
dot_corr,
}
}
pub fn prepare_query(&self, query: &[f32]) -> RabitqQuery {
assert_eq!(query.len(), self.dim, "query dimension mismatch");
let centered: Vec<f32> = query
.iter()
.zip(self.centroid.iter())
.map(|(&x, &c)| x - c)
.collect();
let rotated = self.rotate(¢ered);
let norm_sq = rotated.iter().map(|&x| x * x).sum();
RabitqQuery { rotated, norm_sq }
}
pub fn estimate_l2_sq(&self, query: &RabitqQuery, code: &RabitqCode) -> f32 {
let mut signed_sum = 0.0f32;
for (d, &x) in query.rotated.iter().enumerate() {
if (code.bits[d / 8] >> (d % 8)) & 1 == 1 {
signed_sum += x;
} else {
signed_sum -= x;
}
}
let est_ip = code.norm * (signed_sum / (self.padded_dim as f32).sqrt()) / code.dot_corr;
code.norm * code.norm + query.norm_sq - 2.0 * est_ip
}
#[inline]
pub fn stored_bytes_per_vector(&self) -> usize {
self.padded_dim.div_ceil(8) + CORRECTION_BYTES
}
#[inline]
pub fn compression_ratio(&self) -> f32 {
(self.dim * 4) as f32 / self.stored_bytes_per_vector() as f32
}
pub fn code_to_bytes(&self, code: &RabitqCode) -> Vec<u8> {
let mut out = Vec::with_capacity(code.stored_bytes());
out.extend_from_slice(&code.bits);
out.extend_from_slice(&code.norm.to_le_bytes());
out.extend_from_slice(&code.dot_corr.to_le_bytes());
out
}
pub fn code_from_bytes(&self, data: &[u8]) -> Option<RabitqCode> {
let nbits = self.padded_dim.div_ceil(8);
if data.len() < nbits + CORRECTION_BYTES {
return None;
}
let bits = data[..nbits].to_vec();
let norm = f32::from_le_bytes(data[nbits..nbits + 4].try_into().ok()?);
let dot_corr = f32::from_le_bytes(data[nbits + 4..nbits + 8].try_into().ok()?);
Some(RabitqCode {
bits,
norm,
dot_corr,
})
}
}
impl Quantizer for RabitqQuantizer {
fn encode(&self, vector: &[f32]) -> Vec<u8> {
self.code_to_bytes(&self.encode_code(vector))
}
fn decode(&self, codes: &[u8]) -> Vec<f32> {
let code = match self.code_from_bytes(codes) {
Some(c) => c,
None => return vec![0.0; self.dim],
};
let scale = code.norm * code.dot_corr / (self.padded_dim as f32).sqrt();
let mut rotated = Vec::with_capacity(self.padded_dim);
for d in 0..self.padded_dim {
let sign = if (code.bits[d / 8] >> (d % 8)) & 1 == 1 {
1.0
} else {
-1.0
};
rotated.push(sign * scale);
}
let residual = self.rotate_inverse(&rotated);
residual
.iter()
.take(self.dim)
.zip(self.centroid.iter())
.map(|(&r, &c)| r + c)
.collect()
}
fn tier(&self) -> TemperatureTier {
TemperatureTier::Cold
}
fn dim(&self) -> usize {
self.dim
}
}
#[cfg(test)]
mod tests {
use super::*;
fn lcg_vector(dim: usize, seed: u64) -> Vec<f32> {
let mut x = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15).wrapping_add(1);
(0..dim)
.map(|_| {
x = x
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((x >> 33) as f32) / (u32::MAX as f32) - 0.5
})
.collect()
}
fn make_quantizer(dim: usize, n: usize) -> (RabitqQuantizer, Vec<Vec<f32>>) {
let data: Vec<Vec<f32>> = (0..n).map(|i| lcg_vector(dim, i as u64)).collect();
let refs: Vec<&[f32]> = data.iter().map(|v| v.as_slice()).collect();
(RabitqQuantizer::train(&refs, 0xDEAD_BEEF), data)
}
fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
}
#[test]
fn rotation_is_orthonormal_and_deterministic() {
let (rq, data) = make_quantizer(100, 8); assert_eq!(rq.padded_dim, 128);
for v in &data {
let r1 = rq.rotate(v);
let r2 = rq.rotate(v);
assert_eq!(r1, r2, "rotation must be deterministic");
let norm_in: f32 = v.iter().map(|x| x * x).sum();
let norm_out: f32 = r1.iter().map(|x| x * x).sum();
assert!(
(norm_in - norm_out).abs() < 1e-3 * norm_in.max(1.0),
"rotation must preserve norms: {norm_in} vs {norm_out}"
);
let back = rq.rotate_inverse(&r1);
for (d, (&orig, &rec)) in v.iter().zip(back.iter()).enumerate() {
assert!(
(orig - rec).abs() < 1e-4,
"dim {d}: {orig} != {rec} after inverse rotation"
);
}
for &pad in &back[v.len()..] {
assert!(pad.abs() < 1e-4, "padding must invert to ~0");
}
}
}
#[test]
fn rotation_preserves_inner_products() {
let (rq, data) = make_quantizer(64, 4);
let a = &data[0];
let b = &data[1];
let ip: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let ra = rq.rotate(a);
let rb = rq.rotate(b);
let rip: f32 = ra.iter().zip(rb.iter()).map(|(x, y)| x * y).sum();
assert!((ip - rip).abs() < 1e-3, "ip {ip} vs rotated ip {rip}");
}
#[test]
fn different_seeds_give_different_rotations() {
let v = lcg_vector(32, 7);
let a = RabitqQuantizer::with_centroid(32, vec![0.0; 32], 1, DEFAULT_ROUNDS);
let b = RabitqQuantizer::with_centroid(32, vec![0.0; 32], 2, DEFAULT_ROUNDS);
assert_ne!(a.rotate(&v), b.rotate(&v));
}
#[test]
fn code_round_trip_bytes() {
let (rq, data) = make_quantizer(48, 16);
for v in &data {
let code = rq.encode_code(v);
let bytes = rq.code_to_bytes(&code);
assert_eq!(bytes.len(), rq.stored_bytes_per_vector());
let back = rq.code_from_bytes(&bytes).expect("decode");
assert_eq!(back, code);
}
let code = rq.encode_code(&data[0]);
let bytes = rq.code_to_bytes(&code);
assert!(rq.code_from_bytes(&bytes[..bytes.len() - 1]).is_none());
assert!(rq.code_from_bytes(&[]).is_none());
}
#[test]
fn decode_reconstruction_beats_naive_sign_bits() {
let (rq, data) = make_quantizer(128, 64);
let mut rabitq_err = 0.0f64;
let mut naive_err = 0.0f64;
for v in &data {
let rec = rq.decode(&rq.encode(v));
rabitq_err += l2_sq(v, &rec) as f64;
let bits = crate::binary::encode_binary(v);
let nrec = crate::binary::decode_binary(&bits, v.len());
naive_err += l2_sq(v, &nrec) as f64;
}
assert!(
rabitq_err < naive_err,
"RaBitQ reconstruction error {rabitq_err} must beat naive {naive_err}"
);
}
#[test]
fn estimator_correlates_with_true_distances() {
let dim = 128;
let (rq, data) = make_quantizer(dim, 200);
let codes: Vec<RabitqCode> = data.iter().map(|v| rq.encode_code(v)).collect();
let mut est = Vec::new();
let mut truth = Vec::new();
for qi in 0..20u64 {
let q = lcg_vector(dim, 5_000 + qi);
let prepared = rq.prepare_query(&q);
for (v, code) in data.iter().zip(codes.iter()) {
est.push(rq.estimate_l2_sq(&prepared, code) as f64);
truth.push(l2_sq(&q, v) as f64);
}
}
let n = est.len() as f64;
let me = est.iter().sum::<f64>() / n;
let mt = truth.iter().sum::<f64>() / n;
let mut cov = 0.0;
let mut ve = 0.0;
let mut vt = 0.0;
for (&e, &t) in est.iter().zip(truth.iter()) {
cov += (e - me) * (t - mt);
ve += (e - me) * (e - me);
vt += (t - mt) * (t - mt);
}
let corr = cov / (ve.sqrt() * vt.sqrt());
#[cfg(feature = "std")]
std::eprintln!("estimator/true distance correlation (128d): {corr:.4}");
assert!(
corr > 0.8,
"estimator correlation {corr:.3} too weak (expected > 0.8)"
);
let mean_rel: f64 = est
.iter()
.zip(truth.iter())
.map(|(&e, &t)| ((e - t) / t.max(1e-9)).abs())
.sum::<f64>()
/ n;
#[cfg(feature = "std")]
std::eprintln!("estimator mean relative distance error (128d): {mean_rel:.4}");
assert!(
mean_rel < 0.25,
"mean relative error {mean_rel:.3} too large"
);
}
#[test]
fn compression_ratio_targets() {
let rq128 = RabitqQuantizer::with_centroid(128, vec![0.0; 128], 1, DEFAULT_ROUNDS);
assert_eq!(rq128.padded_dim, 128);
assert_eq!((rq128.dim * 4) / (rq128.padded_dim / 8), 32);
assert!(rq128.compression_ratio() >= 20.0);
let rq1024 = RabitqQuantizer::with_centroid(1024, vec![0.0; 1024], 1, DEFAULT_ROUNDS);
assert!(rq1024.compression_ratio() >= 30.0);
}
#[test]
fn zero_residual_vector_is_safe() {
let (rq, _) = make_quantizer(16, 4);
let code = rq.encode_code(&rq.centroid.clone());
assert!(code.norm <= 1e-6);
let q = rq.prepare_query(&lcg_vector(16, 99));
let est = rq.estimate_l2_sq(&q, &code);
assert!(est.is_finite());
}
}