use super::super::layer::NodeId;
use super::super::ordered_float::OrderedFloat;
use super::search_pools::{
acquire_candidate_heap, acquire_result_heap, acquire_visited_set, release_candidate_heap,
release_result_heap, release_visited_set, BitVecVisited, CandidateHeap, ResultHeap,
};
use crate::perf_optimizations::ContiguousVectors;
use smallvec::SmallVec;
use std::cmp::Reverse;
pub(super) struct SearchState {
pub(super) candidates: CandidateHeap,
pub(super) results: ResultHeap,
pub(super) visited: BitVecVisited,
pub(super) stagnation_count: usize,
pub(super) cached_furthest: f32,
}
impl SearchState {
pub(super) fn new(capacity_hint: usize) -> Self {
Self {
candidates: acquire_candidate_heap(),
results: acquire_result_heap(),
visited: acquire_visited_set(capacity_hint),
stagnation_count: 0,
cached_furthest: f32::MAX,
}
}
#[inline]
pub(super) fn push_candidate(&mut self, node: NodeId, dist: f32) {
self.candidates.push(Reverse((OrderedFloat(dist), node)));
self.results.push((OrderedFloat(dist), node));
self.cached_furthest = self.results.peek().map_or(f32::MAX, |r| r.0 .0);
self.visited.insert(node);
}
#[inline]
pub(super) fn should_terminate(&self, c_dist: f32, ef: usize, stagnation_limit: usize) -> bool {
if c_dist > self.cached_furthest && self.results.len() >= ef {
return true;
}
stagnation_limit > 0 && self.stagnation_count >= stagnation_limit
}
#[inline]
pub(super) fn update_stagnation(&mut self, improved: bool) {
if improved {
self.stagnation_count = 0;
} else {
self.stagnation_count += 1;
}
}
pub(super) fn into_sorted_results(mut self, limit: Option<usize>) -> Vec<(NodeId, f32)> {
let results = std::mem::take(&mut self.results);
let mut result_vec: Vec<(NodeId, f32)> =
results.into_iter().map(|(d, n)| (n, d.0)).collect();
let cmp = |a: &(NodeId, f32), b: &(NodeId, f32)| a.1.total_cmp(&b.1);
if let Some(k) = limit {
crate::index::top_k_partial_sort(&mut result_vec, k, cmp);
} else {
result_vec.sort_by(cmp);
}
result_vec
}
}
impl Drop for SearchState {
fn drop(&mut self) {
let visited = std::mem::take(&mut self.visited);
if !visited.words.is_empty() {
release_visited_set(visited);
}
let candidates = std::mem::take(&mut self.candidates);
if candidates.capacity() > 0 {
release_candidate_heap(candidates);
}
let results = std::mem::take(&mut self.results);
if results.capacity() > 0 {
release_result_heap(results);
}
}
}
const GATHER_PREFETCH_AHEAD: usize = 2;
#[inline]
pub(super) fn gather_unvisited_neighbors<'a>(
neighbors: &[NodeId],
visited: &mut BitVecVisited,
vectors: &'a ContiguousVectors,
use_prefetch: bool,
) -> SmallVec<[(NodeId, &'a [f32]); 32]> {
let mut batch = SmallVec::new();
if use_prefetch {
for &neighbor in neighbors.iter().take(GATHER_PREFETCH_AHEAD) {
vectors.prefetch(neighbor);
}
}
for (i, &neighbor) in neighbors.iter().enumerate() {
if use_prefetch {
if let Some(&ahead) = neighbors.get(i + GATHER_PREFETCH_AHEAD) {
vectors.prefetch(ahead);
}
}
if visited.insert(neighbor) {
debug_assert!(
neighbor < vectors.len(),
"neighbor {neighbor} out of bounds (len {})",
vectors.len()
);
let vec = unsafe { vectors.get_unchecked(neighbor) };
batch.push((neighbor, vec));
}
}
batch
}
#[inline]
pub(super) fn process_batch_results(
batch: &[(NodeId, &[f32])],
distances: &[f32],
ef: usize,
state: &mut SearchState,
) -> bool {
let mut improved = false;
for (&(node_id, _), &dist) in batch.iter().zip(distances.iter()) {
if dist < state.cached_furthest || state.results.len() < ef {
state
.candidates
.push(Reverse((OrderedFloat(dist), node_id)));
state.results.push((OrderedFloat(dist), node_id));
if state.results.len() > ef {
state.results.pop();
state.cached_furthest = state.results.peek().map_or(f32::MAX, |r| r.0 .0);
} else if dist > state.cached_furthest {
state.cached_furthest = dist;
}
improved = true;
}
}
improved
}