use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Arc;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use crate::distance::Distance;
use crate::heap::DistId;
pub(crate) struct MmapBacking {
pub(crate) _mmap: Arc<memmap2::Mmap>,
pub(crate) ptr: *const f32,
pub(crate) len: usize,
}
unsafe impl Send for MmapBacking {}
unsafe impl Sync for MmapBacking {}
pub(crate) struct VecStore {
pub(crate) data: Vec<f32>,
pub(crate) mmap: Option<MmapBacking>,
pub(crate) dim: usize,
}
impl VecStore {
pub(crate) fn new(dim: usize, capacity: usize) -> Self {
Self {
data: Vec::with_capacity(capacity.saturating_mul(dim.max(1))),
mmap: None,
dim,
}
}
pub(crate) fn from_mmap(
mmap: Arc<memmap2::Mmap>,
vec_offset: usize,
n: usize,
dim: usize,
) -> Self {
let len = n * dim;
let ptr = unsafe { mmap.as_ptr().add(vec_offset) as *const f32 };
Self {
data: Vec::new(),
mmap: Some(MmapBacking { _mmap: mmap, ptr, len }),
dim,
}
}
pub(crate) fn push(&mut self, v: Vec<f32>) {
assert!(
self.mmap.is_none(),
"cannot insert into a memory-mapped (read-only) index"
);
debug_assert_eq!(v.len(), self.dim);
self.data.extend_from_slice(&v);
}
#[inline(always)]
pub(crate) fn get(&self, id: usize) -> &[f32] {
let s = id * self.dim;
match &self.mmap {
None => &self.data[s..s + self.dim],
Some(mb) => unsafe { std::slice::from_raw_parts(mb.ptr.add(s), self.dim) },
}
}
#[inline(always)]
pub(crate) fn len(&self) -> usize {
if self.dim == 0 {
return 0;
}
match &self.mmap {
None => self.data.len() / self.dim,
Some(mb) => mb.len / self.dim,
}
}
pub(crate) fn as_bytes(&self) -> &[u8] {
unsafe {
std::slice::from_raw_parts(
self.data.as_ptr() as *const u8,
self.data.len() * std::mem::size_of::<f32>(),
)
}
}
}
struct VisitedTracker {
stamps: Vec<u32>,
current: u32,
}
impl VisitedTracker {
fn new(capacity: usize) -> Self {
Self { stamps: vec![0u32; capacity], current: 1 }
}
#[inline]
fn begin(&mut self) {
self.current = self.current.wrapping_add(1);
if self.current == 0 {
self.stamps.fill(0);
self.current = 1;
}
}
#[inline]
fn visit(&mut self, id: usize) -> bool {
if id >= self.stamps.len() {
self.stamps.resize(id * 2 + 1, 0);
}
if self.stamps[id] == self.current {
false
} else {
self.stamps[id] = self.current;
true
}
}
}
struct Scratch {
candidates: BinaryHeap<Reverse<DistId>>,
results: BinaryHeap<DistId>,
results_cap: usize,
pub out: Vec<DistId>,
}
impl Scratch {
fn new(ef: usize) -> Self {
Self {
candidates: BinaryHeap::with_capacity(ef * 2 + 1),
results: BinaryHeap::with_capacity(ef + 1),
results_cap: ef,
out: Vec::with_capacity(ef),
}
}
#[inline]
fn begin(&mut self, ef: usize) {
self.candidates.clear();
self.results.clear();
self.results_cap = ef;
}
#[inline]
fn push_entry(&mut self, d: DistId) {
self.candidates.push(Reverse(d));
self.results.push(d);
if self.results.len() > self.results_cap { self.results.pop(); }
}
#[inline]
fn push_candidate(&mut self, d: DistId) {
self.candidates.push(Reverse(d));
self.results.push(d);
if self.results.len() > self.results_cap { self.results.pop(); }
}
#[inline]
fn pop_candidate(&mut self) -> Option<DistId> {
self.candidates.pop().map(|Reverse(x)| x)
}
#[inline]
fn worst_result_dist(&self) -> Option<f32> {
self.results.peek().map(|x| x.dist)
}
#[inline]
fn results_len(&self) -> usize { self.results.len() }
fn finish(&mut self) {
self.out.clear();
while let Some(d) = self.results.pop() { self.out.push(d); }
self.out.reverse(); }
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum PruneStrategy {
#[default]
Simple,
Heuristic,
}
#[derive(Clone, Debug)]
pub struct Config {
pub m: usize,
pub m0: Option<usize>,
pub ef_construction: usize,
pub use_heuristic: bool,
pub extend_candidates: bool,
pub keep_pruned: bool,
pub prune_strategy: PruneStrategy,
pub capacity: usize,
}
impl Default for Config {
fn default() -> Self {
Self {
m: 16,
m0: None,
ef_construction: 200,
use_heuristic: true,
extend_candidates: false,
keep_pruned: true,
prune_strategy: PruneStrategy::Simple,
capacity: 0,
}
}
}
impl Config {
#[inline] pub(crate) fn m0(&self) -> usize { self.m0.unwrap_or(2 * self.m) }
#[inline] fn max_links(&self, layer: usize) -> usize {
if layer == 0 { self.m0() } else { self.m }
}
#[inline] fn m_l(&self) -> f64 { 1.0 / (self.m as f64).ln() }
}
#[derive(Clone, Debug, PartialEq)]
pub struct SearchResult {
pub id: usize,
pub distance: f32,
}
pub struct Hnsw<D: Distance> {
pub(crate) config: Config,
pub(crate) metric: D,
pub(crate) vec_store: VecStore,
pub(crate) connections: Vec<Vec<Vec<(u32, f32)>>>,
pub(crate) entry_point: Option<(usize, usize)>,
rng: SmallRng,
pub(crate) dim: Option<usize>,
visited: VisitedTracker,
scratch: Scratch,
ep_buf: Vec<DistId>,
select_buf: Vec<(usize, f32)>,
pruned_buf: Vec<(usize, f32)>,
prune_buf: Vec<(usize, f32)>,
}
impl<D: Distance> Hnsw<D> {
pub fn new(config: Config, metric: D) -> Self {
assert!(config.m >= 2, "M must be at least 2");
assert!(config.ef_construction >= config.m,
"ef_construction should be ≥ M for good recall");
let cap = config.capacity;
let ef = config.ef_construction;
let m = config.m;
Self {
config,
metric,
vec_store: VecStore::new(0, cap),
connections: Vec::with_capacity(cap),
entry_point: None,
rng: SmallRng::from_entropy(),
dim: None,
visited: VisitedTracker::new(cap.max(64)),
scratch: Scratch::new(ef),
ep_buf: Vec::with_capacity(ef),
select_buf: Vec::with_capacity(m * 2 + 2),
pruned_buf: Vec::with_capacity(m * 2 + 2),
prune_buf: Vec::with_capacity(m * 2 + 2),
}
}
pub(crate) fn from_parts(
config: Config,
metric: D,
vec_store: VecStore,
connections: Vec<Vec<Vec<(u32, f32)>>>,
entry_point: Option<(usize, usize)>,
dim: Option<usize>,
) -> Self {
let n = vec_store.len();
let ef = config.ef_construction;
let m = config.m;
Self {
config,
metric,
vec_store,
connections,
entry_point,
rng: SmallRng::from_entropy(),
dim,
visited: VisitedTracker::new(n.max(64)),
scratch: Scratch::new(ef),
ep_buf: Vec::with_capacity(ef),
select_buf: Vec::with_capacity(m * 2 + 2),
pruned_buf: Vec::with_capacity(m * 2 + 2),
prune_buf: Vec::with_capacity(m * 2 + 2),
}
}
pub fn new_with_seed(config: Config, metric: D, seed: u64) -> Self {
assert!(config.m >= 2);
let cap = config.capacity;
let ef = config.ef_construction;
let m = config.m;
Self {
config,
metric,
vec_store: VecStore::new(0, cap),
connections: Vec::with_capacity(cap),
entry_point: None,
rng: SmallRng::seed_from_u64(seed),
dim: None,
visited: VisitedTracker::new(cap.max(64)),
scratch: Scratch::new(ef),
ep_buf: Vec::with_capacity(ef),
select_buf: Vec::with_capacity(m * 2 + 2),
pruned_buf: Vec::with_capacity(m * 2 + 2),
prune_buf: Vec::with_capacity(m * 2 + 2),
}
}
pub fn insert(&mut self, vector: Vec<f32>) -> usize {
let dim = vector.len();
match self.dim {
None => { self.dim = Some(dim); self.vec_store.dim = dim; }
Some(d) => assert_eq!(d, dim,
"all vectors must have the same dimension (expected {d}, got {dim})"),
}
let q = self.vec_store.len();
let q_level = self.random_level();
self.vec_store.push(vector);
let mut conn: Vec<Vec<(u32, f32)>> = Vec::with_capacity(q_level + 1);
for l in 0..=q_level {
conn.push(Vec::with_capacity(self.config.max_links(l)));
}
self.connections.push(conn);
if self.visited.stamps.len() <= q {
self.visited.stamps.resize(q * 2 + 1, 0);
}
let (ep_id, ep_level) = match self.entry_point {
None => { self.entry_point = Some((q, q_level)); return q; }
Some(x) => x,
};
self.ep_buf.clear();
self.ep_buf.push(DistId::new(self.dist(q, ep_id), ep_id));
for layer in (q_level + 1..=ep_level).rev() {
self.search_layer_node(q, 1, layer);
std::mem::swap(&mut self.ep_buf, &mut self.scratch.out);
}
let top = ep_level.min(q_level);
for layer in (0..=top).rev() {
let ef = self.config.ef_construction;
let m_max = self.config.max_links(layer);
self.search_layer_node(q, ef, layer);
if self.config.use_heuristic {
self.select_neighbours_heuristic(q, m_max, layer);
} else {
self.select_neighbours_simple(m_max);
}
let n_sel = self.select_buf.len();
let mut edge_buf = [(0u32, 0.0f32); 64];
for i in 0..n_sel {
let (nb, dist) = self.select_buf[i];
edge_buf[i] = (nb as u32, dist);
}
for i in 0..n_sel {
let (nb_u32, dist_q_nb) = edge_buf[i];
let nb = nb_u32 as usize;
self.connections[q][layer].push((nb_u32, dist_q_nb));
self.connections[nb][layer].push((q as u32, dist_q_nb));
}
for i in 0..n_sel {
let nb = edge_buf[i].0 as usize;
if self.connections[nb][layer].len() > m_max {
match self.config.prune_strategy {
PruneStrategy::Simple => {
self.connections[nb][layer]
.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
self.connections[nb][layer].truncate(m_max);
}
PruneStrategy::Heuristic => {
self.prune_connections_heuristic(nb, layer, m_max);
}
}
}
}
std::mem::swap(&mut self.ep_buf, &mut self.scratch.out);
}
if q_level > ep_level {
self.entry_point = Some((q, q_level));
}
q
}
pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<SearchResult> {
assert!(k > 0, "k must be > 0");
let ef = ef.max(k);
let (ep_id, ep_level) = match self.entry_point {
None => return Vec::new(),
Some(x) => x,
};
let mut visited = VisitedTracker::new(self.vec_store.len());
let mut scratch = Scratch::new(ef);
let mut ep = Vec::with_capacity(ef);
let ep_dist = self.metric.distance(query, self.vec_store.get(ep_id));
ep.push(DistId::new(ep_dist, ep_id));
for layer in (1..=ep_level).rev() {
Self::do_search_layer(
&self.vec_store, &self.connections, &self.metric,
&mut visited, &mut scratch, query, &ep, 1, layer,
);
std::mem::swap(&mut ep, &mut scratch.out);
}
Self::do_search_layer(
&self.vec_store, &self.connections, &self.metric,
&mut visited, &mut scratch, query, &ep, ef, 0,
);
scratch.out.truncate(k);
scratch.out.iter()
.map(|d| SearchResult { id: d.id, distance: d.dist })
.collect()
}
#[inline] pub fn len(&self) -> usize { self.vec_store.len() }
#[inline] pub fn is_empty(&self) -> bool { self.vec_store.len() == 0 }
#[inline] pub fn get_vector(&self, id: usize) -> &[f32] { self.vec_store.get(id) }
#[inline] pub fn dim(&self) -> Option<usize> { self.dim }
pub fn max_level(&self) -> Option<usize> { self.entry_point.map(|(_, l)| l) }
fn random_level(&mut self) -> usize {
let u: f64 = self.rng.gen::<f64>().max(f64::MIN_POSITIVE);
(-u.ln() * self.config.m_l()).floor() as usize
}
#[inline]
fn dist(&self, a: usize, b: usize) -> f32 {
self.metric.distance(self.vec_store.get(a), self.vec_store.get(b))
}
fn search_layer_node(&mut self, q: usize, ef: usize, layer: usize) {
let vec_store = &self.vec_store;
let connections = &self.connections;
let metric = &self.metric;
let visited = &mut self.visited;
let scratch = &mut self.scratch;
let ep = &self.ep_buf;
let q_vec = vec_store.get(q);
visited.begin();
scratch.begin(ef);
for &ep_d in ep {
if visited.visit(ep_d.id) { scratch.push_entry(ep_d); }
}
loop {
let c = match scratch.pop_candidate() { Some(c) => c, None => break };
let worst = match scratch.worst_result_dist() { Some(d) => d, None => break };
if c.dist > worst { break; }
if let Some(nb_list) = connections.get(c.id).and_then(|nc| nc.get(layer)) {
for &(nb_u32, _) in nb_list {
let nb = nb_u32 as usize;
if visited.visit(nb) {
let nb_dist = metric.distance(q_vec, vec_store.get(nb));
let cur_worst = scratch.worst_result_dist().unwrap_or(f32::INFINITY);
if nb_dist < cur_worst || scratch.results_len() < ef {
scratch.push_candidate(DistId::new(nb_dist, nb));
}
}
}
}
}
scratch.finish();
}
fn do_search_layer(
vec_store: &VecStore,
connections: &[Vec<Vec<(u32, f32)>>],
metric: &D,
visited: &mut VisitedTracker,
scratch: &mut Scratch,
query: &[f32],
entry_points: &[DistId],
ef: usize,
layer: usize,
) {
visited.begin();
scratch.begin(ef);
for &ep in entry_points {
if visited.visit(ep.id) { scratch.push_entry(ep); }
}
loop {
let c = match scratch.pop_candidate() { Some(c) => c, None => break };
let worst = match scratch.worst_result_dist() { Some(d) => d, None => break };
if c.dist > worst { break; }
if let Some(nb_list) = connections.get(c.id).and_then(|nc| nc.get(layer)) {
for &(nb_u32, _) in nb_list {
let nb = nb_u32 as usize;
if visited.visit(nb) {
let nb_dist = metric.distance(query, vec_store.get(nb));
let cur_worst = scratch.worst_result_dist().unwrap_or(f32::INFINITY);
if nb_dist < cur_worst || scratch.results_len() < ef {
scratch.push_candidate(DistId::new(nb_dist, nb));
}
}
}
}
}
scratch.finish();
}
fn select_neighbours_simple(&mut self, m: usize) {
self.select_buf.clear();
let end = m.min(self.scratch.out.len());
for i in 0..end {
let d = self.scratch.out[i];
self.select_buf.push((d.id, d.dist));
}
}
fn select_neighbours_heuristic(&mut self, q: usize, m: usize, layer: usize) {
self.select_buf.clear();
self.pruned_buf.clear();
let n_cands = self.scratch.out.len();
if n_cands <= m && !self.config.extend_candidates {
for i in 0..n_cands {
let d = self.scratch.out[i]; self.select_buf.push((d.id, d.dist));
}
return;
}
let ext_buf: Vec<DistId>;
let cands: &[DistId] = if self.config.extend_candidates {
let mut tmp: Vec<DistId> = self.scratch.out.to_vec();
let seen_ids: std::collections::HashSet<usize> =
self.scratch.out.iter().map(|d| d.id).collect();
let mut extra: Vec<DistId> = Vec::new();
for &d in &self.scratch.out {
if let Some(nb_list) = self.connections.get(d.id).and_then(|nc| nc.get(layer)) {
for &(nb_u32, _) in nb_list {
let nb = nb_u32 as usize;
if !seen_ids.contains(&nb) {
extra.push(DistId::new(self.dist(q, nb), nb));
}
}
}
}
tmp.extend_from_slice(&extra);
tmp.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
ext_buf = tmp;
&ext_buf
} else {
&self.scratch.out
};
for i in 0..cands.len() {
if self.select_buf.len() >= m { break; }
let e_dist = cands[i].dist;
let e_id = cands[i].id;
let mut accept = true;
for j in 0..self.select_buf.len() {
let (s_id, s_dist_q) = self.select_buf[j]; if s_dist_q > 2.0 * e_dist { continue; }
let d_es = self.metric.distance(
self.vec_store.get(e_id),
self.vec_store.get(s_id),
);
if d_es <= e_dist { accept = false; break; }
}
if accept {
self.select_buf.push((e_id, e_dist));
} else if self.config.keep_pruned {
self.pruned_buf.push((e_id, e_dist));
}
}
if self.config.keep_pruned {
let needed = m.saturating_sub(self.select_buf.len());
let add = needed.min(self.pruned_buf.len());
for i in 0..add {
let (id, dist) = self.pruned_buf[i]; self.select_buf.push((id, dist));
}
}
}
fn prune_connections_heuristic(&mut self, node_id: usize, layer: usize, m_max: usize) {
self.prune_buf.clear();
let conn_len = self.connections[node_id][layer].len();
for i in 0..conn_len {
let (nb_u32, dist) = self.connections[node_id][layer][i];
self.prune_buf.push((nb_u32 as usize, dist));
}
self.prune_buf.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
self.select_buf.clear();
self.pruned_buf.clear();
for i in 0..self.prune_buf.len() {
if self.select_buf.len() >= m_max { break; }
let (e_id, e_dist) = self.prune_buf[i];
let mut accept = true;
for j in 0..self.select_buf.len() {
let (s_id, s_dist_node) = self.select_buf[j]; if s_dist_node > 2.0 * e_dist { continue; }
let d_es = self.metric.distance(
self.vec_store.get(e_id),
self.vec_store.get(s_id),
);
if d_es <= e_dist { accept = false; break; }
}
if accept {
self.select_buf.push((e_id, e_dist));
} else if self.config.keep_pruned {
self.pruned_buf.push((e_id, e_dist));
}
}
if self.config.keep_pruned {
let needed = m_max.saturating_sub(self.select_buf.len());
let add = needed.min(self.pruned_buf.len());
for i in 0..add {
let (id, dist) = self.pruned_buf[i];
self.select_buf.push((id, dist));
}
}
self.connections[node_id][layer].clear();
for i in 0..self.select_buf.len() {
let (id, dist) = self.select_buf[i];
self.connections[node_id][layer].push((id as u32, dist));
}
}
pub fn stats(&self) -> IndexStats {
let max_level = self.entry_point.map(|(_, l)| l).unwrap_or(0);
let mut layer_counts = vec![0usize; max_level + 1];
let mut layer_edges = vec![0usize; max_level + 1];
for node_conn in &self.connections {
for (l, conn) in node_conn.iter().enumerate() {
layer_counts[l] += 1;
layer_edges[l] += conn.len();
}
}
IndexStats { num_vectors: self.vec_store.len(), max_level, layer_counts, layer_edges }
}
}
#[derive(Debug)]
pub struct IndexStats {
pub num_vectors: usize,
pub max_level: usize,
pub layer_counts: Vec<usize>,
pub layer_edges: Vec<usize>,
}
impl std::fmt::Display for IndexStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "HNSW index — {} vectors", self.num_vectors)?;
writeln!(f, " Max level : {}", self.max_level)?;
for l in (0..=self.max_level).rev() {
writeln!(f, " Layer {:>3} : {:>6} nodes, {:>7} directed edges",
l, self.layer_counts[l], self.layer_edges[l])?;
}
Ok(())
}
}