#![warn(missing_docs)]
use std::collections::HashMap;
use std::hash::Hash;
pub mod adapt;
pub mod diagnostics;
pub mod dp_topk;
pub mod metrics;
pub mod pipeline;
pub mod rerank;
pub mod validate;
#[cfg(test)]
mod proptests;
#[derive(Debug, Clone, PartialEq)]
pub enum FusionError {
ZeroWeights,
InvalidConfig(String),
}
impl std::fmt::Display for FusionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ZeroWeights => write!(f, "weights sum to zero"),
Self::InvalidConfig(msg) => write!(f, "invalid config: {msg}"),
}
}
}
impl std::error::Error for FusionError {}
pub type Result<T> = std::result::Result<T, FusionError>;
const WEIGHT_EPSILON: f32 = 1e-9;
const SCORE_RANGE_EPSILON: f32 = 1e-9;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RrfConfig {
pub k: u32,
pub top_k: Option<usize>,
}
impl Default for RrfConfig {
fn default() -> Self {
Self { k: 60, top_k: None }
}
}
impl RrfConfig {
#[must_use]
pub fn new(k: u32) -> Self {
assert!(
k >= 1,
"k must be >= 1 to avoid division by zero in RRF formula"
);
Self { k, top_k: None }
}
#[must_use]
pub fn with_k(mut self, k: u32) -> Self {
assert!(
k >= 1,
"k must be >= 1 to avoid division by zero in RRF formula"
);
self.k = k;
self
}
#[must_use]
pub const fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = Some(top_k);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct WeightedConfig {
pub weight_a: f32,
pub weight_b: f32,
pub normalize: bool,
pub top_k: Option<usize>,
}
impl Default for WeightedConfig {
fn default() -> Self {
Self {
weight_a: 0.5,
weight_b: 0.5,
normalize: true,
top_k: None,
}
}
}
impl WeightedConfig {
#[must_use]
pub const fn new(weight_a: f32, weight_b: f32) -> Self {
Self {
weight_a,
weight_b,
normalize: true,
top_k: None,
}
}
#[must_use]
pub const fn with_weights(mut self, weight_a: f32, weight_b: f32) -> Self {
self.weight_a = weight_a;
self.weight_b = weight_b;
self
}
#[must_use]
pub const fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
#[must_use]
pub const fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = Some(top_k);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct FusionConfig {
pub top_k: Option<usize>,
}
impl FusionConfig {
#[must_use]
pub const fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = Some(top_k);
self
}
}
pub mod prelude {
pub use crate::{
additive_multi_task, additive_multi_task_multi, additive_multi_task_with_config, borda,
combanz, combmax, combmed, combmnz, combsum, condorcet, copeland, dbsf, isr,
isr_with_config, median_rank, rrf, rrf_with_config, standardized, standardized_multi,
standardized_with_config, weighted,
};
pub use crate::{
evaluate_metric, hit_rate, map, map_at_k, mrr, ndcg_at_k, precision_at_k, recall_at_k,
};
pub use crate::{
AdditiveMultiTaskConfig, FusionConfig, FusionError, FusionMethod, Normalization, Result,
RrfConfig, StandardizedConfig, WeightedConfig,
};
}
pub mod explain {
pub use crate::{
analyze_consensus, attribute_top_k, combmnz_explain, combsum_explain, dbsf_explain,
rrf_explain, ConsensusReport, Explanation, FusedResult, RetrieverId, RetrieverStats,
SourceContribution,
};
}
pub use validate::{
validate, validate_bounds, validate_finite_scores, validate_no_duplicates,
validate_non_negative_scores, validate_sorted, ValidationResult,
};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FusionMethod {
Rrf {
k: u32,
},
Isr {
k: u32,
},
CombSum,
CombMnz,
Borda,
Condorcet,
Copeland,
MedianRank,
CombMax,
CombMin,
CombMed,
CombAnz,
Rbc {
persistence: f32,
},
Weighted {
weight_a: f32,
weight_b: f32,
normalize: bool,
},
Dbsf,
Standardized {
clip_range: (f32, f32),
},
AdditiveMultiTask {
weight_a: f32,
weight_b: f32,
normalization: Normalization,
},
}
impl Default for FusionMethod {
fn default() -> Self {
Self::Rrf { k: 60 }
}
}
impl FusionMethod {
#[must_use]
pub const fn rrf() -> Self {
Self::Rrf { k: 60 }
}
#[must_use]
pub const fn rrf_with_k(k: u32) -> Self {
Self::Rrf { k }
}
#[must_use]
pub const fn isr() -> Self {
Self::Isr { k: 1 }
}
#[must_use]
pub const fn isr_with_k(k: u32) -> Self {
Self::Isr { k }
}
#[must_use]
pub const fn rbc() -> Self {
Self::Rbc { persistence: 0.8 }
}
#[must_use]
pub const fn rbc_with_persistence(persistence: f32) -> Self {
Self::Rbc { persistence }
}
#[must_use]
pub const fn weighted(weight_a: f32, weight_b: f32) -> Self {
Self::Weighted {
weight_a,
weight_b,
normalize: true,
}
}
#[must_use]
pub const fn standardized(clip_range: (f32, f32)) -> Self {
Self::Standardized { clip_range }
}
#[must_use]
pub const fn standardized_default() -> Self {
Self::Standardized {
clip_range: (-3.0, 3.0),
}
}
#[must_use]
pub const fn additive_multi_task(weight_a: f32, weight_b: f32) -> Self {
Self::AdditiveMultiTask {
weight_a,
weight_b,
normalization: Normalization::ZScore,
}
}
#[must_use]
pub fn additive_multi_task_with_norm(
weight_a: f32,
weight_b: f32,
normalization: Normalization,
) -> Self {
Self::AdditiveMultiTask {
weight_a,
weight_b,
normalization,
}
}
#[must_use]
pub fn fuse<I: Clone + Eq + Hash>(&self, a: &[(I, f32)], b: &[(I, f32)]) -> Vec<(I, f32)> {
match self {
Self::Rrf { k } => {
if *k == 0 {
return Vec::new();
}
crate::rrf_multi(&[a, b], RrfConfig::new(*k))
}
Self::Isr { k } => {
if *k == 0 {
return Vec::new();
}
crate::isr_multi(&[a, b], RrfConfig::new(*k))
}
Self::CombSum => crate::combsum(a, b),
Self::CombMnz => crate::combmnz(a, b),
Self::Borda => crate::borda(a, b),
Self::Condorcet => crate::condorcet(a, b),
Self::Copeland => crate::copeland(a, b),
Self::MedianRank => crate::median_rank(a, b),
Self::CombMax => crate::combmax(a, b),
Self::CombMin => crate::combmin(a, b),
Self::CombMed => crate::combmed(a, b),
Self::CombAnz => crate::combanz(a, b),
Self::Rbc { persistence } => crate::rbc_multi(&[a, b], *persistence),
Self::Weighted {
weight_a,
weight_b,
normalize,
} => crate::weighted(
a,
b,
WeightedConfig::new(*weight_a, *weight_b).with_normalize(*normalize),
),
Self::Dbsf => crate::dbsf(a, b),
Self::Standardized { clip_range } => {
crate::standardized_with_config(a, b, StandardizedConfig::new(*clip_range))
}
Self::AdditiveMultiTask {
weight_a,
weight_b,
normalization,
} => crate::additive_multi_task_with_config(
a,
b,
AdditiveMultiTaskConfig::new((*weight_a, *weight_b))
.with_normalization(*normalization),
),
}
}
#[must_use]
pub fn fuse_multi<I, L>(&self, lists: &[L]) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
match self {
Self::Rrf { k } => {
if *k == 0 {
return Vec::new();
}
crate::rrf_multi(lists, RrfConfig::new(*k))
}
Self::Isr { k } => {
if *k == 0 {
return Vec::new();
}
crate::isr_multi(lists, RrfConfig::new(*k))
}
Self::CombSum => crate::combsum_multi(lists, FusionConfig::default()),
Self::CombMnz => crate::combmnz_multi(lists, FusionConfig::default()),
Self::Borda => crate::borda_multi(lists, FusionConfig::default()),
Self::Condorcet => crate::condorcet_multi(lists, FusionConfig::default()),
Self::Copeland => crate::copeland_multi(lists, FusionConfig::default()),
Self::MedianRank => crate::median_rank_multi(lists, FusionConfig::default()),
Self::CombMax => crate::combmax_multi(lists, FusionConfig::default()),
Self::CombMin => crate::combmin_multi(lists, FusionConfig::default()),
Self::CombMed => crate::combmed_multi(lists, FusionConfig::default()),
Self::CombAnz => crate::combanz_multi(lists, FusionConfig::default()),
Self::Rbc { persistence } => crate::rbc_multi(lists, *persistence),
Self::Weighted { normalize, .. } => {
if lists.len() == 2 {
self.fuse(lists[0].as_ref(), lists[1].as_ref())
} else {
let weighted_lists: Vec<_> = lists
.iter()
.map(|l| (l.as_ref(), 1.0 / lists.len() as f32))
.collect();
crate::weighted_multi(&weighted_lists, *normalize, None).unwrap_or_default()
}
}
Self::Dbsf => crate::dbsf_multi(lists, FusionConfig::default()),
Self::Standardized { clip_range } => {
crate::standardized_multi(lists, StandardizedConfig::new(*clip_range))
}
Self::AdditiveMultiTask {
weight_a,
weight_b,
normalization,
} => {
if lists.len() == 2 {
self.fuse(lists[0].as_ref(), lists[1].as_ref())
} else {
let weighted_lists: Vec<_> = lists
.iter()
.map(|l| (l.as_ref(), 1.0 / lists.len() as f32))
.collect();
crate::additive_multi_task_multi(
&weighted_lists,
AdditiveMultiTaskConfig::new((*weight_a, *weight_b))
.with_normalization(*normalization),
)
}
}
}
}
}
#[must_use]
pub fn rrf<I: Clone + Eq + Hash>(results_a: &[(I, f32)], results_b: &[(I, f32)]) -> Vec<(I, f32)> {
rrf_with_config(results_a, results_b, RrfConfig::default())
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn rrf_with_config<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
config: RrfConfig,
) -> Vec<(I, f32)> {
if config.k == 0 {
return Vec::new();
}
let k = config.k as f32;
let estimated_size = results_a.len() + results_b.len();
let mut scores: HashMap<I, f32> = HashMap::with_capacity(estimated_size);
for (rank, (id, _)) in results_a.iter().enumerate() {
let contribution = 1.0 / (k + rank as f32);
if let Some(score) = scores.get_mut(id) {
*score += contribution;
} else {
scores.insert(id.clone(), contribution);
}
}
for (rank, (id, _)) in results_b.iter().enumerate() {
let contribution = 1.0 / (k + rank as f32);
if let Some(score) = scores.get_mut(id) {
*score += contribution;
} else {
scores.insert(id.clone(), contribution);
}
}
finalize(scores, config.top_k)
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn rrf_multi<I, L>(lists: &[L], config: RrfConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
if config.k == 0 {
return Vec::new();
}
let k = config.k as f32;
let estimated_size: usize = lists.iter().map(|l| l.as_ref().len()).sum();
let mut scores: HashMap<I, f32> = HashMap::with_capacity(estimated_size);
for list in lists {
for (rank, (id, _)) in list.as_ref().iter().enumerate() {
let contribution = 1.0 / (k + rank as f32);
if let Some(score) = scores.get_mut(id) {
*score += contribution;
} else {
scores.insert(id.clone(), contribution);
}
}
}
finalize(scores, config.top_k)
}
#[allow(clippy::cast_precision_loss)]
pub fn rrf_weighted<I, L>(lists: &[L], weights: &[f32], config: RrfConfig) -> Result<Vec<(I, f32)>>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.len() != weights.len() {
return Err(FusionError::InvalidConfig(format!(
"lists.len() ({}) != weights.len() ({}). Each list must have a corresponding weight.",
lists.len(),
weights.len()
)));
}
let weight_sum: f32 = weights.iter().sum();
if weight_sum.abs() < WEIGHT_EPSILON {
return Err(FusionError::ZeroWeights);
}
let k = config.k as f32;
let estimated_size: usize = lists.iter().map(|l| l.as_ref().len()).sum();
let mut scores: HashMap<I, f32> = HashMap::with_capacity(estimated_size);
for (list, &weight) in lists.iter().zip(weights.iter()) {
let normalized_weight = weight / weight_sum;
for (rank, (id, _)) in list.as_ref().iter().enumerate() {
let contribution = normalized_weight / (k + rank as f32);
if let Some(score) = scores.get_mut(id) {
*score += contribution;
} else {
scores.insert(id.clone(), contribution);
}
}
}
Ok(finalize(scores, config.top_k))
}
#[must_use]
pub fn isr<I: Clone + Eq + Hash>(results_a: &[(I, f32)], results_b: &[(I, f32)]) -> Vec<(I, f32)> {
isr_with_config(results_a, results_b, RrfConfig::new(1))
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn isr_with_config<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
config: RrfConfig,
) -> Vec<(I, f32)> {
if config.k == 0 {
return Vec::new();
}
let k = config.k as f32;
let estimated_size = results_a.len() + results_b.len();
let mut scores: HashMap<I, f32> = HashMap::with_capacity(estimated_size);
for (rank, (id, _)) in results_a.iter().enumerate() {
let contribution = 1.0 / (k + rank as f32).sqrt();
if let Some(score) = scores.get_mut(id) {
*score += contribution;
} else {
scores.insert(id.clone(), contribution);
}
}
for (rank, (id, _)) in results_b.iter().enumerate() {
let contribution = 1.0 / (k + rank as f32).sqrt();
if let Some(score) = scores.get_mut(id) {
*score += contribution;
} else {
scores.insert(id.clone(), contribution);
}
}
finalize(scores, config.top_k)
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn isr_multi<I, L>(lists: &[L], config: RrfConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
if config.k == 0 {
return Vec::new();
}
let k = config.k as f32;
let estimated_size: usize = lists.iter().map(|l| l.as_ref().len()).sum();
let mut scores: HashMap<I, f32> = HashMap::with_capacity(estimated_size);
for list in lists {
for (rank, (id, _)) in list.as_ref().iter().enumerate() {
let contribution = 1.0 / (k + rank as f32).sqrt();
if let Some(score) = scores.get_mut(id) {
*score += contribution;
} else {
scores.insert(id.clone(), contribution);
}
}
}
finalize(scores, config.top_k)
}
#[must_use]
pub fn weighted<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
config: WeightedConfig,
) -> Vec<(I, f32)> {
weighted_impl(
&[(results_a, config.weight_a), (results_b, config.weight_b)],
config.normalize,
config.top_k,
)
}
pub fn weighted_multi<I, L>(
lists: &[(L, f32)],
normalize: bool,
top_k: Option<usize>,
) -> Result<Vec<(I, f32)>>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
let total_weight: f32 = lists.iter().map(|(_, w)| w).sum();
if total_weight.abs() < WEIGHT_EPSILON {
return Err(FusionError::ZeroWeights);
}
let estimated_size: usize = lists.iter().map(|(l, _)| l.as_ref().len()).sum();
let mut scores: HashMap<I, f32> = HashMap::with_capacity(estimated_size);
for (list, weight) in lists {
let items = list.as_ref();
let w = weight / total_weight;
let (norm, off) = if normalize {
min_max_params(items)
} else {
(1.0, 0.0)
};
for (id, s) in items {
let contribution = w * (s - off) * norm;
if let Some(score) = scores.get_mut(id) {
*score += contribution;
} else {
scores.insert(id.clone(), contribution);
}
}
}
Ok(finalize(scores, top_k))
}
fn weighted_impl<I, L>(lists: &[(L, f32)], normalize: bool, top_k: Option<usize>) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
let total_weight: f32 = lists.iter().map(|(_, w)| w).sum();
if total_weight.abs() < WEIGHT_EPSILON {
return Vec::new();
}
let estimated_size: usize = lists.iter().map(|(l, _)| l.as_ref().len()).sum();
let mut scores: HashMap<I, f32> = HashMap::with_capacity(estimated_size);
for (list, weight) in lists {
let items = list.as_ref();
let w = weight / total_weight;
let (norm, off) = if normalize {
min_max_params(items)
} else {
(1.0, 0.0)
};
for (id, s) in items {
let contribution = w * (s - off) * norm;
if let Some(score) = scores.get_mut(id) {
*score += contribution;
} else {
scores.insert(id.clone(), contribution);
}
}
}
finalize(scores, top_k)
}
#[must_use]
pub fn combsum<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
) -> Vec<(I, f32)> {
combsum_with_config(results_a, results_b, FusionConfig::default())
}
#[must_use]
pub fn combsum_with_config<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
config: FusionConfig,
) -> Vec<(I, f32)> {
combsum_multi(&[results_a, results_b], config)
}
#[must_use]
pub fn combsum_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
let estimated_size: usize = lists.iter().map(|l| l.as_ref().len()).sum();
let mut scores: HashMap<I, f32> = HashMap::with_capacity(estimated_size);
for list in lists {
let items = list.as_ref();
let (norm, off) = min_max_params(items);
for (id, s) in items {
let contribution = (s - off) * norm;
if let Some(score) = scores.get_mut(id) {
*score += contribution;
} else {
scores.insert(id.clone(), contribution);
}
}
}
finalize(scores, config.top_k)
}
#[must_use]
pub fn combmnz<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
) -> Vec<(I, f32)> {
combmnz_with_config(results_a, results_b, FusionConfig::default())
}
#[must_use]
pub fn combmnz_with_config<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
config: FusionConfig,
) -> Vec<(I, f32)> {
combmnz_multi(&[results_a, results_b], config)
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn combmnz_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
let estimated_size: usize = lists.iter().map(|l| l.as_ref().len()).sum();
let mut scores: HashMap<I, (f32, u32)> = HashMap::with_capacity(estimated_size);
for list in lists {
let items = list.as_ref();
let (norm, off) = min_max_params(items);
for (id, s) in items {
let contribution = (s - off) * norm;
if let Some(entry) = scores.get_mut(id) {
entry.0 += contribution;
entry.1 += 1;
} else {
scores.insert(id.clone(), (contribution, 1));
}
}
}
let mut results: Vec<_> = scores
.into_iter()
.map(|(id, (sum, n))| (id, sum * n as f32))
.collect();
sort_scored_desc(&mut results);
if let Some(top_k) = config.top_k {
results.truncate(top_k);
}
results
}
#[must_use]
pub fn borda<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
) -> Vec<(I, f32)> {
borda_with_config(results_a, results_b, FusionConfig::default())
}
#[must_use]
pub fn borda_with_config<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
config: FusionConfig,
) -> Vec<(I, f32)> {
borda_multi(&[results_a, results_b], config)
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn borda_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
let estimated_size: usize = lists.iter().map(|l| l.as_ref().len()).sum();
let mut scores: HashMap<I, f32> = HashMap::with_capacity(estimated_size);
for list in lists {
let items = list.as_ref();
let n = items.len() as f32;
for (rank, (id, _)) in items.iter().enumerate() {
let contribution = n - rank as f32;
if let Some(score) = scores.get_mut(id) {
*score += contribution;
} else {
scores.insert(id.clone(), contribution);
}
}
}
finalize(scores, config.top_k)
}
#[must_use]
pub fn dbsf<I: Clone + Eq + Hash>(results_a: &[(I, f32)], results_b: &[(I, f32)]) -> Vec<(I, f32)> {
dbsf_with_config(results_a, results_b, FusionConfig::default())
}
#[must_use]
pub fn dbsf_with_config<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
config: FusionConfig,
) -> Vec<(I, f32)> {
dbsf_multi(&[results_a, results_b], config)
}
#[must_use]
pub fn dbsf_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
let std_config = StandardizedConfig {
clip_range: (-3.0, 3.0),
top_k: config.top_k,
};
standardized_multi(lists, std_config)
}
#[inline(always)]
fn zscore_params<I>(results: &[(I, f32)]) -> (f32, f32) {
if results.is_empty() {
return (0.0, 1.0);
}
let n = results.len() as f32;
let mean = results.iter().map(|(_, s)| s).sum::<f32>() / n;
let variance = results.iter().map(|(_, s)| (s - mean).powi(2)).sum::<f32>() / n;
let std = variance.sqrt();
(mean, std)
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct StandardizedConfig {
pub clip_range: (f32, f32),
pub top_k: Option<usize>,
}
impl Default for StandardizedConfig {
fn default() -> Self {
Self {
clip_range: (-3.0, 3.0),
top_k: None,
}
}
}
impl StandardizedConfig {
#[must_use]
pub const fn new(clip_range: (f32, f32)) -> Self {
Self {
clip_range,
top_k: None,
}
}
#[must_use]
pub const fn dbsf() -> Self {
Self {
clip_range: (-3.0, 3.0),
top_k: None,
}
}
#[must_use]
pub const fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = Some(top_k);
self
}
}
#[must_use]
pub fn standardized<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
) -> Vec<(I, f32)> {
standardized_with_config(results_a, results_b, StandardizedConfig::default())
}
#[must_use]
pub fn standardized_with_config<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
config: StandardizedConfig,
) -> Vec<(I, f32)> {
standardized_multi(&[results_a, results_b], config)
}
#[must_use]
pub fn standardized_multi<I, L>(lists: &[L], config: StandardizedConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
let estimated_size: usize = lists.iter().map(|l| l.as_ref().len()).sum();
let mut scores: HashMap<I, f32> = HashMap::with_capacity(estimated_size);
let (clip_min, clip_max) = config.clip_range;
for list in lists {
let items = list.as_ref();
let (mean, std) = zscore_params(items);
for (id, s) in items {
let z = if std > SCORE_RANGE_EPSILON {
((s - mean) / std).clamp(clip_min, clip_max)
} else {
0.0 };
if let Some(score) = scores.get_mut(id) {
*score += z;
} else {
scores.insert(id.clone(), z);
}
}
}
finalize(scores, config.top_k)
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct AdditiveMultiTaskConfig {
pub weights: (f32, f32),
pub normalization: Normalization,
pub top_k: Option<usize>,
}
impl Default for AdditiveMultiTaskConfig {
fn default() -> Self {
Self {
weights: (1.0, 1.0),
normalization: Normalization::ZScore,
top_k: None,
}
}
}
impl AdditiveMultiTaskConfig {
#[must_use]
pub const fn new(weights: (f32, f32)) -> Self {
Self {
weights,
normalization: Normalization::ZScore,
top_k: None,
}
}
#[must_use]
pub const fn with_normalization(mut self, normalization: Normalization) -> Self {
self.normalization = normalization;
self
}
#[must_use]
pub const fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = Some(top_k);
self
}
}
#[must_use]
pub fn additive_multi_task<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
config: AdditiveMultiTaskConfig,
) -> Vec<(I, f32)> {
additive_multi_task_with_config(results_a, results_b, config)
}
#[must_use]
pub fn additive_multi_task_with_config<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
config: AdditiveMultiTaskConfig,
) -> Vec<(I, f32)> {
let weighted_lists = vec![(results_a, config.weights.0), (results_b, config.weights.1)];
additive_multi_task_multi(&weighted_lists, config)
}
#[must_use]
pub fn additive_multi_task_multi<I, L>(
weighted_lists: &[(L, f32)],
config: AdditiveMultiTaskConfig,
) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if weighted_lists.is_empty() {
return Vec::new();
}
let normalized: Vec<_> = weighted_lists
.iter()
.map(|(list, _)| normalize_scores(list.as_ref(), config.normalization))
.collect();
let estimated_size: usize = normalized.iter().map(|n| n.len()).sum();
let mut scores: HashMap<I, f32> = HashMap::with_capacity(estimated_size);
for (normalized_list, (_, weight)) in normalized.iter().zip(weighted_lists.iter()) {
for (id, norm_score) in normalized_list {
if let Some(score) = scores.get_mut(id) {
*score += weight * norm_score;
} else {
scores.insert(id.clone(), weight * norm_score);
}
}
}
finalize(scores, config.top_k)
}
#[inline]
fn finalize<I>(scores: HashMap<I, f32>, top_k: Option<usize>) -> Vec<(I, f32)> {
let capacity = top_k.map(|k| k.min(scores.len())).unwrap_or(scores.len());
let mut results = Vec::with_capacity(capacity);
results.extend(scores);
sort_scored_desc(&mut results);
if let Some(k) = top_k {
results.truncate(k);
}
results
}
#[inline]
fn sort_scored_desc<I>(results: &mut [(I, f32)]) {
results.sort_by(|a, b| b.1.total_cmp(&a.1));
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Normalization {
#[default]
MinMax,
ZScore,
Sum,
Rank,
Quantile,
Sigmoid,
None,
}
pub fn normalize_scores<I: Clone>(results: &[(I, f32)], method: Normalization) -> Vec<(I, f32)> {
if results.is_empty() {
return Vec::new();
}
match method {
Normalization::MinMax => {
let (norm, off) = min_max_params(results);
results
.iter()
.map(|(id, s)| (id.clone(), (s - off) * norm))
.collect()
}
Normalization::ZScore => {
let (mean, std) = zscore_params(results);
results
.iter()
.map(|(id, s)| {
let z = if std > SCORE_RANGE_EPSILON {
((s - mean) / std).clamp(-3.0, 3.0)
} else {
0.0
};
(id.clone(), z)
})
.collect()
}
Normalization::Sum => {
let sum: f32 = results.iter().map(|(_, s)| s).sum();
if sum.abs() < SCORE_RANGE_EPSILON {
return results.to_vec();
}
results
.iter()
.map(|(id, s)| (id.clone(), s / sum))
.collect()
}
Normalization::Rank => {
let mut sorted: Vec<_> = results.to_vec();
sorted.sort_by(|a, b| {
match (a.1.is_finite(), b.1.is_finite()) {
(true, true) => b.1.total_cmp(&a.1), (true, false) => std::cmp::Ordering::Less, (false, true) => std::cmp::Ordering::Greater, (false, false) => std::cmp::Ordering::Equal, }
});
let n = sorted.len() as f32;
sorted
.iter()
.enumerate()
.map(|(rank, (id, _))| (id.clone(), 1.0 - (rank as f32 / n)))
.collect()
}
Normalization::Quantile => {
let mut indexed: Vec<(usize, f32)> = results
.iter()
.enumerate()
.map(|(i, (_, s))| (i, *s))
.collect();
indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let n = indexed.len();
let mut quantiles = vec![0.0f32; n];
if n == 1 {
quantiles[0] = 0.5; } else {
for (rank, &(orig_idx, _)) in indexed.iter().enumerate() {
quantiles[orig_idx] = rank as f32 / (n - 1) as f32;
}
}
results
.iter()
.enumerate()
.map(|(i, (id, _))| (id.clone(), quantiles[i]))
.collect()
}
Normalization::Sigmoid => results
.iter()
.map(|(id, s)| (id.clone(), 1.0 / (1.0 + (-s).exp())))
.collect(),
Normalization::None => results.to_vec(),
}
}
#[inline(always)]
fn min_max_params<I>(results: &[(I, f32)]) -> (f32, f32) {
if results.is_empty() {
return (1.0, 0.0);
}
let (min, max) = results
.iter()
.fold((f32::INFINITY, f32::NEG_INFINITY), |(lo, hi), (_, s)| {
(lo.min(*s), hi.max(*s))
});
let range = max - min;
if range < SCORE_RANGE_EPSILON {
(1.0, 0.0)
} else {
(1.0 / range, min)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct FusedResult<K> {
pub id: K,
pub score: f32,
pub rank: usize,
pub explanation: Explanation,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Explanation {
pub sources: Vec<SourceContribution>,
pub method: &'static str,
pub consensus_score: f32,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SourceContribution {
pub retriever_id: String,
pub original_rank: Option<usize>,
pub original_score: Option<f32>,
pub normalized_score: Option<f32>,
pub contribution: f32,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RetrieverId {
id: String,
}
impl RetrieverId {
pub fn new<S: Into<String>>(id: S) -> Self {
Self { id: id.into() }
}
pub fn as_str(&self) -> &str {
&self.id
}
}
impl From<&str> for RetrieverId {
fn from(id: &str) -> Self {
Self::new(id)
}
}
impl From<String> for RetrieverId {
fn from(id: String) -> Self {
Self::new(id)
}
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn rrf_explain<I, L>(
lists: &[L],
retriever_ids: &[RetrieverId],
config: RrfConfig,
) -> Vec<FusedResult<I>>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() || lists.len() != retriever_ids.len() {
return Vec::new();
}
let k = config.k as f32;
let num_retrievers = lists.len() as f32;
let mut scores: HashMap<I, f32> = HashMap::new();
let mut provenance: HashMap<I, Vec<SourceContribution>> = HashMap::new();
for (list, retriever_id) in lists.iter().zip(retriever_ids.iter()) {
for (rank, (id, original_score)) in list.as_ref().iter().enumerate() {
let contribution = 1.0 / (k + rank as f32);
*scores.entry(id.clone()).or_insert(0.0) += contribution;
provenance
.entry(id.clone())
.or_default()
.push(SourceContribution {
retriever_id: retriever_id.id.clone(),
original_rank: Some(rank),
original_score: Some(*original_score),
normalized_score: None, contribution,
});
}
}
let mut results: Vec<FusedResult<I>> = scores
.into_iter()
.map(|(id, score)| {
let sources = provenance.remove(&id).unwrap_or_default();
let consensus_score = sources.len() as f32 / num_retrievers;
FusedResult {
id,
score,
rank: 0, explanation: Explanation {
sources,
method: "rrf",
consensus_score,
},
}
})
.collect();
results.sort_by(|a, b| b.score.total_cmp(&a.score));
for (rank, result) in results.iter_mut().enumerate() {
result.rank = rank;
}
if let Some(top_k) = config.top_k {
results.truncate(top_k);
}
results
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ConsensusReport<K> {
pub high_consensus: Vec<K>,
pub single_source: Vec<K>,
pub rank_disagreement: Vec<(K, Vec<(String, usize)>)>,
}
pub fn analyze_consensus<K: Clone + Eq + Hash>(results: &[FusedResult<K>]) -> ConsensusReport<K> {
let mut high_consensus = Vec::new();
let mut single_source = Vec::new();
let mut rank_disagreement = Vec::new();
for result in results {
if result.explanation.consensus_score >= 1.0 - 1e-6 {
high_consensus.push(result.id.clone());
}
if result.explanation.sources.len() == 1 {
single_source.push(result.id.clone());
}
if result.explanation.sources.len() > 1 {
let ranks: Vec<usize> = result
.explanation
.sources
.iter()
.filter_map(|s| s.original_rank)
.collect();
if let (Some(&min_rank), Some(&max_rank)) = (ranks.iter().min(), ranks.iter().max()) {
if max_rank - min_rank > 10 {
let rank_info: Vec<(String, usize)> = result
.explanation
.sources
.iter()
.filter_map(|s| s.original_rank.map(|r| (s.retriever_id.clone(), r)))
.collect();
rank_disagreement.push((result.id.clone(), rank_info));
}
}
}
}
ConsensusReport {
high_consensus,
single_source,
rank_disagreement,
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct RetrieverStats {
pub top_k_count: usize,
pub avg_contribution: f32,
pub unique_docs: usize,
}
pub fn attribute_top_k<K: Clone + Eq + Hash>(
results: &[FusedResult<K>],
k: usize,
) -> std::collections::HashMap<String, RetrieverStats> {
let top_k = results.iter().take(k);
let mut stats: std::collections::HashMap<String, RetrieverStats> =
std::collections::HashMap::new();
let mut retriever_docs: std::collections::HashMap<String, std::collections::HashSet<K>> =
std::collections::HashMap::new();
for result in top_k {
for source in &result.explanation.sources {
let entry =
stats
.entry(source.retriever_id.clone())
.or_insert_with(|| RetrieverStats {
top_k_count: 0,
avg_contribution: 0.0,
unique_docs: 0,
});
entry.top_k_count += 1;
entry.avg_contribution += source.contribution;
retriever_docs
.entry(source.retriever_id.clone())
.or_default()
.insert(result.id.clone());
}
}
for (retriever_id, stat) in &mut stats {
if stat.top_k_count > 0 {
stat.avg_contribution /= stat.top_k_count as f32;
}
let this_retriever_docs = retriever_docs
.get(retriever_id)
.cloned()
.unwrap_or_default();
let other_retriever_docs: std::collections::HashSet<K> = retriever_docs
.iter()
.filter(|(id, _)| *id != retriever_id)
.flat_map(|(_, docs)| docs.iter().cloned())
.collect();
stat.unique_docs = this_retriever_docs
.difference(&other_retriever_docs)
.count();
}
stats
}
#[must_use]
pub fn combsum_explain<I, L>(
lists: &[L],
retriever_ids: &[RetrieverId],
config: FusionConfig,
) -> Vec<FusedResult<I>>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() || lists.len() != retriever_ids.len() {
return Vec::new();
}
let num_retrievers = lists.len() as f32;
let mut scores: HashMap<I, f32> = HashMap::new();
let mut provenance: HashMap<I, Vec<SourceContribution>> = HashMap::new();
for (list, retriever_id) in lists.iter().zip(retriever_ids.iter()) {
let items = list.as_ref();
let (norm, off) = min_max_params(items);
for (rank, (id, original_score)) in items.iter().enumerate() {
let normalized_score = (original_score - off) * norm;
let contribution = normalized_score;
*scores.entry(id.clone()).or_insert(0.0) += contribution;
provenance
.entry(id.clone())
.or_default()
.push(SourceContribution {
retriever_id: retriever_id.id.clone(),
original_rank: Some(rank),
original_score: Some(*original_score),
normalized_score: Some(normalized_score),
contribution,
});
}
}
build_explained_results(scores, provenance, num_retrievers, "combsum", config.top_k)
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn combmnz_explain<I, L>(
lists: &[L],
retriever_ids: &[RetrieverId],
config: FusionConfig,
) -> Vec<FusedResult<I>>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() || lists.len() != retriever_ids.len() {
return Vec::new();
}
let num_retrievers = lists.len() as f32;
let mut scores: HashMap<I, (f32, u32)> = HashMap::new();
let mut provenance: HashMap<I, Vec<SourceContribution>> = HashMap::new();
for (list, retriever_id) in lists.iter().zip(retriever_ids.iter()) {
let items = list.as_ref();
let (norm, off) = min_max_params(items);
for (rank, (id, original_score)) in items.iter().enumerate() {
let normalized_score = (original_score - off) * norm;
let contribution = normalized_score;
let entry = scores.entry(id.clone()).or_insert((0.0, 0));
entry.0 += contribution;
entry.1 += 1;
provenance
.entry(id.clone())
.or_default()
.push(SourceContribution {
retriever_id: retriever_id.id.clone(),
original_rank: Some(rank),
original_score: Some(*original_score),
normalized_score: Some(normalized_score),
contribution,
});
}
}
let mut final_scores: HashMap<I, f32> = HashMap::new();
let mut final_provenance: HashMap<I, Vec<SourceContribution>> = HashMap::new();
for (id, (sum, overlap_count)) in scores {
let final_score = sum * overlap_count as f32;
final_scores.insert(id.clone(), final_score);
if let Some(mut sources) = provenance.remove(&id) {
for source in &mut sources {
source.contribution *= overlap_count as f32;
}
final_provenance.insert(id, sources);
}
}
build_explained_results(
final_scores,
final_provenance,
num_retrievers,
"combmnz",
config.top_k,
)
}
#[must_use]
pub fn dbsf_explain<I, L>(
lists: &[L],
retriever_ids: &[RetrieverId],
config: FusionConfig,
) -> Vec<FusedResult<I>>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() || lists.len() != retriever_ids.len() {
return Vec::new();
}
let num_retrievers = lists.len() as f32;
let mut scores: HashMap<I, f32> = HashMap::new();
let mut provenance: HashMap<I, Vec<SourceContribution>> = HashMap::new();
for (list, retriever_id) in lists.iter().zip(retriever_ids.iter()) {
let items = list.as_ref();
let (mean, std) = zscore_params(items);
for (rank, (id, original_score)) in items.iter().enumerate() {
let z = if std > SCORE_RANGE_EPSILON {
((original_score - mean) / std).clamp(-3.0, 3.0)
} else {
0.0
};
let contribution = z;
*scores.entry(id.clone()).or_insert(0.0) += contribution;
provenance
.entry(id.clone())
.or_default()
.push(SourceContribution {
retriever_id: retriever_id.id.clone(),
original_rank: Some(rank),
original_score: Some(*original_score),
normalized_score: Some(z),
contribution,
});
}
}
build_explained_results(scores, provenance, num_retrievers, "dbsf", config.top_k)
}
fn build_explained_results<I: Clone + Eq + Hash>(
scores: HashMap<I, f32>,
mut provenance: HashMap<I, Vec<SourceContribution>>,
num_retrievers: f32,
method: &'static str,
top_k: Option<usize>,
) -> Vec<FusedResult<I>> {
let mut results: Vec<FusedResult<I>> = scores
.into_iter()
.map(|(id, score)| {
let sources = provenance.remove(&id).unwrap_or_default();
let consensus_score = sources.len() as f32 / num_retrievers;
FusedResult {
id,
score,
rank: 0, explanation: Explanation {
sources,
method,
consensus_score,
},
}
})
.collect();
results.sort_by(|a, b| b.score.total_cmp(&a.score));
for (rank, result) in results.iter_mut().enumerate() {
result.rank = rank;
}
if let Some(k) = top_k {
results.truncate(k);
}
results
}
#[must_use]
pub fn combmax<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
) -> Vec<(I, f32)> {
combmax_multi(&[results_a, results_b], FusionConfig::default())
}
#[must_use]
pub fn combmax_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
let mut scores: HashMap<I, f32> = HashMap::new();
for list in lists {
for (id, s) in list.as_ref() {
scores
.entry(id.clone())
.and_modify(|max_score| *max_score = max_score.max(*s))
.or_insert(*s);
}
}
finalize(scores, config.top_k)
}
#[must_use]
pub fn combmin<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
) -> Vec<(I, f32)> {
combmin_multi(&[results_a, results_b], FusionConfig::default())
}
#[must_use]
pub fn combmin_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
let mut scores: HashMap<I, f32> = HashMap::new();
for list in lists {
for (id, s) in list.as_ref() {
scores
.entry(id.clone())
.and_modify(|min_score| *min_score = min_score.min(*s))
.or_insert(*s);
}
}
finalize(scores, config.top_k)
}
#[must_use]
pub fn combmed<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
) -> Vec<(I, f32)> {
combmed_multi(&[results_a, results_b], FusionConfig::default())
}
#[must_use]
pub fn combmed_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
let mut score_lists: HashMap<I, Vec<f32>> = HashMap::new();
for list in lists {
for (id, s) in list.as_ref() {
score_lists.entry(id.clone()).or_default().push(*s);
}
}
let mut scores: HashMap<I, f32> = HashMap::new();
for (id, mut score_vec) in score_lists {
score_vec.sort_by(|a, b| a.total_cmp(b));
let median = if score_vec.len() % 2 == 0 {
let mid = score_vec.len() / 2;
(score_vec[mid - 1] + score_vec[mid]) / 2.0
} else {
score_vec[score_vec.len() / 2]
};
scores.insert(id, median);
}
finalize(scores, config.top_k)
}
#[must_use]
pub fn combanz<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
) -> Vec<(I, f32)> {
combanz_multi(&[results_a, results_b], FusionConfig::default())
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn combanz_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
let mut scores: HashMap<I, (f32, usize)> = HashMap::new();
for list in lists {
for (id, s) in list.as_ref() {
let entry = scores.entry(id.clone()).or_insert((0.0, 0));
entry.0 += s;
entry.1 += 1;
}
}
let mut results: Vec<_> = scores
.into_iter()
.map(|(id, (sum, count))| {
debug_assert!(count > 0, "Count should always be > 0 for CombANZ");
(id, sum / count as f32)
})
.collect();
sort_scored_desc(&mut results);
if let Some(top_k) = config.top_k {
results.truncate(top_k);
}
results
}
#[must_use]
pub fn rbc<I: Clone + Eq + Hash>(results_a: &[(I, f32)], results_b: &[(I, f32)]) -> Vec<(I, f32)> {
rbc_multi(&[results_a, results_b], 0.8)
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn rbc_multi<I, L>(lists: &[L], persistence: f32) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
let p = persistence.clamp(0.0, 1.0);
let mut scores: HashMap<I, f32> = HashMap::new();
for list in lists {
let items = list.as_ref();
let n = items.len() as f32;
let denominator = 1.0 - p.powi(n as i32);
for (rank, (id, _)) in items.iter().enumerate() {
let numerator = (1.0 - p).powi(rank as i32);
let contribution = if denominator > 1e-9 {
numerator / denominator
} else {
0.0
};
*scores.entry(id.clone()).or_insert(0.0) += contribution;
}
}
finalize(scores, None)
}
#[must_use]
pub fn condorcet<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
) -> Vec<(I, f32)> {
condorcet_multi(&[results_a, results_b], FusionConfig::default())
}
#[must_use]
pub fn condorcet_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
let mut doc_ranks: HashMap<I, Vec<Option<usize>>> = HashMap::new();
let mut all_docs: std::collections::HashSet<I> = std::collections::HashSet::new();
for list in lists {
let items = list.as_ref();
for doc_id in items.iter().map(|(id, _)| id) {
all_docs.insert(doc_id.clone());
}
}
for doc_id in &all_docs {
doc_ranks.insert(doc_id.clone(), vec![None; lists.len()]);
}
for (list_idx, list) in lists.iter().enumerate() {
for (rank, (id, _)) in list.as_ref().iter().enumerate() {
if let Some(ranks) = doc_ranks.get_mut(id) {
ranks[list_idx] = Some(rank);
}
}
}
let mut scores: HashMap<I, f32> = HashMap::new();
let doc_vec: Vec<I> = all_docs.into_iter().collect();
for (i, d1) in doc_vec.iter().enumerate() {
let mut wins = 0;
for (j, d2) in doc_vec.iter().enumerate() {
if i == j {
continue;
}
let d1_ranks = &doc_ranks[d1];
let d2_ranks = &doc_ranks[d2];
let mut d1_wins = 0;
for (r1, r2) in d1_ranks.iter().zip(d2_ranks.iter()) {
match (r1, r2) {
(Some(rank1), Some(rank2)) if rank1 < rank2 => d1_wins += 1,
(Some(_), None) => d1_wins += 1, _ => {}
}
}
if d1_wins > lists.len() / 2 {
wins += 1;
}
}
scores.insert(d1.clone(), wins as f32);
}
finalize(scores, config.top_k)
}
#[must_use]
pub fn copeland<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
) -> Vec<(I, f32)> {
copeland_multi(&[results_a, results_b], FusionConfig::default())
}
#[must_use]
pub fn copeland_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
let mut doc_ranks: HashMap<I, Vec<Option<usize>>> = HashMap::new();
let mut all_docs: std::collections::HashSet<I> = std::collections::HashSet::new();
for list in lists {
for (id, _) in list.as_ref() {
all_docs.insert(id.clone());
}
}
for doc_id in &all_docs {
doc_ranks.insert(doc_id.clone(), vec![None; lists.len()]);
}
for (list_idx, list) in lists.iter().enumerate() {
for (rank, (id, _)) in list.as_ref().iter().enumerate() {
if let Some(ranks) = doc_ranks.get_mut(id) {
ranks[list_idx] = Some(rank);
}
}
}
let mut scores: HashMap<I, f32> = HashMap::new();
let doc_vec: Vec<I> = all_docs.into_iter().collect();
for (i, d1) in doc_vec.iter().enumerate() {
let mut net = 0i32;
for (j, d2) in doc_vec.iter().enumerate() {
if i == j {
continue;
}
let d1_ranks = &doc_ranks[d1];
let d2_ranks = &doc_ranks[d2];
let mut d1_preferred = 0;
let mut d2_preferred = 0;
for (r1, r2) in d1_ranks.iter().zip(d2_ranks.iter()) {
match (r1, r2) {
(Some(rank1), Some(rank2)) => {
if rank1 < rank2 {
d1_preferred += 1;
} else if rank2 < rank1 {
d2_preferred += 1;
}
}
(Some(_), None) => d1_preferred += 1, (None, Some(_)) => d2_preferred += 1,
(None, None) => {}
}
}
if d1_preferred > d2_preferred {
net += 1; } else if d2_preferred > d1_preferred {
net -= 1; }
}
scores.insert(d1.clone(), net as f32);
}
finalize(scores, config.top_k)
}
#[must_use]
pub fn median_rank<I: Clone + Eq + Hash>(
results_a: &[(I, f32)],
results_b: &[(I, f32)],
) -> Vec<(I, f32)> {
median_rank_multi(&[results_a, results_b], FusionConfig::default())
}
#[must_use]
pub fn median_rank_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
L: AsRef<[(I, f32)]>,
{
if lists.is_empty() {
return Vec::new();
}
let max_len = lists.iter().map(|l| l.as_ref().len()).max().unwrap_or(0);
let penalty_rank = max_len + 1;
let mut doc_ranks: HashMap<I, Vec<usize>> = HashMap::new();
for list in lists {
for (rank, (id, _)) in list.as_ref().iter().enumerate() {
doc_ranks.entry(id.clone()).or_default().push(rank);
}
}
let mut scores: HashMap<I, f32> = HashMap::new();
for (id, mut ranks) in doc_ranks {
while ranks.len() < lists.len() {
ranks.push(penalty_rank);
}
ranks.sort_unstable();
let median = if ranks.len() % 2 == 1 {
ranks[ranks.len() / 2] as f32
} else {
(ranks[ranks.len() / 2 - 1] + ranks[ranks.len() / 2]) as f32 / 2.0
};
scores.insert(id, 1.0 / (1.0 + median));
}
finalize(scores, config.top_k)
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct MmrConfig {
pub lambda: f32,
pub top_k: usize,
}
impl Default for MmrConfig {
fn default() -> Self {
Self {
lambda: 0.7,
top_k: 10,
}
}
}
impl MmrConfig {
#[must_use]
pub fn new(lambda: f32) -> Self {
assert!(
(0.0..=1.0).contains(&lambda),
"lambda must be in [0.0, 1.0], got {lambda}"
);
Self { lambda, top_k: 10 }
}
#[must_use]
pub const fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = top_k;
self
}
}
#[must_use]
pub fn mmr<I, F>(candidates: &[(I, f32)], similarity: F, config: MmrConfig) -> Vec<(I, f32)>
where
I: Clone + Eq + Hash,
F: Fn(&I, &I) -> f32,
{
if candidates.is_empty() {
return Vec::new();
}
let k = config.top_k.min(candidates.len());
let lambda = config.lambda;
let max_rel = candidates
.iter()
.map(|(_, s)| *s)
.fold(f32::NEG_INFINITY, f32::max);
let min_rel = candidates
.iter()
.map(|(_, s)| *s)
.fold(f32::INFINITY, f32::min);
let rel_range = max_rel - min_rel;
let normalized: Vec<(I, f32)> = if rel_range > SCORE_RANGE_EPSILON {
candidates
.iter()
.map(|(id, s)| (id.clone(), (s - min_rel) / rel_range))
.collect()
} else {
candidates.iter().map(|(id, _)| (id.clone(), 1.0)).collect()
};
let mut selected: Vec<(I, f32)> = Vec::with_capacity(k);
let mut remaining: Vec<(I, f32)> = normalized;
while selected.len() < k && !remaining.is_empty() {
let mut best_idx = 0;
let mut best_mmr = f32::NEG_INFINITY;
for (idx, (cand_id, cand_rel)) in remaining.iter().enumerate() {
let relevance_term = lambda * cand_rel;
let redundancy_term = if selected.is_empty() {
0.0
} else {
let max_sim = selected
.iter()
.map(|(sel_id, _)| similarity(cand_id, sel_id))
.fold(0.0_f32, f32::max);
(1.0 - lambda) * max_sim
};
let mmr_score = relevance_term - redundancy_term;
if mmr_score > best_mmr {
best_mmr = mmr_score;
best_idx = idx;
}
}
let (id, _) = remaining.remove(best_idx);
selected.push((id, best_mmr));
}
selected
}
#[must_use]
pub fn mmr_with_matrix<I: Clone + Eq + Hash>(
candidates: &[(I, f32)],
sim_matrix: &HashMap<(I, I), f32>,
config: MmrConfig,
) -> Vec<(I, f32)> {
let similarity =
|a: &I, b: &I| -> f32 { *sim_matrix.get(&(a.clone(), b.clone())).unwrap_or(&0.0) };
mmr(candidates, similarity, config)
}
#[must_use]
pub fn mmr_embeddings<I: Clone + Eq + Hash>(
candidates: &[(I, f32, Vec<f32>)],
config: MmrConfig,
) -> Vec<(I, f32)> {
if candidates.is_empty() {
return Vec::new();
}
let embeddings: HashMap<I, &[f32]> = candidates
.iter()
.map(|(id, _, emb)| (id.clone(), emb.as_slice()))
.collect();
let id_scores: Vec<(I, f32)> = candidates
.iter()
.map(|(id, score, _)| (id.clone(), *score))
.collect();
let similarity = |a: &I, b: &I| -> f32 {
match (embeddings.get(a), embeddings.get(b)) {
(Some(emb_a), Some(emb_b)) => cosine_similarity(emb_a, emb_b),
_ => 0.0,
}
};
mmr(&id_scores, similarity, config)
}
#[inline]
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
return 0.0;
}
(dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
}
pub type Qrels<K> = std::collections::HashMap<K, u32>;
pub fn ndcg_at_k<K: Clone + Eq + Hash>(results: &[(K, f32)], qrels: &Qrels<K>, k: usize) -> f32 {
if qrels.is_empty() || results.is_empty() {
return 0.0;
}
let k = k.min(results.len());
let mut dcg = 0.0;
for (i, (id, _)) in results.iter().take(k).enumerate() {
if let Some(&rel) = qrels.get(id) {
let gain = (2.0_f32.powi(rel as i32) - 1.0) / ((i + 2) as f32).log2();
dcg += gain;
}
}
let mut ideal_relevances: Vec<u32> = qrels.values().copied().collect();
ideal_relevances.sort_by(|a, b| b.cmp(a));
let mut idcg = 0.0;
for (i, &rel) in ideal_relevances.iter().take(k).enumerate() {
let gain = (2.0_f32.powi(rel as i32) - 1.0) / ((i + 2) as f32).log2();
idcg += gain;
}
if idcg > 1e-9 {
dcg / idcg
} else {
0.0
}
}
pub fn mrr<K: Clone + Eq + Hash>(results: &[(K, f32)], qrels: &Qrels<K>) -> f32 {
for (rank, (id, _)) in results.iter().enumerate() {
if qrels.contains_key(id) && qrels[id] > 0 {
return 1.0 / (rank + 1) as f32;
}
}
0.0
}
pub fn recall_at_k<K: Clone + Eq + Hash>(results: &[(K, f32)], qrels: &Qrels<K>, k: usize) -> f32 {
let total_relevant = qrels.values().filter(|&&rel| rel > 0).count();
if total_relevant == 0 {
return 0.0;
}
let k = k.min(results.len());
let relevant_in_top_k = results
.iter()
.take(k)
.filter(|(id, _)| qrels.get(id).is_some_and(|&rel| rel > 0))
.count();
relevant_in_top_k as f32 / total_relevant as f32
}
pub fn precision_at_k<K: Clone + Eq + Hash>(
results: &[(K, f32)],
qrels: &Qrels<K>,
k: usize,
) -> f32 {
if k == 0 || results.is_empty() {
return 0.0;
}
let k = k.min(results.len());
let relevant_in_top_k = results
.iter()
.take(k)
.filter(|(id, _)| qrels.get(id).is_some_and(|&rel| rel > 0))
.count();
relevant_in_top_k as f32 / k as f32
}
pub fn map<K: Clone + Eq + Hash>(results: &[(K, f32)], qrels: &Qrels<K>) -> f32 {
let total_relevant = qrels.values().filter(|&&rel| rel > 0).count();
if total_relevant == 0 || results.is_empty() {
return 0.0;
}
let mut sum_precision = 0.0;
let mut relevant_seen = 0;
for (i, (id, _)) in results.iter().enumerate() {
if qrels.get(id).is_some_and(|&rel| rel > 0) {
relevant_seen += 1;
sum_precision += relevant_seen as f32 / (i + 1) as f32;
}
}
sum_precision / total_relevant as f32
}
pub fn map_at_k<K: Clone + Eq + Hash>(results: &[(K, f32)], qrels: &Qrels<K>, k: usize) -> f32 {
let total_relevant = qrels.values().filter(|&&rel| rel > 0).count();
if total_relevant == 0 || results.is_empty() || k == 0 {
return 0.0;
}
let k = k.min(results.len());
let mut sum_precision = 0.0;
let mut relevant_seen = 0;
for (i, (id, _)) in results.iter().take(k).enumerate() {
if qrels.get(id).is_some_and(|&rel| rel > 0) {
relevant_seen += 1;
sum_precision += relevant_seen as f32 / (i + 1) as f32;
}
}
sum_precision / total_relevant.min(k) as f32
}
pub fn hit_rate<K: Clone + Eq + Hash>(results: &[(K, f32)], qrels: &Qrels<K>, k: usize) -> f32 {
if k == 0 || results.is_empty() {
return 0.0;
}
let k = k.min(results.len());
let hit = results
.iter()
.take(k)
.any(|(id, _)| qrels.get(id).is_some_and(|&rel| rel > 0));
if hit {
1.0
} else {
0.0
}
}
#[derive(Debug, Clone)]
pub struct OptimizeConfig {
pub metric: OptimizeMetric,
pub param_grid: ParamGrid,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OptimizeMetric {
Ndcg {
k: usize,
},
Mrr,
Recall {
k: usize,
},
Precision {
k: usize,
},
Map,
MapAtK {
k: usize,
},
HitRate {
k: usize,
},
}
impl Default for OptimizeMetric {
fn default() -> Self {
Self::Ndcg { k: 10 }
}
}
#[derive(Debug, Clone)]
pub enum ParamGrid {
RrfK {
values: Vec<u32>,
},
Weighted {
weight_combinations: Vec<Vec<f32>>,
},
}
#[derive(Debug, Clone)]
pub struct OptimizedParams {
pub best_score: f32,
pub best_params: String,
}
pub fn evaluate_metric<K: Clone + Eq + Hash>(
results: &[(K, f32)],
qrels: &Qrels<K>,
metric: OptimizeMetric,
) -> f32 {
match metric {
OptimizeMetric::Ndcg { k } => ndcg_at_k(results, qrels, k),
OptimizeMetric::Mrr => mrr(results, qrels),
OptimizeMetric::Recall { k } => recall_at_k(results, qrels, k),
OptimizeMetric::Precision { k } => precision_at_k(results, qrels, k),
OptimizeMetric::Map => map(results, qrels),
OptimizeMetric::MapAtK { k } => map_at_k(results, qrels, k),
OptimizeMetric::HitRate { k } => hit_rate(results, qrels, k),
}
}
pub fn optimize_fusion<K: Clone + Eq + Hash>(
qrels: &Qrels<K>,
runs: &[Vec<(K, f32)>],
config: OptimizeConfig,
) -> OptimizedParams {
let mut best_score = f32::NEG_INFINITY;
let mut best_params = String::new();
match config.param_grid {
ParamGrid::RrfK { values } => {
for k in values {
let method = FusionMethod::Rrf { k };
let fused = method.fuse_multi(runs);
let score = evaluate_metric(&fused, qrels, config.metric);
if score > best_score {
best_score = score;
best_params = format!("k={}", k);
}
}
}
ParamGrid::Weighted {
ref weight_combinations,
} => {
for weights in weight_combinations {
if weights.len() != runs.len() {
continue;
}
let lists: Vec<(&[(K, f32)], f32)> = runs
.iter()
.zip(weights.iter())
.map(|(run, &w)| (run.as_slice(), w))
.collect();
if let Ok(fused) = weighted_multi(&lists, true, None) {
let score = evaluate_metric(&fused, qrels, config.metric);
if score > best_score {
best_score = score;
best_params = format!("weights={:?}", weights);
}
}
}
}
}
OptimizedParams {
best_score,
best_params,
}
}
pub mod optimize {
pub use crate::{
evaluate_metric, hit_rate, map, map_at_k, mrr, ndcg_at_k, optimize_fusion, precision_at_k,
recall_at_k, OptimizeConfig, OptimizeMetric, OptimizedParams, ParamGrid, Qrels,
};
}
#[cfg(test)]
mod tests {
use super::*;
fn ranked<'a>(ids: &[&'a str]) -> Vec<(&'a str, f32)> {
ids.iter()
.enumerate()
.map(|(i, &id)| (id, 1.0 - i as f32 * 0.1))
.collect()
}
#[test]
fn rrf_basic() {
let a = ranked(&["d1", "d2", "d3"]);
let b = ranked(&["d2", "d3", "d4"]);
let f = rrf(&a, &b);
assert_eq!(f.iter().position(|(id, _)| *id == "d2").unwrap(), 0);
}
#[test]
fn rrf_with_top_k() {
let a = ranked(&["d1", "d2", "d3"]);
let b = ranked(&["d2", "d3", "d4"]);
let f = rrf_with_config(&a, &b, RrfConfig::default().with_top_k(2));
assert_eq!(f.len(), 2);
}
#[test]
fn rrf_score_formula() {
let a = vec![("d1", 1.0)];
let b: Vec<(&str, f32)> = vec![];
let f = rrf_with_config(&a, &b, RrfConfig::new(60));
let expected = 1.0 / 60.0;
assert!((f[0].1 - expected).abs() < 1e-6);
}
#[test]
fn rrf_exact_score_computation() {
let a = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
let b = vec![("d4", 0.9), ("d5", 0.8), ("d1", 0.7)];
let f = rrf_with_config(&a, &b, RrfConfig::new(60));
let d1_score = f.iter().find(|(id, _)| *id == "d1").unwrap().1;
let expected = 1.0 / 60.0 + 1.0 / 62.0;
assert!(
(d1_score - expected).abs() < 1e-6,
"d1 score {} != expected {}",
d1_score,
expected
);
}
#[test]
fn isr_exact_score_computation() {
let a = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
let b = vec![("d4", 0.9), ("d5", 0.8), ("d1", 0.7)];
let f = isr_with_config(&a, &b, RrfConfig::new(1));
let d1_score = f.iter().find(|(id, _)| *id == "d1").unwrap().1;
let expected = 1.0 / 1.0_f32.sqrt() + 1.0 / 3.0_f32.sqrt();
assert!(
(d1_score - expected).abs() < 1e-6,
"d1 score {} != expected {}",
d1_score,
expected
);
}
#[test]
fn borda_exact_score_computation() {
let a = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
let b = vec![("d4", 0.9), ("d5", 0.8), ("d1", 0.7), ("d6", 0.6)];
let f = borda(&a, &b);
let d1_score = f.iter().find(|(id, _)| *id == "d1").unwrap().1;
let expected = 3.0 + 2.0;
assert!(
(d1_score - expected).abs() < 1e-6,
"d1 score {} != expected {}",
d1_score,
expected
);
}
#[test]
fn rrf_weighted_applies_weights() {
let list_a = [("d1", 0.0)];
let list_b = [("d2", 0.0)];
let weights = [0.25, 0.75];
let f = rrf_weighted(&[&list_a[..], &list_b[..]], &weights, RrfConfig::new(60)).unwrap();
assert_eq!(f[0].0, "d2", "weighted RRF should favor higher-weight list");
let d1_score = f.iter().find(|(id, _)| *id == "d1").unwrap().1;
let d2_score = f.iter().find(|(id, _)| *id == "d2").unwrap().1;
assert!(
d2_score > d1_score * 2.0,
"d2 should score ~3x higher than d1"
);
}
#[test]
fn rrf_weighted_zero_weights_error() {
let list_a = [("d1", 0.0)];
let list_b = [("d2", 0.0)];
let weights = [0.0, 0.0];
let result = rrf_weighted(&[&list_a[..], &list_b[..]], &weights, RrfConfig::default());
assert!(matches!(result, Err(FusionError::ZeroWeights)));
}
#[test]
fn isr_basic() {
let a = ranked(&["d1", "d2", "d3"]);
let b = ranked(&["d2", "d3", "d4"]);
let f = isr(&a, &b);
assert_eq!(f.iter().position(|(id, _)| *id == "d2").unwrap(), 0);
}
#[test]
fn isr_score_formula() {
let a = vec![("d1", 1.0)];
let b: Vec<(&str, f32)> = vec![];
let f = isr_with_config(&a, &b, RrfConfig::new(1));
let expected = 1.0 / 1.0_f32.sqrt(); assert!((f[0].1 - expected).abs() < 1e-6);
}
#[test]
fn isr_gentler_decay_than_rrf() {
let a = vec![("d1", 1.0), ("d2", 0.9), ("d3", 0.8), ("d4", 0.7)];
let b: Vec<(&str, f32)> = vec![];
let rrf_result = rrf_with_config(&a, &b, RrfConfig::new(1));
let isr_result = isr_with_config(&a, &b, RrfConfig::new(1));
let rrf_ratio = rrf_result[0].1 / rrf_result[3].1;
let isr_ratio = isr_result[0].1 / isr_result[3].1;
assert!(
isr_ratio < rrf_ratio,
"ISR should have gentler decay: ISR ratio={}, RRF ratio={}",
isr_ratio,
rrf_ratio
);
}
#[test]
fn isr_multi_works() {
let a = ranked(&["d1", "d2"]);
let b = ranked(&["d2", "d3"]);
let c = ranked(&["d3", "d4"]);
let f = isr_multi(&[&a, &b, &c], RrfConfig::new(1));
assert_eq!(f.len(), 4);
let top_2: Vec<_> = f.iter().take(2).map(|(id, _)| *id).collect();
assert!(top_2.contains(&"d2") && top_2.contains(&"d3"));
}
#[test]
fn isr_with_top_k() {
let a = ranked(&["d1", "d2", "d3"]);
let b = ranked(&["d2", "d3", "d4"]);
let f = isr_with_config(&a, &b, RrfConfig::new(1).with_top_k(2));
assert_eq!(f.len(), 2);
}
#[test]
fn isr_empty_lists() {
let empty: Vec<(&str, f32)> = vec![];
let non_empty = ranked(&["d1"]);
assert_eq!(isr(&empty, &non_empty).len(), 1);
assert_eq!(isr(&non_empty, &empty).len(), 1);
assert_eq!(isr(&empty, &empty).len(), 0);
}
#[test]
fn fusion_method_isr() {
let a = ranked(&["d1", "d2"]);
let b = ranked(&["d2", "d3"]);
let f = FusionMethod::isr().fuse(&a, &b);
assert_eq!(f[0].0, "d2");
let f = FusionMethod::isr_with_k(10).fuse(&a, &b);
assert_eq!(f[0].0, "d2");
}
#[test]
fn fusion_method_isr_multi() {
let a = ranked(&["d1", "d2"]);
let b = ranked(&["d2", "d3"]);
let c = ranked(&["d3", "d4"]);
let lists = [&a[..], &b[..], &c[..]];
let f = FusionMethod::isr().fuse_multi(&lists);
assert!(!f.is_empty());
}
#[test]
fn combmnz_rewards_overlap() {
let a = ranked(&["d1", "d2"]);
let b = ranked(&["d2", "d3"]);
let f = combmnz(&a, &b);
assert_eq!(f[0].0, "d2");
}
#[test]
fn combsum_basic() {
let a = vec![("d1", 0.5), ("d2", 1.0)];
let b = vec![("d2", 1.0), ("d3", 0.5)];
let f = combsum(&a, &b);
assert_eq!(f[0].0, "d2");
}
#[test]
fn weighted_skewed() {
let a = vec![("d1", 1.0)];
let b = vec![("d2", 1.0)];
let f = weighted(
&a,
&b,
WeightedConfig::default()
.with_weights(0.9, 0.1)
.with_normalize(false),
);
assert_eq!(f[0].0, "d1");
let f = weighted(
&a,
&b,
WeightedConfig::default()
.with_weights(0.1, 0.9)
.with_normalize(false),
);
assert_eq!(f[0].0, "d2");
}
#[test]
fn borda_symmetric() {
let a = ranked(&["d1", "d2", "d3"]);
let b = ranked(&["d3", "d2", "d1"]);
let f = borda(&a, &b);
let scores: Vec<f32> = f.iter().map(|(_, s)| *s).collect();
assert!((scores[0] - scores[1]).abs() < 0.01);
assert!((scores[1] - scores[2]).abs() < 0.01);
}
#[test]
fn rrf_multi_works() {
let lists: Vec<Vec<(&str, f32)>> = vec![
ranked(&["d1", "d2"]),
ranked(&["d2", "d3"]),
ranked(&["d1", "d3"]),
];
let f = rrf_multi(&lists, RrfConfig::default());
assert_eq!(f.len(), 3);
}
#[test]
fn borda_multi_works() {
let lists: Vec<Vec<(&str, f32)>> = vec![
ranked(&["d1", "d2"]),
ranked(&["d2", "d3"]),
ranked(&["d1", "d3"]),
];
let f = borda_multi(&lists, FusionConfig::default());
assert_eq!(f.len(), 3);
assert_eq!(f[0].0, "d1");
}
#[test]
fn combsum_multi_works() {
let lists: Vec<Vec<(&str, f32)>> = vec![
vec![("d1", 1.0), ("d2", 0.5)],
vec![("d2", 1.0), ("d3", 0.5)],
vec![("d1", 1.0), ("d3", 0.5)],
];
let f = combsum_multi(&lists, FusionConfig::default());
assert_eq!(f.len(), 3);
}
#[test]
fn combmnz_multi_works() {
let lists: Vec<Vec<(&str, f32)>> = vec![
vec![("d1", 1.0)],
vec![("d1", 1.0), ("d2", 0.5)],
vec![("d1", 1.0), ("d2", 0.5)],
];
let f = combmnz_multi(&lists, FusionConfig::default());
assert_eq!(f[0].0, "d1");
}
#[test]
fn weighted_multi_works() {
let a = vec![("d1", 1.0)];
let b = vec![("d2", 1.0)];
let c = vec![("d3", 1.0)];
let f = weighted_multi(&[(&a, 1.0), (&b, 1.0), (&c, 1.0)], false, None).unwrap();
assert_eq!(f.len(), 3);
let f = weighted_multi(&[(&a, 10.0), (&b, 1.0), (&c, 1.0)], false, None).unwrap();
assert_eq!(f[0].0, "d1");
}
#[test]
fn weighted_multi_zero_weights() {
let a = vec![("d1", 1.0)];
let result = weighted_multi(&[(&a, 0.0)], false, None);
assert!(matches!(result, Err(FusionError::ZeroWeights)));
}
#[test]
fn empty_inputs() {
let empty: Vec<(&str, f32)> = vec![];
let non_empty = ranked(&["d1"]);
assert_eq!(rrf(&empty, &non_empty).len(), 1);
assert_eq!(rrf(&non_empty, &empty).len(), 1);
}
#[test]
fn both_empty() {
let empty: Vec<(&str, f32)> = vec![];
assert_eq!(rrf(&empty, &empty).len(), 0);
assert_eq!(combsum(&empty, &empty).len(), 0);
assert_eq!(borda(&empty, &empty).len(), 0);
}
#[test]
fn duplicate_ids_in_same_list() {
let a = vec![("d1", 1.0), ("d1", 0.5)];
let b: Vec<(&str, f32)> = vec![];
let f = rrf_with_config(&a, &b, RrfConfig::new(60));
assert_eq!(f.len(), 1);
let expected = 1.0 / 60.0 + 1.0 / 61.0;
assert!((f[0].1 - expected).abs() < 1e-6);
}
#[test]
fn builder_pattern() {
let config = RrfConfig::default().with_k(30).with_top_k(5);
assert_eq!(config.k, 30);
assert_eq!(config.top_k, Some(5));
let config = WeightedConfig::default()
.with_weights(0.8, 0.2)
.with_normalize(false)
.with_top_k(10);
assert_eq!(config.weight_a, 0.8);
assert!(!config.normalize);
assert_eq!(config.top_k, Some(10));
}
#[test]
fn nan_scores_handled() {
let a = vec![("d1", f32::NAN), ("d2", 0.5)];
let b = vec![("d2", 0.9), ("d3", 0.1)];
let r = rrf(&a, &b);
assert!(!r.is_empty());
assert!(r.iter().all(|(_, s)| s.is_finite()));
let r = combsum(&a, &b);
assert!(!r.is_empty());
let r = combmnz(&a, &b);
assert!(!r.is_empty());
let r = borda(&a, &b);
assert!(!r.is_empty());
}
#[test]
fn inf_scores_handled() {
let a = vec![("d1", f32::INFINITY), ("d2", 0.5)];
let b = vec![("d2", f32::NEG_INFINITY), ("d3", 0.1)];
let r = rrf(&a, &b);
assert!(!r.is_empty());
assert!(r.iter().all(|(_, s)| s.is_finite()));
let r = combsum(&a, &b);
assert!(!r.is_empty());
}
#[test]
fn zero_scores() {
let a = vec![("d1", 0.0), ("d2", 0.0)];
let b = vec![("d2", 0.0), ("d3", 0.0)];
let f = combsum(&a, &b);
assert_eq!(f.len(), 3);
}
#[test]
fn negative_scores() {
let a = vec![("d1", -1.0), ("d2", -0.5)];
let b = vec![("d2", -0.9), ("d3", -0.1)];
let f = combsum(&a, &b);
assert_eq!(f.len(), 3);
}
#[test]
fn large_k_value() {
let a = ranked(&["d1", "d2"]);
let b = ranked(&["d2", "d3"]);
let f = rrf_with_config(&a, &b, RrfConfig::new(u32::MAX));
assert!(!f.is_empty());
}
#[test]
#[should_panic(expected = "k must be >= 1")]
fn k_zero_panics() {
let _ = RrfConfig::new(0);
}
#[test]
#[should_panic(expected = "k must be >= 1")]
fn k_zero_with_k_panics() {
let _ = RrfConfig::default().with_k(0);
}
#[test]
fn all_nan_scores() {
let a = vec![("d1", f32::NAN), ("d2", f32::NAN)];
let b = vec![("d3", f32::NAN), ("d4", f32::NAN)];
let f = rrf(&a, &b);
assert_eq!(f.len(), 4);
for (_, score) in &f {
assert!(
score.is_finite(),
"RRF scores should be finite (based on ranks, not input scores)"
);
}
}
#[test]
fn empty_lists_multi() {
let empty: Vec<Vec<(&str, f32)>> = vec![];
assert_eq!(rrf_multi(&empty, RrfConfig::default()).len(), 0);
assert_eq!(combsum_multi(&empty, FusionConfig::default()).len(), 0);
assert_eq!(combmnz_multi(&empty, FusionConfig::default()).len(), 0);
assert_eq!(borda_multi(&empty, FusionConfig::default()).len(), 0);
assert_eq!(dbsf_multi(&empty, FusionConfig::default()).len(), 0);
assert_eq!(isr_multi(&empty, RrfConfig::default()).len(), 0);
}
#[test]
fn rrf_weighted_list_weight_mismatch() {
let a = [("d1", 1.0)];
let b = [("d2", 1.0)];
let weights = [0.5, 0.5, 0.0];
let result = rrf_weighted(&[&a[..], &b[..]], &weights, RrfConfig::default());
assert!(matches!(result, Err(FusionError::InvalidConfig(_))));
}
#[test]
fn rrf_weighted_list_weight_mismatch_short() {
let a = [("d1", 1.0)];
let b = [("d2", 1.0)];
let weights = [0.5];
let result = rrf_weighted(&[&a[..], &b[..]], &weights, RrfConfig::default());
assert!(matches!(result, Err(FusionError::InvalidConfig(_))));
}
#[test]
fn duplicate_ids_commutative() {
let a = vec![("d1", 1.0), ("d1", 0.5), ("d2", 0.3)];
let b = vec![("d2", 0.9), ("d3", 0.7)];
let ab = rrf(&a, &b);
let ba = rrf(&b, &a);
let ab_ids: Vec<&str> = ab.iter().map(|(id, _)| *id).collect();
let ba_ids: Vec<&str> = ba.iter().map(|(id, _)| *id).collect();
assert_eq!(ab_ids.len(), ba_ids.len());
for id in &ab_ids {
assert!(ba_ids.contains(id));
}
}
#[test]
fn dbsf_zero_variance() {
let a = vec![("d1", 1.0), ("d2", 1.0), ("d3", 1.0)];
let b = vec![("d1", 0.9), ("d2", 0.5), ("d3", 0.1)];
let f = dbsf(&a, &b);
assert_eq!(f.len(), 3);
assert_eq!(f[0].0, "d1");
}
#[test]
fn single_item_lists() {
let a = vec![("d1", 1.0)];
let b = vec![("d1", 1.0)];
let f = rrf(&a, &b);
assert_eq!(f.len(), 1);
let f = combsum(&a, &b);
assert_eq!(f.len(), 1);
let f = borda(&a, &b);
assert_eq!(f.len(), 1);
}
#[test]
fn disjoint_lists() {
let a = vec![("d1", 1.0), ("d2", 0.9)];
let b = vec![("d3", 1.0), ("d4", 0.9)];
let f = rrf(&a, &b);
assert_eq!(f.len(), 4);
let f = combmnz(&a, &b);
assert_eq!(f.len(), 4);
}
#[test]
fn identical_lists() {
let a = ranked(&["d1", "d2", "d3"]);
let b = ranked(&["d1", "d2", "d3"]);
let f = rrf(&a, &b);
assert_eq!(f[0].0, "d1");
assert_eq!(f[1].0, "d2");
assert_eq!(f[2].0, "d3");
}
#[test]
fn reversed_lists() {
let a = ranked(&["d1", "d2", "d3"]);
let b = ranked(&["d3", "d2", "d1"]);
let f = rrf(&a, &b);
assert_eq!(f.len(), 3);
}
#[test]
fn top_k_larger_than_result() {
let a = ranked(&["d1"]);
let b = ranked(&["d2"]);
let f = rrf_with_config(&a, &b, RrfConfig::default().with_top_k(100));
assert_eq!(f.len(), 2);
}
#[test]
fn top_k_zero() {
let a = ranked(&["d1", "d2"]);
let b = ranked(&["d2", "d3"]);
let f = rrf_with_config(&a, &b, RrfConfig::default().with_top_k(0));
assert_eq!(f.len(), 0);
}
#[test]
fn fusion_method_rrf() {
let a = ranked(&["d1", "d2"]);
let b = ranked(&["d2", "d3"]);
let f = FusionMethod::rrf().fuse(&a, &b);
assert_eq!(f[0].0, "d2"); }
#[test]
fn fusion_method_combsum() {
let a = vec![("d1", 1.0_f32), ("d2", 0.6), ("d4", 0.2)];
let b = vec![("d2", 1.0_f32), ("d3", 0.5)];
let f = FusionMethod::CombSum.fuse(&a, &b);
assert_eq!(f[0].0, "d2");
}
#[test]
fn fusion_method_combmnz() {
let a = ranked(&["d1", "d2"]);
let b = ranked(&["d2", "d3"]);
let f = FusionMethod::CombMnz.fuse(&a, &b);
assert_eq!(f[0].0, "d2"); }
#[test]
fn fusion_method_borda() {
let a = ranked(&["d1", "d2"]);
let b = ranked(&["d2", "d3"]);
let f = FusionMethod::Borda.fuse(&a, &b);
assert_eq!(f[0].0, "d2");
}
#[test]
fn fusion_method_weighted() {
let a = vec![("d1", 1.0f32)];
let b = vec![("d2", 1.0f32)];
let f = FusionMethod::weighted(0.9, 0.1).fuse(&a, &b);
assert_eq!(f[0].0, "d1");
let f = FusionMethod::weighted(0.1, 0.9).fuse(&a, &b);
assert_eq!(f[0].0, "d2");
}
#[test]
fn fusion_method_multi() {
let lists: Vec<Vec<(&str, f32)>> = vec![
ranked(&["d1", "d2"]),
ranked(&["d2", "d3"]),
ranked(&["d1", "d3"]),
];
let f = FusionMethod::rrf().fuse_multi(&lists);
assert_eq!(f.len(), 3);
}
#[test]
fn fusion_method_default_is_rrf() {
let method = FusionMethod::default();
assert!(matches!(method, FusionMethod::Rrf { k: 60 }));
}
#[test]
fn mmr_basic() {
let candidates = vec![("d1", 0.95), ("d2", 0.90), ("d3", 0.85)];
let similarity = |a: &&str, b: &&str| -> f32 {
if a == b {
1.0
} else if (*a == "d1" && *b == "d2") || (*a == "d2" && *b == "d1") {
0.95 } else {
0.1 }
};
let config = MmrConfig::new(0.5).with_top_k(3);
let results = mmr(&candidates, similarity, config);
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, "d1");
assert_eq!(results[1].0, "d3");
assert_eq!(results[2].0, "d2");
}
#[test]
fn mmr_pure_relevance() {
let candidates = vec![
("d1", 0.9),
("d2", 0.95), ("d3", 0.8),
];
let similarity = |_a: &&str, _b: &&str| -> f32 { 0.5 };
let config = MmrConfig::new(1.0).with_top_k(3);
let results = mmr(&candidates, similarity, config);
assert_eq!(results[0].0, "d2"); assert_eq!(results[1].0, "d1"); assert_eq!(results[2].0, "d3"); }
#[test]
fn mmr_pure_diversity() {
let candidates = vec![
("d1", 0.9),
("d2", 0.9), ("d3", 0.9), ];
let similarity = |a: &&str, b: &&str| -> f32 {
if a == b {
1.0
} else if (*a == "d1" && *b == "d2") || (*a == "d2" && *b == "d1") {
0.9
} else {
0.1
}
};
let config = MmrConfig::new(0.0).with_top_k(2);
let results = mmr(&candidates, similarity, config);
assert_eq!(results.len(), 2);
let ids: Vec<&str> = results.iter().map(|(id, _)| *id).collect();
assert!(
ids.contains(&"d3"),
"d3 must appear: it is most diverse from d1 and d2"
);
}
#[test]
fn mmr_config_lambda_bounds() {
let c = MmrConfig::new(0.0);
assert_eq!(c.lambda, 0.0);
let c = MmrConfig::new(0.5);
assert_eq!(c.lambda, 0.5);
let c = MmrConfig::new(1.0);
assert_eq!(c.lambda, 1.0);
}
#[test]
#[should_panic(expected = "lambda must be in [0.0, 1.0]")]
fn mmr_config_lambda_negative() {
let _ = MmrConfig::new(-0.1);
}
#[test]
#[should_panic(expected = "lambda must be in [0.0, 1.0]")]
fn mmr_config_lambda_too_large() {
let _ = MmrConfig::new(1.1);
}
#[test]
fn mmr_empty_candidates() {
let candidates: Vec<(&str, f32)> = vec![];
let similarity = |_a: &&str, _b: &&str| -> f32 { 0.0 };
let results = mmr(&candidates, similarity, MmrConfig::default());
assert!(results.is_empty());
}
#[test]
fn mmr_single_candidate() {
let candidates = vec![("d1", 0.9)];
let similarity = |_a: &&str, _b: &&str| -> f32 { 1.0 };
let results = mmr(&candidates, similarity, MmrConfig::default());
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, "d1");
}
#[test]
fn mmr_matrix_based() {
let candidates = vec![("a", 0.9), ("b", 0.85), ("c", 0.8)];
let mut matrix: HashMap<(&str, &str), f32> = HashMap::new();
matrix.insert(("a", "b"), 0.9); matrix.insert(("b", "a"), 0.9);
matrix.insert(("a", "c"), 0.1);
matrix.insert(("c", "a"), 0.1);
matrix.insert(("b", "c"), 0.2);
matrix.insert(("c", "b"), 0.2);
let config = MmrConfig::new(0.5).with_top_k(2);
let results = mmr_with_matrix(&candidates, &matrix, config);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "a");
assert_eq!(results[1].0, "c");
}
#[test]
fn mmr_embedding_based() {
let candidates = vec![
("d1", 0.9, vec![1.0, 0.0, 0.0]), ("d2", 0.85, vec![0.9, 0.1, 0.0]), ("d3", 0.8, vec![0.0, 1.0, 0.0]), ];
let config = MmrConfig::new(0.5).with_top_k(2);
let results = mmr_embeddings(&candidates, config);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "d1");
assert_eq!(results[1].0, "d3");
}
#[test]
fn cosine_sim_basic() {
assert!((cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 1e-6);
assert!((cosine_similarity(&[1.0, 0.0], &[0.0, 1.0])).abs() < 1e-6);
assert!((cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]) - (-1.0)).abs() < 1e-6);
let sim = cosine_similarity(&[1.0, 0.0, 0.0], &[0.9, 0.1, 0.0]);
assert!(sim > 0.9);
assert_eq!(cosine_similarity(&[], &[]), 0.0);
assert_eq!(cosine_similarity(&[1.0], &[1.0, 2.0]), 0.0);
}
#[test]
fn quantile_normalization() {
let results = vec![
("a", 10.0),
("b", 20.0),
("c", 30.0),
("d", 40.0),
("e", 50.0),
];
let normed = normalize_scores(&results, Normalization::Quantile);
assert!((normed[0].1 - 0.0).abs() < 1e-6, "a should be 0.0");
assert!((normed[1].1 - 0.25).abs() < 1e-6, "b should be 0.25");
assert!((normed[2].1 - 0.5).abs() < 1e-6, "c should be 0.5");
assert!((normed[4].1 - 1.0).abs() < 1e-6, "e should be 1.0");
}
#[test]
fn quantile_normalization_single() {
let results = vec![("a", 42.0)];
let normed = normalize_scores(&results, Normalization::Quantile);
assert!((normed[0].1 - 0.5).abs() < 1e-6, "single item gets 0.5");
}
#[test]
fn sigmoid_normalization() {
let results = vec![("a", -10.0), ("b", 0.0), ("c", 10.0)];
let normed = normalize_scores(&results, Normalization::Sigmoid);
assert!(normed[0].1 < 0.01, "sigmoid(-10) should be near 0");
assert!((normed[1].1 - 0.5).abs() < 1e-6, "sigmoid(0) should be 0.5");
assert!(normed[2].1 > 0.99, "sigmoid(10) should be near 1");
}
#[test]
fn sigmoid_preserves_order() {
let results = vec![("a", 1.0), ("b", 3.0), ("c", 2.0)];
let normed = normalize_scores(&results, Normalization::Sigmoid);
assert!(normed[1].1 > normed[2].1);
assert!(normed[2].1 > normed[0].1);
}
#[test]
fn quantile_handles_non_gaussian() {
let results = vec![
("a", 0.1),
("b", 0.2),
("c", 0.3),
("d", 100.0), ];
let normed = normalize_scores(&results, Normalization::Quantile);
assert!((normed[0].1 - 0.0).abs() < 1e-6);
assert!((normed[1].1 - 1.0 / 3.0).abs() < 1e-6);
assert!((normed[2].1 - 2.0 / 3.0).abs() < 1e-6);
assert!((normed[3].1 - 1.0).abs() < 1e-6);
}
#[test]
fn copeland_basic() {
let a = ranked(&["d1", "d2", "d3"]);
let b = ranked(&["d2", "d1", "d3"]);
let c = ranked(&["d2", "d3", "d1"]);
let f = copeland_multi(&[&a, &b, &c], FusionConfig::default());
assert_eq!(f[0].0, "d2", "d2 should be Copeland winner");
}
#[test]
fn copeland_net_wins() {
let a = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
let b = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
let f = copeland(&a, &b);
assert_eq!(f[0].0, "d1");
assert!((f[0].1 - 2.0).abs() < 1e-6, "d1 net wins should be 2");
assert_eq!(f[2].0, "d3");
assert!((f[2].1 - (-2.0)).abs() < 1e-6, "d3 net wins should be -2");
}
#[test]
fn copeland_vs_condorcet_more_discriminative() {
let a = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
let b = vec![("d1", 0.9), ("d3", 0.8), ("d2", 0.7)];
let cope = copeland(&a, &b);
let cond = condorcet(&a, &b);
assert_eq!(cope[0].0, "d1");
assert_eq!(cond[0].0, "d1");
}
#[test]
fn copeland_commutative() {
let a = ranked(&["d1", "d2", "d3"]);
let b = ranked(&["d3", "d1", "d2"]);
let f1 = copeland(&a, &b);
let f2 = copeland(&b, &a);
assert_eq!(f1.len(), f2.len());
for (r1, r2) in f1.iter().zip(f2.iter()) {
assert_eq!(r1.0, r2.0);
assert!((r1.1 - r2.1).abs() < 1e-6);
}
}
#[test]
fn median_rank_basic() {
let a = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
let b = vec![("d1", 0.9), ("d2", 0.8)];
let f = median_rank(&a, &b);
assert_eq!(f[0].0, "d1", "d1 should rank first (median rank 0)");
assert!((f[0].1 - 1.0).abs() < 1e-6);
}
#[test]
fn median_rank_outlier_robust() {
let a = vec![("d1", 0.9), ("d2", 0.8)];
let b = vec![("d1", 0.9), ("d2", 0.8)];
let c: Vec<(&str, f32)> = vec![
("x1", 0.9),
("x2", 0.8),
("x3", 0.7),
("x4", 0.6),
("x5", 0.5),
("d1", 0.4),
("d2", 0.3),
];
let f = median_rank_multi(&[&a, &b, &c], FusionConfig::default());
let d1_pos = f.iter().position(|(id, _)| *id == "d1").unwrap();
let d2_pos = f.iter().position(|(id, _)| *id == "d2").unwrap();
assert!(
d1_pos < d2_pos,
"d1 should rank above d2 (outlier-robust median)"
);
}
#[test]
fn median_rank_commutative() {
let a = ranked(&["d1", "d2", "d3"]);
let b = ranked(&["d3", "d1", "d2"]);
let f1 = median_rank(&a, &b);
let f2 = median_rank(&b, &a);
assert_eq!(f1.len(), f2.len());
for (r1, r2) in f1.iter().zip(f2.iter()) {
assert_eq!(r1.0, r2.0);
assert!((r1.1 - r2.1).abs() < 1e-6);
}
}
#[test]
fn fusion_method_copeland_dispatch() {
let a = ranked(&["d1", "d2", "d3"]);
let b = ranked(&["d2", "d1", "d3"]);
let direct = copeland(&a, &b);
let via_enum = FusionMethod::Copeland.fuse(&a, &b);
let direct_map: HashMap<_, _> = direct.into_iter().collect();
let enum_map: HashMap<_, _> = via_enum.into_iter().collect();
assert_eq!(direct_map.len(), enum_map.len());
for (id, score) in &direct_map {
let other = enum_map.get(id).expect("same keys");
assert!((score - other).abs() < 1e-6);
}
}
#[test]
fn fusion_method_median_rank_dispatch() {
let a = ranked(&["d1", "d2", "d3"]);
let b = ranked(&["d3", "d1", "d2"]);
let direct = median_rank(&a, &b);
let via_enum = FusionMethod::MedianRank.fuse(&a, &b);
let direct_map: HashMap<_, _> = direct.into_iter().collect();
let enum_map: HashMap<_, _> = via_enum.into_iter().collect();
assert_eq!(direct_map.len(), enum_map.len());
for (id, score) in &direct_map {
let other = enum_map.get(id).expect("same keys");
assert!((score - other).abs() < 1e-6);
}
}
fn make_qrels() -> Qrels<&'static str> {
HashMap::from([("d1", 2), ("d2", 1), ("d3", 1)])
}
#[test]
fn ndcg_at_k_formula() {
let qrels: Qrels<&str> = HashMap::from([("doc1", 2u32), ("doc2", 0u32), ("doc3", 1u32)]);
let results = vec![("doc1", 0.9_f32), ("doc2", 0.5), ("doc3", 0.1)];
let ndcg = ndcg_at_k(&results, &qrels, 3);
let expected = 3.5_f32 / (3.0 + 1.0_f32 / 3.0_f32.log2());
assert!(
(ndcg - expected).abs() < 1e-4,
"NDCG={ndcg} expected≈{expected}"
);
}
#[test]
fn precision_at_k_basic() {
let qrels = make_qrels();
let results = vec![
("d1", 0.9),
("d4", 0.8),
("d2", 0.7),
("d5", 0.6),
("d3", 0.5),
];
assert!((precision_at_k(&results, &qrels, 1) - 1.0).abs() < 1e-6);
assert!((precision_at_k(&results, &qrels, 2) - 0.5).abs() < 1e-6);
assert!((precision_at_k(&results, &qrels, 3) - 2.0 / 3.0).abs() < 1e-6);
assert!((precision_at_k(&results, &qrels, 5) - 0.6).abs() < 1e-6);
}
#[test]
fn precision_at_k_edge_cases() {
let qrels = make_qrels();
let results = vec![("d1", 0.9)];
assert_eq!(precision_at_k(&results, &qrels, 0), 0.0);
assert_eq!(precision_at_k(&[], &qrels, 5), 0.0);
assert!((precision_at_k(&results, &qrels, 10) - 1.0).abs() < 1e-6);
}
#[test]
fn map_basic() {
let qrels = make_qrels(); let perfect = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7), ("d4", 0.6)];
assert!((map(&perfect, &qrels) - 1.0).abs() < 1e-6);
let interleaved = vec![
("d1", 0.9),
("d4", 0.8),
("d2", 0.7),
("d5", 0.6),
("d3", 0.5),
];
let expected = (1.0 + 2.0 / 3.0 + 3.0 / 5.0) / 3.0;
assert!((map(&interleaved, &qrels) - expected).abs() < 1e-4);
}
#[test]
fn map_at_k_truncation() {
let qrels = make_qrels(); let results = vec![
("d4", 0.9), ("d1", 0.8), ("d5", 0.7), ("d2", 0.6), ];
let expected = (1.0 / 2.0) / 3.0;
assert!(
(map_at_k(&results, &qrels, 3) - expected).abs() < 1e-4,
"MAP@3 = {}, expected {}",
map_at_k(&results, &qrels, 3),
expected
);
}
#[test]
fn map_empty() {
let qrels = make_qrels();
assert_eq!(map(&[], &qrels), 0.0);
assert_eq!(map_at_k(&[], &qrels, 10), 0.0);
let empty_qrels: Qrels<&str> = HashMap::new();
let results = vec![("d1", 0.9)];
assert_eq!(map(&results, &empty_qrels), 0.0);
}
#[test]
fn hit_rate_basic() {
let qrels = make_qrels();
let results = vec![("d1", 0.9), ("d4", 0.8)];
assert_eq!(hit_rate(&results, &qrels, 1), 1.0);
assert_eq!(hit_rate(&results, &qrels, 2), 1.0);
let results2 = vec![("d4", 0.9), ("d5", 0.8), ("d1", 0.7)];
assert_eq!(hit_rate(&results2, &qrels, 1), 0.0);
assert_eq!(hit_rate(&results2, &qrels, 2), 0.0);
assert_eq!(hit_rate(&results2, &qrels, 3), 1.0);
}
#[test]
fn hit_rate_edge_cases() {
let qrels = make_qrels();
assert_eq!(hit_rate(&[], &qrels, 5), 0.0);
assert_eq!(hit_rate(&[("d4", 0.9)], &qrels, 1), 0.0); }
#[test]
fn evaluate_metric_dispatch() {
let qrels = make_qrels();
let results = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
let ndcg = evaluate_metric(&results, &qrels, OptimizeMetric::Ndcg { k: 3 });
assert!((ndcg - ndcg_at_k(&results, &qrels, 3)).abs() < 1e-6);
let m = evaluate_metric(&results, &qrels, OptimizeMetric::Map);
assert!((m - map(&results, &qrels)).abs() < 1e-6);
let p = evaluate_metric(&results, &qrels, OptimizeMetric::Precision { k: 2 });
assert!((p - precision_at_k(&results, &qrels, 2)).abs() < 1e-6);
let h = evaluate_metric(&results, &qrels, OptimizeMetric::HitRate { k: 1 });
assert!((h - hit_rate(&results, &qrels, 1)).abs() < 1e-6);
}
}