use std::collections::{HashMap, HashSet};
use crate::error::Error;
use crate::id::NodeId;
use crate::index::PropPredicate;
use crate::objects::Node;
use crate::repo::readonly::ReadonlyRepo;
pub fn score_normalized_fusion(
a_hits: &[(NodeId, f32)],
a_weight: f32,
b_hits: &[(NodeId, f32)],
b_weight: f32,
) -> Vec<(NodeId, f32)> {
fn normalize(hits: &[(NodeId, f32)]) -> HashMap<NodeId, f32> {
if hits.is_empty() {
return HashMap::new();
}
let scores: Vec<f32> = hits.iter().map(|(_, s)| *s).collect();
let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let min = scores.iter().copied().fold(f32::INFINITY, f32::min);
let range = max - min;
let mut out = HashMap::with_capacity(hits.len());
if range.abs() < 1e-12 {
for (id, _) in hits {
out.insert(*id, 0.5);
}
} else {
for (id, s) in hits {
out.insert(*id, (s - min) / range);
}
}
out
}
let a_norm = normalize(a_hits);
let b_norm = normalize(b_hits);
let mut all: HashMap<NodeId, f32> = HashMap::with_capacity(a_norm.len() + b_norm.len());
for (id, s) in &a_norm {
*all.entry(*id).or_insert(0.0) += a_weight * s;
}
for (id, s) in &b_norm {
*all.entry(*id).or_insert(0.0) += b_weight * s;
}
let mut fused: Vec<(NodeId, f32)> = all.into_iter().collect();
fused.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
fused
}
pub fn convex_min_max_fusion(lanes: &[(Vec<(NodeId, f32)>, f32)]) -> Vec<(NodeId, f32)> {
let mut totals: HashMap<NodeId, f32> = HashMap::new();
for (hits, weight) in lanes {
if hits.is_empty() || *weight == 0.0 {
continue;
}
let (min, max) = hits
.iter()
.map(|(_, s)| *s)
.fold((f32::INFINITY, f32::NEG_INFINITY), |(lo, hi), s| {
(lo.min(s), hi.max(s))
});
let range = max - min;
for (id, s) in hits {
let norm = if range.abs() < 1e-12 {
0.5
} else {
(s - min) / range
};
*totals.entry(*id).or_insert(0.0) += weight * norm;
}
}
let mut fused: Vec<(NodeId, f32)> = totals.into_iter().collect();
fused.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
fused
}
pub fn reciprocal_rank_fusion(lists: &[Vec<NodeId>], k: f32) -> Vec<(NodeId, f32)> {
let weighted: Vec<(Vec<NodeId>, f32)> = lists.iter().map(|l| (l.clone(), 1.0)).collect();
weighted_reciprocal_rank_fusion(&weighted, k)
}
pub fn weighted_reciprocal_rank_fusion(lists: &[(Vec<NodeId>, f32)], k: f32) -> Vec<(NodeId, f32)> {
let mut scores: HashMap<NodeId, f32> = HashMap::new();
for (list, weight) in lists {
if *weight == 0.0 {
continue;
}
for (rank, id) in list.iter().enumerate() {
let contrib = weight / (k + (rank as f32) + 1.0);
*scores.entry(*id).or_insert(0.0) += contrib;
}
}
let mut fused: Vec<(NodeId, f32)> = scores.into_iter().collect();
fused.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
fused
}
pub(super) fn prefetch_and_filter(
repo: &ReadonlyRepo,
ranked: Vec<(NodeId, f32)>,
label: Option<&str>,
prop: Option<&(String, PropPredicate)>,
) -> Result<Vec<(NodeId, f32, Node)>, Error> {
let mut out = Vec::with_capacity(ranked.len());
let mut seen: HashSet<NodeId> = HashSet::with_capacity(ranked.len());
for (id, score) in ranked {
if !seen.insert(id) {
continue;
}
let Some(node) = repo.lookup_node(&id)? else {
continue;
};
if let Some(lbl) = label
&& node.ntype != lbl
{
continue;
}
if let Some((name, PropPredicate::Eq(value))) = prop
&& node.props.get(name) != Some(value)
{
continue;
}
out.push((id, score, node));
}
Ok(out)
}