use crate::hnsw::graph::HNSWIndex;
use crate::RetrieveError;
use qntz::rabitq::{QuantizedVector, RaBitQConfig, RaBitQQuantizer};
pub struct SymphonyQGIndex {
index: HNSWIndex,
codes: Vec<QuantizedVector>,
quantizer: Option<RaBitQQuantizer>,
rabitq_config: RaBitQConfig,
seed: u64,
quantized_built: bool,
}
impl SymphonyQGIndex {
pub fn new(dimension: usize, m: usize, m_max: usize) -> Result<Self, RetrieveError> {
Self::with_config(dimension, m, m_max, RaBitQConfig::bits4(), 42)
}
pub fn with_config(
dimension: usize,
m: usize,
m_max: usize,
rabitq_config: RaBitQConfig,
seed: u64,
) -> Result<Self, RetrieveError> {
let index = HNSWIndex::new(dimension, m, m_max)?;
Ok(Self {
index,
codes: Vec::new(),
quantizer: None,
rabitq_config,
seed,
quantized_built: false,
})
}
pub fn add_slice(&mut self, doc_id: u32, vector: &[f32]) -> Result<(), RetrieveError> {
self.index.add_slice(doc_id, vector)
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
self.index.build()?;
self.quantize_vectors()?;
Ok(())
}
fn quantize_vectors(&mut self) -> Result<(), RetrieveError> {
let n = self.index.num_vectors;
if n == 0 {
self.quantized_built = true;
return Ok(());
}
let dim = self.index.dimension;
let mut quantizer = RaBitQQuantizer::with_config(dim, self.seed, self.rabitq_config)
.map_err(|e| RetrieveError::InvalidParameter(format!("RaBitQ init: {e}")))?;
quantizer
.fit(&self.index.vectors, n)
.map_err(|e| RetrieveError::InvalidParameter(format!("RaBitQ fit: {e}")))?;
let mut codes = Vec::with_capacity(n);
for i in 0..n {
let vec = self.index.get_vector(i);
let qv = quantizer
.quantize(vec)
.map_err(|e| RetrieveError::InvalidParameter(format!("RaBitQ quantize: {e}")))?;
codes.push(qv);
}
self.quantizer = Some(quantizer);
self.codes = codes;
self.quantized_built = true;
Ok(())
}
pub fn search(
&self,
query: &[f32],
k: usize,
ef: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
self.check_search_ready(query)?;
let results = self.search_quantized_graph(query, ef)?;
let mut output: Vec<(u32, f32)> = results
.into_iter()
.take(k)
.map(|(internal_id, dist)| (self.index.doc_ids[internal_id as usize], dist))
.collect();
output.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
Ok(output)
}
pub fn search_reranked(
&self,
query: &[f32],
k: usize,
ef: usize,
rerank_pool: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
self.check_search_ready(query)?;
let pool = rerank_pool.max(k);
let candidates = self.search_quantized_graph(query, ef.max(pool))?;
let mut reranked: Vec<(u32, f32)> = candidates
.into_iter()
.take(pool)
.map(|(internal_id, _approx_dist)| {
let vec = self.index.get_vector(internal_id as usize);
let exact_dist = crate::distance::cosine_distance_normalized(query, vec);
(self.index.doc_ids[internal_id as usize], exact_dist)
})
.collect();
reranked.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
reranked.truncate(k);
Ok(reranked)
}
pub fn len(&self) -> usize {
self.index.num_vectors
}
pub fn is_empty(&self) -> bool {
self.index.num_vectors == 0
}
pub fn inner(&self) -> &HNSWIndex {
&self.index
}
fn check_search_ready(&self, query: &[f32]) -> Result<(), RetrieveError> {
if !self.index.is_built() {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if !self.quantized_built {
return Err(RetrieveError::InvalidParameter(
"quantization not built (call build())".into(),
));
}
if query.len() != self.index.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.index.dimension,
});
}
if self.index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
Ok(())
}
fn rotate_query(&self, query: &[f32]) -> Result<Vec<f32>, RetrieveError> {
self.quantizer
.as_ref()
.ok_or_else(|| {
RetrieveError::InvalidParameter("quantizer must be set after build".into())
})?
.rotate_query(query)
.map_err(|e| RetrieveError::InvalidParameter(format!("rotate query: {e}")))
}
fn search_quantized_graph(
&self,
query: &[f32],
ef: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
let rotated_query = self.rotate_query(query)?;
let codes = &self.codes;
let (entry_point, entry_layer) = self.index.entry_point().unwrap_or((0, 0));
let mut current = entry_point;
let mut current_dist = approx_dist_sqr(&rotated_query, &codes[current as usize]);
for layer_idx in (1..=entry_layer).rev() {
if layer_idx >= self.index.layers.len() {
continue;
}
let layer = &self.index.layers[layer_idx];
let mut changed = true;
while changed {
changed = false;
let neighbors = layer.get_neighbors(current);
for &neighbor_id in neighbors.iter() {
let dist = approx_dist_sqr(&rotated_query, &codes[neighbor_id as usize]);
if dist < current_dist {
current_dist = dist;
current = neighbor_id;
changed = true;
}
}
}
}
if self.index.layers.is_empty() {
return Ok(Vec::new());
}
let base_layer = &self.index.layers[0];
let dist_fn = |_q: &[f32], node_id: u32| -> f32 {
approx_dist_sqr(&rotated_query, &codes[node_id as usize])
};
Ok(crate::hnsw::search::greedy_search_layer_custom(
query,
current,
base_layer,
&self.index.vectors,
self.index.dimension,
ef,
&dist_fn,
))
}
}
#[inline]
fn approx_dist_sqr(rotated_query: &[f32], qv: &QuantizedVector) -> f32 {
RaBitQQuantizer::approximate_l2_sqr_prerotated(rotated_query, qv)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_normalized_vector(seed: usize, dim: usize) -> Vec<f32> {
let v: Vec<f32> = (0..dim)
.map(|j| ((seed * dim + j) as f32 * 0.618_034).fract() * 2.0 - 1.0)
.collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
v.iter().map(|x| x / norm).collect()
}
#[test]
fn test_symphony_qg_basic() {
let dim = 32;
let n = 200;
let mut index = SymphonyQGIndex::new(dim, 8, 8).unwrap();
for i in 0..n {
index
.add_slice(i as u32, &make_normalized_vector(i, dim))
.unwrap();
}
index.build().unwrap();
let q = make_normalized_vector(0, dim);
let results = index.search_reranked(&q, 5, 32, 50).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, 0, "self-query should return doc_id 0");
}
#[test]
fn test_distance_matches_qntz() {
let dim = 32;
let n = 50;
let seed = 42;
let config = RaBitQConfig::bits4();
let vectors: Vec<Vec<f32>> = (0..n).map(|i| make_normalized_vector(i, dim)).collect();
let flat: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
let mut quantizer = RaBitQQuantizer::with_config(dim, seed, config).unwrap();
quantizer.fit(&flat, n).unwrap();
let codes: Vec<QuantizedVector> = vectors
.iter()
.map(|v| quantizer.quantize(v).unwrap())
.collect();
let query = &vectors[0];
let qntz_dist = quantizer.approximate_l2_sqr(query, &codes[1]).unwrap();
let rotated = quantizer.rotate_query(query).unwrap();
let prerotated_dist = RaBitQQuantizer::approximate_l2_sqr_prerotated(&rotated, &codes[1]);
let diff = (qntz_dist - prerotated_dist).abs();
assert!(
diff < 1e-4,
"distance mismatch: qntz={qntz_dist}, prerotated={prerotated_dist}, diff={diff}"
);
}
#[test]
fn test_symphony_qg_recall() {
let dim = 256;
let n = 300;
let mut index =
SymphonyQGIndex::with_config(dim, 16, 16, RaBitQConfig::bits4(), 42).unwrap();
let vectors: Vec<Vec<f32>> = (0..n).map(|i| make_normalized_vector(i, dim)).collect();
for (i, v) in vectors.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let mut hits = 0;
for (i, v) in vectors.iter().enumerate() {
let results = index.search_reranked(v, 1, 200, 100).unwrap();
if results.first().map(|(id, _)| *id) == Some(i as u32) {
hits += 1;
}
}
let recall = hits as f64 / n as f64;
assert!(
recall > 0.5,
"reranked self-search recall too low: {recall:.2} ({hits}/{n})"
);
}
}