#![allow(dead_code)]
use std::collections::HashMap;
use crate::error::{RecommendError, RecommendResult};
pub type CreatorId = u64;
#[derive(Debug, Clone)]
pub struct FairnessCandidate {
pub content_id: u64,
pub creator_id: CreatorId,
pub relevance: f32,
}
impl FairnessCandidate {
#[must_use]
pub fn new(content_id: u64, creator_id: CreatorId, relevance: f32) -> Self {
Self {
content_id,
creator_id,
relevance: relevance.clamp(0.0, 1.0),
}
}
}
pub struct ExposureGini;
impl ExposureGini {
#[must_use]
pub fn compute(ranked: &[FairnessCandidate]) -> f32 {
if ranked.is_empty() {
return 0.0;
}
let mut exposure: HashMap<CreatorId, f64> = HashMap::new();
for (rank0, item) in ranked.iter().enumerate() {
let rank = rank0 + 1;
let discount = 1.0 / (rank as f64 + 1.0).log2();
*exposure.entry(item.creator_id).or_default() += discount;
}
let n = exposure.len();
if n <= 1 {
return 0.0;
}
let mut values: Vec<f64> = exposure.into_values().collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let sum: f64 = values.iter().sum();
if sum <= 0.0 {
return 0.0;
}
let n_f = n as f64;
let weighted_sum: f64 = values
.iter()
.enumerate()
.map(|(i, &v)| (i as f64 + 1.0) * v)
.sum();
((2.0 * weighted_sum) / (n_f * sum) - (n_f + 1.0) / n_f) as f32
}
}
#[derive(Debug, Clone)]
pub struct ExposureDisparity {
pub actual: HashMap<CreatorId, f32>,
pub target: HashMap<CreatorId, f32>,
pub ratio: HashMap<CreatorId, f32>,
}
impl ExposureDisparity {
pub fn compute(
ranked: &[FairnessCandidate],
target_distribution: &HashMap<CreatorId, f32>,
) -> RecommendResult<Self> {
if target_distribution.is_empty() {
return Err(RecommendError::Other(
"target_distribution must not be empty".to_string(),
));
}
let target_sum: f32 = target_distribution.values().sum();
let target: HashMap<CreatorId, f32> = if target_sum > 0.0 {
target_distribution
.iter()
.map(|(&k, &v)| (k, v / target_sum))
.collect()
} else {
target_distribution.clone()
};
let total = ranked.len() as f32;
let mut counts: HashMap<CreatorId, f32> = HashMap::new();
for item in ranked {
*counts.entry(item.creator_id).or_default() += 1.0;
}
let actual: HashMap<CreatorId, f32> = counts
.iter()
.map(|(&k, &v)| (k, if total > 0.0 { v / total } else { 0.0 }))
.collect();
let ratio: HashMap<CreatorId, f32> = target
.keys()
.map(|&creator| {
let act = actual.get(&creator).copied().unwrap_or(0.0);
let tgt = *target.get(&creator).unwrap_or(&0.0);
let r = if tgt > 0.0 { act / tgt } else { 0.0 };
(creator, r)
})
.collect();
Ok(Self {
actual,
target,
ratio,
})
}
#[must_use]
pub fn mean_absolute_deviation(&self) -> f32 {
if self.target.is_empty() {
return 0.0;
}
let sum: f32 = self
.target
.keys()
.map(|c| {
let act = self.actual.get(c).copied().unwrap_or(0.0);
let tgt = self.target.get(c).copied().unwrap_or(0.0);
(act - tgt).abs()
})
.sum();
sum / self.target.len() as f32
}
}
pub struct NdcgFairness;
impl NdcgFairness {
#[must_use]
pub fn group_dcg_fraction(
ranked: &[FairnessCandidate],
protected_creators: &[CreatorId],
) -> (f32, f32) {
if ranked.is_empty() {
return (0.0, 0.0);
}
let protected_set: std::collections::HashSet<CreatorId> =
protected_creators.iter().copied().collect();
let mut total_dcg = 0.0_f64;
let mut group_dcg = 0.0_f64;
for (rank0, item) in ranked.iter().enumerate() {
let rank = rank0 + 1;
let discount = 1.0 / (rank as f64 + 1.0).log2();
let gain = f64::from(item.relevance) * discount;
total_dcg += gain;
if protected_set.contains(&item.creator_id) {
group_dcg += gain;
}
}
let fraction = if total_dcg > 0.0 {
(group_dcg / total_dcg) as f32
} else {
0.0
};
(fraction, total_dcg as f32)
}
#[must_use]
pub fn representation_ratio(
ranked: &[FairnessCandidate],
protected_creators: &[CreatorId],
) -> f32 {
if ranked.is_empty() {
return 0.0;
}
let protected_set: std::collections::HashSet<CreatorId> =
protected_creators.iter().copied().collect();
let count = ranked
.iter()
.filter(|i| protected_set.contains(&i.creator_id))
.count();
count as f32 / ranked.len() as f32
}
}
#[derive(Debug, Clone)]
pub struct FairnessConfig {
pub target_distribution: HashMap<CreatorId, f32>,
pub max_relevance_penalty: f32,
pub restrict_to_target_creators: bool,
}
impl FairnessConfig {
#[must_use]
pub fn uniform(creator_ids: &[CreatorId]) -> Self {
let n = creator_ids.len();
let share = if n > 0 { 1.0 / n as f32 } else { 0.0 };
let target_distribution = creator_ids.iter().map(|&id| (id, share)).collect();
Self {
target_distribution,
max_relevance_penalty: 0.2,
restrict_to_target_creators: false,
}
}
}
pub struct FairnessReranker {
config: FairnessConfig,
}
impl FairnessReranker {
pub fn new(config: FairnessConfig) -> RecommendResult<Self> {
if !(0.0..=1.0).contains(&config.max_relevance_penalty) {
return Err(RecommendError::Other(
"max_relevance_penalty must be in [0, 1]".to_string(),
));
}
Ok(Self { config })
}
#[must_use]
pub fn rerank(&self, candidates: &[FairnessCandidate]) -> Vec<FairnessCandidate> {
if candidates.is_empty() {
return Vec::new();
}
let n = candidates.len();
let original_avg = avg_relevance(candidates);
let target_sum: f32 = self.config.target_distribution.values().sum();
let normalised_target: HashMap<CreatorId, f32> = if target_sum > 0.0 {
self.config
.target_distribution
.iter()
.map(|(&k, &v)| (k, v / target_sum))
.collect()
} else {
return candidates.to_vec();
};
let target_slots: HashMap<CreatorId, usize> = normalised_target
.iter()
.map(|(&creator, &frac)| (creator, (frac * n as f32).round() as usize))
.collect();
let mut creator_pools: HashMap<CreatorId, Vec<FairnessCandidate>> = HashMap::new();
let mut unclaimed: Vec<FairnessCandidate> = Vec::new();
for item in candidates {
if self
.config
.target_distribution
.contains_key(&item.creator_id)
{
creator_pools
.entry(item.creator_id)
.or_default()
.push(item.clone());
} else if self.config.restrict_to_target_creators {
} else {
unclaimed.push(item.clone());
}
}
for pool in creator_pools.values_mut() {
pool.sort_by(|a, b| {
b.relevance
.partial_cmp(&a.relevance)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
unclaimed.sort_by(|a, b| {
b.relevance
.partial_cmp(&a.relevance)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut result: Vec<FairnessCandidate> = Vec::with_capacity(n);
for (&creator, &slots) in &target_slots {
let pool = match creator_pools.get_mut(&creator) {
Some(p) => p,
None => continue,
};
let take = slots.min(pool.len());
for item in pool.drain(..take) {
result.push(item);
}
}
for pool in creator_pools.values_mut() {
unclaimed.append(pool);
}
unclaimed.sort_by(|a, b| {
b.relevance
.partial_cmp(&a.relevance)
.unwrap_or(std::cmp::Ordering::Equal)
});
let remaining = n.saturating_sub(result.len());
result.extend(unclaimed.into_iter().take(remaining));
result.sort_by(|a, b| {
b.relevance
.partial_cmp(&a.relevance)
.unwrap_or(std::cmp::Ordering::Equal)
});
let new_avg = avg_relevance(&result);
let penalty = original_avg - new_avg;
if penalty > self.config.max_relevance_penalty {
return candidates.to_vec();
}
result
}
}
fn avg_relevance(items: &[FairnessCandidate]) -> f32 {
if items.is_empty() {
return 0.0;
}
items.iter().map(|i| i.relevance).sum::<f32>() / items.len() as f32
}
#[cfg(test)]
mod tests {
use super::*;
fn make_candidates() -> Vec<FairnessCandidate> {
vec![
FairnessCandidate::new(1, 1, 0.9),
FairnessCandidate::new(2, 1, 0.85),
FairnessCandidate::new(3, 1, 0.80),
FairnessCandidate::new(4, 1, 0.75),
FairnessCandidate::new(5, 2, 0.70),
FairnessCandidate::new(6, 2, 0.65),
]
}
#[test]
fn test_exposure_gini_empty() {
let gini = ExposureGini::compute(&[]);
assert_eq!(gini, 0.0);
}
#[test]
fn test_exposure_gini_single_creator() {
let items = vec![
FairnessCandidate::new(1, 42, 0.9),
FairnessCandidate::new(2, 42, 0.8),
];
let gini = ExposureGini::compute(&items);
assert_eq!(gini, 0.0, "single creator → Gini = 0");
}
#[test]
fn test_exposure_gini_unequal() {
let candidates = make_candidates();
let gini = ExposureGini::compute(&candidates);
assert!(
gini > 0.0,
"unequal distribution should yield positive Gini"
);
assert!(gini <= 1.0);
}
#[test]
fn test_exposure_disparity_empty_target_error() {
let candidates = make_candidates();
let result = ExposureDisparity::compute(&candidates, &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_exposure_disparity_ratios() {
let candidates = make_candidates();
let mut target = HashMap::new();
target.insert(1_u64, 0.5_f32);
target.insert(2_u64, 0.5_f32);
let disparity = ExposureDisparity::compute(&candidates, &target).expect("should compute");
let r1 = disparity.ratio[&1];
let r2 = disparity.ratio[&2];
assert!(r1 > 1.0, "creator 1 is over-represented: {r1}");
assert!(r2 < 1.0, "creator 2 is under-represented: {r2}");
}
#[test]
fn test_exposure_disparity_mad() {
let candidates = make_candidates();
let mut target = HashMap::new();
target.insert(1_u64, 0.5_f32);
target.insert(2_u64, 0.5_f32);
let disparity = ExposureDisparity::compute(&candidates, &target).expect("should compute");
let mad = disparity.mean_absolute_deviation();
assert!(mad >= 0.0 && mad <= 1.0);
}
#[test]
fn test_ndcg_fairness_group_fraction() {
let candidates = make_candidates();
let (frac, total_dcg) = NdcgFairness::group_dcg_fraction(&candidates, &[1]);
assert!(frac > 0.0 && frac <= 1.0);
assert!(total_dcg > 0.0);
assert!(frac > 0.5, "creator 1 dominates top slots: {frac}");
}
#[test]
fn test_ndcg_fairness_representation_ratio() {
let candidates = make_candidates();
let ratio = NdcgFairness::representation_ratio(&candidates, &[2]);
assert!((ratio - 2.0 / 6.0).abs() < 1e-4);
}
#[test]
fn test_fairness_reranker_invalid_penalty() {
let config = FairnessConfig {
target_distribution: HashMap::new(),
max_relevance_penalty: 1.5,
restrict_to_target_creators: false,
};
assert!(FairnessReranker::new(config).is_err());
}
#[test]
fn test_fairness_reranker_improves_gini() {
let candidates = make_candidates();
let gini_before = ExposureGini::compute(&candidates);
let mut target = HashMap::new();
target.insert(1_u64, 0.5_f32);
target.insert(2_u64, 0.5_f32);
let config = FairnessConfig {
target_distribution: target,
max_relevance_penalty: 0.5,
restrict_to_target_creators: false,
};
let reranker = FairnessReranker::new(config).expect("valid config");
let reranked = reranker.rerank(&candidates);
assert!(!reranked.is_empty());
let gini_after = ExposureGini::compute(&reranked);
assert!(
gini_after <= gini_before + 1e-4,
"Gini should not increase: before={gini_before}, after={gini_after}"
);
}
#[test]
fn test_fairness_reranker_empty_input() {
let config = FairnessConfig::uniform(&[1, 2]);
let reranker = FairnessReranker::new(config).expect("valid config");
let result = reranker.rerank(&[]);
assert!(result.is_empty());
}
}