use crate::hnsw::graph::HNSWIndex;
use crate::RetrieveError;
use qntz::rabitq::{QuantizedVector, RaBitQConfig, RaBitQQuantizer};
pub struct SymphonyQGIndex {
index: HNSWIndex,
codes: Vec<QuantizedVector>,
quantizer: Option<RaBitQQuantizer>,
rabitq_config: RaBitQConfig,
seed: u64,
quantized_built: bool,
}
impl SymphonyQGIndex {
pub fn new(dimension: usize, m: usize, m_max: usize) -> Result<Self, RetrieveError> {
Self::with_config(dimension, m, m_max, RaBitQConfig::bits4(), 42)
}
pub fn with_config(
dimension: usize,
m: usize,
m_max: usize,
rabitq_config: RaBitQConfig,
seed: u64,
) -> Result<Self, RetrieveError> {
let index = HNSWIndex::new(dimension, m, m_max)?;
Ok(Self {
index,
codes: Vec::new(),
quantizer: None,
rabitq_config,
seed,
quantized_built: false,
})
}
pub fn with_hnsw_params(
dimension: usize,
params: super::graph::HNSWParams,
rabitq_config: RaBitQConfig,
seed: u64,
) -> Result<Self, RetrieveError> {
let index = HNSWIndex::with_params(dimension, params)?;
Ok(Self {
index,
codes: Vec::new(),
quantizer: None,
rabitq_config,
seed,
quantized_built: false,
})
}
pub fn add_slice(&mut self, doc_id: u32, vector: &[f32]) -> Result<(), RetrieveError> {
self.index.add_slice(doc_id, vector)
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
self.index.build()?;
self.quantize_vectors()?;
Ok(())
}
fn quantize_vectors(&mut self) -> Result<(), RetrieveError> {
let n = self.index.num_vectors;
if n == 0 {
self.quantized_built = true;
return Ok(());
}
let dim = self.index.dimension;
let mut quantizer = RaBitQQuantizer::with_config(dim, self.seed, self.rabitq_config)
.map_err(|e| RetrieveError::InvalidParameter(format!("RaBitQ init: {e}")))?;
quantizer
.fit(&self.index.vectors, n)
.map_err(|e| RetrieveError::InvalidParameter(format!("RaBitQ fit: {e}")))?;
let mut codes = Vec::with_capacity(n);
for i in 0..n {
let vec = self.index.get_vector(i);
let qv = quantizer
.quantize(vec)
.map_err(|e| RetrieveError::InvalidParameter(format!("RaBitQ quantize: {e}")))?;
codes.push(qv);
}
self.quantizer = Some(quantizer);
self.codes = codes;
self.quantized_built = true;
Ok(())
}
pub fn search(
&self,
query: &[f32],
k: usize,
ef: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
self.check_search_ready(query)?;
let results = self.search_quantized_graph(query, ef)?;
let mut output: Vec<(u32, f32)> = results
.into_iter()
.take(k)
.map(|(internal_id, dist)| (self.index.doc_ids[internal_id as usize], dist))
.collect();
output.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
Ok(output)
}
pub fn search_reranked(
&self,
query: &[f32],
k: usize,
ef: usize,
rerank_pool: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
self.check_search_ready(query)?;
let pool = rerank_pool.max(k);
let candidates = self.search_quantized_graph(query, ef.max(pool))?;
let dist_fn = self.index.dist_fn();
let mut reranked: Vec<(u32, f32)> = candidates
.into_iter()
.take(pool)
.map(|(internal_id, _approx_dist)| {
let vec = self.index.get_vector(internal_id as usize);
let exact_dist = dist_fn(query, vec);
(self.index.doc_ids[internal_id as usize], exact_dist)
})
.collect();
reranked.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
reranked.truncate(k);
Ok(reranked)
}
pub fn len(&self) -> usize {
self.index.num_vectors
}
pub fn is_empty(&self) -> bool {
self.index.num_vectors == 0
}
pub fn inner(&self) -> &HNSWIndex {
&self.index
}
fn check_search_ready(&self, query: &[f32]) -> Result<(), RetrieveError> {
if !self.index.is_built() {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if !self.quantized_built {
return Err(RetrieveError::InvalidParameter(
"quantization not built (call build())".into(),
));
}
if query.len() != self.index.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.index.dimension,
});
}
if self.index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
Ok(())
}
fn rotate_query(&self, query: &[f32]) -> Result<Vec<f32>, RetrieveError> {
self.quantizer
.as_ref()
.ok_or_else(|| {
RetrieveError::InvalidParameter("quantizer must be set after build".into())
})?
.rotate_query(query)
.map_err(|e| RetrieveError::InvalidParameter(format!("rotate query: {e}")))
}
fn search_quantized_graph(
&self,
query: &[f32],
ef: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
let rotated_query = self.rotate_query(query)?;
let codes = &self.codes;
let (entry_point, entry_layer) = self.index.entry_point().unwrap_or((0, 0));
let mut current = entry_point;
let mut current_dist = approx_dist_sqr(&rotated_query, &codes[current as usize]);
for layer_idx in (1..=entry_layer).rev() {
if layer_idx >= self.index.layers.len() {
continue;
}
let layer = &self.index.layers[layer_idx];
let mut changed = true;
while changed {
changed = false;
let neighbors = layer.get_neighbors(current);
for &neighbor_id in neighbors.iter() {
let dist = approx_dist_sqr(&rotated_query, &codes[neighbor_id as usize]);
if dist < current_dist {
current_dist = dist;
current = neighbor_id;
changed = true;
}
}
}
}
if self.index.layers.is_empty() {
return Ok(Vec::new());
}
let base_layer = &self.index.layers[0];
let dist_fn = |_q: &[f32], node_id: u32| -> f32 {
approx_dist_sqr(&rotated_query, &codes[node_id as usize])
};
Ok(crate::hnsw::search::greedy_search_layer_custom(
query,
current,
base_layer,
&self.index.vectors,
self.index.dimension,
ef,
&dist_fn,
))
}
}
#[inline]
fn approx_dist_sqr(rotated_query: &[f32], qv: &QuantizedVector) -> f32 {
RaBitQQuantizer::approximate_l2_sqr_prerotated(rotated_query, qv)
}
#[derive(Clone, Copy)]
struct EdgeScalars {
f_add: f32,
f_rescale: f32,
ip_u_rot_codes: f32,
}
pub struct SymphonyQGVRIndex {
index: HNSWIndex,
edge_scalars: Vec<EdgeScalars>,
packed_codes: Vec<u8>,
packed_dim: usize,
total_bits: usize,
cb: f32,
neighbor_offsets: Vec<u32>,
quantizer: Option<RaBitQQuantizer>,
rabitq_config: RaBitQConfig,
seed: u64,
dimension: usize,
built: bool,
}
impl SymphonyQGVRIndex {
pub fn new(
dimension: usize,
params: super::graph::HNSWParams,
rabitq_config: RaBitQConfig,
seed: u64,
) -> Result<Self, RetrieveError> {
let index = HNSWIndex::with_params(dimension, params)?;
let total_bits = rabitq_config.total_bits;
let ex_bits = total_bits.saturating_sub(1);
let cb = -((1u32 << ex_bits) as f32 - 0.5);
let codes_per_byte = 8 / total_bits.max(1);
let packed_dim = dimension.div_ceil(codes_per_byte);
Ok(Self {
index,
edge_scalars: Vec::new(),
packed_codes: Vec::new(),
packed_dim,
total_bits,
cb,
neighbor_offsets: Vec::new(),
quantizer: None,
rabitq_config,
seed,
dimension,
built: false,
})
}
pub fn add_slice(&mut self, doc_id: u32, vector: &[f32]) -> Result<(), RetrieveError> {
self.index.add_slice(doc_id, vector)
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
self.index.build()?;
self.build_edge_codes()?;
self.built = true;
Ok(())
}
fn build_edge_codes(&mut self) -> Result<(), RetrieveError> {
let n = self.index.num_vectors;
if n == 0 {
return Ok(());
}
let dim = self.dimension;
let mut quantizer = RaBitQQuantizer::with_config(dim, self.seed, self.rabitq_config)
.map_err(|e| RetrieveError::InvalidParameter(format!("RaBitQ init: {e}")))?;
quantizer
.set_centroid(vec![0.0f32; dim])
.map_err(|e| RetrieveError::InvalidParameter(format!("RaBitQ centroid: {e}")))?;
let rotated_flat = {
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
let vectors = &self.index.vectors;
let mut flat = vec![0.0f32; n * dim];
flat.par_chunks_mut(dim)
.enumerate()
.try_for_each(|(i, chunk)| {
let v = &vectors[i * dim..(i + 1) * dim];
let r = quantizer
.rotate_query(v)
.map_err(|e| RetrieveError::InvalidParameter(format!("rotate: {e}")))?;
chunk.copy_from_slice(&r);
Ok::<_, RetrieveError>(())
})?;
flat
}
#[cfg(not(feature = "parallel"))]
{
let mut flat = vec![0.0f32; n * dim];
for i in 0..n {
let v = self.index.get_vector(i);
let r = quantizer
.rotate_query(v)
.map_err(|e| RetrieveError::InvalidParameter(format!("rotate: {e}")))?;
flat[i * dim..(i + 1) * dim].copy_from_slice(&r);
}
flat
}
};
if self.index.layers.is_empty() {
self.quantizer = Some(quantizer);
return Ok(());
}
let base_layer = &self.index.layers[0];
let layer_len = base_layer.len();
let total_edges: usize = (0..layer_len as u32)
.map(|id| base_layer.get_neighbors(id).len())
.sum();
let packed_dim = self.packed_dim;
let total_bits = self.total_bits;
let quantize_node = |node_id: u32| -> Result<(Vec<EdgeScalars>, Vec<u8>), RetrieveError> {
let neighbors = base_layer.get_neighbors(node_id);
let u_rot = &rotated_flat[node_id as usize * dim..(node_id as usize + 1) * dim];
let mut scalars = Vec::with_capacity(neighbors.len());
let mut codes = Vec::with_capacity(neighbors.len() * packed_dim);
for &neighbor_id in neighbors.iter() {
let v_rot =
&rotated_flat[neighbor_id as usize * dim..(neighbor_id as usize + 1) * dim];
let rotated_residual: Vec<f32> = v_rot
.iter()
.zip(u_rot.iter())
.map(|(&v, &u)| v - u)
.collect();
let edge = quantizer
.quantize_edge_prerotated(u_rot, &rotated_residual)
.map_err(|e| RetrieveError::InvalidParameter(format!("quantize edge: {e}")))?;
pack_codes(&edge.quantized.codes, total_bits, dim, &mut codes);
scalars.push(EdgeScalars {
f_add: edge.quantized.f_add,
f_rescale: edge.quantized.f_rescale,
ip_u_rot_codes: edge.ip_parent_rot_codes,
});
}
Ok((scalars, codes))
};
let per_node: Vec<(Vec<EdgeScalars>, Vec<u8>)> = {
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
(0..layer_len as u32)
.into_par_iter()
.map(quantize_node)
.collect::<Result<Vec<_>, _>>()?
}
#[cfg(not(feature = "parallel"))]
{
(0..layer_len as u32)
.map(quantize_node)
.collect::<Result<Vec<_>, _>>()?
}
};
let mut edge_scalars = Vec::with_capacity(total_edges);
let mut packed_codes = Vec::with_capacity(total_edges * packed_dim);
let mut neighbor_offsets = Vec::with_capacity(layer_len + 1);
for (scalars, codes) in per_node {
neighbor_offsets.push(edge_scalars.len() as u32);
edge_scalars.extend(scalars);
packed_codes.extend(codes);
}
neighbor_offsets.push(edge_scalars.len() as u32);
self.edge_scalars = edge_scalars;
self.packed_codes = packed_codes;
self.neighbor_offsets = neighbor_offsets;
self.quantizer = Some(quantizer);
Ok(())
}
pub fn search(
&self,
query: &[f32],
k: usize,
ef: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
let internal = self.search_internal(query, k, ef)?;
Ok(internal
.into_iter()
.map(|(id, d)| (self.index.doc_ids[id as usize], d))
.collect())
}
fn search_internal(
&self,
query: &[f32],
k: usize,
ef: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built || self.index.num_vectors == 0 || self.index.layers.is_empty() {
return Ok(Vec::new());
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let quantizer = self.quantizer.as_ref().ok_or_else(|| {
RetrieveError::InvalidParameter("quantizer not built (call build())".into())
})?;
let rotated_query = quantizer
.rotate_query(query)
.map_err(|e| RetrieveError::InvalidParameter(format!("rotate query: {e}")))?;
let (entry_point, entry_layer) = self.index.entry_point().unwrap_or((0, 0));
let dist_fn_exact = self.index.dist_fn();
let mut current = entry_point;
let mut current_dist = dist_fn_exact(query, self.index.get_vector(current as usize));
for layer_idx in (1..=entry_layer).rev() {
if layer_idx >= self.index.layers.len() {
continue;
}
let layer = &self.index.layers[layer_idx];
let mut changed = true;
while changed {
changed = false;
let neighbors = layer.get_neighbors(current);
for &neighbor_id in neighbors.iter() {
let dist = dist_fn_exact(query, self.index.get_vector(neighbor_id as usize));
if dist < current_dist {
current_dist = dist;
current = neighbor_id;
changed = true;
}
}
}
}
let base_layer = &self.index.layers[0];
let edge_scalars = &self.edge_scalars;
let packed = &self.packed_codes;
let neighbor_offsets = &self.neighbor_offsets;
let packed_dim = self.packed_dim;
let vectors = &self.index.vectors;
let dim = self.dimension;
let lut = nibble_lut(self.cb);
let parent_dist_cache: std::cell::RefCell<std::collections::HashMap<u32, f32>> =
std::cell::RefCell::new(std::collections::HashMap::with_capacity(ef * 2));
let entry_vec = &vectors[current as usize * dim..(current as usize + 1) * dim];
let entry_dist: f32 = query
.iter()
.zip(entry_vec.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
parent_dist_cache.borrow_mut().insert(current, entry_dist);
let total_bits = self.total_bits;
let dist_fn = |parent_id: u32, _neighbor_id: u32, slot: usize| -> f32 {
let base_offset = neighbor_offsets[parent_id as usize] as usize;
let offset = base_offset + slot;
if offset + 2 < edge_scalars.len() {
let ptr = packed.as_ptr().wrapping_add((offset + 2) * packed_dim);
crate::hnsw::search::prefetch_read_data(ptr as *const f32);
}
let scalars = &edge_scalars[offset];
let codes = &packed[offset * packed_dim..(offset + 1) * packed_dim];
let edge_approx = if total_bits >= 4 {
approx_dist_vr_packed(&rotated_query, codes, scalars, &lut)
} else {
approx_dist_vr_binary(&rotated_query, codes, scalars)
};
let q_parent_dist = {
let cache = parent_dist_cache.borrow();
if let Some(&d) = cache.get(&parent_id) {
d
} else {
drop(cache);
let parent_vec =
&vectors[parent_id as usize * dim..(parent_id as usize + 1) * dim];
let d: f32 = query
.iter()
.zip(parent_vec.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
parent_dist_cache.borrow_mut().insert(parent_id, d);
d
}
};
(q_parent_dist + edge_approx).max(0.0)
};
let results = crate::hnsw::search::greedy_search_layer_edge_aware(
current,
entry_dist,
base_layer,
self.index.num_vectors,
ef,
&dist_fn,
);
let mut output: Vec<(u32, f32)> = results.into_iter().take(k).collect();
output.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
Ok(output)
}
pub fn search_reranked(
&self,
query: &[f32],
k: usize,
ef: usize,
rerank_pool: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built || self.index.num_vectors == 0 {
return Ok(Vec::new());
}
if self.is_compacted() {
return Err(RetrieveError::InvalidParameter(
"search_reranked unavailable after compact() -- use search() instead".into(),
));
}
let pool = rerank_pool.max(k);
let candidates = self.search_internal(query, pool, ef.max(pool))?;
let dist_fn = self.index.dist_fn();
let mut reranked: Vec<(u32, f32)> = candidates
.into_iter()
.take(pool)
.map(|(internal_id, _approx_dist)| {
let vec = self.index.get_vector(internal_id as usize);
let exact_dist = dist_fn(query, vec);
(self.index.doc_ids[internal_id as usize], exact_dist)
})
.collect();
reranked.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
reranked.truncate(k);
Ok(reranked)
}
#[allow(dead_code)]
fn approx_dist_vr_entry(&self, _rotated_query: &[f32], _entry_id: u32) -> f32 {
0.0 }
pub fn compact(&mut self) {
self.index.vectors.clear();
self.index.vectors.shrink_to_fit();
}
pub fn is_compacted(&self) -> bool {
self.index.vectors.is_empty() && self.index.num_vectors > 0
}
#[cfg(feature = "parallel")]
pub fn search_batch(
&self,
queries: &[&[f32]],
k: usize,
ef: usize,
) -> Result<Vec<Vec<(u32, f32)>>, RetrieveError> {
use rayon::prelude::*;
queries.par_iter().map(|q| self.search(q, k, ef)).collect()
}
#[cfg(feature = "parallel")]
pub fn search_reranked_batch(
&self,
queries: &[&[f32]],
k: usize,
ef: usize,
rerank_pool: usize,
) -> Result<Vec<Vec<(u32, f32)>>, RetrieveError> {
use rayon::prelude::*;
queries
.par_iter()
.map(|q| self.search_reranked(q, k, ef, rerank_pool))
.collect()
}
pub fn len(&self) -> usize {
self.index.num_vectors
}
pub fn is_empty(&self) -> bool {
self.index.num_vectors == 0
}
pub fn inner(&self) -> &HNSWIndex {
&self.index
}
pub fn memory_usage_bytes(&self) -> VRMemoryReport {
let vectors = self.index.vectors.len() * 4;
let codes = self.packed_codes.len();
let scalars = self.edge_scalars.len() * std::mem::size_of::<EdgeScalars>();
let offsets = self.neighbor_offsets.len() * 4;
let graph = self
.index
.layers
.iter()
.map(|l| l.len() * 16 * 4) .sum::<usize>();
VRMemoryReport {
vectors_bytes: vectors,
packed_codes_bytes: codes,
edge_scalars_bytes: scalars,
graph_bytes: graph,
offsets_bytes: offsets,
}
}
}
pub struct VRMemoryReport {
pub vectors_bytes: usize,
pub packed_codes_bytes: usize,
pub edge_scalars_bytes: usize,
pub graph_bytes: usize,
pub offsets_bytes: usize,
}
impl VRMemoryReport {
pub fn total(&self) -> usize {
self.vectors_bytes
+ self.packed_codes_bytes
+ self.edge_scalars_bytes
+ self.graph_bytes
+ self.offsets_bytes
}
}
#[inline]
fn pack_codes(codes: &[u16], total_bits: usize, dim: usize, out: &mut Vec<u8>) {
if total_bits >= 4 {
let pairs = dim.div_ceil(2);
for j in 0..pairs {
let hi = (codes[j * 2] & 0x0F) as u8;
let lo = if j * 2 + 1 < dim {
(codes[j * 2 + 1] & 0x0F) as u8
} else {
0
};
out.push((hi << 4) | lo);
}
} else {
let bytes = dim.div_ceil(8);
for j in 0..bytes {
let mut byte = 0u8;
for bit in 0..8 {
let idx = j * 8 + bit;
if idx < dim && codes[idx] != 0 {
byte |= 1 << (7 - bit);
}
}
out.push(byte);
}
}
}
#[inline]
fn nibble_lut(cb: f32) -> [f32; 16] {
let mut lut = [0.0f32; 16];
for (i, slot) in lut.iter_mut().enumerate() {
*slot = i as f32 + cb;
}
lut
}
#[inline]
fn approx_dist_vr_packed(
rotated_query: &[f32],
packed: &[u8],
scalars: &EdgeScalars,
lut: &[f32; 16],
) -> f32 {
let mut ip = 0.0f32;
let dim = rotated_query.len();
let pairs = dim / 2;
for j in 0..pairs {
let byte = packed[j];
let c0 = lut[(byte >> 4) as usize];
let c1 = lut[(byte & 0x0F) as usize];
ip += rotated_query[j * 2] * c0 + rotated_query[j * 2 + 1] * c1;
}
if !dim.is_multiple_of(2) {
let byte = packed[pairs];
ip += rotated_query[dim - 1] * lut[(byte >> 4) as usize];
}
(scalars.f_add + scalars.f_rescale * (ip - scalars.ip_u_rot_codes)).max(0.0)
}
#[inline]
fn approx_dist_vr_binary(rotated_query: &[f32], packed: &[u8], scalars: &EdgeScalars) -> f32 {
let dim = rotated_query.len();
let mut sum_positive = 0.0f32;
let mut sum_all = 0.0f32;
for (j, &byte) in packed.iter().enumerate() {
for bit in 0..8 {
let idx = j * 8 + bit;
if idx >= dim {
break;
}
let q = rotated_query[idx];
sum_all += q;
if byte & (1 << (7 - bit)) != 0 {
sum_positive += q;
}
}
}
let ip = sum_positive - 0.5 * sum_all;
(scalars.f_add + scalars.f_rescale * (ip - scalars.ip_u_rot_codes)).max(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_normalized_vector(seed: usize, dim: usize) -> Vec<f32> {
let v: Vec<f32> = (0..dim)
.map(|j| ((seed * dim + j) as f32 * 0.618_034).fract() * 2.0 - 1.0)
.collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
v.iter().map(|x| x / norm).collect()
}
#[test]
fn test_symphony_qg_basic() {
let dim = 32;
let n = 200;
let mut index = SymphonyQGIndex::new(dim, 8, 8).unwrap();
for i in 0..n {
index
.add_slice(i as u32, &make_normalized_vector(i, dim))
.unwrap();
}
index.build().unwrap();
let q = make_normalized_vector(0, dim);
let results = index.search_reranked(&q, 5, 32, 50).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, 0, "self-query should return doc_id 0");
}
#[test]
fn test_distance_matches_qntz() {
let dim = 32;
let n = 50;
let seed = 42;
let config = RaBitQConfig::bits4();
let vectors: Vec<Vec<f32>> = (0..n).map(|i| make_normalized_vector(i, dim)).collect();
let flat: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
let mut quantizer = RaBitQQuantizer::with_config(dim, seed, config).unwrap();
quantizer.fit(&flat, n).unwrap();
let codes: Vec<QuantizedVector> = vectors
.iter()
.map(|v| quantizer.quantize(v).unwrap())
.collect();
let query = &vectors[0];
let qntz_dist = quantizer.approximate_l2_sqr(query, &codes[1]).unwrap();
let rotated = quantizer.rotate_query(query).unwrap();
let prerotated_dist = RaBitQQuantizer::approximate_l2_sqr_prerotated(&rotated, &codes[1]);
let diff = (qntz_dist - prerotated_dist).abs();
assert!(
diff < 1e-4,
"distance mismatch: qntz={qntz_dist}, prerotated={prerotated_dist}, diff={diff}"
);
}
#[test]
fn test_symphony_qg_recall() {
let dim = 256;
let n = 300;
let mut index =
SymphonyQGIndex::with_config(dim, 16, 16, RaBitQConfig::bits4(), 42).unwrap();
let vectors: Vec<Vec<f32>> = (0..n).map(|i| make_normalized_vector(i, dim)).collect();
for (i, v) in vectors.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let mut hits = 0;
for (i, v) in vectors.iter().enumerate() {
let results = index.search_reranked(v, 1, 200, 100).unwrap();
if results.first().map(|(id, _)| *id) == Some(i as u32) {
hits += 1;
}
}
let recall = hits as f64 / n as f64;
assert!(
recall > 0.5,
"reranked self-search recall too low: {recall:.2} ({hits}/{n})"
);
}
#[test]
fn test_rabitq_distance_correlation_unnormalized() {
let dim = 128;
let n = 100;
let vectors: Vec<Vec<f32>> = (0..n)
.map(|seed| {
let v: Vec<f32> = (0..dim)
.map(|j| ((seed * dim + j) as f32 * 0.618_034).fract() * 2.0 - 1.0)
.collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
let target_norm = 5.0 + (seed as f32 % 5.0); v.iter().map(|x| x * target_norm / norm).collect()
})
.collect();
let flat: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
let mut quantizer = RaBitQQuantizer::with_config(dim, 42, RaBitQConfig::bits4()).unwrap();
quantizer.fit(&flat, n).unwrap();
let codes: Vec<QuantizedVector> = vectors
.iter()
.map(|v| quantizer.quantize(v).unwrap())
.collect();
let query = &vectors[0];
let rotated = quantizer.rotate_query(query).unwrap();
let mut true_dists: Vec<(usize, f32)> = vectors
.iter()
.enumerate()
.skip(1) .map(|(i, v)| {
let d: f32 = query
.iter()
.zip(v.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
(i, d)
})
.collect();
true_dists.sort_by(|a, b| a.1.total_cmp(&b.1));
let mut approx_dists: Vec<(usize, f32)> = (1..n)
.map(|i| {
let d = RaBitQQuantizer::approximate_l2_sqr_prerotated(&rotated, &codes[i]);
(i, d)
})
.collect();
approx_dists.sort_by(|a, b| a.1.total_cmp(&b.1));
let true_top10: std::collections::HashSet<usize> =
true_dists.iter().take(10).map(|(i, _)| *i).collect();
let approx_top10: std::collections::HashSet<usize> =
approx_dists.iter().take(10).map(|(i, _)| *i).collect();
let overlap = true_top10.intersection(&approx_top10).count();
eprintln!(
"RaBitQ unnormalized recall@10: {}/10 (true top-10 vs approx top-10)",
overlap
);
eprintln!(
" f_add range: {:.1}..{:.1}",
codes.iter().map(|c| c.f_add).fold(f32::INFINITY, f32::min),
codes
.iter()
.map(|c| c.f_add)
.fold(f32::NEG_INFINITY, f32::max),
);
eprintln!(
" f_rescale range: {:.4}..{:.4}",
codes
.iter()
.map(|c| c.f_rescale)
.fold(f32::INFINITY, f32::min),
codes
.iter()
.map(|c| c.f_rescale)
.fold(f32::NEG_INFINITY, f32::max),
);
eprintln!(
" residual_norm range: {:.2}..{:.2}",
codes
.iter()
.map(|c| c.residual_norm)
.fold(f32::INFINITY, f32::min),
codes
.iter()
.map(|c| c.residual_norm)
.fold(f32::NEG_INFINITY, f32::max),
);
if overlap <= 2 {
eprintln!("WARNING: RaBitQ distance approximation is broken for unnormalized vectors");
eprintln!(" The correction factors (f_add, f_rescale) scale with ||residual||^2,");
eprintln!(
" which varies wildly for unnormalized data, drowning the discriminative IP."
);
}
assert!(
overlap >= 1 || n < 20,
"RaBitQ has zero correlation with true L2 on unnormalized data"
);
}
#[test]
fn test_symphony_qg_l2_unnormalized() {
use crate::distance::DistanceMetric;
use crate::hnsw::graph::HNSWParams;
let dim = 64;
let n = 200;
let vectors: Vec<Vec<f32>> = (0..n)
.map(|seed| {
let v: Vec<f32> = (0..dim)
.map(|j| ((seed * dim + j) as f32 * 0.618_034).fract() * 2.0 - 1.0)
.collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
let target_norm = 5.0 + (seed as f32 % 5.0);
v.iter().map(|x| x * target_norm / norm).collect()
})
.collect();
let params = HNSWParams {
m: 16,
m_max: 32,
ef_construction: 200,
metric: DistanceMetric::L2,
seed: Some(42),
..Default::default()
};
let mut index =
SymphonyQGIndex::with_hnsw_params(dim, params, RaBitQConfig::bits4(), 42).unwrap();
for (i, v) in vectors.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let q = &vectors[0];
let raw_results = index.search(q, 10, 100).unwrap();
eprintln!(
"L2 raw quantized: {} results, top={:?}",
raw_results.len(),
raw_results.first()
);
let reranked_results = index.search_reranked(q, 10, 100, 50).unwrap();
eprintln!(
"L2 reranked: {} results, top={:?}",
reranked_results.len(),
reranked_results.first()
);
assert!(!raw_results.is_empty(), "raw search returned no results");
assert_eq!(
reranked_results[0].0, 0,
"self-query should return doc_id 0 (got {}), \
likely rerank uses wrong distance metric",
reranked_results[0].0
);
}
#[test]
fn test_symphony_qg_vr_l2_unnormalized() {
use crate::distance::DistanceMetric;
use crate::hnsw::graph::HNSWParams;
let dim = 64;
let n = 200;
let vectors: Vec<Vec<f32>> = (0..n)
.map(|seed| {
let v: Vec<f32> = (0..dim)
.map(|j| ((seed * dim + j) as f32 * 0.618_034).fract() * 2.0 - 1.0)
.collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
let target_norm = 5.0 + (seed as f32 % 5.0);
v.iter().map(|x| x * target_norm / norm).collect()
})
.collect();
let params = HNSWParams {
m: 16,
m_max: 32,
ef_construction: 200,
metric: DistanceMetric::L2,
seed: Some(42),
..Default::default()
};
let mut index = SymphonyQGVRIndex::new(dim, params, RaBitQConfig::bits4(), 42).unwrap();
for (i, v) in vectors.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let q = &vectors[0];
let raw_results = index.search(q, 10, 100).unwrap();
assert!(!raw_results.is_empty(), "VR raw search returned no results");
let reranked = index.search_reranked(q, 10, 100, 50).unwrap();
assert_eq!(
reranked[0].0, 0,
"VR reranked self-query should return doc_id 0 (got {})",
reranked[0].0
);
let mut gt: Vec<(u32, f32)> = vectors
.iter()
.enumerate()
.map(|(i, v)| {
let d: f32 = q.iter().zip(v.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
(i as u32, d)
})
.collect();
gt.sort_by(|a, b| a.1.total_cmp(&b.1));
let gt_set: std::collections::HashSet<u32> =
gt.iter().take(10).map(|(id, _)| *id).collect();
let result_set: std::collections::HashSet<u32> =
reranked.iter().map(|(id, _)| *id).collect();
let overlap = gt_set.intersection(&result_set).count();
assert!(
overlap >= 5,
"VR L2 recall@10 too low: {}/10 overlap with brute-force",
overlap
);
}
#[test]
fn test_vr_build_and_search_timing() {
use crate::distance::DistanceMetric;
use crate::hnsw::graph::HNSWParams;
let dim = 128;
let n = 5000;
let vectors: Vec<Vec<f32>> = (0..n)
.map(|seed| {
let v: Vec<f32> = (0..dim)
.map(|j| ((seed * dim + j) as f32 * 0.618_034).fract() * 2.0 - 1.0)
.collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
let target_norm = 5.0 + (seed as f32 % 5.0);
v.iter().map(|x| x * target_norm / norm).collect()
})
.collect();
let params = HNSWParams {
m: 16,
m_max: 32,
ef_construction: 100,
metric: DistanceMetric::L2,
seed: Some(42),
..Default::default()
};
let mut index = SymphonyQGVRIndex::new(dim, params, RaBitQConfig::bits4(), 42).unwrap();
for (i, v) in vectors.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
let t0 = std::time::Instant::now();
index.build().unwrap();
let build_ms = t0.elapsed().as_millis();
let queries: Vec<&[f32]> = vectors.iter().take(100).map(|v| v.as_slice()).collect();
let t0 = std::time::Instant::now();
for q in &queries {
let _ = index.search(q, 10, 50);
}
let search_ms = t0.elapsed().as_millis();
let qps = 100.0 / (search_ms as f64 / 1000.0);
let mem = index.memory_usage_bytes();
eprintln!(
"VR timing (n={n}, d={dim}): build={build_ms}ms, search={}ms ({qps:.0} QPS), \
mem={:.1}MB (codes={:.1}MB, vectors={:.1}MB)",
search_ms,
mem.total() as f64 / 1e6,
mem.packed_codes_bytes as f64 / 1e6,
mem.vectors_bytes as f64 / 1e6,
);
assert!(
build_ms < 60_000,
"VR build took {build_ms}ms (>60s) for only {n} vectors at d={dim}"
);
}
#[test]
fn test_vr_edge_distance_correlation() {
use crate::distance::DistanceMetric;
use crate::hnsw::graph::HNSWParams;
let dim = 128;
let n = 2000;
let vectors: Vec<Vec<f32>> = (0..n)
.map(|seed| {
(0..dim)
.map(|j| ((seed * dim + j) as f32 * 0.618_034).fract() * 20.0 - 10.0)
.collect()
})
.collect();
let params = HNSWParams {
m: 16,
m_max: 32,
ef_construction: 100,
metric: DistanceMetric::L2,
seed: Some(42),
..Default::default()
};
let mut index = SymphonyQGVRIndex::new(dim, params, RaBitQConfig::bits4(), 42).unwrap();
for (i, v) in vectors.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let quantizer = index.quantizer.as_ref().unwrap();
let base_layer = &index.index.layers[0];
let packed_dim = index.packed_dim;
let lut = nibble_lut(index.cb);
let mut concordant = 0usize;
let mut discordant = 0usize;
let mut total_pairs = 0usize;
for qi in 0..50 {
let base_id = (qi * 37) % n;
let base_vec = index.index.get_vector(base_id);
let query: Vec<f32> = base_vec
.iter()
.enumerate()
.map(|(j, &v)| v + ((qi * dim + j) as f32 * 0.314_159).fract() * 2.0 - 1.0)
.collect();
let query = query.as_slice();
let node_id = base_id as u32;
let rotated_query = quantizer.rotate_query(query).unwrap();
let neighbors = base_layer.get_neighbors(node_id);
if neighbors.len() < 2 {
continue;
}
let base_offset = index.neighbor_offsets[node_id as usize] as usize;
let mut exact_dists: Vec<(u32, f32)> = Vec::new();
let mut approx_dists: Vec<(u32, f32)> = Vec::new();
for (slot, &nbr_id) in neighbors.iter().enumerate() {
let nbr_vec = index.index.get_vector(nbr_id as usize);
let exact_l2: f32 = query
.iter()
.zip(nbr_vec.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
let offset = base_offset + slot;
let scalars = &index.edge_scalars[offset];
let codes = &index.packed_codes[offset * packed_dim..(offset + 1) * packed_dim];
let approx = approx_dist_vr_packed(&rotated_query, codes, scalars, &lut);
exact_dists.push((nbr_id, exact_l2));
approx_dists.push((nbr_id, approx));
}
for i in 0..exact_dists.len() {
for j in (i + 1)..exact_dists.len() {
let exact_order = exact_dists[i].1.total_cmp(&exact_dists[j].1);
let approx_order = approx_dists[i].1.total_cmp(&approx_dists[j].1);
if exact_order == approx_order {
concordant += 1;
} else {
discordant += 1;
}
total_pairs += 1;
}
}
}
let tau = if total_pairs > 0 {
(concordant as f64 - discordant as f64) / total_pairs as f64
} else {
0.0
};
eprintln!(
"VR edge distance correlation (d={dim}, n={n}): \
tau={tau:.3}, concordant={concordant}, discordant={discordant}, pairs={total_pairs}"
);
assert!(
tau > 0.3,
"VR per-edge distance has weak correlation with exact L2: tau={tau:.3}. \
The beam search cannot navigate effectively. \
Likely cause: correction factors mix rotated/unrotated spaces."
);
}
}