use faer::{Mat, MatRef};
use faer_traits::ComplexField;
use num_traits::{Float, FromPrimitive, ToPrimitive};
use rand::rngs::StdRng;
use rand::Rng;
use rand::SeedableRng;
use rand_distr::StandardNormal;
use std::iter::Sum;
use crate::binary::dist_binary::*;
use crate::prelude::*;
use crate::utils::k_means_utils::*;
const RABITQ_K_MEANS_ITER: usize = 30;
#[repr(C)]
pub struct RaBitQQuery<T> {
pub quantised: Vec<u8>,
pub dist_to_centroid: T,
pub lower: T,
pub width: T,
pub sum_quantised: u32,
}
pub type VecEncoding<T> = (Vec<u8>, T, T, u32);
pub struct RaBitQEncoder<T> {
pub rotation: Vec<T>,
pub dim: usize,
pub n_bytes: usize,
pub metric: Dist,
}
impl<T> RaBitQEncoder<T>
where
T: Float + FromPrimitive + ToPrimitive + ComplexField + SimdDistance,
{
pub fn new(dim: usize, metric: Dist, seed: u64) -> Self {
let rotation = Self::generate_random_orthogonal(dim, seed);
let n_bytes = dim.div_ceil(8);
Self {
rotation,
dim,
n_bytes,
metric,
}
}
#[inline]
pub fn encode_vector(&self, vec: &[T], centroid: &[T]) -> VecEncoding<T> {
let res = T::subtract_simd(vec, centroid);
let dist_to_centroid = compute_l2_norm(&res);
let v_c: Vec<T> = if dist_to_centroid > T::epsilon() {
res.iter().map(|&r| r / dist_to_centroid).collect()
} else {
vec![T::zero(); self.dim]
};
let v_c_rotated = self.apply_rotation(&v_c);
let mut binary = vec![0u8; self.n_bytes];
let mut popcount: u32 = 0;
for d in 0..self.dim {
if v_c_rotated[d] >= T::zero() {
binary[d / 8] |= 1u8 << (d % 8);
popcount += 1;
}
}
let dot_correction: T = compute_l1_norm(&v_c_rotated);
(binary, dist_to_centroid, dot_correction, popcount)
}
#[inline]
pub fn encode_query(&self, query: &[T], centroid: &[T]) -> RaBitQQuery<T> {
let query_norm: Vec<T> = match self.metric {
Dist::Cosine => {
let norm = compute_l2_norm(query);
if norm > T::epsilon() {
query.iter().map(|&x| x / norm).collect()
} else {
query.to_vec()
}
}
Dist::Euclidean => query.to_vec(),
};
let res = T::subtract_simd(&query_norm, centroid);
let dist_to_centroid = compute_l2_norm(&res);
let q_c: Vec<T> = if dist_to_centroid > T::epsilon() {
res.iter().map(|&r| r / dist_to_centroid).collect()
} else {
vec![T::zero(); self.dim]
};
let q_c_rotated = self.apply_rotation(&q_c);
let (mut lower, mut upper) = (q_c_rotated[0], q_c_rotated[0]);
for d in 1..self.dim {
if q_c_rotated[d] < lower {
lower = q_c_rotated[d];
}
if q_c_rotated[d] > upper {
upper = q_c_rotated[d];
}
}
let range = upper - lower;
let width = if range > T::epsilon() {
range / T::from_f32(15.0).unwrap()
} else {
T::one()
};
let mut quantised = vec![0u8; self.dim];
let mut sum_quantised: u32 = 0;
for d in 0..self.dim {
let val = ((q_c_rotated[d] - lower) / width)
.round()
.to_u8()
.unwrap_or(0)
.min(15);
quantised[d] = val;
sum_quantised += val as u32;
}
RaBitQQuery {
quantised,
dist_to_centroid,
lower,
width,
sum_quantised,
}
}
#[inline]
fn apply_rotation(&self, vec: &[T]) -> Vec<T> {
let mut rotated = vec![T::zero(); self.dim];
let dim = self.dim;
for i in 0..dim {
let row = &self.rotation[i * dim..(i + 1) * dim];
rotated[i] = T::dot_simd(row, vec);
}
rotated
}
fn generate_random_orthogonal(dim: usize, seed: u64) -> Vec<T> {
let mut rng = StdRng::seed_from_u64(seed);
let mut mat = Mat::<T>::zeros(dim, dim);
for i in 0..dim {
for j in 0..dim {
let val: f64 = rng.sample(StandardNormal);
mat[(i, j)] = T::from_f64(val).unwrap();
}
}
let qr = mat.as_ref().qr();
let q = qr.compute_Q();
let mut rotation = Vec::with_capacity(dim * dim);
for i in 0..dim {
for j in 0..dim {
rotation.push(q[(i, j)]);
}
}
rotation
}
pub fn memory_usage_bytes(&self) -> usize {
std::mem::size_of_val(self) + self.rotation.capacity() * std::mem::size_of::<T>()
}
}
#[repr(C)]
#[derive(Clone)]
pub struct RaBitQPackedVector<T> {
pub dist_to_centroid: T,
pub dot_correction: T,
pub popcount: u32,
}
impl<T> RaBitQPackedVector<T> {
#[inline]
pub fn memory_usage_bytes() -> usize {
std::mem::size_of::<Self>()
}
}
pub struct RaBitQStorage<T> {
pub centroids: Vec<T>,
pub centroids_norm: Vec<T>,
pub binary_codes: Vec<u8>,
pub packed_vectors: Vec<RaBitQPackedVector<T>>,
pub vector_indices: Vec<usize>,
pub offsets: Vec<usize>,
pub nlist: usize,
pub dim: usize,
pub n_bytes: usize,
}
impl<T: Float + FromPrimitive + Clone> RaBitQStorage<T> {
pub fn with_capacity(nlist: usize, n: usize, dim: usize) -> Self {
let n_bytes = dim.div_ceil(8);
Self {
centroids: Vec::with_capacity(nlist * dim),
centroids_norm: Vec::with_capacity(nlist),
binary_codes: Vec::with_capacity(n * n_bytes),
packed_vectors: Vec::with_capacity(n),
vector_indices: Vec::with_capacity(n),
offsets: vec![0; nlist + 1],
nlist,
dim,
n_bytes,
}
}
#[inline]
pub fn centroid(&self, cluster_idx: usize) -> &[T] {
let start = cluster_idx * self.dim;
&self.centroids[start..start + self.dim]
}
#[inline]
pub fn cluster_binary_codes(&self, cluster_idx: usize) -> &[u8] {
let start_vec = self.offsets[cluster_idx];
let end_vec = self.offsets[cluster_idx + 1];
let start_byte = start_vec * self.n_bytes;
let end_byte = end_vec * self.n_bytes;
&self.binary_codes[start_byte..end_byte]
}
#[inline]
pub fn vector_binary(&self, cluster_idx: usize, local_idx: usize) -> &[u8] {
let cluster_start = self.offsets[cluster_idx];
let global_pos = cluster_start + local_idx;
let byte_start = global_pos * self.n_bytes;
&self.binary_codes[byte_start..byte_start + self.n_bytes]
}
#[inline]
pub fn get_vector_data(&self, cluster_idx: usize, local_idx: usize) -> &RaBitQPackedVector<T> {
let global_idx = self.offsets[cluster_idx] + local_idx;
&self.packed_vectors[global_idx]
}
#[inline]
pub fn cluster_packed_data(&self, cluster_idx: usize) -> &[RaBitQPackedVector<T>] {
let start = self.offsets[cluster_idx];
let end = self.offsets[cluster_idx + 1];
&self.packed_vectors[start..end]
}
#[inline]
pub fn cluster_popcounts(&self, cluster_idx: usize) -> impl Iterator<Item = u32> + '_ {
self.cluster_packed_data(cluster_idx)
.iter()
.map(|v| v.popcount)
}
#[inline]
pub fn cluster_dist_to_centroid(&self, cluster_idx: usize) -> impl Iterator<Item = T> + '_ {
self.cluster_packed_data(cluster_idx)
.iter()
.map(|v| v.dist_to_centroid)
}
#[inline]
pub fn cluster_dot_corrections(&self, cluster_idx: usize) -> impl Iterator<Item = T> + '_ {
self.cluster_packed_data(cluster_idx)
.iter()
.map(|v| v.dot_correction)
}
#[inline]
pub fn cluster_vector_indices(&self, cluster_idx: usize) -> &[usize] {
let start = self.offsets[cluster_idx];
let end = self.offsets[cluster_idx + 1];
&self.vector_indices[start..end]
}
#[inline]
pub fn cluster_size(&self, cluster_idx: usize) -> usize {
self.offsets[cluster_idx + 1] - self.offsets[cluster_idx]
}
#[inline]
pub fn n_vectors(&self) -> usize {
self.vector_indices.len()
}
pub fn memory_usage_bytes(&self) -> usize {
std::mem::size_of_val(self)
+ self.centroids.capacity() * std::mem::size_of::<T>()
+ self.centroids_norm.capacity() * std::mem::size_of::<T>()
+ self.binary_codes.capacity()
+ self.packed_vectors.capacity() * std::mem::size_of::<RaBitQPackedVector<T>>()
+ self.vector_indices.capacity() * std::mem::size_of::<usize>()
+ self.offsets.capacity() * std::mem::size_of::<usize>()
}
}
pub fn build_rabitq_storage<T>(
data: &[T],
dim: usize,
n: usize,
centroids: &[T],
nlist: usize,
assignments: &[usize],
encoder: &RaBitQEncoder<T>,
) -> RaBitQStorage<T>
where
T: Float + FromPrimitive + ToPrimitive + ComplexField + Sum + SimdDistance + Clone,
{
let n_bytes = dim.div_ceil(8);
let centroids_norm: Vec<T> = (0..nlist)
.map(|i| compute_l2_norm(¢roids[i * dim..(i + 1) * dim]))
.collect();
let mut counts = vec![0usize; nlist];
for &a in assignments {
counts[a] += 1;
}
let mut offsets = vec![0usize; nlist + 1];
for i in 0..nlist {
offsets[i + 1] = offsets[i] + counts[i];
}
let mut storage = RaBitQStorage {
centroids: centroids.to_vec(),
centroids_norm,
binary_codes: vec![0u8; n * n_bytes],
packed_vectors: vec![
RaBitQPackedVector {
dist_to_centroid: T::zero(),
dot_correction: T::zero(),
popcount: 0,
};
n
],
vector_indices: vec![0usize; n],
offsets: offsets.clone(),
nlist,
dim,
n_bytes,
};
let mut insert_pos = offsets[..nlist].to_vec();
for vec_idx in 0..n {
let cluster_idx = assignments[vec_idx];
let pos = insert_pos[cluster_idx];
insert_pos[cluster_idx] += 1;
let vec = &data[vec_idx * dim..(vec_idx + 1) * dim];
let centroid = ¢roids[cluster_idx * dim..(cluster_idx + 1) * dim];
let (binary, dist, dot_corr, popcount) = encoder.encode_vector(vec, centroid);
let byte_start = pos * n_bytes;
storage.binary_codes[byte_start..byte_start + n_bytes].copy_from_slice(&binary);
storage.packed_vectors[pos] = RaBitQPackedVector {
dist_to_centroid: dist,
dot_correction: dot_corr,
popcount,
};
storage.vector_indices[pos] = vec_idx;
}
storage
}
pub struct RaBitQQuantiser<T> {
pub encoder: RaBitQEncoder<T>,
pub storage: RaBitQStorage<T>,
}
impl<T> RaBitQQuantiser<T>
where
T: Float + FromPrimitive + ToPrimitive + Send + Sync + Sum + ComplexField + SimdDistance,
{
pub fn new(data: MatRef<T>, metric: &Dist, n_clusters: Option<usize>, seed: usize) -> Self {
let n = data.nrows();
let dim = data.ncols();
let k = n_clusters
.unwrap_or_else(|| ((n as f64).sqrt() * 0.5).ceil() as usize)
.max(1)
.min(n);
let mut data_flat = Vec::with_capacity(n * dim);
let mut data_norms = Vec::with_capacity(n);
for i in 0..n {
let row = data.row(i);
let vec: Vec<T> = row.iter().cloned().collect();
let norm = compute_l2_norm(&vec);
data_norms.push(norm);
match metric {
Dist::Cosine => {
if norm > T::epsilon() {
data_flat.extend(vec.iter().map(|&x| x / norm));
} else {
data_flat.extend(vec);
}
}
Dist::Euclidean => {
data_flat.extend(vec);
}
}
}
let cluster_norms = if matches!(metric, Dist::Cosine) {
vec![T::one(); n]
} else {
data_norms
};
let centroids_flat = train_centroids(
&data_flat,
dim,
n,
k,
metric,
RABITQ_K_MEANS_ITER,
seed,
false,
);
let centroid_norms: Vec<T> = (0..k)
.map(|c| {
let cent = ¢roids_flat[c * dim..(c + 1) * dim];
compute_l2_norm(cent)
})
.collect();
let assignments = assign_all_parallel(
&data_flat,
&cluster_norms,
dim,
n,
¢roids_flat,
¢roid_norms,
k,
metric,
);
let encoder = RaBitQEncoder::new(dim, *metric, seed as u64);
let storage = build_rabitq_storage(
&data_flat,
dim,
n,
¢roids_flat,
k,
&assignments,
&encoder,
);
Self { encoder, storage }
}
#[inline]
pub fn encode_query(&self, query: &[T], cluster_idx: usize) -> RaBitQQuery<T> {
let centroid = self.storage.centroid(cluster_idx);
self.encoder.encode_query(query, centroid)
}
pub fn n_clusters(&self) -> usize {
self.storage.nlist
}
pub fn n_vectors(&self) -> usize {
self.storage.n_vectors()
}
pub fn memory_usage_bytes(&self) -> usize {
self.encoder.memory_usage_bytes() + self.storage.memory_usage_bytes()
}
}
impl<T> VectorDistanceRaBitQ<T> for RaBitQQuantiser<T>
where
T: Float + FromPrimitive,
{
fn storage(&self) -> &RaBitQStorage<T> {
&self.storage
}
fn encoder(&self) -> &RaBitQEncoder<T> {
&self.encoder
}
}
impl<T> CentroidDistance<T> for RaBitQQuantiser<T>
where
T: Float + FromPrimitive + Sum + SimdDistance,
{
fn centroids(&self) -> &[T] {
&self.storage.centroids
}
fn dim(&self) -> usize {
self.storage.dim
}
fn nlist(&self) -> usize {
self.storage.nlist
}
fn metric(&self) -> Dist {
self.encoder.metric
}
fn centroids_norm(&self) -> &[T] {
&self.storage.centroids_norm
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
fn sample_data_2d() -> Vec<f32> {
vec![
1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0, -1.0, 0.5, 0.5, -0.5, 0.5,
]
}
#[test]
fn test_encoder_creation() {
let encoder = RaBitQEncoder::<f32>::new(4, Dist::Euclidean, 42);
assert_eq!(encoder.dim, 4);
assert_eq!(encoder.n_bytes, 1);
assert_eq!(encoder.rotation.len(), 16);
}
#[test]
fn test_rotation_orthogonality() {
let dim = 8;
let encoder = RaBitQEncoder::<f32>::new(dim, Dist::Euclidean, 42);
for i in 0..dim {
for j in 0..dim {
let mut dot = 0.0;
for k in 0..dim {
dot += encoder.rotation[i * dim + k] * encoder.rotation[j * dim + k];
}
let expected = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(dot, expected, epsilon = 1e-5);
}
}
}
#[test]
fn test_encode_vector_basic() {
let encoder = RaBitQEncoder::<f32>::new(4, Dist::Euclidean, 42);
let vec = vec![1.0, 0.0, 0.0, 0.0];
let centroid = vec![0.0, 0.0, 0.0, 0.0];
let (binary, dist, correction, _) = encoder.encode_vector(&vec, ¢roid);
assert_eq!(binary.len(), 1); assert_abs_diff_eq!(dist, 1.0, epsilon = 1e-5);
assert!(correction > 0.0);
}
#[test]
fn test_encode_vector_with_centroid() {
let encoder = RaBitQEncoder::<f32>::new(4, Dist::Euclidean, 42);
let vec = vec![2.0, 2.0, 0.0, 0.0];
let centroid = vec![1.0, 1.0, 0.0, 0.0];
let (_, dist, _, _) = encoder.encode_vector(&vec, ¢roid);
let expected_dist = (1.0f32 + 1.0f32).sqrt();
assert_abs_diff_eq!(dist, expected_dist, epsilon = 1e-5);
}
#[test]
fn test_encode_query_int4_range() {
let encoder = RaBitQEncoder::<f32>::new(8, Dist::Euclidean, 42);
let query = vec![1.0; 8];
let centroid = vec![0.0; 8];
let encoded = encoder.encode_query(&query, ¢roid);
assert_eq!(encoded.quantised.len(), 8);
for &val in &encoded.quantised {
assert!(val <= 15); }
assert_eq!(
encoded.sum_quantised,
encoded.quantised.iter().map(|&x| x as u32).sum::<u32>()
);
}
#[test]
fn test_encode_query_cosine_normalises() {
let encoder = RaBitQEncoder::<f32>::new(4, Dist::Cosine, 42);
let query = vec![2.0, 0.0, 0.0, 0.0]; let centroid = vec![0.0; 4];
let encoded = encoder.encode_query(&query, ¢roid);
assert_abs_diff_eq!(encoded.dist_to_centroid, 1.0, epsilon = 1e-5);
}
#[test]
fn test_storage_creation() {
let storage = RaBitQStorage::<f32>::with_capacity(10, 100, 8);
assert_eq!(storage.nlist, 10);
assert_eq!(storage.dim, 8);
assert_eq!(storage.n_bytes, 1);
assert_eq!(storage.offsets.len(), 11);
}
#[test]
fn test_build_rabitq_storage() {
let data = sample_data_2d();
let dim = 2;
let n = 6;
let nlist = 2;
let centroids = vec![0.5, 0.0, -0.5, 0.0]; let assignments = vec![0, 0, 1, 1, 0, 1]; let encoder = RaBitQEncoder::new(dim, Dist::Euclidean, 42);
let storage =
build_rabitq_storage(&data, dim, n, ¢roids, nlist, &assignments, &encoder);
assert_eq!(storage.nlist, 2);
assert_eq!(storage.n_vectors(), 6);
assert_eq!(storage.cluster_size(0), 3);
assert_eq!(storage.cluster_size(1), 3);
assert_eq!(storage.centroids.len(), 4); assert_eq!(storage.centroids_norm.len(), 2);
}
#[test]
fn test_storage_accessors() {
let data = sample_data_2d();
let dim = 2;
let n = 6;
let nlist = 2;
let centroids = vec![0.5, 0.0, -0.5, 0.0];
let assignments = vec![0, 0, 1, 1, 0, 1];
let encoder = RaBitQEncoder::new(dim, Dist::Euclidean, 42);
let storage =
build_rabitq_storage(&data, dim, n, ¢roids, nlist, &assignments, &encoder);
let centroid_0 = storage.centroid(0);
assert_eq!(centroid_0.len(), dim);
assert_abs_diff_eq!(centroid_0[0], 0.5, epsilon = 1e-5);
let indices_0 = storage.cluster_vector_indices(0);
assert_eq!(indices_0.len(), 3);
let binary_0 = storage.cluster_binary_codes(0);
assert_eq!(binary_0.len(), 3); }
#[test]
fn test_quantiser_creation_euclidean() {
let data = sample_data_2d();
let mat = Mat::from_fn(6, 2, |i, j| data[i * 2 + j]);
let quantiser = RaBitQQuantiser::new(mat.as_ref(), &Dist::Euclidean, Some(2), 42);
assert_eq!(quantiser.n_clusters(), 2);
assert_eq!(quantiser.n_vectors(), 6);
assert_eq!(quantiser.encoder.dim, 2);
}
#[test]
fn test_quantiser_creation_cosine() {
let data = sample_data_2d();
let mat = Mat::from_fn(6, 2, |i, j| data[i * 2 + j]);
let quantiser = RaBitQQuantiser::new(mat.as_ref(), &Dist::Cosine, Some(2), 42);
assert_eq!(quantiser.n_clusters(), 2);
assert_eq!(quantiser.encoder.metric, Dist::Cosine);
}
#[test]
fn test_quantiser_encode_query() {
let data = sample_data_2d();
let mat = Mat::from_fn(6, 2, |i, j| data[i * 2 + j]);
let quantiser = RaBitQQuantiser::new(mat.as_ref(), &Dist::Euclidean, Some(2), 42);
let query = vec![0.8, 0.2];
let encoded = quantiser.encode_query(&query, 0);
assert_eq!(encoded.quantised.len(), 2);
assert!(encoded.dist_to_centroid >= 0.0);
assert!(encoded.sum_quantised <= 30); }
#[test]
fn test_quantiser_default_nlist() {
let data = sample_data_2d();
let mat = Mat::from_fn(6, 2, |i, j| data[i * 2 + j]);
let quantiser = RaBitQQuantiser::new(mat.as_ref(), &Dist::Euclidean, None, 42);
assert!(quantiser.n_clusters() >= 1);
}
#[test]
fn test_encode_zero_residual() {
let encoder = RaBitQEncoder::<f32>::new(4, Dist::Euclidean, 42);
let vec = vec![1.0, 2.0, 3.0, 4.0];
let centroid = vec.clone();
let (_, dist, _, _) = encoder.encode_vector(&vec, ¢roid);
assert_abs_diff_eq!(dist, 0.0, epsilon = 1e-5);
}
}