Skip to main content

ann_search_rs/
lib.rs

1//! Optimised vector searches in Rust originally designed for single cell
2//! applications, but now as additionally GPU-accelerated, quantised (with
3//! binary indices) vector searches leveraging Rust's performance under the
4//! hood.
5//!
6//! ## Feature flags
7#![doc = document_features::document_features!()]
8#![allow(clippy::needless_range_loop)] // I want these loops!
9#![warn(missing_docs)]
10
11#[cfg(feature = "mimalloc")]
12use mimalloc::MiMalloc;
13
14// MiMalloc for better allocations
15#[cfg(feature = "mimalloc")]
16#[global_allocator]
17static GLOBAL: MiMalloc = MiMalloc;
18
19pub mod cpu;
20pub mod prelude;
21pub mod utils;
22
23#[cfg(feature = "gpu")]
24pub mod gpu;
25
26#[cfg(feature = "quantised")]
27pub mod quantised;
28
29#[cfg(feature = "binary")]
30pub mod binary;
31
32use faer::MatRef;
33use rayon::prelude::*;
34
35use std::sync::{
36    atomic::{AtomicUsize, Ordering},
37    Arc,
38};
39use thousands::*;
40
41#[cfg(feature = "gpu")]
42use cubecl::prelude::*;
43#[cfg(feature = "quantised")]
44use std::ops::AddAssign;
45
46#[cfg(feature = "binary")]
47use bytemuck::Pod;
48#[cfg(feature = "binary")]
49use std::path::Path;
50
51use crate::cpu::{
52    annoy::*, ball_tree::*, exhaustive::*, hnsw::*, ivf::*, kd_forest::*, lsh::*, nndescent::*,
53    vamana::*,
54};
55use crate::prelude::*;
56
57#[cfg(feature = "binary")]
58use crate::binary::{exhaustive_binary::*, exhaustive_rabitq::*, ivf_binary::*, ivf_rabitq::*};
59#[cfg(feature = "gpu")]
60use crate::gpu::{exhaustive_gpu::*, ivf_gpu::*, nndescent_gpu::*};
61#[cfg(feature = "quantised")]
62use crate::quantised::{
63    exhaustive_bf16::*, exhaustive_opq::*, exhaustive_pq::*, exhaustive_sq8::*, ivf_bf16::*,
64    ivf_opq::*, ivf_pq::*, ivf_sq8::*,
65};
66
67////////////
68// Helper //
69////////////
70
71/// Helper function to execute parallel queries across samples
72///
73/// ### Params
74///
75/// * `n_samples` - Number of samples to query
76/// * `return_dist` - Whether to return distances alongside indices
77/// * `verbose` - Print progress information every 100,000 samples
78/// * `query_fn` - Closure that takes a sample index and returns (indices,
79///   distances)
80///
81/// ### Returns
82///
83/// A tuple of `(knn_indices, optional distances)`
84fn query_parallel<T, F>(
85    n_samples: usize,
86    return_dist: bool,
87    verbose: bool,
88    query_fn: F,
89) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
90where
91    T: Send,
92    F: Fn(usize) -> (Vec<usize>, Vec<T>) + Sync,
93{
94    let counter = Arc::new(AtomicUsize::new(0));
95
96    let results: Vec<(Vec<usize>, Vec<T>)> = (0..n_samples)
97        .into_par_iter()
98        .map(|i| {
99            let result = query_fn(i);
100            if verbose {
101                let count = counter.fetch_add(1, Ordering::Relaxed) + 1;
102                if count.is_multiple_of(100_000) {
103                    println!(
104                        "  Processed {} / {} samples.",
105                        count.separate_with_underscores(),
106                        n_samples.separate_with_underscores()
107                    );
108                }
109            }
110            result
111        })
112        .collect();
113
114    if return_dist {
115        let (indices, distances) = results.into_iter().unzip();
116        (indices, Some(distances))
117    } else {
118        let indices: Vec<Vec<usize>> = results.into_iter().map(|(idx, _)| idx).collect();
119        (indices, None)
120    }
121}
122
123/// Helper function to execute parallel queries with boolean flags
124///
125/// ### Params
126///
127/// * `n_samples` - Number of samples to query
128/// * `return_dist` - Whether to return distances alongside indices
129/// * `verbose` - Print progress information every 100,000 samples
130/// * `query_fn` - Closure that takes a sample index and returns (indices,
131///   distances, flag)
132///
133/// ### Returns
134///
135/// A tuple of `(knn_indices, optional distances)`
136///
137/// ### Note
138///
139/// This variant tracks boolean flags returned by the query function. If more
140/// than 1% of queries return true flags, a warning is printed. Used primarily
141/// for LSH queries where the flag indicates samples not represented in hash
142/// buckets.
143fn query_parallel_with_flags<T, F>(
144    n_samples: usize,
145    return_dist: bool,
146    verbose: bool,
147    query_fn: F,
148) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
149where
150    T: Send,
151    F: Fn(usize) -> (Vec<usize>, Vec<T>, bool) + Sync,
152{
153    let counter = Arc::new(AtomicUsize::new(0));
154
155    let results: Vec<(Vec<usize>, Vec<T>, bool)> = (0..n_samples)
156        .into_par_iter()
157        .map(|i| {
158            let result = query_fn(i);
159            if verbose {
160                let count = counter.fetch_add(1, Ordering::Relaxed) + 1;
161                if count.is_multiple_of(100_000) {
162                    println!(
163                        " Processed {} / {} samples.",
164                        count.separate_with_underscores(),
165                        n_samples.separate_with_underscores()
166                    );
167                }
168            }
169            result
170        })
171        .collect();
172
173    let mut random: usize = 0;
174    let mut indices: Vec<Vec<usize>> = Vec::with_capacity(results.len());
175    let mut distances: Vec<Vec<T>> = Vec::with_capacity(results.len());
176
177    for (idx, dist, rnd) in results {
178        if rnd {
179            random += 1;
180        }
181        indices.push(idx);
182        distances.push(dist);
183    }
184
185    if (random as f32) / (n_samples as f32) >= 0.01 {
186        println!("More than 1% of samples were not represented in the buckets.");
187        println!("Please verify underlying data");
188    }
189
190    if return_dist {
191        (indices, Some(distances))
192    } else {
193        (indices, None)
194    }
195}
196
197////////////////
198// Exhaustive //
199////////////////
200
201/// Build an exhaustive index
202///
203/// ### Params
204///
205/// * `mat` - The initial matrix with samples x features
206/// * `dist_metric` - Distance metric: "euclidean" or "cosine"
207///
208/// ### Returns
209///
210/// The initialised `ExhausiveIndex`
211pub fn build_exhaustive_index<T>(mat: MatRef<T>, dist_metric: &str) -> ExhaustiveIndex<T>
212where
213    T: AnnSearchFloat,
214{
215    let metric = parse_ann_dist(dist_metric).unwrap_or_default();
216    ExhaustiveIndex::new(mat, metric)
217}
218
219/// Helper function to query a given exhaustive index
220///
221/// ### Params
222///
223/// * `query_mat` - The query matrix containing the samples × features
224/// * `index` - The exhaustive index
225/// * `k` - Number of neighbours to return
226/// * `return_dist` - Shall the distances be returned
227/// * `verbose` - Controls verbosity of the function
228///
229/// ### Returns
230///
231/// A tuple of `(knn_indices, optional distances)`
232pub fn query_exhaustive_index<T>(
233    query_mat: MatRef<T>,
234    index: &ExhaustiveIndex<T>,
235    k: usize,
236    return_dist: bool,
237    verbose: bool,
238) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
239where
240    T: AnnSearchFloat,
241{
242    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
243        index.query_row(query_mat.row(i), k)
244    })
245}
246
247/// Helper function to self query an exhaustive index
248///
249/// This function will generate a full kNN graph based on the internal data.
250///
251/// ### Params
252///
253/// * `index` - The exhaustive index
254/// * `k` - Number of neighbours to return
255/// * `return_dist` - Shall the distances be returned
256/// * `verbose` - Controls verbosity of the function
257///
258/// ### Returns
259///
260/// A tuple of `(knn_indices, optional distances)`
261pub fn query_exhaustive_self<T>(
262    index: &ExhaustiveIndex<T>,
263    k: usize,
264    return_dist: bool,
265    verbose: bool,
266) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
267where
268    T: AnnSearchFloat,
269{
270    index.generate_knn(k, return_dist, verbose)
271}
272
273///////////
274// Annoy //
275///////////
276
277/// Build an Annoy index
278///
279/// ### Params
280///
281/// * `mat` - The data matrix. Rows represent the samples, columns represent
282///   the embedding dimensions
283/// * `n_trees` - Number of trees to use to build the index
284/// * `seed` - Random seed for reproducibility
285///
286/// ### Return
287///
288/// The `AnnoyIndex`.
289pub fn build_annoy_index<T>(
290    mat: MatRef<T>,
291    dist_metric: String,
292    n_trees: usize,
293    seed: usize,
294) -> AnnoyIndex<T>
295where
296    T: AnnSearchFloat,
297{
298    let ann_dist = parse_ann_dist(&dist_metric).unwrap_or_default();
299
300    AnnoyIndex::new(mat, n_trees, ann_dist, seed)
301}
302
303/// Helper function to query a given Annoy index
304///
305/// ### Params
306///
307/// * `query_mat` - The query matrix containing the samples x features
308/// * `k` - Number of neighbours to return
309/// * `index` - The AnnoyIndex to query.
310/// * `search_budget` - Search budget per tree
311/// * `return_dist` - Shall the distances between the different points be
312///   returned
313/// * `verbose` - Controls verbosity of the function
314///
315/// ### Returns
316///
317/// A tuple of `(knn_indices, optional distances)`
318pub fn query_annoy_index<T>(
319    query_mat: MatRef<T>,
320    index: &AnnoyIndex<T>,
321    k: usize,
322    search_budget: Option<usize>,
323    return_dist: bool,
324    verbose: bool,
325) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
326where
327    T: AnnSearchFloat,
328{
329    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
330        index.query_row(query_mat.row(i), k, search_budget)
331    })
332}
333
334/// Helper function to self query the Annoy index
335///
336/// This function will generate a full kNN graph based on the internal data.
337///
338/// ### Params
339///
340/// * `k` - Number of neighbours to return
341/// * `index` - The AnnoyIndex to query.
342/// * `search_budget` - Search budget per tree
343/// * `return_dist` - Shall the distances between the different points be
344///   returned
345/// * `verbose` - Controls verbosity of the function
346///
347/// ### Returns
348///
349/// A tuple of `(knn_indices, optional distances)`
350pub fn query_annoy_self<T>(
351    index: &AnnoyIndex<T>,
352    k: usize,
353    search_budget: Option<usize>,
354    return_dist: bool,
355    verbose: bool,
356) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
357where
358    T: AnnSearchFloat,
359{
360    index.generate_knn(k, search_budget, return_dist, verbose)
361}
362
363//////////////
364// BallTree //
365//////////////
366
367/// Build a BallTree index
368///
369/// ### Params
370///
371/// * `mat` - The data matrix. Rows represent the samples, columns represent
372///   the embedding dimensions
373/// * `dist_metric` - Distance metric to use
374/// * `seed` - Random seed for reproducibility
375///
376/// ### Return
377///
378/// The `BallTreeIndex`.
379pub fn build_balltree_index<T>(mat: MatRef<T>, dist_metric: String, seed: usize) -> BallTreeIndex<T>
380where
381    T: AnnSearchFloat,
382{
383    let ann_dist = parse_ann_dist(&dist_metric).unwrap_or_default();
384    BallTreeIndex::new(mat, ann_dist, seed)
385}
386
387/// Helper function to query a given BallTree index
388///
389/// ### Params
390///
391/// * `query_mat` - The query matrix containing the samples x features
392/// * `k` - Number of neighbours to return
393/// * `index` - The BallTreeIndex to query
394/// * `search_budget` - Search budget (number of items to examine)
395/// * `return_dist` - Shall the distances between the different points be
396///   returned
397/// * `verbose` - Controls verbosity of the function
398///
399/// ### Returns
400///
401/// A tuple of `(knn_indices, optional distances)`
402pub fn query_balltree_index<T>(
403    query_mat: MatRef<T>,
404    index: &BallTreeIndex<T>,
405    k: usize,
406    search_budget: Option<usize>,
407    return_dist: bool,
408    verbose: bool,
409) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
410where
411    T: AnnSearchFloat,
412{
413    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
414        index.query_row(query_mat.row(i), k, search_budget)
415    })
416}
417
418/// Helper function to self query the BallTree index
419///
420/// This function will generate a full kNN graph based on the internal data.
421///
422/// ### Params
423///
424/// * `k` - Number of neighbours to return
425/// * `index` - The BallTreeIndex to query
426/// * `search_budget` - Search budget (number of items to examine)
427/// * `return_dist` - Shall the distances between the different points be
428///   returned
429/// * `verbose` - Controls verbosity of the function
430///
431/// ### Returns
432///
433/// A tuple of `(knn_indices, optional distances)`
434pub fn query_balltree_self<T>(
435    index: &BallTreeIndex<T>,
436    k: usize,
437    search_budget: Option<usize>,
438    return_dist: bool,
439    verbose: bool,
440) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
441where
442    T: AnnSearchFloat,
443{
444    index.generate_knn(k, search_budget, return_dist, verbose)
445}
446
447//////////
448// HNSW //
449//////////
450
451/// Build an HNSW index
452///
453/// ### Params
454///
455/// * `mat` - The data matrix. Rows represent the samples, columns represent
456///   the embedding dimensions.
457/// * `m` - Number of bidirectional connections per layer.
458/// * `ef_construction` - Size of candidate list during construction.
459/// * `dist_metric` - The distance metric to use. One of `"euclidean"` or
460///   `"cosine"`.
461/// * `seed` - Random seed for reproducibility
462///
463/// ### Return
464///
465/// The `HnswIndex`.
466pub fn build_hnsw_index<T>(
467    mat: MatRef<T>,
468    m: usize,
469    ef_construction: usize,
470    dist_metric: &str,
471    seed: usize,
472    verbose: bool,
473) -> HnswIndex<T>
474where
475    T: AnnSearchFloat,
476    HnswIndex<T>: HnswState<T>,
477{
478    HnswIndex::build(mat, m, ef_construction, dist_metric, seed, verbose)
479}
480
481/// Helper function to query a given HNSW index
482///
483/// ### Params
484///
485/// * `query_mat` - The query matrix containing the samples x features
486/// * `index` - Reference to the built HNSW index
487/// * `k` - Number of neighbours to return
488/// * `ef_search` - Size of candidate list during search (higher = better
489///   recall, slower)
490/// * `return_dist` - Shall the distances between the different points be
491///   returned
492/// * `verbose` - Print progress information
493///
494/// ### Returns
495///
496/// A tuple of `(knn_indices, optional distances)`
497///
498/// ### Note
499///
500/// The distance metric is determined at index build time and cannot be changed
501/// during querying.
502pub fn query_hnsw_index<T>(
503    query_mat: MatRef<T>,
504    index: &HnswIndex<T>,
505    k: usize,
506    ef_search: usize,
507    return_dist: bool,
508    verbose: bool,
509) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
510where
511    T: AnnSearchFloat,
512    HnswIndex<T>: HnswState<T>,
513{
514    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
515        index.query_row(query_mat.row(i), k, ef_search)
516    })
517}
518
519/// Helper function to self query the HNSW index
520///
521/// This function will generate a full kNN graph based on the internal data.
522///
523/// ### Params
524///
525/// * `k` - Number of neighbours to return
526/// * `index` - Reference to the built HNSW index
527/// * `k` - Number of neighbours to return
528/// * `ef_search` - Size of candidate list during search (higher = better
529///   recall, slower)
530/// * `return_dist` - Shall the distances between the different points be
531///   returned
532/// * `verbose` - Print progress information
533///
534/// ### Returns
535///
536/// A tuple of `(knn_indices, optional distances)`
537pub fn query_hnsw_self<T>(
538    index: &HnswIndex<T>,
539    k: usize,
540    ef_search: usize,
541    return_dist: bool,
542    verbose: bool,
543) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
544where
545    T: AnnSearchFloat,
546    HnswIndex<T>: HnswState<T>,
547{
548    index.generate_knn(k, ef_search, return_dist, verbose)
549}
550
551/////////
552// IVF //
553/////////
554
555/// Build an IVF index
556///
557/// ### Params
558///
559/// * `mat` - The data matrix. Rows represent the samples, columns represent
560///   the embedding dimensions
561/// * `nlist` - Number of clusters to create
562/// * `max_iters` - Maximum k-means iterations (defaults to 30 if None)
563/// * `dist_metric` - The distance metric to use. One of `"euclidean"` or
564///   `"cosine"`
565/// * `seed` - Random seed for reproducibility
566/// * `verbose` - Print progress information during index construction
567///
568/// ### Return
569///
570/// The `IvfIndex`.
571pub fn build_ivf_index<T>(
572    mat: MatRef<T>,
573    nlist: Option<usize>,
574    max_iters: Option<usize>,
575    dist_metric: &str,
576    seed: usize,
577    verbose: bool,
578) -> IvfIndex<T>
579where
580    T: AnnSearchFloat,
581{
582    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
583
584    IvfIndex::build(mat, ann_dist, nlist, max_iters, seed, verbose)
585}
586
587/// Helper function to query a given IVF index
588///
589/// ### Params
590///
591/// * `query_mat` - The query matrix containing the samples x features
592/// * `index` - Reference to the built IVF index
593/// * `k` - Number of neighbours to return
594/// * `nprobe` - Number of clusters to search (defaults to min(nlist/10, 10))
595///   Higher values improve recall at the cost of speed
596/// * `return_dist` - Shall the distances between the different points be
597///   returned
598/// * `verbose` - Print progress information
599///
600/// ### Returns
601///
602/// A tuple of `(knn_indices, optional distances)`
603///
604/// ### Note
605///
606/// The distance metric is determined at index build time and cannot be changed
607/// during querying.
608pub fn query_ivf_index<T>(
609    query_mat: MatRef<T>,
610    index: &IvfIndex<T>,
611    k: usize,
612    nprobe: Option<usize>,
613    return_dist: bool,
614    verbose: bool,
615) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
616where
617    T: AnnSearchFloat,
618{
619    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
620        index.query_row(query_mat.row(i), k, nprobe)
621    })
622}
623
624/// Helper function to self query an IVF index
625///
626/// This function will generate a full kNN graph based on the internal data. To
627/// accelerate the process, it will leverage the information on the Voronoi
628/// cells under the hood and query nearby cells per given internal vector.
629///
630/// ### Params
631///
632/// * `query_mat` - The query matrix containing the samples x features
633/// * `index` - Reference to the built IVF index
634/// * `k` - Number of neighbours to return
635/// * `nprobe` - Number of clusters to search (defaults to min(nlist/10, 10))
636///   Higher values improve recall at the cost of speed
637/// * `return_dist` - Shall the distances between the different points be
638///   returned
639/// * `verbose` - Print progress information
640///
641/// ### Returns
642///
643/// A tuple of `(knn_indices, optional distances)`
644///
645/// ### Note
646///
647/// The distance metric is determined at index build time and cannot be changed
648/// during querying.
649pub fn query_ivf_self<T>(
650    index: &IvfIndex<T>,
651    k: usize,
652    nprobe: Option<usize>,
653    return_dist: bool,
654    verbose: bool,
655) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
656where
657    T: AnnSearchFloat,
658{
659    index.generate_knn(k, nprobe, return_dist, verbose)
660}
661
662////////////
663// KdTree //
664////////////
665
666/// Build a Kd-Tree forest index
667///
668/// ### Params
669///
670/// * `mat` - The data matrix. Rows represent the samples, columns represent
671///   the embedding dimensions
672/// * `dist_metric` - Distance metric string ("euclidean" or "cosine")
673/// * `n_trees` - Number of trees to use to build the index
674/// * `seed` - Random seed for reproducibility
675/// * `overlap` - Spill-tree overlap fraction. If None, uses the default
676///   (5%). If Some(0.0), builds a standard Kd-tree without overlap.
677///
678/// ### Return
679///
680/// The `KdTreeIndex`.
681pub fn build_kd_tree_index<T>(
682    mat: MatRef<T>,
683    dist_metric: String,
684    n_trees: usize,
685    seed: usize,
686) -> KdTreeIndex<T>
687where
688    T: AnnSearchFloat,
689{
690    let ann_dist = parse_ann_dist(&dist_metric).unwrap_or_default();
691
692    KdTreeIndex::new(mat, n_trees, ann_dist, seed)
693}
694
695/// Helper function to query a given Kd-Tree index
696///
697/// ### Params
698///
699/// * `query_mat` - The query matrix containing the samples x features
700/// * `index` - The KdTreeIndex to query
701/// * `k` - Number of neighbours to return
702/// * `search_budget` - Search budget (total items to examine)
703/// * `return_dist` - Shall the distances between the different points be
704///   returned
705/// * `verbose` - Controls verbosity of the function
706///
707/// ### Returns
708///
709/// A tuple of `(knn_indices, optional distances)`
710pub fn query_kd_tree_index<T>(
711    query_mat: MatRef<T>,
712    index: &KdTreeIndex<T>,
713    k: usize,
714    search_budget: Option<usize>,
715    return_dist: bool,
716    verbose: bool,
717) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
718where
719    T: AnnSearchFloat,
720{
721    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
722        index.query_row(query_mat.row(i), k, search_budget)
723    })
724}
725
726/// Helper function to self query the Kd-Tree index
727///
728/// This function will generate a full kNN graph based on the internal data.
729///
730/// ### Params
731///
732/// * `index` - The KdTreeIndex to query
733/// * `k` - Number of neighbours to return
734/// * `search_budget` - Search budget (total items to examine)
735/// * `return_dist` - Shall the distances between the different points be
736///   returned
737/// * `verbose` - Controls verbosity of the function
738///
739/// ### Returns
740///
741/// A tuple of `(knn_indices, optional distances)`
742pub fn query_kd_tree_self<T>(
743    index: &KdTreeIndex<T>,
744    k: usize,
745    search_budget: Option<usize>,
746    return_dist: bool,
747    verbose: bool,
748) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
749where
750    T: AnnSearchFloat,
751{
752    index.generate_knn(k, search_budget, return_dist, verbose)
753}
754
755/////////
756// LSH //
757/////////
758
759/// Build the LSH index
760///
761/// ### Params
762///
763/// * `mat` - The initial matrix with samples x features
764/// * `dist_metric` - Distance metric: "euclidean" or "cosine"
765/// * `num_tables` - Number of HashMaps to use (usually something 20 to 100)
766/// * `bits_per_hash` - How many bits per hash. Lower values (8) usually yield
767///   better Recall with higher query time; higher values (16) have worse Recall
768///   but faster query time
769/// * `seed` - Random seed for reproducibility
770///
771/// ### Returns
772///
773/// The ready LSH index for querying
774pub fn build_lsh_index<T>(
775    mat: MatRef<T>,
776    dist_metric: &str,
777    num_tables: usize,
778    bits_per_hash: usize,
779    seed: usize,
780) -> LSHIndex<T>
781where
782    T: AnnSearchFloat,
783{
784    let metric = parse_ann_dist(dist_metric).unwrap_or_default();
785    LSHIndex::new(mat, metric, num_tables, bits_per_hash, seed)
786}
787
788/// Helper function to query a given LSH index
789///
790/// This function will generate a full kNN graph based on the internal data.
791///
792/// ### Params
793///
794/// * `query_mat` - The query matrix containing the samples × features
795/// * `index` - The LSH index
796/// * `k` - Number of neighbours to return
797/// * `max_candidates` - Optional number to limit the candidate selection per
798///   given table. Makes the querying faster at cost of Recall.
799/// * `nprobe` - Number of additional buckets to probe per table. Will identify
800///   the closest hash tables and use bit flipping to investigate these.
801/// * `return_dist` - Shall the distances be returned
802/// * `verbose` - Controls verbosity of the function
803///
804/// ### Returns
805///
806/// A tuple of `(knn_indices, optional distances)`
807pub fn query_lsh_index<T>(
808    query_mat: MatRef<T>,
809    index: &LSHIndex<T>,
810    k: usize,
811    n_probe: usize,
812    max_candidates: Option<usize>,
813    return_dist: bool,
814    verbose: bool,
815) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
816where
817    T: AnnSearchFloat,
818{
819    query_parallel_with_flags(query_mat.nrows(), return_dist, verbose, |i| {
820        index.query_row(query_mat.row(i), k, max_candidates, n_probe)
821    })
822}
823
824/// Helper function to self query an LSH index
825///
826/// ### Params
827///
828/// * `index` - The LSH index
829/// * `k` - Number of neighbours to return
830/// * `max_candidates` - Optional number to limit the candidate selection per
831///   given table. Makes the querying faster at cost of Recall.
832/// * `n_probe` - Optional number of additional buckets to probe per table. Will
833///   identify the closest hash tables and use bit flipping to investigate
834///   these. Defaults to half the number of bits.
835/// * `return_dist` - Shall the distances be returned
836/// * `verbose` - Controls verbosity of the function
837///
838/// ### Returns
839///
840/// A tuple of `(knn_indices, optional distances)`
841pub fn query_lsh_self<T>(
842    index: &LSHIndex<T>,
843    k: usize,
844    n_probe: Option<usize>,
845    max_candidates: Option<usize>,
846    return_dist: bool,
847    verbose: bool,
848) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
849where
850    T: AnnSearchFloat,
851{
852    let n_probe = n_probe.unwrap_or(index.num_bits() / 2);
853
854    index.generate_knn(k, max_candidates, n_probe, return_dist, verbose)
855}
856
857///////////////
858// NNDescent //
859///////////////
860
861/// Build an NNDescent index
862///
863/// ### Params
864///
865/// * `mat` - The data matrix. Rows represent the samples, columns represent
866///   the embedding dimensions.
867/// * `k` - Number of neighbours for the k-NN graph.
868/// * `dist_metric` - The distance metric to use. One of `"euclidean"` or
869///   `"cosine"`.
870/// * `max_iter` - Maximum iterations for the algorithm.
871/// * `delta` - Early stop criterium for the algorithm.
872/// * `rho` - Sampling rate for the old neighbours. Will adaptively decrease
873///   over time.
874/// * `diversify_prob` - Probability of pruning redundant edges (1.0 = always prune)
875/// * `seed` - Random seed for reproducibility
876/// * `verbose` - Controls verbosity of the algorithm
877///
878/// ### Return
879///
880/// The `NNDescent` index.
881#[allow(clippy::too_many_arguments)]
882pub fn build_nndescent_index<T>(
883    mat: MatRef<T>,
884    dist_metric: &str,
885    delta: T,
886    diversify_prob: T,
887    k: Option<usize>,
888    max_iter: Option<usize>,
889    max_candidates: Option<usize>,
890    n_tree: Option<usize>,
891    seed: usize,
892    verbose: bool,
893) -> NNDescent<T>
894where
895    T: AnnSearchFloat,
896    NNDescent<T>: ApplySortedUpdates<T>,
897    NNDescent<T>: NNDescentQuery<T>,
898{
899    let metric = parse_ann_dist(dist_metric).unwrap_or(Dist::Cosine);
900    NNDescent::new(
901        mat,
902        metric,
903        k,
904        max_candidates,
905        max_iter,
906        n_tree,
907        delta,
908        diversify_prob,
909        seed,
910        verbose,
911    )
912}
913
914/// Helper function to query a given NNDescent index
915///
916/// ### Params
917///
918/// * `query_mat` - The query matrix containing the samples x features
919/// * `index` - Reference to the built NNDescent index
920/// * `k` - Number of neighbours to return
921/// * `ef_search` -
922/// * `return_dist` - Shall the distances between the different points be
923///   returned
924/// * `verbose` - Print progress information
925///
926/// ### Returns
927///
928/// A tuple of `(knn_indices, optional distances)`
929///
930/// ### Note
931///
932/// The distance metric is determined at index build time and cannot be changed
933/// during querying.
934pub fn query_nndescent_index<T>(
935    query_mat: MatRef<T>,
936    index: &NNDescent<T>,
937    k: usize,
938    ef_search: Option<usize>,
939    return_dist: bool,
940    verbose: bool,
941) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
942where
943    T: AnnSearchFloat,
944    NNDescent<T>: ApplySortedUpdates<T>,
945    NNDescent<T>: NNDescentQuery<T>,
946{
947    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
948        index.query_row(query_mat.row(i), k, ef_search)
949    })
950}
951
952/// Helper function to self query the NNDescent index
953///
954/// This function will generate a full kNN graph based on the internal data.
955///
956/// ### Params
957///
958/// * `index` - Reference to the built NNDescent index
959/// * `k` - Number of neighbours to return
960/// * `ef_search` -
961/// * `return_dist` - Shall the distances between the different points be
962///   returned
963/// * `verbose` - Print progress information
964///
965/// ### Returns
966///
967/// A tuple of `(knn_indices, optional distances)`
968///
969/// ### Note
970///
971/// The distance metric is determined at index build time and cannot be changed
972/// during querying.
973pub fn query_nndescent_self<T>(
974    index: &NNDescent<T>,
975    k: usize,
976    ef_search: Option<usize>,
977    return_dist: bool,
978    verbose: bool,
979) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
980where
981    T: AnnSearchFloat,
982    NNDescent<T>: ApplySortedUpdates<T>,
983    NNDescent<T>: NNDescentQuery<T>,
984{
985    index.generate_knn(k, ef_search, return_dist, verbose)
986}
987
988////////////
989// Vamana //
990////////////
991
992/// Build a Vamana index
993///
994/// ### Params
995///
996/// * `mat` - The data matrix. Rows are samples, columns are dimensions.
997/// * `r` - Maximum out-degree (edges per node).
998/// * `l_build` - Beam width during construction.
999/// * `alpha_pass1` - Pruning alpha for pass 1 (typically 1.0).
1000/// * `alpha_pass2` - Pruning alpha for pass 2 (typically 1.2–1.5).
1001/// * `dist_metric` - One of `"euclidean"` or `"cosine"`.
1002/// * `seed` - Random seed for reproducibility.
1003///
1004/// ### Returns
1005///
1006/// The built `VamanaIndex`.
1007pub fn build_vamana_index<T>(
1008    mat: MatRef<T>,
1009    r: usize,
1010    l_build: usize,
1011    alpha_pass1: f32,
1012    alpha_pass2: f32,
1013    dist_metric: &str,
1014    seed: usize,
1015) -> VamanaIndex<T>
1016where
1017    T: AnnSearchFloat,
1018    VamanaIndex<T>: VamanaState<T>,
1019{
1020    let metric = parse_ann_dist(dist_metric).unwrap_or(Dist::Euclidean);
1021    VamanaIndex::build(mat, metric, r, l_build, alpha_pass1, alpha_pass2, seed)
1022}
1023
1024/// Query a Vamana index with an external query matrix
1025///
1026/// ### Params
1027///
1028/// * `query_mat` - Query matrix (samples × features).
1029/// * `index` - Reference to the built index.
1030/// * `k` - Number of neighbours to return.
1031/// * `ef_search` - Optional beam width override. Defaults to 100 inside the
1032///   index if `None`.
1033/// * `return_dist` - Whether to return distances.
1034/// * `verbose` - Print progress every 100,000 samples.
1035///
1036/// ### Returns
1037///
1038/// A tuple of `(knn_indices, optional distances)`.
1039pub fn query_vamana_index<T>(
1040    query_mat: MatRef<T>,
1041    index: &VamanaIndex<T>,
1042    k: usize,
1043    ef_search: Option<usize>,
1044    return_dist: bool,
1045    verbose: bool,
1046) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1047where
1048    T: AnnSearchFloat,
1049    VamanaIndex<T>: VamanaState<T>,
1050{
1051    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
1052        index.query_row(query_mat.row(i), k, ef_search)
1053    })
1054}
1055
1056/// Self-query a Vamana index to generate a full kNN graph
1057///
1058/// ### Params
1059///
1060/// * `index` - Reference to the built index.
1061/// * `k` - Number of neighbours to return.
1062/// * `ef_search` - Optional beam width override.
1063/// * `return_dist` - Whether to return distances.
1064/// * `verbose` - Print progress every 100,000 samples.
1065///
1066/// ### Returns
1067///
1068/// A tuple of `(knn_indices, optional distances)`.
1069pub fn query_vamana_self<T>(
1070    index: &VamanaIndex<T>,
1071    k: usize,
1072    ef_search: Option<usize>,
1073    return_dist: bool,
1074    verbose: bool,
1075) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1076where
1077    T: AnnSearchFloat,
1078    VamanaIndex<T>: VamanaState<T>,
1079{
1080    index.generate_knn(k, ef_search, return_dist, verbose)
1081}
1082
1083///////////////
1084// Quantised //
1085///////////////
1086
1087/////////////////////
1088// Exhaustive-BF16 //
1089/////////////////////
1090
1091#[cfg(feature = "quantised")]
1092/// Build an Exhaustive-BF16 index
1093///
1094/// ### Params
1095///
1096/// * `mat` - The data matrix. Rows represent the samples, columns represent
1097///   the embedding dimensions
1098/// * `dist_metric` - Distance metric to use
1099/// * `verbose` - Print progress information during index construction
1100///
1101/// ### Return
1102///
1103/// The `ExhaustiveIndexBf16`.
1104pub fn build_exhaustive_bf16_index<T>(
1105    mat: MatRef<T>,
1106    dist_metric: &str,
1107    verbose: bool,
1108) -> ExhaustiveIndexBf16<T>
1109where
1110    T: AnnSearchFloat + Bf16Compatible,
1111{
1112    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
1113    if verbose {
1114        println!(
1115            "Building exhaustive BF16 index with {} samples",
1116            mat.nrows()
1117        );
1118    }
1119    ExhaustiveIndexBf16::new(mat, ann_dist)
1120}
1121
1122#[cfg(feature = "quantised")]
1123/// Helper function to query a given Exhaustive-BF16 index
1124///
1125/// ### Params
1126///
1127/// * `query_mat` - The query matrix containing the samples x features
1128/// * `index` - Reference to the built Exhaustive-BF16 index
1129/// * `k` - Number of neighbours to return
1130/// * `return_dist` - Shall the distances be returned
1131/// * `verbose` - Print progress information
1132///
1133/// ### Returns
1134///
1135/// A tuple of `(knn_indices, optional distances)`
1136pub fn query_exhaustive_bf16_index<T>(
1137    query_mat: MatRef<T>,
1138    index: &ExhaustiveIndexBf16<T>,
1139    k: usize,
1140    return_dist: bool,
1141    verbose: bool,
1142) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1143where
1144    T: AnnSearchFloat + Bf16Compatible,
1145{
1146    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
1147        index.query_row(query_mat.row(i), k)
1148    })
1149}
1150
1151#[cfg(feature = "quantised")]
1152/// Helper function to self query a given Exhaustive-BF16 index
1153///
1154/// This function will generate a full kNN graph based on the internal data.
1155///
1156/// ### Params
1157///
1158/// * `index` - Reference to the built Exhaustive-BF16 index
1159/// * `k` - Number of neighbours to return
1160/// * `return_dist` - Shall the distances be returned
1161/// * `verbose` - Print progress information
1162///
1163/// ### Returns
1164///
1165/// A tuple of `(knn_indices, optional distances)`
1166pub fn query_exhaustive_bf16_self<T>(
1167    index: &ExhaustiveIndexBf16<T>,
1168    k: usize,
1169    return_dist: bool,
1170    verbose: bool,
1171) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1172where
1173    T: AnnSearchFloat + Bf16Compatible,
1174{
1175    index.generate_knn(k, return_dist, verbose)
1176}
1177
1178////////////////////
1179// Exhaustive-SQ8 //
1180////////////////////
1181
1182#[cfg(feature = "quantised")]
1183/// Build an Exhaustive-SQ8 index
1184///
1185/// ### Params
1186///
1187/// * `mat` - The data matrix. Rows represent the samples, columns represent
1188///   the embedding dimensions
1189/// * `dist_metric` - Distance metric to use
1190/// * `verbose` - Print progress information during index construction
1191///
1192/// ### Return
1193///
1194/// The `ExhaustiveSq8Index`.
1195pub fn build_exhaustive_sq8_index<T>(
1196    mat: MatRef<T>,
1197    dist_metric: &str,
1198    verbose: bool,
1199) -> ExhaustiveSq8Index<T>
1200where
1201    T: AnnSearchFloat,
1202{
1203    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
1204    if verbose {
1205        println!("Building exhaustive SQ8 index with {} samples", mat.nrows());
1206    }
1207    ExhaustiveSq8Index::new(mat, ann_dist)
1208}
1209
1210#[cfg(feature = "quantised")]
1211/// Helper function to query a given Exhaustive-SQ8 index
1212///
1213/// ### Params
1214///
1215/// * `query_mat` - The query matrix containing the samples x features
1216/// * `index` - Reference to the built Exhaustive-SQ8 index
1217/// * `k` - Number of neighbours to return
1218/// * `return_dist` - Shall the distances be returned
1219/// * `verbose` - Print progress information
1220///
1221/// ### Returns
1222///
1223/// A tuple of `(knn_indices, optional distances)`
1224pub fn query_exhaustive_sq8_index<T>(
1225    query_mat: MatRef<T>,
1226    index: &ExhaustiveSq8Index<T>,
1227    k: usize,
1228    return_dist: bool,
1229    verbose: bool,
1230) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1231where
1232    T: AnnSearchFloat,
1233{
1234    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
1235        index.query_row(query_mat.row(i), k)
1236    })
1237}
1238
1239#[cfg(feature = "quantised")]
1240/// Helper function to self query a given Exhaustive-SQ8 index
1241///
1242/// This function will generate a full kNN graph based on the internal data.
1243///
1244/// ### Params
1245///
1246/// * `index` - Reference to the built Exhaustive-SQ8 index
1247/// * `k` - Number of neighbours to return
1248/// * `return_dist` - Shall the distances be returned
1249/// * `verbose` - Print progress information
1250///
1251/// ### Returns
1252///
1253/// A tuple of `(knn_indices, optional distances)`
1254pub fn query_exhaustive_sq8_self<T>(
1255    index: &ExhaustiveSq8Index<T>,
1256    k: usize,
1257    return_dist: bool,
1258    verbose: bool,
1259) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1260where
1261    T: AnnSearchFloat,
1262{
1263    index.generate_knn(k, return_dist, verbose)
1264}
1265
1266////////////////////
1267// Exhaustive-PQ //
1268////////////////////
1269
1270#[cfg(feature = "quantised")]
1271/// Build an Exhaustive-PQ index
1272///
1273/// ### Params
1274///
1275/// * `mat` - The data matrix. Rows represent the samples, columns represent
1276///   the embedding dimensions
1277/// * `m` - Number of subspaces for product quantisation (dim must be divisible
1278///   by m)
1279/// * `max_iters` - Maximum k-means iterations (defaults to 30 if None)
1280/// * `n_pq_centroids` - Number of centroids per subspace (defaults to 256 if None)
1281/// * `dist_metric` - Distance metric ("euclidean" or "cosine")
1282/// * `seed` - Random seed for reproducibility
1283/// * `verbose` - Print progress information during index construction
1284///
1285/// ### Return
1286///
1287/// The `ExhaustivePqIndex`.
1288#[allow(clippy::too_many_arguments)]
1289pub fn build_exhaustive_pq_index<T>(
1290    mat: MatRef<T>,
1291    m: usize,
1292    max_iters: Option<usize>,
1293    n_pq_centroids: Option<usize>,
1294    dist_metric: &str,
1295    seed: usize,
1296    verbose: bool,
1297) -> ExhaustivePqIndex<T>
1298where
1299    T: AnnSearchFloat,
1300{
1301    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
1302    ExhaustivePqIndex::build(mat, m, ann_dist, max_iters, n_pq_centroids, seed, verbose)
1303}
1304
1305#[cfg(feature = "quantised")]
1306/// Helper function to query a given Exhaustive-PQ index
1307///
1308/// ### Params
1309///
1310/// * `query_mat` - The query matrix containing the samples x features
1311/// * `index` - Reference to the built Exhaustive-PQ index
1312/// * `k` - Number of neighbours to return
1313/// * `return_dist` - Shall the distances be returned
1314/// * `verbose` - Print progress information
1315///
1316/// ### Returns
1317///
1318/// A tuple of `(knn_indices, optional distances)`
1319pub fn query_exhaustive_pq_index<T>(
1320    query_mat: MatRef<T>,
1321    index: &ExhaustivePqIndex<T>,
1322    k: usize,
1323    return_dist: bool,
1324    verbose: bool,
1325) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1326where
1327    T: AnnSearchFloat,
1328{
1329    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
1330        index.query_row(query_mat.row(i), k)
1331    })
1332}
1333
1334#[cfg(feature = "quantised")]
1335/// Helper function to self query an Exhaustive-PQ index
1336///
1337/// This function will generate a full kNN graph based on the internal data. To
1338/// note, during quantisation information is lost, hence, the quality of the
1339/// graph is reduced compared to other indices.
1340///
1341/// ### Params
1342///
1343/// * `index` - Reference to the built Exhaustive-PQ index
1344/// * `k` - Number of neighbours to return
1345/// * `return_dist` - Shall the distances be returned
1346/// * `verbose` - Print progress information
1347///
1348/// ### Returns
1349///
1350/// A tuple of `(knn_indices, optional distances)`
1351pub fn query_exhaustive_pq_index_self<T>(
1352    index: &ExhaustivePqIndex<T>,
1353    k: usize,
1354    return_dist: bool,
1355    verbose: bool,
1356) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1357where
1358    T: AnnSearchFloat,
1359{
1360    index.generate_knn(k, return_dist, verbose)
1361}
1362
1363////////////////////
1364// Exhaustive-OPQ //
1365////////////////////
1366
1367#[cfg(feature = "quantised")]
1368/// Build an Exhaustive-OPQ index
1369///
1370/// ### Params
1371///
1372/// * `mat` - The data matrix. Rows represent the samples, columns represent
1373///   the embedding dimensions
1374/// * `m` - Number of subspaces for product quantisation (dim must be divisible
1375///   by m)
1376/// * `max_iters` - Maximum k-means iterations (defaults to 30 if None)
1377/// * `n_pq_centroids` - Number of centroids per subspace (defaults to 256 if None)
1378/// * `dist_metric` - Distance metric ("euclidean" or "cosine")
1379/// * `seed` - Random seed for reproducibility
1380/// * `verbose` - Print progress information during index construction
1381///
1382/// ### Return
1383///
1384/// The `ExhaustivePqIndex`.
1385#[allow(clippy::too_many_arguments)]
1386pub fn build_exhaustive_opq_index<T>(
1387    mat: MatRef<T>,
1388    m: usize,
1389    max_iters: Option<usize>,
1390    n_pq_centroids: Option<usize>,
1391    dist_metric: &str,
1392    seed: usize,
1393    verbose: bool,
1394) -> ExhaustiveOpqIndex<T>
1395where
1396    T: AnnSearchFloat + AddAssign,
1397{
1398    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
1399    ExhaustiveOpqIndex::build(mat, m, ann_dist, max_iters, n_pq_centroids, seed, verbose)
1400}
1401
1402#[cfg(feature = "quantised")]
1403/// Helper function to query a given Exhaustive-OPQ index
1404///
1405/// ### Params
1406///
1407/// * `query_mat` - The query matrix containing the samples x features
1408/// * `index` - Reference to the built Exhaustive-PQ index
1409/// * `k` - Number of neighbours to return
1410/// * `return_dist` - Shall the distances be returned
1411/// * `verbose` - Print progress information
1412///
1413/// ### Returns
1414///
1415/// A tuple of `(knn_indices, optional distances)`
1416pub fn query_exhaustive_opq_index<T>(
1417    query_mat: MatRef<T>,
1418    index: &ExhaustiveOpqIndex<T>,
1419    k: usize,
1420    return_dist: bool,
1421    verbose: bool,
1422) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1423where
1424    T: AnnSearchFloat + AddAssign,
1425{
1426    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
1427        index.query_row(query_mat.row(i), k)
1428    })
1429}
1430
1431#[cfg(feature = "quantised")]
1432/// Helper function to self query an Exhaustive-OPQ index
1433///
1434/// This function will generate a full kNN graph based on the internal data. To
1435/// note, during quantisation information is lost, hence, the quality of the
1436/// graph is reduced compared to other indices.
1437///
1438/// ### Params
1439///
1440/// * `index` - Reference to the built Exhaustive-PQ index
1441/// * `k` - Number of neighbours to return
1442/// * `return_dist` - Shall the distances be returned
1443/// * `verbose` - Print progress information
1444///
1445/// ### Returns
1446///
1447/// A tuple of `(knn_indices, optional distances)`
1448pub fn query_exhaustive_opq_index_self<T>(
1449    index: &ExhaustiveOpqIndex<T>,
1450    k: usize,
1451    return_dist: bool,
1452    verbose: bool,
1453) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1454where
1455    T: AnnSearchFloat + AddAssign,
1456{
1457    index.generate_knn(k, return_dist, verbose)
1458}
1459
1460//////////////
1461// IVF-BF16 //
1462//////////////
1463
1464#[cfg(feature = "quantised")]
1465/// Build an IVF-BF16 index
1466///
1467/// ### Params
1468///
1469/// * `mat` - The data matrix. Rows represent the samples, columns represent
1470///   the embedding dimensions
1471/// * `nlist` - Optional number of cells to create. If not provided, defaults
1472///   to `sqrt(n)`.
1473/// * `max_iters` - Maximum k-means iterations (defaults to 30 if None)
1474/// * `seed` - Random seed for reproducibility
1475/// * `verbose` - Print progress information during index construction
1476///
1477/// ### Return
1478///
1479/// The `IvfIndexBf16`.
1480pub fn build_ivf_bf16_index<T>(
1481    mat: MatRef<T>,
1482    nlist: Option<usize>,
1483    max_iters: Option<usize>,
1484    dist_metric: &str,
1485    seed: usize,
1486    verbose: bool,
1487) -> IvfIndexBf16<T>
1488where
1489    T: AnnSearchFloat + Bf16Compatible,
1490{
1491    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
1492
1493    IvfIndexBf16::build(mat, ann_dist, nlist, max_iters, seed, verbose)
1494}
1495
1496#[cfg(feature = "quantised")]
1497/// Helper function to query a given IVF-BF16 index
1498///
1499/// ### Params
1500///
1501/// * `query_mat` - The query matrix containing the samples x features
1502/// * `index` - Reference to the built IVF-BF16 index
1503/// * `k` - Number of neighbours to return
1504/// * `nprobe` - Number of clusters to search (defaults to 20% of nlist)
1505///   Higher values improve recall at the cost of speed
1506/// * `return_dist` - Shall the inner product scores be returned
1507/// * `verbose` - Print progress information
1508///
1509/// ### Returns
1510///
1511/// A tuple of `(knn_indices, optional inner_product_scores)`
1512pub fn query_ivf_bf16_index<T>(
1513    query_mat: MatRef<T>,
1514    index: &IvfIndexBf16<T>,
1515    k: usize,
1516    nprobe: Option<usize>,
1517    return_dist: bool,
1518    verbose: bool,
1519) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1520where
1521    T: AnnSearchFloat + Bf16Compatible,
1522{
1523    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
1524        index.query_row(query_mat.row(i), k, nprobe)
1525    })
1526}
1527
1528#[cfg(feature = "quantised")]
1529/// Helper function to self query a given IVF-SQ8 index
1530///
1531/// This function will generate a full kNN graph based on the internal data. To
1532/// accelerate the process, it will leverage the internally quantised vectors
1533/// and the information on the Voronoi cells under the hood and query nearby
1534/// cells per given internal vector.
1535///
1536/// ### Params
1537///
1538/// * `index` - Reference to the built IVF-SQ8 index
1539/// * `k` - Number of neighbours to return
1540/// * `nprobe` - Number of clusters to search (defaults to 20% of nlist)
1541///   Higher values improve recall at the cost of speed
1542/// * `return_dist` - Shall the inner product scores be returned
1543/// * `verbose` - Print progress information
1544///
1545/// ### Returns
1546///
1547/// A tuple of `(knn_indices, optional inner_product_scores)`
1548pub fn query_ivf_bf16_self<T>(
1549    index: &IvfIndexBf16<T>,
1550    k: usize,
1551    nprobe: Option<usize>,
1552    return_dist: bool,
1553    verbose: bool,
1554) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1555where
1556    T: AnnSearchFloat + Bf16Compatible,
1557{
1558    index.generate_knn(k, nprobe, return_dist, verbose)
1559}
1560
1561/////////////
1562// IVF-SQ8 //
1563/////////////
1564
1565#[cfg(feature = "quantised")]
1566/// Build an IVF-SQ8 index
1567///
1568/// ### Params
1569///
1570/// * `mat` - The data matrix. Rows represent the samples, columns represent
1571///   the embedding dimensions
1572/// * `nlist` - Optional number of cells to create. If not provided, defaults
1573///   to `sqrt(n)`.
1574/// * `max_iters` - Maximum k-means iterations (defaults to 30 if None)
1575/// * `seed` - Random seed for reproducibility
1576/// * `verbose` - Print progress information during index construction
1577///
1578/// ### Return
1579///
1580/// The `IvfSq8Index`.
1581pub fn build_ivf_sq8_index<T>(
1582    mat: MatRef<T>,
1583    nlist: Option<usize>,
1584    max_iters: Option<usize>,
1585    dist_metric: &str,
1586    seed: usize,
1587    verbose: bool,
1588) -> IvfSq8Index<T>
1589where
1590    T: AnnSearchFloat,
1591{
1592    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
1593
1594    IvfSq8Index::build(mat, nlist, ann_dist, max_iters, seed, verbose)
1595}
1596
1597#[cfg(feature = "quantised")]
1598/// Helper function to query a given IVF-SQ8 index
1599///
1600/// ### Params
1601///
1602/// * `query_mat` - The query matrix containing the samples x features
1603/// * `index` - Reference to the built IVF-SQ8 index
1604/// * `k` - Number of neighbours to return
1605/// * `nprobe` - Number of clusters to search (defaults to 20% of nlist)
1606///   Higher values improve recall at the cost of speed
1607/// * `return_dist` - Shall the inner product scores be returned
1608/// * `verbose` - Print progress information
1609///
1610/// ### Returns
1611///
1612/// A tuple of `(knn_indices, optional distances)`
1613pub fn query_ivf_sq8_index<T>(
1614    query_mat: MatRef<T>,
1615    index: &IvfSq8Index<T>,
1616    k: usize,
1617    nprobe: Option<usize>,
1618    return_dist: bool,
1619    verbose: bool,
1620) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1621where
1622    T: AnnSearchFloat,
1623{
1624    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
1625        index.query_row(query_mat.row(i), k, nprobe)
1626    })
1627}
1628
1629#[cfg(feature = "quantised")]
1630/// Helper function to self query a given IVF-SQ8 index
1631///
1632/// This function will generate a full kNN graph based on the internal data. To
1633/// accelerate the process, it will leverage the internally quantised vectors
1634/// and the information on the Voronoi cells under the hood and query nearby
1635/// cells per given internal vector.
1636///
1637/// ### Params
1638///
1639/// * `index` - Reference to the built IVF-SQ8 index
1640/// * `k` - Number of neighbours to return
1641/// * `nprobe` - Number of clusters to search (defaults to 20% of nlist)
1642///   Higher values improve recall at the cost of speed
1643/// * `return_dist` - Shall the inner product scores be returned
1644/// * `verbose` - Print progress information
1645///
1646/// ### Returns
1647///
1648/// A tuple of `(knn_indices, optional distances)`
1649pub fn query_ivf_sq8_self<T>(
1650    index: &IvfSq8Index<T>,
1651    k: usize,
1652    nprobe: Option<usize>,
1653    return_dist: bool,
1654    verbose: bool,
1655) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1656where
1657    T: AnnSearchFloat,
1658{
1659    index.generate_knn(k, nprobe, return_dist, verbose)
1660}
1661
1662////////////
1663// IVF-PQ //
1664////////////
1665
1666#[cfg(feature = "quantised")]
1667/// Build an IVF-PQ index
1668///
1669/// ### Params
1670///
1671/// * `mat` - The data matrix. Rows represent the samples, columns represent
1672///   the embedding dimensions
1673/// * `nlist` - Number of IVF clusters to create
1674/// * `m` - Number of subspaces for product quantisation (dim must be divisible
1675///   by m)
1676/// * `max_iters` - Maximum k-means iterations (defaults to 30 if None)
1677/// * `dist_metric` - Distance metric ("euclidean" or "cosine")
1678/// * `seed` - Random seed for reproducibility
1679/// * `verbose` - Print progress information during index construction
1680///
1681/// ### Return
1682///
1683/// The `IvfPqIndex`.
1684#[allow(clippy::too_many_arguments)]
1685pub fn build_ivf_pq_index<T>(
1686    mat: MatRef<T>,
1687    nlist: Option<usize>,
1688    m: usize,
1689    max_iters: Option<usize>,
1690    n_pq_centroids: Option<usize>,
1691    dist_metric: &str,
1692    seed: usize,
1693    verbose: bool,
1694) -> IvfPqIndex<T>
1695where
1696    T: AnnSearchFloat,
1697{
1698    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
1699
1700    IvfPqIndex::build(
1701        mat,
1702        nlist,
1703        m,
1704        ann_dist,
1705        max_iters,
1706        n_pq_centroids,
1707        seed,
1708        verbose,
1709    )
1710}
1711
1712#[cfg(feature = "quantised")]
1713/// Helper function to query a given IVF-PQ index
1714///
1715/// ### Params
1716///
1717/// * `query_mat` - The query matrix containing the samples x features
1718/// * `index` - Reference to the built IVF-PQ index
1719/// * `k` - Number of neighbours to return
1720/// * `nprobe` - Number of clusters to search (defaults to 15% of nlist)
1721///   Higher values improve recall at the cost of speed
1722/// * `return_dist` - Shall the distances be returned
1723/// * `verbose` - Print progress information
1724///
1725/// ### Returns
1726///
1727/// A tuple of `(knn_indices, optional distances)`
1728pub fn query_ivf_pq_index<T>(
1729    query_mat: MatRef<T>,
1730    index: &IvfPqIndex<T>,
1731    k: usize,
1732    nprobe: Option<usize>,
1733    return_dist: bool,
1734    verbose: bool,
1735) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1736where
1737    T: AnnSearchFloat,
1738{
1739    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
1740        index.query_row(query_mat.row(i), k, nprobe)
1741    })
1742}
1743
1744#[cfg(feature = "quantised")]
1745/// Helper function to self query a IVF-PQ index
1746///
1747/// This function will generate a full kNN graph based on the internal data. To
1748/// note, during quantisation information is lost, hence, the quality of the
1749/// graph is reduced compared to other indices.
1750///
1751/// ### Params
1752///
1753/// * `index` - Reference to the built IVF-PQ index
1754/// * `k` - Number of neighbours to return
1755/// * `nprobe` - Number of clusters to search (defaults to 15% of nlist)
1756///   Higher values improve recall at the cost of speed
1757/// * `return_dist` - Shall the distances be returned
1758/// * `verbose` - Print progress information
1759///
1760/// ### Returns
1761///
1762/// A tuple of `(knn_indices, optional distances)`
1763pub fn query_ivf_pq_index_self<T>(
1764    index: &IvfPqIndex<T>,
1765    k: usize,
1766    nprobe: Option<usize>,
1767    return_dist: bool,
1768    verbose: bool,
1769) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1770where
1771    T: AnnSearchFloat,
1772{
1773    index.generate_knn(k, nprobe, return_dist, verbose)
1774}
1775
1776/////////////
1777// IVF-OPQ //
1778/////////////
1779
1780#[cfg(feature = "quantised")]
1781/// Build an IVF-OPQ index
1782///
1783/// ### Params
1784///
1785/// * `mat` - The data matrix. Rows represent the samples, columns represent
1786///   the embedding dimensions
1787/// * `nlist` - Number of IVF clusters to create
1788/// * `m` - Number of subspaces for product quantisation (dim must be divisible
1789///   by m)
1790/// * `max_iters` - Maximum k-means iterations (defaults to 30 if None)
1791/// * `dist_metric` - Distance metric ("euclidean" or "cosine")
1792/// * `seed` - Random seed for reproducibility
1793/// * `verbose` - Print progress information during index construction
1794///
1795/// ### Return
1796///
1797/// The `IvfOpqIndex`.
1798#[allow(clippy::too_many_arguments)]
1799pub fn build_ivf_opq_index<T>(
1800    mat: MatRef<T>,
1801    nlist: Option<usize>,
1802    m: usize,
1803    max_iters: Option<usize>,
1804    n_opq_centroids: Option<usize>,
1805    n_opq_iter: Option<usize>,
1806    dist_metric: &str,
1807    seed: usize,
1808    verbose: bool,
1809) -> IvfOpqIndex<T>
1810where
1811    T: AnnSearchFloat + AddAssign,
1812{
1813    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
1814
1815    IvfOpqIndex::build(
1816        mat,
1817        nlist,
1818        m,
1819        ann_dist,
1820        max_iters,
1821        n_opq_iter,
1822        n_opq_centroids,
1823        seed,
1824        verbose,
1825    )
1826}
1827
1828#[cfg(feature = "quantised")]
1829/// Helper function to query a given IVF-OPQ index
1830///
1831/// ### Params
1832///
1833/// * `query_mat` - The query matrix containing the samples x features
1834/// * `index` - Reference to the built IVF-OPQ index
1835/// * `k` - Number of neighbours to return
1836/// * `nprobe` - Number of clusters to search (defaults to 15% of nlist)
1837///   Higher values improve recall at the cost of speed
1838/// * `return_dist` - Shall the distances be returned
1839/// * `verbose` - Print progress information
1840///
1841/// ### Returns
1842///
1843/// A tuple of `(knn_indices, optional distances)`
1844pub fn query_ivf_opq_index<T>(
1845    query_mat: MatRef<T>,
1846    index: &IvfOpqIndex<T>,
1847    k: usize,
1848    nprobe: Option<usize>,
1849    return_dist: bool,
1850    verbose: bool,
1851) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1852where
1853    T: AnnSearchFloat + AddAssign,
1854{
1855    query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
1856        index.query_row(query_mat.row(i), k, nprobe)
1857    })
1858}
1859
1860#[cfg(feature = "quantised")]
1861/// Helper function to self query a IVF-OPQ index
1862///
1863/// This function will generate a full kNN graph based on the internal data. To
1864/// note, during quantisation information is lost, hence, the quality of the
1865/// graph is reduced compared to other indices.
1866///
1867/// ### Params
1868///
1869/// * `index` - Reference to the built IVF-OPQ index
1870/// * `k` - Number of neighbours to return
1871/// * `nprobe` - Number of clusters to search (defaults to 15% of nlist)
1872///   Higher values improve recall at the cost of speed
1873/// * `return_dist` - Shall the distances be returned
1874/// * `verbose` - Print progress information
1875///
1876/// ### Returns
1877///
1878/// A tuple of `(knn_indices, optional distances)`
1879pub fn query_ivf_opq_index_self<T>(
1880    index: &IvfOpqIndex<T>,
1881    k: usize,
1882    nprobe: Option<usize>,
1883    return_dist: bool,
1884    verbose: bool,
1885) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1886where
1887    T: AnnSearchFloat + AddAssign,
1888{
1889    index.generate_knn(k, nprobe, return_dist, verbose)
1890}
1891
1892/////////
1893// GPU //
1894/////////
1895
1896////////////////////
1897// Exhaustive GPU //
1898////////////////////
1899
1900#[cfg(feature = "gpu")]
1901/// Build an exhaustive GPU index
1902///
1903/// ### Params
1904///
1905/// * `mat` - The initial matrix with samples x features
1906/// * `dist_metric` - Distance metric: "euclidean" or "cosine"
1907/// * `device` - The GPU device to use
1908///
1909/// ### Returns
1910///
1911/// The initialised `ExhaustiveIndexGpu`
1912pub fn build_exhaustive_index_gpu<T, R>(
1913    mat: MatRef<T>,
1914    dist_metric: &str,
1915    device: R::Device,
1916) -> ExhaustiveIndexGpu<T, R>
1917where
1918    T: AnnSearchGpuFloat + AnnSearchFloat,
1919    R: Runtime,
1920{
1921    let metric = parse_ann_dist(dist_metric).unwrap_or_default();
1922    ExhaustiveIndexGpu::new(mat, metric, device)
1923}
1924
1925#[cfg(feature = "gpu")]
1926/// Query the exhaustive GPU index
1927///
1928/// ### Params
1929///
1930/// * `query_mat` - The query matrix containing the samples × features
1931/// * `index` - The exhaustive GPU index
1932/// * `k` - Number of neighbours to return
1933/// * `return_dist` - Shall the distances be returned
1934///
1935/// ### Returns
1936///
1937/// A tuple of `(knn_indices, optional distances)`
1938pub fn query_exhaustive_index_gpu<T, R>(
1939    query_mat: MatRef<T>,
1940    index: &ExhaustiveIndexGpu<T, R>,
1941    k: usize,
1942    return_dist: bool,
1943    verbose: bool,
1944) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1945where
1946    T: AnnSearchGpuFloat + AnnSearchFloat,
1947    R: Runtime,
1948{
1949    let (indices, distances) = index.query_batch(query_mat, k, verbose);
1950
1951    if return_dist {
1952        (indices, Some(distances))
1953    } else {
1954        (indices, None)
1955    }
1956}
1957
1958#[cfg(feature = "gpu")]
1959/// Query the exhaustive GPU index itself
1960///
1961/// This function will generate a full kNN graph based on the internal data.
1962///
1963/// ### Params
1964///
1965/// * `query_mat` - The query matrix containing the samples × features
1966/// * `index` - The exhaustive GPU index
1967/// * `k` - Number of neighbours to return
1968/// * `return_dist` - Shall the distances be returned
1969///
1970/// ### Returns
1971///
1972/// A tuple of `(knn_indices, optional distances)`
1973pub fn query_exhaustive_index_gpu_self<T, R>(
1974    index: &ExhaustiveIndexGpu<T, R>,
1975    k: usize,
1976    return_dist: bool,
1977    verbose: bool,
1978) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
1979where
1980    T: AnnSearchGpuFloat + AnnSearchFloat,
1981    R: Runtime,
1982{
1983    index.generate_knn(k, return_dist, verbose)
1984}
1985
1986//////////////
1987// IVF GPU //
1988//////////////
1989
1990#[cfg(feature = "gpu")]
1991/// Build an IVF index with batched GPU acceleration
1992///
1993/// ### Params
1994///
1995/// * `mat` - Data matrix [samples, features]
1996/// * `nlist` - Number of clusters (defaults to √n)
1997/// * `max_iters` - K-means iterations (defaults to 30)
1998/// * `dist_metric` - "euclidean" or "cosine"
1999/// * `seed` - Random seed
2000/// * `verbose` - Print progress
2001/// * `device` - GPU device
2002pub fn build_ivf_index_gpu<T, R>(
2003    mat: MatRef<T>,
2004    nlist: Option<usize>,
2005    max_iters: Option<usize>,
2006    dist_metric: &str,
2007    seed: usize,
2008    verbose: bool,
2009    device: R::Device,
2010) -> IvfIndexGpu<T, R>
2011where
2012    R: Runtime,
2013    T: AnnSearchFloat + AnnSearchGpuFloat,
2014{
2015    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
2016    IvfIndexGpu::build(mat, ann_dist, nlist, max_iters, seed, verbose, device)
2017}
2018
2019#[cfg(feature = "gpu")]
2020/// Query an IVF GPU index
2021///
2022/// ### Params
2023///
2024/// * `query_mat` - Query matrix [samples, features]
2025/// * `index` - Reference to built index
2026/// * `k` - Number of neighbours
2027/// * `nprobe` - Clusters to search (defaults to √nlist)
2028/// * `nquery` - Number of queries to load into the GPU.
2029/// * `return_dist` - Return distances
2030/// * `verbose` - Controls verbosity of the function
2031///
2032/// ### Returns
2033///
2034/// Tuple of (indices, optional distances)
2035pub fn query_ivf_index_gpu<T, R>(
2036    query_mat: MatRef<T>,
2037    index: &IvfIndexGpu<T, R>,
2038    k: usize,
2039    nprobe: Option<usize>,
2040    nquery: Option<usize>,
2041    return_dist: bool,
2042    verbose: bool,
2043) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
2044where
2045    R: Runtime,
2046    T: AnnSearchFloat + AnnSearchGpuFloat,
2047{
2048    let (indices, distances) = index.query_batch(query_mat, k, nprobe, nquery, verbose);
2049
2050    if return_dist {
2051        (indices, Some(distances))
2052    } else {
2053        (indices, None)
2054    }
2055}
2056
2057#[cfg(feature = "gpu")]
2058/// Query an IVF GPU index itself
2059///
2060/// This function will generate a full kNN graph based on the internal data.
2061///
2062/// ### Params
2063///
2064/// * `query_mat` - Query matrix [samples, features]
2065/// * `index` - Reference to built index
2066/// * `k` - Number of neighbours
2067/// * `nprobe` - Clusters to search (defaults to √nlist)
2068/// * `nquery` - Number of queries to load into the GPU.
2069/// * `return_dist` - Return distances
2070/// * `verbose` - Controls verbosity of the function
2071///
2072/// ### Returns
2073///
2074/// Tuple of (indices, optional distances)
2075pub fn query_ivf_index_gpu_self<T, R>(
2076    index: &IvfIndexGpu<T, R>,
2077    k: usize,
2078    nprobe: Option<usize>,
2079    nquery: Option<usize>,
2080    return_dist: bool,
2081    verbose: bool,
2082) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
2083where
2084    R: Runtime,
2085    T: AnnSearchFloat + AnnSearchGpuFloat,
2086{
2087    index.generate_knn(k, nprobe, nquery, return_dist, verbose)
2088}
2089
2090///////////////////
2091// NNDescent GPU //
2092///////////////////
2093
2094#[cfg(feature = "gpu")]
2095/// Build an NNDescent index with GPU-accelerated graph construction
2096/// and CAGRA optimisation.
2097///
2098/// ### Params
2099///
2100/// * `mat` - Data matrix [samples, features]
2101/// * `dist_metric` - "euclidean" or "cosine"
2102/// * `k` - Final neighbours per node (default 30)
2103/// * `build_k` - Internal NNDescent degree before CAGRA pruning (default 2*k)
2104/// * `max_iters` - Maximum NNDescent iterations (default 15)
2105/// * `n_trees` - Annoy forest size (default auto)
2106/// * `delta` - Convergence threshold (default 0.001)
2107/// * `rho` - Sampling rate (default 0.5)
2108/// * `seed` - Random seed
2109/// * `verbose` - Print progress
2110/// * `device` - GPU device
2111#[allow(clippy::too_many_arguments)]
2112pub fn build_nndescent_index_gpu<T, R>(
2113    mat: MatRef<T>,
2114    dist_metric: &str,
2115    k: Option<usize>,
2116    build_k: Option<usize>,
2117    max_iters: Option<usize>,
2118    n_trees: Option<usize>,
2119    delta: Option<f32>,
2120    rho: Option<f32>,
2121    refine_knn: Option<usize>,
2122    seed: usize,
2123    verbose: bool,
2124    retain_gpu: bool,
2125    device: R::Device,
2126) -> NNDescentGpu<T, R>
2127where
2128    R: Runtime,
2129    T: AnnSearchFloat + AnnSearchGpuFloat,
2130    NNDescentGpu<T, R>: NNDescentQuery<T>,
2131{
2132    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
2133    NNDescentGpu::build(
2134        mat, ann_dist, k, build_k, max_iters, n_trees, delta, rho, refine_knn, seed, verbose,
2135        retain_gpu, device,
2136    )
2137}
2138
2139#[cfg(feature = "gpu")]
2140/// Query an NNDescent GPU index.
2141///
2142/// ### Params
2143///
2144/// * `query_mat` - Query matrix [samples, features]
2145/// * `index` - Reference to built index
2146/// * `k` - Number of neighbours
2147/// * `ef_search` - Beam width (default auto)
2148/// * `query_params` - Optional GPU beam search parameters
2149/// * `return_dist` - Return distances
2150/// * `verbose` - Print progress
2151///
2152/// ### Returns
2153///
2154/// Tuple of (indices, optional distances)
2155pub fn query_nndescent_index_gpu<T, R>(
2156    query_mat: MatRef<T>,
2157    index: &mut NNDescentGpu<T, R>,
2158    k: usize,
2159    ef_search: Option<usize>,
2160    query_params: Option<CagraGpuSearchParams>,
2161    return_dist: bool,
2162    verbose: bool,
2163) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
2164where
2165    R: Runtime,
2166    T: AnnSearchFloat + AnnSearchGpuFloat,
2167    NNDescentGpu<T, R>: NNDescentQuery<T>,
2168{
2169    use rayon::prelude::*;
2170
2171    let n_queries = query_mat.nrows();
2172    let gpu_batch_threshold = 32;
2173
2174    if n_queries >= gpu_batch_threshold && ef_search.is_none() {
2175        if verbose {
2176            println!("  GPU batch query: {} vectors, k={}...", n_queries, k);
2177        }
2178
2179        let queries_flat: Vec<T> = (0..n_queries)
2180            .flat_map(|i| {
2181                let row = query_mat.row(i);
2182                if row.col_stride() == 1 {
2183                    unsafe { std::slice::from_raw_parts(row.as_ptr(), row.ncols()) }.to_vec()
2184                } else {
2185                    row.iter().cloned().collect()
2186                }
2187            })
2188            .collect();
2189
2190        let (indices, distances) =
2191            index.query_batch_gpu(&queries_flat, n_queries, query_params, k, 42);
2192
2193        if return_dist {
2194            (indices, Some(distances))
2195        } else {
2196            (indices, None)
2197        }
2198    } else {
2199        if verbose {
2200            println!(
2201                "  CPU beam search: {} vectors (ef={:?})...",
2202                n_queries, ef_search
2203            );
2204        }
2205
2206        let results: Vec<(Vec<usize>, Vec<T>)> = (0..n_queries)
2207            .into_par_iter()
2208            .map(|i| {
2209                let row = query_mat.row(i);
2210                index.query_row(row, k, ef_search)
2211            })
2212            .collect();
2213
2214        if return_dist {
2215            let (indices, distances) = results.into_iter().unzip();
2216            (indices, Some(distances))
2217        } else {
2218            let indices = results.into_iter().map(|(idx, _)| idx).collect();
2219            (indices, None)
2220        }
2221    }
2222}
2223
2224#[cfg(feature = "gpu")]
2225/// Extract the internal kNN graph from an NNDescent GPU index.
2226///
2227/// No search is performed -- this simply reshapes the graph that
2228/// was already built during construction.
2229///
2230/// ### Params
2231///
2232/// * `index` - Reference to built index
2233/// * `return_dist` - Return distances
2234///
2235/// ### Returns
2236///
2237/// Tuple of (indices, optional distances)
2238pub fn extract_nndescent_knn_gpu<T, R>(
2239    index: &NNDescentGpu<T, R>,
2240    return_dist: bool,
2241) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
2242where
2243    R: Runtime,
2244    T: AnnSearchFloat + AnnSearchGpuFloat,
2245    NNDescentGpu<T, R>: NNDescentQuery<T>,
2246{
2247    index.extract_knn(return_dist)
2248}
2249
2250#[cfg(feature = "gpu")]
2251/// Self-query the NNDescent GPU index via GPU beam search.
2252///
2253/// Searches the CAGRA navigational graph for every vector in the index,
2254/// producing a full kNN graph. Results differ from `extract_nndescent_knn_gpu`
2255/// which returns the raw NNDescent graph without beam search refinement.
2256///
2257/// ### Params
2258///
2259/// * `index` - Mutable reference to built index
2260/// * `k` - Number of neighbours
2261/// * `query_params` - Optional GPU beam search parameters
2262/// * `return_dist` - Return distances
2263///
2264/// ### Returns
2265///
2266/// Tuple of (indices, optional distances)
2267pub fn query_nndescent_index_gpu_self<T, R>(
2268    index: &mut NNDescentGpu<T, R>,
2269    k: usize,
2270    query_params: Option<CagraGpuSearchParams>,
2271    return_dist: bool,
2272) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
2273where
2274    R: Runtime,
2275    T: AnnSearchFloat + AnnSearchGpuFloat,
2276    NNDescentGpu<T, R>: NNDescentQuery<T>,
2277{
2278    let (indices, distances) = index.self_query_gpu(k, query_params, 42);
2279
2280    if return_dist {
2281        (indices, Some(distances))
2282    } else {
2283        (indices, None)
2284    }
2285}
2286
2287////////////
2288// Binary //
2289////////////
2290
2291///////////////////////
2292// Exhaustive Binary //
2293///////////////////////
2294
2295#[cfg(feature = "binary")]
2296/// Build an exhaustive binary index
2297///
2298/// This one can be only used for Cosine distance. There is no good hash
2299/// function that translates Euclidean distance to Hamming distance!
2300///
2301/// ### Params
2302///
2303/// * `mat` - The initial matrix with samples x features
2304/// * `n_bits` - Number of bits per binary code (must be multiple of 8)
2305/// * `seed` - Random seed for binariser
2306/// * `binary_init` - Initialisation method ("itq" or "random")
2307/// * `metric` - Distance metric for reranking (when save_store is true)
2308/// * `save_store` - Whether to save vector store for reranking
2309/// * `save_path` - Path to save vector store files (required if save_store is
2310///   true)
2311///
2312/// ### Returns
2313///
2314/// The initialised `ExhaustiveIndexBinary`
2315pub fn build_exhaustive_index_binary<T>(
2316    mat: MatRef<T>,
2317    n_bits: usize,
2318    seed: usize,
2319    binary_init: &str,
2320    metric: &str,
2321    save_store: bool,
2322    save_path: Option<impl AsRef<Path>>,
2323) -> std::io::Result<ExhaustiveIndexBinary<T>>
2324where
2325    T: AnnSearchFloat + Pod,
2326{
2327    let metric = parse_ann_dist(metric).unwrap_or_default();
2328
2329    if save_store {
2330        let path = save_path.expect("save_path required when save_store is true");
2331        ExhaustiveIndexBinary::new_with_vector_store(mat, binary_init, n_bits, metric, seed, path)
2332    } else {
2333        Ok(ExhaustiveIndexBinary::new(mat, binary_init, n_bits, seed))
2334    }
2335}
2336
2337#[cfg(feature = "binary")]
2338/// Helper function to query a given exhaustive binary index
2339///
2340/// ### Params
2341///
2342/// * `query_mat` - The query matrix containing the samples × features
2343/// * `index` - The exhaustive binary index
2344/// * `k` - Number of neighbours to return
2345/// * `rerank` - Whether to use exact distance reranking (requires vector store)
2346/// * `rerank_factor` - Multiplier for candidate set size (only used if rerank
2347///   is true)
2348/// * `return_dist` - Shall the distances be returned
2349/// * `verbose` - Controls verbosity of the function
2350///
2351/// ### Returns
2352///
2353/// A tuple of `(knn_indices, optional distances)` where distances are Hamming (u32 converted to T) or exact distances (T)
2354pub fn query_exhaustive_index_binary<T>(
2355    query_mat: MatRef<T>,
2356    index: &ExhaustiveIndexBinary<T>,
2357    k: usize,
2358    rerank: bool,
2359    rerank_factor: Option<usize>,
2360    return_dist: bool,
2361    verbose: bool,
2362) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
2363where
2364    T: AnnSearchFloat + Pod,
2365{
2366    if rerank {
2367        query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
2368            index.query_row_reranking(query_mat.row(i), k, rerank_factor)
2369        })
2370    } else {
2371        let (indices, dist) = if index.use_asymmetric() {
2372            // path where asymmetric queries are sensible
2373            query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
2374                index.query_row_asymmetric(query_mat.row(i), k, rerank_factor)
2375            })
2376        } else {
2377            // path where asymmetric queries are not sensible/possible
2378            let (indices, distances_u32) =
2379                query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
2380                    index.query_row(query_mat.row(i), k)
2381                });
2382            let distances_t = distances_u32.map(|dists| {
2383                dists
2384                    .into_iter()
2385                    .map(|v| v.into_iter().map(|d| T::from_u32(d).unwrap()).collect())
2386                    .collect()
2387            });
2388
2389            (indices, distances_t)
2390        };
2391
2392        (indices, dist)
2393    }
2394}
2395
2396#[cfg(feature = "binary")]
2397/// Query an exhaustive binary index against itself
2398///
2399/// Generates a full kNN graph based on the internal data.
2400///
2401/// ### Params
2402///
2403/// * `index` - Reference to built index
2404/// * `k` - Number of neighbours
2405/// * `rerank_factor` - Multiplier for candidate set (only used if vector store
2406///   available)
2407/// * `return_dist` - Return distances
2408/// * `verbose` - Controls verbosity
2409///
2410/// ### Returns
2411///
2412/// Tuple of (indices, optional distances)
2413pub fn query_exhaustive_index_binary_self<T>(
2414    index: &ExhaustiveIndexBinary<T>,
2415    k: usize,
2416    rerank_factor: Option<usize>,
2417    return_dist: bool,
2418    verbose: bool,
2419) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
2420where
2421    T: AnnSearchFloat + Pod,
2422{
2423    index.generate_knn(k, rerank_factor, return_dist, verbose)
2424}
2425
2426////////////////
2427// IVF Binary //
2428////////////////
2429
2430#[cfg(feature = "binary")]
2431/// Build an IVF index with binary quantisation
2432///
2433/// ### Params
2434///
2435/// * `mat` - Data matrix [samples, features]
2436/// * `binarisation_init` - "itq" or "random"
2437/// * `n_bits` - Number of bits per code (multiple of 8)
2438/// * `nlist` - Number of clusters (defaults to √n)
2439/// * `max_iters` - K-means iterations (defaults to 30)
2440/// * `dist_metric` - "euclidean" or "cosine"
2441/// * `seed` - Random seed
2442/// * `save_store` - Whether to save vector store for reranking
2443/// * `save_path` - Path to save vector store files (required if save_store
2444///   is true)
2445/// * `verbose` - Print progress
2446///
2447/// ### Returns
2448///
2449/// Built IVF binary index
2450#[allow(clippy::too_many_arguments)]
2451pub fn build_ivf_index_binary<T>(
2452    mat: MatRef<T>,
2453    binarisation_init: &str,
2454    n_bits: usize,
2455    nlist: Option<usize>,
2456    max_iters: Option<usize>,
2457    dist_metric: &str,
2458    seed: usize,
2459    save_store: bool,
2460    save_path: Option<impl AsRef<Path>>,
2461    verbose: bool,
2462) -> std::io::Result<IvfIndexBinary<T>>
2463where
2464    T: AnnSearchFloat + Pod,
2465{
2466    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
2467
2468    if save_store {
2469        let path = save_path.expect("save_path required when save_store is true");
2470        IvfIndexBinary::build_with_vector_store(
2471            mat,
2472            binarisation_init,
2473            n_bits,
2474            ann_dist,
2475            nlist,
2476            max_iters,
2477            seed,
2478            verbose,
2479            path,
2480        )
2481    } else {
2482        Ok(IvfIndexBinary::build(
2483            mat,
2484            binarisation_init,
2485            n_bits,
2486            ann_dist,
2487            nlist,
2488            max_iters,
2489            seed,
2490            verbose,
2491        ))
2492    }
2493}
2494
2495#[cfg(feature = "binary")]
2496/// Query an IVF binary index
2497///
2498/// ### Params
2499///
2500/// * `query_mat` - Query matrix [samples, features]
2501/// * `index` - Reference to built index
2502/// * `k` - Number of neighbours
2503/// * `nprobe` - Clusters to search (defaults to √nlist)
2504/// * `rerank` - Whether to use exact distance reranking (requires vector store)
2505/// * `rerank_factor` - Multiplier for candidate set size (only used if rerank
2506///   is true)
2507/// * `return_dist` - Return distances
2508/// * `verbose` - Controls verbosity
2509///
2510/// ### Returns
2511///
2512/// Tuple of (indices, optional distances)
2513#[allow(clippy::too_many_arguments)]
2514pub fn query_ivf_index_binary<T>(
2515    query_mat: MatRef<T>,
2516    index: &IvfIndexBinary<T>,
2517    k: usize,
2518    nprobe: Option<usize>,
2519    rerank: bool,
2520    rerank_factor: Option<usize>,
2521    return_dist: bool,
2522    verbose: bool,
2523) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
2524where
2525    T: AnnSearchFloat + Pod,
2526{
2527    if rerank {
2528        query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
2529            index.query_row_reranking(query_mat.row(i), k, nprobe, rerank_factor)
2530        })
2531    } else {
2532        let (indices, dist) = if index.use_asymmetric() {
2533            query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
2534                index.query_row_asymmetric(query_mat.row(i), k, nprobe, rerank_factor)
2535            })
2536        } else {
2537            let (indices, distances_u32) =
2538                query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
2539                    index.query_row(query_mat.row(i), k, nprobe)
2540                });
2541            let distances_t = distances_u32.map(|dists| {
2542                dists
2543                    .into_iter()
2544                    .map(|v| v.into_iter().map(|d| T::from_u32(d).unwrap()).collect())
2545                    .collect()
2546            });
2547            (indices, distances_t)
2548        };
2549        (indices, dist)
2550    }
2551}
2552
2553#[cfg(feature = "binary")]
2554/// Query an IVF binary index against itself
2555///
2556/// Generates a full kNN graph based on the internal data.
2557///
2558/// ### Params
2559///
2560/// * `index` - Reference to built index
2561/// * `k` - Number of neighbours
2562/// * `nprobe` - Clusters to search (defaults to √nlist)
2563/// * `rerank_factor` - Multiplier for candidate set (only used if vector store available)
2564/// * `return_dist` - Return distances
2565/// * `verbose` - Controls verbosity
2566///
2567/// ### Returns
2568///
2569/// Tuple of (indices, optional distances)
2570pub fn query_ivf_index_binary_self<T>(
2571    index: &IvfIndexBinary<T>,
2572    k: usize,
2573    nprobe: Option<usize>,
2574    rerank_factor: Option<usize>,
2575    return_dist: bool,
2576    verbose: bool,
2577) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
2578where
2579    T: AnnSearchFloat + Pod,
2580{
2581    index.generate_knn(k, nprobe, rerank_factor, return_dist, verbose)
2582}
2583
2584///////////////////////
2585// Exhaustive RaBitQ //
2586///////////////////////
2587
2588#[cfg(feature = "binary")]
2589/// Build an exhaustive RaBitQ index
2590///
2591/// ### Params
2592///
2593/// * `mat` - The initial matrix with samples x features
2594/// * `n_clust_rabitq` - Number of clusters (None for automatic)
2595/// * `dist_metric` - "euclidean" or "cosine"
2596/// * `seed` - Random seed
2597/// * `save_store` - Whether to save vector store for reranking
2598/// * `save_path` - Path to save vector store files (required if save_store is
2599///   true)
2600///
2601/// ### Returns
2602///
2603/// The initialised `ExhaustiveIndexRaBitQ`
2604pub fn build_exhaustive_index_rabitq<T>(
2605    mat: MatRef<T>,
2606    n_clust_rabitq: Option<usize>,
2607    dist_metric: &str,
2608    seed: usize,
2609    save_store: bool,
2610    save_path: Option<impl AsRef<Path>>,
2611) -> std::io::Result<ExhaustiveIndexRaBitQ<T>>
2612where
2613    T: AnnSearchFloat + Pod,
2614{
2615    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
2616    if save_store {
2617        let path = save_path.expect("save_path required when save_store is true");
2618        ExhaustiveIndexRaBitQ::new_with_vector_store(mat, &ann_dist, n_clust_rabitq, seed, path)
2619    } else {
2620        Ok(ExhaustiveIndexRaBitQ::new(
2621            mat,
2622            &ann_dist,
2623            n_clust_rabitq,
2624            seed,
2625        ))
2626    }
2627}
2628
2629#[cfg(feature = "binary")]
2630/// Helper function to query a given exhaustive RaBitQ index
2631///
2632/// ### Params
2633///
2634/// * `query_mat` - The query matrix containing the samples × features
2635/// * `index` - The exhaustive RaBitQ index
2636/// * `k` - Number of neighbours to return
2637/// * `n_probe` - Number of clusters to search (None for default 20%)
2638/// * `rerank` - Whether to use exact distance reranking (requires vector store)
2639/// * `rerank_factor` - Multiplier for candidate set size (only used if rerank is true)
2640/// * `return_dist` - Shall the distances be returned
2641/// * `verbose` - Controls verbosity of the function
2642///
2643/// ### Returns
2644///
2645/// A tuple of `(knn_indices, optional distances)`
2646#[allow(clippy::too_many_arguments)]
2647pub fn query_exhaustive_index_rabitq<T>(
2648    query_mat: MatRef<T>,
2649    index: &ExhaustiveIndexRaBitQ<T>,
2650    k: usize,
2651    n_probe: Option<usize>,
2652    rerank: bool,
2653    rerank_factor: Option<usize>,
2654    return_dist: bool,
2655    verbose: bool,
2656) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
2657where
2658    T: AnnSearchFloat + Pod,
2659{
2660    if rerank {
2661        query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
2662            index.query_row_reranking(query_mat.row(i), k, n_probe, rerank_factor)
2663        })
2664    } else {
2665        query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
2666            index.query_row(query_mat.row(i), k, n_probe)
2667        })
2668    }
2669}
2670
2671#[cfg(feature = "binary")]
2672/// Query an exhaustive RaBitQ index against itself
2673///
2674/// Generates a full kNN graph based on the internal data.
2675/// Requires vector store to be available (use save_store=true when building).
2676///
2677/// ### Params
2678///
2679/// * `index` - Reference to built index
2680/// * `k` - Number of neighbours
2681/// * `n_probe` - Number of clusters to search (None for default 20%)
2682/// * `rerank_factor` - Multiplier for candidate set size
2683/// * `return_dist` - Return distances
2684/// * `verbose` - Controls verbosity
2685///
2686/// ### Returns
2687///
2688/// Tuple of (indices, optional distances)
2689pub fn query_exhaustive_index_rabitq_self<T>(
2690    index: &ExhaustiveIndexRaBitQ<T>,
2691    k: usize,
2692    n_probe: Option<usize>,
2693    rerank_factor: Option<usize>,
2694    return_dist: bool,
2695    verbose: bool,
2696) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
2697where
2698    T: AnnSearchFloat + Pod,
2699{
2700    index.generate_knn(k, n_probe, rerank_factor, return_dist, verbose)
2701}
2702
2703/////////////////
2704// IVF-RaBitQ  //
2705/////////////////
2706
2707#[cfg(feature = "binary")]
2708/// Build an IVF-RaBitQ index
2709///
2710/// ### Params
2711///
2712/// * `mat` - The initial matrix with samples x features
2713/// * `nlist` - Number of IVF cells (None for sqrt(n))
2714/// * `max_iters` - K-means iterations (None for 30)
2715/// * `dist_metric` - "euclidean" or "cosine"
2716/// * `seed` - Random seed
2717/// * `save_store` - Whether to save vector store for reranking
2718/// * `save_path` - Path to save vector store files (required if save_store is
2719///   true)
2720/// * `verbose` - Print progress during build
2721///
2722/// ### Returns
2723///
2724/// The initialised `IvfIndexRaBitQ`
2725#[allow(clippy::too_many_arguments)]
2726pub fn build_ivf_index_rabitq<T>(
2727    mat: MatRef<T>,
2728    nlist: Option<usize>,
2729    max_iters: Option<usize>,
2730    dist_metric: &str,
2731    seed: usize,
2732    save_store: bool,
2733    save_path: Option<impl AsRef<Path>>,
2734    verbose: bool,
2735) -> std::io::Result<IvfIndexRaBitQ<T>>
2736where
2737    T: AnnSearchFloat + Pod,
2738{
2739    let ann_dist = parse_ann_dist(dist_metric).unwrap_or_default();
2740    if save_store {
2741        let path = save_path.expect("save_path required when save_store is true");
2742        IvfIndexRaBitQ::build_with_vector_store(
2743            mat, ann_dist, nlist, max_iters, seed, verbose, path,
2744        )
2745    } else {
2746        Ok(IvfIndexRaBitQ::build(
2747            mat, ann_dist, nlist, max_iters, seed, verbose,
2748        ))
2749    }
2750}
2751
2752#[cfg(feature = "binary")]
2753/// Helper function to query a given IVF-RaBitQ index
2754///
2755/// ### Params
2756///
2757/// * `query_mat` - The query matrix containing the samples × features
2758/// * `index` - The IVF-RaBitQ index
2759/// * `k` - Number of neighbours to return
2760/// * `nprobe` - Number of IVF cells to probe (None for sqrt(nlist))
2761/// * `rerank` - Whether to use exact distance reranking (requires vector store)
2762/// * `rerank_factor` - Multiplier for candidate set size (only used if rerank is true)
2763/// * `return_dist` - Shall the distances be returned
2764/// * `verbose` - Controls verbosity of the function
2765///
2766/// ### Returns
2767///
2768/// A tuple of `(knn_indices, optional distances)`
2769#[allow(clippy::too_many_arguments)]
2770pub fn query_ivf_index_rabitq<T>(
2771    query_mat: MatRef<T>,
2772    index: &IvfIndexRaBitQ<T>,
2773    k: usize,
2774    nprobe: Option<usize>,
2775    rerank: bool,
2776    rerank_factor: Option<usize>,
2777    return_dist: bool,
2778    verbose: bool,
2779) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
2780where
2781    T: AnnSearchFloat + Pod,
2782{
2783    if rerank {
2784        query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
2785            index.query_row_reranking(query_mat.row(i), k, nprobe, rerank_factor)
2786        })
2787    } else {
2788        query_parallel(query_mat.nrows(), return_dist, verbose, |i| {
2789            index.query_row(query_mat.row(i), k, nprobe)
2790        })
2791    }
2792}
2793
2794#[cfg(feature = "binary")]
2795/// Query an IVF-RaBitQ index against itself
2796///
2797/// Generates a full kNN graph based on the internal data.
2798/// Requires vector store to be available (use save_store=true when building).
2799///
2800/// ### Params
2801///
2802/// * `index` - Reference to built index
2803/// * `k` - Number of neighbours
2804/// * `nprobe` - Number of IVF cells to probe (None for sqrt(nlist))
2805/// * `rerank_factor` - Multiplier for candidate set size
2806/// * `return_dist` - Return distances
2807/// * `verbose` - Controls verbosity
2808///
2809/// ### Returns
2810///
2811/// Tuple of (indices, optional distances)
2812pub fn query_ivf_index_rabitq_self<T>(
2813    index: &IvfIndexRaBitQ<T>,
2814    k: usize,
2815    nprobe: Option<usize>,
2816    rerank_factor: Option<usize>,
2817    return_dist: bool,
2818    verbose: bool,
2819) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>)
2820where
2821    T: AnnSearchFloat + Pod,
2822{
2823    index.generate_knn(k, nprobe, rerank_factor, return_dist, verbose)
2824}