use crate::simd;
use crate::sng::graph::SNGIndex;
use crate::RetrieveError;
use std::collections::BinaryHeap;
#[derive(Clone, PartialEq)]
struct SearchCandidate {
id: u32,
distance: f32,
}
impl Eq for SearchCandidate {}
impl PartialOrd for SearchCandidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SearchCandidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance.total_cmp(&other.distance).reverse()
}
}
pub fn search_sng(
index: &SNGIndex,
query: &[f32],
k: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
if index.num_vectors == 0 {
return Ok(Vec::new());
}
let mut current = 0u32;
let mut current_dist = f32::INFINITY;
for i in 0..index.num_vectors.min(100) {
let vec = index.get_vector(i);
let dist = 1.0 - simd::dot(query, vec);
if dist < current_dist {
current_dist = dist;
current = i as u32;
}
}
thread_local! {
static VISITED: std::cell::RefCell<(Vec<u8>, u8)> =
const { std::cell::RefCell::new((Vec::new(), 1)) };
}
VISITED.with(|cell| {
let (marks, gen) = &mut *cell.borrow_mut();
let num_vectors = index.num_vectors;
if marks.len() < num_vectors {
marks.resize(num_vectors, 0);
}
if let Some(next) = gen.checked_add(1) {
*gen = next;
} else {
marks.fill(0);
*gen = 1;
}
let generation = *gen;
let mut visited_insert = |id: u32| -> bool {
let idx = id as usize;
if idx < marks.len() && marks[idx] != generation {
marks[idx] = generation;
true
} else if idx >= marks.len() {
true
} else {
false
}
};
let mut candidates = BinaryHeap::new();
let mut results = Vec::new();
candidates.push(SearchCandidate {
id: current,
distance: current_dist,
});
visited_insert(current);
let max_iterations = (index.num_vectors as f32).ln().ceil() as usize * 10;
let mut iterations = 0;
while let Some(candidate) = candidates.pop() {
results.push((candidate.id, candidate.distance));
if results.len() >= k {
break;
}
if iterations >= max_iterations {
break;
}
iterations += 1;
if let Some(neighbors) = index.neighbors.get(candidate.id as usize) {
for (i, &neighbor_id) in neighbors.iter().enumerate() {
if i + 1 < neighbors.len() {
let next_id = neighbors[i + 1] as usize;
let ptr = index.vectors.as_ptr().wrapping_add(next_id * index.dimension);
#[cfg(target_arch = "aarch64")]
unsafe {
std::arch::asm!("prfm pldl1keep, [{ptr}]", ptr = in(reg) ptr, options(nostack, preserves_flags));
}
#[cfg(target_arch = "x86_64")]
unsafe {
std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
}
}
if !visited_insert(neighbor_id) {
continue;
}
let neighbor_vec = index.get_vector(neighbor_id as usize);
let dist = 1.0 - simd::dot(query, neighbor_vec);
candidates.push(SearchCandidate {
id: neighbor_id,
distance: dist,
});
}
}
}
results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
Ok(results.into_iter().take(k).collect())
})
}