use ahash::AHashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswGraph {
pub entry_point: Option<u64>,
pub max_level: usize,
id_to_index: AHashMap<u64, usize>,
index_to_id: Vec<u64>,
nodes: Vec<Vec<Vec<u64>>>,
pub m: usize,
pub m_max: usize, pub m_max_0: usize, pub ef_construction: usize,
pub level_mult: f64,
}
impl HnswGraph {
#[allow(clippy::too_many_arguments)]
pub fn new(
entry_point: Option<u64>,
max_level: usize,
nodes_map: std::collections::HashMap<u64, Vec<Vec<u64>>>,
m: usize,
m_max: usize,
m_max_0: usize,
ef_construction: usize,
level_mult: f64,
) -> Self {
let mut id_to_index = AHashMap::with_capacity(nodes_map.len());
let mut index_to_id = Vec::with_capacity(nodes_map.len());
let mut nodes = Vec::with_capacity(nodes_map.len());
for (doc_id, layers) in nodes_map {
let index = nodes.len();
id_to_index.insert(doc_id, index);
index_to_id.push(doc_id);
nodes.push(layers);
}
Self {
entry_point,
max_level,
id_to_index,
index_to_id,
nodes,
m,
m_max,
m_max_0,
ef_construction,
level_mult,
}
}
pub fn get_neighbors(&self, doc_id: u64, level: usize) -> Option<&Vec<u64>> {
let &index = self.id_to_index.get(&doc_id)?;
self.nodes.get(index).and_then(|levels| levels.get(level))
}
pub fn set_neighbors(&mut self, doc_id: u64, level: usize, neighbors: Vec<u64>) {
let index = self.get_or_create_index(doc_id);
if level < self.nodes[index].len() {
self.nodes[index][level] = neighbors;
}
}
fn get_or_create_index(&mut self, doc_id: u64) -> usize {
if let Some(&index) = self.id_to_index.get(&doc_id) {
index
} else {
let index = self.nodes.len();
self.id_to_index.insert(doc_id, index);
self.index_to_id.push(doc_id);
self.nodes.push(Vec::new());
index
}
}
pub fn contains_node(&self, doc_id: &u64) -> bool {
self.id_to_index.contains_key(doc_id)
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn get_node_layers(&self, doc_id: &u64) -> Option<&Vec<Vec<u64>>> {
let &index = self.id_to_index.get(doc_id)?;
self.nodes.get(index)
}
pub fn iter_nodes(&self) -> impl Iterator<Item = (u64, &Vec<Vec<u64>>)> {
self.index_to_id
.iter()
.zip(self.nodes.iter())
.map(|(&doc_id, layers)| (doc_id, layers))
}
pub fn into_iter_nodes(self) -> impl Iterator<Item = (u64, Vec<Vec<u64>>)> {
self.index_to_id.into_iter().zip(self.nodes)
}
pub fn sorted_nodes(&self) -> Vec<(u64, &Vec<Vec<u64>>)> {
let mut pairs: Vec<_> = self.iter_nodes().collect();
pairs.sort_by_key(|(id, _)| *id);
pairs
}
}