use ann_search_rs::cpu::hnsw::{HnswIndex, HnswState};
use ann_search_rs::cpu::nndescent::{ApplySortedUpdates, NNDescent, NNDescentQuery};
use ann_search_rs::prelude::*;
use ann_search_rs::*;
use faer::MatRef;
use num_traits::Float;
use rayon::prelude::*;
use std::default::Default;
use crate::prelude::*;
#[derive(Default, Clone, Copy, Debug)]
pub enum AnnSearch {
#[default]
KmKnn,
Hnsw,
NNDescent,
Annoy,
Ivf,
BallTree,
Exhaustive,
}
#[derive(Debug, Clone)]
pub struct NearestNeighbourParams<T> {
pub dist_metric: String,
pub n_tree: usize,
pub search_budget: Option<usize>,
pub m: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub diversify_prob: T,
pub delta: T,
pub ef_budget: Option<usize>,
pub bt_budget: T,
pub n_list: Option<usize>,
pub n_probes: Option<usize>,
}
impl<T> NearestNeighbourParams<T> {
#[allow(clippy::too_many_arguments)]
pub fn new(
dist_metric: String,
n_tree: usize,
search_budget: Option<usize>,
m: usize,
ef_construction: usize,
ef_search: usize,
diversify_prob: T,
delta: T,
ef_budget: Option<usize>,
bt_budget: T,
n_list: Option<usize>,
n_probes: Option<usize>,
) -> Self {
Self {
dist_metric,
n_tree,
search_budget,
m,
ef_construction,
ef_search,
diversify_prob,
delta,
ef_budget,
bt_budget,
n_list,
n_probes,
}
}
}
impl<T> Default for NearestNeighbourParams<T>
where
T: Float,
{
fn default() -> Self {
Self {
dist_metric: "euclidean".to_string(),
n_tree: 50,
search_budget: None,
m: 16,
ef_construction: 200,
ef_search: 100,
diversify_prob: T::from(0.0).unwrap(),
delta: T::from(0.001).unwrap(),
ef_budget: None,
bt_budget: T::from(0.1).unwrap(),
n_list: None,
n_probes: None,
}
}
}
pub fn parse_ann_search(s: &str) -> Option<AnnSearch> {
match s.to_lowercase().as_str() {
"annoy" => Some(AnnSearch::Annoy),
"balltree" => Some(AnnSearch::BallTree),
"exhaustive" => Some(AnnSearch::Exhaustive),
"hnsw" => Some(AnnSearch::Hnsw),
"ivf" => Some(AnnSearch::Ivf),
"kmknn" => Some(AnnSearch::KmKnn),
"nndescent" => Some(AnnSearch::NNDescent),
_ => None,
}
}
pub fn run_ann_search<T>(
data: MatRef<T>,
k: usize,
ann_type: String,
params_nn: &NearestNeighbourParams<T>,
seed: usize,
verbose: usize,
) -> ManifoldsKnnResults<T>
where
T: AnnSearchFloat,
HnswIndex<T>: HnswState<T>,
NNDescent<T>: ApplySortedUpdates<T> + NNDescentQuery<T>,
{
let verbosity = parse_verbosity_level(verbose);
let ann_search = parse_ann_search(&ann_type).unwrap_or_else(|| {
println!(
"Unrecognised approximate nearest neighbour method provided: {:?}. Default to KmKnn.",
ann_type
);
AnnSearch::default()
});
let (knn_indices, knn_dist) = match ann_search {
AnnSearch::Annoy => {
let index = build_annoy_index(data, ¶ms_nn.dist_metric, params_nn.n_tree, seed)?;
query_annoy_self(
&index,
k + 1,
params_nn.search_budget,
true,
verbosity.detailed_verbosity(),
)?
}
AnnSearch::Hnsw => {
let index = build_hnsw_index(
data,
params_nn.m,
params_nn.ef_construction,
¶ms_nn.dist_metric,
seed,
verbosity.detailed_verbosity(),
);
query_hnsw_self(
&index,
k + 1,
params_nn.ef_search,
true,
verbosity.normal_verbosity(),
)?
}
AnnSearch::NNDescent => {
let index = build_nndescent_index(
data,
¶ms_nn.dist_metric,
params_nn.delta,
params_nn.diversify_prob,
None, None,
None,
None,
seed,
verbosity.detailed_verbosity(),
)?;
query_nndescent_self(
&index,
k + 1,
params_nn.ef_budget,
true,
verbosity.normal_verbosity(),
)?
}
AnnSearch::BallTree => {
let index = build_balltree_index(data, ¶ms_nn.dist_metric, seed)?;
let budget = (data.nrows() as f32 * params_nn.bt_budget.to_f32().unwrap()) as usize;
query_balltree_self(
&index,
k + 1,
Some(budget),
true,
verbosity.normal_verbosity(),
)?
}
AnnSearch::Exhaustive => {
let index = build_exhaustive_index(data, ¶ms_nn.dist_metric);
query_exhaustive_self(&index, k + 1, true, verbosity.normal_verbosity())?
}
AnnSearch::Ivf => {
let index = build_ivf_index(
data,
params_nn.n_list,
None,
¶ms_nn.dist_metric,
seed,
verbosity.detailed_verbosity(),
)?;
query_ivf_self(
&index,
k + 1,
params_nn.n_probes,
true,
verbosity.normal_verbosity(),
)?
}
AnnSearch::KmKnn => {
let index = build_kmknn_index(
data,
¶ms_nn.dist_metric,
params_nn.n_list,
None,
seed,
verbosity.detailed_verbosity(),
)?;
query_kmknn_self(&index, k + 1, true, verbosity.normal_verbosity())?
}
};
let knn_dist = knn_dist.unwrap();
let knn_indices: Vec<Vec<usize>> = knn_indices
.into_par_iter()
.map(|mut v| v.drain(1..).collect())
.collect();
let knn_dist: Vec<Vec<T>> = knn_dist
.into_par_iter()
.map(|mut v| v.drain(1..).collect())
.collect();
Ok((knn_indices, knn_dist))
}