use crate::simd;
use crate::store::types::{Distance, Id};
use anyhow::Result;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Node {
id: Id,
vector: Vec<f32>,
layer: usize,
neighbors: Vec<HashSet<Id>>,
}
#[derive(Clone)]
struct Candidate {
id: Id,
distance: f32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
other.distance.partial_cmp(&self.distance)
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap_or(Ordering::Equal)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WasmHnsw {
nodes: HashMap<Id, Node>,
entry_point: Option<Id>,
max_layer: usize,
dimension: usize,
metric: Distance,
m: usize,
m_max: usize,
ef_construction: usize,
ml: f32,
}
impl WasmHnsw {
pub fn new(dimension: usize) -> Self {
Self::with_params(dimension, Distance::Cosine, 16, 200)
}
pub fn with_params(
dimension: usize,
metric: Distance,
m: usize,
ef_construction: usize,
) -> Self {
Self {
nodes: HashMap::new(),
entry_point: None,
max_layer: 0,
dimension,
metric,
m,
m_max: m * 2,
ef_construction,
ml: 1.0 / (2.0_f32).ln(),
}
}
pub fn insert(&mut self, id: Id, vector: Vec<f32>) -> Result<()> {
if vector.len() != self.dimension {
anyhow::bail!(
"Vector dimension mismatch: expected {}, got {}",
self.dimension,
vector.len()
);
}
let layer = self.random_layer();
let mut neighbors = Vec::new();
for _ in 0..=layer {
neighbors.push(HashSet::new());
}
let new_node = Node {
id: id.clone(),
vector: vector.clone(),
layer,
neighbors,
};
if self.entry_point.is_none() {
self.entry_point = Some(id.clone());
self.max_layer = layer;
self.nodes.insert(id, new_node);
return Ok(());
}
self.nodes.insert(id.clone(), new_node);
let ep = self.entry_point.as_ref().unwrap().clone();
let ep_dist = self.compute_distance(&vector, &self.nodes[&ep].vector);
let mut nearest = vec![Candidate {
id: ep,
distance: ep_dist,
}];
for lc in (layer + 1..=self.max_layer).rev() {
nearest = self.search_layer(&vector, &nearest, 1, lc);
}
for lc in (0..=layer).rev() {
let candidates = self.search_layer(&vector, &nearest, self.ef_construction, lc);
let m = if lc == 0 { self.m_max } else { self.m };
let neighbors_to_add = self.select_neighbors(&vector, &candidates, m);
for neighbor_id in &neighbors_to_add {
if let Some(node) = self.nodes.get_mut(&id) {
if lc < node.neighbors.len() {
node.neighbors[lc].insert(neighbor_id.clone());
}
}
if let Some(neighbor_node) = self.nodes.get_mut(neighbor_id) {
if lc < neighbor_node.neighbors.len() {
neighbor_node.neighbors[lc].insert(id.clone());
let max_conn = if lc == 0 { self.m_max } else { self.m };
if neighbor_node.neighbors[lc].len() > max_conn {
self.prune_connections(neighbor_id, lc, max_conn);
}
}
}
}
nearest = candidates;
}
if layer > self.max_layer {
self.max_layer = layer;
self.entry_point = Some(id.clone());
}
Ok(())
}
pub fn search(&self, query: &[f32], k: usize, ef_search: usize) -> Result<Vec<(Id, f32)>> {
if query.len() != self.dimension {
anyhow::bail!(
"Query dimension mismatch: expected {}, got {}",
self.dimension,
query.len()
);
}
if self.entry_point.is_none() {
return Ok(Vec::new());
}
let ep = self.entry_point.as_ref().unwrap().clone();
let ep_dist = self.compute_distance(query, &self.nodes[&ep].vector);
let mut nearest = vec![Candidate {
id: ep,
distance: ep_dist,
}];
for lc in (1..=self.max_layer).rev() {
nearest = self.search_layer(query, &nearest, 1, lc);
}
let candidates = self.search_layer(query, &nearest, ef_search.max(k), 0);
Ok(candidates
.into_iter()
.take(k)
.map(|c| (c.id, c.distance))
.collect())
}
pub fn remove(&mut self, id: &str) -> Result<()> {
if let Some(node) = self.nodes.remove(id) {
for (layer, neighbor_ids) in node.neighbors.iter().enumerate() {
for neighbor_id in neighbor_ids {
if let Some(neighbor) = self.nodes.get_mut(neighbor_id) {
if layer < neighbor.neighbors.len() {
neighbor.neighbors[layer].remove(id);
}
}
}
}
if self.entry_point.as_deref() == Some(id) {
self.entry_point = self
.nodes
.iter()
.max_by_key(|(_, n)| n.layer)
.map(|(id, _)| id.clone());
self.max_layer = self.nodes.values().map(|n| n.layer).max().unwrap_or(0);
}
}
Ok(())
}
fn search_layer(
&self,
query: &[f32],
entry_points: &[Candidate],
num_closest: usize,
layer: usize,
) -> Vec<Candidate> {
let mut visited = HashSet::new();
let mut candidates = BinaryHeap::new();
let mut best = BinaryHeap::new();
for ep in entry_points {
visited.insert(ep.id.clone());
candidates.push(ep.clone());
best.push(ep.clone());
}
while let Some(current) = candidates.pop() {
let worst_best = best.peek().map(|c| c.distance).unwrap_or(f32::INFINITY);
if current.distance > worst_best {
break; }
if let Some(node) = self.nodes.get(¤t.id) {
if layer < node.neighbors.len() {
for neighbor_id in &node.neighbors[layer] {
if visited.insert(neighbor_id.clone()) {
if let Some(neighbor_node) = self.nodes.get(neighbor_id) {
let dist = self.compute_distance(query, &neighbor_node.vector);
if dist < worst_best || best.len() < num_closest {
let cand = Candidate {
id: neighbor_id.clone(),
distance: dist,
};
candidates.push(cand.clone());
best.push(cand);
if best.len() > num_closest {
best.pop();
}
}
}
}
}
}
}
}
let mut result: Vec<_> = best.into_iter().collect();
result.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(Ordering::Equal)
});
result
}
fn select_neighbors(&self, _query: &[f32], candidates: &[Candidate], m: usize) -> Vec<Id> {
if candidates.len() <= m {
return candidates.iter().map(|c| c.id.clone()).collect();
}
candidates.iter().take(m).map(|c| c.id.clone()).collect()
}
fn prune_connections(&mut self, node_id: &Id, layer: usize, max_conn: usize) {
if let Some(node) = self.nodes.get(node_id) {
if layer >= node.neighbors.len() {
return;
}
let query = &node.vector;
let mut candidates: Vec<_> = node.neighbors[layer]
.iter()
.filter_map(|neighbor_id| {
self.nodes.get(neighbor_id).map(|neighbor| Candidate {
id: neighbor_id.clone(),
distance: self.compute_distance(query, &neighbor.vector),
})
})
.collect();
candidates.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(Ordering::Equal)
});
let to_keep: HashSet<_> = candidates
.into_iter()
.take(max_conn)
.map(|c| c.id)
.collect();
if let Some(node) = self.nodes.get_mut(node_id) {
if layer < node.neighbors.len() {
node.neighbors[layer] = to_keep;
}
}
}
}
fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
match self.metric {
Distance::Cosine => 1.0 - simd::cosine_similarity_simd(a, b), Distance::Euclidean => simd::euclidean_distance_simd(a, b),
Distance::DotProduct => -simd::dot_product_simd(a, b), Distance::Manhattan => simd::manhattan_distance_simd(a, b),
Distance::Hamming => simd::hamming_distance_simd(a, b),
Distance::Jaccard => 1.0 - simd::jaccard_similarity_simd(a, b), Distance::Chebyshev => simd::chebyshev_distance_simd(a, b),
Distance::Canberra => simd::canberra_distance_simd(a, b),
Distance::BrayCurtis => simd::braycurtis_distance_simd(a, b),
}
}
fn random_layer(&self) -> usize {
let mut rng = rand::thread_rng();
let r: f32 = rng.gen();
(-r.ln() * self.ml).floor() as usize
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn ids(&self) -> Vec<Id> {
self.nodes.keys().cloned().collect()
}
pub fn clear(&mut self) {
self.nodes.clear();
self.entry_point = None;
self.max_layer = 0;
}
pub fn stats(&self) -> HnswStats {
let total_edges: usize = self
.nodes
.values()
.flat_map(|n| &n.neighbors)
.map(|layer| layer.len())
.sum();
let layer_distribution: Vec<usize> = (0..=self.max_layer)
.map(|l| self.nodes.values().filter(|n| n.layer >= l).count())
.collect();
HnswStats {
num_nodes: self.nodes.len(),
num_edges: total_edges,
max_layer: self.max_layer,
layer_distribution,
m: self.m,
ef_construction: self.ef_construction,
}
}
pub fn to_visualizer(&self) -> crate::graph_viz::HnswVisualizer {
use crate::graph_viz::{GraphEdge, GraphNode, HnswVisualizer};
let mut graph_nodes = Vec::new();
for (id, node) in &self.nodes {
let degree = node.neighbors.iter().map(|layer| layer.len()).sum();
let vector_preview = if node.vector.len() >= 3 {
Some(vec![node.vector[0], node.vector[1], node.vector[2]])
} else if !node.vector.is_empty() {
Some(node.vector.clone())
} else {
None
};
graph_nodes.push(GraphNode {
id: id.clone(),
layer: node.layer,
degree,
vector_preview,
});
}
let mut graph_edges = Vec::new();
for (id, node) in &self.nodes {
for (layer_idx, neighbors_at_layer) in node.neighbors.iter().enumerate() {
for neighbor_id in neighbors_at_layer {
if let Some(neighbor_node) = self.nodes.get(neighbor_id) {
let distance = self.compute_distance(&node.vector, &neighbor_node.vector);
graph_edges.push(GraphEdge {
source: id.clone(),
target: neighbor_id.clone(),
layer: layer_idx,
weight: Some(distance),
});
}
}
}
}
HnswVisualizer::new(graph_nodes, graph_edges, self.max_layer + 1)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswStats {
pub num_nodes: usize,
pub num_edges: usize,
pub max_layer: usize,
pub layer_distribution: Vec<usize>,
pub m: usize,
pub ef_construction: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wasm_hnsw_basic() {
let mut index = WasmHnsw::new(3);
index.insert("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
index.insert("v2".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
index.insert("v3".to_string(), vec![0.0, 0.0, 1.0]).unwrap();
index.insert("v4".to_string(), vec![1.0, 1.0, 0.0]).unwrap();
assert_eq!(index.len(), 4);
let results = index.search(&[1.0, 0.1, 0.0], 2, 50).unwrap();
assert_eq!(results.len(), 2);
assert!(
results[0].1 < 0.3,
"First result distance should be small: {}",
results[0].1
);
assert!(
results[0].0 == "v1" || results[0].0 == "v4",
"First result should be v1 or v4, got: {}",
results[0].0
);
}
#[test]
fn test_wasm_hnsw_remove() {
let mut index = WasmHnsw::new(2);
index.insert("v1".to_string(), vec![1.0, 2.0]).unwrap();
index.insert("v2".to_string(), vec![3.0, 4.0]).unwrap();
index.insert("v3".to_string(), vec![5.0, 6.0]).unwrap();
assert_eq!(index.len(), 3);
let results_before = index.search(&[3.0, 4.0], 1, 50).unwrap();
assert_eq!(results_before[0].0, "v2");
index.remove("v2").unwrap();
assert_eq!(index.len(), 2);
let results_after = index.search(&[3.0, 4.0], 3, 50).unwrap();
assert!(results_after.len() >= 1); assert!(results_after.len() <= 2); for (id, _) in &results_after {
assert_ne!(id, "v2");
}
}
#[test]
fn test_wasm_hnsw_large() {
let mut index = WasmHnsw::with_params(128, Distance::Cosine, 16, 200);
let mut rng = rand::thread_rng();
for i in 0..1000 {
let vector: Vec<f32> = (0..128).map(|_| rng.gen::<f32>()).collect();
index.insert(format!("v{}", i), vector).unwrap();
}
assert_eq!(index.len(), 1000);
let query: Vec<f32> = (0..128).map(|_| rng.gen::<f32>()).collect();
let results = index.search(&query, 10, 50).unwrap();
assert_eq!(results.len(), 10);
for i in 1..results.len() {
assert!(results[i].1 >= results[i - 1].1);
}
}
#[test]
fn test_dimension_validation() {
let mut index = WasmHnsw::new(3);
let result = index.insert("v1".to_string(), vec![1.0, 2.0]);
assert!(result.is_err());
let result = index.insert("v1".to_string(), vec![1.0, 2.0, 3.0]);
assert!(result.is_ok());
}
#[test]
fn test_visualization() {
let mut index = WasmHnsw::new(3);
index.insert("v1".to_string(), vec![0.1, 0.1, 0.1]).unwrap();
index.insert("v2".to_string(), vec![0.2, 0.1, 0.1]).unwrap();
index.insert("v3".to_string(), vec![0.8, 0.9, 0.8]).unwrap();
index.insert("v4".to_string(), vec![0.5, 0.5, 0.5]).unwrap();
let viz = index.to_visualizer();
assert_eq!(viz.node_count(), 4);
assert!(viz.edge_count() > 0);
let dot = viz.export_dot().unwrap();
assert!(dot.contains("digraph HNSW"));
assert!(dot.contains("v1"));
assert!(dot.contains("v2"));
assert!(dot.contains("v3"));
assert!(dot.contains("v4"));
let json = viz.export_json().unwrap();
assert!(json.contains("nodes"));
assert!(json.contains("links"));
let cyto = viz.export_cytoscape().unwrap();
assert!(cyto.contains("data"));
let stats = viz.statistics();
assert_eq!(stats.node_count, 4);
assert!(stats.edge_count > 0);
assert!(stats.layer_count > 0);
let stats_text = viz.export_statistics_text();
assert!(stats_text.contains("HNSW Graph Statistics"));
let sampled = viz.sample(2);
assert_eq!(sampled.node_count(), 2);
}
}