use super::VectorAccessor;
use super::compute_distance;
use crate::index::vector::HnswConfig;
use grafeo_common::types::NodeId;
use ordered_float::OrderedFloat;
use parking_lot::RwLock;
use rand::{RngExt, SeedableRng};
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq)]
struct Neighbor {
id: NodeId,
distance: f32,
}
impl Eq for Neighbor {}
impl PartialOrd for Neighbor {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Neighbor {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
OrderedFloat(other.distance).cmp(&OrderedFloat(self.distance))
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct FurthestCandidate {
id: NodeId,
distance: f32,
}
impl Eq for FurthestCandidate {}
impl PartialOrd for FurthestCandidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for FurthestCandidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
OrderedFloat(self.distance).cmp(&OrderedFloat(other.distance))
}
}
#[derive(Debug, Clone)]
struct HnswNode {
neighbors: Vec<Vec<NodeId>>,
}
pub struct HnswIndex {
config: HnswConfig,
nodes: RwLock<HashMap<NodeId, HnswNode>>,
entry_point: RwLock<Option<NodeId>>,
max_level: RwLock<usize>,
rng: RwLock<rand::rngs::StdRng>,
}
impl HnswIndex {
#[must_use]
pub fn new(config: HnswConfig) -> Self {
Self {
config,
nodes: RwLock::new(HashMap::new()),
entry_point: RwLock::new(None),
max_level: RwLock::new(0),
rng: RwLock::new(rand::rngs::StdRng::from_rng(&mut rand::rng())),
}
}
#[must_use]
pub fn with_capacity(config: HnswConfig, capacity: usize) -> Self {
Self {
config,
nodes: RwLock::new(HashMap::with_capacity(capacity)),
entry_point: RwLock::new(None),
max_level: RwLock::new(0),
rng: RwLock::new(rand::rngs::StdRng::from_rng(&mut rand::rng())),
}
}
#[must_use]
pub fn with_seed(config: HnswConfig, seed: u64) -> Self {
Self {
config,
nodes: RwLock::new(HashMap::new()),
entry_point: RwLock::new(None),
max_level: RwLock::new(0),
rng: RwLock::new(rand::rngs::StdRng::seed_from_u64(seed)),
}
}
#[must_use]
pub fn config(&self) -> &HnswConfig {
&self.config
}
#[must_use]
pub fn len(&self) -> usize {
self.nodes.read().len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.nodes.read().is_empty()
}
#[must_use]
pub fn heap_memory_bytes(&self) -> usize {
let nodes = self.nodes.read();
let map_overhead = nodes.capacity()
* (std::mem::size_of::<NodeId>() + std::mem::size_of::<HnswNode>() + 1);
let mut node_bytes = 0usize;
for node in nodes.values() {
node_bytes += node.neighbors.capacity() * std::mem::size_of::<Vec<NodeId>>();
for layer in &node.neighbors {
node_bytes += layer.capacity() * std::mem::size_of::<NodeId>();
}
}
map_overhead + node_bytes
}
pub fn insert(&self, id: NodeId, vector: &[f32], accessor: &impl VectorAccessor) {
assert_eq!(
vector.len(),
self.config.dimensions,
"Vector dimensions mismatch: expected {}, got {}",
self.config.dimensions,
vector.len()
);
let level = self.random_level();
let node = HnswNode {
neighbors: vec![Vec::new(); level + 1],
};
let mut nodes = self.nodes.write();
let mut entry_point = self.entry_point.write();
let mut max_level = self.max_level.write();
if entry_point.is_none() {
nodes.insert(id, node);
*entry_point = Some(id);
*max_level = level;
return;
}
let ep = entry_point.expect("entry_point confirmed Some above");
let current_max_level = *max_level;
nodes.insert(id, node);
let mut current_ep = ep;
for lc in (level + 1..=current_max_level).rev() {
current_ep = self.search_layer_single(&nodes, accessor, vector, current_ep, lc);
}
for lc in (0..=level.min(current_max_level)).rev() {
let m_max = if lc == 0 {
self.config.m_max
} else {
self.config.m
};
let neighbors = self.search_layer(
&nodes,
accessor,
vector,
current_ep,
self.config.ef_construction,
lc,
);
let selected = self.select_neighbors_heuristic(accessor, &neighbors, m_max);
if let Some(new_node) = nodes.get_mut(&id) {
new_node.neighbors[lc].clone_from(&selected);
}
let mut needs_pruning: Vec<NodeId> = Vec::new();
for &neighbor_id in &selected {
if let Some(neighbor) = nodes.get_mut(&neighbor_id)
&& neighbor.neighbors.len() > lc
{
neighbor.neighbors[lc].push(id);
if neighbor.neighbors[lc].len() > m_max {
needs_pruning.push(neighbor_id);
}
}
}
let mut prune_data: Vec<(NodeId, Vec<(NodeId, f32)>)> = Vec::new();
for neighbor_id in &needs_pruning {
if let Some(neighbor) = nodes.get(neighbor_id)
&& neighbor.neighbors.len() > lc
{
let Some(base_vec) = accessor.get_vector(*neighbor_id) else {
continue;
};
let distances: Vec<(NodeId, f32)> = neighbor.neighbors[lc]
.iter()
.map(|&nid| {
let dist = accessor
.get_vector(nid)
.map_or(f32::MAX, |v| self.vector_distance(&base_vec, &v));
(nid, dist)
})
.collect();
prune_data.push((*neighbor_id, distances));
}
}
for (neighbor_id, distances) in prune_data {
if let Some(neighbor) = nodes.get_mut(&neighbor_id)
&& neighbor.neighbors.len() > lc
{
Self::prune_neighbors_with_distances(
&mut neighbor.neighbors[lc],
&distances,
m_max,
);
}
}
if !selected.is_empty() {
current_ep = selected[0];
}
}
if level > current_max_level {
*entry_point = Some(id);
*max_level = level;
}
}
#[must_use]
pub fn search(
&self,
query: &[f32],
k: usize,
accessor: &impl VectorAccessor,
) -> Vec<(NodeId, f32)> {
self.search_with_ef(query, k, self.config.ef, accessor)
}
#[must_use]
pub fn search_with_ef(
&self,
query: &[f32],
k: usize,
ef: usize,
accessor: &impl VectorAccessor,
) -> Vec<(NodeId, f32)> {
assert_eq!(
query.len(),
self.config.dimensions,
"Query dimensions mismatch: expected {}, got {}",
self.config.dimensions,
query.len()
);
let nodes = self.nodes.read();
let entry_point = self.entry_point.read();
let max_level = *self.max_level.read();
if entry_point.is_none() || nodes.is_empty() {
return Vec::new();
}
let ep = entry_point.expect("entry_point confirmed Some above");
let mut current_ep = ep;
for lc in (1..=max_level).rev() {
current_ep = self.search_layer_single(&nodes, accessor, query, current_ep, lc);
}
let ef_search = ef.max(k);
let candidates = self.search_layer(&nodes, accessor, query, current_ep, ef_search, 0);
candidates
.into_iter()
.take(k)
.map(|n| (n.id, n.distance))
.collect()
}
#[must_use]
pub fn search_with_filter(
&self,
query: &[f32],
k: usize,
allowlist: &HashSet<NodeId>,
accessor: &impl VectorAccessor,
) -> Vec<(NodeId, f32)> {
if allowlist.is_empty() {
return Vec::new();
}
let total = self.nodes.read().len();
let selectivity = if total == 0 {
1.0
} else {
(allowlist.len() as f64 / total as f64).max(0.01)
};
let ef_scaled = ((self.config.ef as f64 / selectivity).ceil() as usize)
.min(total)
.max(k);
self.search_with_ef_and_filter(query, k, ef_scaled, allowlist, accessor)
}
#[must_use]
pub fn search_with_ef_and_filter(
&self,
query: &[f32],
k: usize,
ef: usize,
allowlist: &HashSet<NodeId>,
accessor: &impl VectorAccessor,
) -> Vec<(NodeId, f32)> {
if allowlist.is_empty() {
return Vec::new();
}
assert_eq!(
query.len(),
self.config.dimensions,
"Query dimensions mismatch: expected {}, got {}",
self.config.dimensions,
query.len()
);
let nodes = self.nodes.read();
let entry_point = self.entry_point.read();
let max_level = *self.max_level.read();
if entry_point.is_none() || nodes.is_empty() {
return Vec::new();
}
let ep = entry_point.expect("entry_point confirmed Some above");
let mut current_ep = ep;
for lc in (1..=max_level).rev() {
current_ep = self.search_layer_single(&nodes, accessor, query, current_ep, lc);
}
let ef_search = ef.max(k);
let candidates = self
.search_layer_filtered(&nodes, accessor, query, current_ep, ef_search, 0, allowlist);
candidates
.into_iter()
.take(k)
.map(|n| (n.id, n.distance))
.collect()
}
pub fn remove(&self, id: NodeId) -> bool {
let mut nodes = self.nodes.write();
let mut entry_point = self.entry_point.write();
if nodes.remove(&id).is_none() {
return false;
}
for (_, node) in nodes.iter_mut() {
for neighbors in &mut node.neighbors {
neighbors.retain(|&n| n != id);
}
}
if *entry_point == Some(id) {
*entry_point = nodes.keys().next().copied();
}
true
}
#[must_use]
pub fn contains(&self, id: NodeId) -> bool {
self.nodes.read().contains_key(&id)
}
fn random_level(&self) -> usize {
let mut rng = self.rng.write();
let r: f64 = rng.random();
(-r.ln() * self.config.ml).floor() as usize
}
fn search_layer_single(
&self,
nodes: &HashMap<NodeId, HnswNode>,
accessor: &impl VectorAccessor,
query: &[f32],
ep: NodeId,
layer: usize,
) -> NodeId {
let mut current = ep;
let mut current_dist = self.node_distance(accessor, query, ep);
loop {
let mut changed = false;
if let Some(node) = nodes.get(¤t)
&& layer < node.neighbors.len()
{
for &neighbor in &node.neighbors[layer] {
let dist = self.node_distance(accessor, query, neighbor);
if dist < current_dist {
current = neighbor;
current_dist = dist;
changed = true;
}
}
}
if !changed {
break;
}
}
current
}
fn search_layer(
&self,
nodes: &HashMap<NodeId, HnswNode>,
accessor: &impl VectorAccessor,
query: &[f32],
ep: NodeId,
ef: usize,
layer: usize,
) -> Vec<Neighbor> {
let ep_dist = self.node_distance(accessor, query, ep);
let mut candidates: BinaryHeap<Neighbor> = BinaryHeap::new();
candidates.push(Neighbor {
id: ep,
distance: ep_dist,
});
let mut results: BinaryHeap<FurthestCandidate> = BinaryHeap::new();
results.push(FurthestCandidate {
id: ep,
distance: ep_dist,
});
let mut visited: HashSet<NodeId> =
HashSet::with_capacity(nodes.len().min(ef.saturating_mul(2)));
visited.insert(ep);
while let Some(current) = candidates.pop() {
if let Some(furthest) = results.peek()
&& current.distance > furthest.distance
&& results.len() >= ef
{
break;
}
if let Some(node) = nodes.get(¤t.id)
&& layer < node.neighbors.len()
{
for &neighbor in &node.neighbors[layer] {
if visited.contains(&neighbor) {
continue;
}
visited.insert(neighbor);
let dist = self.node_distance(accessor, query, neighbor);
let should_add =
results.len() < ef || results.peek().map_or(true, |f| dist < f.distance);
if should_add {
candidates.push(Neighbor {
id: neighbor,
distance: dist,
});
results.push(FurthestCandidate {
id: neighbor,
distance: dist,
});
while results.len() > ef {
results.pop();
}
}
}
}
}
let mut result_vec: Vec<Neighbor> = results
.into_iter()
.map(|fc| Neighbor {
id: fc.id,
distance: fc.distance,
})
.collect();
result_vec.sort_by(|a, b| OrderedFloat(a.distance).cmp(&OrderedFloat(b.distance)));
result_vec
}
#[allow(clippy::too_many_arguments)]
fn search_layer_filtered(
&self,
nodes: &HashMap<NodeId, HnswNode>,
accessor: &impl VectorAccessor,
query: &[f32],
ep: NodeId,
ef: usize,
layer: usize,
allowlist: &HashSet<NodeId>,
) -> Vec<Neighbor> {
let ep_dist = self.node_distance(accessor, query, ep);
let mut candidates: BinaryHeap<Neighbor> = BinaryHeap::new();
candidates.push(Neighbor {
id: ep,
distance: ep_dist,
});
let mut best_seen: BinaryHeap<FurthestCandidate> = BinaryHeap::new();
best_seen.push(FurthestCandidate {
id: ep,
distance: ep_dist,
});
let mut results: BinaryHeap<FurthestCandidate> = BinaryHeap::new();
if allowlist.contains(&ep) {
results.push(FurthestCandidate {
id: ep,
distance: ep_dist,
});
}
let mut visited: HashSet<NodeId> =
HashSet::with_capacity(nodes.len().min(ef.saturating_mul(4)));
visited.insert(ep);
while let Some(current) = candidates.pop() {
if let Some(furthest) = best_seen.peek()
&& current.distance > furthest.distance
&& best_seen.len() >= ef
{
break;
}
if let Some(node) = nodes.get(¤t.id)
&& layer < node.neighbors.len()
{
for &neighbor in &node.neighbors[layer] {
if visited.contains(&neighbor) {
continue;
}
visited.insert(neighbor);
let dist = self.node_distance(accessor, query, neighbor);
let should_explore = best_seen.len() < ef
|| best_seen.peek().map_or(true, |f| dist < f.distance);
if should_explore {
candidates.push(Neighbor {
id: neighbor,
distance: dist,
});
best_seen.push(FurthestCandidate {
id: neighbor,
distance: dist,
});
while best_seen.len() > ef {
best_seen.pop();
}
}
if allowlist.contains(&neighbor) {
let should_add = results.len() < ef
|| results.peek().map_or(true, |f| dist < f.distance);
if should_add {
results.push(FurthestCandidate {
id: neighbor,
distance: dist,
});
while results.len() > ef {
results.pop();
}
}
}
}
}
}
let mut result_vec: Vec<Neighbor> = results
.into_iter()
.map(|fc| Neighbor {
id: fc.id,
distance: fc.distance,
})
.collect();
result_vec.sort_by(|a, b| OrderedFloat(a.distance).cmp(&OrderedFloat(b.distance)));
result_vec
}
fn select_neighbors_heuristic(
&self,
accessor: &impl VectorAccessor,
candidates: &[Neighbor],
m: usize,
) -> Vec<NodeId> {
let alpha = self.config.alpha;
let mut selected: Vec<(NodeId, Arc<[f32]>)> = Vec::with_capacity(m);
for candidate in candidates {
if selected.len() >= m {
break;
}
let Some(cv) = accessor.get_vector(candidate.id) else {
continue;
};
let covered = selected
.iter()
.any(|(_, sv)| self.vector_distance(&cv, sv) < alpha * candidate.distance);
if !covered {
selected.push((candidate.id, cv));
}
}
selected.into_iter().map(|(id, _)| id).collect()
}
fn prune_neighbors_with_distances(
neighbors: &mut Vec<NodeId>,
distances: &[(NodeId, f32)],
m: usize,
) {
if neighbors.len() <= m {
return;
}
let mut sorted: Vec<_> = distances.to_vec();
sorted.sort_by(|a, b| OrderedFloat(a.1).cmp(&OrderedFloat(b.1)));
*neighbors = sorted.into_iter().take(m).map(|(id, _)| id).collect();
}
#[inline]
fn vector_distance(&self, a: &[f32], b: &[f32]) -> f32 {
compute_distance(a, b, self.config.metric)
}
fn node_distance(&self, accessor: &impl VectorAccessor, query: &[f32], id: NodeId) -> f32 {
accessor
.get_vector(id)
.map_or(f32::MAX, |v| self.vector_distance(query, &v))
}
pub fn batch_insert<'a, I>(&self, vectors: I, accessor: &impl VectorAccessor)
where
I: IntoIterator<Item = (NodeId, &'a [f32])>,
{
for (id, vector) in vectors {
self.insert(id, vector, accessor);
}
}
#[must_use]
pub fn batch_search(
&self,
queries: &[Vec<f32>],
k: usize,
accessor: &impl VectorAccessor,
) -> Vec<Vec<(NodeId, f32)>> {
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
queries
.par_iter()
.map(|query| self.search(query, k, accessor))
.collect()
}
#[cfg(not(feature = "parallel"))]
{
queries
.iter()
.map(|query| self.search(query, k, accessor))
.collect()
}
}
#[must_use]
pub fn batch_search_slices(
&self,
queries: &[&[f32]],
k: usize,
accessor: &impl VectorAccessor,
) -> Vec<Vec<(NodeId, f32)>> {
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
queries
.par_iter()
.map(|query| self.search(query, k, accessor))
.collect()
}
#[cfg(not(feature = "parallel"))]
{
queries
.iter()
.map(|query| self.search(query, k, accessor))
.collect()
}
}
#[must_use]
pub fn batch_search_with_ef(
&self,
queries: &[Vec<f32>],
k: usize,
ef: usize,
accessor: &impl VectorAccessor,
) -> Vec<Vec<(NodeId, f32)>> {
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
queries
.par_iter()
.map(|query| self.search_with_ef(query, k, ef, accessor))
.collect()
}
#[cfg(not(feature = "parallel"))]
{
queries
.iter()
.map(|query| self.search_with_ef(query, k, ef, accessor))
.collect()
}
}
#[must_use]
pub fn batch_search_with_filter(
&self,
queries: &[Vec<f32>],
k: usize,
allowlist: &HashSet<NodeId>,
accessor: &impl VectorAccessor,
) -> Vec<Vec<(NodeId, f32)>> {
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
queries
.par_iter()
.map(|query| self.search_with_filter(query, k, allowlist, accessor))
.collect()
}
#[cfg(not(feature = "parallel"))]
{
queries
.iter()
.map(|query| self.search_with_filter(query, k, allowlist, accessor))
.collect()
}
}
#[must_use]
pub fn batch_search_with_ef_and_filter(
&self,
queries: &[Vec<f32>],
k: usize,
ef: usize,
allowlist: &HashSet<NodeId>,
accessor: &impl VectorAccessor,
) -> Vec<Vec<(NodeId, f32)>> {
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
queries
.par_iter()
.map(|query| self.search_with_ef_and_filter(query, k, ef, allowlist, accessor))
.collect()
}
#[cfg(not(feature = "parallel"))]
{
queries
.iter()
.map(|query| self.search_with_ef_and_filter(query, k, ef, allowlist, accessor))
.collect()
}
}
}
impl std::fmt::Debug for HnswIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HnswIndex")
.field("config", &self.config)
.field("len", &self.len())
.field("max_level", &*self.max_level.read())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::vector::DistanceMetric;
fn create_test_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
(0..n)
.map(|i| {
(0..dim)
.map(|j| ((i * dim + j) as f32) / (n * dim) as f32)
.collect()
})
.collect()
}
fn make_accessor(map: &HashMap<NodeId, Arc<[f32]>>) -> impl VectorAccessor + '_ {
move |id: NodeId| -> Option<Arc<[f32]>> { map.get(&id).cloned() }
}
#[test]
fn test_hnsw_empty() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::new(config);
let map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
let accessor = make_accessor(&map);
assert!(index.is_empty());
assert_eq!(index.len(), 0);
assert!(
index
.search(&[0.0, 0.0, 0.0, 0.0], 10, &accessor)
.is_empty()
);
}
#[test]
fn test_hnsw_single_insert() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::new(config);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
let v: Arc<[f32]> = vec![0.1, 0.2, 0.3, 0.4].into();
map.insert(NodeId::new(1), v.clone());
let accessor = make_accessor(&map);
index.insert(NodeId::new(1), &v, &accessor);
assert_eq!(index.len(), 1);
assert!(index.contains(NodeId::new(1)));
assert!(!index.contains(NodeId::new(2)));
let results = index.search(&[0.1, 0.2, 0.3, 0.4], 1, &accessor);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, NodeId::new(1));
assert!(results[0].1 < 0.001); }
#[test]
fn test_hnsw_multiple_inserts() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::with_seed(config, 42);
let vectors = create_test_vectors(100, 4);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64 + 1);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec, &accessor);
}
assert_eq!(index.len(), 100);
let query = &vectors[50];
let results = index.search(query, 5, &accessor);
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, NodeId::new(51));
assert!(results[0].1 < 0.001);
}
#[test]
fn test_hnsw_search_returns_sorted() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::with_seed(config, 42);
let vectors = create_test_vectors(50, 4);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64 + 1);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec, &accessor);
}
let query = [0.5, 0.5, 0.5, 0.5];
let results = index.search(&query, 10, &accessor);
for i in 1..results.len() {
assert!(results[i - 1].1 <= results[i].1);
}
}
#[test]
fn test_hnsw_remove() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::new(config);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(1), vec![0.1, 0.2, 0.3, 0.4].into());
map.insert(NodeId::new(2), vec![0.5, 0.6, 0.7, 0.8].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(1), &[0.1, 0.2, 0.3, 0.4], &accessor);
index.insert(NodeId::new(2), &[0.5, 0.6, 0.7, 0.8], &accessor);
assert_eq!(index.len(), 2);
assert!(index.remove(NodeId::new(1)));
assert_eq!(index.len(), 1);
assert!(!index.contains(NodeId::new(1)));
assert!(index.contains(NodeId::new(2)));
assert!(!index.remove(NodeId::new(1)));
}
#[test]
fn test_hnsw_cosine_metric() {
let config = HnswConfig::new(4, DistanceMetric::Cosine);
let index = HnswIndex::with_seed(config, 42);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(1), vec![1.0, 0.0, 0.0, 0.0].into());
map.insert(NodeId::new(2), vec![0.0, 1.0, 0.0, 0.0].into());
map.insert(NodeId::new(3), vec![0.707, 0.707, 0.0, 0.0].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(1), &[1.0, 0.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(2), &[0.0, 1.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(3), &[0.707, 0.707, 0.0, 0.0], &accessor);
let results = index.search(&[0.9, 0.1, 0.0, 0.0], 3, &accessor);
assert_eq!(results[0].0, NodeId::new(1));
}
#[test]
fn test_hnsw_ef_parameter() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::with_seed(config, 42);
let vectors = create_test_vectors(100, 4);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64 + 1);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec, &accessor);
}
let query = [0.5, 0.5, 0.5, 0.5];
let results_low = index.search_with_ef(&query, 5, 10, &accessor);
let results_high = index.search_with_ef(&query, 5, 100, &accessor);
assert_eq!(results_low.len(), 5);
assert_eq!(results_high.len(), 5);
assert!(results_high[0].1 <= results_low[0].1);
}
#[test]
#[should_panic(expected = "Vector dimensions mismatch")]
fn test_hnsw_dimension_mismatch_insert() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::new(config);
let map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
let accessor = make_accessor(&map);
index.insert(NodeId::new(1), &[0.1, 0.2, 0.3], &accessor); }
#[test]
#[should_panic(expected = "Query dimensions mismatch")]
fn test_hnsw_dimension_mismatch_search() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::new(config);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(1), vec![0.1, 0.2, 0.3, 0.4].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(1), &[0.1, 0.2, 0.3, 0.4], &accessor);
let _ = index.search(&[0.1, 0.2, 0.3], 1, &accessor); }
#[test]
fn test_hnsw_batch_insert() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::with_seed(config, 42);
let vectors = create_test_vectors(100, 4);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64 + 1);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
let pairs: Vec<_> = vectors
.iter()
.enumerate()
.map(|(i, v)| (NodeId::new(i as u64 + 1), v.as_slice()))
.collect();
index.batch_insert(pairs, &accessor);
assert_eq!(index.len(), 100);
let results = index.search(&vectors[50], 5, &accessor);
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, NodeId::new(51));
}
#[test]
fn test_hnsw_batch_search() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::with_seed(config, 42);
let vectors = create_test_vectors(100, 4);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64 + 1);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec, &accessor);
}
let queries: Vec<Vec<f32>> = (0..5).map(|i| vectors[i * 20].clone()).collect();
let all_results = index.batch_search(&queries, 3, &accessor);
assert_eq!(all_results.len(), 5);
for (i, results) in all_results.iter().enumerate() {
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, NodeId::new((i * 20 + 1) as u64));
assert!(results[0].1 < 0.001);
}
}
#[test]
fn test_hnsw_batch_search_with_ef() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::with_seed(config, 42);
let vectors = create_test_vectors(100, 4);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64 + 1);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec, &accessor);
}
let queries: Vec<Vec<f32>> = vec![vectors[25].clone(), vectors[75].clone()];
let results = index.batch_search_with_ef(&queries, 5, 100, &accessor);
assert_eq!(results.len(), 2);
assert_eq!(results[0].len(), 5);
assert_eq!(results[1].len(), 5);
}
#[test]
fn test_hnsw_batch_search_empty_index() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::new(config);
let map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
let accessor = make_accessor(&map);
let queries = vec![vec![0.0f32, 0.0, 0.0, 0.0]];
let results = index.batch_search(&queries, 10, &accessor);
assert_eq!(results.len(), 1);
assert!(results[0].is_empty());
}
fn brute_force_knn(
vectors: &[Vec<f32>],
query: &[f32],
k: usize,
metric: DistanceMetric,
) -> Vec<usize> {
let mut dists: Vec<(usize, f32)> = vectors
.iter()
.enumerate()
.map(|(i, v)| (i, crate::index::vector::compute_distance(query, v, metric)))
.collect();
dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
dists.into_iter().take(k).map(|(i, _)| i).collect()
}
#[test]
fn test_hnsw_recall_euclidean() {
let n = 1000;
let dim = 20;
let k = 10;
let num_queries = 100;
let mut seed: u64 = 12345;
let mut rand_f32 = || -> f32 {
seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
((seed >> 33) as f32) / (u32::MAX as f32)
};
let vectors: Vec<Vec<f32>> = (0..n)
.map(|_| (0..dim).map(|_| rand_f32()).collect())
.collect();
let config = HnswConfig::new(dim, DistanceMetric::Euclidean).with_m(16);
let index = HnswIndex::with_seed(config, 42);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64), vec, &accessor);
}
let queries: Vec<Vec<f32>> = (0..num_queries)
.map(|_| (0..dim).map(|_| rand_f32()).collect())
.collect();
let mut total_recall = 0.0f64;
for query in &queries {
let ground_truth = brute_force_knn(&vectors, query, k, DistanceMetric::Euclidean);
let gt_set: std::collections::HashSet<u64> =
ground_truth.iter().map(|&i| i as u64).collect();
let results = index.search_with_ef(query, k, 50, &accessor);
let found: std::collections::HashSet<u64> =
results.iter().map(|(id, _)| id.as_u64()).collect();
let overlap = gt_set.intersection(&found).count();
total_recall += overlap as f64 / k as f64;
}
let avg_recall = total_recall / num_queries as f64;
assert!(
avg_recall >= 0.90,
"Recall {avg_recall:.3} is below 0.90 threshold at M=16/ef=50"
);
}
#[test]
fn test_hnsw_recall_cosine() {
let n = 500;
let dim = 20;
let k = 10;
let num_queries = 50;
let mut seed: u64 = 67890;
let mut rand_f32 = || -> f32 {
seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
((seed >> 33) as f32) / (u32::MAX as f32)
};
let vectors: Vec<Vec<f32>> = (0..n)
.map(|_| (0..dim).map(|_| rand_f32()).collect())
.collect();
let config = HnswConfig::new(dim, DistanceMetric::Cosine).with_m(16);
let index = HnswIndex::with_seed(config, 42);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64), vec, &accessor);
}
let queries: Vec<Vec<f32>> = (0..num_queries)
.map(|_| (0..dim).map(|_| rand_f32()).collect())
.collect();
let mut total_recall = 0.0f64;
for query in &queries {
let ground_truth = brute_force_knn(&vectors, query, k, DistanceMetric::Cosine);
let gt_set: std::collections::HashSet<u64> =
ground_truth.iter().map(|&i| i as u64).collect();
let results = index.search_with_ef(query, k, 50, &accessor);
let found: std::collections::HashSet<u64> =
results.iter().map(|(id, _)| id.as_u64()).collect();
let overlap = gt_set.intersection(&found).count();
total_recall += overlap as f64 / k as f64;
}
let avg_recall = total_recall / num_queries as f64;
assert!(
avg_recall >= 0.90,
"Cosine recall {avg_recall:.3} is below 0.90 threshold at M=16/ef=50"
);
}
#[test]
fn test_diversity_pruning_prevents_clustering() {
let dim = 4;
let config = HnswConfig::new(dim, DistanceMetric::Euclidean).with_m(4);
let index = HnswIndex::with_seed(config, 42);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(0), vec![0.0, 0.0, 0.0, 0.0].into());
map.insert(NodeId::new(1), vec![0.01, 0.0, 0.0, 0.0].into());
map.insert(NodeId::new(2), vec![0.02, 0.0, 0.0, 0.0].into());
map.insert(NodeId::new(3), vec![0.03, 0.0, 0.0, 0.0].into());
map.insert(NodeId::new(4), vec![0.04, 0.0, 0.0, 0.0].into());
map.insert(NodeId::new(5), vec![0.0, 1.0, 0.0, 0.0].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(0), &[0.0, 0.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(1), &[0.01, 0.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(2), &[0.02, 0.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(3), &[0.03, 0.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(4), &[0.04, 0.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(5), &[0.0, 1.0, 0.0, 0.0], &accessor);
let results = index.search(&[0.0, 0.9, 0.0, 0.0], 1, &accessor);
assert_eq!(results[0].0, NodeId::new(5));
}
#[test]
fn test_single_vector() {
let config = HnswConfig::new(3, DistanceMetric::Euclidean);
let index = HnswIndex::new(config);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(0), vec![1.0, 0.0, 0.0].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(0), &[1.0, 0.0, 0.0], &accessor);
let results = index.search(&[1.0, 0.0, 0.0], 1, &accessor);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, NodeId::new(0));
assert!(results[0].1 < 0.01);
}
#[test]
fn test_search_k_larger_than_index() {
let config = HnswConfig::new(3, DistanceMetric::Euclidean);
let index = HnswIndex::new(config);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(0), vec![1.0, 0.0, 0.0].into());
map.insert(NodeId::new(1), vec![0.0, 1.0, 0.0].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(0), &[1.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(1), &[0.0, 1.0, 0.0], &accessor);
let results = index.search(&[1.0, 0.0, 0.0], 10, &accessor);
assert_eq!(results.len(), 2);
}
#[test]
fn test_empty_index_search() {
let config = HnswConfig::new(3, DistanceMetric::Euclidean);
let index = HnswIndex::new(config);
let map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
let accessor = make_accessor(&map);
let results = index.search(&[1.0, 0.0, 0.0], 5, &accessor);
assert!(results.is_empty());
}
#[test]
fn test_remove_and_search() {
let config = HnswConfig::new(3, DistanceMetric::Euclidean);
let index = HnswIndex::new(config);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(0), vec![1.0, 0.0, 0.0].into());
map.insert(NodeId::new(1), vec![0.0, 1.0, 0.0].into());
map.insert(NodeId::new(2), vec![0.0, 0.0, 1.0].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(0), &[1.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(1), &[0.0, 1.0, 0.0], &accessor);
index.insert(NodeId::new(2), &[0.0, 0.0, 1.0], &accessor);
index.remove(NodeId::new(1));
let results = index.search(&[0.0, 1.0, 0.0], 3, &accessor);
assert!(results.iter().all(|(id, _)| *id != NodeId::new(1)));
assert_eq!(results.len(), 2);
}
#[test]
fn test_duplicate_insert() {
let config = HnswConfig::new(3, DistanceMetric::Euclidean);
let index = HnswIndex::new(config);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(0), vec![1.0, 0.0, 0.0].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(0), &[1.0, 0.0, 0.0], &accessor);
let mut map2: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map2.insert(NodeId::new(0), vec![0.0, 1.0, 0.0].into());
let accessor2 = make_accessor(&map2);
index.insert(NodeId::new(0), &[0.0, 1.0, 0.0], &accessor2);
assert_eq!(index.len(), 1);
let results = index.search(&[0.0, 1.0, 0.0], 1, &accessor2);
assert_eq!(results[0].0, NodeId::new(0));
}
#[test]
fn test_search_with_ef_zero() {
let config = HnswConfig::new(3, DistanceMetric::Euclidean);
let index = HnswIndex::new(config);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(0), vec![1.0, 0.0, 0.0].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(0), &[1.0, 0.0, 0.0], &accessor);
let results = index.search_with_ef(&[1.0, 0.0, 0.0], 1, 0, &accessor);
assert!(results.len() <= 1);
}
#[test]
fn test_all_metrics_search() {
for metric in [
DistanceMetric::Cosine,
DistanceMetric::Euclidean,
DistanceMetric::DotProduct,
DistanceMetric::Manhattan,
] {
let config = HnswConfig::new(3, metric);
let index = HnswIndex::new(config);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(0), vec![1.0, 0.0, 0.0].into());
map.insert(NodeId::new(1), vec![0.0, 1.0, 0.0].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(0), &[1.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(1), &[0.0, 1.0, 0.0], &accessor);
let results = index.search(&[1.0, 0.0, 0.0], 2, &accessor);
assert_eq!(results.len(), 2, "Failed for metric {metric:?}");
assert_eq!(
results[0].0,
NodeId::new(0),
"Closest not correct for metric {metric:?}"
);
}
}
#[test]
fn test_batch_search_consistency() {
let config = HnswConfig::new(3, DistanceMetric::Euclidean);
let index = HnswIndex::new(config);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(0), vec![1.0, 0.0, 0.0].into());
map.insert(NodeId::new(1), vec![0.0, 1.0, 0.0].into());
map.insert(NodeId::new(2), vec![0.0, 0.0, 1.0].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(0), &[1.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(1), &[0.0, 1.0, 0.0], &accessor);
index.insert(NodeId::new(2), &[0.0, 0.0, 1.0], &accessor);
let queries: Vec<Vec<f32>> = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let batch_results = index.batch_search(&queries, 1, &accessor);
assert_eq!(batch_results.len(), 3);
for (i, results) in batch_results.iter().enumerate() {
assert_eq!(results[0].0, NodeId::new(i as u64));
}
}
#[test]
fn test_with_capacity_constructor() {
let config = HnswConfig::new(3, DistanceMetric::Euclidean);
let index = HnswIndex::with_capacity(config, 100);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(0), vec![1.0, 0.0, 0.0].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(0), &[1.0, 0.0, 0.0], &accessor);
assert_eq!(index.len(), 1);
assert!(!index.is_empty());
}
#[test]
fn test_high_m_value() {
let config = HnswConfig::new(3, DistanceMetric::Euclidean).with_m(64);
let index = HnswIndex::new(config);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(0), vec![1.0, 0.0, 0.0].into());
map.insert(NodeId::new(1), vec![0.0, 1.0, 0.0].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(0), &[1.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(1), &[0.0, 1.0, 0.0], &accessor);
let results = index.search(&[1.0, 0.0, 0.0], 2, &accessor);
assert_eq!(results.len(), 2);
}
#[test]
fn test_filtered_search_returns_only_allowlisted() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::with_seed(config, 42);
let vectors = create_test_vectors(50, 4);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64 + 1);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec, &accessor);
}
let allowlist: HashSet<NodeId> = (1..=50).filter(|i| i % 2 == 0).map(NodeId::new).collect();
let results = index.search_with_filter(&vectors[25], 5, &allowlist, &accessor);
assert!(!results.is_empty());
assert!(results.len() <= 5);
for (id, _) in &results {
assert!(allowlist.contains(id), "Result {id:?} not in allowlist");
}
}
#[test]
fn test_filtered_search_empty_allowlist() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::with_seed(config, 42);
let vectors = create_test_vectors(20, 4);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64 + 1);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec, &accessor);
}
let allowlist: HashSet<NodeId> = HashSet::new();
let results = index.search_with_filter(&vectors[5], 5, &allowlist, &accessor);
assert!(results.is_empty());
}
#[test]
fn test_filtered_search_full_allowlist_matches_unfiltered() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::with_seed(config, 42);
let vectors = create_test_vectors(50, 4);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64 + 1);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec, &accessor);
}
let allowlist: HashSet<NodeId> = (1..=50).map(NodeId::new).collect();
let query = &vectors[25];
let unfiltered = index.search_with_ef(query, 5, 200, &accessor);
let filtered = index.search_with_ef_and_filter(query, 5, 200, &allowlist, &accessor);
assert_eq!(unfiltered.len(), filtered.len());
for (u, f) in unfiltered.iter().zip(filtered.iter()) {
assert_eq!(u.0, f.0);
}
}
#[test]
fn test_filtered_search_single_allowlisted_node() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::with_seed(config, 42);
let vectors = create_test_vectors(50, 4);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64 + 1);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec, &accessor);
}
let allowlist: HashSet<NodeId> = [NodeId::new(30)].into_iter().collect();
let results = index.search_with_filter(&vectors[25], 5, &allowlist, &accessor);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, NodeId::new(30));
}
#[test]
fn test_filtered_search_sorted_by_distance() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::with_seed(config, 42);
let vectors = create_test_vectors(100, 4);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64 + 1);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec, &accessor);
}
let allowlist: HashSet<NodeId> =
(1..=100).filter(|i| i % 3 == 0).map(NodeId::new).collect();
let results = index.search_with_filter(&[0.5, 0.5, 0.5, 0.5], 10, &allowlist, &accessor);
for i in 1..results.len() {
assert!(results[i - 1].1 <= results[i].1);
}
}
#[test]
fn test_batch_filtered_search() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = HnswIndex::with_seed(config, 42);
let vectors = create_test_vectors(100, 4);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64 + 1);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec, &accessor);
}
let allowlist: HashSet<NodeId> = (1..=50).map(NodeId::new).collect();
let queries: Vec<Vec<f32>> = vec![vectors[10].clone(), vectors[70].clone()];
let all_results = index.batch_search_with_filter(&queries, 5, &allowlist, &accessor);
assert_eq!(all_results.len(), 2);
for results in &all_results {
for (id, _) in results {
assert!(allowlist.contains(id));
}
}
}
#[test]
fn test_filtered_search_ef_scaling() {
let n = 500;
let dim = 8;
let k = 10;
let config = HnswConfig::new(dim, DistanceMetric::Euclidean).with_m(16);
let index = HnswIndex::with_seed(config, 42);
let mut seed: u64 = 99999;
let mut rand_f32 = || -> f32 {
seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
((seed >> 33) as f32) / (u32::MAX as f32)
};
let vectors: Vec<Vec<f32>> = (0..n)
.map(|_| (0..dim).map(|_| rand_f32()).collect())
.collect();
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
for (i, vec) in vectors.iter().enumerate() {
let id = NodeId::new(i as u64);
let arc: Arc<[f32]> = vec.as_slice().into();
map.insert(id, arc);
}
let accessor = make_accessor(&map);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64), vec, &accessor);
}
let allowlist: HashSet<NodeId> = (0..n)
.filter(|i| i % 5 == 0)
.map(|i| NodeId::new(i as u64))
.collect();
let query: Vec<f32> = (0..dim).map(|_| rand_f32()).collect();
let mut gt: Vec<(u64, f32)> = allowlist
.iter()
.map(|id| {
let dist = crate::index::vector::compute_distance(
&query,
&vectors[id.as_u64() as usize],
DistanceMetric::Euclidean,
);
(id.as_u64(), dist)
})
.collect();
gt.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let gt_set: std::collections::HashSet<u64> = gt.iter().take(k).map(|(id, _)| *id).collect();
let results = index.search_with_filter(&query, k, &allowlist, &accessor);
let found: std::collections::HashSet<u64> =
results.iter().map(|(id, _)| id.as_u64()).collect();
let overlap = gt_set.intersection(&found).count();
let recall = overlap as f64 / k as f64;
assert!(
recall >= 0.60,
"Filtered recall {recall:.3} is below 0.60 threshold (20% selectivity)"
);
}
#[test]
fn test_filtered_search_cosine() {
let config = HnswConfig::new(4, DistanceMetric::Cosine);
let index = HnswIndex::with_seed(config, 42);
let mut map: HashMap<NodeId, Arc<[f32]>> = HashMap::new();
map.insert(NodeId::new(1), vec![1.0, 0.0, 0.0, 0.0].into());
map.insert(NodeId::new(2), vec![0.0, 1.0, 0.0, 0.0].into());
map.insert(NodeId::new(3), vec![0.707, 0.707, 0.0, 0.0].into());
let accessor = make_accessor(&map);
index.insert(NodeId::new(1), &[1.0, 0.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(2), &[0.0, 1.0, 0.0, 0.0], &accessor);
index.insert(NodeId::new(3), &[0.707, 0.707, 0.0, 0.0], &accessor);
let allowlist: HashSet<NodeId> = [NodeId::new(2), NodeId::new(3)].into_iter().collect();
let results = index.search_with_filter(&[0.9, 0.1, 0.0, 0.0], 2, &allowlist, &accessor);
assert!(!results.is_empty());
for (id, _) in &results {
assert!(allowlist.contains(id));
}
assert_eq!(results[0].0, NodeId::new(3));
}
}