use crate::simd;
use crate::RetrieveError;
use smallvec::SmallVec;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use super::graph::NSWIndex;
use super::search::greedy_search;
pub fn construct_graph(index: &mut NSWIndex) -> Result<(), RetrieveError> {
if index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let n = index.num_vectors;
let m = index.params.m;
let m_max = index.params.m_max;
let ef_construction = index.params.ef_construction.max(m);
index.neighbors = vec![SmallVec::new(); n];
index.entry_point = Some(0);
for current_id in 1..n {
let entry_point = index.entry_point.unwrap_or(0);
let candidates = greedy_search(
index.get_vector(current_id),
entry_point,
&index.neighbors,
&index.vectors,
index.dimension,
ef_construction,
)?;
let selected = select_neighbors(&candidates, m);
for &neighbor_id in &selected {
let j = neighbor_id as usize;
if !index.neighbors[current_id].contains(&neighbor_id) {
index.neighbors[current_id].push(neighbor_id);
}
let current_u32 = current_id as u32;
if !index.neighbors[j].contains(¤t_u32) {
index.neighbors[j].push(current_u32);
}
if index.neighbors[j].len() > m_max {
prune_neighbors(
j,
m_max,
&index.vectors,
index.dimension,
&mut index.neighbors,
);
}
}
}
Ok(())
}
fn select_neighbors(candidates: &[(u32, f32)], m: usize) -> Vec<u32> {
candidates.iter().take(m).map(|&(id, _)| id).collect()
}
#[cfg(feature = "parallel")]
pub fn construct_graph_parallel(
index: &mut NSWIndex,
batch_size: usize,
) -> Result<(), RetrieveError> {
if index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let n = index.num_vectors;
let m = index.params.m;
let m_max = index.params.m_max;
let ef_construction = index.params.ef_construction.max(m);
let dim = index.dimension;
index.neighbors = vec![SmallVec::new(); n];
index.entry_point = Some(0);
let sequential_count = (ef_construction * 2).min(n);
for current_id in 1..sequential_count {
let entry_point = index.entry_point.unwrap_or(0);
let candidates = greedy_search(
index.get_vector(current_id),
entry_point,
&index.neighbors,
&index.vectors,
dim,
ef_construction,
)?;
let selected = select_neighbors(&candidates, m);
for &neighbor_id in &selected {
let j = neighbor_id as usize;
if !index.neighbors[current_id].contains(&neighbor_id) {
index.neighbors[current_id].push(neighbor_id);
}
let current_u32 = current_id as u32;
if !index.neighbors[j].contains(¤t_u32) {
index.neighbors[j].push(current_u32);
}
}
}
let batch_sz = batch_size.max(1);
for batch_start in (sequential_count..n).step_by(batch_sz) {
let batch_end = (batch_start + batch_sz).min(n);
let batch_ids: Vec<usize> = (batch_start..batch_end).collect();
let entry_point = index.entry_point.unwrap_or(0);
let results: Vec<(usize, Vec<u32>)> = batch_ids
.par_iter()
.map(|¤t_id| {
let candidates = greedy_search(
index.get_vector(current_id),
entry_point,
&index.neighbors,
&index.vectors,
dim,
ef_construction,
)
.unwrap_or_default();
(current_id, select_neighbors(&candidates, m))
})
.collect();
for (current_id, selected) in &results {
for &neighbor_id in selected {
let j = neighbor_id as usize;
if !index.neighbors[*current_id].contains(&neighbor_id) {
index.neighbors[*current_id].push(neighbor_id);
}
let current_u32 = *current_id as u32;
if !index.neighbors[j].contains(¤t_u32) {
index.neighbors[j].push(current_u32);
}
}
}
let overweight: Vec<usize> = (0..n)
.filter(|&id| index.neighbors[id].len() > m_max)
.collect();
if !overweight.is_empty() {
let pruned: Vec<(usize, SmallVec<[u32; 16]>)> = overweight
.par_iter()
.map(|&node_id| {
let node_start = node_id * dim;
let node_vec = &index.vectors[node_start..node_start + dim];
let mut with_dist: Vec<(u32, f32)> = index.neighbors[node_id]
.iter()
.map(|&nb_id| {
let start = nb_id as usize * dim;
let nb_vec = &index.vectors[start..start + dim];
(nb_id, 1.0 - simd::dot(node_vec, nb_vec))
})
.collect();
with_dist.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
with_dist.truncate(m_max);
(node_id, with_dist.into_iter().map(|(id, _)| id).collect())
})
.collect();
for (node_id, new_neighbors) in pruned {
index.neighbors[node_id] = new_neighbors;
}
}
}
Ok(())
}
fn prune_neighbors(
node_id: usize,
m_max: usize,
vectors: &[f32],
dimension: usize,
neighbors: &mut [SmallVec<[u32; 16]>],
) {
let node_vec_start = node_id * dimension;
let node_vec = &vectors[node_vec_start..node_vec_start + dimension];
let mut with_dist: Vec<(u32, f32)> = neighbors[node_id]
.iter()
.map(|&nb_id| {
let start = nb_id as usize * dimension;
let nb_vec = &vectors[start..start + dimension];
let dist = 1.0 - simd::dot(node_vec, nb_vec);
(nb_id, dist)
})
.collect();
with_dist.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
with_dist.truncate(m_max);
neighbors[node_id] = with_dist.into_iter().map(|(id, _)| id).collect();
}