use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub struct Hit {
pub pk: String,
pub score: f32,
}
impl Hit {
pub fn new(pk: impl Into<String>, score: f32) -> Self {
Self {
pk: pk.into(),
score,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct RrfReRanker {
pub k: f32,
}
impl Default for RrfReRanker {
fn default() -> Self {
Self { k: 60.0 }
}
}
impl RrfReRanker {
pub fn new(k: f32) -> Self {
assert!(
k.is_finite() && k > 0.0,
"RRF k must be positive and finite"
);
Self { k }
}
pub fn fuse<L, I>(&self, lists: L) -> Vec<Hit>
where
L: IntoIterator<Item = I>,
I: IntoIterator<Item = Hit>,
{
let mut agg: HashMap<String, f32> = HashMap::new();
for list in lists {
let mut seen: HashMap<String, usize> = HashMap::new();
for (rank, hit) in list.into_iter().enumerate() {
let best = seen.entry(hit.pk.clone()).or_insert(rank);
if rank < *best {
*best = rank;
}
}
for (pk, rank) in seen {
let contribution = 1.0 / (self.k + rank as f32 + 1.0);
*agg.entry(pk).or_insert(0.0) += contribution;
}
}
sort_desc(agg)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Normalization {
None,
MinMax,
}
#[derive(Debug, Clone, PartialEq)]
pub struct WeightedReRanker {
pub weights: Vec<f32>,
pub normalization: Normalization,
}
impl WeightedReRanker {
pub fn new(weights: Vec<f32>) -> Self {
Self {
weights,
normalization: Normalization::MinMax,
}
}
pub fn with_normalization(mut self, n: Normalization) -> Self {
self.normalization = n;
self
}
pub fn fuse<L, I>(&self, lists: L) -> Vec<Hit>
where
L: IntoIterator<Item = I>,
I: IntoIterator<Item = Hit>,
{
let mut agg: HashMap<String, f32> = HashMap::new();
let mut list_idx = 0usize;
for list in lists {
let hits: Vec<Hit> = list.into_iter().collect();
let weight = *self.weights.get(list_idx).unwrap_or_else(|| {
panic!(
"WeightedReRanker: got more lists than weights (weights.len() = {})",
self.weights.len()
)
});
let (lo, hi) = extrema(&hits);
for h in hits {
let normalised = match self.normalization {
Normalization::None => h.score,
Normalization::MinMax => {
if hi > lo {
(h.score - lo) / (hi - lo)
} else {
1.0
}
}
};
*agg.entry(h.pk).or_insert(0.0) += weight * normalised;
}
list_idx += 1;
}
if list_idx != self.weights.len() {
panic!(
"WeightedReRanker: got {} lists but {} weights",
list_idx,
self.weights.len()
);
}
sort_desc(agg)
}
}
fn extrema(hits: &[Hit]) -> (f32, f32) {
if hits.is_empty() {
return (0.0, 0.0);
}
let mut lo = hits[0].score;
let mut hi = hits[0].score;
for h in &hits[1..] {
if h.score < lo {
lo = h.score;
}
if h.score > hi {
hi = h.score;
}
}
(lo, hi)
}
fn sort_desc(agg: HashMap<String, f32>) -> Vec<Hit> {
let mut out: Vec<Hit> = agg
.into_iter()
.map(|(pk, score)| Hit { pk, score })
.collect();
out.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.pk.cmp(&b.pk))
});
out
}
#[cfg(test)]
mod tests {
use super::*;
fn hits(pairs: &[(&str, f32)]) -> Vec<Hit> {
pairs.iter().map(|(pk, s)| Hit::new(*pk, *s)).collect()
}
#[test]
fn rrf_prefers_consensus_hits() {
let a = hits(&[("doc1", 0.9), ("doc2", 0.8), ("doc3", 0.7)]);
let b = hits(&[("doc2", 0.95), ("doc1", 0.9), ("doc4", 0.8)]);
let fused = RrfReRanker::default().fuse([a, b]);
assert_eq!(&fused[0].pk, "doc1");
assert_eq!(&fused[1].pk, "doc2");
assert!(fused[0].score > fused[2].score);
assert_eq!(fused.len(), 4);
}
#[test]
fn rrf_dedup_within_list_uses_best_rank() {
let a = hits(&[("doc1", 1.0), ("doc1", 0.5), ("doc2", 0.4)]);
let fused = RrfReRanker::new(10.0).fuse([a]);
assert_eq!(fused.len(), 2);
assert_eq!(&fused[0].pk, "doc1");
let expected = 1.0_f32 / 11.0;
assert!((fused[0].score - expected).abs() < 1e-6);
}
#[test]
fn weighted_minmax_flattens_score_scales() {
let a = hits(&[("doc1", 10.0), ("doc2", 5.0), ("doc3", 0.0)]);
let b = hits(&[("doc3", 1.0), ("doc2", 0.5), ("doc1", 0.0)]);
let fused = WeightedReRanker::new(vec![1.0, 1.0]).fuse([a, b]);
assert_eq!(fused.len(), 3);
assert!((fused[0].score - 1.0).abs() < 1e-6);
assert!((fused[2].score - 1.0).abs() < 1e-6);
let pks: Vec<_> = fused.iter().map(|h| h.pk.as_str()).collect();
assert_eq!(pks, ["doc1", "doc2", "doc3"]);
}
#[test]
fn weighted_respects_weights() {
let a = hits(&[("doc1", 1.0), ("doc2", 0.0)]);
let b = hits(&[("doc2", 1.0), ("doc1", 0.0)]);
let fused = WeightedReRanker::new(vec![1.0, 10.0])
.with_normalization(Normalization::None)
.fuse([a, b]);
assert_eq!(fused.len(), 2);
assert_eq!(&fused[0].pk, "doc2");
assert!(fused[0].score > fused[1].score);
}
#[test]
#[should_panic(expected = "got 1 lists but 2 weights")]
fn weighted_rejects_too_few_lists() {
let a = hits(&[("doc1", 1.0)]);
WeightedReRanker::new(vec![1.0, 1.0]).fuse([a]);
}
#[test]
#[should_panic(expected = "got more lists than weights")]
fn weighted_rejects_too_many_lists() {
let a = hits(&[("doc1", 1.0)]);
let b = hits(&[("doc2", 1.0)]);
WeightedReRanker::new(vec![1.0]).fuse([a, b]);
}
}