use crate::distance::cosine_distance_normalized;
use crate::distance::FloatOrd;
use crate::RetrieveError;
use smallvec::SmallVec;
use std::collections::BinaryHeap;
#[derive(Clone, Debug)]
pub struct FingerParams {
pub max_degree: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub alpha: f32,
}
impl Default for FingerParams {
fn default() -> Self {
Self {
max_degree: 32,
ef_construction: 200,
ef_search: 100,
alpha: 1.2,
}
}
}
pub struct FingerIndex {
dimension: usize,
params: FingerParams,
built: bool,
vectors: Vec<f32>,
num_vectors: usize,
doc_ids: Vec<u32>,
neighbors: Vec<SmallVec<[u32; 16]>>,
medoid: u32,
direction: Vec<f32>,
projections: Vec<f32>,
}
impl FingerIndex {
pub fn new(dimension: usize, params: FingerParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::InvalidParameter(
"dimension must be > 0".into(),
));
}
if params.alpha < 1.0 {
return Err(RetrieveError::InvalidParameter(
"alpha must be >= 1.0".into(),
));
}
Ok(Self {
dimension,
params,
built: false,
vectors: Vec::new(),
num_vectors: 0,
doc_ids: Vec::new(),
neighbors: Vec::new(),
medoid: 0,
direction: Vec::new(),
projections: Vec::new(),
})
}
pub fn add(&mut self, doc_id: u32, vector: Vec<f32>) -> Result<(), RetrieveError> {
self.add_slice(doc_id, &vector)
}
pub fn add_slice(&mut self, doc_id: u32, vector: &[f32]) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::InvalidParameter(
"cannot add after build".into(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: vector.len(),
doc_dim: self.dimension,
});
}
let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
self.vectors.extend(vector.iter().map(|x| x / norm));
} else {
self.vectors.extend_from_slice(vector);
}
self.doc_ids.push(doc_id);
self.num_vectors += 1;
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let n = self.num_vectors;
let (medoid, direction) = self.compute_medoid_and_direction();
self.medoid = medoid;
self.direction = direction;
self.projections = (0..n)
.map(|i| {
let v = self.get_vector(i);
v.iter()
.zip(self.direction.iter())
.map(|(a, b)| a * b)
.sum::<f32>()
})
.collect();
self.build_knn_graph();
for i in 0..n {
let vi = self.get_vector(i).to_vec();
let candidates = self.beam_search_internal(&vi, self.params.ef_construction);
let selected = self.rng_prune(&vi, &candidates);
let old = std::mem::replace(
&mut self.neighbors[i],
selected.iter().map(|&(id, _)| id).collect(),
);
let max_deg = self.params.max_degree;
for &(neighbor_id, _) in &selected {
let nid = neighbor_id as usize;
if !self.neighbors[nid].contains(&(i as u32)) {
if self.neighbors[nid].len() < max_deg {
self.neighbors[nid].push(i as u32);
} else {
let nv = self.get_vector(nid).to_vec();
let rev_cands: Vec<(u32, f32)> = self.neighbors[nid]
.iter()
.chain(std::iter::once(&(i as u32)))
.map(|&id| {
(
id,
cosine_distance_normalized(&nv, self.get_vector(id as usize)),
)
})
.collect();
let pruned = self.rng_prune(&nv, &rev_cands);
self.neighbors[nid] = pruned.iter().map(|&(id, _)| id).collect();
}
}
}
drop(old);
}
self.ensure_connectivity();
self.built = true;
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
let q_norm: Vec<f32> = if norm > 1e-10 {
query.iter().map(|x| x / norm).collect()
} else {
query.to_vec()
};
let query_proj: f32 = q_norm
.iter()
.zip(self.direction.iter())
.map(|(a, b)| a * b)
.sum();
let ef = self.params.ef_search.max(k);
let results = self.beam_search_with_pruning(&q_norm, query_proj, ef);
Ok(results
.into_iter()
.take(k)
.map(|(id, dist)| (self.doc_ids[id as usize], dist))
.collect())
}
pub fn search_with_ef(
&self,
query: &[f32],
k: usize,
ef_search: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
let q_norm: Vec<f32> = if norm > 1e-10 {
query.iter().map(|x| x / norm).collect()
} else {
query.to_vec()
};
let query_proj: f32 = q_norm
.iter()
.zip(self.direction.iter())
.map(|(a, b)| a * b)
.sum();
let ef = ef_search.max(k);
let results = self.beam_search_with_pruning(&q_norm, query_proj, ef);
Ok(results
.into_iter()
.take(k)
.map(|(id, dist)| (self.doc_ids[id as usize], dist))
.collect())
}
pub fn len(&self) -> usize {
self.num_vectors
}
pub fn is_empty(&self) -> bool {
self.num_vectors == 0
}
#[inline]
fn get_vector(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
&self.vectors[start..start + self.dimension]
}
fn compute_medoid_and_direction(&self) -> (u32, Vec<f32>) {
let dim = self.dimension;
let n = self.num_vectors;
let mut centroid = vec![0.0f32; dim];
for i in 0..n {
let v = self.get_vector(i);
for (j, &val) in v.iter().enumerate() {
centroid[j] += val;
}
}
for c in &mut centroid {
*c /= n as f32;
}
let c_norm: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
let direction: Vec<f32> = if c_norm > 1e-10 {
centroid.iter().map(|x| x / c_norm).collect()
} else {
self.get_vector(0).to_vec()
};
let unnorm_centroid: Vec<f32> = direction.iter().map(|x| x * c_norm).collect();
let mut best = 0u32;
let mut best_d = f32::INFINITY;
for i in 0..n {
let d = cosine_distance_normalized(&unnorm_centroid, self.get_vector(i));
if d < best_d {
best_d = d;
best = i as u32;
}
}
(best, direction)
}
fn build_knn_graph(&mut self) {
let n = self.num_vectors;
if n <= 1000 {
self.build_knn_graph_bruteforce();
} else {
self.build_knn_graph_nndescent();
}
}
fn build_knn_graph_bruteforce(&mut self) {
let n = self.num_vectors;
let k = self.params.max_degree.min(n.saturating_sub(1));
self.neighbors = vec![SmallVec::new(); n];
for i in 0..n {
let vi = self.get_vector(i);
let mut dists: Vec<(u32, f32)> = (0..n)
.filter(|&j| j != i)
.map(|j| (j as u32, cosine_distance_normalized(vi, self.get_vector(j))))
.collect();
dists.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
dists.truncate(k);
self.neighbors[i] = dists.iter().map(|(id, _)| *id).collect();
}
}
fn build_knn_graph_nndescent(&mut self) {
let (n, k, dim) = (self.num_vectors, self.params.max_degree, self.dimension);
let vecs = &self.vectors;
self.neighbors = crate::graph_utils::build_knn_graph_nndescent(n, k, |i, j| {
cosine_distance_normalized(&vecs[i * dim..(i + 1) * dim], &vecs[j * dim..(j + 1) * dim])
});
}
fn rng_prune(&self, _query_vec: &[f32], candidates: &[(u32, f32)]) -> Vec<(u32, f32)> {
let mut sorted = candidates.to_vec();
sorted.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
sorted.dedup_by_key(|c| c.0);
let max_deg = self.params.max_degree;
let alpha = self.params.alpha;
let mut selected: Vec<(u32, f32)> = Vec::with_capacity(max_deg);
for &(cand_id, cand_dist) in &sorted {
if selected.len() >= max_deg {
break;
}
let cand_vec = self.get_vector(cand_id as usize);
let mut keep = true;
for &(sel_id, _) in &selected {
let sel_vec = self.get_vector(sel_id as usize);
let inter_dist = cosine_distance_normalized(sel_vec, cand_vec);
if alpha * cand_dist >= inter_dist {
keep = false;
break;
}
}
if keep {
selected.push((cand_id, cand_dist));
}
}
selected
}
fn beam_search_internal(&self, query: &[f32], ef: usize) -> Vec<(u32, f32)> {
let n = self.num_vectors;
if n == 0 {
return Vec::new();
}
thread_local! {
static VISITED: std::cell::RefCell<(Vec<u8>, u8)> =
const { std::cell::RefCell::new((Vec::new(), 1)) };
}
VISITED.with(|cell| {
let (marks, gen) = &mut *cell.borrow_mut();
if marks.len() < n {
marks.resize(n, 0);
}
if let Some(next) = gen.checked_add(1) {
*gen = next;
} else {
marks.fill(0);
*gen = 1;
}
let generation = *gen;
let mut visited_insert = |id: u32| -> bool {
let idx = id as usize;
if idx < marks.len() && marks[idx] != generation {
marks[idx] = generation;
true
} else { idx >= marks.len() }
};
let mut frontier: BinaryHeap<std::cmp::Reverse<(FloatOrd, u32)>> = BinaryHeap::new();
let mut candidates: Vec<(u32, f32)> = Vec::new();
let entry = self.medoid;
let entry_dist = cosine_distance_normalized(query, self.get_vector(entry as usize));
visited_insert(entry);
frontier.push(std::cmp::Reverse((FloatOrd(entry_dist), entry)));
candidates.push((entry, entry_dist));
let mut visited_count = 1usize;
while let Some(std::cmp::Reverse((FloatOrd(cur_dist), cur_id))) = frontier.pop() {
if candidates.len() >= ef {
candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
if cur_dist > candidates[ef - 1].1 * 1.5 {
break;
}
}
let neighbors = &self.neighbors[cur_id as usize];
for (i, &nb) in neighbors.iter().enumerate() {
if i + 1 < neighbors.len() {
let next_id = neighbors[i + 1] as usize;
let ptr = self.vectors.as_ptr().wrapping_add(next_id * self.dimension);
#[cfg(target_arch = "aarch64")]
unsafe {
std::arch::asm!("prfm pldl1keep, [{ptr}]", ptr = in(reg) ptr, options(nostack, preserves_flags));
}
#[cfg(target_arch = "x86_64")]
unsafe {
std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
}
}
if visited_insert(nb) {
visited_count += 1;
let d = cosine_distance_normalized(query, self.get_vector(nb as usize));
candidates.push((nb, d));
frontier.push(std::cmp::Reverse((FloatOrd(d), nb)));
}
}
if visited_count > ef * 10 {
break;
}
}
candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
candidates.dedup_by_key(|c| c.0);
candidates
})
}
fn beam_search_with_pruning(
&self,
query: &[f32],
query_proj: f32,
ef: usize,
) -> Vec<(u32, f32)> {
let n = self.num_vectors;
if n == 0 {
return Vec::new();
}
thread_local! {
static VISITED_P: std::cell::RefCell<(Vec<u8>, u8)> =
const { std::cell::RefCell::new((Vec::new(), 1)) };
}
VISITED_P.with(|cell| {
let (marks, gen) = &mut *cell.borrow_mut();
if marks.len() < n {
marks.resize(n, 0);
}
if let Some(next) = gen.checked_add(1) {
*gen = next;
} else {
marks.fill(0);
*gen = 1;
}
let generation = *gen;
let mut visited_insert = |id: u32| -> bool {
let idx = id as usize;
if idx < marks.len() && marks[idx] != generation {
marks[idx] = generation;
true
} else { idx >= marks.len() }
};
let mut frontier: BinaryHeap<std::cmp::Reverse<(FloatOrd, u32)>> = BinaryHeap::new();
let mut candidates: Vec<(u32, f32)> = Vec::new();
let mut worst_dist = f32::INFINITY;
let entry = self.medoid;
let entry_dist = cosine_distance_normalized(query, self.get_vector(entry as usize));
visited_insert(entry);
frontier.push(std::cmp::Reverse((FloatOrd(entry_dist), entry)));
candidates.push((entry, entry_dist));
let mut visited_count = 1usize;
while let Some(std::cmp::Reverse((FloatOrd(cur_dist), cur_id))) = frontier.pop() {
if candidates.len() >= ef {
candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
let cutoff = candidates[ef - 1].1;
worst_dist = cutoff;
if cur_dist > cutoff * 1.5 {
break;
}
}
let neighbors = &self.neighbors[cur_id as usize];
for (i, &nb) in neighbors.iter().enumerate() {
if !visited_insert(nb) {
continue;
}
visited_count += 1;
let lb = (query_proj - self.projections[nb as usize]).abs();
if lb * lb * 0.5 > worst_dist {
continue;
}
if i + 1 < neighbors.len() {
let next_id = neighbors[i + 1] as usize;
let ptr = self.vectors.as_ptr().wrapping_add(next_id * self.dimension);
#[cfg(target_arch = "aarch64")]
unsafe {
std::arch::asm!("prfm pldl1keep, [{ptr}]", ptr = in(reg) ptr, options(nostack, preserves_flags));
}
#[cfg(target_arch = "x86_64")]
unsafe {
std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
}
}
let d = cosine_distance_normalized(query, self.get_vector(nb as usize));
candidates.push((nb, d));
frontier.push(std::cmp::Reverse((FloatOrd(d), nb)));
if candidates.len() >= ef {
candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
worst_dist = candidates[ef - 1].1;
}
}
if visited_count > ef * 10 {
break;
}
}
candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
candidates.dedup_by_key(|c| c.0);
candidates
})
}
fn ensure_connectivity(&mut self) {
let (dim, vecs) = (self.dimension, &self.vectors);
crate::graph_utils::ensure_connectivity(&mut self.neighbors, self.medoid, |i, j| {
cosine_distance_normalized(&vecs[i * dim..(i + 1) * dim], &vecs[j * dim..(j + 1) * dim])
});
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn make_vectors(n: usize, dim: usize, seed: u64) -> Vec<f32> {
let mut rng = seed;
(0..n * dim)
.map(|_| {
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
((rng >> 33) as f32 / (1u64 << 31) as f32) - 1.0
})
.collect()
}
fn make_index(n: usize, dim: usize, seed: u64) -> FingerIndex {
let data = make_vectors(n, dim, seed);
let mut index = FingerIndex::new(
dim,
FingerParams {
max_degree: 16,
ef_construction: 64,
ef_search: 50,
alpha: 1.2,
},
)
.unwrap();
for i in 0..n {
index
.add_slice(i as u32, &data[i * dim..(i + 1) * dim])
.unwrap();
}
index.build().unwrap();
index
}
#[test]
fn build_and_search() {
let dim = 16;
let n = 40;
let index = make_index(n, dim, 42);
let data = make_vectors(n, dim, 42);
let query = &data[0..dim];
let results = index.search(query, 5).unwrap();
assert!(!results.is_empty());
assert!(results.iter().any(|(id, _)| *id == 0));
}
#[test]
fn self_search_recall() {
let dim = 16;
let n = 30;
let data = make_vectors(n, dim, 7);
let mut index = FingerIndex::new(
dim,
FingerParams {
max_degree: 16,
ef_construction: 64,
ef_search: 50,
alpha: 1.2,
},
)
.unwrap();
for i in 0..n {
index
.add_slice(i as u32, &data[i * dim..(i + 1) * dim])
.unwrap();
}
index.build().unwrap();
let mut hits = 0usize;
for i in 0..n {
let query = &data[i * dim..(i + 1) * dim];
let results = index.search(query, 1).unwrap();
if results.first().map(|(id, _)| *id) == Some(i as u32) {
hits += 1;
}
}
let recall = hits as f64 / n as f64;
assert!(
recall > 0.6,
"self-search recall too low: {recall:.2} ({hits}/{n})"
);
}
#[test]
fn projection_reduces_computations() {
let dim = 8;
let n = 20;
let data = make_vectors(n, dim, 55);
let index = make_index(n, dim, 55);
for i in 0..n {
let query = &data[i * dim..(i + 1) * dim];
let results = index.search(query, 3).unwrap();
assert!(
results.iter().any(|(id, _)| *id == i as u32),
"vector {i} not in its own top-3 results"
);
}
}
#[test]
fn connectivity() {
let dim = 8;
let n = 20;
let index = make_index(n, dim, 123);
let mut visited = vec![false; n];
let mut stack = vec![index.medoid as usize];
visited[index.medoid as usize] = true;
while let Some(node) = stack.pop() {
for &nb in &index.neighbors[node] {
let nb = nb as usize;
if !visited[nb] {
visited[nb] = true;
stack.push(nb);
}
}
}
let reachable = visited.iter().filter(|&&v| v).count();
assert_eq!(reachable, n, "not all nodes reachable from medoid");
}
#[test]
fn empty_index_errors() {
let mut index = FingerIndex::new(8, FingerParams::default()).unwrap();
assert!(index.build().is_err());
}
#[test]
fn dimension_mismatch_errors() {
let mut index = FingerIndex::new(8, FingerParams::default()).unwrap();
let result = index.add_slice(0, &[1.0, 2.0, 3.0]);
assert!(result.is_err());
}
#[test]
fn add_after_build_errors() {
let dim = 4;
let data = make_vectors(5, dim, 1);
let mut index = FingerIndex::new(dim, FingerParams::default()).unwrap();
for i in 0..5usize {
index
.add_slice(i as u32, &data[i * dim..(i + 1) * dim])
.unwrap();
}
index.build().unwrap();
let result = index.add_slice(99, &data[0..dim]);
assert!(result.is_err());
}
}