use std::collections::BTreeMap;
use miette::{bail, Result};
use smartstring::{LazyCompact, SmartString};
use crate::data::expr::Expr;
use crate::data::symb::Symbol;
use crate::data::value::DataValue;
use crate::fixed_rule::{FixedRule, FixedRulePayload};
use crate::parse::SourceSpan;
use crate::runtime::db::Poison;
use crate::runtime::temp_store::RegularTempStore;
pub(crate) struct ReciprocalRankFusion;
impl FixedRule for ReciprocalRankFusion {
fn run(
&self,
payload: FixedRulePayload<'_, '_>,
out: &mut RegularTempStore,
poison: Poison,
) -> Result<()> {
let in_rel = payload.get_input(0)?;
let k = payload.float_option("k", Some(60.0))?.max(0.0);
let descending = payload.bool_option("descending", Some(true))?;
let mut lists: BTreeMap<DataValue, Vec<(DataValue, DataValue)>> = BTreeMap::new();
for tuple in in_rel.iter()? {
let tuple = tuple?;
if tuple.len() != 3 {
bail!(
"ReciprocalRankFusion expects a 3-column input [list_id, item, score], \
got a row of arity {}",
tuple.len()
);
}
let mut it = tuple.into_iter();
let list_id = it.next().unwrap();
let item = it.next().unwrap();
let score = it.next().unwrap();
if let Some(f) = score.get_float() {
if !f.is_finite() {
bail!("ReciprocalRankFusion: score (column 3) must be finite, got {f}");
}
}
lists.entry(list_id).or_default().push((item, score));
poison.check()?;
}
let mut fused: BTreeMap<DataValue, f64> = BTreeMap::new();
for (_list_id, entries) in lists {
let mut best: BTreeMap<DataValue, DataValue> = BTreeMap::new();
for (item, score) in entries {
match best.get_mut(&item) {
Some(cur) => {
let better = if descending { score > *cur } else { score < *cur };
if better {
*cur = score;
}
}
None => {
best.insert(item, score);
}
}
}
let mut ranked: Vec<(DataValue, DataValue)> = best.into_iter().collect();
if descending {
ranked.sort_by(|a, b| b.1.cmp(&a.1));
} else {
ranked.sort_by(|a, b| a.1.cmp(&b.1));
}
for (idx, (item, _score)) in ranked.into_iter().enumerate() {
let rank = (idx + 1) as f64;
*fused.entry(item).or_insert(0.0) += 1.0 / (k + rank);
}
poison.check()?;
}
for (item, score) in fused {
out.put(vec![item, DataValue::from(score)]);
}
Ok(())
}
fn arity(
&self,
_options: &BTreeMap<SmartString<LazyCompact>, Expr>,
_rule_head: &[Symbol],
_span: SourceSpan,
) -> Result<usize> {
Ok(2)
}
}