use crate::solver::priority_selection::{PriorityCandidate, rank_priority_candidates};
#[derive(Clone, Debug)]
pub struct RemlCandidate {
pub index: usize,
pub name: String,
pub score: f64,
pub edf: Option<f64>,
}
#[derive(Clone, Debug)]
pub struct RemlComparison {
pub ranking: Vec<RankedRow>,
pub winner: String,
pub evidence_summary: String,
pub score_table: Vec<ScoreRow>,
}
#[derive(Clone, Debug)]
pub struct RankedRow {
pub name: String,
pub score: f64,
pub delta: f64,
pub bayes_factor: f64,
pub edf: Option<f64>,
}
#[derive(Clone, Debug)]
pub struct ScoreRow {
pub name: String,
pub reml_score: f64,
pub delta_reml: f64,
pub bayes_factor_best_over_model: f64,
pub effective_dof: Option<f64>,
}
#[inline]
pub fn log_bayes_factor(reml_score_a: f64, reml_score_b: f64) -> f64 {
reml_score_b - reml_score_a
}
pub fn compare_reml_fits(mut candidates: Vec<RemlCandidate>) -> Result<RemlComparison, String> {
if candidates.is_empty() {
return Err("compare_models requires at least one fit".to_string());
}
candidates = rank_priority_candidates(
candidates
.into_iter()
.enumerate()
.map(|(idx, row)| {
let score = row.score;
PriorityCandidate::new(row, idx, score, 0)
})
.collect(),
)
.into_iter()
.map(|row| row.item)
.collect();
let best_score = candidates[0].score;
let winner = candidates[0].name.clone();
let mut ranking = Vec::with_capacity(candidates.len());
let mut score_table = Vec::with_capacity(candidates.len());
for row in candidates.iter() {
let delta = log_bayes_factor(best_score, row.score);
let bayes_factor = delta.exp();
ranking.push(RankedRow {
name: row.name.clone(),
score: row.score,
delta,
bayes_factor,
edf: row.edf,
});
score_table.push(ScoreRow {
name: row.name.clone(),
reml_score: row.score,
delta_reml: delta,
bayes_factor_best_over_model: bayes_factor,
effective_dof: row.edf,
});
}
let evidence_summary = if candidates.len() >= 2 {
let runner_up = &candidates[1];
let best_log_bayes_factor_vs_runner_up = log_bayes_factor(best_score, runner_up.score);
format!(
"{} wins by Bayes factor {} over {}",
winner,
format_bayes_factor(best_log_bayes_factor_vs_runner_up),
runner_up.name
)
} else {
format!("{winner} (single fit; no comparison)")
};
Ok(RemlComparison {
ranking,
winner,
evidence_summary,
score_table,
})
}
pub fn format_bayes_factor(log_bf: f64) -> String {
if !log_bf.is_finite() {
return "inf".to_string();
}
if log_bf.abs() >= std::f64::consts::LN_10 * 3.0 {
return format!("1e{:+.1}", log_bf / std::f64::consts::LN_10);
}
format_three_significant(log_bf.exp())
}
pub fn format_three_significant(value: f64) -> String {
if value == 0.0 {
return "0".to_string();
}
if !value.is_finite() {
return format!("{value}");
}
let abs = value.abs();
let exponent = abs.log10().floor() as i32;
if exponent >= 3 {
return format!("{value:.2e}");
}
let decimals = (2 - exponent).max(0) as usize;
let scale = 10f64.powi(decimals as i32);
let rounded = (value * scale).abs().round() / scale * value.signum();
let formatted = format!("{rounded:.decimals$}");
formatted
}
#[cfg(test)]
mod tests {
use super::{RemlCandidate, compare_reml_fits, log_bayes_factor};
#[test]
fn log_bayes_factor_favours_lower_cost_model() {
let better = 311.06_f64;
let worse = 815.71_f64;
assert!(log_bayes_factor(better, worse) > 0.0);
assert!(log_bayes_factor(better, worse).exp() > 1.0);
assert!((log_bayes_factor(worse, better) + log_bayes_factor(better, worse)).abs() < 1e-12);
assert!(log_bayes_factor(worse, better) < 0.0);
assert_eq!(log_bayes_factor(5.0, 5.0), 0.0);
let comparison = compare_reml_fits(vec![
RemlCandidate {
index: 0,
name: "better".to_string(),
score: better,
edf: None,
},
RemlCandidate {
index: 1,
name: "worse".to_string(),
score: worse,
edf: None,
},
])
.expect("finite candidates compare");
let row = &comparison.ranking[1];
assert!((row.delta - log_bayes_factor(better, worse)).abs() < 1e-12);
}
#[test]
fn ranks_lowest_reml_cost_as_winner() {
let comparison = compare_reml_fits(vec![
RemlCandidate {
index: 0,
name: "low_cost".to_string(),
score: 2.0,
edf: None,
},
RemlCandidate {
index: 1,
name: "high_cost".to_string(),
score: 5.0,
edf: Some(4.0),
},
])
.expect("finite REML candidates should compare");
assert_eq!(comparison.winner, "low_cost");
assert_eq!(comparison.ranking[0].name, "low_cost");
assert_eq!(comparison.ranking[0].delta, 0.0);
assert_eq!(comparison.ranking[0].bayes_factor, 1.0);
assert_eq!(comparison.ranking[1].name, "high_cost");
assert_eq!(comparison.ranking[1].delta, 3.0);
assert!((comparison.ranking[1].bayes_factor - 3.0_f64.exp()).abs() < 1e-12);
assert_eq!(comparison.score_table[1].name, "high_cost");
assert_eq!(comparison.score_table[1].reml_score, 5.0);
assert_eq!(comparison.score_table[1].delta_reml, 3.0);
assert!(
(comparison.score_table[1].bayes_factor_best_over_model - 3.0_f64.exp()).abs() < 1e-12
);
assert!(comparison.evidence_summary.contains("low_cost"));
}
#[test]
fn issue_396_lower_reml_cost_wins_against_higher_cost() {
let comparison = compare_reml_fits(vec![
RemlCandidate {
index: 0,
name: "m1_xonly".to_string(),
score: 786.32,
edf: Some(9.49),
},
RemlCandidate {
index: 1,
name: "m2_true".to_string(),
score: 359.06,
edf: Some(18.58),
},
])
.expect("finite REML candidates should compare");
assert_eq!(
comparison.winner, "m2_true",
"issue #396: the lower-cost (better) model must win"
);
let winner_row = &comparison.score_table[0];
assert_eq!(winner_row.name, "m2_true");
assert!((winner_row.reml_score - 359.06).abs() < 1.0e-12);
assert_eq!(winner_row.delta_reml, 0.0);
assert_eq!(winner_row.bayes_factor_best_over_model, 1.0);
let loser_row = &comparison.score_table[1];
assert_eq!(loser_row.name, "m1_xonly");
assert!((loser_row.reml_score - 786.32).abs() < 1.0e-12);
let expected_delta = 786.32 - 359.06;
assert!((loser_row.delta_reml - expected_delta).abs() < 1.0e-12);
assert!(loser_row.bayes_factor_best_over_model > 1.0e+100);
assert!(comparison.evidence_summary.starts_with("m2_true wins"));
}
}