use std::collections::HashMap;
use crate::Result;
pub type CentralityMap = HashMap<u64, f32>;
#[derive(Debug, Clone)]
pub struct BiasOptions {
pub alpha: f32,
pub over_fetch_multiplier: usize,
}
impl Default for BiasOptions {
fn default() -> Self {
Self { alpha: 0.2, over_fetch_multiplier: 4 }
}
}
pub fn apply_bias(
candidates: Vec<(u64, f32)>,
k: usize,
centrality: Option<&CentralityMap>,
prefilter: Option<&dyn Fn(u64) -> bool>,
opts: &BiasOptions,
) -> Result<Vec<(u64, f32)>> {
if candidates.is_empty() || k == 0 {
return Ok(Vec::new());
}
let kept: Vec<(u64, f32)> = candidates
.into_iter()
.filter(|(id, _)| prefilter.map_or(true, |f| f(*id)))
.collect();
if kept.is_empty() {
return Ok(Vec::new());
}
let mut scored: Vec<(u64, f32)> = if let Some(cent) = centrality {
let max_dist = kept
.iter()
.map(|(_, d)| *d)
.fold(f32::MIN, f32::max);
let alpha = opts.alpha.clamp(0.0, 1.0);
kept.into_iter()
.map(|(id, dist)| {
let normalised = if max_dist > 0.0 { dist / max_dist } else { 0.0 };
let cent_v = cent.get(&id).copied().unwrap_or(0.0).clamp(0.0, 1.0);
let score = (1.0 - alpha) * normalised - alpha * cent_v;
(id, score)
})
.collect()
} else {
kept
};
scored.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
});
scored.truncate(k);
Ok(scored)
}
#[cfg(test)]
mod tests {
use super::*;
fn cent_map(pairs: &[(u64, f32)]) -> CentralityMap {
pairs.iter().copied().collect()
}
#[test]
fn no_bias_no_prefilter_preserves_top_k() {
let cands = vec![(1, 0.1), (2, 0.2), (3, 0.3), (4, 0.4)];
let r = apply_bias(cands, 2, None, None, &BiasOptions::default()).unwrap();
assert_eq!(r.len(), 2);
assert_eq!(r[0].0, 1);
assert_eq!(r[1].0, 2);
}
#[test]
fn prefilter_drops_non_matches_before_truncation() {
let cands = vec![(1, 0.1), (2, 0.2), (3, 0.3), (4, 0.4)];
let pf = |id: u64| id % 2 == 0; let r = apply_bias(cands, 2, None, Some(&pf), &BiasOptions::default()).unwrap();
assert_eq!(r.len(), 2);
assert_eq!(r[0].0, 2);
assert_eq!(r[1].0, 4);
}
#[test]
fn centrality_promotes_well_connected_node() {
let cands = vec![(1, 0.10), (2, 0.20), (3, 0.30)];
let cent = cent_map(&[(1, 0.0), (2, 0.0), (3, 1.0)]);
let opts = BiasOptions { alpha: 0.9, over_fetch_multiplier: 1 };
let r = apply_bias(cands, 1, Some(¢), None, &opts).unwrap();
assert_eq!(r.len(), 1);
assert_eq!(r[0].0, 3, "high-centrality node should win");
}
#[test]
fn alpha_zero_ignores_centrality() {
let cands = vec![(1, 0.10), (2, 0.20), (3, 0.30)];
let cent = cent_map(&[(1, 0.0), (2, 0.0), (3, 1.0)]);
let opts = BiasOptions { alpha: 0.0, ..BiasOptions::default() };
let r = apply_bias(cands, 1, Some(¢), None, &opts).unwrap();
assert_eq!(r[0].0, 1, "alpha=0 must fall back to pure distance");
}
#[test]
fn empty_inputs_are_safe() {
let r = apply_bias(Vec::new(), 5, None, None, &BiasOptions::default()).unwrap();
assert!(r.is_empty());
let cands = vec![(1, 0.1)];
let r = apply_bias(cands, 0, None, None, &BiasOptions::default()).unwrap();
assert!(r.is_empty());
}
}