use std::collections::{BinaryHeap, HashSet};
#[cfg(feature = "hnsw")]
mod distance_impl {
use crate::hnsw::distance;
pub use distance::cosine_distance;
}
#[cfg(feature = "hnsw")]
use distance_impl::cosine_distance;
#[derive(Clone, PartialEq)]
pub(crate) struct Candidate {
pub(crate) id: u32,
pub(crate) distance: f32,
}
impl Eq for Candidate {}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance.total_cmp(&other.distance).reverse()
}
}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
pub(crate) struct SearchState {
candidates: BinaryHeap<Candidate>,
visited: HashSet<u32>,
best_distance: f32,
no_improvement_count: usize,
}
impl SearchState {
#[allow(dead_code)]
fn new() -> Self {
Self {
candidates: BinaryHeap::new(),
visited: HashSet::new(),
best_distance: f32::INFINITY,
no_improvement_count: 0,
}
}
pub(crate) fn with_capacity(ef: usize) -> Self {
Self {
candidates: BinaryHeap::with_capacity(ef * 2), visited: HashSet::with_capacity(ef * 2), best_distance: f32::INFINITY,
no_improvement_count: 0,
}
}
pub(crate) fn add_candidate(&mut self, id: u32, distance: f32) {
if !self.visited.contains(&id) {
self.candidates.push(Candidate { id, distance });
}
}
pub(crate) fn pop_candidate(&mut self) -> Option<Candidate> {
while let Some(candidate) = self.candidates.pop() {
if !self.visited.contains(&candidate.id) {
self.visited.insert(candidate.id);
if candidate.distance < self.best_distance {
self.best_distance = candidate.distance;
self.no_improvement_count = 0;
} else {
self.no_improvement_count += 1;
}
return Some(candidate);
}
}
None
}
}
#[cfg(feature = "hnsw")]
pub fn greedy_search_layer(
query: &[f32],
entry_point: u32,
layer: &crate::hnsw::graph::Layer,
vectors: &[f32],
dimension: usize,
ef: usize,
) -> Vec<(u32, f32)> {
use std::collections::BinaryHeap;
#[derive(PartialEq)]
struct MinCandidate {
id: u32,
distance: f32,
}
impl Eq for MinCandidate {}
impl Ord for MinCandidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.distance.total_cmp(&self.distance)
}
}
impl PartialOrd for MinCandidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(PartialEq)]
struct MaxResult {
id: u32,
distance: f32,
}
impl Eq for MaxResult {}
impl Ord for MaxResult {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance.total_cmp(&other.distance)
}
}
impl PartialOrd for MaxResult {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
let mut candidates: BinaryHeap<MinCandidate> = BinaryHeap::with_capacity(ef * 2);
let mut results: BinaryHeap<MaxResult> = BinaryHeap::with_capacity(ef + 1);
let mut visited = std::collections::HashSet::with_capacity(ef * 2);
let entry_vector = get_vector(vectors, dimension, entry_point as usize);
let entry_distance = cosine_distance(query, entry_vector);
candidates.push(MinCandidate {
id: entry_point,
distance: entry_distance,
});
results.push(MaxResult {
id: entry_point,
distance: entry_distance,
});
visited.insert(entry_point);
while let Some(candidate) = candidates.pop() {
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if candidate.distance > worst_dist && results.len() >= ef {
break;
}
let neighbors = layer.get_neighbors(candidate.id);
for &neighbor_id in neighbors.iter() {
if visited.insert(neighbor_id) {
let neighbor_vector = get_vector(vectors, dimension, neighbor_id as usize);
let neighbor_distance = cosine_distance(query, neighbor_vector);
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if results.len() < ef || neighbor_distance < worst_dist {
candidates.push(MinCandidate {
id: neighbor_id,
distance: neighbor_distance,
});
results.push(MaxResult {
id: neighbor_id,
distance: neighbor_distance,
});
if results.len() > ef {
results.pop();
}
}
}
}
}
let mut output: Vec<(u32, f32)> = results.into_iter().map(|r| (r.id, r.distance)).collect();
output.sort_by(|a, b| a.1.total_cmp(&b.1));
output
}
fn get_vector(vectors: &[f32], dimension: usize, idx: usize) -> &[f32] {
let start = idx * dimension;
let end = start + dimension;
&vectors[start..end]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_candidate_ordering() {
let mut heap = BinaryHeap::new();
heap.push(Candidate {
id: 0,
distance: 0.5,
});
heap.push(Candidate {
id: 1,
distance: 0.1,
});
heap.push(Candidate {
id: 2,
distance: 0.3,
});
assert_eq!(heap.pop().unwrap().distance, 0.1);
assert_eq!(heap.pop().unwrap().distance, 0.3);
assert_eq!(heap.pop().unwrap().distance, 0.5);
}
}