use crate::error::{MetricsError, Result};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct FairnessAuditResult {
pub demographic_parity_ratio: f64,
pub equal_opportunity_ratio: f64,
pub equalized_odds_ratio: f64,
pub disparate_impact_ratio: f64,
pub four_fifths_rule_satisfied: bool,
pub group_selection_rates: HashMap<usize, f64>,
pub group_tpr: HashMap<usize, f64>,
pub group_fpr: HashMap<usize, f64>,
pub group_ppv: HashMap<usize, f64>,
}
pub fn demographic_parity_ratio(y_pred: &[f64], groups: &[usize]) -> Result<f64> {
if y_pred.len() != groups.len() {
return Err(MetricsError::InvalidInput(
"y_pred and groups must have the same length".to_string(),
));
}
if y_pred.is_empty() {
return Err(MetricsError::InvalidInput(
"inputs must not be empty".to_string(),
));
}
let rates = group_positive_rates(y_pred, groups)?;
let min_rate = rates.values().copied().fold(f64::INFINITY, f64::min);
let max_rate = rates.values().copied().fold(f64::NEG_INFINITY, f64::max);
if max_rate <= 0.0 {
return Ok(1.0);
}
Ok(min_rate / max_rate)
}
pub fn equal_opportunity_ratio(y_true: &[f64], y_pred: &[f64], groups: &[usize]) -> Result<f64> {
validate_ternary_inputs(y_true, y_pred, groups)?;
let tpr_map = group_true_positive_rates(y_true, y_pred, groups)?;
if tpr_map.is_empty() {
return Err(MetricsError::InvalidInput(
"No groups with positive samples found".to_string(),
));
}
let min_tpr = tpr_map.values().copied().fold(f64::INFINITY, f64::min);
let max_tpr = tpr_map.values().copied().fold(f64::NEG_INFINITY, f64::max);
if max_tpr <= 0.0 {
return Ok(1.0);
}
Ok(min_tpr / max_tpr)
}
pub fn equalized_odds_ratio(y_true: &[f64], y_pred: &[f64], groups: &[usize]) -> Result<f64> {
validate_ternary_inputs(y_true, y_pred, groups)?;
let tpr_map = group_true_positive_rates(y_true, y_pred, groups)?;
let fpr_map = group_false_positive_rates(y_true, y_pred, groups)?;
let tpr_ratio = if !tpr_map.is_empty() {
let min_tpr = tpr_map.values().copied().fold(f64::INFINITY, f64::min);
let max_tpr = tpr_map.values().copied().fold(f64::NEG_INFINITY, f64::max);
if max_tpr <= 0.0 {
1.0
} else {
min_tpr / max_tpr
}
} else {
1.0
};
let fpr_ratio = if !fpr_map.is_empty() {
let min_fpr = fpr_map.values().copied().fold(f64::INFINITY, f64::min);
let max_fpr = fpr_map.values().copied().fold(f64::NEG_INFINITY, f64::max);
if max_fpr <= 0.0 {
1.0
} else {
min_fpr / max_fpr
}
} else {
1.0
};
Ok(tpr_ratio.min(fpr_ratio))
}
pub fn disparate_impact_check(
y_pred: &[f64],
groups: &[usize],
threshold: f64,
) -> Result<DisparateImpactResult> {
let ratio = demographic_parity_ratio(y_pred, groups)?;
let rates = group_positive_rates(y_pred, groups)?;
Ok(DisparateImpactResult {
ratio,
passes_threshold: ratio >= threshold,
threshold,
group_rates: rates,
})
}
#[derive(Debug, Clone)]
pub struct DisparateImpactResult {
pub ratio: f64,
pub passes_threshold: bool,
pub threshold: f64,
pub group_rates: HashMap<usize, f64>,
}
pub fn predictive_parity_difference(
y_true: &[f64],
y_pred: &[f64],
groups: &[usize],
) -> Result<f64> {
validate_ternary_inputs(y_true, y_pred, groups)?;
let ppv_map = group_positive_predictive_values(y_true, y_pred, groups)?;
if ppv_map.len() < 2 {
return Ok(0.0);
}
let values: Vec<f64> = ppv_map.values().copied().collect();
let mut max_diff = 0.0_f64;
for i in 0..values.len() {
for j in i + 1..values.len() {
let diff = (values[i] - values[j]).abs();
if diff > max_diff {
max_diff = diff;
}
}
}
Ok(max_diff)
}
pub fn treatment_equality_difference(
y_true: &[f64],
y_pred: &[f64],
groups: &[usize],
) -> Result<f64> {
validate_ternary_inputs(y_true, y_pred, groups)?;
let mut group_fn_fp: HashMap<usize, (usize, usize)> = HashMap::new();
for i in 0..y_true.len() {
let g = groups[i];
let entry = group_fn_fp.entry(g).or_insert((0, 0));
let is_positive_true = y_true[i] > 0.0;
let is_positive_pred = y_pred[i] > 0.0;
if is_positive_true && !is_positive_pred {
entry.0 += 1; } else if !is_positive_true && is_positive_pred {
entry.1 += 1; }
}
let mut ratios: Vec<f64> = Vec::new();
for (fn_count, fp_count) in group_fn_fp.values() {
if *fp_count > 0 {
ratios.push(*fn_count as f64 / *fp_count as f64);
} else if *fn_count > 0 {
ratios.push(f64::INFINITY);
} else {
ratios.push(0.0); }
}
if ratios.len() < 2 {
return Ok(0.0);
}
let finite_ratios: Vec<f64> = ratios.iter().copied().filter(|r| r.is_finite()).collect();
if finite_ratios.len() < 2 {
if ratios.iter().any(|r| r.is_infinite()) && finite_ratios.len() == 1 {
return Ok(f64::INFINITY);
}
return Ok(0.0);
}
let mut max_diff = 0.0_f64;
for i in 0..finite_ratios.len() {
for j in i + 1..finite_ratios.len() {
let diff = (finite_ratios[i] - finite_ratios[j]).abs();
if diff > max_diff {
max_diff = diff;
}
}
}
Ok(max_diff)
}
pub fn fairness_audit(
y_true: &[f64],
y_pred: &[f64],
groups: &[usize],
) -> Result<FairnessAuditResult> {
validate_ternary_inputs(y_true, y_pred, groups)?;
let dp_ratio = demographic_parity_ratio(y_pred, groups)?;
let eo_ratio = equal_opportunity_ratio(y_true, y_pred, groups).unwrap_or(1.0);
let eod_ratio = equalized_odds_ratio(y_true, y_pred, groups).unwrap_or(1.0);
let selection_rates = group_positive_rates(y_pred, groups)?;
let tpr = group_true_positive_rates(y_true, y_pred, groups).unwrap_or_default();
let fpr = group_false_positive_rates(y_true, y_pred, groups).unwrap_or_default();
let ppv = group_positive_predictive_values(y_true, y_pred, groups).unwrap_or_default();
Ok(FairnessAuditResult {
demographic_parity_ratio: dp_ratio,
equal_opportunity_ratio: eo_ratio,
equalized_odds_ratio: eod_ratio,
disparate_impact_ratio: dp_ratio,
four_fifths_rule_satisfied: dp_ratio >= 0.8,
group_selection_rates: selection_rates,
group_tpr: tpr,
group_fpr: fpr,
group_ppv: ppv,
})
}
fn group_positive_rates(y_pred: &[f64], groups: &[usize]) -> Result<HashMap<usize, f64>> {
let mut group_counts: HashMap<usize, (usize, usize)> = HashMap::new();
for i in 0..y_pred.len() {
let entry = group_counts.entry(groups[i]).or_insert((0, 0));
entry.1 += 1;
if y_pred[i] > 0.0 {
entry.0 += 1;
}
}
let mut rates = HashMap::new();
for (g, (pos, total)) in &group_counts {
if *total > 0 {
rates.insert(*g, *pos as f64 / *total as f64);
}
}
Ok(rates)
}
fn group_true_positive_rates(
y_true: &[f64],
y_pred: &[f64],
groups: &[usize],
) -> Result<HashMap<usize, f64>> {
let mut group_tp: HashMap<usize, usize> = HashMap::new();
let mut group_pos: HashMap<usize, usize> = HashMap::new();
for i in 0..y_true.len() {
let g = groups[i];
if y_true[i] > 0.0 {
*group_pos.entry(g).or_insert(0) += 1;
if y_pred[i] > 0.0 {
*group_tp.entry(g).or_insert(0) += 1;
}
}
}
let mut rates = HashMap::new();
for (g, &pos) in &group_pos {
if pos > 0 {
let tp = group_tp.get(g).copied().unwrap_or(0);
rates.insert(*g, tp as f64 / pos as f64);
}
}
Ok(rates)
}
fn group_false_positive_rates(
y_true: &[f64],
y_pred: &[f64],
groups: &[usize],
) -> Result<HashMap<usize, f64>> {
let mut group_fp: HashMap<usize, usize> = HashMap::new();
let mut group_neg: HashMap<usize, usize> = HashMap::new();
for i in 0..y_true.len() {
let g = groups[i];
if y_true[i] <= 0.0 {
*group_neg.entry(g).or_insert(0) += 1;
if y_pred[i] > 0.0 {
*group_fp.entry(g).or_insert(0) += 1;
}
}
}
let mut rates = HashMap::new();
for (g, &neg) in &group_neg {
if neg > 0 {
let fp = group_fp.get(g).copied().unwrap_or(0);
rates.insert(*g, fp as f64 / neg as f64);
}
}
Ok(rates)
}
fn group_positive_predictive_values(
y_true: &[f64],
y_pred: &[f64],
groups: &[usize],
) -> Result<HashMap<usize, f64>> {
let mut group_tp: HashMap<usize, usize> = HashMap::new();
let mut group_pred_pos: HashMap<usize, usize> = HashMap::new();
for i in 0..y_true.len() {
let g = groups[i];
if y_pred[i] > 0.0 {
*group_pred_pos.entry(g).or_insert(0) += 1;
if y_true[i] > 0.0 {
*group_tp.entry(g).or_insert(0) += 1;
}
}
}
let mut ppv = HashMap::new();
for (g, &pred_pos) in &group_pred_pos {
if pred_pos > 0 {
let tp = group_tp.get(g).copied().unwrap_or(0);
ppv.insert(*g, tp as f64 / pred_pos as f64);
}
}
Ok(ppv)
}
fn validate_ternary_inputs(y_true: &[f64], y_pred: &[f64], groups: &[usize]) -> Result<()> {
if y_true.len() != y_pred.len() || y_true.len() != groups.len() {
return Err(MetricsError::InvalidInput(format!(
"All inputs must have the same length: {}, {}, {}",
y_true.len(),
y_pred.len(),
groups.len()
)));
}
if y_true.is_empty() {
return Err(MetricsError::InvalidInput(
"inputs must not be empty".to_string(),
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dp_ratio_perfect() {
let y_pred = vec![1.0, 0.0, 1.0, 0.0]; let groups = vec![0, 0, 1, 1];
let val = demographic_parity_ratio(&y_pred, &groups).expect("should succeed");
assert!((val - 1.0).abs() < 1e-10);
}
#[test]
fn test_dp_ratio_imbalanced() {
let y_pred = vec![1.0, 1.0, 0.0, 0.0];
let groups = vec![0, 0, 1, 1];
let val = demographic_parity_ratio(&y_pred, &groups).expect("should succeed");
assert!((val - 0.0).abs() < 1e-10);
}
#[test]
fn test_dp_ratio_partial() {
let y_pred = vec![1.0, 1.0, 0.0, 1.0, 0.0, 0.0];
let groups = vec![0, 0, 0, 1, 1, 1];
let val = demographic_parity_ratio(&y_pred, &groups).expect("should succeed");
assert!((val - 0.5).abs() < 1e-10);
}
#[test]
fn test_dp_ratio_multi_group() {
let y_pred = vec![1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let groups = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
let val = demographic_parity_ratio(&y_pred, &groups).expect("should succeed");
assert!((val - 1.0 / 3.0).abs() < 1e-10);
}
#[test]
fn test_dp_ratio_empty() {
assert!(demographic_parity_ratio(&[], &[]).is_err());
}
#[test]
fn test_eo_ratio_perfect() {
let y_true = vec![1.0, 0.0, 1.0, 0.0];
let y_pred = vec![1.0, 0.0, 1.0, 0.0];
let groups = vec![0, 0, 1, 1];
let val = equal_opportunity_ratio(&y_true, &y_pred, &groups).expect("should succeed");
assert!((val - 1.0).abs() < 1e-10);
}
#[test]
fn test_eo_ratio_unequal() {
let y_true = vec![1.0, 0.0, 1.0, 0.0];
let y_pred = vec![1.0, 0.0, 0.0, 0.0];
let groups = vec![0, 0, 1, 1];
let val = equal_opportunity_ratio(&y_true, &y_pred, &groups).expect("should succeed");
assert!((val - 0.0).abs() < 1e-10);
}
#[test]
fn test_eo_ratio_partial() {
let y_true = vec![1.0, 1.0, 0.0, 1.0, 1.0, 0.0];
let y_pred = vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0];
let groups = vec![0, 0, 0, 1, 1, 1];
let val = equal_opportunity_ratio(&y_true, &y_pred, &groups).expect("should succeed");
assert!((val - 0.5).abs() < 1e-10);
}
#[test]
fn test_eo_ratio_mismatched() {
assert!(equal_opportunity_ratio(&[1.0], &[0.0], &[0, 1]).is_err());
}
#[test]
fn test_eod_ratio_perfect() {
let y_true = vec![1.0, 0.0, 1.0, 0.0];
let y_pred = vec![1.0, 0.0, 1.0, 0.0];
let groups = vec![0, 0, 1, 1];
let val = equalized_odds_ratio(&y_true, &y_pred, &groups).expect("should succeed");
assert!((val - 1.0).abs() < 1e-10);
}
#[test]
fn test_eod_ratio_unequal() {
let y_true = vec![1.0, 0.0, 1.0, 0.0];
let y_pred = vec![1.0, 0.0, 0.0, 1.0]; let groups = vec![0, 0, 1, 1];
let val = equalized_odds_ratio(&y_true, &y_pred, &groups).expect("should succeed");
assert!(val < 0.5);
}
#[test]
fn test_eod_ratio_empty() {
assert!(equalized_odds_ratio(&[], &[], &[]).is_err());
}
#[test]
fn test_eod_ratio_multi_group() {
let y_true = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
let y_pred = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
let groups = vec![0, 0, 1, 1, 2, 2];
let val = equalized_odds_ratio(&y_true, &y_pred, &groups).expect("should succeed");
assert!((val - 1.0).abs() < 1e-10);
}
#[test]
fn test_di_check_passes() {
let y_pred = vec![1.0, 1.0, 1.0, 1.0]; let groups = vec![0, 0, 1, 1];
let result = disparate_impact_check(&y_pred, &groups, 0.8).expect("should succeed");
assert!(result.passes_threshold);
assert!((result.ratio - 1.0).abs() < 1e-10);
}
#[test]
fn test_di_check_fails() {
let y_pred = vec![1.0, 1.0, 0.0, 0.0]; let groups = vec![0, 0, 1, 1];
let result = disparate_impact_check(&y_pred, &groups, 0.8).expect("should succeed");
assert!(!result.passes_threshold);
assert!((result.ratio - 0.0).abs() < 1e-10);
}
#[test]
fn test_di_check_borderline() {
let y_pred = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0];
let groups = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
let result = disparate_impact_check(&y_pred, &groups, 0.8).expect("should succeed");
assert!(result.passes_threshold);
}
#[test]
fn test_di_check_empty() {
assert!(disparate_impact_check(&[], &[], 0.8).is_err());
}
#[test]
fn test_pp_difference_equal() {
let y_true = vec![1.0, 0.0, 1.0, 0.0];
let y_pred = vec![1.0, 0.0, 1.0, 0.0];
let groups = vec![0, 0, 1, 1];
let val = predictive_parity_difference(&y_true, &y_pred, &groups).expect("should succeed");
assert!((val - 0.0).abs() < 1e-10);
}
#[test]
fn test_pp_difference_unequal() {
let y_true = vec![1.0, 0.0, 0.0, 0.0];
let y_pred = vec![1.0, 0.0, 1.0, 0.0];
let groups = vec![0, 0, 1, 1];
let val = predictive_parity_difference(&y_true, &y_pred, &groups).expect("should succeed");
assert!((val - 1.0).abs() < 1e-10);
}
#[test]
fn test_pp_difference_partial() {
let y_true = vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0];
let y_pred = vec![1.0, 1.0, 0.0, 1.0, 1.0, 0.0];
let groups = vec![0, 0, 0, 1, 1, 1];
let val = predictive_parity_difference(&y_true, &y_pred, &groups).expect("should succeed");
assert!((val - 0.5).abs() < 1e-10);
}
#[test]
fn test_pp_difference_empty() {
assert!(predictive_parity_difference(&[], &[], &[]).is_err());
}
#[test]
fn test_treatment_equality_equal() {
let y_true = vec![1.0, 0.0, 1.0, 0.0];
let y_pred = vec![0.0, 1.0, 0.0, 1.0]; let groups = vec![0, 0, 1, 1];
let val = treatment_equality_difference(&y_true, &y_pred, &groups).expect("should succeed");
assert!((val - 0.0).abs() < 1e-10);
}
#[test]
fn test_treatment_equality_unequal() {
let y_true = vec![1.0, 1.0, 0.0, 0.0];
let y_pred = vec![0.0, 1.0, 1.0, 0.0];
let groups = vec![0, 0, 1, 1];
let val = treatment_equality_difference(&y_true, &y_pred, &groups).expect("should succeed");
assert!(val.is_infinite() || val > 0.5);
}
#[test]
fn test_treatment_equality_no_errors() {
let y_true = vec![1.0, 0.0, 1.0, 0.0];
let y_pred = vec![1.0, 0.0, 1.0, 0.0];
let groups = vec![0, 0, 1, 1];
let val = treatment_equality_difference(&y_true, &y_pred, &groups).expect("should succeed");
assert!((val - 0.0).abs() < 1e-10);
}
#[test]
fn test_treatment_equality_empty() {
assert!(treatment_equality_difference(&[], &[], &[]).is_err());
}
#[test]
fn test_fairness_audit_basic() {
let y_true = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
let y_pred = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
let groups = vec![0, 0, 0, 0, 1, 1, 1, 1];
let result = fairness_audit(&y_true, &y_pred, &groups).expect("should succeed");
assert!((result.demographic_parity_ratio - 1.0).abs() < 1e-10);
assert!((result.equal_opportunity_ratio - 1.0).abs() < 1e-10);
assert!(result.four_fifths_rule_satisfied);
}
#[test]
fn test_fairness_audit_unfair() {
let y_true = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
let y_pred = vec![1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]; let groups = vec![0, 0, 0, 0, 1, 1, 1, 1];
let result = fairness_audit(&y_true, &y_pred, &groups).expect("should succeed");
assert!((result.demographic_parity_ratio - 0.0).abs() < 1e-10);
assert!(!result.four_fifths_rule_satisfied);
}
#[test]
fn test_fairness_audit_has_per_group_data() {
let y_true = vec![0.0, 1.0, 0.0, 1.0];
let y_pred = vec![0.0, 1.0, 0.0, 1.0];
let groups = vec![0, 0, 1, 1];
let result = fairness_audit(&y_true, &y_pred, &groups).expect("should succeed");
assert!(result.group_selection_rates.contains_key(&0));
assert!(result.group_selection_rates.contains_key(&1));
}
#[test]
fn test_fairness_audit_empty() {
assert!(fairness_audit(&[], &[], &[]).is_err());
}
}