pub mod metric;
mod node;
use metric::Metric;
use ndarray::Array1;
use node::HNSWNode;
use rand::Rng;
use std::cmp::Ordering;
pub struct HNSW {
nodes: Vec<HNSWNode>,
max_level: usize,
ef_construction: usize,
m: usize,
m_max: usize,
norm: f32,
entry: Option<usize>,
level: usize,
metric: Metric,
}
#[derive(Default)]
pub struct HNSWInitializer {
pub max_level: usize,
pub ef_construction: usize,
pub m: usize,
pub m_max: usize,
pub norm: f32,
pub entry: Option<usize>,
pub metric: Metric,
}
impl HNSW {
pub fn len(&self) -> usize {
self.nodes.len()
}
pub(crate) fn new(hnsw_init: HNSWInitializer) -> HNSW {
HNSW {
nodes: Vec::new(),
max_level: hnsw_init.max_level,
ef_construction: hnsw_init.ef_construction,
m: hnsw_init.m,
m_max: hnsw_init.m_max,
norm: hnsw_init.norm,
entry: hnsw_init.entry,
level: 0,
metric: hnsw_init.metric,
}
}
fn random_level(&self) -> usize {
let mut rng = rand::thread_rng();
let random: f32 = rng.gen_range(0.0..1.0);
(-random.ln() * self.norm).floor() as usize
}
fn set_entry(&mut self, entry: usize) {
self.entry = Some(entry);
}
fn get_entry(&self) -> usize {
self.entry.unwrap()
}
pub async fn insert(&mut self, embedding: &Array1<f32>, id: usize) {
let mut node = HNSWNode::new(embedding.clone(), id);
let node_level = self.random_level();
node.level = node_level;
if self.entry.is_none() {
self.set_entry(node.id);
self.nodes.push(node);
return;
}
let mut closest_node = self.get_entry();
for level in (node_level..=self.level).rev() {
closest_node = self.greedy_search(node.clone(), closest_node, level).await;
}
let node_id = self.nodes.len();
self.nodes.push(node);
for level in 0..=node_level {
let neighbors = self
.get_neighbors(self.nodes[node_id].clone(), closest_node, level)
.await;
self.connect_neighbors(node_id, neighbors.clone(), level)
.await;
}
if node_level > self.level {
self.level = node_level;
self.set_entry(node_id);
}
}
async fn greedy_search(&self, target: HNSWNode, entry_point: usize, level: usize) -> usize {
let mut closest_node = entry_point;
let mut closest_dist = self.distance(&self.nodes[entry_point].embedding, &target.embedding);
loop {
let mut improved = false;
for &neighbor in &self.nodes[closest_node].neighbors {
let dist = self.distance(&self.nodes[neighbor].embedding, &target.embedding);
if dist < closest_dist {
closest_node = neighbor;
closest_dist = dist;
improved = true;
}
}
if !improved {
break;
}
}
closest_node
}
async fn get_neighbors(
&self,
target: HNSWNode,
entry_point: usize,
level: usize,
) -> Vec<usize> {
let mut candidates = vec![entry_point];
let mut neighbors = Vec::new();
while candidates.len() > 0 && neighbors.len() < self.m {
let candidate = candidates.pop().unwrap();
let dist = self.distance(&self.nodes[candidate].embedding, &target.embedding);
neighbors.push(candidate);
for &neighbor in &self.nodes[candidate].neighbors {
candidates.push(neighbor);
}
}
neighbors
}
async fn connect_neighbors(&mut self, node_id: usize, neighbors: Vec<usize>, level: usize) {
if node_id >= self.nodes.len() {
panic!("Node ID is out of bounds");
}
for &neighbor_id in &neighbors {
if neighbor_id >= self.nodes.len() {
panic!("Neighbor ID is out of bounds");
}
self.nodes[neighbor_id].neighbors.push(node_id);
self.nodes[node_id].neighbors.push(neighbor_id);
}
}
pub async fn search(&self, query_emb: Array1<f32>, id: usize, k: usize) -> Vec<usize> {
let node = HNSWNode::new(query_emb.clone(), id);
if self.entry.is_none() {
return Vec::new();
}
let mut closest_node = self.get_entry();
for level in (0..=self.level).rev() {
closest_node = self.greedy_search(node.clone(), closest_node, level).await;
}
let neighbors = self.get_neighbors(node, closest_node, 0).await;
neighbors
}
fn distance(&self, emb1: &Array1<f32>, emb2: &Array1<f32>) -> f32 {
self.metric.distance(emb1, emb2)
}
}