use std::marker::PhantomData;
use indicatif::{ProgressBar, ProgressStyle};
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use rand::{Rng, SeedableRng};
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
use serde::{Deserialize, Serialize};
use crate::graph::{GraphTrait, GrowableGraph};
use crate::graph_index::GraphIndex;
use crate::quantizer::{IdentityQuantizer, Quantizer, QueryEvaluator};
use crate::DotProduct;
use crate::EuclideanDistance;
use crate::{hnsw_utils::*, Dataset, DistanceType, Float, GrowableDataset};
#[derive(Serialize, Deserialize)]
pub struct HNSW<D, Q, G> {
levels: Box<[G]>,
level1_to_level0_mapping: Box<[usize]>,
dataset: D,
num_neighbors_per_vec: usize,
entry_point: usize,
_phantom: PhantomData<Q>,
}
pub struct HNSWBuildParams {
pub num_neighbors_per_vec: usize,
pub ef_construction: usize,
pub initial_build_batch_size: usize,
pub max_build_batch_size: usize,
}
impl HNSWBuildParams {
#[must_use]
pub fn new(
num_neighbors_per_vec: usize,
ef_construction: usize,
initial_build_batch_size: usize,
max_build_batch_size: usize,
) -> Self {
Self {
num_neighbors_per_vec,
ef_construction,
initial_build_batch_size,
max_build_batch_size,
}
}
}
impl Default for HNSWBuildParams {
fn default() -> Self {
Self {
num_neighbors_per_vec: 16, ef_construction: 150, initial_build_batch_size: 4, max_build_batch_size: 320, }
}
}
pub struct HNSWSearchParams {
pub ef_search: usize,
}
impl HNSWSearchParams {
#[must_use]
pub fn new(ef_search: usize) -> Self {
Self { ef_search }
}
}
impl Default for HNSWSearchParams {
fn default() -> Self {
Self { ef_search: 100 }
}
}
impl<D, Q, G> HNSW<D, Q, G>
where
D: Dataset<Q> + GrowableDataset<Q>,
Q: Quantizer<DatasetType = D>,
G: GraphTrait,
{
}
impl<D, Q, G> HNSW<D, Q, G>
where
D: Dataset<Q> + Sync,
Q: Quantizer<InputItem: Float, DatasetType = D> + Sync,
G: GraphTrait,
{
#[must_use]
#[inline]
pub fn max_level(&self) -> usize {
if self.levels.is_empty() {
0
} else {
self.levels.len() - 1
}
}
#[must_use]
pub fn nodes_per_level(&self) -> Vec<usize> {
self.levels.iter().map(|g| g.n_nodes()).collect()
}
}
impl<D, Q, G> GraphIndex<D, Q, G> for HNSW<D, Q, G>
where
D: Dataset<Q> + GrowableDataset<Q> + Sync,
Q: Quantizer<DatasetType = D>,
Q: Quantizer<InputItem: Float, DatasetType = D> + Sync,
G: GraphTrait,
{
type BuildParams = HNSWBuildParams;
type SearchParams = HNSWSearchParams;
#[inline]
fn n_vectors(&self) -> usize {
self.dataset.len()
}
#[inline]
fn dim(&self) -> usize {
self.dataset.dim()
}
fn print_space_usage_bytes(&self) {
let dataset_size = self.dataset.get_space_usage_bytes();
let quantizer_size = self.dataset.quantizer().get_space_usage_bytes();
let index_size = self
.levels
.iter()
.map(|g| g.get_space_usage_bytes())
.sum::<usize>();
let total_size = dataset_size + quantizer_size + index_size;
println!(
"[######] Space usage: Dataset: {dataset_size} bytes, Quantizer: {quantizer_size} bytes, Index: {index_size} bytes, Total: {total_size} bytes"
);
}
fn search<'a, QD, QQ>(
&'a self,
query: QD::DataType<'a>,
k: usize,
search_params: &Self::SearchParams,
) -> Vec<(f32, usize)>
where
QD: Dataset<QQ> + Sync + 'a,
QQ: Quantizer<DatasetType = QD> + Sync + 'a,
<Q as Quantizer>::Evaluator<'a>:
QueryEvaluator<'a, QueryType = <QD as Dataset<QQ>>::DataType<'a>>,
<Q as Quantizer>::InputItem: EuclideanDistance<<Q as Quantizer>::InputItem>
+ DotProduct<<Q as Quantizer>::InputItem>,
<Q as Quantizer>::InputItem: 'a,
{
let query_eval = self.dataset.query_evaluator(query);
let num_levels = self.levels.len();
let mut entry_node = Candidate(f32::MAX, self.entry_point);
if num_levels > 1 {
for level_graph in &self.levels[..num_levels - 1] {
entry_node =
level_graph.greedy_search_nearest(&self.dataset, &query_eval, entry_node);
}
}
let ground_graph = &self.levels[num_levels - 1];
let entry_global_id = if num_levels > 1 {
self.level1_to_level0_mapping[entry_node.id_vec()]
} else {
self.entry_point
};
let ground_entry_node = Candidate(entry_node.distance(), entry_global_id);
let mut topk = ground_graph.greedy_search_topk(
&self.dataset,
ground_entry_node,
&query_eval,
k,
search_params.ef_search,
);
if self.dataset.quantizer().distance() == DistanceType::DotProduct {
topk.iter_mut().for_each(|(dis, _)| *dis = -(*dis));
}
topk
}
fn build_from_dataset<'a, BD, IQ>(
source_dataset: &'a BD,
quantizer: Q,
build_params: &Self::BuildParams,
) -> Self
where
BD: Dataset<IQ> + Sync + 'a,
IQ: IdentityQuantizer<DatasetType = BD, T: Float> + Sync + 'a,
<IQ as Quantizer>::Evaluator<'a>:
QueryEvaluator<'a, QueryType = <BD as Dataset<IQ>>::DataType<'a>>,
D: GrowableDataset<Q, InputDataType<'a> = <BD as Dataset<IQ>>::DataType<'a>>,
<Q as Quantizer>::InputItem: 'a + Float,
{
let num_vectors = source_dataset.len();
let m = build_params.num_neighbors_per_vec;
let default_probabs =
compute_levels_probabilities(1.0 / (m as f32).ln(), num_vectors as f32);
let (levels_mapping, ids_sorted_by_level, cumulative_ids_per_level, max_level) =
compute_levels(&default_probabs, num_vectors);
let mut growable_levels: Vec<GrowableGraph> = Vec::with_capacity(max_level as usize + 1);
for i in (1..=max_level).rev() {
let mut graph = GrowableGraph::with_max_degree(m);
let num_nodes_in_level = levels_mapping[i as usize - 1].len();
graph.reserve(num_nodes_in_level);
graph.set_mapping(levels_mapping[i as usize - 1].clone());
growable_levels.push(graph);
}
let mut ground_graph = GrowableGraph::with_max_degree(2 * m);
ground_graph.reserve(num_vectors);
growable_levels.push(ground_graph);
let level1_to_level0_mapping = if max_level > 0 {
levels_mapping[0].clone()
} else {
Vec::new()
};
let entry_point_local_id = 0;
let entry_point_global_id = ids_sorted_by_level[0];
let pb = ProgressBar::new(num_vectors as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta}) - Building HNSW")
.unwrap()
.progress_chars("#>-"),
);
Self::insert_entry_point(&mut growable_levels, entry_point_global_id, max_level, &pb);
for level in (0..=max_level).rev() {
let start_index = cumulative_ids_per_level[max_level as usize - level as usize];
let start_index = if start_index == 0 { 1 } else { start_index };
let end_index = cumulative_ids_per_level[max_level as usize - level as usize + 1];
if start_index >= end_index {
continue;
}
let nodes_to_insert_slice = &ids_sorted_by_level[start_index..end_index];
if nodes_to_insert_slice.len() > 2 * build_params.max_build_batch_size {
Self::process_level_parallelly(
nodes_to_insert_slice,
level,
max_level,
m,
&mut growable_levels,
source_dataset,
build_params,
entry_point_local_id,
&level1_to_level0_mapping,
&ids_sorted_by_level,
&pb,
);
} else {
Self::process_level_sequentially(
nodes_to_insert_slice,
level,
max_level,
m,
&mut growable_levels,
source_dataset,
build_params,
entry_point_local_id,
&level1_to_level0_mapping,
&ids_sorted_by_level,
&pb,
);
}
}
pb.finish_with_message("HNSW build complete.");
let final_levels: Vec<G> = growable_levels
.into_iter()
.map(|g| G::from_growable_graph(&g))
.collect();
let mut dataset = D::new(quantizer, source_dataset.dim());
for id in 0..source_dataset.len() {
dataset.push(&source_dataset.get(id));
}
Self {
levels: final_levels.into_boxed_slice(),
level1_to_level0_mapping: level1_to_level0_mapping.into_boxed_slice(),
dataset,
num_neighbors_per_vec: m,
entry_point: entry_point_local_id,
_phantom: PhantomData,
}
}
}
#[must_use]
fn compute_levels_probabilities(level_mult: f32, dataset_len: f32) -> Vec<f32> {
let mut probabs_levels = Vec::new();
for level in 0.. {
let proba = (-level as f32 / level_mult).exp() * (1.0 - (-1.0 / level_mult).exp());
if proba < 1.0 / dataset_len {
break;
}
probabs_levels.push(proba);
}
probabs_levels
}
#[must_use]
#[inline]
fn random_level(probabs_levels: &[f32], rng: &mut StdRng) -> u8 {
let mut f: f32 = rng.gen_range(0.0..1.0);
for (level, &prob) in probabs_levels.iter().enumerate() {
if f < prob {
return level as u8;
}
f -= prob;
}
(probabs_levels.len() - 1) as u8
}
#[must_use]
#[inline]
fn compute_levels(
default_probabs: &Vec<f32>,
num_vectors: usize,
) -> (Vec<Vec<usize>>, Vec<usize>, Vec<usize>, u8) {
let mut rng = StdRng::seed_from_u64(523);
let mut all_ids: Vec<usize> = (0..num_vectors).collect();
all_ids.shuffle(&mut rng);
let mut ids_per_level: Vec<Vec<usize>> = vec![Vec::new(); default_probabs.len() + 1];
for &id in &all_ids {
let level = random_level(default_probabs, &mut rng);
ids_per_level[level as usize].push(id);
}
let max_level = ids_per_level
.iter()
.rposition(|level_nodes| !level_nodes.is_empty())
.unwrap_or(0) as u8;
let mut ids_sorted_by_level: Vec<usize> = Vec::with_capacity(num_vectors);
for i in (0..=max_level).rev() {
ids_sorted_by_level.extend(&ids_per_level[i as usize]);
}
let mut cumulative_ids_per_level = Vec::with_capacity(max_level as usize + 2);
cumulative_ids_per_level.push(0);
let mut count = 0;
for i in (0..=max_level).rev() {
count += ids_per_level[i as usize].len();
cumulative_ids_per_level.push(count);
}
let mut levels_mapping: Vec<Vec<usize>> = Vec::with_capacity(max_level as usize);
for i in 0..max_level as usize {
let num_nodes_at_this_level_or_above = cumulative_ids_per_level[max_level as usize - i];
let mapping_for_this_level: Vec<usize> =
ids_sorted_by_level[0..num_nodes_at_this_level_or_above].to_vec();
levels_mapping.push(mapping_for_this_level);
}
(
levels_mapping,
ids_sorted_by_level,
cumulative_ids_per_level,
max_level,
)
}
impl<D, Q, G> HNSW<D, Q, G>
where
D: Dataset<Q> + GrowableDataset<Q> + Sync,
Q: Quantizer<InputItem: Float, DatasetType = D> + Sync,
G: GraphTrait,
{
fn insert_entry_point(
growable_levels: &mut [GrowableGraph],
entry_point_global_id: usize,
max_level: u8,
pb: &ProgressBar,
) {
for (i, graph) in growable_levels.iter_mut().enumerate() {
if i < max_level as usize {
graph.push_with_precomputed_reverse_links(Some(entry_point_global_id), &[], 0, &[]);
} else {
graph.push_with_precomputed_reverse_links(None, &[], entry_point_global_id, &[]);
}
}
pb.inc(1);
for graph in growable_levels.iter_mut().take(max_level as usize) {
graph.advance_inserted_nodes(1);
}
}
#[allow(clippy::too_many_arguments)]
fn process_level_sequentially<'a, BD, IQ>(
nodes_to_insert_slice: &[usize],
level: u8,
max_level: u8,
m: usize,
growable_levels: &mut [GrowableGraph],
source_dataset: &'a BD,
build_params: &HNSWBuildParams,
entry_point_local_id: usize,
level1_to_level0_mapping: &[usize],
ids_sorted_by_level: &[usize],
pb: &ProgressBar,
) where
BD: Dataset<IQ> + Sync + 'a,
IQ: IdentityQuantizer<DatasetType = BD, T: Float> + Sync + 'a,
<IQ as Quantizer>::Evaluator<'a>:
QueryEvaluator<'a, QueryType = <BD as Dataset<IQ>>::DataType<'a>>,
<Q as Quantizer>::InputItem: 'a + Float,
{
for &global_id in nodes_to_insert_slice {
let query_eval = source_dataset.query_evaluator(source_dataset.get(global_id));
let mut entry_node = Candidate(f32::MAX, entry_point_local_id);
if level > 0 {
for current_level in ((level + 1)..=max_level).rev() {
let graph_idx = max_level as usize - current_level as usize;
entry_node = growable_levels[graph_idx].greedy_search_nearest(
source_dataset,
&query_eval,
entry_node,
);
}
for current_level in (1..=level).rev() {
let graph_idx = max_level as usize - current_level as usize;
let graph = &mut growable_levels[graph_idx];
let local_id = graph.inserted_nodes();
let (forward, reverse, new_entry) = graph.find_and_prune_neighbors(
source_dataset,
&query_eval,
entry_node,
build_params.ef_construction,
m,
local_id,
);
graph.push_with_precomputed_reverse_links(
Some(global_id),
&forward,
local_id,
&reverse,
);
graph.advance_inserted_nodes(1);
entry_node = new_entry;
}
}
let ground_graph = &mut growable_levels[max_level as usize];
let ground_entry_global_id = if max_level > 0 {
level1_to_level0_mapping[entry_node.id_vec()]
} else {
ids_sorted_by_level[0]
};
let dist = query_eval.compute_distance(source_dataset, ground_entry_global_id);
let ground_entry_node = Candidate(dist, ground_entry_global_id);
let (ground_neighbors, ground_reverse_links, _) = ground_graph
.find_and_prune_neighbors(
source_dataset,
&query_eval,
ground_entry_node,
build_params.ef_construction,
2 * m,
global_id,
);
ground_graph.push_with_precomputed_reverse_links(
None,
&ground_neighbors,
global_id,
&ground_reverse_links,
);
pb.inc(1);
}
}
#[allow(clippy::too_many_arguments)]
fn process_level_parallelly<'a, BD, IQ>(
nodes_to_insert_slice: &[usize],
level: u8,
max_level: u8,
m: usize,
growable_levels: &mut [GrowableGraph],
source_dataset: &'a BD,
build_params: &HNSWBuildParams,
entry_point_local_id: usize,
level1_to_level0_mapping: &[usize],
ids_sorted_by_level: &[usize],
pb: &ProgressBar,
) where
BD: Dataset<IQ> + Sync + 'a,
IQ: IdentityQuantizer<DatasetType = BD, T: Float> + Sync + 'a,
<IQ as Quantizer>::Evaluator<'a>:
QueryEvaluator<'a, QueryType = <BD as Dataset<IQ>>::DataType<'a>>,
<Q as Quantizer>::InputItem: 'a + Float,
{
let mut current_batch_size = build_params.initial_build_batch_size;
let max_batch_size = build_params.max_build_batch_size;
let level_start_local_ids: Vec<usize> =
growable_levels.iter().map(|g| g.inserted_nodes()).collect();
let mut processed_nodes = 0;
while processed_nodes < nodes_to_insert_slice.len() {
let remaining_nodes = nodes_to_insert_slice.len() - processed_nodes;
let actual_batch_size = current_batch_size.min(remaining_nodes);
let batch =
&nodes_to_insert_slice[processed_nodes..processed_nodes + actual_batch_size];
let insertion_data: Vec<_> = batch
.par_iter()
.enumerate()
.map(|(i, &global_id)| {
let query_eval = source_dataset.query_evaluator(source_dataset.get(global_id));
let mut entry_node = Candidate(f32::MAX, entry_point_local_id);
let mut upper_level_data = Vec::new();
if level > 0 {
for current_level in ((level + 1)..=max_level).rev() {
let graph_idx = max_level as usize - current_level as usize;
entry_node = growable_levels[graph_idx].greedy_search_nearest(
source_dataset,
&query_eval,
entry_node,
);
}
for current_level in (1..=level).rev() {
let graph_idx = max_level as usize - current_level as usize;
let graph = &growable_levels[graph_idx];
let local_id = level_start_local_ids[graph_idx] + processed_nodes + i;
let (forward, reverse, new_entry) = graph.find_and_prune_neighbors(
source_dataset,
&query_eval,
entry_node,
build_params.ef_construction,
m,
local_id,
);
upper_level_data.push((forward, reverse));
entry_node = new_entry;
}
}
let ground_graph = &growable_levels[max_level as usize];
let ground_entry_global_id = if max_level > 0 {
level1_to_level0_mapping[entry_node.id_vec()]
} else {
ids_sorted_by_level[0]
};
let dist = query_eval.compute_distance(source_dataset, ground_entry_global_id);
let ground_entry_node = Candidate(dist, ground_entry_global_id);
let (ground_neighbors, ground_reverse_links, _) = ground_graph
.find_and_prune_neighbors(
source_dataset,
&query_eval,
ground_entry_node,
build_params.ef_construction,
2 * m,
global_id,
);
(
global_id,
upper_level_data,
(ground_neighbors, ground_reverse_links),
)
})
.collect();
for (i, (global_id, upper_level_data, ground_data)) in
insertion_data.into_iter().enumerate()
{
for (level_idx, (forward, reverse)) in
upper_level_data.into_iter().rev().enumerate()
{
let hnsw_level = level_idx + 1;
let graph_idx = max_level as usize - hnsw_level;
let graph = &mut growable_levels[graph_idx];
let local_id = level_start_local_ids[graph_idx] + processed_nodes + i;
graph.push_with_precomputed_reverse_links(
Some(global_id),
&forward,
local_id,
&reverse,
);
}
let (forward, reverse) = ground_data;
let ground_graph = &mut growable_levels[max_level as usize];
ground_graph
.push_with_precomputed_reverse_links(None, &forward, global_id, &reverse);
}
for current_level in (1..=level).rev() {
let graph_idx = max_level as usize - current_level as usize;
growable_levels[graph_idx].advance_inserted_nodes(actual_batch_size);
}
processed_nodes += actual_batch_size;
pb.inc(actual_batch_size as u64);
if current_batch_size < max_batch_size {
current_batch_size = (current_batch_size * 2).min(max_batch_size);
}
}
}
}