use crate::distance as hnsw_distance;
use crate::hnsw::construction::select_neighbors;
use crate::hnsw::graph::NeighborhoodDiversification;
use crate::vamana::graph::VamanaIndex;
use crate::RetrieveError;
use rand::seq::IndexedRandom;
use smallvec::SmallVec;
fn initialize_random_graph(index: &mut VamanaIndex) -> Result<(), RetrieveError> {
if index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let min_degree = (index.num_vectors as f64).ln().ceil() as usize;
let mut rng = rand::rng();
let all_ids: Vec<u32> = (0..index.num_vectors as u32).collect();
for i in 0..index.num_vectors {
let mut neighbors: SmallVec<[u32; 16]> = SmallVec::with_capacity(min_degree);
let candidate_ids: Vec<u32> = all_ids
.iter()
.filter(|&&id| id != i as u32)
.copied()
.collect();
let num_neighbors = min_degree.min(candidate_ids.len());
let selected: Vec<u32> = candidate_ids
.choose_multiple(&mut rng, num_neighbors)
.copied()
.collect();
neighbors.extend(selected);
index.neighbors[i] = neighbors;
}
Ok(())
}
#[cfg(feature = "vamana")]
fn refine_with_rrnd(index: &mut VamanaIndex) -> Result<(), RetrieveError> {
if index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
for current_id in 0..index.num_vectors {
let current_vector = index.get_vector(current_id);
let mut candidates: Vec<(u32, f32)> = Vec::with_capacity(index.params.ef_construction);
for &neighbor_id in &index.neighbors[current_id] {
let neighbor_vec = index.get_vector(neighbor_id as usize);
let dist = hnsw_distance::cosine_distance_normalized(current_vector, neighbor_vec);
candidates.push((neighbor_id, dist));
}
use std::collections::VecDeque;
let mut to_explore: VecDeque<u32> = index.neighbors[current_id].iter().copied().collect();
let mut visited = std::collections::HashSet::with_capacity(index.params.ef_construction);
visited.insert(current_id as u32);
while let Some(explore_id) = to_explore.pop_front() {
if visited.contains(&explore_id) {
continue;
}
visited.insert(explore_id);
if candidates.len() >= index.params.ef_construction {
break;
}
let explore_vec = index.get_vector(explore_id as usize);
let dist = hnsw_distance::cosine_distance_normalized(current_vector, explore_vec);
candidates.push((explore_id, dist));
for &neighbor_id in &index.neighbors[explore_id as usize] {
if !visited.contains(&neighbor_id) {
to_explore.push_back(neighbor_id);
}
}
}
let selected = select_neighbors(
current_vector,
&candidates,
index.params.max_degree,
&index.vectors,
index.dimension,
&NeighborhoodDiversification::RelaxedRelative {
alpha: index.params.alpha,
},
);
index.neighbors[current_id] = SmallVec::from_vec(selected.clone());
let dim = index.dimension;
let max_deg = index.params.max_degree;
for &neighbor_id in &selected {
let reverse = &mut index.neighbors[neighbor_id as usize];
if !reverse.contains(&(current_id as u32)) {
reverse.push(current_id as u32);
if reverse.len() > max_deg {
let nstart = neighbor_id as usize * dim;
let node_vec = &index.vectors[nstart..nstart + dim];
let mut scored: Vec<(u32, f32)> = reverse
.iter()
.map(|&nid| {
let s = nid as usize * dim;
let v = &index.vectors[s..s + dim];
(
nid,
crate::distance::cosine_distance_normalized(node_vec, v),
)
})
.collect();
scored.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
scored.truncate(max_deg);
*reverse = scored.iter().map(|(id, _)| *id).collect();
}
}
}
}
Ok(())
}
#[cfg(feature = "vamana")]
fn refine_with_rnd(index: &mut VamanaIndex) -> Result<(), RetrieveError> {
if index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
for current_id in 0..index.num_vectors {
let current_vector = index.get_vector(current_id);
let mut candidates: Vec<(u32, f32)> = Vec::with_capacity(index.params.ef_construction);
for &neighbor_id in &index.neighbors[current_id] {
let neighbor_vec = index.get_vector(neighbor_id as usize);
let dist = hnsw_distance::cosine_distance_normalized(current_vector, neighbor_vec);
candidates.push((neighbor_id, dist));
}
use std::collections::VecDeque;
let mut to_explore: VecDeque<u32> = index.neighbors[current_id].iter().copied().collect();
let mut visited = std::collections::HashSet::with_capacity(index.params.ef_construction);
visited.insert(current_id as u32);
while let Some(explore_id) = to_explore.pop_front() {
if visited.contains(&explore_id) {
continue;
}
visited.insert(explore_id);
if candidates.len() >= index.params.ef_construction {
break;
}
let explore_vec = index.get_vector(explore_id as usize);
let dist = hnsw_distance::cosine_distance_normalized(current_vector, explore_vec);
candidates.push((explore_id, dist));
for &neighbor_id in &index.neighbors[explore_id as usize] {
if !visited.contains(&neighbor_id) {
to_explore.push_back(neighbor_id);
}
}
}
let selected = select_neighbors(
current_vector,
&candidates,
index.params.max_degree,
&index.vectors,
index.dimension,
&NeighborhoodDiversification::RelativeNeighborhood,
);
index.neighbors[current_id] = SmallVec::from_vec(selected.clone());
let dim = index.dimension;
let max_deg = index.params.max_degree;
for &neighbor_id in &selected {
let reverse = &mut index.neighbors[neighbor_id as usize];
if !reverse.contains(&(current_id as u32)) {
reverse.push(current_id as u32);
if reverse.len() > max_deg {
let nstart = neighbor_id as usize * dim;
let node_vec = &index.vectors[nstart..nstart + dim];
let mut scored: Vec<(u32, f32)> = reverse
.iter()
.map(|&nid| {
let s = nid as usize * dim;
let v = &index.vectors[s..s + dim];
(
nid,
crate::distance::cosine_distance_normalized(node_vec, v),
)
})
.collect();
scored.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
scored.truncate(max_deg);
*reverse = scored.iter().map(|(id, _)| *id).collect();
}
}
}
}
Ok(())
}
fn compute_medoid(index: &VamanaIndex) -> u32 {
let n = index.num_vectors;
let dim = index.dimension;
let mut centroid = vec![0.0_f32; dim];
for i in 0..n {
let vec = index.get_vector(i);
for (c, &v) in centroid.iter_mut().zip(vec.iter()) {
*c += v;
}
}
let inv_n = 1.0 / n as f32;
for c in centroid.iter_mut() {
*c *= inv_n;
}
centroid = crate::distance::normalize(¢roid);
let mut best_id: u32 = 0;
let mut best_dist = f32::INFINITY;
for i in 0..n {
let vec = index.get_vector(i);
let dist = hnsw_distance::cosine_distance_normalized(¢roid, vec);
if dist < best_dist {
best_dist = dist;
best_id = i as u32;
}
}
best_id
}
pub fn construct_graph(index: &mut VamanaIndex) -> Result<(), RetrieveError> {
if index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
index.medoid = compute_medoid(index);
initialize_random_graph(index)?;
refine_with_rnd(index)?;
refine_with_rrnd(index)?;
Ok(())
}