use crate::{Document, Result, SearchResult};
use parking_lot::{Mutex, RwLock};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rayon::prelude::*;
use std::cmp::{max, Reverse};
use std::collections::{BinaryHeap, HashSet};
#[inline(always)]
unsafe fn prefetch_read(ptr: *const u8) {
#[cfg(target_arch = "x86_64")]
{
std::arch::x86_64::_mm_prefetch::<{ std::arch::x86_64::_MM_HINT_T0 }>(ptr as *const i8);
}
#[cfg(target_arch = "aarch64")]
{
std::arch::asm!("prfm pldl1keep, [{ptr}]", ptr = in(reg) ptr, options(nostack, preserves_flags));
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let _ = ptr;
}
}
#[inline(always)]
unsafe fn prefetch_embedding(ptr: *const u8, cache_lines: usize) {
for i in 0..cache_lines {
prefetch_read(ptr.add(i * 64));
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct OrderedFloat(f32);
impl Eq for OrderedFloat {}
impl PartialOrd for OrderedFloat {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrderedFloat {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.total_cmp(&other.0)
}
}
struct BitsetVisited {
bits: Vec<u64>,
}
impl BitsetVisited {
fn new(n: usize) -> Self {
Self {
bits: vec![0u64; (n + 63) / 64],
}
}
#[inline(always)]
fn is_visited(&self, node: usize) -> bool {
debug_assert!(
(node >> 6) < self.bits.len(),
"BitsetVisited::is_visited: node {} out of bounds (capacity {})",
node,
self.bits.len() * 64
);
let word = unsafe { *self.bits.get_unchecked(node >> 6) };
word & (1u64 << (node & 63)) != 0
}
#[inline(always)]
fn mark_visited(&mut self, node: usize) {
debug_assert!(
(node >> 6) < self.bits.len(),
"BitsetVisited::mark_visited: node {} out of bounds (capacity {})",
node,
self.bits.len() * 64
);
unsafe {
*self.bits.get_unchecked_mut(node >> 6) |= 1u64 << (node & 63);
}
}
#[inline]
fn clear(&mut self) {
self.bits.fill(0);
}
}
pub struct SearchContext {
visited: BitsetVisited,
capacity: usize,
candidates: BinaryHeap<Reverse<(OrderedFloat, usize)>>,
best: BinaryHeap<(OrderedFloat, usize)>,
}
impl SearchContext {
pub fn new(n: usize) -> Self {
Self {
visited: BitsetVisited::new(n),
capacity: n,
candidates: BinaryHeap::with_capacity(256),
best: BinaryHeap::with_capacity(256),
}
}
#[inline]
fn reset(&mut self) {
self.visited.clear();
self.candidates.clear();
self.best.clear();
}
#[inline(always)]
fn is_visited(&self, node: usize) -> bool {
self.visited.is_visited(node)
}
#[inline(always)]
fn mark_visited(&mut self, node: usize) {
self.visited.mark_visited(node);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BuildStrategy {
#[default]
Sequential,
Parallel,
Auto,
}
#[derive(Debug, Clone)]
pub struct HNSWConfig {
pub m: usize,
pub m0: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub ml: f32,
pub use_heuristic: bool,
pub extend_candidates: bool,
pub keep_pruned_connections: bool,
pub build_strategy: BuildStrategy,
pub seed: Option<u64>,
}
impl Default for HNSWConfig {
fn default() -> Self {
let m = 32; Self {
m,
m0: m * 2,
ef_construction: 100,
ef_search: 100,
ml: 1.0 / (m as f32).ln(),
use_heuristic: true,
extend_candidates: false,
keep_pruned_connections: true,
build_strategy: BuildStrategy::default(),
seed: None,
}
}
}
impl HNSWConfig {
pub fn with_simple_selection(mut self) -> Self {
self.use_heuristic = false;
self
}
pub fn with_extended_candidates(mut self) -> Self {
self.extend_candidates = true;
self
}
pub fn with_ef_search(mut self, ef: usize) -> Self {
self.ef_search = ef;
self
}
pub fn with_build_strategy(mut self, strategy: BuildStrategy) -> Self {
self.build_strategy = strategy;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn with_ef_construction(mut self, ef: usize) -> Self {
self.ef_construction = ef;
self
}
pub fn with_m(mut self, m: usize) -> Self {
assert!(m <= 127, "m must be <= 127 (m0 = 2*m must fit in u8)");
self.m = m;
self.m0 = m * 2;
self.ml = 1.0 / (m as f32).ln();
self
}
}
pub struct HNSWIndex {
embedding_dim: usize,
config: HNSWConfig,
embeddings: Vec<f32>,
norms: Vec<f32>,
connections: Vec<Vec<Vec<u32>>>,
connections_l0: Vec<u32>,
connections_l0_count: Vec<u8>,
ids: Vec<String>,
contents: Vec<String>,
metadata: Vec<Option<serde_json::Value>>,
entry_point: Option<usize>,
max_layer: usize,
}
impl HNSWIndex {
pub fn new(embedding_dim: usize, config: HNSWConfig) -> Self {
assert!(
config.m <= 127,
"m must be <= 127 (m0 = 2*m must fit in u8), got m={}",
config.m
);
Self {
embedding_dim,
config,
embeddings: Vec::new(),
norms: Vec::new(),
connections: Vec::new(),
connections_l0: Vec::new(),
connections_l0_count: Vec::new(),
ids: Vec::new(),
contents: Vec::new(),
metadata: Vec::new(),
entry_point: None,
max_layer: 0,
}
}
pub fn with_defaults(embedding_dim: usize) -> Self {
Self::new(embedding_dim, HNSWConfig::default())
}
pub fn build(embeddings: Vec<Vec<f32>>, config: HNSWConfig) -> Self {
if embeddings.is_empty() {
return Self::new(0, config);
}
let expected_dim = embeddings[0].len();
for (i, embedding) in embeddings.iter().enumerate() {
assert!(
embedding.len() == expected_dim,
"All embeddings must have the same dimension: expected {}, got {} at index {}",
expected_dim,
embedding.len(),
i
);
}
let n = embeddings.len();
let strategy = match config.build_strategy {
BuildStrategy::Auto => {
if n < 50_000 {
BuildStrategy::Parallel
} else {
BuildStrategy::Sequential
}
}
other => other,
};
match strategy {
BuildStrategy::Sequential => Self::build_sequential(embeddings, config),
BuildStrategy::Parallel | BuildStrategy::Auto => {
Self::build_parallel(embeddings, config)
}
}
}
fn build_sequential(embeddings: Vec<Vec<f32>>, config: HNSWConfig) -> Self {
let embedding_dim = embeddings[0].len();
let n = embeddings.len();
let seed = config.seed.unwrap_or_else(rand::random);
let mut rng = StdRng::seed_from_u64(seed);
let ml = config.ml;
let levels: Vec<usize> = (0..n)
.map(|_| {
let r: f32 = rng.gen::<f32>().max(f32::EPSILON);
(-r.ln() * ml).floor() as usize
})
.collect();
let _max_level = *levels.iter().max().unwrap_or(&0);
let mut sorted_indices: Vec<usize> = (0..n).collect();
sorted_indices.sort_by(|&a, &b| levels[b].cmp(&levels[a]));
let mut index = Self::new(embedding_dim, config);
index.embeddings.reserve(n * embedding_dim);
index.norms.reserve(n);
index.connections.reserve(n);
index.ids.reserve(n);
index.contents.reserve(n);
index.metadata.reserve(n);
for &i in &sorted_indices {
let level = levels[i];
let node_id = index.len();
let mut node_connections: Vec<Vec<u32>> = Vec::with_capacity(level + 1);
for _ in 0..=level {
node_connections.push(Vec::new());
}
let norm = crate::vector::simd::norm_simd(&embeddings[i]);
index.embeddings.extend_from_slice(&embeddings[i]);
index.norms.push(norm);
index.connections.push(node_connections);
index.ids.push(i.to_string());
index.contents.push(String::new());
index.metadata.push(None);
if index.entry_point.is_none() {
index.entry_point = Some(node_id);
index.max_layer = level;
continue;
}
index.insert_node(node_id, level);
if level > index.max_layer {
index.max_layer = level;
index.entry_point = Some(node_id);
}
}
index.build_l0_cache();
index
}
#[inline]
pub fn len(&self) -> usize {
self.ids.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.ids.is_empty()
}
#[inline]
fn get_embedding(&self, node_id: usize) -> &[f32] {
let start = node_id * self.embedding_dim;
let end = start + self.embedding_dim;
&self.embeddings[start..end]
}
#[inline(always)]
fn get_neighbors_l0(&self, node_id: usize) -> &[u32] {
let m0 = self.config.m0;
let start = node_id * m0;
let count = self.connections_l0_count[node_id] as usize;
&self.connections_l0[start..start + count]
}
fn build_l0_cache(&mut self) {
let n = self.len();
let m0 = self.config.m0;
debug_assert!(m0 <= 255, "m0 exceeds u8 capacity for connections_l0_count");
self.connections_l0 = vec![0u32; n * m0];
self.connections_l0_count = vec![0u8; n];
for i in 0..n {
if !self.connections[i].is_empty() {
let neighbors = &self.connections[i][0];
let count = neighbors.len().min(m0);
let start = i * m0;
self.connections_l0[start..start + count].copy_from_slice(&neighbors[..count]);
self.connections_l0_count[i] = count as u8;
}
}
}
fn extend_l0_cache_for_new_node(&mut self) {
let m0 = self.config.m0;
self.connections_l0
.resize(self.connections_l0.len() + m0, 0);
self.connections_l0_count.push(0);
}
pub fn add(&mut self, document: Document) -> Result<()> {
if document.embedding.len() != self.embedding_dim {
return Err(crate::RagError::DimensionMismatch {
expected: self.embedding_dim,
actual: document.embedding.len(),
});
}
if document.embedding.iter().any(|v| !v.is_finite()) {
return Err(crate::RagError::InvalidInput(
"embedding contains non-finite values (NaN or Inf)".into(),
));
}
let node_id = self.len();
let node_level = self.random_level();
let mut node_connections: Vec<Vec<u32>> = Vec::with_capacity(node_level + 1);
for _ in 0..=node_level {
node_connections.push(Vec::new());
}
let norm = crate::vector::simd::norm_simd(&document.embedding);
self.embeddings.extend_from_slice(&document.embedding);
self.norms.push(norm);
self.connections.push(node_connections);
self.ids.push(document.id);
self.contents.push(document.content);
self.metadata.push(document.metadata);
self.extend_l0_cache_for_new_node();
if self.entry_point.is_none() {
self.entry_point = Some(node_id);
self.max_layer = node_level;
return Ok(());
}
let affected = self.insert_node(node_id, node_level);
self.sync_l0_cache_for_nodes(&affected);
if node_level > self.max_layer {
self.max_layer = node_level;
self.entry_point = Some(node_id);
}
Ok(())
}
pub fn add_embedding(&mut self, id: String, embedding: Vec<f32>) -> Result<()> {
if embedding.len() != self.embedding_dim {
return Err(crate::RagError::DimensionMismatch {
expected: self.embedding_dim,
actual: embedding.len(),
});
}
if embedding.iter().any(|v| !v.is_finite()) {
return Err(crate::RagError::InvalidInput(
"embedding contains non-finite values (NaN or Inf)".into(),
));
}
let node_id = self.len();
let node_level = self.random_level();
let mut node_connections: Vec<Vec<u32>> = Vec::with_capacity(node_level + 1);
for _ in 0..=node_level {
node_connections.push(Vec::new());
}
let norm = crate::vector::simd::norm_simd(&embedding);
self.embeddings.extend_from_slice(&embedding);
self.norms.push(norm);
self.connections.push(node_connections);
self.ids.push(id);
self.contents.push(String::new());
self.metadata.push(None);
self.extend_l0_cache_for_new_node();
if self.entry_point.is_none() {
self.entry_point = Some(node_id);
self.max_layer = node_level;
return Ok(());
}
let affected = self.insert_node(node_id, node_level);
self.sync_l0_cache_for_nodes(&affected);
if node_level > self.max_layer {
self.max_layer = node_level;
self.entry_point = Some(node_id);
}
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
let mut ctx = SearchContext::new(self.len());
self.search_with_context(query, k, &mut ctx)
}
pub fn search_batch(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<SearchResult>>> {
use rayon::prelude::*;
queries
.par_iter()
.map(|query| self.search(query, k))
.collect()
}
pub fn create_search_context(&self) -> SearchContext {
SearchContext::new(self.len())
}
pub fn search_with_context(
&self,
query: &[f32],
k: usize,
ctx: &mut SearchContext,
) -> Result<Vec<SearchResult>> {
if query.len() != self.embedding_dim {
return Err(crate::RagError::DimensionMismatch {
expected: self.embedding_dim,
actual: query.len(),
});
}
if self.is_empty() {
return Ok(Vec::new());
}
if ctx.capacity < self.len() {
*ctx = SearchContext::new(self.len());
}
let query_norm = crate::vector::simd::norm_simd(query);
let entry_point = self.entry_point.unwrap();
let mut current_nearest = vec![entry_point];
for layer in (1..=self.max_layer).rev() {
current_nearest = self.search_layer(query, ¤t_nearest, 1, layer, ctx, query_norm);
}
let ef = self.config.ef_search.max(k);
current_nearest = self.search_layer(query, ¤t_nearest, ef, 0, ctx, query_norm);
let mut results: Vec<SearchResult> = current_nearest
.iter()
.map(|&node_id| {
let dist = self.distance_to_node(query, node_id, query_norm);
let score = 1.0 - dist;
SearchResult {
id: self.ids[node_id].clone(),
content: self.contents[node_id].clone(),
score,
metadata: self.metadata[node_id].clone(),
}
})
.collect();
results.sort_by(|a, b| b.score.total_cmp(&a.score));
results.truncate(k);
Ok(results)
}
pub fn search_batch_fast(
&self,
queries: &[Vec<f32>],
k: usize,
) -> Result<Vec<Vec<SearchResult>>> {
use rayon::prelude::*;
use std::cell::RefCell;
thread_local! {
static CTX: RefCell<Option<SearchContext>> = const { RefCell::new(None) };
}
queries
.par_iter()
.map(|query| {
CTX.with(|ctx| {
let mut ctx_ref = ctx.borrow_mut();
if ctx_ref.is_none() || ctx_ref.as_ref().unwrap().capacity < self.len() {
*ctx_ref = Some(SearchContext::new(self.len()));
}
self.search_with_context(query, k, ctx_ref.as_mut().unwrap())
})
})
.collect()
}
pub fn clear(&mut self) {
self.embeddings.clear();
self.norms.clear();
self.connections.clear();
self.connections_l0.clear();
self.connections_l0_count.clear();
self.ids.clear();
self.contents.clear();
self.metadata.clear();
self.entry_point = None;
self.max_layer = 0;
}
pub fn get_all_documents(&self) -> Vec<Document> {
(0..self.len())
.map(|i| Document {
id: self.ids[i].clone(),
content: self.contents[i].clone(),
embedding: self.get_embedding(i).to_vec(),
metadata: self.metadata[i].clone(),
})
.collect()
}
pub fn config(&self) -> &HNSWConfig {
&self.config
}
pub fn entry_point(&self) -> Option<usize> {
self.entry_point
}
pub fn max_layer(&self) -> usize {
self.max_layer
}
pub fn embedding_dim(&self) -> usize {
self.embedding_dim
}
fn random_level(&self) -> usize {
let mut rng = rand::thread_rng();
let uniform: f32 = rng.gen::<f32>().max(f32::EPSILON);
(-uniform.ln() * self.config.ml).floor() as usize
}
fn insert_node(&mut self, node_id: usize, node_level: usize) -> Vec<usize> {
let entry_point = self.entry_point.unwrap();
let mut current_nearest = vec![entry_point];
let node_embedding = self.get_embedding(node_id).to_vec();
let query_norm = crate::vector::simd::norm_simd(&node_embedding);
let mut ctx = SearchContext::new(self.len());
for layer in (node_level + 1..=self.max_layer).rev() {
current_nearest = self.search_layer(
&node_embedding,
¤t_nearest,
1,
layer,
&mut ctx,
query_norm,
);
}
let mut l0_modified: Vec<usize> = vec![node_id];
for layer in (0..=node_level).rev() {
current_nearest = self.search_layer(
&node_embedding,
¤t_nearest,
self.config.ef_construction,
layer,
&mut ctx,
query_norm,
);
let m = if layer == 0 {
self.config.m0
} else {
self.config.m
};
let neighbors = self.select_neighbors(¤t_nearest, &node_embedding, m, layer);
for &neighbor_id in &neighbors {
let neighbor_u32 = neighbor_id as u32;
if !self.connections[node_id][layer].contains(&neighbor_u32) {
self.connections[node_id][layer].push(neighbor_u32);
}
if layer < self.connections[neighbor_id].len() {
let node_u32 = node_id as u32;
if !self.connections[neighbor_id][layer].contains(&node_u32) {
self.connections[neighbor_id][layer].push(node_u32);
}
let neighbor_m = if layer == 0 {
self.config.m0
} else {
self.config.m
};
if self.connections[neighbor_id][layer].len() > neighbor_m {
let neighbor_embedding = self.get_embedding(neighbor_id).to_vec();
let neighbor_connections: Vec<usize> = self.connections[neighbor_id][layer]
.iter()
.map(|&x| x as usize)
.collect();
let pruned = self.select_neighbors(
&neighbor_connections,
&neighbor_embedding,
neighbor_m,
layer,
);
self.connections[neighbor_id][layer] =
pruned.into_iter().map(|x| x as u32).collect();
}
if layer == 0 {
l0_modified.push(neighbor_id);
}
}
}
}
l0_modified
}
fn sync_l0_cache_for_nodes(&mut self, node_ids: &[usize]) {
let m0 = self.config.m0;
let needed = self.connections.len() * m0;
if self.connections_l0.len() < needed {
self.connections_l0.resize(needed, 0u32);
self.connections_l0_count
.resize(self.connections.len(), 0u8);
}
for &node_id in node_ids {
if node_id < self.connections.len() && !self.connections[node_id].is_empty() {
let neighbors = &self.connections[node_id][0];
let count = neighbors.len().min(m0);
let start = node_id * m0;
self.connections_l0[start..start + count].copy_from_slice(&neighbors[..count]);
for j in count..m0 {
self.connections_l0[start + j] = 0;
}
self.connections_l0_count[node_id] = count as u8;
}
}
}
#[inline]
fn search_layer(
&self,
query: &[f32],
entry_points: &[usize],
ef: usize,
layer: usize,
ctx: &mut SearchContext,
query_norm: f32,
) -> Vec<usize> {
ctx.reset();
for &ep in entry_points {
let dist = self.distance_to_node(query, ep, query_norm);
ctx.candidates.push(Reverse((OrderedFloat(dist), ep)));
ctx.best.push((OrderedFloat(dist), ep));
ctx.mark_visited(ep);
}
while let Some(Reverse((current_dist, current_id))) = ctx.candidates.pop() {
if ctx.best.len() >= ef {
if let Some(&(furthest_dist, _)) = ctx.best.peek() {
if current_dist > furthest_dist {
break;
}
}
}
let neighbors_l0_slice;
let neighbors: &[u32] = if layer == 0 && !self.connections_l0.is_empty() {
neighbors_l0_slice = self.get_neighbors_l0(current_id);
neighbors_l0_slice
} else if layer < self.connections[current_id].len() {
&self.connections[current_id][layer]
} else {
&[]
};
if !neighbors.is_empty() {
let n_neighbors = neighbors.len();
const PREFETCH_AHEAD: usize = 2;
const CACHE_LINES_PER_EMBEDDING: usize = 3;
let mut batch_buf: [(f32, usize); 64] = [(0.0, 0); 64];
let mut batch_count = 0usize;
let mut overflow = Vec::new();
for (i, &neighbor_u32) in neighbors.iter().enumerate() {
let neighbor_id = neighbor_u32 as usize;
unsafe {
let lookahead = i + PREFETCH_AHEAD;
if lookahead < n_neighbors {
let ahead_id = neighbors[lookahead] as usize;
let ahead_ptr = self
.embeddings
.as_ptr()
.wrapping_add(ahead_id * self.embedding_dim)
as *const u8;
prefetch_embedding(ahead_ptr, CACHE_LINES_PER_EMBEDDING);
let bitset_ptr =
ctx.visited.bits.as_ptr().wrapping_add(ahead_id >> 6) as *const u8;
prefetch_read(bitset_ptr);
let norm_ptr = self.norms.as_ptr().wrapping_add(ahead_id) as *const u8;
prefetch_read(norm_ptr);
}
}
if !ctx.is_visited(neighbor_id) {
ctx.mark_visited(neighbor_id);
let dist = self.distance_to_node(query, neighbor_id, query_norm);
if batch_count < batch_buf.len() {
batch_buf[batch_count] = (dist, neighbor_id);
batch_count += 1;
} else {
overflow.push((dist, neighbor_id));
}
}
}
let mut consider = |dist: f32, neighbor_id: usize| {
let dist_ord = OrderedFloat(dist);
if ctx.best.len() < ef {
ctx.candidates.push(Reverse((dist_ord, neighbor_id)));
ctx.best.push((dist_ord, neighbor_id));
} else if let Some(&(furthest_dist, _)) = ctx.best.peek() {
if dist_ord < furthest_dist {
ctx.candidates.push(Reverse((dist_ord, neighbor_id)));
ctx.best.push((dist_ord, neighbor_id));
if ctx.best.len() > ef {
ctx.best.pop();
}
}
}
};
for &(dist, neighbor_id) in &batch_buf[..batch_count] {
consider(dist, neighbor_id);
}
for &(dist, neighbor_id) in &overflow {
consider(dist, neighbor_id);
}
}
}
let mut results: Vec<(f32, usize)> = ctx
.best
.drain()
.map(|(OrderedFloat(dist), id)| (dist, id))
.collect();
results.sort_by(|a, b| a.0.total_cmp(&b.0));
results.into_iter().map(|(_, id)| id).collect()
}
fn select_neighbors(
&self,
candidates: &[usize],
query: &[f32],
m: usize,
layer: usize,
) -> Vec<usize> {
if !self.config.use_heuristic {
return self.select_neighbors_simple(candidates, query, m);
}
let mut working_candidates: Vec<usize> = candidates.to_vec();
if self.config.extend_candidates {
let mut seen: HashSet<usize> = candidates.iter().copied().collect();
for &candidate in candidates {
if layer < self.connections[candidate].len() {
for &neighbor_u32 in &self.connections[candidate][layer] {
let neighbor = neighbor_u32 as usize;
if seen.insert(neighbor) {
working_candidates.push(neighbor);
}
}
}
}
}
let mut scored: Vec<(f32, usize)> = working_candidates
.iter()
.map(|&id| {
let dist = self.distance(query, self.get_embedding(id));
(dist, id)
})
.collect();
scored.sort_by(|a, b| a.0.total_cmp(&b.0));
let mut selected: Vec<usize> = Vec::with_capacity(m);
let mut pruned: Vec<(f32, usize)> = Vec::new();
for (dist_to_query, candidate_id) in scored {
if selected.len() >= m {
break;
}
let candidate_embedding = self.get_embedding(candidate_id);
let mut is_good = true;
for &selected_id in &selected {
let selected_embedding = self.get_embedding(selected_id);
let dist_to_selected = self.distance(candidate_embedding, selected_embedding);
if dist_to_selected < dist_to_query {
is_good = false;
pruned.push((dist_to_query, candidate_id));
break;
}
}
if is_good {
selected.push(candidate_id);
}
}
if self.config.keep_pruned_connections && selected.len() < m {
for (_, pruned_id) in pruned {
if selected.len() >= m {
break;
}
if !selected.contains(&pruned_id) {
selected.push(pruned_id);
}
}
}
selected
}
#[inline]
fn select_neighbors_simple(&self, candidates: &[usize], query: &[f32], m: usize) -> Vec<usize> {
let mut scored: Vec<(f32, usize)> = candidates
.iter()
.map(|&id| {
let dist = self.distance(query, self.get_embedding(id));
(dist, id)
})
.collect();
scored.sort_by(|a, b| a.0.total_cmp(&b.0));
scored.truncate(m);
scored.into_iter().map(|(_, id)| id).collect()
}
#[inline]
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
1.0 - crate::vector::simd::cosine_similarity_simd(a, b)
}
#[inline]
fn distance_to_node(&self, query: &[f32], node_id: usize, query_norm: f32) -> f32 {
let embedding = self.get_embedding(node_id);
let norm_b = self.norms[node_id];
if query_norm == 0.0 || norm_b == 0.0 {
return 1.0;
}
crate::vector::simd::cosine_distance_prenorm(query, embedding, norm_b)
}
pub fn build_parallel(embeddings: Vec<Vec<f32>>, config: HNSWConfig) -> Self {
assert!(!embeddings.is_empty(), "Cannot build from empty embeddings");
let embedding_dim = embeddings[0].len();
let n = embeddings.len();
if n == 1 {
return Self::build_single(embeddings, config);
}
let ml = config.ml;
let ef_construction = config.ef_construction;
let seed = config.seed.unwrap_or_else(rand::random);
let mut rng = StdRng::seed_from_u64(seed);
let mut sizes = Vec::new();
let mut num = n;
loop {
let next = (num as f32 * ml) as usize;
if next < M_MAX {
break;
}
sizes.push((num - next, num));
num = next;
}
sizes.push((num, num));
sizes.reverse();
let num_batches = sizes.len();
let top = LayerId(num_batches - 1);
assert!(n < u32::MAX as usize);
let mut shuffled: Vec<(u32, usize)> = (0..n).map(|i| (rng.gen::<u32>(), i)).collect();
shuffled.sort_unstable_by_key(|&(r, _)| r);
let points: Vec<Vec<f32>> = shuffled
.iter()
.map(|&(_, idx)| embeddings[idx].clone())
.collect();
let mut ranges = Vec::with_capacity(num_batches);
for (i, (size, cumulative)) in sizes.into_iter().enumerate() {
let start = cumulative - size;
let batch_id = LayerId(num_batches - i - 1);
ranges.push((batch_id, max(start, 1)..cumulative));
}
let zero: Vec<RwLock<ZeroNode>> =
(0..n).map(|_| RwLock::new(ZeroNode::default())).collect();
let mut layers: Vec<Vec<UpperNode>> = vec![Vec::new(); top.0];
let pool = SearchPool::new(n);
for (batch, range) in ranges {
let end = range.end;
if batch.0 == top.0 {
for i in range {
Self::par_insert(
PointId(i as u32),
batch,
&zero,
&layers,
&points,
&pool,
ef_construction,
top,
);
}
} else {
range.into_par_iter().for_each(|i| {
Self::par_insert(
PointId(i as u32),
batch,
&zero,
&layers,
&points,
&pool,
ef_construction,
top,
);
});
}
if !batch.is_zero() {
zero[..end]
.par_iter()
.map(|z| UpperNode::from_zero(&z.read()))
.collect_into_vec(&mut layers[batch.0 - 1]);
}
}
Self::convert_parallel_to_index(zero, layers, points, shuffled, embedding_dim, config, top)
}
#[allow(clippy::too_many_arguments)]
fn par_insert(
new: PointId,
target_layer: LayerId, zero: &[RwLock<ZeroNode>],
layers: &[Vec<UpperNode>],
points: &[Vec<f32>],
pool: &SearchPool,
ef_construction: usize,
top: LayerId,
) {
let mut search = pool.pop();
search.visited.reserve(points.len());
let point = &points[new.as_usize()];
search.reset();
search.push(PointId(0), point, points);
for cur in top.descend() {
search.ef = if cur.0 <= target_layer.0 {
ef_construction
} else {
1
};
if cur.0 > target_layer.0 {
if cur.0 <= layers.len() && !layers[cur.0 - 1].is_empty() {
search.search_upper(point, &layers[cur.0 - 1], points, M_MAX);
search.cull();
}
} else {
search.search_zero(point, zero, points, M0_MAX);
break; }
}
let found = search.select_simple();
{
let mut node = zero[new.as_usize()].write();
for (i, candidate) in found.iter().take(M0_MAX).enumerate() {
node.nearest[i] = candidate.pid;
}
}
for candidate in found.iter().take(M0_MAX) {
Self::add_reverse_connection(zero, points, new, candidate.pid);
}
pool.push(search);
}
fn add_reverse_connection(
zero: &[RwLock<ZeroNode>],
points: &[Vec<f32>],
new: PointId,
neighbor: PointId,
) {
let mut node = zero[neighbor.as_usize()].write();
let neighbor_point = &points[neighbor.as_usize()];
let new_dist = Self::parallel_distance(neighbor_point, &points[new.as_usize()]);
let count = node.count();
let pos = {
let mut left = 0;
let mut right = count;
while left < right {
let mid = (left + right) / 2;
let mid_dist =
Self::parallel_distance(neighbor_point, &points[node.nearest[mid].as_usize()]);
if mid_dist < new_dist {
left = mid + 1;
} else {
right = mid;
}
}
left
};
if pos >= M0_MAX {
return;
}
let shift_end = count.min(M0_MAX - 1);
for i in (pos..shift_end).rev() {
node.nearest[i + 1] = node.nearest[i];
}
node.nearest[pos] = new;
}
fn convert_parallel_to_index(
zero: Vec<RwLock<ZeroNode>>,
layers: Vec<Vec<UpperNode>>,
points: Vec<Vec<f32>>,
shuffled: Vec<(u32, usize)>,
embedding_dim: usize,
config: HNSWConfig,
top: LayerId,
) -> Self {
let n = points.len();
let num_layers = top.0 + 1; let zero_final: Vec<ZeroNode> = zero.into_iter().map(|n| n.into_inner()).collect();
let mut connections: Vec<Vec<Vec<u32>>> = Vec::with_capacity(n);
for i in 0..n {
let mut node_connections: Vec<Vec<u32>> = Vec::with_capacity(num_layers);
let layer0: Vec<u32> = zero_final[i].iter().map(|p| p.as_usize() as u32).collect();
node_connections.push(layer0);
for layer in &layers {
if i < layer.len() {
let layer_conns: Vec<u32> =
layer[i].iter().map(|p| p.as_usize() as u32).collect();
node_connections.push(layer_conns);
}
}
connections.push(node_connections);
}
let norms: Vec<f32> = points
.iter()
.map(|p| crate::vector::simd::norm_simd(p))
.collect();
let flat_embeddings: Vec<f32> = points.into_iter().flatten().collect();
let ids: Vec<String> = shuffled.iter().map(|&(_, orig)| orig.to_string()).collect();
let mut index = Self {
embedding_dim,
config,
embeddings: flat_embeddings,
norms,
connections,
connections_l0: Vec::new(),
connections_l0_count: Vec::new(),
ids,
contents: vec![String::new(); n],
metadata: vec![None; n],
entry_point: Some(0),
max_layer: top.0,
};
index.build_l0_cache();
index
}
fn build_single(embeddings: Vec<Vec<f32>>, config: HNSWConfig) -> Self {
let embedding_dim = embeddings[0].len();
let norms = vec![crate::vector::simd::norm_simd(&embeddings[0])];
let m0 = config.m0;
Self {
embedding_dim,
config,
embeddings: embeddings.into_iter().flatten().collect(),
norms,
connections: vec![vec![Vec::new()]],
connections_l0: vec![0u32; m0],
connections_l0_count: vec![0u8],
ids: vec!["0".to_string()],
contents: vec![String::new()],
metadata: vec![None],
entry_point: Some(0),
max_layer: 0,
}
}
#[inline]
fn parallel_distance(a: &[f32], b: &[f32]) -> f32 {
1.0 - crate::vector::simd::cosine_similarity_simd(a, b)
}
}
impl crate::index::VectorIndex for HNSWIndex {
fn add(&mut self, document: Document) -> Result<()> {
self.add(document)
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
self.search(query, k)
}
fn len(&self) -> usize {
self.len()
}
fn clear(&mut self) {
self.clear()
}
fn embedding_dim(&self) -> usize {
self.embedding_dim()
}
}
impl crate::index::VectorIndexSnapshot for HNSWIndex {
fn get_all_documents(&self) -> Vec<Document> {
self.get_all_documents()
}
}
const M0_MAX: usize = 64;
const M_MAX: usize = 32;
const INVALID: u32 = u32::MAX;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct PointId(u32);
impl PointId {
fn as_usize(self) -> usize {
self.0 as usize
}
fn is_valid(self) -> bool {
self.0 != INVALID
}
}
#[derive(Clone)]
struct ZeroNode {
nearest: [PointId; M0_MAX],
}
impl Default for ZeroNode {
fn default() -> Self {
Self {
nearest: [PointId(INVALID); M0_MAX],
}
}
}
impl ZeroNode {
fn count(&self) -> usize {
self.nearest.iter().take_while(|p| p.is_valid()).count()
}
fn iter(&self) -> impl Iterator<Item = PointId> + '_ {
self.nearest.iter().copied().take_while(|p| p.is_valid())
}
}
#[derive(Clone)]
struct UpperNode {
nearest: [PointId; M_MAX],
}
impl Default for UpperNode {
fn default() -> Self {
Self {
nearest: [PointId(INVALID); M_MAX],
}
}
}
impl UpperNode {
fn from_zero(zero: &ZeroNode) -> Self {
let mut node = Self::default();
for (i, &pid) in zero.nearest.iter().take(M_MAX).enumerate() {
node.nearest[i] = pid;
}
node
}
fn iter(&self) -> impl Iterator<Item = PointId> + '_ {
self.nearest.iter().copied().take_while(|p| p.is_valid())
}
}
struct Visited {
store: Vec<u8>,
generation: u8,
}
impl Visited {
fn new(capacity: usize) -> Self {
Self {
store: vec![0; capacity],
generation: 1,
}
}
fn clear(&mut self) {
if self.generation == 255 {
self.store.fill(0);
self.generation = 1;
} else {
self.generation += 1;
}
}
fn insert(&mut self, pid: PointId) -> bool {
let idx = pid.as_usize();
if self.store[idx] == self.generation {
false
} else {
self.store[idx] = self.generation;
true
}
}
fn reserve(&mut self, capacity: usize) {
if self.store.len() < capacity {
self.store.resize(capacity, 0);
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct Candidate {
distance: f32,
pid: PointId,
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance
.total_cmp(&other.distance)
.then_with(|| self.pid.cmp(&other.pid))
}
}
struct Search {
candidates: BinaryHeap<Reverse<Candidate>>,
nearest: Vec<Candidate>,
visited: Visited,
ef: usize,
}
impl Search {
fn new(capacity: usize) -> Self {
Self {
candidates: BinaryHeap::new(),
nearest: Vec::new(),
visited: Visited::new(capacity),
ef: 1,
}
}
fn reset(&mut self) {
self.candidates.clear();
self.nearest.clear();
self.visited.clear();
}
fn push(&mut self, pid: PointId, point: &[f32], points: &[Vec<f32>]) {
let distance = HNSWIndex::parallel_distance(point, &points[pid.as_usize()]);
let candidate = Candidate { distance, pid };
self.candidates.push(Reverse(candidate));
self.nearest.push(candidate);
self.visited.insert(pid);
}
fn cull(&mut self) {
self.candidates.clear();
for &candidate in &self.nearest {
self.candidates.push(Reverse(candidate));
}
self.visited.clear();
for c in &self.nearest {
self.visited.insert(c.pid);
}
}
fn search_zero(
&mut self,
point: &[f32],
layer: &[RwLock<ZeroNode>],
points: &[Vec<f32>],
num: usize,
) {
while let Some(Reverse(candidate)) = self.candidates.pop() {
if let Some(furthest) = self.nearest.last() {
if candidate.distance > furthest.distance && self.nearest.len() >= self.ef {
break;
}
}
let node = layer[candidate.pid.as_usize()].read();
for neighbor_pid in node.iter() {
if self.visited.insert(neighbor_pid) {
let distance =
HNSWIndex::parallel_distance(point, &points[neighbor_pid.as_usize()]);
let new_candidate = Candidate {
distance,
pid: neighbor_pid,
};
let dominated = self.nearest.len() >= self.ef
&& self
.nearest
.last()
.map(|f| distance > f.distance)
.unwrap_or(false);
if !dominated {
self.candidates.push(Reverse(new_candidate));
let pos = self
.nearest
.binary_search(&new_candidate)
.unwrap_or_else(|i| i);
if pos < self.ef {
self.nearest.insert(pos, new_candidate);
if self.nearest.len() > self.ef {
self.nearest.pop();
}
}
}
}
}
}
self.nearest.truncate(num);
}
fn search_upper(
&mut self,
point: &[f32],
layer: &[UpperNode],
points: &[Vec<f32>],
num: usize,
) {
if layer.is_empty() {
return;
}
while let Some(Reverse(candidate)) = self.candidates.pop() {
if let Some(furthest) = self.nearest.last() {
if candidate.distance > furthest.distance && self.nearest.len() >= self.ef {
break;
}
}
if candidate.pid.as_usize() >= layer.len() {
continue;
}
let node = &layer[candidate.pid.as_usize()];
for neighbor_pid in node.iter() {
if self.visited.insert(neighbor_pid) {
let distance =
HNSWIndex::parallel_distance(point, &points[neighbor_pid.as_usize()]);
let new_candidate = Candidate {
distance,
pid: neighbor_pid,
};
let dominated = self.nearest.len() >= self.ef
&& self
.nearest
.last()
.map(|f| distance > f.distance)
.unwrap_or(false);
if !dominated {
self.candidates.push(Reverse(new_candidate));
let pos = self
.nearest
.binary_search(&new_candidate)
.unwrap_or_else(|i| i);
if pos < self.ef {
self.nearest.insert(pos, new_candidate);
if self.nearest.len() > self.ef {
self.nearest.pop();
}
}
}
}
}
}
self.nearest.truncate(num);
}
fn select_simple(&self) -> &[Candidate] {
&self.nearest
}
}
struct SearchPool {
pool: Mutex<Vec<Search>>,
capacity: usize,
}
impl SearchPool {
fn new(capacity: usize) -> Self {
Self {
pool: Mutex::new(Vec::new()),
capacity,
}
}
fn pop(&self) -> Search {
self.pool
.lock()
.pop()
.unwrap_or_else(|| Search::new(self.capacity))
}
fn push(&self, mut search: Search) {
search.reset();
self.pool.lock().push(search);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
struct LayerId(usize);
impl LayerId {
fn is_zero(self) -> bool {
self.0 == 0
}
fn descend(self) -> impl Iterator<Item = LayerId> {
(0..=self.0).rev().map(LayerId)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_document(id: &str, embedding: Vec<f32>) -> Document {
Document {
id: id.to_string(),
content: format!("Content for {}", id),
embedding,
metadata: None,
}
}
fn generate_random_vector(dim: usize, seed: u64) -> Vec<f32> {
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
(0..dim).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect()
}
#[test]
fn test_hnsw_config_default() {
let config = HNSWConfig::default();
assert_eq!(config.m, 32); assert_eq!(config.m0, 64); assert_eq!(config.ef_construction, 100);
assert_eq!(config.ef_search, 100);
assert!((config.ml - (1.0 / 32_f32.ln())).abs() < 0.01);
assert!(config.use_heuristic); assert!(!config.extend_candidates);
assert!(config.keep_pruned_connections);
}
#[test]
fn test_hnsw_config_builders() {
let config = HNSWConfig::default()
.with_m(32)
.with_ef_search(100)
.with_ef_construction(400)
.with_simple_selection()
.with_extended_candidates();
assert_eq!(config.m, 32);
assert_eq!(config.m0, 64);
assert_eq!(config.ef_search, 100);
assert_eq!(config.ef_construction, 400);
assert!(!config.use_heuristic);
assert!(config.extend_candidates);
}
#[test]
fn test_hnsw_new() {
let index = HNSWIndex::with_defaults(128);
assert_eq!(index.embedding_dim, 128);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn search_layer_considers_neighbors_beyond_fixed_stack_batch() {
let mut config = HNSWConfig::default().with_m(64);
config.m0 = 128;
let mut index = HNSWIndex::new(2, config);
let total_nodes = 66usize; index.embeddings = Vec::with_capacity(total_nodes * 2);
index.norms = Vec::with_capacity(total_nodes);
index.connections = vec![vec![Vec::new()]; total_nodes];
index.ids = Vec::with_capacity(total_nodes);
index.contents = Vec::with_capacity(total_nodes);
index.metadata = vec![None; total_nodes];
index.entry_point = Some(0);
index.max_layer = 0;
for node in 0..total_nodes {
if node == 65 {
index.embeddings.extend_from_slice(&[1.0, 0.0]);
} else {
index.embeddings.extend_from_slice(&[0.0, 1.0]);
}
index.norms.push(1.0);
index.ids.push(format!("doc-{node}"));
index.contents.push(String::new());
}
index.connections[0][0] = (1..=65).map(|n| n as u32).collect();
let mut ctx = index.create_search_context();
let candidates = index.search_layer(&[1.0, 0.0], &[0], 66, 0, &mut ctx, 1.0);
assert!(
candidates.contains(&65),
"best neighbor from position >64 should be considered"
);
}
#[test]
fn test_add_single_document() {
let mut index = HNSWIndex::with_defaults(3);
let doc = create_test_document("doc1", vec![1.0, 0.0, 0.0]);
assert!(index.add(doc).is_ok());
assert_eq!(index.len(), 1);
assert!(!index.is_empty());
}
#[test]
fn test_add_dimension_mismatch() {
let mut index = HNSWIndex::with_defaults(3);
let doc = create_test_document("doc1", vec![1.0, 0.0]);
assert!(index.add(doc).is_err());
}
#[test]
fn rejects_inf_embedding() {
let mut index = HNSWIndex::new(3, HNSWConfig::default());
let doc = Document {
id: "inf".to_string(),
content: "test".to_string(),
embedding: vec![f32::INFINITY, 0.0, 0.0],
metadata: None,
};
assert!(index.add(doc).is_err());
let doc_neg = Document {
id: "neg_inf".to_string(),
content: "test".to_string(),
embedding: vec![0.0, f32::NEG_INFINITY, 0.0],
metadata: None,
};
assert!(index.add(doc_neg).is_err());
}
#[test]
fn search_context_resizes_after_add() {
let mut index = HNSWIndex::new(3, HNSWConfig::default());
index
.add(Document {
id: "a".into(),
content: "a".into(),
embedding: vec![1.0, 0.0, 0.0],
metadata: None,
})
.unwrap();
let mut ctx = index.create_search_context();
for i in 0..10 {
index
.add(Document {
id: format!("doc-{i}"),
content: format!("content-{i}"),
embedding: vec![(i as f32) * 0.1, 1.0 - (i as f32) * 0.1, 0.0],
metadata: None,
})
.unwrap();
}
let results = index
.search_with_context(&[1.0, 0.0, 0.0], 5, &mut ctx)
.unwrap();
assert!(!results.is_empty());
}
#[test]
fn test_search_empty_index() {
let index = HNSWIndex::with_defaults(3);
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_single_document() {
let mut index = HNSWIndex::with_defaults(3);
let doc = create_test_document("doc1", vec![1.0, 0.0, 0.0]);
index.add(doc).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "doc1");
assert!((results[0].score - 1.0).abs() < 1e-6);
}
#[test]
fn test_search_multiple_documents() {
let mut index = HNSWIndex::with_defaults(3);
let docs = vec![
create_test_document("doc1", vec![1.0, 0.0, 0.0]),
create_test_document("doc2", vec![0.0, 1.0, 0.0]),
create_test_document("doc3", vec![0.0, 0.0, 1.0]),
create_test_document("doc4", vec![1.0, 1.0, 0.0]),
];
for doc in docs {
index.add(doc).unwrap();
}
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "doc1");
assert!(results[0].score > 0.9);
}
#[test]
fn test_search_exact_match() {
let mut index = HNSWIndex::with_defaults(3);
let embedding = vec![0.5, 0.5, 0.7072];
let doc = create_test_document("doc1", embedding.clone());
index.add(doc).unwrap();
let results = index.search(&embedding, 1).unwrap();
assert_eq!(results.len(), 1);
assert!((results[0].score - 1.0).abs() < 1e-5);
}
#[test]
fn test_clear() {
let mut index = HNSWIndex::with_defaults(3);
for i in 0..5 {
let doc = create_test_document(&format!("doc{}", i), vec![i as f32, 0.0, 0.0]);
index.add(doc).unwrap();
}
assert_eq!(index.len(), 5);
index.clear();
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_random_dataset_100_vectors() {
let dim = 128;
let mut index = HNSWIndex::with_defaults(dim);
for i in 0..100 {
let embedding = generate_random_vector(dim, i);
let doc = create_test_document(&format!("doc{}", i), embedding);
index.add(doc).unwrap();
}
assert_eq!(index.len(), 100);
let query = generate_random_vector(dim, 9999);
let results = index.search(&query, 10).unwrap();
assert_eq!(results.len(), 10);
for i in 0..results.len() - 1 {
assert!(results[i].score >= results[i + 1].score);
}
}
#[test]
fn test_random_dataset_1000_vectors() {
let dim = 64;
let mut index = HNSWIndex::with_defaults(dim);
for i in 0..1000 {
let embedding = generate_random_vector(dim, i);
let doc = create_test_document(&format!("doc{}", i), embedding);
index.add(doc).unwrap();
}
assert_eq!(index.len(), 1000);
for seed in [111, 222, 333, 444, 555] {
let query = generate_random_vector(dim, seed);
let results = index.search(&query, 20).unwrap();
assert_eq!(results.len(), 20);
for i in 0..results.len() - 1 {
assert!(results[i].score >= results[i + 1].score);
}
for result in &results {
assert!(result.score >= -1.0 && result.score <= 1.0);
}
}
}
#[test]
fn test_recall_with_known_neighbors() {
let dim = 32;
let mut index = HNSWIndex::with_defaults(dim);
let query = generate_random_vector(dim, 0);
for i in 0..100 {
let mut embedding = generate_random_vector(dim, i + 1);
if i < 10 {
for j in 0..dim {
embedding[j] = query[j] * 0.9 + embedding[j] * 0.1;
}
}
let doc = create_test_document(&format!("doc{}", i), embedding);
index.add(doc).unwrap();
}
let results = index.search(&query, 10).unwrap();
let mut recall_count = 0;
for result in &results {
let doc_num: usize = result.id.strip_prefix("doc").unwrap().parse().unwrap();
if doc_num < 10 {
recall_count += 1;
}
}
assert!(recall_count >= 7, "Recall too low: {}/10", recall_count);
}
#[test]
fn test_search_dimension_mismatch() {
let mut index = HNSWIndex::with_defaults(3);
let doc = create_test_document("doc1", vec![1.0, 0.0, 0.0]);
index.add(doc).unwrap();
let query = vec![1.0, 0.0]; assert!(index.search(&query, 1).is_err());
}
#[test]
fn test_metadata_preservation() {
let mut index = HNSWIndex::with_defaults(3);
let mut doc = create_test_document("doc1", vec![1.0, 0.0, 0.0]);
doc.metadata = Some(serde_json::json!({"category": "test", "priority": 5}));
index.add(doc).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 1).unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].metadata.is_some());
let metadata = results[0].metadata.as_ref().unwrap();
assert_eq!(metadata["category"], "test");
assert_eq!(metadata["priority"], 5);
}
#[test]
fn test_search_with_nan_query_does_not_panic() {
let mut index = HNSWIndex::with_defaults(3);
index
.add(create_test_document("doc1", vec![1.0, 0.0, 0.0]))
.unwrap();
index
.add(create_test_document("doc2", vec![0.0, 1.0, 0.0]))
.unwrap();
let query = vec![f32::NAN, 0.0, 0.0];
let outcome = std::panic::catch_unwind(|| index.search(&query, 2));
assert!(outcome.is_ok(), "search panicked when query contains NaN");
}
#[test]
#[should_panic(expected = "All embeddings must have the same dimension")]
fn test_build_rejects_mismatched_dimensions() {
let _ = HNSWIndex::build(
vec![vec![1.0, 0.0, 0.0], vec![1.0, 0.0]],
HNSWConfig::default(),
);
}
#[test]
fn test_add_rejects_nan_embedding() {
let mut index = HNSWIndex::with_defaults(3);
let doc = create_test_document("nan_doc", vec![1.0, f32::NAN, 0.0]);
let result = index.add(doc);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("NaN"),
"Error should mention NaN, got: {}",
err_msg
);
}
#[test]
fn test_add_embedding_rejects_nan() {
let mut index = HNSWIndex::with_defaults(3);
let result = index.add_embedding("nan_vec".into(), vec![f32::NAN, 0.0, 0.0]);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("NaN"),
"Error should mention NaN, got: {}",
err_msg
);
}
#[test]
fn test_add_rejects_all_nan_embedding() {
let mut index = HNSWIndex::with_defaults(3);
let doc = create_test_document("all_nan", vec![f32::NAN, f32::NAN, f32::NAN]);
assert!(index.add(doc).is_err());
}
#[test]
fn test_zero_vector_accepted_and_searchable() {
let mut index = HNSWIndex::with_defaults(3);
let doc_zero = create_test_document("zero", vec![0.0, 0.0, 0.0]);
assert!(index.add(doc_zero).is_ok());
let doc_normal = create_test_document("normal", vec![1.0, 0.0, 0.0]);
assert!(index.add(doc_normal).is_ok());
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "normal");
}
#[test]
fn test_zero_vector_query_does_not_panic() {
let mut index = HNSWIndex::with_defaults(3);
let doc = create_test_document("doc1", vec![1.0, 0.0, 0.0]);
index.add(doc).unwrap();
let query = vec![0.0, 0.0, 0.0];
let results = index.search(&query, 1).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_random_level_never_panics() {
let index = HNSWIndex::with_defaults(3);
for _ in 0..10_000 {
let _level = index.random_level();
}
}
}