use crate::hnsw::HnswIndex;
use anyhow::Result;
use std::collections::BinaryHeap;
pub type NeighborList = Vec<(usize, f32)>;
impl HnswIndex {
pub fn select_neighbors_heuristic_v2(
&self,
candidates: &[(usize, f32)],
query_vector: &[f32],
m: usize,
layer: usize,
extend_candidates: bool,
keep_pruned: bool,
) -> Result<NeighborList> {
let m_max = if layer == 0 {
self.config().m_l0
} else {
self.config().m
};
let effective_m = m.min(m_max);
if candidates.is_empty() {
return Ok(Vec::new());
}
let mut w: BinaryHeap<RevCandidate> = candidates
.iter()
.map(|&(id, dist)| RevCandidate { id, dist })
.collect();
if extend_candidates {
let snapshot: Vec<RevCandidate> = w.iter().cloned().collect();
for RevCandidate { id: cand_id, .. } in &snapshot {
let node = match self.nodes().get(*cand_id) {
Some(n) => n,
None => continue,
};
let effective_layer = layer.min(node.connections.len().saturating_sub(1));
let neighbors = match node.get_connections(effective_layer) {
Some(c) => c.clone(),
None => continue,
};
for neighbor_id in neighbors {
if w.iter().any(|rc| rc.id == neighbor_id) {
continue;
}
if let Ok(dist) = self.calculate_distance_from_slice(query_vector, neighbor_id)
{
w.push(RevCandidate {
id: neighbor_id,
dist,
});
}
}
}
}
let mut selected: NeighborList = Vec::with_capacity(effective_m);
let mut pruned: NeighborList = Vec::new();
while let Some(RevCandidate {
id: e,
dist: dist_e,
}) = w.pop()
{
if selected.len() >= effective_m {
break;
}
let shadowed = selected.iter().any(|&(r_id, _)| {
self.calculate_distance_between_nodes_pub(r_id, e)
.map(|dist_r_e| dist_r_e < dist_e)
.unwrap_or(false)
});
if !shadowed {
selected.push((e, dist_e));
} else {
pruned.push((e, dist_e));
}
}
if keep_pruned {
pruned.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
for item in pruned {
if selected.len() >= effective_m {
break;
}
selected.push(item);
}
}
selected.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(selected)
}
pub fn prune_connections_heuristic(&mut self, node_id: usize, level: usize) -> Result<()> {
use std::collections::HashSet;
let max_connections = if level == 0 {
self.config().m_l0
} else {
self.config().m
};
let (current_conns, query_vec) = {
let node = match self.nodes().get(node_id) {
Some(n) => n,
None => return Ok(()),
};
let conns: Vec<usize> = node
.get_connections(level)
.map(|c| c.iter().cloned().collect())
.unwrap_or_default();
(conns, node.vector_data_f32.clone())
};
if current_conns.len() <= max_connections {
return Ok(());
}
let mut candidates: Vec<(usize, f32)> = current_conns
.iter()
.filter_map(|&nb| {
self.calculate_distance_from_slice(&query_vec, nb)
.ok()
.map(|d| (nb, d))
})
.collect();
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let selected = self.select_neighbors_heuristic_v2(
&candidates,
&query_vec,
max_connections,
level,
false,
false,
)?;
let selected_set: HashSet<usize> = selected.iter().map(|&(id, _)| id).collect();
let removed: Vec<usize> = current_conns
.iter()
.filter(|&&id| !selected_set.contains(&id))
.cloned()
.collect();
if let Some(node) = self.nodes_mut().get_mut(node_id) {
if let Some(conns) = node.get_connections_mut(level) {
*conns = selected_set.clone();
}
}
for pruned_id in removed {
if let Some(pruned_node) = self.nodes_mut().get_mut(pruned_id) {
pruned_node.remove_connection(level, node_id);
}
}
for &selected_id in &selected_set {
if let Some(selected_node) = self.nodes_mut().get_mut(selected_id) {
selected_node.add_connection(level, node_id);
}
}
Ok(())
}
fn calculate_distance_from_slice(&self, query: &[f32], node_id: usize) -> Result<f32> {
let node = self
.nodes()
.get(node_id)
.ok_or_else(|| anyhow::anyhow!("Node {} not found", node_id))?;
let query_vec = crate::Vector::new(query.to_vec());
self.config().metric.distance(&query_vec, &node.vector)
}
pub fn calculate_distance_between_nodes_pub(&self, a: usize, b: usize) -> Option<f32> {
let nodes = self.nodes();
let node_a = nodes.get(a)?;
let node_b = nodes.get(b)?;
node_a
.vector
.cosine_similarity(&node_b.vector)
.ok()
.map(|s| 1.0 - s)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct RevCandidate {
id: usize,
dist: f32,
}
impl Eq for RevCandidate {}
impl PartialOrd for RevCandidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for RevCandidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.dist
.partial_cmp(&self.dist)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| self.id.cmp(&other.id))
}
}
#[cfg(test)]
mod tests {
use crate::hnsw::{HnswConfig, HnswIndex};
use crate::{Vector, VectorIndex};
fn make_index(pairs: &[(&str, Vec<f32>)]) -> HnswIndex {
let config = HnswConfig::default();
let mut index = HnswIndex::new_cpu_only(config);
for (uri, data) in pairs {
index
.insert(uri.to_string(), Vector::new(data.clone()))
.expect("insert failed");
}
index
}
#[test]
fn test_heuristic_returns_at_most_m() {
let index = make_index(&[
("a", vec![1.0, 0.0, 0.0]),
("b", vec![0.9, 0.1, 0.0]),
("c", vec![0.8, 0.2, 0.0]),
("d", vec![0.0, 1.0, 0.0]),
("e", vec![0.0, 0.0, 1.0]),
]);
let candidates: Vec<(usize, f32)> = (0..5).map(|i| (i, i as f32 * 0.1)).collect();
let query = vec![1.0f32, 0.0, 0.0];
let result = index
.select_neighbors_heuristic_v2(&candidates, &query, 3, 0, false, false)
.expect("heuristic failed");
assert!(
result.len() <= 3,
"Got {} neighbors, expected <= 3",
result.len()
);
}
#[test]
fn test_heuristic_sorted_by_distance() {
let index = make_index(&[
("a", vec![1.0, 0.0]),
("b", vec![0.7, 0.7]),
("c", vec![0.0, 1.0]),
("d", vec![-1.0, 0.0]),
]);
let candidates: Vec<(usize, f32)> = vec![(0, 0.1), (1, 0.3), (2, 0.8), (3, 1.9)];
let query = vec![1.0f32, 0.0];
let result = index
.select_neighbors_heuristic_v2(&candidates, &query, 4, 0, false, false)
.expect("heuristic failed");
for window in result.windows(2) {
assert!(
window[0].1 <= window[1].1,
"Results not sorted: {} > {}",
window[0].1,
window[1].1
);
}
}
#[test]
fn test_heuristic_prunes_shadowed_candidates() {
let index = make_index(&[
("a", vec![1.0, 0.0, 0.0]),
("b", vec![0.9999, 0.001, 0.0]), ("c", vec![0.9998, 0.002, 0.0]), ("d", vec![0.0, 0.0, 1.0]), ]);
let candidates: Vec<(usize, f32)> = vec![
(0, 0.001), (1, 0.002), (2, 0.003), (3, 0.9), ];
let query = vec![1.0f32, 0.0, 0.0];
let result = index
.select_neighbors_heuristic_v2(&candidates, &query, 2, 1, false, false)
.expect("heuristic failed");
assert!(result.len() <= 2);
assert!(
result.iter().any(|&(id, _)| id == 0),
"Closest node must always be selected"
);
}
#[test]
fn test_keep_pruned_fills_slots() {
let index = make_index(&[
("a", vec![1.0, 0.0]),
("b", vec![0.9999, 0.001]),
("c", vec![0.9998, 0.002]),
("d", vec![0.9997, 0.003]),
]);
let candidates: Vec<(usize, f32)> = vec![(0, 0.01), (1, 0.02), (2, 0.03), (3, 0.04)];
let query = vec![1.0f32, 0.0];
let without_keep = index
.select_neighbors_heuristic_v2(&candidates, &query, 4, 1, false, false)
.expect("heuristic failed");
let with_keep = index
.select_neighbors_heuristic_v2(&candidates, &query, 4, 1, false, true)
.expect("heuristic with keep failed");
assert!(
with_keep.len() >= without_keep.len(),
"keep_pruned should fill more slots"
);
}
#[test]
fn test_extend_candidates_discovers_more() {
let index = make_index(&[
("origin", vec![0.0, 0.0]),
("mid", vec![0.5, 0.5]),
("far", vec![1.0, 1.0]),
]);
let candidates = vec![(0, 0.0f32), (1, 0.5)];
let query = vec![0.0f32, 0.0];
let without_ext = index
.select_neighbors_heuristic_v2(&candidates, &query, 5, 0, false, false)
.expect("failed");
let with_ext = index
.select_neighbors_heuristic_v2(&candidates, &query, 5, 0, true, false)
.expect("failed");
assert!(
with_ext.len() >= without_ext.len(),
"extend_candidates should not reduce result count"
);
}
#[test]
fn test_empty_candidates() {
let index = make_index(&[("a", vec![1.0, 0.0])]);
let result = index
.select_neighbors_heuristic_v2(&[], &[1.0, 0.0], 5, 0, false, false)
.expect("empty candidates should not fail");
assert!(result.is_empty());
}
#[test]
fn test_prune_connections_heuristic() {
let config = HnswConfig {
m: 2,
m_l0: 4,
..HnswConfig::default()
};
let mut index = HnswIndex::new_cpu_only(config);
for i in 0..10usize {
let angle = std::f32::consts::PI * 2.0 * i as f32 / 10.0;
let v = Vector::new(vec![angle.cos(), angle.sin()]);
index
.insert(format!("node_{}", i), v)
.expect("insert failed");
}
index
.prune_connections_heuristic(0, 0)
.expect("prune failed");
let connections_at_0 = index
.nodes()
.first()
.and_then(|n| n.get_connections(0))
.map(|c| c.len())
.unwrap_or(0);
assert!(
connections_at_0 <= 4,
"Expected <= 4 connections (m_l0), got {}",
connections_at_0
);
}
}