use std::cell::RefCell;
use std::collections::{BinaryHeap, HashSet};
#[inline(always)]
#[allow(unsafe_code, unused_variables)]
fn prefetch_read_data(ptr: *const f32) {
#[cfg(target_arch = "x86_64")]
{
unsafe {
std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe {
std::arch::asm!("prfm pldl1keep, [{ptr}]", ptr = in(reg) ptr, options(nostack, preserves_flags));
}
}
}
#[cfg(feature = "hnsw")]
use crate::distance::cosine_distance_normalized as cosine_distance;
const DENSE_VISITED_THRESHOLD: usize = 100_000;
enum VisitedSet {
Dense { marks: Vec<u8>, generation: u8 },
Sparse(HashSet<u32>),
}
impl VisitedSet {
fn new(num_nodes: usize, capacity_hint: usize) -> Self {
if num_nodes <= DENSE_VISITED_THRESHOLD {
VisitedSet::Dense {
marks: vec![0u8; num_nodes],
generation: 1,
}
} else {
VisitedSet::Sparse(HashSet::with_capacity(capacity_hint))
}
}
fn clear(&mut self) {
match self {
VisitedSet::Dense { marks, generation } => {
if let Some(next) = generation.checked_add(1) {
*generation = next;
} else {
marks.fill(0);
*generation = 1;
}
}
VisitedSet::Sparse(s) => s.clear(),
}
}
#[cfg(test)]
#[inline]
fn contains(&self, id: u32) -> bool {
match self {
VisitedSet::Dense { marks, generation } => {
let idx = id as usize;
idx < marks.len() && marks[idx] == *generation
}
VisitedSet::Sparse(s) => s.contains(&id),
}
}
#[inline]
fn insert(&mut self, id: u32) -> bool {
match self {
VisitedSet::Dense { marks, generation } => {
let idx = id as usize;
debug_assert!(
idx < marks.len(),
"VisitedSet::insert: id {} out of bounds (capacity {})",
id,
marks.len()
);
if idx < marks.len() {
if marks[idx] != *generation {
marks[idx] = *generation;
true
} else {
false
}
} else {
true
}
}
VisitedSet::Sparse(s) => s.insert(id),
}
}
fn prepare(&mut self, num_nodes: usize, capacity_hint: usize) {
match self {
VisitedSet::Dense { marks, .. } if num_nodes <= DENSE_VISITED_THRESHOLD => {
if marks.len() < num_nodes {
marks.resize(num_nodes, 0);
}
self.clear();
}
VisitedSet::Sparse(s) if num_nodes > DENSE_VISITED_THRESHOLD => {
s.clear();
}
_ => {
*self = VisitedSet::new(num_nodes, capacity_hint);
}
}
}
}
thread_local! {
static THREAD_VISITED: RefCell<VisitedSet> = const { RefCell::new(
VisitedSet::Dense { marks: Vec::new(), generation: 1 }
) };
}
fn with_visited_set<F, R>(num_nodes: usize, capacity_hint: usize, f: F) -> R
where
F: FnOnce(&mut VisitedSet) -> R,
{
THREAD_VISITED.with(|cell| {
let mut visited = cell.borrow_mut();
visited.prepare(num_nodes, capacity_hint);
f(&mut visited)
})
}
#[derive(PartialEq)]
struct MinCandidate {
id: u32,
distance: f32,
}
impl Eq for MinCandidate {}
impl Ord for MinCandidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.distance.total_cmp(&self.distance)
}
}
impl PartialOrd for MinCandidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(PartialEq)]
struct MaxResult {
id: u32,
distance: f32,
}
impl Eq for MaxResult {}
impl Ord for MaxResult {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance.total_cmp(&other.distance)
}
}
impl PartialOrd for MaxResult {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[cfg(feature = "hnsw")]
pub fn greedy_search_layer(
query: &[f32],
entry_point: u32,
layer: &crate::hnsw::graph::Layer,
vectors: &[f32],
dimension: usize,
ef: usize,
) -> Vec<(u32, f32)> {
let num_vectors = vectors.len() / dimension;
with_visited_set(num_vectors, ef * 2, |visited| {
let mut candidates: BinaryHeap<MinCandidate> = BinaryHeap::with_capacity(ef * 2);
let mut results: BinaryHeap<MaxResult> = BinaryHeap::with_capacity(ef + 1);
let entry_vector = get_vector(vectors, dimension, entry_point as usize);
let entry_distance = cosine_distance(query, entry_vector);
candidates.push(MinCandidate {
id: entry_point,
distance: entry_distance,
});
results.push(MaxResult {
id: entry_point,
distance: entry_distance,
});
visited.insert(entry_point);
while let Some(candidate) = candidates.pop() {
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if candidate.distance > worst_dist && results.len() >= ef {
break;
}
let neighbors = layer.get_neighbors(candidate.id);
for (i, &neighbor_id) in neighbors.iter().enumerate() {
if i + 1 < neighbors.len() {
let next_id = neighbors[i + 1] as usize;
if next_id < num_vectors {
prefetch_read_data(vectors.as_ptr().wrapping_add(next_id * dimension));
}
}
if visited.insert(neighbor_id) {
let neighbor_vector = get_vector(vectors, dimension, neighbor_id as usize);
let neighbor_distance = cosine_distance(query, neighbor_vector);
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if results.len() < ef || neighbor_distance < worst_dist {
candidates.push(MinCandidate {
id: neighbor_id,
distance: neighbor_distance,
});
results.push(MaxResult {
id: neighbor_id,
distance: neighbor_distance,
});
if results.len() > ef {
results.pop();
}
}
}
}
}
let mut output: Vec<(u32, f32)> = results.into_iter().map(|r| (r.id, r.distance)).collect();
output.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
output
})
}
#[cfg(all(feature = "hnsw", feature = "experimental"))]
pub fn greedy_search_layer_adaptive(
query: &[f32],
entry_point: u32,
layer: &crate::hnsw::graph::Layer,
vectors: &[f32],
dimension: usize,
ef: usize,
k: usize,
config: &crate::adaptive::AdaptiveConfig,
) -> (Vec<(u32, f32)>, usize) {
use crate::adaptive::EarlyTerminationOracle;
let num_vectors = vectors.len() / dimension;
with_visited_set(num_vectors, ef * 2, |visited| {
let mut candidates: BinaryHeap<MinCandidate> = BinaryHeap::with_capacity(ef * 2);
let mut results: BinaryHeap<MaxResult> = BinaryHeap::with_capacity(ef + 1);
let mut oracle = EarlyTerminationOracle::new(k, config.clone());
let entry_vector = get_vector(vectors, dimension, entry_point as usize);
let entry_distance = cosine_distance(query, entry_vector);
oracle.observe(entry_distance);
candidates.push(MinCandidate {
id: entry_point,
distance: entry_distance,
});
results.push(MaxResult {
id: entry_point,
distance: entry_distance,
});
visited.insert(entry_point);
while let Some(candidate) = candidates.pop() {
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if candidate.distance > worst_dist && results.len() >= ef {
break;
}
let neighbors = layer.get_neighbors(candidate.id);
for (i, &neighbor_id) in neighbors.iter().enumerate() {
if i + 1 < neighbors.len() {
let next_id = neighbors[i + 1] as usize;
if next_id < num_vectors {
prefetch_read_data(vectors.as_ptr().wrapping_add(next_id * dimension));
}
}
if visited.insert(neighbor_id) {
let neighbor_vector = get_vector(vectors, dimension, neighbor_id as usize);
let neighbor_distance = cosine_distance(query, neighbor_vector);
oracle.observe(neighbor_distance);
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if results.len() < ef || neighbor_distance < worst_dist {
candidates.push(MinCandidate {
id: neighbor_id,
distance: neighbor_distance,
});
results.push(MaxResult {
id: neighbor_id,
distance: neighbor_distance,
});
if results.len() > ef {
results.pop();
}
}
}
}
if oracle.should_terminate() && results.len() >= k {
break;
}
}
let num_evaluated = oracle.num_evaluated();
let mut output: Vec<(u32, f32)> = results.into_iter().map(|r| (r.id, r.distance)).collect();
output.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
(output, num_evaluated)
})
}
#[inline]
fn get_vector(vectors: &[f32], dimension: usize, idx: usize) -> &[f32] {
let start = idx * dimension;
let end = start + dimension;
&vectors[start..end]
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_candidate_ordering() {
let mut heap = BinaryHeap::new();
heap.push(MinCandidate {
id: 0,
distance: 0.5,
});
heap.push(MinCandidate {
id: 1,
distance: 0.1,
});
heap.push(MinCandidate {
id: 2,
distance: 0.3,
});
assert_eq!(heap.pop().unwrap().distance, 0.1);
assert_eq!(heap.pop().unwrap().distance, 0.3);
assert_eq!(heap.pop().unwrap().distance, 0.5);
}
#[test]
fn test_visited_set_dense() {
let mut v = VisitedSet::new(100, 10);
assert!(!v.contains(5));
assert!(v.insert(5));
assert!(v.contains(5));
assert!(!v.insert(5)); }
#[test]
fn test_visited_set_dense_clear() {
let mut v = VisitedSet::new(100, 10);
assert!(v.insert(5));
assert!(v.contains(5));
v.clear();
assert!(!v.contains(5));
assert!(v.insert(5));
}
#[test]
fn test_visited_set_dense_generation_overflow() {
let mut v = VisitedSet::new(100, 10);
if let VisitedSet::Dense {
ref mut generation, ..
} = v
{
*generation = u8::MAX;
}
assert!(v.insert(5));
assert!(v.contains(5));
v.clear();
assert!(!v.contains(5));
assert!(v.insert(5));
assert!(v.contains(5));
}
#[test]
fn test_visited_set_sparse() {
let mut v = VisitedSet::new(DENSE_VISITED_THRESHOLD + 1, 10);
assert!(!v.contains(42));
assert!(v.insert(42));
assert!(v.contains(42));
assert!(!v.insert(42));
}
#[test]
fn test_visited_set_sparse_clear() {
let mut v = VisitedSet::new(DENSE_VISITED_THRESHOLD + 1, 10);
assert!(v.insert(42));
v.clear();
assert!(!v.contains(42));
assert!(v.insert(42));
}
}