use crate::RetrieveError;
use std::collections::{BinaryHeap, HashSet};
pub trait FilterPredicate: Sync {
fn matches(&self, node_id: u32) -> bool;
}
pub struct FnFilter<F: Fn(u32) -> bool + Sync>(pub F);
impl<F: Fn(u32) -> bool + Sync> FilterPredicate for FnFilter<F> {
fn matches(&self, node_id: u32) -> bool {
self.0(node_id)
}
}
pub struct MetadataFilterAdapter<'a> {
filter: &'a crate::filtering::MetadataFilter,
store: &'a crate::filtering::MetadataStore,
}
impl<'a> MetadataFilterAdapter<'a> {
pub fn new(
filter: &'a crate::filtering::MetadataFilter,
store: &'a crate::filtering::MetadataStore,
) -> Self {
Self { filter, store }
}
}
impl FilterPredicate for MetadataFilterAdapter<'_> {
fn matches(&self, doc_id: u32) -> bool {
self.store.matches(doc_id, self.filter)
}
}
pub struct NoFilter;
impl FilterPredicate for NoFilter {
fn matches(&self, _node_id: u32) -> bool {
true
}
}
#[derive(Clone, Debug)]
pub struct AcornConfig {
pub enable_two_hop: bool,
pub two_hop_threshold: f32,
pub max_two_hop_neighbors: usize,
pub ef_search: usize,
}
impl Default for AcornConfig {
fn default() -> Self {
Self {
enable_two_hop: true,
two_hop_threshold: 0.3, max_two_hop_neighbors: 32,
ef_search: 100,
}
}
}
struct SearchState {
visited: HashSet<u32>,
filtered_count: usize,
visited_count: usize,
}
impl SearchState {
fn new() -> Self {
Self {
visited: HashSet::new(),
filtered_count: 0,
visited_count: 0,
}
}
fn visit(&mut self, node_id: u32, passes_filter: bool) -> bool {
if self.visited.insert(node_id) {
self.visited_count += 1;
if passes_filter {
self.filtered_count += 1;
}
true
} else {
false
}
}
fn filter_ratio(&self) -> f32 {
if self.visited_count == 0 {
1.0
} else {
self.filtered_count as f32 / self.visited_count as f32
}
}
}
#[derive(Clone, Copy)]
struct Candidate {
node_id: u32,
distance: f32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance.total_cmp(&other.distance)
}
}
pub fn acorn_search<F, N, D>(
k: usize,
config: &AcornConfig,
filter: &F,
get_neighbors: N,
compute_distance: D,
entry_point: u32,
) -> Result<Vec<(u32, f32)>, RetrieveError>
where
F: FilterPredicate,
N: Fn(u32) -> Vec<u32>,
D: Fn(u32) -> f32,
{
let mut state = SearchState::new();
let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
let mut frontier: BinaryHeap<Candidate> = BinaryHeap::new();
let entry_passes = filter.matches(entry_point);
state.visit(entry_point, entry_passes);
let entry_dist = compute_distance(entry_point);
frontier.push(Candidate {
node_id: entry_point,
distance: -entry_dist, });
if entry_passes {
results.push(Candidate {
node_id: entry_point,
distance: entry_dist,
});
}
let mut worst_result_dist = f32::INFINITY;
while let Some(current) = frontier.pop() {
let current_dist = -current.distance;
let can_stop = results.len() >= k && current_dist > worst_result_dist * 1.5; if can_stop && state.filter_ratio() > 0.3 {
break;
}
let neighbors = get_neighbors(current.node_id);
let use_two_hop = config.enable_two_hop
&& (state.filter_ratio() < config.two_hop_threshold || results.len() < k);
for &neighbor in &neighbors {
let neighbor_passes = filter.matches(neighbor);
if !state.visit(neighbor, neighbor_passes) {
continue; }
let dist = compute_distance(neighbor);
if neighbor_passes {
results.push(Candidate {
node_id: neighbor,
distance: dist,
});
while results.len() > k {
results.pop();
}
if let Some(worst) = results.peek() {
worst_result_dist = worst.distance;
}
}
if dist < worst_result_dist * 2.0 || results.len() < k {
frontier.push(Candidate {
node_id: neighbor,
distance: -dist,
});
}
if !neighbor_passes && use_two_hop {
let two_hop_neighbors = get_neighbors(neighbor);
let mut two_hop_count = 0;
for &two_hop in &two_hop_neighbors {
if two_hop_count >= config.max_two_hop_neighbors {
break;
}
let two_hop_passes = filter.matches(two_hop);
if !state.visit(two_hop, two_hop_passes) {
continue;
}
let two_hop_dist = compute_distance(two_hop);
if two_hop_passes {
results.push(Candidate {
node_id: two_hop,
distance: two_hop_dist,
});
while results.len() > k {
results.pop();
}
if let Some(worst) = results.peek() {
worst_result_dist = worst.distance;
}
}
if two_hop_dist < worst_result_dist * 2.0 || results.len() < k {
frontier.push(Candidate {
node_id: two_hop,
distance: -two_hop_dist,
});
}
two_hop_count += 1;
}
}
}
if state.visited_count >= config.ef_search * 10 {
break;
}
}
let mut result_vec: Vec<(u32, f32)> = results
.into_iter()
.map(|c| (c.node_id, c.distance))
.collect();
result_vec.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
result_vec.truncate(k);
Ok(result_vec)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn mock_graph() -> (Vec<Vec<u32>>, Vec<f32>) {
let neighbors = vec![
vec![1, 2, 3], vec![0, 2, 3, 4], vec![0, 1, 3, 4, 5], vec![0, 1, 2, 4, 5, 6], vec![1, 2, 3, 5, 6, 7], vec![2, 3, 4, 6, 7, 8], vec![3, 4, 5, 7, 8, 9], vec![4, 5, 6, 8, 9], vec![5, 6, 7, 9], vec![6, 7, 8], ];
let distances = vec![0.5, 0.3, 0.6, 0.4, 0.7, 0.2, 0.8, 0.1, 0.9, 0.35];
(neighbors, distances)
}
#[test]
fn test_acorn_no_filter() {
let (neighbors, distances) = mock_graph();
let config = AcornConfig {
enable_two_hop: true,
two_hop_threshold: 0.3,
max_two_hop_neighbors: 32,
ef_search: 100,
};
let results = acorn_search(
5, &config,
&NoFilter,
|id| neighbors[id as usize].clone(),
|id| distances[id as usize],
0,
)
.unwrap();
assert!(!results.is_empty(), "Should find some results");
for i in 1..results.len() {
assert!(results[i - 1].1 <= results[i].1, "Results should be sorted");
}
}
#[test]
fn test_acorn_with_filter() {
let (neighbors, distances) = mock_graph();
let filter = FnFilter(|id: u32| id % 2 == 0);
let results = acorn_search(
3,
&AcornConfig::default(),
&filter,
|id| neighbors[id as usize].clone(),
|id| distances[id as usize],
0,
)
.unwrap();
for (id, _) in &results {
assert_eq!(id % 2, 0, "Node {} should be even", id);
}
}
#[test]
fn test_acorn_selective_filter() {
let (neighbors, distances) = mock_graph();
let filter = FnFilter(|id: u32| id >= 8);
let config = AcornConfig {
enable_two_hop: true,
two_hop_threshold: 0.8, max_two_hop_neighbors: 32,
ef_search: 100,
};
let results = acorn_search(
2,
&config,
&filter,
|id| neighbors[id as usize].clone(),
|id| distances[id as usize],
0,
)
.unwrap();
assert!(!results.is_empty(), "Should find at least one node >= 8");
for (id, _) in &results {
assert!(*id >= 8, "Node {} should be >= 8", id);
}
}
}