use rand::{rngs::StdRng, Rng, SeedableRng};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RaBitQCodebook {
pub dim: usize,
pub seed: u64,
#[serde(skip)]
proj: Vec<f32>, }
impl RaBitQCodebook {
pub fn new(dim: usize, seed: u64) -> Self {
let mut cb = Self {
dim,
seed,
proj: vec![],
};
cb.rebuild_proj();
cb
}
pub fn rebuild_proj(&mut self) {
let dim = self.dim;
let mut rng = StdRng::seed_from_u64(self.seed);
let mut proj = vec![0.0f32; dim * dim];
for x in proj.iter_mut() {
*x = rng.gen::<f32>() * 2.0 - 1.0;
}
for col in 0..dim {
for prev in 0..col {
let dot: f32 = (0..dim)
.map(|row| proj[row * dim + col] * proj[row * dim + prev])
.sum();
for row in 0..dim {
let p = proj[row * dim + prev];
proj[row * dim + col] -= dot * p;
}
}
let norm: f32 = (0..dim)
.map(|row| proj[row * dim + col] * proj[row * dim + col])
.sum::<f32>()
.sqrt();
let inv = 1.0 / norm.max(1e-12);
for row in 0..dim {
proj[row * dim + col] *= inv;
}
}
self.proj = proj;
}
pub fn is_ready(&self) -> bool {
self.proj.len() == self.dim * self.dim
}
pub fn project(&self, v: &[f32]) -> Vec<f32> {
debug_assert_eq!(v.len(), self.dim);
let dim = self.dim;
(0..dim)
.map(|i| {
let row = &self.proj[i * dim..(i + 1) * dim];
row.iter().zip(v.iter()).map(|(a, b)| a * b).sum::<f32>()
})
.collect()
}
pub fn encode(&self, v: &[f32]) -> RaBitQVec {
let dim = self.dim;
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
let v_hat: Vec<f32> = if norm > 1e-12 {
v.iter().map(|x| x / norm).collect()
} else {
v.to_vec()
};
let pv = self.project(&v_hat);
let code = bits_from_signs(&pv);
let scale = pv.iter().map(|x| x.abs()).sum::<f32>() / (dim as f32).sqrt();
RaBitQVec { code, norm, scale }
}
pub fn prepare_query(&self, q: &[f32]) -> (Vec<f32>, f32) {
let dim = self.dim;
let norm = q.iter().map(|x| x * x).sum::<f32>().sqrt();
let q_hat: Vec<f32> = if norm > 1e-12 {
q.iter().map(|x| x / norm).collect()
} else {
q.to_vec()
};
let pq = self.project(&q_hat);
let scale = pq.iter().map(|x| x.abs()).sum::<f32>() / (dim as f32).sqrt();
(pq, scale)
}
pub fn estimate_ip_binary(&self, b_q: &[u8], q_scale: f32, entry: &RaBitQVec) -> f32 {
let dim = self.dim;
let hamming: u32 = b_q
.iter()
.zip(entry.code.iter())
.map(|(a, b)| (a ^ b).count_ones())
.sum();
(1.0 - 2.0 * hamming as f32 / dim as f32) * q_scale * entry.scale
}
pub fn estimate_ip(&self, q_proj: &[f32], q_scale: f32, entry: &RaBitQVec) -> f32 {
let b_q = bits_from_signs(q_proj);
self.estimate_ip_binary(&b_q, q_scale, entry)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RaBitQVec {
pub code: Vec<u8>,
pub norm: f32,
pub scale: f32,
}
pub fn bits_from_signs(v: &[f32]) -> Vec<u8> {
let code_len = v.len().div_ceil(8);
let mut code = vec![0u8; code_len];
for (i, &val) in v.iter().enumerate() {
if val > 0.0 {
code[i / 8] |= 1 << (i & 7);
}
}
code
}
pub fn encode_batch(codebook: &RaBitQCodebook, vectors: &[Vec<f32>]) -> Vec<RaBitQVec> {
vectors.par_iter().map(|v| codebook.encode(v)).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn codebook_rebuild_is_deterministic() {
let cb1 = RaBitQCodebook::new(16, 42);
let mut cb2 = RaBitQCodebook {
dim: 16,
seed: 42,
proj: vec![],
};
cb2.rebuild_proj();
assert_eq!(cb1.proj, cb2.proj);
}
#[test]
fn encode_decode_roundtrip_similar_vectors() {
let dim = 32usize;
let cb = RaBitQCodebook::new(dim, 99);
let v: Vec<f32> = (0..dim).map(|i| (i as f32).cos()).collect();
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
let v: Vec<f32> = v.iter().map(|x| x / norm).collect();
let e1 = cb.encode(&v);
let e2 = cb.encode(&v);
assert_eq!(e1.code, e2.code);
}
#[test]
fn ip_estimate_identical_vectors() {
let dim = 64usize;
let cb = RaBitQCodebook::new(dim, 7);
let v: Vec<f32> = (0..dim)
.map(|i| if i % 3 == 0 { 1.0 } else { -0.5 })
.collect();
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
let v: Vec<f32> = v.iter().map(|x| x / norm).collect();
let entry = cb.encode(&v);
let (q_proj, q_scale) = cb.prepare_query(&v);
let ip = cb.estimate_ip(&q_proj, q_scale, &entry);
assert!(
ip > 0.4,
"expected IP estimate > 0.4 for identical unit vectors, got {ip}"
);
let v2: Vec<f32> = (0..dim).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
let entry2 = cb.encode(&v2);
let (q2_proj, q2_scale) = cb.prepare_query(&v2);
let ip_diff = cb.estimate_ip(&q_proj, q_scale, &entry2);
let _ = (ip, ip_diff, q2_proj, q2_scale); }
#[test]
fn ip_estimate_orthogonal_vectors() {
let dim = 128usize;
let cb = RaBitQCodebook::new(dim, 13);
let mut a = vec![0.0f32; dim];
let mut b = vec![0.0f32; dim];
a[0] = 1.0;
b[1] = 1.0;
let entry = cb.encode(&b);
let (q_proj, q_scale) = cb.prepare_query(&a);
let ip = cb.estimate_ip(&q_proj, q_scale, &entry);
assert!(
ip.abs() < 0.3,
"expected IP estimate ≈ 0 for orthogonal vectors, got {ip}"
);
}
#[test]
fn bits_from_signs_basic() {
let v = vec![1.0f32, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
let code = bits_from_signs(&v);
assert_eq!(code.len(), 1);
assert_eq!(code[0], 0x55);
}
}