use crate::metrics::evaluation::Metric;
use crate::objective::ObjectiveFunction;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FairnessType {
DemographicParity,
EqualizedOdds,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FairnessObjective {
pub sensitive_attr: Vec<i32>,
pub lambda: f32,
pub fairness_type: FairnessType,
}
impl FairnessObjective {
pub fn new(sensitive_attr: Vec<i32>, lambda: f32) -> Self {
Self {
sensitive_attr,
lambda,
fairness_type: FairnessType::DemographicParity,
}
}
pub fn with_type(sensitive_attr: Vec<i32>, lambda: f32, fairness_type: FairnessType) -> Self {
Self {
sensitive_attr,
lambda,
fairness_type,
}
}
}
#[derive(Default)]
struct GroupStats {
sum: f64,
count: f64,
}
impl GroupStats {
fn mean(&self) -> f64 {
if self.count < 1.0 { 0.0 } else { self.sum / self.count }
}
fn safe_count(&self) -> f64 {
if self.count < 1.0 { 1.0 } else { self.count }
}
}
impl ObjectiveFunction for FairnessObjective {
fn loss(&self, y: &[f64], yhat: &[f64], _sample_weight: Option<&[f64]>, _group: Option<&[u64]>) -> Vec<f32> {
let n = y.len();
let probs: Vec<f64> = yhat.iter().map(|yh| 1.0 / (1.0 + (-yh).exp())).collect();
let mut loss = Vec::with_capacity(n);
for i in 0..n {
let y_i = y[i];
let p = probs[i];
let score = -(y_i * p.max(1e-15).ln() + (1.0 - y_i) * (1.0 - p).max(1e-15).ln());
loss.push(score as f32);
}
let n_f64 = n as f64;
match &self.fairness_type {
FairnessType::DemographicParity => {
let mut s1 = GroupStats::default();
let mut s0 = GroupStats::default();
for (i, &p) in probs.iter().enumerate() {
if self.sensitive_attr[i] == 1 {
s1.sum += p;
s1.count += 1.0;
} else {
s0.sum += p;
s0.count += 1.0;
}
}
let diff = s1.mean() - s0.mean();
let penalty = (self.lambda as f64) * diff * diff * n_f64;
let p_per_n = (penalty / n_f64) as f32;
for l in loss.iter_mut() {
*l += p_per_n;
}
}
FairnessType::EqualizedOdds => {
let mut s0_y0 = GroupStats::default();
let mut s0_y1 = GroupStats::default();
let mut s1_y0 = GroupStats::default();
let mut s1_y1 = GroupStats::default();
for i in 0..n {
let p = probs[i];
let label = if y[i] >= 0.5 { 1 } else { 0 };
let group = self.sensitive_attr[i];
match (group, label) {
(1, 1) => {
s1_y1.sum += p;
s1_y1.count += 1.0;
}
(1, 0) => {
s1_y0.sum += p;
s1_y0.count += 1.0;
}
(_, 1) => {
s0_y1.sum += p;
s0_y1.count += 1.0;
}
(_, 0) => {
s0_y0.sum += p;
s0_y0.count += 1.0;
}
_ => unreachable!(),
}
}
let diff_y0 = s1_y0.mean() - s0_y0.mean();
let diff_y1 = s1_y1.mean() - s0_y1.mean();
let penalty = (self.lambda as f64) * (diff_y0 * diff_y0 + diff_y1 * diff_y1) * n_f64;
let p_per_n = (penalty / n_f64) as f32;
for l in loss.iter_mut() {
*l += p_per_n;
}
}
}
loss
}
fn gradient(
&self,
y: &[f64],
yhat: &[f64],
_sample_weight: Option<&[f64]>,
_group: Option<&[u64]>,
) -> (Vec<f32>, Option<Vec<f32>>) {
let n = y.len();
let mut grad = Vec::with_capacity(n);
let mut hess = Vec::with_capacity(n);
let probs: Vec<f64> = yhat.iter().map(|yh| 1.0 / (1.0 + (-yh).exp())).collect();
let n_f64 = n as f64;
match &self.fairness_type {
FairnessType::DemographicParity => {
let mut s1 = GroupStats::default();
let mut s0 = GroupStats::default();
for (i, &p) in probs.iter().enumerate() {
if self.sensitive_attr[i] == 1 {
s1.sum += p;
s1.count += 1.0;
} else {
s0.sum += p;
s0.count += 1.0;
}
}
let diff = s1.mean() - s0.mean();
for i in 0..n {
let p = probs[i];
let dp = p * (1.0 - p); let mut g = p - y[i];
let fairness_grad = if self.sensitive_attr[i] == 1 {
2.0 * (self.lambda as f64) * diff * (1.0 / s1.safe_count()) * dp * n_f64
} else {
2.0 * (self.lambda as f64) * diff * (-1.0 / s0.safe_count()) * dp * n_f64
};
g += fairness_grad;
grad.push(g as f32);
hess.push(dp as f32);
}
}
FairnessType::EqualizedOdds => {
let mut s0_y0 = GroupStats::default();
let mut s0_y1 = GroupStats::default();
let mut s1_y0 = GroupStats::default();
let mut s1_y1 = GroupStats::default();
for i in 0..n {
let p = probs[i];
let label = if y[i] >= 0.5 { 1 } else { 0 };
let group = self.sensitive_attr[i];
match (group, label) {
(1, 1) => {
s1_y1.sum += p;
s1_y1.count += 1.0;
}
(1, 0) => {
s1_y0.sum += p;
s1_y0.count += 1.0;
}
(_, 1) => {
s0_y1.sum += p;
s0_y1.count += 1.0;
}
(_, 0) => {
s0_y0.sum += p;
s0_y0.count += 1.0;
}
_ => unreachable!(),
}
}
let diff_y0 = s1_y0.mean() - s0_y0.mean();
let diff_y1 = s1_y1.mean() - s0_y1.mean();
for i in 0..n {
let p = probs[i];
let dp = p * (1.0 - p);
let mut g = p - y[i];
let label = if y[i] >= 0.5 { 1 } else { 0 };
let (diff, cnt_s1, cnt_s0) = if label == 1 {
(diff_y1, s1_y1.safe_count(), s0_y1.safe_count())
} else {
(diff_y0, s1_y0.safe_count(), s0_y0.safe_count())
};
let fairness_grad = if self.sensitive_attr[i] == 1 {
2.0 * (self.lambda as f64) * diff * (1.0 / cnt_s1) * dp * n_f64
} else {
2.0 * (self.lambda as f64) * diff * (-1.0 / cnt_s0) * dp * n_f64
};
g += fairness_grad;
grad.push(g as f32);
hess.push(dp as f32);
}
}
}
(grad, Some(hess))
}
fn initial_value(&self, y: &[f64], _sample_weight: Option<&[f64]>, _group: Option<&[u64]>) -> f64 {
let mean = y.iter().sum::<f64>() / y.len() as f64;
let p = mean.clamp(1e-15, 1.0 - 1e-15);
(p / (1.0 - p)).ln()
}
fn default_metric(&self) -> Metric {
Metric::LogLoss
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fairness_dp() {
let sensitive = vec![1, 0, 1, 0];
let y = vec![1.0, 1.0, 0.0, 0.0];
let yhat = vec![0.0, 0.0, 0.0, 0.0]; let obj = FairnessObjective::new(sensitive, 1.0);
assert!((obj.initial_value(&y, None::<&[f64]>, None::<&[u64]>) - 0.0).abs() < 1e-6);
let loss = obj.loss(&y, &yhat, None::<&[f64]>, None::<&[u64]>);
assert_eq!(loss.len(), 4);
let (g, h) = obj.gradient(&y, &yhat, None::<&[f64]>, None::<&[u64]>);
assert_eq!(g.len(), 4);
assert!(h.is_some());
}
#[test]
fn test_fairness_eo() {
let sensitive = vec![1, 0, 1, 0];
let y = vec![1.0, 1.0, 0.0, 0.0];
let yhat = vec![0.0, 0.0, 0.0, 0.0];
let obj = FairnessObjective::with_type(sensitive, 1.0, FairnessType::EqualizedOdds);
let loss = obj.loss(&y, &yhat, None::<&[f64]>, None::<&[u64]>);
assert_eq!(loss.len(), 4);
let (g, h) = obj.gradient(&y, &yhat, None::<&[f64]>, None::<&[u64]>);
assert_eq!(g.len(), 4);
assert!(h.is_some());
}
}