pub mod dist;
pub mod graph_utils;
pub mod heap_structs;
pub mod k_means_utils;
pub mod parallelism;
pub mod traits;
pub mod tree_utils;
use faer::MatRef;
use num_traits::{Float, FromPrimitive, ToPrimitive};
use rand::rngs::StdRng;
use rand::Rng;
use rand::SeedableRng;
use rustc_hash::FxHashSet;
use std::collections::BinaryHeap;
use std::iter::Sum;
use crate::prelude::*;
pub type FlattenData<T> = (Vec<T>, usize, usize);
pub fn matrix_to_flat<T>(mat: MatRef<T>) -> FlattenData<T>
where
T: Float,
{
let n = mat.nrows();
let dim = mat.ncols();
let mut vectors_flat = Vec::with_capacity(n * dim);
for i in 0..n {
vectors_flat.extend(mat.row(i).iter().cloned());
}
(vectors_flat, n, dim)
}
pub trait KnnValidation<T>: VectorDistance<T>
where
T: Float + FromPrimitive + ToPrimitive + Send + Sync + Sum + SimdDistance,
{
fn query_for_validation(&self, query_vec: &[T], k: usize) -> (Vec<usize>, Vec<T>);
fn n(&self) -> usize;
fn metric(&self) -> Dist;
fn original_ids(&self) -> &[usize];
fn exhaustive_query(&self, query_vec: &[T], k: usize) -> (Vec<usize>, Vec<T>) {
let n_vectors = self.n();
let k = k.min(n_vectors);
let mut heap: BinaryHeap<(OrderedFloat<T>, usize)> = BinaryHeap::with_capacity(k + 1);
match self.metric() {
Dist::Euclidean => {
for idx in 0..n_vectors {
let dist = self.euclidean_distance_to_query(idx, query_vec);
if heap.len() < k {
heap.push((OrderedFloat(dist), idx));
} else if dist < heap.peek().unwrap().0 .0 {
heap.pop();
heap.push((OrderedFloat(dist), idx));
}
}
}
Dist::Cosine => {
let query_norm = query_vec
.iter()
.map(|v| *v * *v)
.fold(T::zero(), |a, b| a + b)
.sqrt();
for idx in 0..n_vectors {
let dist = self.cosine_distance_to_query(idx, query_vec, query_norm);
if heap.len() < k {
heap.push((OrderedFloat(dist), idx));
} else if dist < heap.peek().unwrap().0 .0 {
heap.pop();
heap.push((OrderedFloat(dist), idx));
}
}
}
}
let mut results: Vec<_> = heap.into_iter().collect();
results.sort_unstable_by_key(|&(dist, _)| dist);
let (distances, indices): (Vec<_>, Vec<_>) = results
.into_iter()
.map(|(OrderedFloat(dist), idx)| (dist, self.original_ids()[idx]))
.unzip();
(indices, distances)
}
fn validate_index(&self, k: usize, seed: usize, no_samples: Option<usize>) -> f64 {
let no_samples = no_samples.unwrap_or(1000).min(self.n());
let mut rng = StdRng::seed_from_u64(seed as u64);
let query_indices: Vec<usize> = (0..no_samples)
.map(|_| rng.random_range(0..self.n()))
.collect();
let mut total_recall = 0.0;
for &query_idx in &query_indices {
let start = query_idx * self.dim();
let query_vec = &self.vectors_flat()[start..start + self.dim()];
let (approx_indices, _) = self.query_for_validation(query_vec, k);
let (true_indices, _) = self.exhaustive_query(query_vec, k);
let approx_set: FxHashSet<_> = approx_indices.into_iter().collect();
let matches = true_indices
.iter()
.filter(|idx| approx_set.contains(idx))
.count();
total_recall += matches as f64 / k as f64;
}
total_recall / no_samples as f64
}
}
#[inline(always)]
pub fn prefetch_read<T>(ptr: *const T) {
#[cfg(target_arch = "x86_64")]
unsafe {
core::arch::x86_64::_mm_prefetch(ptr as *const i8, core::arch::x86_64::_MM_HINT_T0);
}
#[cfg(target_arch = "aarch64")]
unsafe {
core::arch::asm!(
"prfm pldl1keep, [{}]",
in(reg) ptr,
options(readonly, preserves_flags, nostack)
);
}
}