use super::graph::*;
use super::simd::{
VisitedList, inner_product_distance, inner_product_distance_batch_4, l2_distance,
l2_distance_batch_4,
};
#[derive(Debug, Clone)]
pub struct SearchParams {
pub ef_search: usize,
pub beam_size: usize,
pub prune_ratio: f64,
pub recompute_embeddings: bool,
pub pruning_strategy: PruningStrategy,
pub batch_size: usize,
pub check_relative_distance: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PruningStrategy {
Global,
Local,
Proportional,
}
impl Default for SearchParams {
fn default() -> Self {
Self {
ef_search: 64,
beam_size: 1,
prune_ratio: 0.0,
recompute_embeddings: true,
pruning_strategy: PruningStrategy::Global,
batch_size: 0,
check_relative_distance: true,
}
}
}
pub(crate) struct FlatMinHeap {
dis: Vec<f32>,
ids: Vec<u32>,
len: usize,
}
impl FlatMinHeap {
#[inline]
pub(crate) fn new(capacity: usize) -> Self {
Self {
dis: Vec::with_capacity(capacity),
ids: Vec::with_capacity(capacity),
len: 0,
}
}
#[inline(always)]
pub(crate) fn clear(&mut self) {
self.len = 0;
}
#[inline(always)]
pub(crate) fn is_empty(&self) -> bool {
self.len == 0
}
#[inline(always)]
pub(crate) fn peek(&self) -> (f32, u32) {
debug_assert!(self.len > 0);
unsafe { (*self.dis.get_unchecked(0), *self.ids.get_unchecked(0)) }
}
#[inline]
pub(crate) fn push(&mut self, dis: f32, id: u32) {
let pos = self.len;
if pos == self.dis.len() {
self.dis.push(dis);
self.ids.push(id);
} else {
unsafe {
*self.dis.get_unchecked_mut(pos) = dis;
*self.ids.get_unchecked_mut(pos) = id;
}
}
self.len += 1;
self.sift_up(pos);
}
#[inline]
pub(crate) fn pop(&mut self) -> (f32, u32) {
debug_assert!(self.len > 0);
unsafe {
let dis = *self.dis.get_unchecked(0);
let id = *self.ids.get_unchecked(0);
self.len -= 1;
if self.len > 0 {
*self.dis.get_unchecked_mut(0) = *self.dis.get_unchecked(self.len);
*self.ids.get_unchecked_mut(0) = *self.ids.get_unchecked(self.len);
self.sift_down(0);
}
(dis, id)
}
}
#[inline]
fn sift_up(&mut self, mut pos: usize) {
unsafe {
let d = *self.dis.get_unchecked(pos);
let id = *self.ids.get_unchecked(pos);
while pos > 0 {
let parent = (pos - 1) >> 1;
let pd = *self.dis.get_unchecked(parent);
if d < pd {
*self.dis.get_unchecked_mut(pos) = pd;
*self.ids.get_unchecked_mut(pos) = *self.ids.get_unchecked(parent);
pos = parent;
} else {
break;
}
}
*self.dis.get_unchecked_mut(pos) = d;
*self.ids.get_unchecked_mut(pos) = id;
}
}
#[inline]
fn sift_down(&mut self, mut pos: usize) {
let n = self.len;
unsafe {
let d = *self.dis.get_unchecked(pos);
let id = *self.ids.get_unchecked(pos);
loop {
let left = 2 * pos + 1;
if left >= n {
break;
}
let right = left + 1;
let mut smallest = left;
if right < n && *self.dis.get_unchecked(right) < *self.dis.get_unchecked(left) {
smallest = right;
}
let sd = *self.dis.get_unchecked(smallest);
if sd < d {
*self.dis.get_unchecked_mut(pos) = sd;
*self.ids.get_unchecked_mut(pos) = *self.ids.get_unchecked(smallest);
pos = smallest;
} else {
break;
}
}
*self.dis.get_unchecked_mut(pos) = d;
*self.ids.get_unchecked_mut(pos) = id;
}
}
}
pub(crate) struct FlatMaxHeap {
dis: Vec<f32>,
ids: Vec<u32>,
len: usize,
}
impl FlatMaxHeap {
#[inline]
pub(crate) fn new(capacity: usize) -> Self {
Self {
dis: Vec::with_capacity(capacity),
ids: Vec::with_capacity(capacity),
len: 0,
}
}
#[inline(always)]
pub(crate) fn clear(&mut self) {
self.len = 0;
}
#[inline(always)]
pub(crate) fn len(&self) -> usize {
self.len
}
#[inline(always)]
pub(crate) fn peek_max_dis(&self) -> f32 {
debug_assert!(self.len > 0);
unsafe { *self.dis.get_unchecked(0) }
}
#[inline]
pub(crate) fn pop_max(&mut self) -> (f32, u32) {
debug_assert!(self.len > 0);
unsafe {
let dis = *self.dis.get_unchecked(0);
let id = *self.ids.get_unchecked(0);
self.len -= 1;
if self.len > 0 {
*self.dis.get_unchecked_mut(0) = *self.dis.get_unchecked(self.len);
*self.ids.get_unchecked_mut(0) = *self.ids.get_unchecked(self.len);
self.sift_down(0);
}
(dis, id)
}
}
#[inline]
pub(crate) fn push(&mut self, dis: f32, id: u32) {
let pos = self.len;
if pos == self.dis.len() {
self.dis.push(dis);
self.ids.push(id);
} else {
unsafe {
*self.dis.get_unchecked_mut(pos) = dis;
*self.ids.get_unchecked_mut(pos) = id;
}
}
self.len += 1;
self.sift_up(pos);
}
#[inline]
pub(crate) fn replace_max(&mut self, dis: f32, id: u32) {
debug_assert!(self.len > 0);
self.dis[0] = dis;
self.ids[0] = id;
self.sift_down(0);
}
pub(crate) fn drain_sorted(&mut self) -> (&[u32], &[f32]) {
let n = self.len;
while self.len > 1 {
self.len -= 1;
self.dis.swap(0, self.len);
self.ids.swap(0, self.len);
self.sift_down(0);
}
self.len = n;
(&self.ids[..n], &self.dis[..n])
}
#[inline]
fn sift_up(&mut self, mut pos: usize) {
unsafe {
let d = *self.dis.get_unchecked(pos);
let id = *self.ids.get_unchecked(pos);
while pos > 0 {
let parent = (pos - 1) >> 1;
let pd = *self.dis.get_unchecked(parent);
if d > pd {
*self.dis.get_unchecked_mut(pos) = pd;
*self.ids.get_unchecked_mut(pos) = *self.ids.get_unchecked(parent);
pos = parent;
} else {
break;
}
}
*self.dis.get_unchecked_mut(pos) = d;
*self.ids.get_unchecked_mut(pos) = id;
}
}
#[inline]
fn sift_down(&mut self, mut pos: usize) {
let n = self.len;
unsafe {
let d = *self.dis.get_unchecked(pos);
let id = *self.ids.get_unchecked(pos);
loop {
let left = 2 * pos + 1;
if left >= n {
break;
}
let right = left + 1;
let mut largest = left;
if right < n && *self.dis.get_unchecked(right) > *self.dis.get_unchecked(left) {
largest = right;
}
let ld = *self.dis.get_unchecked(largest);
if ld > d {
*self.dis.get_unchecked_mut(pos) = ld;
*self.ids.get_unchecked_mut(pos) = *self.ids.get_unchecked(largest);
pos = largest;
} else {
break;
}
}
*self.dis.get_unchecked_mut(pos) = d;
*self.ids.get_unchecked_mut(pos) = id;
}
}
}
#[derive(Debug)]
pub struct SearchResults {
pub labels: Vec<Vec<usize>>,
pub distances: Vec<Vec<f32>>,
}
pub struct SearchBuffers {
visited: VisitedList,
candidates: FlatMinHeap,
results: FlatMaxHeap,
}
impl SearchBuffers {
pub fn new(ntotal: usize) -> Self {
Self {
visited: VisitedList::new(ntotal),
candidates: FlatMinHeap::new(0),
results: FlatMaxHeap::new(0),
}
}
}
pub fn search_hnsw(
graph: &HnswGraph,
query: &[f32],
top_k: usize,
vectors: &[f32], params: &SearchParams,
) -> (Vec<usize>, Vec<f32>) {
let mut buffers = SearchBuffers::new(graph.ntotal);
search_hnsw_buf(graph, query, top_k, vectors, params, &mut buffers)
}
pub fn search_hnsw_buf(
graph: &HnswGraph,
query: &[f32],
top_k: usize,
vectors: &[f32],
params: &SearchParams,
buffers: &mut SearchBuffers,
) -> (Vec<usize>, Vec<f32>) {
match graph.config.distance_metric {
crate::index::DistanceMetric::L2 => search_hnsw_inner(
graph,
query,
top_k,
vectors,
params,
buffers,
l2_distance,
l2_distance_batch_4,
),
_ => search_hnsw_inner(
graph,
query,
top_k,
vectors,
params,
buffers,
inner_product_distance,
inner_product_distance_batch_4,
),
}
}
#[inline(always)]
unsafe fn get_vec(vectors: &[f32], id: usize, dim: usize) -> &[f32] {
debug_assert!(id * dim + dim <= vectors.len());
unsafe { vectors.get_unchecked(id * dim..id * dim + dim) }
}
#[allow(clippy::too_many_arguments)]
fn search_hnsw_inner<D, B>(
graph: &HnswGraph,
query: &[f32],
top_k: usize,
vectors: &[f32],
params: &SearchParams,
buffers: &mut SearchBuffers,
dist_fn: D,
dist_batch_4: B,
) -> (Vec<usize>, Vec<f32>)
where
D: Fn(&[f32], &[f32]) -> f32,
B: Fn(&[f32], &[f32], &[f32], &[f32], &[f32]) -> [f32; 4],
{
let d = graph.dimensions;
let ef = params.ef_search.max(top_k);
let SearchBuffers {
visited,
candidates,
results,
} = buffers;
assert!(vectors.len() >= graph.ntotal * d);
assert!(visited.len() >= graph.ntotal);
let mut curr = graph.entry_point as usize;
let mut d_curr = unsafe { dist_fn(query, get_vec(vectors, curr, d)) };
for level in (1..=graph.max_level as usize).rev() {
loop {
let mut changed = false;
let neighbors = graph.get_neighbors(curr, level);
for &nb in neighbors {
if nb < 0 {
break;
}
let nb = nb as usize;
let d_nb = unsafe { dist_fn(query, get_vec(vectors, nb, d)) };
if d_nb < d_curr {
curr = nb;
d_curr = d_nb;
changed = true;
}
}
if !changed {
break;
}
}
}
candidates.clear();
results.clear();
visited.reset();
let d_entry = unsafe { dist_fn(query, get_vec(vectors, curr, d)) };
candidates.push(d_entry, curr as u32);
results.push(d_entry, curr as u32);
visited.set(curr);
let mut worst_dist = d_entry;
let mut saved: [u32; 4] = [0; 4];
while !candidates.is_empty() {
let (cand_dist, cand_id) = candidates.pop();
if results.len() >= ef && cand_dist > worst_dist {
break;
}
let neighbors = graph.get_neighbors(cand_id as usize, 0);
for &nb in neighbors {
if nb < 0 {
break;
}
visited.prefetch_l2(nb as usize);
}
let mut counter = 0;
for &nb in neighbors {
if nb < 0 {
break;
}
if !visited.check_and_set(nb as usize) {
continue;
}
saved[counter] = nb as u32;
counter += 1;
unsafe {
let vptr = vectors.as_ptr().add(nb as usize * d) as *const u8;
#[cfg(target_arch = "aarch64")]
{
std::arch::asm!("prfm pldl1keep, [{ptr}]", ptr = in(reg) vptr, options(nostack, preserves_flags));
std::arch::asm!("prfm pldl1keep, [{ptr}]", ptr = in(reg) vptr.add(64), options(nostack, preserves_flags));
}
#[cfg(target_arch = "x86_64")]
{
std::arch::x86_64::_mm_prefetch(
vptr as *const i8,
std::arch::x86_64::_MM_HINT_T0,
);
std::arch::x86_64::_mm_prefetch(
vptr.add(64) as *const i8,
std::arch::x86_64::_MM_HINT_T0,
);
}
}
if counter == 4 {
let dists = unsafe {
dist_batch_4(
query,
get_vec(vectors, saved[0] as usize, d),
get_vec(vectors, saved[1] as usize, d),
get_vec(vectors, saved[2] as usize, d),
get_vec(vectors, saved[3] as usize, d),
)
};
for k in 0..4 {
let nb_id = saved[k];
let d_nb = dists[k];
if results.len() < ef {
candidates.push(d_nb, nb_id);
results.push(d_nb, nb_id);
if results.len() == ef {
worst_dist = results.peek_max_dis();
}
} else if d_nb < worst_dist {
candidates.push(d_nb, nb_id);
results.replace_max(d_nb, nb_id);
worst_dist = results.peek_max_dis();
}
}
counter = 0;
}
}
for &nb_id in &saved[..counter] {
let d_nb = unsafe { dist_fn(query, get_vec(vectors, nb_id as usize, d)) };
if results.len() < ef {
candidates.push(d_nb, nb_id);
results.push(d_nb, nb_id);
if results.len() == ef {
worst_dist = results.peek_max_dis();
}
} else if d_nb < worst_dist {
candidates.push(d_nb, nb_id);
results.replace_max(d_nb, nb_id);
worst_dist = results.peek_max_dis();
}
}
}
let (sorted_ids, sorted_dis) = results.drain_sorted();
let n = top_k.min(sorted_ids.len());
let labels: Vec<usize> = sorted_ids[..n].iter().map(|&id| id as usize).collect();
let distances: Vec<f32> = sorted_dis[..n].to_vec();
(labels, distances)
}
pub fn search_hnsw_recompute<F>(
graph: &HnswGraph,
query: &[f32],
top_k: usize,
params: &SearchParams,
compute_distance: F,
) -> (Vec<usize>, Vec<f32>)
where
F: FnMut(&[usize], &[f32], &mut [f32]),
{
let mut buffers = SearchBuffers::new(graph.ntotal);
search_hnsw_recompute_buf(graph, query, top_k, params, &mut buffers, compute_distance)
}
pub fn search_hnsw_recompute_buf<F>(
graph: &HnswGraph,
query: &[f32],
top_k: usize,
params: &SearchParams,
buffers: &mut SearchBuffers,
mut compute_distance: F,
) -> (Vec<usize>, Vec<f32>)
where
F: FnMut(&[usize], &[f32], &mut [f32]),
{
let ef = params.ef_search.max(top_k);
let max_neighbors = graph.neighbors_at_level(0);
let SearchBuffers {
visited,
candidates,
results,
} = buffers;
let mut node_buf = Vec::with_capacity(max_neighbors + 1);
let mut dist_buf = vec![0.0f32; max_neighbors + 1];
let mut curr = graph.entry_point as usize;
for level in (1..=graph.max_level as usize).rev() {
loop {
let mut changed = false;
let neighbors = graph.get_neighbors(curr, level);
node_buf.clear();
for &nb in neighbors {
if nb >= 0 {
node_buf.push(nb as usize);
}
}
if node_buf.is_empty() {
break;
}
node_buf.push(curr);
let n = node_buf.len();
compute_distance(&node_buf, query, &mut dist_buf[..n]);
let curr_dist = dist_buf[n - 1];
for (i, &nb) in node_buf.iter().enumerate().take(n - 1) {
if dist_buf[i] < curr_dist {
curr = nb;
changed = true;
}
}
if !changed {
break;
}
}
}
candidates.clear();
results.clear();
visited.reset();
node_buf.clear();
node_buf.push(curr);
compute_distance(&node_buf, query, &mut dist_buf[..1]);
let d_entry = dist_buf[0];
candidates.push(d_entry, curr as u32);
results.push(d_entry, curr as u32);
visited.set(curr);
let mut worst_dist = d_entry;
while !candidates.is_empty() {
let (cand_dist, cand_id) = candidates.pop();
if results.len() >= ef && cand_dist > worst_dist {
break;
}
let neighbors = graph.get_neighbors(cand_id as usize, 0);
node_buf.clear();
for &nb in neighbors {
if nb >= 0 {
let nb = nb as usize;
if visited.check_and_set(nb) {
node_buf.push(nb);
}
}
}
if node_buf.is_empty() {
continue;
}
let n = node_buf.len();
compute_distance(&node_buf, query, &mut dist_buf[..n]);
for i in 0..n {
let d_nb = dist_buf[i];
let nb = node_buf[i] as u32;
if results.len() < ef {
candidates.push(d_nb, nb);
results.push(d_nb, nb);
if results.len() == ef {
worst_dist = results.peek_max_dis();
}
} else if d_nb < worst_dist {
candidates.push(d_nb, nb);
results.replace_max(d_nb, nb);
worst_dist = results.peek_max_dis();
}
}
}
let (sorted_ids, sorted_dis) = results.drain_sorted();
let n = top_k.min(sorted_ids.len());
let labels: Vec<usize> = sorted_ids[..n].iter().map(|&id| id as usize).collect();
let distances: Vec<f32> = sorted_dis[..n].to_vec();
(labels, distances)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hnsw::build::build_hnsw;
use ndarray::Array2;
#[test]
fn test_search_with_stored_vectors() {
let data = Array2::from_shape_vec(
(4, 3),
vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.5, 0.5, 0.0],
)
.unwrap();
let config = HnswConfig {
m: 4,
ef_construction: 16,
ef_search: 16,
distance_metric: crate::index::DistanceMetric::L2,
is_compact: false,
is_recompute: false,
seed: None,
};
let graph = build_hnsw(&data, &config, None).unwrap();
let flat_vectors: Vec<f32> = data.iter().copied().collect();
let query = vec![1.0, 0.0, 0.0];
let params = SearchParams {
ef_search: 16,
..Default::default()
};
let (labels, distances) = search_hnsw(&graph, &query, 2, &flat_vectors, ¶ms);
assert_eq!(labels.len(), 2);
assert_eq!(distances.len(), 2);
assert_eq!(labels[0], 0);
assert!((distances[0] - 0.0).abs() < 1e-6);
}
#[test]
fn test_search_with_recompute() {
let data = Array2::from_shape_vec(
(4, 3),
vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.5, 0.5, 0.0],
)
.unwrap();
let config = HnswConfig {
m: 4,
ef_construction: 16,
ef_search: 16,
distance_metric: crate::index::DistanceMetric::L2,
is_compact: false,
is_recompute: true,
seed: None,
};
let graph = build_hnsw(&data, &config, None).unwrap();
let flat_vectors: Vec<f32> = data.iter().copied().collect();
let d = 3;
let query = vec![1.0, 0.0, 0.0];
let params = SearchParams {
ef_search: 16,
recompute_embeddings: true,
..Default::default()
};
let (labels, _distances) =
search_hnsw_recompute(&graph, &query, 2, ¶ms, |node_ids, q, out| {
for (i, &id) in node_ids.iter().enumerate() {
let vec = &flat_vectors[id * d..(id + 1) * d];
out[i] = vec
.iter()
.zip(q.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
}
});
assert_eq!(labels.len(), 2);
assert_eq!(labels[0], 0);
}
}