use std::cell::RefCell;
use std::collections::{BinaryHeap, HashSet};
#[inline(always)]
#[allow(unsafe_code, unused_variables)]
pub(crate) fn prefetch_read_data(ptr: *const f32) {
#[cfg(target_arch = "x86_64")]
{
unsafe {
std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe {
std::arch::asm!("prfm pldl1keep, [{ptr}]", ptr = in(reg) ptr, options(nostack, preserves_flags));
}
}
}
const DENSE_VISITED_THRESHOLD: usize = 4_000_000;
pub(crate) enum VisitedSet {
Dense { marks: Vec<u8>, generation: u8 },
Sparse(HashSet<u32>),
}
impl VisitedSet {
fn new(num_nodes: usize, capacity_hint: usize) -> Self {
if num_nodes <= DENSE_VISITED_THRESHOLD {
VisitedSet::Dense {
marks: vec![0u8; num_nodes],
generation: 1,
}
} else {
VisitedSet::Sparse(HashSet::with_capacity(capacity_hint))
}
}
fn clear(&mut self) {
match self {
VisitedSet::Dense { marks, generation } => {
if let Some(next) = generation.checked_add(1) {
*generation = next;
} else {
marks.fill(0);
*generation = 1;
}
}
VisitedSet::Sparse(s) => s.clear(),
}
}
#[cfg(test)]
#[inline]
fn contains(&self, id: u32) -> bool {
match self {
VisitedSet::Dense { marks, generation } => {
let idx = id as usize;
idx < marks.len() && marks[idx] == *generation
}
VisitedSet::Sparse(s) => s.contains(&id),
}
}
#[inline]
pub(crate) fn insert(&mut self, id: u32) -> bool {
match self {
VisitedSet::Dense { marks, generation } => {
let idx = id as usize;
debug_assert!(
idx < marks.len(),
"VisitedSet::insert: id {} out of bounds (capacity {})",
id,
marks.len()
);
if idx < marks.len() {
if marks[idx] != *generation {
marks[idx] = *generation;
true
} else {
false
}
} else {
true
}
}
VisitedSet::Sparse(s) => s.insert(id),
}
}
fn prepare(&mut self, num_nodes: usize, capacity_hint: usize) {
match self {
VisitedSet::Dense { marks, .. } if num_nodes <= DENSE_VISITED_THRESHOLD => {
if marks.len() < num_nodes {
marks.resize(num_nodes, 0);
}
self.clear();
}
VisitedSet::Sparse(s) if num_nodes > DENSE_VISITED_THRESHOLD => {
s.clear();
}
_ => {
*self = VisitedSet::new(num_nodes, capacity_hint);
}
}
}
}
thread_local! {
static THREAD_VISITED: RefCell<VisitedSet> = const { RefCell::new(
VisitedSet::Dense { marks: Vec::new(), generation: 1 }
) };
}
pub(crate) fn with_visited_set<F, R>(num_nodes: usize, capacity_hint: usize, f: F) -> R
where
F: FnOnce(&mut VisitedSet) -> R,
{
THREAD_VISITED.with(|cell| {
let mut visited = cell.borrow_mut();
visited.prepare(num_nodes, capacity_hint);
f(&mut visited)
})
}
#[derive(PartialEq)]
pub(crate) struct MinCandidate {
pub(crate) id: u32,
pub(crate) distance: f32,
}
impl Eq for MinCandidate {}
impl Ord for MinCandidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.distance.total_cmp(&self.distance)
}
}
impl PartialOrd for MinCandidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(PartialEq)]
pub(crate) struct MaxResult {
pub(crate) id: u32,
pub(crate) distance: f32,
}
impl Eq for MaxResult {}
impl Ord for MaxResult {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance.total_cmp(&other.distance)
}
}
impl PartialOrd for MaxResult {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[inline]
fn flush_batch(
query: &[f32],
batch_ids: &[u32; 4],
count: usize,
vectors: &[f32],
dimension: usize,
dist_fn: fn(&[f32], &[f32]) -> f32,
candidates: &mut std::collections::BinaryHeap<MinCandidate>,
results: &mut std::collections::BinaryHeap<MaxResult>,
ef: usize,
) {
let mut dists = [0.0f32; 4];
for i in 0..count {
let vec = get_vector(vectors, dimension, batch_ids[i] as usize);
dists[i] = dist_fn(query, vec);
}
for i in 0..count {
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if results.len() < ef || dists[i] < worst_dist {
candidates.push(MinCandidate {
id: batch_ids[i],
distance: dists[i],
});
results.push(MaxResult {
id: batch_ids[i],
distance: dists[i],
});
if results.len() > ef {
results.pop();
}
}
}
}
#[inline]
fn flush_batch_custom<F: Fn(&[f32], u32) -> f32>(
query: &[f32],
batch_ids: &[u32; 4],
count: usize,
dist_fn: &F,
candidates: &mut std::collections::BinaryHeap<MinCandidate>,
results: &mut std::collections::BinaryHeap<MaxResult>,
ef: usize,
) {
let mut dists = [0.0f32; 4];
for i in 0..count {
dists[i] = dist_fn(query, batch_ids[i]);
}
for i in 0..count {
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if results.len() < ef || dists[i] < worst_dist {
candidates.push(MinCandidate {
id: batch_ids[i],
distance: dists[i],
});
results.push(MaxResult {
id: batch_ids[i],
distance: dists[i],
});
if results.len() > ef {
results.pop();
}
}
}
}
#[cfg(feature = "hnsw")]
pub fn greedy_search_layer(
query: &[f32],
entry_point: u32,
layer: &crate::hnsw::graph::Layer,
vectors: &[f32],
dimension: usize,
ef: usize,
dist_fn: fn(&[f32], &[f32]) -> f32,
) -> Vec<(u32, f32)> {
let num_vectors = vectors.len() / dimension;
with_visited_set(num_vectors, ef * 2, |visited| {
let mut candidates: BinaryHeap<MinCandidate> = BinaryHeap::with_capacity(ef * 2);
let mut results: BinaryHeap<MaxResult> = BinaryHeap::with_capacity(ef + 1);
let entry_vector = get_vector(vectors, dimension, entry_point as usize);
let entry_distance = dist_fn(query, entry_vector);
candidates.push(MinCandidate {
id: entry_point,
distance: entry_distance,
});
results.push(MaxResult {
id: entry_point,
distance: entry_distance,
});
visited.insert(entry_point);
while let Some(candidate) = candidates.pop() {
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if candidate.distance > worst_dist && results.len() >= ef {
break;
}
let neighbors = layer.get_neighbors(candidate.id);
let mut batch_ids: [u32; 4] = [0; 4];
let mut batch_count = 0usize;
for &neighbor_id in neighbors.iter() {
if visited.insert(neighbor_id) {
batch_ids[batch_count] = neighbor_id;
batch_count += 1;
if (neighbor_id as usize) < num_vectors {
let ptr = vectors
.as_ptr()
.wrapping_add(neighbor_id as usize * dimension);
prefetch_read_data(ptr);
if dimension > 16 {
prefetch_read_data(ptr.wrapping_add(16));
}
}
if batch_count == 4 {
flush_batch(
query,
&batch_ids,
batch_count,
vectors,
dimension,
dist_fn,
&mut candidates,
&mut results,
ef,
);
batch_count = 0;
}
}
}
if batch_count > 0 {
flush_batch(
query,
&batch_ids,
batch_count,
vectors,
dimension,
dist_fn,
&mut candidates,
&mut results,
ef,
);
}
}
let mut output: Vec<(u32, f32)> = results.into_iter().map(|r| (r.id, r.distance)).collect();
output.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
output
})
}
#[cfg(feature = "hnsw")]
pub fn greedy_search_layer_multi_entry(
query: &[f32],
entry_points: &[u32],
layer: &crate::hnsw::graph::Layer,
vectors: &[f32],
dimension: usize,
ef: usize,
dist_fn: fn(&[f32], &[f32]) -> f32,
) -> Vec<(u32, f32)> {
if entry_points.is_empty() {
return Vec::new();
}
let num_vectors = vectors.len() / dimension;
with_visited_set(num_vectors, ef * 2, |visited| {
let mut candidates: BinaryHeap<MinCandidate> = BinaryHeap::with_capacity(ef * 2);
let mut results: BinaryHeap<MaxResult> = BinaryHeap::with_capacity(ef + 1);
for &ep in entry_points {
if visited.insert(ep) {
let ep_vec = get_vector(vectors, dimension, ep as usize);
let ep_dist = dist_fn(query, ep_vec);
candidates.push(MinCandidate {
id: ep,
distance: ep_dist,
});
results.push(MaxResult {
id: ep,
distance: ep_dist,
});
if results.len() > ef {
results.pop();
}
}
}
while let Some(candidate) = candidates.pop() {
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if candidate.distance > worst_dist && results.len() >= ef {
break;
}
let neighbors = layer.get_neighbors(candidate.id);
for (i, &neighbor_id) in neighbors.iter().enumerate() {
if i + 1 < neighbors.len() {
let next_id = neighbors[i + 1] as usize;
if next_id < num_vectors {
let ptr = vectors.as_ptr().wrapping_add(next_id * dimension);
prefetch_read_data(ptr);
prefetch_read_data(ptr.wrapping_add(16));
}
}
if i + 4 < neighbors.len() {
let far_id = neighbors[i + 4] as usize;
if far_id < num_vectors {
prefetch_read_data(vectors.as_ptr().wrapping_add(far_id * dimension));
}
}
if visited.insert(neighbor_id) {
let neighbor_vector = get_vector(vectors, dimension, neighbor_id as usize);
let neighbor_distance = dist_fn(query, neighbor_vector);
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if results.len() < ef || neighbor_distance < worst_dist {
candidates.push(MinCandidate {
id: neighbor_id,
distance: neighbor_distance,
});
results.push(MaxResult {
id: neighbor_id,
distance: neighbor_distance,
});
if results.len() > ef {
results.pop();
}
}
}
}
}
let mut output: Vec<(u32, f32)> = results.into_iter().map(|r| (r.id, r.distance)).collect();
output.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
output
})
}
#[cfg(feature = "hnsw")]
pub fn greedy_search_layer_custom<F: Fn(&[f32], u32) -> f32>(
query: &[f32],
entry_point: u32,
layer: &crate::hnsw::graph::Layer,
vectors: &[f32],
dimension: usize,
ef: usize,
dist_fn: &F,
) -> Vec<(u32, f32)> {
let num_vectors = vectors.len() / dimension;
with_visited_set(num_vectors, ef * 2, |visited| {
let mut candidates: BinaryHeap<MinCandidate> = BinaryHeap::with_capacity(ef * 2);
let mut results: BinaryHeap<MaxResult> = BinaryHeap::with_capacity(ef + 1);
let entry_distance = dist_fn(query, entry_point);
candidates.push(MinCandidate {
id: entry_point,
distance: entry_distance,
});
results.push(MaxResult {
id: entry_point,
distance: entry_distance,
});
visited.insert(entry_point);
while let Some(candidate) = candidates.pop() {
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if candidate.distance > worst_dist && results.len() >= ef {
break;
}
let neighbors = layer.get_neighbors(candidate.id);
let mut batch_ids: [u32; 4] = [0; 4];
let mut batch_count = 0usize;
for &neighbor_id in neighbors.iter() {
if visited.insert(neighbor_id) {
batch_ids[batch_count] = neighbor_id;
batch_count += 1;
if (neighbor_id as usize) < num_vectors {
let ptr = vectors
.as_ptr()
.wrapping_add(neighbor_id as usize * dimension);
prefetch_read_data(ptr);
if dimension > 16 {
prefetch_read_data(ptr.wrapping_add(16));
}
}
if batch_count == 4 {
flush_batch_custom(
query,
&batch_ids,
batch_count,
dist_fn,
&mut candidates,
&mut results,
ef,
);
batch_count = 0;
}
}
}
if batch_count > 0 {
flush_batch_custom(
query,
&batch_ids,
batch_count,
dist_fn,
&mut candidates,
&mut results,
ef,
);
}
}
let mut output: Vec<(u32, f32)> = results.into_iter().map(|r| (r.id, r.distance)).collect();
output.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
output
})
}
#[cfg(feature = "ivf_rabitq")]
pub fn greedy_search_layer_edge_aware<F: Fn(u32, u32, usize) -> f32>(
entry_point: u32,
entry_dist: f32,
layer: &crate::hnsw::graph::Layer,
num_vectors: usize,
ef: usize,
dist_fn: &F,
) -> Vec<(u32, f32)> {
with_visited_set(num_vectors, ef * 2, |visited| {
let mut candidates: BinaryHeap<MinCandidate> = BinaryHeap::with_capacity(ef * 2);
let mut results: BinaryHeap<MaxResult> = BinaryHeap::with_capacity(ef + 1);
candidates.push(MinCandidate {
id: entry_point,
distance: entry_dist,
});
results.push(MaxResult {
id: entry_point,
distance: entry_dist,
});
visited.insert(entry_point);
while let Some(candidate) = candidates.pop() {
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if candidate.distance > worst_dist && results.len() >= ef {
break;
}
let neighbors = layer.get_neighbors(candidate.id);
for (slot, &neighbor_id) in neighbors.iter().enumerate() {
if visited.insert(neighbor_id) {
let neighbor_distance = dist_fn(candidate.id, neighbor_id, slot);
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if results.len() < ef || neighbor_distance < worst_dist {
candidates.push(MinCandidate {
id: neighbor_id,
distance: neighbor_distance,
});
results.push(MaxResult {
id: neighbor_id,
distance: neighbor_distance,
});
if results.len() > ef {
results.pop();
}
}
}
}
}
let mut output: Vec<(u32, f32)> = results.into_iter().map(|r| (r.id, r.distance)).collect();
output.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
output
})
}
#[cfg(feature = "hnsw")]
pub fn greedy_search_layer_adaptive(
query: &[f32],
entry_point: u32,
layer: &crate::hnsw::graph::Layer,
vectors: &[f32],
dimension: usize,
ef: usize,
k: usize,
config: &crate::adaptive::AdaptiveConfig,
dist_fn: fn(&[f32], &[f32]) -> f32,
) -> (Vec<(u32, f32)>, usize) {
use crate::adaptive::EarlyTerminationOracle;
let num_vectors = vectors.len() / dimension;
with_visited_set(num_vectors, ef * 2, |visited| {
let mut candidates: BinaryHeap<MinCandidate> = BinaryHeap::with_capacity(ef * 2);
let mut results: BinaryHeap<MaxResult> = BinaryHeap::with_capacity(ef + 1);
let mut oracle = EarlyTerminationOracle::new(k, config.clone());
let entry_vector = get_vector(vectors, dimension, entry_point as usize);
let entry_distance = dist_fn(query, entry_vector);
oracle.observe(entry_distance);
candidates.push(MinCandidate {
id: entry_point,
distance: entry_distance,
});
results.push(MaxResult {
id: entry_point,
distance: entry_distance,
});
visited.insert(entry_point);
while let Some(candidate) = candidates.pop() {
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if candidate.distance > worst_dist && results.len() >= ef {
break;
}
let neighbors = layer.get_neighbors(candidate.id);
for (i, &neighbor_id) in neighbors.iter().enumerate() {
if i + 1 < neighbors.len() {
let next_id = neighbors[i + 1] as usize;
if next_id < num_vectors {
let ptr = vectors.as_ptr().wrapping_add(next_id * dimension);
prefetch_read_data(ptr);
prefetch_read_data(ptr.wrapping_add(16));
}
}
if i + 4 < neighbors.len() {
let far_id = neighbors[i + 4] as usize;
if far_id < num_vectors {
prefetch_read_data(vectors.as_ptr().wrapping_add(far_id * dimension));
}
}
if visited.insert(neighbor_id) {
let neighbor_vector = get_vector(vectors, dimension, neighbor_id as usize);
let neighbor_distance = dist_fn(query, neighbor_vector);
oracle.observe(neighbor_distance);
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if results.len() < ef || neighbor_distance < worst_dist {
candidates.push(MinCandidate {
id: neighbor_id,
distance: neighbor_distance,
});
results.push(MaxResult {
id: neighbor_id,
distance: neighbor_distance,
});
if results.len() > ef {
results.pop();
}
}
}
}
if oracle.should_terminate() && results.len() >= k {
break;
}
}
let num_evaluated = oracle.num_evaluated();
let mut output: Vec<(u32, f32)> = results.into_iter().map(|r| (r.id, r.distance)).collect();
output.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
(output, num_evaluated)
})
}
#[inline]
fn get_vector(vectors: &[f32], dimension: usize, idx: usize) -> &[f32] {
let start = idx * dimension;
let end = start + dimension;
&vectors[start..end]
}
#[cfg(feature = "hnsw")]
pub fn greedy_search_layer_prt(
query: &[f32],
entry_point: u32,
layer: &crate::hnsw::graph::Layer,
vectors: &[f32],
dimension: usize,
ef: usize,
dist_fn: fn(&[f32], &[f32]) -> f32,
prt: &crate::prt::ProbabilisticRoutingTest,
query_proj: &[f32],
tfb: &mut crate::prt::TestFeedbackBuffer,
) -> (Vec<(u32, f32)>, usize) {
let num_vectors = vectors.len() / dimension;
let mut full_dist_count: usize = 0;
let results = with_visited_set(num_vectors, ef * 2, |visited| {
let mut candidates: BinaryHeap<MinCandidate> = BinaryHeap::with_capacity(ef * 2);
let mut results: BinaryHeap<MaxResult> = BinaryHeap::with_capacity(ef + 1);
let entry_vector = get_vector(vectors, dimension, entry_point as usize);
let entry_distance = dist_fn(query, entry_vector);
full_dist_count += 1;
candidates.push(MinCandidate {
id: entry_point,
distance: entry_distance,
});
results.push(MaxResult {
id: entry_point,
distance: entry_distance,
});
visited.insert(entry_point);
while let Some(candidate) = candidates.pop() {
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if candidate.distance > worst_dist && results.len() >= ef {
break;
}
let neighbors = layer.get_neighbors(candidate.id);
for (i, &neighbor_id) in neighbors.iter().enumerate() {
if i + 1 < neighbors.len() {
let next_id = neighbors[i + 1] as usize;
if next_id < num_vectors {
let ptr = vectors.as_ptr().wrapping_add(next_id * dimension);
prefetch_read_data(ptr);
}
}
if !visited.insert(neighbor_id) {
continue;
}
let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if results.len() >= ef
&& !prt.should_compute_full_distance(query_proj, neighbor_id, worst_dist, tfb)
{
continue;
}
let neighbor_vector = get_vector(vectors, dimension, neighbor_id as usize);
let neighbor_distance = dist_fn(query, neighbor_vector);
full_dist_count += 1;
if results.len() < ef || neighbor_distance < worst_dist {
tfb.record_true_positive();
candidates.push(MinCandidate {
id: neighbor_id,
distance: neighbor_distance,
});
results.push(MaxResult {
id: neighbor_id,
distance: neighbor_distance,
});
if results.len() > ef {
results.pop();
}
} else {
tfb.record_false_positive();
}
}
}
let mut output: Vec<(u32, f32)> = results.into_iter().map(|r| (r.id, r.distance)).collect();
output.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
output
});
(results, full_dist_count)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_candidate_ordering() {
let mut heap = BinaryHeap::new();
heap.push(MinCandidate {
id: 0,
distance: 0.5,
});
heap.push(MinCandidate {
id: 1,
distance: 0.1,
});
heap.push(MinCandidate {
id: 2,
distance: 0.3,
});
assert_eq!(heap.pop().unwrap().distance, 0.1);
assert_eq!(heap.pop().unwrap().distance, 0.3);
assert_eq!(heap.pop().unwrap().distance, 0.5);
}
#[test]
fn test_visited_set_dense() {
let mut v = VisitedSet::new(100, 10);
assert!(!v.contains(5));
assert!(v.insert(5));
assert!(v.contains(5));
assert!(!v.insert(5)); }
#[test]
fn test_visited_set_dense_clear() {
let mut v = VisitedSet::new(100, 10);
assert!(v.insert(5));
assert!(v.contains(5));
v.clear();
assert!(!v.contains(5));
assert!(v.insert(5));
}
#[test]
fn test_visited_set_dense_generation_overflow() {
let mut v = VisitedSet::new(100, 10);
if let VisitedSet::Dense {
ref mut generation, ..
} = v
{
*generation = u8::MAX;
}
assert!(v.insert(5));
assert!(v.contains(5));
v.clear();
assert!(!v.contains(5));
assert!(v.insert(5));
assert!(v.contains(5));
}
#[test]
fn test_visited_set_sparse() {
let mut v = VisitedSet::new(DENSE_VISITED_THRESHOLD + 1, 10);
assert!(!v.contains(42));
assert!(v.insert(42));
assert!(v.contains(42));
assert!(!v.insert(42));
}
#[test]
fn test_visited_set_sparse_clear() {
let mut v = VisitedSet::new(DENSE_VISITED_THRESHOLD + 1, 10);
assert!(v.insert(42));
v.clear();
assert!(!v.contains(42));
assert!(v.insert(42));
}
}