use std::sync::Arc;
use iqdb_types::{IqdbError, Metadata, Result, VectorId};
use crate::graph::{NodeIdx, cap_at_layer, pick_layer};
use crate::index::HnswIndex;
use crate::search::{distance_between, distance_to, search_layer};
use crate::topk::{Scored, take_topk_sorted};
pub(crate) fn insert_node(
idx: &mut HnswIndex,
id: VectorId,
vector: Arc<[f32]>,
metadata: Option<Metadata>,
) -> Result<()> {
idx.check_dim(vector.len())?;
if idx.id_to_node.contains_key(&id) {
return Err(IqdbError::Duplicate);
}
let seq = idx.next_seq;
let next_seq = idx
.next_seq
.checked_add(1)
.ok_or(IqdbError::InvalidConfig {
reason: "HnswIndex insertion sequence counter overflowed u64",
})?;
idx.next_seq = next_seq;
let layer = pick_layer(&mut idx.rng, idx.m_l_inv);
let new_node: NodeIdx = idx.vectors.len() as NodeIdx;
idx.vectors.push(Arc::clone(&vector));
idx.ids.push(id.clone());
idx.metadata.push(metadata);
idx.seqs.push(seq);
idx.tombstoned.push(false);
idx.node_layer.push(layer);
let mut per_layer_adj: Vec<Vec<NodeIdx>> = Vec::with_capacity((layer as usize) + 1);
for lc in 0..=layer {
per_layer_adj.push(Vec::with_capacity(cap_at_layer(idx.cfg.m, lc)));
}
idx.layers.push(per_layer_adj);
let _prev = idx.id_to_node.insert(id, new_node);
let entry = match idx.entry {
Some(e) => e,
None => {
idx.entry = Some(new_node);
idx.top_layer = layer;
idx.live_count = idx.live_count.saturating_add(1);
return Ok(());
}
};
let entry_dist = distance_to(idx, &vector, entry)?;
let mut cur = Scored {
dist: entry_dist,
seq: idx.seqs[entry as usize],
node: entry,
};
let top = idx.top_layer;
if top > layer {
let mut lc = top;
while lc > layer {
let result_heap = search_layer(idx, &vector, &[cur], lc, 1)?;
if let Some(nearest) = result_heap.iter().min().copied() {
cur = nearest;
}
lc -= 1;
}
}
let mut entry_points: Vec<Scored> = vec![cur];
let start_lc = layer.min(top);
let mut lc = start_lc;
loop {
let result_heap = search_layer(idx, &vector, &entry_points, lc, idx.cfg.ef_construction)?;
let w_sorted = take_topk_sorted(result_heap, idx.cfg.ef_construction);
let live_candidates: Vec<Scored> = w_sorted
.iter()
.copied()
.filter(|s| !idx.tombstoned[s.node as usize] && s.node != new_node)
.collect();
let m_cap = cap_at_layer(idx.cfg.m, lc);
let chosen = select_heuristic(idx, &live_candidates, m_cap)?;
for s in &chosen {
idx.layers[new_node as usize][lc as usize].push(s.node);
idx.layers[s.node as usize][lc as usize].push(new_node);
if idx.layers[s.node as usize][lc as usize].len() > m_cap {
trim_neighbourhood(idx, s.node, lc, m_cap, Some(new_node))?;
}
}
entry_points = w_sorted;
if lc == 0 {
break;
}
lc -= 1;
}
if layer > idx.top_layer {
idx.entry = Some(new_node);
idx.top_layer = layer;
}
idx.live_count = idx.live_count.saturating_add(1);
Ok(())
}
pub(crate) fn select_heuristic(
idx: &HnswIndex,
candidates: &[Scored],
m_max: usize,
) -> Result<Vec<Scored>> {
if m_max == 0 || candidates.is_empty() {
return Ok(Vec::new());
}
let mut sorted: Vec<Scored> = candidates.to_vec();
sorted.sort();
let mut selected: Vec<Scored> = Vec::with_capacity(m_max);
let mut pruned: Vec<Scored> = Vec::new();
for c in sorted {
if selected.len() >= m_max {
break;
}
let mut covered = false;
for s in &selected {
let d_cs = distance_between(idx, c.node, s.node)?;
if d_cs < c.dist {
covered = true;
break;
}
}
if covered {
pruned.push(c);
} else {
selected.push(c);
}
}
if selected.len() < m_max {
for c in pruned {
if selected.len() >= m_max {
break;
}
selected.push(c);
}
}
Ok(selected)
}
fn trim_neighbourhood(
idx: &mut HnswIndex,
node: NodeIdx,
layer: u8,
cap: usize,
pinned: Option<NodeIdx>,
) -> Result<()> {
let current_adj: Vec<NodeIdx> = idx.layers[node as usize][layer as usize].clone();
let mut candidates: Vec<Scored> = Vec::with_capacity(current_adj.len());
for &nb in ¤t_adj {
let d = distance_between(idx, node, nb)?;
candidates.push(Scored {
dist: d,
seq: idx.seqs[nb as usize],
node: nb,
});
}
let mut chosen = select_heuristic(idx, &candidates, cap)?;
if let Some(pin) = pinned {
let already_in = chosen.iter().any(|s| s.node == pin);
if !already_in {
if let Some(pin_scored) = candidates.iter().copied().find(|s| s.node == pin) {
if chosen.len() >= cap {
if let Some((worst_idx, _)) =
chosen.iter().enumerate().max_by(|(_, a), (_, b)| a.cmp(b))
{
let _evicted = chosen.swap_remove(worst_idx);
}
}
chosen.push(pin_scored);
}
}
}
idx.layers[node as usize][layer as usize] = chosen.iter().map(|s| s.node).collect();
Ok(())
}