use crate::error::Error;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct RaBitQCorrection {
pub vector_norm: f32,
pub quantization_ip: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RaBitQVector {
pub bits: Vec<u64>,
pub correction: RaBitQCorrection,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RaBitQIndex {
pub rotation: Vec<f32>,
pub centroid: Vec<f32>,
pub dimension: usize,
}
#[must_use]
pub(crate) fn signs_to_bits(values: &[f32], dim: usize) -> Vec<u64> {
let num_words = dim.div_ceil(64);
let mut bits = vec![0u64; num_words];
for (i, &v) in values.iter().take(dim).enumerate() {
if v >= 0.0 {
let word = i / 64;
let bit = i % 64;
bits[word] |= 1u64 << bit;
}
}
bits
}
#[must_use]
pub(crate) fn apply_rotation_flat(rotation: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
(0..dim)
.map(|i| {
let row_start = i * dim;
crate::simd_native::dot_product_native(&rotation[row_start..row_start + dim], vector)
})
.collect()
}
fn xor_popcount_ip(q_bits: &[u64], enc_bits: &[u64], num_words: usize, dim: usize) -> f32 {
let differing_bits =
crate::simd_native::hamming_binary_native(&q_bits[..num_words], &enc_bits[..num_words]);
#[allow(clippy::cast_possible_truncation)]
let dim_u32 = dim as u32;
debug_assert!(
differing_bits <= dim_u32,
"differing_bits ({differing_bits}) > dim ({dim}): \
signs_to_bits must zero padding bits"
);
let matching_bits = dim_u32 - differing_bits;
#[allow(clippy::cast_precision_loss)]
let d_f = dim as f32;
#[allow(clippy::cast_precision_loss)]
let ip = (2.0f32.mul_add(matching_bits as f32, -d_f)) / d_f;
ip
}
pub(crate) struct PreparedQuery {
pub(crate) norm_sq: f32,
pub(crate) norm: f32,
pub(crate) bits: Vec<u64>,
pub(crate) num_words: usize,
pub(crate) rotated: Vec<f32>,
}
impl RaBitQIndex {
pub(crate) fn prepare_query(&self, vector: &[f32]) -> Option<PreparedQuery> {
let centered: Vec<f32> = vector
.iter()
.zip(self.centroid.iter())
.map(|(&v, &c)| v - c)
.collect();
let norm_sq: f32 = centered.iter().map(|&x| x * x).sum();
let norm = norm_sq.sqrt();
if norm < f32::EPSILON {
return None;
}
let normalized: Vec<f32> = centered.iter().map(|&x| x / norm).collect();
let rotated = apply_rotation_flat(&self.rotation, &normalized, self.dimension);
let bits = signs_to_bits(&rotated, self.dimension);
let num_words = self.dimension.div_ceil(64);
Some(PreparedQuery {
norm_sq,
norm,
bits,
num_words,
rotated,
})
}
pub(crate) fn distance_from_prepared(&self, pq: &PreparedQuery, encoded: &RaBitQVector) -> f32 {
self.distance_from_prepared_slice(pq, &encoded.bits, encoded.correction)
}
pub(crate) fn distance_from_prepared_slice(
&self,
pq: &PreparedQuery,
bits: &[u64],
correction: RaBitQCorrection,
) -> f32 {
let ip_binary = xor_popcount_ip(&pq.bits, bits, pq.num_words, self.dimension);
let v_norm = correction.vector_norm;
let estimated_ip = pq.norm * v_norm * ip_binary;
let l2_sq = v_norm.mul_add(v_norm, pq.norm_sq) - 2.0 * estimated_ip;
l2_sq.max(0.0).sqrt()
}
pub fn encode(&self, vector: &[f32]) -> Result<RaBitQVector, Error> {
if vector.len() != self.dimension {
return Err(Error::InvalidQuantizerConfig(format!(
"RaBitQ encode: expected dimension {}, got {}",
self.dimension,
vector.len()
)));
}
let Some(pq) = self.prepare_query(vector) else {
let num_words = self.dimension.div_ceil(64);
return Ok(RaBitQVector {
bits: vec![0u64; num_words],
correction: RaBitQCorrection {
vector_norm: 0.0,
quantization_ip: 1.0,
},
});
};
#[allow(clippy::cast_precision_loss)]
let scale = 1.0 / (self.dimension as f32).sqrt();
let mut qip: f32 = 0.0;
for (i, &rv) in pq.rotated.iter().enumerate().take(self.dimension) {
let word = i / 64;
let bit = i % 64;
let sign = if (pq.bits[word] >> bit) & 1 == 1 {
1.0
} else {
-1.0
};
qip = (sign * scale).mul_add(rv, qip);
}
Ok(RaBitQVector {
bits: pq.bits,
correction: RaBitQCorrection {
vector_norm: pq.norm,
quantization_ip: qip,
},
})
}
#[must_use]
pub fn distance(&self, query: &[f32], encoded: &RaBitQVector) -> f32 {
let Some(pq) = self.prepare_query(query) else {
return encoded.correction.vector_norm;
};
self.distance_from_prepared(&pq, encoded)
}
#[must_use]
pub fn batch_distance(&self, query: &[f32], encoded: &[RaBitQVector]) -> Vec<f32> {
let Some(pq) = self.prepare_query(query) else {
return encoded.iter().map(|e| e.correction.vector_norm).collect();
};
encoded
.iter()
.map(|ev| self.distance_from_prepared(&pq, ev))
.collect()
}
}
#[cfg(feature = "persistence")]
impl RaBitQIndex {
pub fn train(vectors: &[Vec<f32>], seed: u64) -> Result<Self, Error> {
if vectors.is_empty() {
return Err(Error::InvalidQuantizerConfig(
"cannot train RaBitQ with empty dataset".into(),
));
}
let dimension = vectors[0].len();
if dimension == 0 {
return Err(Error::InvalidQuantizerConfig(
"vectors must have non-zero dimension".into(),
));
}
if !vectors.iter().all(|v| v.len() == dimension) {
return Err(Error::InvalidQuantizerConfig(
"all vectors must share the same dimension".into(),
));
}
let mut centroid = vec![0.0f32; dimension];
for v in vectors {
for (ci, &vi) in centroid.iter_mut().zip(v.iter()) {
*ci += vi;
}
}
#[allow(clippy::cast_precision_loss)]
let inv_n = 1.0 / vectors.len() as f32;
for x in &mut centroid {
*x *= inv_n;
}
let rotation = generate_orthogonal_matrix(dimension, seed);
Ok(Self {
rotation,
centroid,
dimension,
})
}
pub fn save(&self, dir: &std::path::Path) -> Result<(), Error> {
let data = postcard::to_allocvec(self).map_err(|e| {
Error::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("failed to serialize RaBitQ index: {e}"),
))
})?;
let tmp_path = dir.join("rabitq.idx.tmp");
let final_path = dir.join("rabitq.idx");
std::fs::write(&tmp_path, &data)?;
std::fs::rename(&tmp_path, &final_path)?;
Ok(())
}
pub fn load(dir: &std::path::Path) -> Result<Option<Self>, Error> {
let path = dir.join("rabitq.idx");
if !path.exists() {
return Ok(None);
}
let data = std::fs::read(&path)?;
let index: Self = postcard::from_bytes(&data).map_err(|e| {
Error::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("failed to deserialize RaBitQ index: {e}"),
))
})?;
Ok(Some(index))
}
}
#[cfg(feature = "persistence")]
fn generate_orthogonal_matrix(dim: usize, seed: u64) -> Vec<f32> {
use rand::{Rng, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut columns: Vec<Vec<f32>> = (0..dim)
.map(|_| (0..dim).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect())
.collect();
for j in 0..dim {
let norm: f32 = columns[j].iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > f32::EPSILON {
for x in &mut columns[j] {
*x /= norm;
}
}
for k in (j + 1)..dim {
let dot: f32 = columns[j]
.iter()
.zip(columns[k].iter())
.map(|(&a, &b)| a * b)
.sum();
let proj: Vec<f32> = columns[j].iter().map(|&x| dot * x).collect();
for (ck, p) in columns[k].iter_mut().zip(proj.iter()) {
*ck -= p;
}
}
}
let mut rotation = vec![0.0f32; dim * dim];
for i in 0..dim {
for j in 0..dim {
rotation[i * dim + j] = columns[j][i];
}
}
rotation
}
#[cfg(test)]
#[path = "rabitq_tests.rs"]
mod tests;