use crate::simd;
use crate::RetrieveError;
use smallvec::SmallVec;
use super::graph::NSWIndex;
use super::search::greedy_search;
pub fn construct_graph(index: &mut NSWIndex) -> Result<(), RetrieveError> {
if index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let n = index.num_vectors;
let m = index.params.m;
let m_max = index.params.m_max;
let ef_construction = index.params.ef_construction.max(m);
index.neighbors = vec![SmallVec::new(); n];
index.entry_point = Some(0);
for current_id in 1..n {
let entry_point = index.entry_point.unwrap_or(0);
let candidates = greedy_search(
index.get_vector(current_id),
entry_point,
&index.neighbors,
&index.vectors,
index.dimension,
ef_construction,
)?;
let selected = select_neighbors(&candidates, m);
for &neighbor_id in &selected {
let j = neighbor_id as usize;
if !index.neighbors[current_id].contains(&neighbor_id) {
index.neighbors[current_id].push(neighbor_id);
}
let current_u32 = current_id as u32;
if !index.neighbors[j].contains(¤t_u32) {
index.neighbors[j].push(current_u32);
}
if index.neighbors[j].len() > m_max {
prune_neighbors(
j,
m_max,
&index.vectors,
index.dimension,
&mut index.neighbors,
);
}
}
}
Ok(())
}
fn select_neighbors(candidates: &[(u32, f32)], m: usize) -> Vec<u32> {
candidates.iter().take(m).map(|&(id, _)| id).collect()
}
fn prune_neighbors(
node_id: usize,
m_max: usize,
vectors: &[f32],
dimension: usize,
neighbors: &mut [SmallVec<[u32; 16]>],
) {
let node_vec_start = node_id * dimension;
let node_vec = &vectors[node_vec_start..node_vec_start + dimension];
let mut with_dist: Vec<(u32, f32)> = neighbors[node_id]
.iter()
.map(|&nb_id| {
let start = nb_id as usize * dimension;
let nb_vec = &vectors[start..start + dimension];
let dist = 1.0 - simd::dot(node_vec, nb_vec);
(nb_id, dist)
})
.collect();
with_dist.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
with_dist.truncate(m_max);
neighbors[node_id] = with_dist.into_iter().map(|(id, _)| id).collect();
}