use crate::vector_quant::codec::VectorCodec;
use crate::vector_quant::hamming::hamming_distance;
use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
pub struct BbqQuantized(pub UnifiedQuantizedVector);
impl AsRef<UnifiedQuantizedVector> for BbqQuantized {
#[inline]
fn as_ref(&self) -> &UnifiedQuantizedVector {
&self.0
}
}
pub struct BbqQuery {
pub centered: Vec<f32>,
pub signs: Vec<u8>,
pub query_norm: f32,
pub query_dot_quantized: f32,
}
pub struct BbqCodec {
pub dim: usize,
centroid: Vec<f32>,
pub oversample: u8,
}
impl BbqCodec {
pub fn calibrate(vectors: &[&[f32]], dim: usize, oversample: u8) -> Self {
let mut centroid = vec![0.0f32; dim];
if vectors.is_empty() {
return Self {
dim,
centroid,
oversample,
};
}
for v in vectors {
for (c, &x) in centroid.iter_mut().zip(v.iter()) {
*c += x;
}
}
let n = vectors.len() as f32;
for c in &mut centroid {
*c /= n;
}
Self {
dim,
centroid,
oversample,
}
}
fn center(&self, v: &[f32], out: &mut Vec<f32>) {
out.clear();
out.extend(v.iter().zip(self.centroid.iter()).map(|(&x, &c)| x - c));
}
fn pack_signs(centered: &[f32]) -> Vec<u8> {
let nbytes = centered.len().div_ceil(8);
let mut bits = vec![0u8; nbytes];
for (i, &x) in centered.iter().enumerate() {
if x >= 0.0 {
bits[i / 8] |= 1 << (7 - (i % 8));
}
}
bits
}
fn norm(v: &[f32]) -> f32 {
v.iter().map(|&x| x * x).sum::<f32>().sqrt()
}
fn dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
fn dequantize(packed: &[u8], residual_norm: f32, dim: usize) -> Vec<f32> {
let scale = if dim > 0 {
residual_norm / (dim as f32).sqrt()
} else {
0.0
};
(0..dim)
.map(|i| {
let bit = (packed[i / 8] >> (7 - (i % 8))) & 1;
if bit != 0 { scale } else { -scale }
})
.collect()
}
}
impl VectorCodec for BbqCodec {
type Quantized = BbqQuantized;
type Query = BbqQuery;
fn encode(&self, v: &[f32]) -> BbqQuantized {
let mut centered = Vec::with_capacity(self.dim);
self.center(v, &mut centered);
let packed = Self::pack_signs(¢ered);
let residual_norm = Self::norm(¢ered);
let sign_fp: Vec<f32> = centered
.iter()
.map(|&x| if x >= 0.0 { 1.0 } else { -1.0 })
.collect();
let dot_vs = Self::dot(¢ered, &sign_fp);
let dot_quantized = if residual_norm > 0.0 {
dot_vs / residual_norm
} else {
0.0
};
let centroid_norm = Self::norm(&self.centroid);
let dot_vc = Self::dot(¢ered, &self.centroid);
let query_alignment = if centroid_norm > 0.0 {
dot_vc / centroid_norm
} else {
0.0
};
let reserved = [0u8; 8];
let header = QuantHeader {
quant_mode: QuantMode::Bbq as u16,
dim: self.dim as u16,
global_scale: query_alignment,
residual_norm,
dot_quantized,
outlier_bitmask: 0,
reserved,
};
let uqv = UnifiedQuantizedVector::new(header, &packed, &[]).expect(
"BBQ encode: UnifiedQuantizedVector construction must succeed with no outliers",
);
BbqQuantized(uqv)
}
fn prepare_query(&self, q: &[f32]) -> BbqQuery {
let mut centered = Vec::with_capacity(self.dim);
self.center(q, &mut centered);
let signs = Self::pack_signs(¢ered);
let query_norm = Self::norm(¢ered);
let sign_fp: Vec<f32> = centered
.iter()
.map(|&x| if x >= 0.0 { 1.0 } else { -1.0 })
.collect();
let dot_vs = Self::dot(¢ered, &sign_fp);
let query_dot_quantized = if query_norm > 0.0 {
dot_vs / query_norm
} else {
0.0
};
BbqQuery {
centered,
signs,
query_norm,
query_dot_quantized,
}
}
fn fast_symmetric_distance(&self, q: &BbqQuantized, v: &BbqQuantized) -> f32 {
let q_bits = q.0.packed_bits();
let v_bits = v.0.packed_bits();
let ham = hamming_distance(q_bits, v_bits);
let dim = self.dim as f32;
let dot_estimate = 1.0 - 2.0 * ham as f32 / dim;
let q_n = q.0.header().residual_norm;
let v_n = v.0.header().residual_norm;
(q_n * q_n + v_n * v_n - 2.0 * q_n * v_n * dot_estimate).max(0.0)
}
fn exact_asymmetric_distance(&self, q: &BbqQuery, v: &BbqQuantized) -> f32 {
let header = v.0.header();
let recon = Self::dequantize(v.0.packed_bits(), header.residual_norm, self.dim);
q.centered
.iter()
.zip(recon.iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum::<f32>()
.sqrt()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn rand_vec(seed: u64, dim: usize) -> Vec<f32> {
let mut x = seed
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(0..dim)
.map(|_| {
x = x
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((x >> 33) as f32) / (u32::MAX as f32) * 4.0 - 2.0
})
.collect()
}
#[test]
fn calibrate_centroid_mean() {
let dim = 8;
let a = vec![1.0f32; dim];
let b = vec![3.0f32; dim];
let c = vec![2.0f32; dim];
let refs: Vec<&[f32]> = vec![&a, &b, &c];
let codec = BbqCodec::calibrate(&refs, dim, 3);
for &x in &codec.centroid {
assert!((x - 2.0).abs() < 1e-5, "expected centroid 2.0, got {x}");
}
}
#[test]
fn calibrate_empty_gives_zero_centroid() {
let codec = BbqCodec::calibrate(&[], 4, 3);
assert!(codec.centroid.iter().all(|&x| x == 0.0));
}
#[test]
fn encode_packed_bits_length() {
let dim = 128;
let v: Vec<f32> = (0..dim).map(|i| i as f32).collect();
let refs: Vec<&[f32]> = vec![v.as_slice()];
let codec = BbqCodec::calibrate(&refs, dim, 3);
let q = codec.encode(&v);
let expected_bytes = dim.div_ceil(8);
assert_eq!(
q.0.packed_bits().len(),
expected_bytes,
"packed bits length should be dim.div_ceil(8)"
);
}
#[test]
fn encode_odd_dim_packed_bits_length() {
let dim = 17;
let v: Vec<f32> = (0..dim).map(|i| i as f32 - 8.0).collect();
let refs: Vec<&[f32]> = vec![v.as_slice()];
let codec = BbqCodec::calibrate(&refs, dim, 3);
let q = codec.encode(&v);
assert_eq!(q.0.packed_bits().len(), 3);
}
#[test]
fn hamming_scalar_vs_self_zero() {
let bits = vec![0b10101010u8, 0b11001100, 0b11110000];
assert_eq!(hamming_distance(&bits, &bits), 0);
}
#[test]
fn hamming_scalar_known_distance() {
let a = vec![0xFFu8];
let b = vec![0x00u8];
assert_eq!(hamming_distance(&a, &b), 8);
}
#[test]
fn hamming_multi_byte_agreement() {
let dim = 64;
let a: Vec<u8> = (0..dim as u8).collect();
let b: Vec<u8> = a.iter().map(|&x| !x).collect();
assert_eq!(hamming_distance(&a, &b), 512);
}
#[test]
fn distance_non_negative_finite() {
let dim = 32;
let vecs: Vec<Vec<f32>> = (0..8).map(|i| rand_vec(i, dim)).collect();
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let codec = BbqCodec::calibrate(&refs, dim, 3);
for i in 0..vecs.len() {
for j in 0..vecs.len() {
let qi = codec.encode(&vecs[i]);
let qj = codec.encode(&vecs[j]);
let sym = codec.fast_symmetric_distance(&qi, &qj);
assert!(
sym.is_finite() && sym >= 0.0,
"fast_symmetric_distance({i},{j}) = {sym}"
);
let query = codec.prepare_query(&vecs[i]);
let asym = codec.exact_asymmetric_distance(&query, &qj);
assert!(
asym.is_finite() && asym >= 0.0,
"exact_asymmetric_distance({i},{j}) = {asym}"
);
}
}
}
#[test]
fn oversample_default_is_three() {
let codec = BbqCodec::calibrate(&[], 4, 3);
assert_eq!(codec.oversample, 3);
}
#[test]
fn encode_quant_mode_is_bbq() {
let dim = 16;
let v: Vec<f32> = vec![1.0; dim];
let refs: Vec<&[f32]> = vec![v.as_slice()];
let codec = BbqCodec::calibrate(&refs, dim, 3);
let q = codec.encode(&v);
assert_eq!(q.0.header().quant_mode, QuantMode::Bbq as u16);
}
}