use std::collections::{BinaryHeap, HashSet};
use nodedb_codec::vector_quant::codec::VectorCodec;
use crate::distance::scalar::l2_squared;
use crate::vamana::graph::VamanaGraph;
use crate::vamana::node_fetcher::NodeFetcher;
#[derive(Debug, Clone, PartialEq)]
pub struct BeamSearchResult {
pub id: u64,
pub distance: f32,
}
#[derive(Clone, Copy, PartialEq)]
struct Candidate {
dist: f32,
idx: u32,
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.dist
.partial_cmp(&self.dist)
.unwrap_or(std::cmp::Ordering::Equal)
.then(other.idx.cmp(&self.idx))
}
}
pub fn beam_search<C, F>(
graph: &VamanaGraph,
query: &C::Query,
codec: &C,
quantized: &[C::Quantized],
fetcher: &mut F,
k: usize,
l_search: usize,
) -> Vec<BeamSearchResult>
where
C: VectorCodec,
F: NodeFetcher,
{
if graph.is_empty() || quantized.is_empty() {
return Vec::new();
}
let l = l_search.max(k);
let mut visited: HashSet<u32> = HashSet::new();
let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
let mut result: BinaryHeap<std::cmp::Reverse<Candidate>> = BinaryHeap::new();
let entry_idx = graph.entry as u32;
let entry_dist = codec.exact_asymmetric_distance(query, &quantized[graph.entry]);
candidates.push(Candidate {
dist: entry_dist,
idx: entry_idx,
});
visited.insert(entry_idx);
let mut frontier_indices: Vec<u32> = Vec::new();
while let Some(current) = candidates.pop() {
if result.len() >= l
&& let Some(worst) = result.peek()
&& current.dist > worst.0.dist
{
break;
}
result.push(std::cmp::Reverse(current));
if result.len() > l {
result.pop(); }
frontier_indices.clear();
for &neighbor_idx in graph.neighbors(current.idx as usize) {
if !visited.contains(&neighbor_idx) && (neighbor_idx as usize) < quantized.len() {
frontier_indices.push(neighbor_idx);
}
}
fetcher.prefetch_batch(&frontier_indices);
for &neighbor_idx in &frontier_indices {
visited.insert(neighbor_idx);
let d = codec.exact_asymmetric_distance(query, &quantized[neighbor_idx as usize]);
candidates.push(Candidate {
dist: d,
idx: neighbor_idx,
});
}
}
let mut out: Vec<Candidate> = result.into_iter().map(|r| r.0).collect();
out.sort_by(|a, b| {
a.dist
.partial_cmp(&b.dist)
.unwrap_or(std::cmp::Ordering::Equal)
});
out.truncate(k);
out.into_iter()
.map(|c| BeamSearchResult {
id: graph.external_id(c.idx as usize),
distance: c.dist,
})
.collect()
}
pub fn rerank<F: NodeFetcher>(
candidates: Vec<BeamSearchResult>,
query_fp32: &[f32],
fetcher: &mut F,
graph: &VamanaGraph,
k: usize,
) -> Vec<BeamSearchResult> {
let id_to_idx: std::collections::HashMap<u64, usize> =
graph.iter().map(|(idx, node)| (node.id, idx)).collect();
let candidate_indices: Vec<u32> = candidates
.iter()
.filter_map(|c| id_to_idx.get(&c.id).map(|&i| i as u32))
.collect();
fetcher.prefetch_batch(&candidate_indices);
let mut reranked: Vec<BeamSearchResult> = candidates
.into_iter()
.filter_map(|c| {
let idx = *id_to_idx.get(&c.id)?;
let vec = fetcher.fetch_fp32(idx as u32)?;
let d = l2_squared(query_fp32, &vec);
Some(BeamSearchResult {
id: c.id,
distance: d,
})
})
.collect();
reranked.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
reranked.truncate(k);
reranked
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vamana::build::build_vamana;
use crate::vamana::node_fetcher::InMemoryFetcher;
struct L2Codec;
struct L2Quantized(Vec<f32>);
impl AsRef<nodedb_codec::vector_quant::layout::UnifiedQuantizedVector> for L2Quantized {
fn as_ref(&self) -> &nodedb_codec::vector_quant::layout::UnifiedQuantizedVector {
panic!("stub: UnifiedQuantizedVector not needed for L2Codec tests")
}
}
impl VectorCodec for L2Codec {
type Quantized = L2Quantized;
type Query = Vec<f32>;
fn encode(&self, v: &[f32]) -> Self::Quantized {
L2Quantized(v.to_vec())
}
fn prepare_query(&self, q: &[f32]) -> Self::Query {
q.to_vec()
}
fn fast_symmetric_distance(&self, a: &Self::Quantized, b: &Self::Quantized) -> f32 {
l2_squared(&a.0, &b.0)
}
fn exact_asymmetric_distance(&self, q: &Self::Query, v: &Self::Quantized) -> f32 {
l2_squared(q, &v.0)
}
}
fn random_vecs(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
let mut state = seed.max(1);
let mut xorshift = move || -> f32 {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
(state as f32) / (u64::MAX as f32)
};
(0..n)
.map(|_| (0..dim).map(|_| xorshift()).collect())
.collect()
}
#[test]
fn beam_search_finds_self_as_nearest() {
let dim = 8;
let n = 50;
let codec = L2Codec;
let vecs = random_vecs(n, dim, 42);
let ids: Vec<u64> = (0..n as u64).collect();
let quantized: Vec<L2Quantized> = vecs.iter().map(|v| codec.encode(v)).collect();
let graph = build_vamana(&vecs, &ids, &codec, &quantized, 8, 1.2, 20);
let query_vec = vecs[7].clone();
let query = codec.prepare_query(&query_vec);
let mut fetcher = InMemoryFetcher::new(dim, vecs.clone());
let results = beam_search(&graph, &query, &codec, &quantized, &mut fetcher, 5, 20);
assert!(
!results.is_empty(),
"beam_search must return at least one result"
);
assert_eq!(
results[0].id, 7,
"nearest result must be the query vector itself"
);
assert!(
results[0].distance < 1e-6,
"distance to self must be near zero"
);
}
}