#[inline]
fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
a.iter()
.zip(b.iter())
.map(|(x, y)| x * y)
.sum::<f32>()
.clamp(-1.0, 1.0)
}
#[inline]
fn l2_normalize(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
struct NswNode {
id: usize,
vector: Vec<f32>,
neighbors: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct NswConfig {
pub max_connections: usize,
pub ef_search: usize,
pub ef_construct: usize,
}
impl Default for NswConfig {
fn default() -> Self {
Self {
max_connections: 16,
ef_search: 64,
ef_construct: 32,
}
}
}
#[derive(Debug, Clone)]
pub struct NswSearchResult {
pub id: usize,
pub score: f32,
}
pub struct NswIndex {
nodes: Vec<NswNode>,
config: NswConfig,
dim: usize,
entry_counter: usize,
}
impl NswIndex {
pub fn new(dim: usize, config: NswConfig) -> Self {
Self {
nodes: Vec::new(),
config,
dim,
entry_counter: 0,
}
}
pub fn insert(&mut self, id: usize, vector: Vec<f32>) {
let mut v = vector;
v.resize(self.dim, 0.0);
l2_normalize(&mut v);
let new_idx = self.nodes.len();
if self.nodes.is_empty() {
self.nodes.push(NswNode {
id,
vector: v,
neighbors: Vec::new(),
});
self.entry_counter = 0;
return;
}
let entry = self.entry_counter % self.nodes.len();
self.entry_counter += 1;
let ef = self.config.ef_construct;
let candidates = self.greedy_search(&v, entry, ef);
let max_conn = self.config.max_connections;
let neighbor_indices: Vec<usize> = candidates
.iter()
.take(max_conn)
.map(|(node_idx, _)| *node_idx)
.collect();
self.nodes.push(NswNode {
id,
vector: v.clone(),
neighbors: neighbor_indices.clone(),
});
for &nb_idx in &neighbor_indices {
self.nodes[nb_idx].neighbors.push(new_idx);
if self.nodes[nb_idx].neighbors.len() > max_conn {
self.prune_neighbors(nb_idx, max_conn);
}
}
}
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<NswSearchResult> {
if self.nodes.is_empty() || top_k == 0 {
return Vec::new();
}
let mut q = query.to_vec();
q.resize(self.dim, 0.0);
l2_normalize(&mut q);
let entry = 0;
let ef = self.config.ef_search;
let mut candidates = self.greedy_search(&q, entry, ef);
candidates
.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
candidates.truncate(top_k);
candidates
.into_iter()
.map(|(node_idx, score)| NswSearchResult {
id: self.nodes[node_idx].id,
score,
})
.collect()
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn dim(&self) -> usize {
self.dim
}
fn greedy_search(&self, query: &[f32], entry: usize, ef: usize) -> Vec<(usize, f32)> {
if self.nodes.is_empty() {
return Vec::new();
}
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashSet};
#[derive(PartialEq)]
struct Scored(f32, usize);
impl Eq for Scored {}
impl PartialOrd for Scored {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Scored {
fn cmp(&self, other: &Self) -> Ordering {
self.0
.partial_cmp(&other.0)
.unwrap_or(Ordering::Equal)
.then(self.1.cmp(&other.1))
}
}
let mut visited: HashSet<usize> = HashSet::new();
let entry_score = cosine_sim(query, &self.nodes[entry].vector);
visited.insert(entry);
let mut frontier: BinaryHeap<Scored> = BinaryHeap::new();
frontier.push(Scored(entry_score, entry));
let mut results: Vec<(usize, f32)> = vec![(entry, entry_score)];
while let Some(Scored(_, node_idx)) = frontier.pop() {
if results.len() >= ef {
let worst_result = results
.iter()
.map(|(_, s)| *s)
.fold(f32::INFINITY, f32::min);
let node_score = results
.iter()
.find(|(i, _)| *i == node_idx)
.map(|(_, s)| *s)
.unwrap_or(f32::NEG_INFINITY);
if node_score < worst_result && frontier.is_empty() {
break;
}
}
for &nb_idx in &self.nodes[node_idx].neighbors {
if visited.contains(&nb_idx) {
continue;
}
visited.insert(nb_idx);
let nb_score = cosine_sim(query, &self.nodes[nb_idx].vector);
frontier.push(Scored(nb_score, nb_idx));
results.push((nb_idx, nb_score));
if results.len() > ef {
let worst_idx = results
.iter()
.enumerate()
.min_by(|a, b| {
a.1 .1
.partial_cmp(&b.1 .1)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.expect("results is non-empty");
results.swap_remove(worst_idx);
}
}
}
results
}
fn prune_neighbors(&mut self, node_idx: usize, max_conn: usize) {
let v = self.nodes[node_idx].vector.clone();
let neighbors = &self.nodes[node_idx].neighbors;
let mut scored: Vec<(usize, f32)> = neighbors
.iter()
.map(|&nb| {
let score = cosine_sim(&v, &self.nodes[nb].vector);
(nb, score)
})
.collect();
scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(max_conn);
self.nodes[node_idx].neighbors = scored.into_iter().map(|(nb, _)| nb).collect();
}
}
pub struct EmbeddingIndex<T: Clone> {
graph: NswIndex,
metadata: Vec<(usize, T)>,
next_id: usize,
}
impl<T: Clone> EmbeddingIndex<T> {
pub fn new(dim: usize) -> Self {
Self::new_with_config(dim, NswConfig::default())
}
pub fn new_with_config(dim: usize, config: NswConfig) -> Self {
Self {
graph: NswIndex::new(dim, config),
metadata: Vec::new(),
next_id: 0,
}
}
pub fn insert(&mut self, vector: Vec<f32>, meta: T) -> usize {
let id = self.next_id;
self.next_id += 1;
self.graph.insert(id, vector);
self.metadata.push((id, meta));
id
}
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(NswSearchResult, &T)> {
let results = self.graph.search(query, top_k);
results
.into_iter()
.filter_map(|r| {
self.metadata
.iter()
.find(|(id, _)| *id == r.id)
.map(|(_, meta)| (r, meta))
})
.collect()
}
pub fn len(&self) -> usize {
self.graph.len()
}
pub fn is_empty(&self) -> bool {
self.graph.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn unit_vec(values: &[f32]) -> Vec<f32> {
let mut v = values.to_vec();
l2_normalize(&mut v);
v
}
#[test]
fn test_nsw_index_empty() {
let idx = NswIndex::new(4, NswConfig::default());
assert!(idx.is_empty());
assert_eq!(idx.len(), 0);
assert_eq!(idx.dim(), 4);
let results = idx.search(&[1.0, 0.0, 0.0, 0.0], 5);
assert!(results.is_empty());
}
#[test]
fn test_nsw_index_single_insert() {
let mut idx = NswIndex::new(4, NswConfig::default());
idx.insert(0, vec![1.0, 0.0, 0.0, 0.0]);
assert_eq!(idx.len(), 1);
assert!(!idx.is_empty());
let results = idx.search(&[1.0, 0.0, 0.0, 0.0], 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 0);
assert!(
(results[0].score - 1.0).abs() < 1e-5,
"score={}",
results[0].score
);
}
#[test]
fn test_nsw_index_search_exact() {
let mut idx = NswIndex::new(3, NswConfig::default());
let v = unit_vec(&[1.0, 2.0, 3.0]);
idx.insert(42, v.clone());
let results = idx.search(&v, 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 42);
assert!(
(results[0].score - 1.0).abs() < 1e-5,
"score={}",
results[0].score
);
}
#[test]
fn test_nsw_index_search_nearest() {
let mut idx = NswIndex::new(2, NswConfig::default());
idx.insert(0, unit_vec(&[1.0, 0.0])); idx.insert(1, unit_vec(&[0.0, 1.0])); idx.insert(2, unit_vec(&[-1.0, 0.0]));
let query = unit_vec(&[0.1, 0.9]); let results = idx.search(&query, 1);
assert_eq!(results.len(), 1);
assert_eq!(
results[0].id, 1,
"nearest should be y-axis vector, got id={}",
results[0].id
);
}
#[test]
fn test_nsw_index_many_vectors() {
let dim = 8;
let config = NswConfig {
max_connections: 8,
ef_search: 32,
ef_construct: 16,
};
let mut idx = NswIndex::new(dim, config);
for i in 0..100usize {
let mut v: Vec<f32> = (0..dim)
.map(|d| {
let x = (i as u64)
.wrapping_mul(6364136223846793005u64)
.wrapping_add((d as u64).wrapping_mul(1442695040888963407u64));
let x = x ^ (x >> 33);
let x = x.wrapping_mul(0xff51afd7ed558ccdu64);
let x = x ^ (x >> 33);
(x as i64) as f32 / i64::MAX as f32
})
.collect();
l2_normalize(&mut v);
idx.insert(i, v);
}
assert_eq!(idx.len(), 100);
let mut query = vec![0.0f32; dim];
query[0] = 1.0;
let results = idx.search(&query, 5);
assert!(!results.is_empty());
assert!(results.len() <= 5);
for w in results.windows(2) {
assert!(
w[0].score >= w[1].score - 1e-5,
"scores not sorted: {} < {}",
w[0].score,
w[1].score
);
}
}
#[test]
fn test_embedding_index_insert_and_search() {
let mut idx: EmbeddingIndex<u32> = EmbeddingIndex::new(4);
idx.insert(unit_vec(&[1.0, 0.0, 0.0, 0.0]), 100);
idx.insert(unit_vec(&[0.0, 1.0, 0.0, 0.0]), 200);
idx.insert(unit_vec(&[0.0, 0.0, 1.0, 0.0]), 300);
let results = idx.search(&unit_vec(&[1.0, 0.0, 0.0, 0.0]), 1);
assert_eq!(results.len(), 1);
assert_eq!(*results[0].1, 100u32);
}
#[test]
fn test_embedding_index_metadata_returned() {
let mut idx: EmbeddingIndex<String> = EmbeddingIndex::new(3);
let id = idx.insert(unit_vec(&[1.0, 1.0, 0.0]), "hello world".to_string());
assert_eq!(id, 0);
let results = idx.search(&unit_vec(&[1.0, 1.0, 0.0]), 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].1, &"hello world".to_string());
assert!((results[0].0.score - 1.0).abs() < 1e-5);
}
#[test]
fn test_nsw_config_defaults() {
let cfg = NswConfig::default();
assert_eq!(cfg.max_connections, 16);
assert_eq!(cfg.ef_search, 64);
assert_eq!(cfg.ef_construct, 32);
}
}