use wide::f32x8;
#[cfg(target_arch = "x86_64")]
use crate::superfile::vector::simd_dispatch::avx512_enabled;
const BITS_PER_CODE_BYTE: usize = 8;
const SIGN_TABLE_BYTE_PATTERNS: usize = 256;
const RABITQ_POSITIVE_SIGN: f32 = 1.0;
const RABITQ_NEGATIVE_SIGN: f32 = -1.0;
#[cfg_attr(not(target_arch = "x86_64"), allow(dead_code))]
const RABITQ_DOT_POS_COEFF: f32 = 2.0;
#[cfg_attr(not(target_arch = "x86_64"), allow(dead_code))]
const AVX512_F32_LANES: usize = 16;
#[derive(Debug, Clone)]
pub struct BitQuantizer {
pub dim: usize,
sign_table: Box<[f32; SIGN_TABLE_BYTE_PATTERNS * BITS_PER_CODE_BYTE]>,
}
impl BitQuantizer {
pub fn new(dim: usize) -> Self {
let mut table = Box::new([0.0f32; SIGN_TABLE_BYTE_PATTERNS * BITS_PER_CODE_BYTE]);
for b in 0..SIGN_TABLE_BYTE_PATTERNS {
for bit in 0..BITS_PER_CODE_BYTE {
let set = (b >> bit) & 1;
table[b * BITS_PER_CODE_BYTE + bit] = if set == 1 {
RABITQ_POSITIVE_SIGN
} else {
RABITQ_NEGATIVE_SIGN
};
}
}
Self {
dim,
sign_table: table,
}
}
#[inline]
pub fn code_bytes(&self) -> usize {
self.dim.div_ceil(BITS_PER_CODE_BYTE)
}
#[inline]
pub fn encode_rotated_into(&self, rotated: &[f32], out: &mut [u8]) {
debug_assert_eq!(rotated.len(), self.dim);
debug_assert_eq!(out.len(), self.code_bytes());
let zero = f32x8::ZERO;
let full_bytes = self.dim / BITS_PER_CODE_BYTE;
for byte_idx in 0..full_bytes {
let lane: [f32; BITS_PER_CODE_BYTE] = rotated
[byte_idx * BITS_PER_CODE_BYTE..byte_idx * BITS_PER_CODE_BYTE + BITS_PER_CODE_BYTE]
.try_into()
.expect("slice [byte_idx*8..byte_idx*8+8] has length 8");
let v = f32x8::from(lane);
out[byte_idx] = v.simd_gt(zero).to_bitmask() as u8;
}
let tail_start = full_bytes * BITS_PER_CODE_BYTE;
if tail_start < self.dim {
let mut byte: u8 = 0;
for i in 0..(self.dim - tail_start) {
if rotated[tail_start + i] > 0.0 {
byte |= 1u8 << i;
}
}
out[full_bytes] = byte;
}
}
#[inline]
pub fn estimate_dot_rotated(&self, q_rot: &[f32], code: &[u8]) -> f32 {
let q_total: f32 = q_rot.iter().sum();
self.estimate_dot_rotated_with_total(q_rot, code, q_total)
}
#[inline]
pub fn estimate_dot_rotated_with_total(&self, q_rot: &[f32], code: &[u8], q_total: f32) -> f32 {
debug_assert_eq!(q_rot.len(), self.dim);
debug_assert_eq!(code.len(), self.code_bytes());
#[cfg(target_arch = "x86_64")]
if avx512_enabled() {
return unsafe { estimate_dot_rotated_avx512(q_rot, code, q_total, self.dim) };
}
let _ = q_total; estimate_dot_rotated_wide(&self.sign_table, q_rot, code, self.dim)
}
}
#[inline]
fn estimate_dot_rotated_wide(
sign_table: &[f32; SIGN_TABLE_BYTE_PATTERNS * BITS_PER_CODE_BYTE],
q_rot: &[f32],
code: &[u8],
dim: usize,
) -> f32 {
let full_bytes = dim / BITS_PER_CODE_BYTE;
let mut acc = f32x8::ZERO;
for byte_idx in 0..full_bytes {
let b = code[byte_idx] as usize;
let signs_slice: &[f32; BITS_PER_CODE_BYTE] = (&sign_table
[b * BITS_PER_CODE_BYTE..b * BITS_PER_CODE_BYTE + BITS_PER_CODE_BYTE])
.try_into()
.expect("slice [b*8..b*8+8] has length 8");
let q_slice: &[f32; BITS_PER_CODE_BYTE] = (&q_rot
[byte_idx * BITS_PER_CODE_BYTE..byte_idx * BITS_PER_CODE_BYTE + BITS_PER_CODE_BYTE])
.try_into()
.expect("slice [byte_idx*8..byte_idx*8+8] has length 8");
let signs = f32x8::from(*signs_slice);
let q_block = f32x8::from(*q_slice);
acc += q_block * signs;
}
let mut sum: f32 = acc.reduce_add();
let tail_start = full_bytes * BITS_PER_CODE_BYTE;
if tail_start < dim {
let byte = code[full_bytes] as usize;
for i in 0..(dim - tail_start) {
let bit = (byte >> i) & 1;
let s = if bit == 1 {
RABITQ_POSITIVE_SIGN
} else {
RABITQ_NEGATIVE_SIGN
};
sum += q_rot[tail_start + i] * s;
}
}
sum
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn estimate_dot_rotated_avx512(q_rot: &[f32], code: &[u8], q_total: f32, dim: usize) -> f32 {
use std::arch::x86_64::*;
debug_assert_eq!(q_rot.len(), dim);
debug_assert_eq!(code.len(), dim.div_ceil(BITS_PER_CODE_BYTE));
unsafe {
let mut pos_sum = _mm512_setzero_ps();
let mut i: usize = 0;
while i + AVX512_F32_LANES <= dim {
let bits = u16::from_le_bytes([
code[i / BITS_PER_CODE_BYTE],
code[i / BITS_PER_CODE_BYTE + 1],
]);
let q = _mm512_loadu_ps(q_rot.as_ptr().add(i));
pos_sum = _mm512_mask_add_ps(pos_sum, bits, pos_sum, q);
i += AVX512_F32_LANES;
}
let mut pos: f32 = _mm512_reduce_add_ps(pos_sum);
if i + BITS_PER_CODE_BYTE <= dim {
let bits = code[i / BITS_PER_CODE_BYTE];
let q8 = _mm256_loadu_ps(q_rot.as_ptr().add(i));
let masked = _mm256_maskz_mov_ps(bits, q8);
let zext = _mm512_zextps256_ps512(masked);
pos += _mm512_reduce_add_ps(zext);
i += BITS_PER_CODE_BYTE;
}
while i < dim {
let bit = ((code[i / BITS_PER_CODE_BYTE] >> (i % BITS_PER_CODE_BYTE)) & 1) != 0;
if bit {
pos += q_rot[i];
}
i += 1;
}
RABITQ_DOT_POS_COEFF * pos - q_total
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx(a: f32, b: f32, eps: f32) -> bool {
(a - b).abs() < eps
}
#[test]
fn code_bytes_for_byte_aligned_dims() {
for &dim in &[8, 16, 32, 64, 128, 256, 384, 768, 1024] {
assert_eq!(BitQuantizer::new(dim).code_bytes(), dim / 8);
}
}
#[test]
fn code_bytes_for_non_aligned_dims_rounds_up() {
assert_eq!(BitQuantizer::new(1).code_bytes(), 1);
assert_eq!(BitQuantizer::new(7).code_bytes(), 1);
assert_eq!(BitQuantizer::new(9).code_bytes(), 2);
assert_eq!(BitQuantizer::new(15).code_bytes(), 2);
assert_eq!(BitQuantizer::new(17).code_bytes(), 3);
}
#[test]
fn encode_all_positive_sets_every_bit() {
let q = BitQuantizer::new(8);
let v = vec![1.0; 8];
let mut out = vec![0u8; 1];
q.encode_rotated_into(&v, &mut out);
assert_eq!(out, vec![0xFF]);
}
#[test]
fn encode_all_negative_clears_every_bit() {
let q = BitQuantizer::new(8);
let v = vec![-1.0; 8];
let mut out = vec![0u8; 1];
q.encode_rotated_into(&v, &mut out);
assert_eq!(out, vec![0x00]);
}
#[test]
fn encode_zero_is_negative() {
let q = BitQuantizer::new(8);
let v = vec![0.0; 8];
let mut out = vec![0u8; 1];
q.encode_rotated_into(&v, &mut out);
assert_eq!(out, vec![0x00]);
}
#[test]
fn encode_single_positive_dim_sets_one_bit() {
let q = BitQuantizer::new(8);
for i in 0..8 {
let mut v = vec![-1.0; 8];
v[i] = 1.0;
let mut out = vec![0u8; 1];
q.encode_rotated_into(&v, &mut out);
assert_eq!(out, vec![1u8 << i], "dim {i}");
}
}
#[test]
fn encode_non_aligned_dim_uses_partial_byte() {
let q = BitQuantizer::new(12);
let mut v = vec![-1.0; 12];
v[0] = 1.0;
v[11] = 1.0;
let mut out = vec![0u8; 2];
q.encode_rotated_into(&v, &mut out);
assert_eq!(out, vec![0x01, 0x08]); }
#[test]
fn estimate_query_against_self_returns_l1_sum_of_query() {
let q = BitQuantizer::new(8);
let q_rot = vec![3.0, -1.0, 2.0, -4.0, 5.0, -6.0, 7.0, -2.0];
let mut code = vec![0u8; 1];
q.encode_rotated_into(&q_rot, &mut code);
let est = q.estimate_dot_rotated(&q_rot, &code);
let expected: f32 = q_rot.iter().map(|x| x.abs()).sum();
assert!(approx(est, expected, 1e-5));
}
#[test]
fn estimate_query_against_opposite_returns_negative_sum() {
let q = BitQuantizer::new(8);
let q_rot = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let neg = q_rot.iter().map(|&x| -x).collect::<Vec<_>>();
let mut code = vec![0u8; 1];
q.encode_rotated_into(&neg, &mut code);
let est = q.estimate_dot_rotated(&q_rot, &code);
let expected: f32 = -q_rot.iter().map(|x| x.abs()).sum::<f32>();
assert!(approx(est, expected, 1e-5));
}
#[test]
fn estimate_handles_tail_dim() {
let q = BitQuantizer::new(12);
let q_rot: Vec<f32> = (1..=12).map(|i| i as f32).collect();
let mut code = vec![0u8; 2];
q.encode_rotated_into(&q_rot, &mut code);
let est = q.estimate_dot_rotated(&q_rot, &code);
let expected: f32 = q_rot.iter().sum(); assert!(approx(est, expected, 1e-5));
}
#[test]
fn estimate_zero_query_yields_zero() {
let q = BitQuantizer::new(16);
let q_rot = vec![0.0; 16];
let any_code = vec![0xAAu8; 2];
assert_eq!(q.estimate_dot_rotated(&q_rot, &any_code), 0.0);
}
#[test]
fn estimate_is_unbiased_indicator_of_alignment() {
let q = BitQuantizer::new(8);
let q_rot = vec![1.0; 8];
let code_all = vec![0xFFu8];
let code_none = vec![0x00u8];
let code_half = vec![0x0Fu8];
assert!(approx(q.estimate_dot_rotated(&q_rot, &code_all), 8.0, 1e-5));
assert!(approx(
q.estimate_dot_rotated(&q_rot, &code_none),
-8.0,
1e-5
));
assert!(approx(
q.estimate_dot_rotated(&q_rot, &code_half),
0.0,
1e-5
));
}
#[test]
fn sign_table_has_correct_size() {
let q = BitQuantizer::new(128);
assert_eq!(q.sign_table.len(), 256 * 8);
}
#[test]
fn quantizer_is_clone() {
let q = BitQuantizer::new(64);
let _q2 = q.clone();
}
fn fake_vec(dim: usize, seed: u32) -> Vec<f32> {
(0..dim)
.map(|i| {
let x = ((i as u32).wrapping_mul(2654435761).wrapping_add(seed)) as i32;
(x as f32) * 1e-6
})
.collect()
}
fn fake_code(quant: &BitQuantizer, seed: u32) -> Vec<u8> {
let d_vec = fake_vec(quant.dim, seed);
let mut code = vec![0u8; quant.code_bytes()];
quant.encode_rotated_into(&d_vec, &mut code);
code
}
#[test]
#[cfg(target_arch = "x86_64")]
fn estimate_avx512_matches_wide_across_lengths() {
if !avx512_enabled() {
eprintln!("estimate_avx512_matches_wide_across_lengths: skipped, no AVX-512");
return;
}
for dim in [
1usize, 7, 8, 15, 16, 17, 23, 24, 31, 32, 40, 48, 64, 96, 128, 384, 768,
] {
let q = BitQuantizer::new(dim);
let q_rot = fake_vec(dim, 0xC0DE);
let code = fake_code(&q, 0xD0DE);
let q_total: f32 = q_rot.iter().sum();
let want = estimate_dot_rotated_wide(&q.sign_table, &q_rot, &code, dim);
let got = unsafe { estimate_dot_rotated_avx512(&q_rot, &code, q_total, dim) };
let tol = 1e-4 * want.abs().max(1.0) + 1e-5 * (dim as f32).sqrt();
assert!(
(want - got).abs() <= tol,
"dim {dim}: avx512 {got} vs wide {want} (tol {tol})"
);
}
}
#[test]
fn estimate_inline_and_precomputed_total_agree() {
for &dim in &[16usize, 32, 33, 64, 384] {
let q = BitQuantizer::new(dim);
let q_rot = fake_vec(dim, 0xFEED);
let code = fake_code(&q, 0xBABE);
let inline = q.estimate_dot_rotated(&q_rot, &code);
let q_total: f32 = q_rot.iter().sum();
let precomp = q.estimate_dot_rotated_with_total(&q_rot, &code, q_total);
assert_eq!(
inline, precomp,
"dim {dim}: inline {inline} vs precomp {precomp}"
);
}
}
#[test]
#[ignore]
#[cfg(target_arch = "x86_64")]
fn avx512_microbench_estimate_dot_rotated() {
if !avx512_enabled() {
eprintln!("avx512_microbench: skipped, no AVX-512 on this host");
return;
}
use std::{hint::black_box, time::Instant};
eprintln!();
eprintln!("### RaBitQ estimator — AVX-512 mask-add vs wide sign-table (ns per call)\n");
eprintln!("| kernel | dim | wide ns | avx512 ns | speedup |");
eprintln!("|--------|----:|--------:|----------:|--------:|");
for &dim in &[128usize, 384, 768, 1024, 1536] {
let q = BitQuantizer::new(dim);
let q_rot = fake_vec(dim, 0xC0DE);
let code = fake_code(&q, 0xD0DE);
let q_total: f32 = q_rot.iter().sum();
let iters: u32 = (10_000_000u64 / (dim as u64).max(1)).max(50_000) as u32;
for _ in 0..(iters / 10).max(64) {
black_box(estimate_dot_rotated_wide(
black_box(&q.sign_table),
black_box(&q_rot),
black_box(&code),
black_box(dim),
));
}
let t = Instant::now();
for _ in 0..iters {
black_box(estimate_dot_rotated_wide(
black_box(&q.sign_table),
black_box(&q_rot),
black_box(&code),
black_box(dim),
));
}
let wide_ns = t.elapsed().as_secs_f64() * 1e9 / iters as f64;
for _ in 0..(iters / 10).max(64) {
black_box(unsafe {
estimate_dot_rotated_avx512(
black_box(&q_rot),
black_box(&code),
black_box(q_total),
black_box(dim),
)
});
}
let t = Instant::now();
for _ in 0..iters {
black_box(unsafe {
estimate_dot_rotated_avx512(
black_box(&q_rot),
black_box(&code),
black_box(q_total),
black_box(dim),
)
});
}
let avx_ns = t.elapsed().as_secs_f64() * 1e9 / iters as f64;
eprintln!(
"| `quant::estimate_dot_rotated` | {dim} | {:>7.1} | {:>7.1} | {:>5.2}× |",
wide_ns,
avx_ns,
wide_ns / avx_ns,
);
}
}
}