#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_precision_loss)]
use super::graph::{GraphError, HnswIndex, NodeId, VectorId};
use crate::quantization::variable::BinaryVector;
use crate::simd::popcount::simd_popcount_xor;
use crate::storage::binary::BinaryVectorStorage;
use crate::storage::VectorStorage;
use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashSet};
#[derive(Clone, Copy, Debug)]
struct BqCandidate {
distance: u32,
node_id: NodeId,
}
impl PartialEq for BqCandidate {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance && self.node_id == other.node_id
}
}
impl Eq for BqCandidate {}
impl PartialOrd for BqCandidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for BqCandidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance.cmp(&other.distance)
}
}
impl HnswIndex {
pub fn search_bq(
&self,
query: &[f32],
k: usize,
_storage: &VectorStorage,
) -> Result<Vec<(VectorId, f32)>, GraphError> {
let bq_storage = self.bq_storage.as_ref().ok_or(GraphError::BqNotEnabled)?;
let expected_dim = self.config.dimensions as usize;
if query.len() != expected_dim {
return Err(GraphError::DimensionMismatch {
expected: expected_dim,
actual: query.len(),
});
}
if self.entry_point.is_none() {
return Ok(Vec::new());
}
let query_bq =
BinaryVector::quantize(query).map_err(|e| GraphError::Quantization(e.to_string()))?;
let candidates = self.search_bq_internal(&query_bq, k, bq_storage)?;
let dimension = expected_dim as f32;
let results: Vec<_> = candidates
.into_iter()
.map(|(id, hamming_dist)| {
let similarity = 1.0 - (hamming_dist as f32 / dimension);
(id, similarity)
})
.collect();
Ok(results)
}
pub fn search_bq_rescored(
&self,
query: &[f32],
k: usize,
rescore_factor: usize,
storage: &VectorStorage,
) -> Result<Vec<(VectorId, f32)>, GraphError> {
use super::rescore::rescore_top_k;
let rescore_factor = rescore_factor.max(1);
let overfetched_k = k.saturating_mul(rescore_factor);
let bq_candidates = self.search_bq(query, overfetched_k, storage)?;
let rescored = rescore_top_k(&bq_candidates, query, storage, k);
let results: Vec<_> = rescored
.into_iter()
.map(|(id, distance)| {
let similarity = 1.0 / (1.0 + distance);
(id, similarity)
})
.collect();
Ok(results)
}
pub fn search_bq_rescored_default(
&self,
query: &[f32],
k: usize,
storage: &VectorStorage,
) -> Result<Vec<(VectorId, f32)>, GraphError> {
self.search_bq_rescored(query, k, 5, storage)
}
pub fn search_bq_high_recall(
&self,
query: &[f32],
k: usize,
storage: &VectorStorage,
) -> Result<Vec<(VectorId, f32)>, GraphError> {
self.search_bq_rescored(query, k, 15, storage)
}
fn search_bq_internal(
&self,
query_bq: &BinaryVector,
k: usize,
bq_storage: &BinaryVectorStorage,
) -> Result<Vec<(VectorId, u32)>, GraphError> {
let entry = self.entry_point.ok_or(GraphError::BqNotEnabled)?;
let mut current = entry;
for layer in (1..=self.max_layer).rev() {
current = self.greedy_bq(current, query_bq, layer, bq_storage)?;
}
let candidates = self.search_layer_bq(current, query_bq, k, bq_storage)?;
Ok(candidates)
}
fn greedy_bq(
&self,
start: NodeId,
query_bq: &BinaryVector,
layer: u8,
bq_storage: &BinaryVectorStorage,
) -> Result<NodeId, GraphError> {
let mut current = start;
let mut current_dist = self.hamming_to_node(query_bq, current, bq_storage)?;
loop {
let mut changed = false;
let node = self
.get_node(current)
.ok_or(GraphError::NodeIdOutOfBounds)?;
if node.max_layer < layer {
break;
}
let neighbors = self.get_neighbors_at_layer(node, layer)?;
for neighbor in neighbors {
let Some(neighbor_node) = self.get_node(neighbor) else {
continue;
};
if neighbor_node.deleted != 0 {
continue;
}
if neighbor_node.max_layer < layer {
continue;
}
let dist = self.hamming_to_node(query_bq, neighbor, bq_storage)?;
if dist < current_dist {
current = neighbor;
current_dist = dist;
changed = true;
}
}
if !changed {
break;
}
}
Ok(current)
}
fn search_layer_bq(
&self,
entry: NodeId,
query_bq: &BinaryVector,
k: usize,
bq_storage: &BinaryVectorStorage,
) -> Result<Vec<(VectorId, u32)>, GraphError> {
let ef = self.config.ef_search.max(k as u32) as usize;
let mut visited: HashSet<NodeId> = HashSet::new();
let mut candidates: BinaryHeap<Reverse<BqCandidate>> = BinaryHeap::new();
let mut results: BinaryHeap<BqCandidate> = BinaryHeap::new();
let entry_dist = self.hamming_to_node(query_bq, entry, bq_storage)?;
let entry_node = self.get_node(entry).ok_or(GraphError::NodeIdOutOfBounds)?;
visited.insert(entry);
candidates.push(Reverse(BqCandidate {
distance: entry_dist,
node_id: entry,
}));
if entry_node.deleted == 0 {
results.push(BqCandidate {
distance: entry_dist,
node_id: entry,
});
}
while let Some(Reverse(candidate)) = candidates.pop() {
let worst_dist = if results.len() >= ef {
results.peek().map_or(u32::MAX, |c| c.distance)
} else {
u32::MAX
};
if candidate.distance > worst_dist {
break;
}
let Some(node) = self.get_node(candidate.node_id) else {
continue;
};
let neighbors = self.get_neighbors_at_layer(node, 0)?;
for neighbor in neighbors {
if visited.contains(&neighbor) {
continue;
}
visited.insert(neighbor);
let Some(neighbor_node) = self.get_node(neighbor) else {
continue;
};
let dist = self.hamming_to_node(query_bq, neighbor, bq_storage)?;
let should_add =
results.len() < ef || dist < results.peek().map_or(u32::MAX, |c| c.distance);
if should_add {
candidates.push(Reverse(BqCandidate {
distance: dist,
node_id: neighbor,
}));
if neighbor_node.deleted == 0 {
results.push(BqCandidate {
distance: dist,
node_id: neighbor,
});
while results.len() > ef {
results.pop();
}
}
}
}
}
let mut result_vec: Vec<_> = results
.into_iter()
.filter_map(|c| {
let node = self.nodes.get(c.node_id.0 as usize)?;
Some((node.vector_id, c.distance))
})
.collect();
result_vec.sort_by_key(|(_, d)| *d);
result_vec.truncate(k);
Ok(result_vec)
}
fn hamming_to_node(
&self,
query_bq: &BinaryVector,
node: NodeId,
bq_storage: &BinaryVectorStorage,
) -> Result<u32, GraphError> {
let idx = node.0 as usize;
if idx >= self.nodes.len() {
return Err(GraphError::NodeIdOutOfBounds);
}
let node_data = bq_storage
.get_raw(u64::from(node.0))
.ok_or(GraphError::InvalidVectorId)?;
let dist = simd_popcount_xor(query_bq.data(), node_data);
Ok(dist)
}
fn get_neighbors_at_layer(
&self,
node: &super::graph::HnswNode,
layer: u8,
) -> Result<Vec<NodeId>, GraphError> {
if layer == 0 {
self.get_neighbors(node)
} else {
self.get_neighbors(node)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hnsw::config::HnswConfig;
#[test]
fn test_search_bq_empty_index() {
let config = HnswConfig::new(128);
let storage = VectorStorage::new(&config, None);
let index = HnswIndex::with_bq(config, &storage).unwrap();
let query = vec![1.0f32; 128];
let results = index.search_bq(&query, 10, &storage).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_bq_not_enabled_error() {
let config = HnswConfig::new(128);
let storage = VectorStorage::new(&config, None);
let index = HnswIndex::new(config, &storage).unwrap();
let query = vec![1.0f32; 128];
let result = index.search_bq(&query, 10, &storage);
assert!(matches!(result, Err(GraphError::BqNotEnabled)));
}
#[test]
fn test_search_bq_dimension_mismatch() {
let config = HnswConfig::new(128);
let storage = VectorStorage::new(&config, None);
let index = HnswIndex::with_bq(config, &storage).unwrap();
let query = vec![1.0f32; 64]; let result = index.search_bq(&query, 10, &storage);
assert!(matches!(
result,
Err(GraphError::DimensionMismatch {
expected: 128,
actual: 64
})
));
}
#[test]
fn test_search_bq_single_vector() {
let config = HnswConfig::new(128);
let mut storage = VectorStorage::new(&config, None);
let mut index = HnswIndex::with_bq(config, &storage).unwrap();
let v = vec![1.0f32; 128];
index.insert_bq(&v, &mut storage).unwrap();
let query = vec![1.0f32; 128];
let results = index.search_bq(&query, 10, &storage).unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].1 > 0.99);
}
#[test]
fn test_search_bq_finds_most_similar() {
let config = HnswConfig::new(128);
let mut storage = VectorStorage::new(&config, None);
let mut index = HnswIndex::with_bq(config, &storage).unwrap();
let v1 = vec![1.0f32; 128]; let v2 = vec![-1.0f32; 128]; let v3: Vec<f32> = (0..128).map(|i| if i < 64 { 1.0 } else { -1.0 }).collect();
index.insert_bq(&v1, &mut storage).unwrap();
index.insert_bq(&v2, &mut storage).unwrap();
index.insert_bq(&v3, &mut storage).unwrap();
let query = vec![1.0f32; 128];
let results = index.search_bq(&query, 3, &storage).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].0 .0, 1); assert!((results[0].1 - 1.0).abs() < 0.01); }
#[test]
fn test_search_bq_multiple_vectors() {
let config = HnswConfig::new(128);
let mut storage = VectorStorage::new(&config, None);
let mut index = HnswIndex::with_bq(config, &storage).unwrap();
for i in 0..10 {
let v: Vec<f32> = (0..128)
.map(|j| ((i * 128 + j) % 256) as f32 / 128.0 - 1.0)
.collect();
index.insert_bq(&v, &mut storage).unwrap();
}
let query: Vec<f32> = (0..128).map(|j| (j % 256) as f32 / 128.0 - 1.0).collect();
let results = index.search_bq(&query, 5, &storage).unwrap();
assert_eq!(results.len(), 5);
for i in 1..results.len() {
assert!(results[i - 1].1 >= results[i].1);
}
}
}