use ann_search_rs::cpu::nndescent::NNDescentQuery;
use ann_search_rs::gpu::nndescent_gpu::NNDescentGpu;
use ann_search_rs::prelude::*;
use ann_search_rs::{
build_exhaustive_index_gpu, build_ivf_index_gpu, build_nndescent_index_gpu,
query_exhaustive_index_gpu_self, query_ivf_index_gpu_self, query_nndescent_index_gpu_self,
};
use cubecl::prelude::*;
use faer::MatRef;
use rayon::prelude::*;
use crate::prelude::*;
#[derive(Default)]
pub enum AnnSearchGpu {
#[default]
IvfGpu,
NNDescentGpu,
ExhaustiveGpu,
}
#[derive(Debug, Clone)]
pub struct NearestNeighbourParamsGpu<T> {
pub dist_metric: String,
pub n_list: Option<usize>,
pub n_probes: Option<usize>,
pub k: Option<usize>,
pub k_build: Option<usize>,
pub n_tree: Option<usize>,
pub delta: T,
pub rho: Option<T>,
pub beam_width: Option<usize>,
pub max_beam_iters: Option<usize>,
pub n_entry_points: Option<usize>,
}
impl<T> NearestNeighbourParamsGpu<T> {
#[allow(clippy::too_many_arguments)]
pub fn new(
dist_metric: String,
n_list: Option<usize>,
n_probes: Option<usize>,
k: Option<usize>,
k_build: Option<usize>,
n_tree: Option<usize>,
delta: T,
rho: Option<T>,
beam_width: Option<usize>,
max_beam_iters: Option<usize>,
n_entry_points: Option<usize>,
) -> Self {
Self {
dist_metric,
n_list,
n_probes,
k,
k_build,
n_tree,
delta,
rho,
beam_width,
max_beam_iters,
n_entry_points,
}
}
}
impl<T> Default for NearestNeighbourParamsGpu<T>
where
T: AnnSearchFloat,
{
fn default() -> Self {
Self {
dist_metric: "euclidean".to_string(),
n_list: None,
n_probes: None,
k: None,
k_build: None,
n_tree: None,
delta: T::from(0.001).unwrap(),
rho: None,
beam_width: None,
max_beam_iters: None,
n_entry_points: None,
}
}
}
pub fn parse_ann_search_gpu(s: &str) -> Option<AnnSearchGpu> {
match s.to_lowercase().as_str() {
"exhaustive_gpu" | "exhaustive" => Some(AnnSearchGpu::ExhaustiveGpu),
"ivf_gpu" | "ivf" => Some(AnnSearchGpu::IvfGpu),
"nndescent_gpu" | "nndescent" => Some(AnnSearchGpu::NNDescentGpu),
_ => None,
}
}
#[cfg(feature = "gpu")]
pub fn run_ann_search_gpu<T, R>(
data: MatRef<T>,
k: usize,
ann_type: String,
params_nn: &NearestNeighbourParamsGpu<T>,
device: R::Device,
seed: usize,
verbose: usize,
) -> ManifoldsKnnResults<T>
where
T: AnnSearchFloat + AnnSearchGpuFloat,
R: Runtime,
NNDescentGpu<T, R>: NNDescentQuery<T>,
{
let verbosity = parse_verbosity_level(verbose);
let ann_search = parse_ann_search_gpu(&ann_type).unwrap_or_else(|| {
println!("Unrecognised GPU-accelerated approximate nearest neighbour method provided: {:?}. Default to GPU IVF.", ann_type);
AnnSearchGpu::default()
});
let (knn_indices_raw, knn_dist) = match ann_search {
AnnSearchGpu::ExhaustiveGpu => {
let index = build_exhaustive_index_gpu::<T, R>(data, ¶ms_nn.dist_metric, device)?;
query_exhaustive_index_gpu_self(&index, k + 1, true, verbosity.detailed_verbosity())?
}
AnnSearchGpu::IvfGpu => {
let index = build_ivf_index_gpu::<T, R>(
data,
params_nn.n_list,
None,
¶ms_nn.dist_metric,
seed,
verbosity.detailed_verbosity(),
device,
)?;
query_ivf_index_gpu_self(
&index,
k + 1,
params_nn.n_probes,
None,
true,
verbosity.normal_verbosity(),
)?
}
AnnSearchGpu::NNDescentGpu => {
let mut index = build_nndescent_index_gpu::<T, R>(
data,
¶ms_nn.dist_metric,
params_nn.k,
params_nn.k_build,
None,
params_nn.n_tree,
params_nn.delta.to_f32().map(Some).unwrap_or(None),
params_nn.rho.map(|r| r.to_f32().unwrap()),
None,
seed,
verbosity.detailed_verbosity(),
false,
device,
)?;
let query_params = CagraGpuSearchParams::new(
params_nn.beam_width,
params_nn.max_beam_iters,
params_nn.n_entry_points,
None,
);
query_nndescent_index_gpu_self(&mut index, k + 1, Some(query_params), true)?
}
};
let knn_dist = knn_dist.unwrap();
let (knn_indices, knn_dist): (Vec<Vec<usize>>, Vec<Vec<T>>) = knn_indices_raw
.into_par_iter()
.zip(knn_dist.into_par_iter())
.enumerate()
.map(|(i, (idx, dist))| {
idx.into_iter()
.zip(dist)
.filter(|(j, _)| *j != i)
.take(k)
.unzip()
})
.unzip();
Ok((knn_indices, knn_dist))
}