use scirs2_core::ndarray::{Array1, ArrayBase, Data, Ix1, Ix2};
use scirs2_core::numeric::Float;
use scirs2_core::random::{random, rngs::StdRng, seq::SliceRandom, SeedableRng};
use std::collections::HashMap;
use crate::error::{MetricsError, Result};
#[derive(Debug, Clone)]
pub struct InvarianceResult {
pub group_metrics: HashMap<String, f64>,
pub invariance_score: f64,
}
#[allow(dead_code)]
pub fn performance_invariance<T, S1, S2, S3, F>(
y_true: &ArrayBase<S1, Ix1>,
y_pred: &ArrayBase<S2, Ix1>,
protected_groups: &ArrayBase<S3, Ix2>,
group_names: &[String],
metric_fn: F,
) -> Result<InvarianceResult>
where
T: Float + PartialOrd + Clone,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
S3: Data<Elem = T>,
F: Fn(&[T], &[T]) -> f64,
{
let n_samples = y_true.len();
if n_samples != y_pred.len() || n_samples != protected_groups.nrows() {
return Err(MetricsError::InvalidInput(format!(
"Dimensions mismatch: y_true ({}), y_pred ({}), protected_groups ({} rows)",
n_samples,
y_pred.len(),
protected_groups.nrows()
)));
}
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let n_groups = protected_groups.ncols();
if n_groups != group_names.len() {
return Err(MetricsError::InvalidInput(format!(
"Number of group columns ({}) doesn't match number of group _names ({})",
n_groups,
group_names.len()
)));
}
let y_true_vec: Vec<T> = y_true.iter().cloned().collect();
let y_pred_vec: Vec<T> = y_pred.iter().cloned().collect();
let overall_performance = metric_fn(&y_true_vec, &y_pred_vec);
let mut group_metrics = HashMap::new();
group_metrics.insert("overall".to_string(), overall_performance);
for (group_idx, group_name) in group_names.iter().enumerate() {
let mut unique_values = std::collections::BTreeSet::new();
for i in 0..n_samples {
let value = protected_groups[[i, group_idx]]
.to_f64()
.expect("Operation failed");
let rounded_value = (value * 1000.0).round() as i64;
unique_values.insert(rounded_value);
}
for &int_value in &unique_values {
let value = int_value as f64 / 1000.0;
let mut group_y_true = Vec::new();
let mut group_y_pred = Vec::new();
for i in 0..n_samples {
let group_value = protected_groups[[i, group_idx]]
.to_f64()
.expect("Operation failed");
let rounded_value = (group_value * 1000.0).round() as i64;
if rounded_value == int_value {
group_y_true.push(y_true[i].clone());
group_y_pred.push(y_pred[i].clone());
}
}
if group_y_true.is_empty() {
continue;
}
let group_performance = metric_fn(&group_y_true, &group_y_pred);
group_metrics.insert(format!("{}={}", group_name, value), group_performance);
}
}
let performances: Vec<f64> = group_metrics.values().copied().collect();
let mean_performance = performances.iter().sum::<f64>() / performances.len() as f64;
let variance = performances
.iter()
.map(|&p| (p - mean_performance).powi(2))
.sum::<f64>()
/ performances.len() as f64;
let invariance_score = variance.sqrt();
Ok(InvarianceResult {
group_metrics,
invariance_score,
})
}
#[allow(dead_code)]
pub fn influence_function<T, F>(
y_true: &Array1<T>,
y_pred: &Array1<T>,
protected_group: &Array1<T>,
fairness_metric: F,
n_samples: Option<usize>,
) -> Result<Vec<f64>>
where
T: Float + PartialOrd + Clone,
F: Fn(&Array1<T>, &Array1<T>) -> f64,
{
let n_total_samples = y_true.len();
if n_total_samples != y_pred.len() || n_total_samples != protected_group.len() {
return Err(MetricsError::InvalidInput(format!(
"Dimensions mismatch: y_true ({}), y_pred ({}), protected_group ({})",
n_total_samples,
y_pred.len(),
protected_group.len()
)));
}
if n_total_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let n_samples_to_use = n_samples.unwrap_or(n_total_samples);
if n_samples_to_use > n_total_samples {
return Err(MetricsError::InvalidInput(format!(
"Requested number of _samples ({}) exceeds available _samples ({})",
n_samples_to_use, n_total_samples
)));
}
let baseline_fairness = fairness_metric(y_pred, protected_group);
let mut influence_scores = vec![0.0; n_total_samples];
for i in 0..n_samples_to_use {
let mut y_pred_modified = Vec::with_capacity(n_total_samples - 1);
let mut protected_group_modified = Vec::with_capacity(n_total_samples - 1);
for j in 0..n_total_samples {
if j != i {
y_pred_modified.push(y_pred[j].clone());
protected_group_modified.push(protected_group[j].clone());
}
}
let y_pred_temp = Array1::from_vec(y_pred_modified);
let protected_group_temp = Array1::from_vec(protected_group_modified);
let modified_fairness = fairness_metric(&y_pred_temp, &protected_group_temp);
influence_scores[i] = baseline_fairness - modified_fairness;
}
Ok(influence_scores)
}
#[derive(Debug, Clone, Copy)]
pub enum PerturbationType {
LabelFlip,
Subsample,
Noise,
}
#[derive(Debug, Clone)]
pub struct SensitivityResult {
pub original_fairness: f64,
pub mean_fairness: f64,
pub std_deviation: f64,
pub sensitivity_score: f64,
pub perturbed_values: Vec<f64>,
pub perturbation_type: String,
pub perturbation_level: f64,
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn perturbation_sensitivity<T, F>(
y_true: &Array1<T>,
y_pred: &Array1<T>,
protected_group: &Array1<T>,
perturbation_type: PerturbationType,
fairness_metric: F,
perturbation_level: f64,
n_iterations: usize,
random_seed: Option<u64>,
) -> Result<SensitivityResult>
where
T: Float + PartialOrd + Clone,
F: Fn(&Array1<T>, &Array1<T>) -> f64,
{
let n_samples = y_true.len();
if n_samples != y_pred.len() || n_samples != protected_group.len() {
return Err(MetricsError::InvalidInput(format!(
"Dimensions mismatch: y_true ({}), y_pred ({}), protected_group ({})",
n_samples,
y_pred.len(),
protected_group.len()
)));
}
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
if perturbation_level <= 0.0 || perturbation_level >= 1.0 {
return Err(MetricsError::InvalidInput(
"Perturbation _level must be between 0 and 1 exclusive".to_string(),
));
}
if n_iterations == 0 {
return Err(MetricsError::InvalidInput(
"Number of _iterations must be positive".to_string(),
));
}
let mut rng = match random_seed {
Some(_seed) => StdRng::seed_from_u64(_seed),
None => StdRng::seed_from_u64(random()),
};
let original_fairness = fairness_metric(y_pred, protected_group);
let mut perturbed_values = Vec::with_capacity(n_iterations);
for _ in 0..n_iterations {
let perturbed_y_pred = match perturbation_type {
PerturbationType::LabelFlip => {
perturb_by_label_flip(y_pred, perturbation_level, &mut rng)?
}
PerturbationType::Subsample => {
perturb_by_subsample(y_pred, protected_group, perturbation_level, &mut rng)?
}
PerturbationType::Noise => perturb_by_noise(y_pred, perturbation_level, &mut rng)?,
};
let perturbed_fairness = fairness_metric(&perturbed_y_pred, protected_group);
perturbed_values.push(perturbed_fairness);
}
let mean_fairness = perturbed_values.iter().sum::<f64>() / n_iterations as f64;
let variance = perturbed_values
.iter()
.map(|&v| (v - mean_fairness).powi(2))
.sum::<f64>()
/ n_iterations as f64;
let std_deviation = variance.sqrt();
let sensitivity_score = std_deviation / perturbation_level;
Ok(SensitivityResult {
original_fairness,
mean_fairness,
std_deviation,
sensitivity_score,
perturbed_values,
perturbation_type: format!("{:?}", perturbation_type),
perturbation_level,
})
}
#[allow(dead_code)]
fn perturb_by_label_flip<T>(
y_pred: &Array1<T>,
flip_prob: f64,
rng: &mut StdRng,
) -> Result<Array1<T>>
where
T: Float + PartialOrd + Clone,
{
let n_samples = y_pred.len();
let n_flips = (n_samples as f64 * flip_prob).round() as usize;
let mut perturbed = y_pred.to_owned().into_raw_vec_and_offset().0;
let mut indices: Vec<usize> = (0..n_samples).collect();
indices.shuffle(rng);
let flip_indices = &indices[0..n_flips];
let zero = T::zero();
let one = T::one();
for &idx in flip_indices {
if perturbed[idx] > zero {
perturbed[idx] = zero.clone();
} else {
perturbed[idx] = one.clone();
}
}
Ok(Array1::from_vec(perturbed))
}
#[allow(dead_code)]
fn perturb_by_subsample<T>(
y_pred: &Array1<T>,
protected_group: &Array1<T>,
sample_fraction: f64,
rng: &mut StdRng,
) -> Result<Array1<T>>
where
T: Float + PartialOrd + Clone,
{
let n_samples = y_pred.len();
let n_subsample = (n_samples as f64 * sample_fraction).round() as usize;
let y_pred_vec = y_pred.to_owned().into_raw_vec_and_offset().0;
let protected_vec = protected_group.to_owned().into_raw_vec_and_offset().0;
let mut indices: Vec<usize> = (0..n_samples).collect();
indices.shuffle(rng);
let subsample_indices = &indices[0..n_subsample];
let mut perturbed = Vec::with_capacity(n_samples);
for i in 0..n_samples {
if subsample_indices.contains(&i) {
perturbed.push(y_pred_vec[i].clone());
} else {
let group_val = &protected_vec[i];
let zero = T::zero();
let is_protected = group_val > &zero;
let mut group_sum = T::zero();
let mut group_count = 0;
for j in 0..n_samples {
let j_group_val = &protected_vec[j];
let j_is_protected = j_group_val > &zero;
if j_is_protected == is_protected {
group_sum = group_sum + y_pred_vec[j].clone();
group_count += 1;
}
}
if group_count > 0 {
let group_avg = group_sum / T::from(group_count).expect("Operation failed");
perturbed.push(group_avg);
} else {
perturbed.push(y_pred_vec[i].clone());
}
}
}
Ok(Array1::from_vec(perturbed))
}
#[allow(dead_code)]
fn perturb_by_noise<T>(
y_pred: &Array1<T>,
noise_level: f64,
_rng: &mut StdRng, ) -> Result<Array1<T>>
where
T: Float + PartialOrd + Clone,
{
let n_samples = y_pred.len();
let y_pred_vec = y_pred.to_owned().into_raw_vec_and_offset().0;
let mut perturbed = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let y_val = y_pred_vec[i].to_f64().expect("Operation failed");
let u1: f64 = random();
let u2: f64 = random();
let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
let noise = z0 * noise_level;
let mut perturbed_val = y_val + noise;
if (y_val == 0.0 || y_val == 1.0) && !(0.0..=1.0).contains(&perturbed_val) {
perturbed_val = perturbed_val.clamp(0.0, 1.0);
}
perturbed.push(T::from(perturbed_val).expect("Operation failed"));
}
Ok(Array1::from_vec(perturbed))
}