use scirs2_core::ndarray::{ArrayBase, Data, Ix1, Ix2};
use scirs2_core::numeric::Float;
use std::collections::{BTreeSet, HashMap};
use crate::error::{MetricsError, Result};
#[derive(Debug, Clone)]
pub struct DataSlice {
pub name: String,
pub mask: Vec<bool>,
pub description: Option<String>,
}
#[allow(dead_code)]
pub fn slice_analysis<T, S1, S2, S3, F>(
features: &ArrayBase<S1, Ix2>,
categorical_features: &[usize],
y_true: &ArrayBase<S2, Ix1>,
y_pred: &ArrayBase<S3, Ix1>,
metric_fn: F,
) -> Result<HashMap<String, f64>>
where
T: Float + PartialOrd + Clone,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
S3: Data<Elem = T>,
F: Fn(&[T], &[T]) -> f64,
{
let n_samples = features.nrows();
if n_samples != y_true.len() || n_samples != y_pred.len() {
return Err(MetricsError::InvalidInput(format!(
"Dimensions mismatch: _features ({} rows), y_true ({}), y_pred ({})",
n_samples,
y_true.len(),
y_pred.len()
)));
}
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let n_features = features.ncols();
for &feature_idx in categorical_features {
if feature_idx >= n_features {
return Err(MetricsError::InvalidInput(format!(
"Feature index {} out of bounds (max: {})",
feature_idx,
n_features - 1
)));
}
}
let mut results = HashMap::new();
let y_true_vec: Vec<T> = y_true.iter().cloned().collect();
let y_pred_vec: Vec<T> = y_pred.iter().cloned().collect();
results.insert("overall".to_string(), metric_fn(&y_true_vec, &y_pred_vec));
for &feature_idx in categorical_features {
let mut unique_values = BTreeSet::new();
for i in 0..n_samples {
let value = features[[i, feature_idx]]
.to_f64()
.expect("Operation failed");
let rounded_value = (value * 1000.0).round() as i64;
unique_values.insert(rounded_value);
}
for int_value in unique_values {
let value = int_value as f64 / 1000.0;
let slice_name = format!("feature_{}_{}", feature_idx, value);
let mut mask = vec![false; n_samples];
let mut slice_y_true = Vec::new();
let mut slice_y_pred = Vec::new();
for i in 0..n_samples {
let feature_value = features[[i, feature_idx]]
.to_f64()
.expect("Operation failed");
let rounded_value = (feature_value * 1000.0).round() as i64;
if rounded_value == int_value {
mask[i] = true;
slice_y_true.push(y_true[i].clone());
slice_y_pred.push(y_pred[i].clone());
}
}
if !slice_y_true.is_empty() {
let slice_metric = metric_fn(&slice_y_true, &slice_y_pred);
results.insert(slice_name, slice_metric);
}
}
}
Ok(results)
}
#[allow(dead_code)]
pub fn subgroup_performance<T, S1, S2, S3, F>(
y_true: &ArrayBase<S1, Ix1>,
y_pred: &ArrayBase<S2, Ix1>,
groups: &ArrayBase<S3, Ix2>,
group_names: &[String],
metric_fn: F,
) -> Result<HashMap<String, f64>>
where
T: Float + PartialOrd + Clone,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
S3: Data<Elem = T>,
F: Fn(&[T], &[T]) -> f64,
{
let n_samples = y_true.len();
if n_samples != y_pred.len() || n_samples != groups.nrows() {
return Err(MetricsError::InvalidInput(format!(
"Dimensions mismatch: y_true ({}), y_pred ({}), groups ({} rows)",
n_samples,
y_pred.len(),
groups.nrows()
)));
}
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let n_groups = groups.ncols();
if n_groups != group_names.len() {
return Err(MetricsError::InvalidInput(format!(
"Number of group columns ({}) doesn't match number of group _names ({})",
n_groups,
group_names.len()
)));
}
let mut results = HashMap::new();
let y_true_vec: Vec<T> = y_true.iter().cloned().collect();
let y_pred_vec: Vec<T> = y_pred.iter().cloned().collect();
results.insert("overall".to_string(), metric_fn(&y_true_vec, &y_pred_vec));
let mut unique_values = vec![BTreeSet::new(); n_groups];
for i in 0..n_samples {
for j in 0..n_groups {
let value = groups[[i, j]].to_f64().expect("Operation failed");
let rounded_value = (value * 1000.0).round() as i64;
unique_values[j].insert(rounded_value);
}
}
for (group_idx, group_name) in group_names.iter().enumerate() {
for &int_value in &unique_values[group_idx] {
let value = int_value as f64 / 1000.0;
let subgroup_name = format!("{}={}", group_name, value);
let mut subgroup_y_true = Vec::new();
let mut subgroup_y_pred = Vec::new();
for i in 0..n_samples {
let group_value = groups[[i, group_idx]].to_f64().expect("Operation failed");
let rounded_value = (group_value * 1000.0).round() as i64;
if rounded_value == int_value {
subgroup_y_true.push(y_true[i].clone());
subgroup_y_pred.push(y_pred[i].clone());
}
}
if !subgroup_y_true.is_empty() {
let subgroup_metric = metric_fn(&subgroup_y_true, &subgroup_y_pred);
results.insert(subgroup_name, subgroup_metric);
}
}
}
if n_groups > 1 {
generate_intersectional_subgroups(
y_true,
y_pred,
groups,
group_names,
&unique_values,
&mut results,
metric_fn,
)?;
}
Ok(results)
}
#[allow(dead_code)]
fn generate_intersectional_subgroups<T, S1, S2, S3, F>(
y_true: &ArrayBase<S1, Ix1>,
y_pred: &ArrayBase<S2, Ix1>,
groups: &ArrayBase<S3, Ix2>,
group_names: &[String],
unique_values: &[BTreeSet<i64>],
results: &mut HashMap<String, f64>,
metric_fn: F,
) -> Result<()>
where
T: Float + PartialOrd + Clone,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
S3: Data<Elem = T>,
F: Fn(&[T], &[T]) -> f64,
{
let n_samples = y_true.len();
let n_groups = groups.ncols();
for i in 0..n_groups {
for j in (i + 1)..n_groups {
for &int_value_i in &unique_values[i] {
for &int_value_j in &unique_values[j] {
let value_i = int_value_i as f64 / 1000.0;
let value_j = int_value_j as f64 / 1000.0;
let subgroup_name = format!(
"{}={} & {}={}",
group_names[i], value_i, group_names[j], value_j
);
let mut subgroup_y_true = Vec::new();
let mut subgroup_y_pred = Vec::new();
for k in 0..n_samples {
let group_value_i = groups[[k, i]].to_f64().expect("Operation failed");
let rounded_value_i = (group_value_i * 1000.0).round() as i64;
let group_value_j = groups[[k, j]].to_f64().expect("Operation failed");
let rounded_value_j = (group_value_j * 1000.0).round() as i64;
if rounded_value_i == int_value_i && rounded_value_j == int_value_j {
subgroup_y_true.push(y_true[k].clone());
subgroup_y_pred.push(y_pred[k].clone());
}
}
if !subgroup_y_true.is_empty() {
let subgroup_metric = metric_fn(&subgroup_y_true, &subgroup_y_pred);
results.insert(subgroup_name, subgroup_metric);
}
}
}
}
}
if n_groups > 2 {
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct FairnessMetrics {
pub demographic_parity: f64,
pub equalized_odds: f64,
pub equal_opportunity: f64,
}
#[allow(dead_code)]
pub fn intersectional_fairness<T, S1, S2, S3>(
y_true: &ArrayBase<S1, Ix1>,
y_pred: &ArrayBase<S2, Ix1>,
protected_features: &ArrayBase<S3, Ix2>,
feature_names: &[String],
) -> Result<HashMap<String, FairnessMetrics>>
where
T: Float + PartialOrd + Clone,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
S3: Data<Elem = T>,
{
let n_samples = y_true.len();
if n_samples != y_pred.len() || n_samples != protected_features.nrows() {
return Err(MetricsError::InvalidInput(format!(
"Dimensions mismatch: y_true ({}), y_pred ({}), protected_features ({} rows)",
n_samples,
y_pred.len(),
protected_features.nrows()
)));
}
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let n_protected = protected_features.ncols();
if n_protected != feature_names.len() {
return Err(MetricsError::InvalidInput(format!(
"Number of protected feature columns ({}) doesn't match number of feature _names ({})",
n_protected,
feature_names.len()
)));
}
if n_samples == 0 {
return Err(MetricsError::InvalidInput("Empty input data".to_string()));
}
let mut results = HashMap::new();
let mut unique_values = vec![BTreeSet::new(); n_protected];
for i in 0..n_samples {
for j in 0..n_protected {
let value = protected_features[[i, j]]
.to_f64()
.expect("Operation failed");
let rounded_value = (value * 1000.0).round() as i64;
unique_values[j].insert(rounded_value);
}
}
for (feat_idx, feat_name) in feature_names.iter().enumerate() {
for &int_value in &unique_values[feat_idx] {
let value = int_value as f64 / 1000.0;
let mut protected_group = vec![T::zero(); n_samples];
for i in 0..n_samples {
let feat_value = protected_features[[i, feat_idx]]
.to_f64()
.expect("Operation failed");
let rounded_value = (feat_value * 1000.0).round() as i64;
if rounded_value == int_value {
protected_group[i] = T::one();
}
}
let num_in_group: usize = protected_group.iter().filter(|&&x| x > T::zero()).count();
if num_in_group == 0 || num_in_group == n_samples {
continue;
}
let protected_group_array = scirs2_core::ndarray::Array::from(protected_group);
let group_name = format!("{}={}", feat_name, value);
let metrics = calculate_fairness_metrics(y_true, y_pred, &protected_group_array)?;
results.insert(group_name, metrics);
}
}
if n_protected > 1 {
for i in 0..n_protected {
for j in (i + 1)..n_protected {
for &int_value_i in &unique_values[i] {
for &int_value_j in &unique_values[j] {
let value_i = int_value_i as f64 / 1000.0;
let value_j = int_value_j as f64 / 1000.0;
let mut protected_group = vec![T::zero(); n_samples];
for k in 0..n_samples {
let feat_i_value = protected_features[[k, i]]
.to_f64()
.expect("Operation failed");
let rounded_i = (feat_i_value * 1000.0).round() as i64;
let feat_j_value = protected_features[[k, j]]
.to_f64()
.expect("Operation failed");
let rounded_j = (feat_j_value * 1000.0).round() as i64;
if rounded_i == int_value_i && rounded_j == int_value_j {
protected_group[k] = T::one();
}
}
let num_in_group: usize =
protected_group.iter().filter(|&&x| x > T::zero()).count();
if num_in_group == 0 || num_in_group == n_samples {
continue;
}
let group_name = format!(
"{}={} & {}={}",
feature_names[i], value_i, feature_names[j], value_j
);
let protected_group_array =
scirs2_core::ndarray::Array::from(protected_group);
let metrics =
calculate_fairness_metrics(y_true, y_pred, &protected_group_array)?;
results.insert(group_name, metrics);
}
}
}
}
}
Ok(results)
}
#[allow(dead_code)]
fn calculate_fairness_metrics<T, S1, S2, S3>(
y_true: &ArrayBase<S1, Ix1>,
y_pred: &ArrayBase<S2, Ix1>,
protected_group: &ArrayBase<S3, Ix1>,
) -> Result<FairnessMetrics>
where
T: Float + PartialOrd + Clone,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
S3: Data<Elem = T>,
{
use crate::fairness::{
demographic_parity_difference, equal_opportunity_difference, equalized_odds_difference,
};
let dp = demographic_parity_difference(y_pred, protected_group)?;
let eod = equalized_odds_difference(y_true, y_pred, protected_group).unwrap_or(1.0);
let eo = equal_opportunity_difference(y_true, y_pred, protected_group).unwrap_or(1.0);
Ok(FairnessMetrics {
demographic_parity: dp,
equalized_odds: eod,
equal_opportunity: eo,
})
}