use crate::hnsw::construction::select_neighbors;
use crate::hnsw::distance as hnsw_distance;
use crate::hnsw::graph::NeighborhoodDiversification;
use crate::vamana::graph::VamanaIndex;
use crate::RetrieveError;
use rand::seq::IndexedRandom;
use smallvec::SmallVec;
fn initialize_random_graph(index: &mut VamanaIndex) -> Result<(), RetrieveError> {
if index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let min_degree = (index.num_vectors as f64).ln().ceil() as usize;
let mut rng = rand::rng();
let all_ids: Vec<u32> = (0..index.num_vectors as u32).collect();
for i in 0..index.num_vectors {
let mut neighbors: SmallVec<[u32; 16]> = SmallVec::with_capacity(min_degree);
let candidate_ids: Vec<u32> = all_ids
.iter()
.filter(|&&id| id != i as u32)
.copied()
.collect();
let num_neighbors = min_degree.min(candidate_ids.len());
let selected: Vec<u32> = candidate_ids
.choose_multiple(&mut rng, num_neighbors)
.copied()
.collect();
neighbors.extend(selected);
index.neighbors[i] = neighbors;
}
Ok(())
}
#[cfg(feature = "vamana")]
fn refine_with_rrnd(index: &mut VamanaIndex) -> Result<(), RetrieveError> {
if index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
for current_id in 0..index.num_vectors {
let current_vector = index.get_vector(current_id);
let mut candidates: Vec<(u32, f32)> = Vec::with_capacity(index.params.ef_construction);
for &neighbor_id in &index.neighbors[current_id] {
let neighbor_vec = index.get_vector(neighbor_id as usize);
let dist = hnsw_distance::cosine_distance(current_vector, neighbor_vec);
candidates.push((neighbor_id, dist));
}
use std::collections::VecDeque;
let mut to_explore: VecDeque<u32> = index.neighbors[current_id].iter().copied().collect();
let mut visited = std::collections::HashSet::with_capacity(index.params.ef_construction);
visited.insert(current_id as u32);
while let Some(explore_id) = to_explore.pop_front() {
if visited.contains(&explore_id) {
continue;
}
visited.insert(explore_id);
if candidates.len() >= index.params.ef_construction {
break;
}
let explore_vec = index.get_vector(explore_id as usize);
let dist = hnsw_distance::cosine_distance(current_vector, explore_vec);
candidates.push((explore_id, dist));
for &neighbor_id in &index.neighbors[explore_id as usize] {
if !visited.contains(&neighbor_id) {
to_explore.push_back(neighbor_id);
}
}
}
let selected = select_neighbors(
current_vector,
&candidates,
index.params.max_degree,
&index.vectors,
index.dimension,
&NeighborhoodDiversification::RelaxedRelative {
alpha: index.params.alpha,
},
);
index.neighbors[current_id] = SmallVec::from_vec(selected);
}
Ok(())
}
#[cfg(feature = "vamana")]
fn refine_with_rnd(index: &mut VamanaIndex) -> Result<(), RetrieveError> {
if index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
for current_id in 0..index.num_vectors {
let current_vector = index.get_vector(current_id);
let mut candidates: Vec<(u32, f32)> = Vec::with_capacity(index.params.ef_construction);
for &neighbor_id in &index.neighbors[current_id] {
let neighbor_vec = index.get_vector(neighbor_id as usize);
let dist = hnsw_distance::cosine_distance(current_vector, neighbor_vec);
candidates.push((neighbor_id, dist));
}
use std::collections::VecDeque;
let mut to_explore: VecDeque<u32> = index.neighbors[current_id].iter().copied().collect();
let mut visited = std::collections::HashSet::with_capacity(index.params.ef_construction);
visited.insert(current_id as u32);
while let Some(explore_id) = to_explore.pop_front() {
if visited.contains(&explore_id) {
continue;
}
visited.insert(explore_id);
if candidates.len() >= index.params.ef_construction {
break;
}
let explore_vec = index.get_vector(explore_id as usize);
let dist = hnsw_distance::cosine_distance(current_vector, explore_vec);
candidates.push((explore_id, dist));
for &neighbor_id in &index.neighbors[explore_id as usize] {
if !visited.contains(&neighbor_id) {
to_explore.push_back(neighbor_id);
}
}
}
let selected = select_neighbors(
current_vector,
&candidates,
index.params.max_degree,
&index.vectors,
index.dimension,
&NeighborhoodDiversification::RelativeNeighborhood,
);
index.neighbors[current_id] = SmallVec::from_vec(selected);
}
Ok(())
}
pub fn construct_graph(index: &mut VamanaIndex) -> Result<(), RetrieveError> {
if index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
initialize_random_graph(index)?;
refine_with_rrnd(index)?;
refine_with_rnd(index)?;
Ok(())
}