use scirs2_core::ndarray::{ArrayBase, Data, Ix1, Ix2};
use scirs2_core::numeric::Float;
use std::cmp::Ordering;
use crate::error::{MetricsError, Result};
pub mod bias_detection;
pub mod group_fairness;
pub mod robustness;
#[allow(dead_code)]
pub fn demographic_parity_difference<T, S, R>(
y_pred: &ArrayBase<S, Ix1>,
protected_group: &ArrayBase<R, Ix1>,
) -> Result<f64>
where
T: Float + PartialOrd + Clone,
S: Data<Elem = T>,
R: Data<Elem = T>,
{
if y_pred.len() != protected_group.len() {
return Err(MetricsError::InvalidInput(format!(
"Arrays have different lengths: {} vs {}",
y_pred.len(),
protected_group.len()
)));
}
if y_pred.is_empty() {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let zero = T::zero();
let mut protected_group_positive = 0;
let mut protected_group_total = 0;
let mut unprotected_group_positive = 0;
let mut unprotected_group_total = 0;
for (pred, group) in y_pred.iter().zip(protected_group.iter()) {
if group > &zero {
protected_group_total += 1;
if pred > &zero {
protected_group_positive += 1;
}
} else {
unprotected_group_total += 1;
if pred > &zero {
unprotected_group_positive += 1;
}
}
}
if protected_group_total == 0 || unprotected_group_total == 0 {
return Err(MetricsError::InvalidInput(
"Each group must have at least one member".to_string(),
));
}
let protected_rate = protected_group_positive as f64 / protected_group_total as f64;
let unprotected_rate = unprotected_group_positive as f64 / unprotected_group_total as f64;
Ok((protected_rate - unprotected_rate).abs())
}
#[allow(dead_code)]
pub fn disparate_impact<T, S, R>(
y_pred: &ArrayBase<S, Ix1>,
protected_group: &ArrayBase<R, Ix1>,
) -> Result<f64>
where
T: Float + PartialOrd + Clone,
S: Data<Elem = T>,
R: Data<Elem = T>,
{
if y_pred.len() != protected_group.len() {
return Err(MetricsError::InvalidInput(format!(
"Arrays have different lengths: {} vs {}",
y_pred.len(),
protected_group.len()
)));
}
if y_pred.is_empty() {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let zero = T::zero();
let mut protected_group_positive = 0;
let mut protected_group_total = 0;
let mut unprotected_group_positive = 0;
let mut unprotected_group_total = 0;
for (pred, group) in y_pred.iter().zip(protected_group.iter()) {
if group > &zero {
protected_group_total += 1;
if pred > &zero {
protected_group_positive += 1;
}
} else {
unprotected_group_total += 1;
if pred > &zero {
unprotected_group_positive += 1;
}
}
}
if protected_group_total == 0 || unprotected_group_total == 0 {
return Err(MetricsError::InvalidInput(
"Each group must have at least one member".to_string(),
));
}
let protected_rate = protected_group_positive as f64 / protected_group_total as f64;
let unprotected_rate = unprotected_group_positive as f64 / unprotected_group_total as f64;
if unprotected_rate == 0.0 {
if protected_rate == 0.0 {
return Ok(1.0);
}
return Ok(f64::INFINITY);
}
Ok(protected_rate / unprotected_rate)
}
#[allow(dead_code)]
pub fn equalized_odds_difference<T, S, R, Q>(
y_true: &ArrayBase<S, Ix1>,
y_pred: &ArrayBase<R, Ix1>,
protected_group: &ArrayBase<Q, Ix1>,
) -> Result<f64>
where
T: Float + PartialOrd + Clone,
S: Data<Elem = T>,
R: Data<Elem = T>,
Q: Data<Elem = T>,
{
if y_true.len() != y_pred.len() || y_true.len() != protected_group.len() {
return Err(MetricsError::InvalidInput(format!(
"Arrays have different lengths: {}, {}, {}",
y_true.len(),
y_pred.len(),
protected_group.len()
)));
}
if y_true.is_empty() {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let zero = T::zero();
let mut protected_true_positives = 0;
let mut protected_false_positives = 0;
let mut protected_true_negatives = 0;
let mut protected_false_negatives = 0;
let mut unprotected_true_positives = 0;
let mut unprotected_false_positives = 0;
let mut unprotected_true_negatives = 0;
let mut unprotected_false_negatives = 0;
for ((truth, pred), group) in y_true.iter().zip(y_pred.iter()).zip(protected_group.iter()) {
if group > &zero {
if truth > &zero {
if pred > &zero {
protected_true_positives += 1;
} else {
protected_false_negatives += 1;
}
} else {
if pred > &zero {
protected_false_positives += 1;
} else {
protected_true_negatives += 1;
}
}
} else {
if truth > &zero {
if pred > &zero {
unprotected_true_positives += 1;
} else {
unprotected_false_negatives += 1;
}
} else {
if pred > &zero {
unprotected_false_positives += 1;
} else {
unprotected_true_negatives += 1;
}
}
}
}
let protected_tpr = if protected_true_positives + protected_false_negatives > 0 {
protected_true_positives as f64
/ (protected_true_positives + protected_false_negatives) as f64
} else {
0.0
};
let protected_fpr = if protected_false_positives + protected_true_negatives > 0 {
protected_false_positives as f64
/ (protected_false_positives + protected_true_negatives) as f64
} else {
0.0
};
let unprotected_tpr = if unprotected_true_positives + unprotected_false_negatives > 0 {
unprotected_true_positives as f64
/ (unprotected_true_positives + unprotected_false_negatives) as f64
} else {
0.0
};
let unprotected_fpr = if unprotected_false_positives + unprotected_true_negatives > 0 {
unprotected_false_positives as f64
/ (unprotected_false_positives + unprotected_true_negatives) as f64
} else {
0.0
};
let tpr_diff = (protected_tpr - unprotected_tpr).abs();
let fpr_diff = (protected_fpr - unprotected_fpr).abs();
Ok(tpr_diff.max(fpr_diff))
}
#[allow(dead_code)]
pub fn equal_opportunity_difference<T, S, R, Q>(
y_true: &ArrayBase<S, Ix1>,
y_pred: &ArrayBase<R, Ix1>,
protected_group: &ArrayBase<Q, Ix1>,
) -> Result<f64>
where
T: Float + PartialOrd + Clone,
S: Data<Elem = T>,
R: Data<Elem = T>,
Q: Data<Elem = T>,
{
if y_true.len() != y_pred.len() || y_true.len() != protected_group.len() {
return Err(MetricsError::InvalidInput(format!(
"Arrays have different lengths: {}, {}, {}",
y_true.len(),
y_pred.len(),
protected_group.len()
)));
}
if y_true.is_empty() {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let zero = T::zero();
let mut protected_true_positives = 0;
let mut protected_false_negatives = 0;
let mut unprotected_true_positives = 0;
let mut unprotected_false_negatives = 0;
for ((truth, pred), group) in y_true.iter().zip(y_pred.iter()).zip(protected_group.iter()) {
if truth > &zero {
if group > &zero {
if pred > &zero {
protected_true_positives += 1;
} else {
protected_false_negatives += 1;
}
} else {
if pred > &zero {
unprotected_true_positives += 1;
} else {
unprotected_false_negatives += 1;
}
}
}
}
let protected_tpr = if protected_true_positives + protected_false_negatives > 0 {
protected_true_positives as f64
/ (protected_true_positives + protected_false_negatives) as f64
} else {
return Err(MetricsError::InvalidInput(
"No positive examples in protected group".to_string(),
));
};
let unprotected_tpr = if unprotected_true_positives + unprotected_false_negatives > 0 {
unprotected_true_positives as f64
/ (unprotected_true_positives + unprotected_false_negatives) as f64
} else {
return Err(MetricsError::InvalidInput(
"No positive examples in unprotected group".to_string(),
));
};
Ok((protected_tpr - unprotected_tpr).abs())
}
#[allow(dead_code)]
pub fn consistency_score<T, S, R>(
features: &ArrayBase<S, Ix2>,
predictions: &ArrayBase<R, Ix1>,
k: usize,
) -> Result<f64>
where
T: Float + PartialOrd + Clone,
S: Data<Elem = T>,
R: Data<Elem = T>,
{
let n_samples = features.nrows();
if n_samples != predictions.len() {
return Err(MetricsError::InvalidInput(format!(
"Number of samples in features ({}) and predictions ({}) do not match",
n_samples,
predictions.len()
)));
}
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
if k >= n_samples {
return Err(MetricsError::InvalidInput(format!(
"k ({}) must be less than the number of samples ({})",
k, n_samples
)));
}
if k == 0 {
return Err(MetricsError::InvalidInput(
"k must be greater than 0".to_string(),
));
}
let mut distances = Vec::with_capacity(n_samples * n_samples);
for i in 0..n_samples {
for j in 0..n_samples {
if i != j {
let mut dist = 0.0;
for c in 0..features.ncols() {
let diff = features[[i, c]].to_f64().unwrap_or(0.0)
- features[[j, c]].to_f64().unwrap_or(0.0);
dist += diff * diff;
}
dist = dist.sqrt();
distances.push((i, j, dist));
}
}
}
let mut consistency_sum = 0.0;
for i in 0..n_samples {
let mut neighbors: Vec<_> = distances.iter().filter(|(idx, _, _)| *idx == i).collect();
neighbors.sort_by(|(_, _, dist_a), (_, _, dist_b)| {
dist_a.partial_cmp(dist_b).unwrap_or(Ordering::Equal)
});
let nearest_k = neighbors
.iter()
.take(k)
.map(|(_, j, _)| *j)
.collect::<Vec<_>>();
let pred_i = predictions[i].to_f64().unwrap_or(0.0);
let mut diff_sum = 0.0;
for &j in &nearest_k {
let pred_j = predictions[j].to_f64().unwrap_or(0.0);
diff_sum += (pred_i - pred_j).abs();
}
let mean_diff = diff_sum / k as f64;
consistency_sum += mean_diff;
}
Ok(1.0 - consistency_sum / n_samples as f64)
}