use ann_search_rs::cpu::hnsw::{HnswIndex, HnswState};
use ann_search_rs::cpu::nndescent::{ApplySortedUpdates, NNDescent, NNDescentQuery};
use ann_search_rs::prelude::*;
use ann_search_rs::*;
use faer::MatRef;
use num_traits::Float;
use rayon::prelude::*;
use std::default::Default;
#[derive(Default)]
pub enum AnnSearch {
#[default]
KmKnn,
Hnsw,
NNDescent,
Annoy,
Ivf,
BallTree,
Exhaustive,
}
#[derive(Debug, Clone)]
pub struct NearestNeighbourParams<T> {
pub dist_metric: String,
pub n_tree: usize,
pub search_budget: Option<usize>,
pub m: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub diversify_prob: T,
pub delta: T,
pub ef_budget: Option<usize>,
pub bt_budget: T,
pub n_list: Option<usize>,
pub n_probes: Option<usize>,
}
impl<T> NearestNeighbourParams<T> {
#[allow(clippy::too_many_arguments)]
pub fn new(
dist_metric: String,
n_tree: usize,
search_budget: Option<usize>,
m: usize,
ef_construction: usize,
ef_search: usize,
diversify_prob: T,
delta: T,
ef_budget: Option<usize>,
bt_budget: T,
n_list: Option<usize>,
n_probes: Option<usize>,
) -> Self {
Self {
dist_metric,
n_tree,
search_budget,
m,
ef_construction,
ef_search,
diversify_prob,
delta,
ef_budget,
bt_budget,
n_list,
n_probes,
}
}
}
impl<T> Default for NearestNeighbourParams<T>
where
T: Float,
{
fn default() -> Self {
Self {
dist_metric: "euclidean".to_string(),
n_tree: 50,
search_budget: None,
m: 16,
ef_construction: 200,
ef_search: 100,
diversify_prob: T::from(0.0).unwrap(),
delta: T::from(0.001).unwrap(),
ef_budget: None,
bt_budget: T::from(0.1).unwrap(),
n_list: None,
n_probes: None,
}
}
}
pub fn parse_ann_search(s: &str) -> Option<AnnSearch> {
match s.to_lowercase().as_str() {
"annoy" => Some(AnnSearch::Annoy),
"balltree" => Some(AnnSearch::BallTree),
"exhaustive" => Some(AnnSearch::Exhaustive),
"hnsw" => Some(AnnSearch::Hnsw),
"ivf" => Some(AnnSearch::Ivf),
"kmknn" => Some(AnnSearch::KmKnn),
"nndescent" => Some(AnnSearch::NNDescent),
_ => None,
}
}
pub fn run_ann_search<T>(
data: MatRef<T>,
k: usize,
ann_type: String,
params_nn: &NearestNeighbourParams<T>,
seed: usize,
verbose: bool,
) -> (Vec<Vec<usize>>, Vec<Vec<T>>)
where
T: AnnSearchFloat,
HnswIndex<T>: HnswState<T>,
NNDescent<T>: ApplySortedUpdates<T> + NNDescentQuery<T>,
{
let ann_search = parse_ann_search(&ann_type).unwrap_or_default();
let (knn_indices, knn_dist) = match ann_search {
AnnSearch::Annoy => {
let index =
build_annoy_index(data, params_nn.dist_metric.clone(), params_nn.n_tree, seed);
query_annoy_self(&index, k + 1, params_nn.search_budget, true, verbose)
}
AnnSearch::Hnsw => {
let index = build_hnsw_index(
data,
params_nn.m,
params_nn.ef_construction,
¶ms_nn.dist_metric,
seed,
verbose,
);
query_hnsw_self(&index, k + 1, params_nn.ef_search, true, verbose)
}
AnnSearch::NNDescent => {
let index = build_nndescent_index(
data,
¶ms_nn.dist_metric,
params_nn.delta,
params_nn.diversify_prob,
None, None,
None,
None,
seed,
verbose,
);
query_nndescent_self(&index, k + 1, params_nn.ef_budget, true, verbose)
}
AnnSearch::BallTree => {
let index = build_balltree_index(data, params_nn.dist_metric.clone(), seed);
let budget = (data.nrows() as f32 * params_nn.bt_budget.to_f32().unwrap()) as usize;
query_balltree_self(&index, k + 1, Some(budget), true, verbose)
}
AnnSearch::Exhaustive => {
let index = build_exhaustive_index(data, ¶ms_nn.dist_metric);
query_exhaustive_self(&index, k + 1, true, verbose)
}
AnnSearch::Ivf => {
let index = build_ivf_index(
data,
params_nn.n_list,
None,
¶ms_nn.dist_metric,
seed,
verbose,
);
query_ivf_self(&index, k + 1, params_nn.n_probes, true, verbose)
}
AnnSearch::KmKnn => {
let index = build_kmknn_index(
data,
¶ms_nn.dist_metric,
params_nn.n_list,
None,
seed,
verbose,
);
query_kmknn_self(&index, k + 1, true, verbose)
}
};
let knn_dist = knn_dist.unwrap();
let knn_indices: Vec<Vec<usize>> = knn_indices
.into_par_iter()
.map(|mut v| v.drain(1..).collect())
.collect();
let knn_dist: Vec<Vec<T>> = knn_dist
.into_par_iter()
.map(|mut v| v.drain(1..).collect())
.collect();
(knn_indices, knn_dist)
}