use std::collections::{BinaryHeap, HashMap, HashSet};
use crate::Result;
pub type Adjacency = HashMap<u64, Vec<u64>>;
pub type Positions = HashMap<u64, Vec<f32>>;
#[derive(Debug, Clone)]
pub struct InDescentOptions {
pub alpha: f32,
pub epsilon: f32,
pub ef: usize,
pub entry_points: usize,
}
impl Default for InDescentOptions {
fn default() -> Self {
Self { alpha: 0.2, epsilon: 0.05, ef: 64, entry_points: 1 }
}
}
pub fn search(
adjacency: &Adjacency,
positions: &Positions,
centrality: &HashMap<u64, f32>,
entry: &[u64],
query: &[f32],
k: usize,
prefilter: Option<&dyn Fn(u64) -> bool>,
opts: &InDescentOptions,
) -> Result<Vec<(u64, f32)>> {
if k == 0 || entry.is_empty() {
return Ok(Vec::new());
}
let dist_to = |id: &u64| -> f32 {
positions
.get(id)
.map(|v| l2_dist(query, v))
.unwrap_or(f32::INFINITY)
};
let mut visited: HashSet<u64> = HashSet::new();
let mut best: BinaryHeap<Candidate> = BinaryHeap::new();
let mut frontier: BinaryHeap<std::cmp::Reverse<Candidate>> = BinaryHeap::new();
for &eid in entry.iter().take(opts.entry_points.max(1)) {
if !visited.insert(eid) {
continue;
}
if prefilter.map_or(false, |f| !f(eid)) {
continue;
}
let d = dist_to(&eid);
let c = centrality.get(&eid).copied().unwrap_or(0.0);
let cand = Candidate { id: eid, dist: d, centrality: c };
frontier.push(std::cmp::Reverse(cand));
best.push(cand);
if best.len() > opts.ef {
best.pop();
}
}
while let Some(std::cmp::Reverse(current)) = frontier.pop() {
if let Some(worst) = best.peek() {
if current.dist > worst.dist && best.len() >= opts.ef {
break;
}
}
let Some(neighbours) = adjacency.get(¤t.id) else { continue };
for &n in neighbours {
if !visited.insert(n) {
continue;
}
if prefilter.map_or(false, |f| !f(n)) {
continue;
}
let d = dist_to(&n);
let c = centrality.get(&n).copied().unwrap_or(0.0);
let cand = Candidate { id: n, dist: d, centrality: c };
let push = match best.peek() {
Some(worst) if best.len() >= opts.ef => {
if cand.dist < worst.dist {
true
} else if (cand.dist - worst.dist).abs()
<= worst.dist * opts.epsilon
&& cand.centrality > worst.centrality
{
true
} else {
false
}
}
_ => true,
};
if push {
best.push(cand);
if best.len() > opts.ef {
best.pop();
}
frontier.push(std::cmp::Reverse(cand));
}
}
}
let alpha = opts.alpha.clamp(0.0, 1.0);
let mut out: Vec<(u64, f32)> = best
.into_iter()
.map(|c| {
let d_norm = c.dist; let score = (1.0 - alpha) * d_norm - alpha * c.centrality;
(c.id, score)
})
.collect();
out.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
});
out.truncate(k);
Ok(out)
}
#[derive(Copy, Clone, Debug)]
struct Candidate {
id: u64,
dist: f32,
centrality: f32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self
.dist
.partial_cmp(&other.dist)
.unwrap_or(std::cmp::Ordering::Equal)
{
std::cmp::Ordering::Equal => {
other
.centrality
.partial_cmp(&self.centrality)
.unwrap_or(std::cmp::Ordering::Equal)
}
o => o,
}
}
}
fn l2_dist(a: &[f32], b: &[f32]) -> f32 {
let n = a.len().min(b.len());
let mut s = 0.0f32;
for i in 0..n {
let d = a[i] - b[i];
s += d * d;
}
s.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
fn line_graph() -> (Adjacency, Positions) {
let mut adj: Adjacency = HashMap::new();
for i in 0u64..5 {
let mut n = Vec::new();
if i > 0 {
n.push(i - 1);
}
if i < 4 {
n.push(i + 1);
}
adj.insert(i, n);
}
let mut pos: Positions = HashMap::new();
for i in 0u64..5 {
pos.insert(i, vec![i as f32]);
}
(adj, pos)
}
#[test]
fn search_finds_nearest_in_line_graph() {
let (adj, pos) = line_graph();
let cent: HashMap<u64, f32> = HashMap::new();
let opts = InDescentOptions::default();
let r = search(&adj, &pos, ¢, &[0], &[2.5], 1, None, &opts).unwrap();
assert_eq!(r.len(), 1);
assert!(r[0].0 == 2 || r[0].0 == 3, "got {r:?}");
}
#[test]
fn centrality_breaks_tie_within_epsilon() {
let (adj, pos) = line_graph();
let mut cent: HashMap<u64, f32> = HashMap::new();
cent.insert(3, 1.0);
let opts = InDescentOptions { alpha: 0.9, ..Default::default() };
let r = search(&adj, &pos, ¢, &[0], &[2.5], 1, None, &opts).unwrap();
assert_eq!(r[0].0, 3);
}
#[test]
fn prefilter_drops_non_matches_during_descent() {
let (adj, pos) = line_graph();
let cent: HashMap<u64, f32> = HashMap::new();
let opts = InDescentOptions::default();
let pf = |id: u64| !(id == 2 || id == 3);
let r = search(&adj, &pos, ¢, &[0], &[2.5], 1, Some(&pf), &opts).unwrap();
assert_eq!(r.len(), 1);
assert!(r[0].0 == 1 || r[0].0 == 4, "got {r:?}");
}
#[test]
fn empty_inputs_safe() {
let r = search(
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
&[],
&[1.0],
5,
None,
&InDescentOptions::default(),
)
.unwrap();
assert!(r.is_empty());
}
}