use crate::{EntityType, Model};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub enum Decade {
Pre1900,
D1900s,
D1910s,
D1920s,
D1930s,
D1940s,
D1950s,
D1960s,
D1970s,
D1980s,
D1990s,
D2000s,
D2010s,
D2020s,
}
impl Decade {
pub fn is_historical(&self) -> bool {
matches!(
self,
Decade::Pre1900
| Decade::D1900s
| Decade::D1910s
| Decade::D1920s
| Decade::D1930s
| Decade::D1940s
)
}
pub fn is_modern(&self) -> bool {
matches!(self, Decade::D2000s | Decade::D2010s | Decade::D2020s)
}
pub fn midpoint_year(&self) -> u16 {
match self {
Decade::Pre1900 => 1890,
Decade::D1900s => 1905,
Decade::D1910s => 1915,
Decade::D1920s => 1925,
Decade::D1930s => 1935,
Decade::D1940s => 1945,
Decade::D1950s => 1955,
Decade::D1960s => 1965,
Decade::D1970s => 1975,
Decade::D1980s => 1985,
Decade::D1990s => 1995,
Decade::D2000s => 2005,
Decade::D2010s => 2015,
Decade::D2020s => 2022,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporalNameExample {
pub first_name: String,
pub last_name: String,
pub full_name: String,
pub peak_decade: Decade,
pub gender: TemporalGender,
pub is_classic: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum TemporalGender {
Masculine,
Feminine,
Neutral,
}
impl TemporalNameExample {
pub fn new(
first_name: &str,
last_name: &str,
peak_decade: Decade,
gender: TemporalGender,
is_classic: bool,
) -> Self {
Self {
first_name: first_name.to_string(),
last_name: last_name.to_string(),
full_name: format!("{} {}", first_name, last_name),
peak_decade,
gender,
is_classic,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporalBiasResults {
pub overall_recognition_rate: f64,
pub by_decade: HashMap<String, f64>,
pub historical_rate: f64,
pub modern_rate: f64,
pub historical_modern_gap: f64,
pub temporal_parity_gap: f64,
pub by_gender: HashMap<String, f64>,
pub classic_rate: f64,
pub trendy_rate: f64,
pub total_tested: usize,
}
#[derive(Debug, Clone, Default)]
pub struct TemporalBiasEvaluator {
pub detailed: bool,
}
impl TemporalBiasEvaluator {
pub fn new(detailed: bool) -> Self {
Self { detailed }
}
pub fn evaluate(
&self,
model: &dyn Model,
names: &[TemporalNameExample],
) -> TemporalBiasResults {
let mut by_decade: HashMap<String, (usize, usize)> = HashMap::new();
let mut by_gender: HashMap<String, (usize, usize)> = HashMap::new();
let mut historical_count = (0usize, 0usize);
let mut modern_count = (0usize, 0usize);
let mut classic_count = (0usize, 0usize);
let mut trendy_count = (0usize, 0usize);
let mut total_recognized = 0;
for name in names {
let text = create_realistic_temporal_sentence(&name.full_name);
let entities = model.extract_entities(&text, None).unwrap_or_default();
let recognized = entities.iter().any(|e| {
e.entity_type == EntityType::Person
&& e.extract_text(&text).contains(&name.first_name)
});
if recognized {
total_recognized += 1;
}
let decade_key = format!("{:?}", name.peak_decade);
let decade_entry = by_decade.entry(decade_key).or_insert((0, 0));
decade_entry.1 += 1;
if recognized {
decade_entry.0 += 1;
}
if name.peak_decade.is_historical() {
historical_count.1 += 1;
if recognized {
historical_count.0 += 1;
}
}
if name.peak_decade.is_modern() {
modern_count.1 += 1;
if recognized {
modern_count.0 += 1;
}
}
let gender_key = format!("{:?}", name.gender);
let gender_entry = by_gender.entry(gender_key).or_insert((0, 0));
gender_entry.1 += 1;
if recognized {
gender_entry.0 += 1;
}
if name.is_classic {
classic_count.1 += 1;
if recognized {
classic_count.0 += 1;
}
} else {
trendy_count.1 += 1;
if recognized {
trendy_count.0 += 1;
}
}
}
let to_rate = |counts: &HashMap<String, (usize, usize)>| -> HashMap<String, f64> {
counts
.iter()
.map(|(k, (correct, total))| {
let rate = if *total > 0 {
*correct as f64 / *total as f64
} else {
0.0
};
(k.clone(), rate)
})
.collect()
};
let count_to_rate = |c: (usize, usize)| -> f64 {
if c.1 > 0 {
c.0 as f64 / c.1 as f64
} else {
0.0
}
};
let decade_rates = to_rate(&by_decade);
let gender_rates = to_rate(&by_gender);
let historical_rate = count_to_rate(historical_count);
let modern_rate = count_to_rate(modern_count);
let classic_rate = count_to_rate(classic_count);
let trendy_rate = count_to_rate(trendy_count);
let temporal_parity_gap = compute_max_gap(&decade_rates);
let historical_modern_gap = (historical_rate - modern_rate).abs();
TemporalBiasResults {
overall_recognition_rate: if names.is_empty() {
0.0
} else {
total_recognized as f64 / names.len() as f64
},
by_decade: decade_rates,
historical_rate,
modern_rate,
historical_modern_gap,
temporal_parity_gap,
by_gender: gender_rates,
classic_rate,
trendy_rate,
total_tested: names.len(),
}
}
}
fn compute_max_gap(rates: &HashMap<String, f64>) -> f64 {
if rates.len() < 2 {
return 0.0;
}
let values: Vec<f64> = rates.values().copied().collect();
let min = values.iter().copied().fold(f64::INFINITY, f64::min);
let max = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
max - min
}
fn create_realistic_temporal_sentence(name: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
name.hash(&mut hasher);
let hash = hasher.finish();
let templates = [
format!("{} was featured in the historical archives.", name),
format!("The biography of {} was published last year.", name),
format!("{} made significant contributions to the field.", name),
format!("Records show that {} attended the event in 1950.", name),
format!("{} was recognized for lifetime achievements.", name),
format!("The family of {} established a scholarship fund.", name),
format!("{} served as president of the organization.", name),
format!("Historical documents mention {} in several contexts.", name),
format!("{} was known for innovative research methods.", name),
format!(
"The legacy of {} continues to inspire new generations.",
name
),
];
templates[hash as usize % templates.len()].clone()
}
pub fn create_temporal_name_dataset() -> Vec<TemporalNameExample> {
let mut names = Vec::new();
let last_names = ["Smith", "Johnson", "Williams", "Brown", "Jones"];
let pre1900 = [
("Gertrude", TemporalGender::Feminine),
("Clarence", TemporalGender::Masculine),
("Mildred", TemporalGender::Feminine),
("Herbert", TemporalGender::Masculine),
("Bertha", TemporalGender::Feminine),
("Agnes", TemporalGender::Feminine),
("Albert", TemporalGender::Masculine),
("Florence", TemporalGender::Feminine),
("Walter", TemporalGender::Masculine),
("Edith", TemporalGender::Feminine),
];
let d1900s = [
("Ethel", TemporalGender::Feminine),
("Harold", TemporalGender::Masculine),
("Pearl", TemporalGender::Feminine),
("Clarence", TemporalGender::Masculine),
("Minnie", TemporalGender::Feminine),
("Alice", TemporalGender::Feminine),
("Raymond", TemporalGender::Masculine),
("Ruth", TemporalGender::Feminine),
("Frank", TemporalGender::Masculine),
("Helen", TemporalGender::Feminine),
];
let d1910s = [
("Dorothy", TemporalGender::Feminine),
("Earl", TemporalGender::Masculine),
("Gladys", TemporalGender::Feminine),
("Howard", TemporalGender::Masculine),
("Thelma", TemporalGender::Feminine),
];
let d1920s = [
("Betty", TemporalGender::Feminine),
("Donald", TemporalGender::Masculine),
("Doris", TemporalGender::Feminine),
("Raymond", TemporalGender::Masculine),
("Shirley", TemporalGender::Feminine),
];
let d1930s = [
("Barbara", TemporalGender::Feminine),
("Robert", TemporalGender::Masculine),
("Patricia", TemporalGender::Feminine),
("Richard", TemporalGender::Masculine),
("Carol", TemporalGender::Feminine),
];
let d1940s = [
("Linda", TemporalGender::Feminine),
("Gary", TemporalGender::Masculine),
("Sandra", TemporalGender::Feminine),
("Larry", TemporalGender::Masculine),
("Sharon", TemporalGender::Feminine),
];
let d1950s = [
("Deborah", TemporalGender::Feminine),
("Dennis", TemporalGender::Masculine),
("Debra", TemporalGender::Feminine),
("Timothy", TemporalGender::Masculine),
("Pamela", TemporalGender::Feminine),
];
let d1960s = [
("Lisa", TemporalGender::Feminine),
("Mark", TemporalGender::Masculine),
("Kimberly", TemporalGender::Feminine),
("Kevin", TemporalGender::Masculine),
("Michelle", TemporalGender::Feminine),
];
let d1970s = [
("Jennifer", TemporalGender::Feminine),
("Jason", TemporalGender::Masculine),
("Amy", TemporalGender::Feminine),
("Brian", TemporalGender::Masculine),
("Heather", TemporalGender::Feminine),
];
let d1980s = [
("Jessica", TemporalGender::Feminine),
("Michael", TemporalGender::Masculine),
("Amanda", TemporalGender::Feminine),
("Christopher", TemporalGender::Masculine),
("Ashley", TemporalGender::Feminine),
];
let d1990s = [
("Brittany", TemporalGender::Feminine),
("Tyler", TemporalGender::Masculine),
("Taylor", TemporalGender::Neutral),
("Brandon", TemporalGender::Masculine),
("Megan", TemporalGender::Feminine),
];
let d2000s = [
("Madison", TemporalGender::Feminine),
("Aiden", TemporalGender::Masculine),
("Emma", TemporalGender::Feminine),
("Ethan", TemporalGender::Masculine),
("Chloe", TemporalGender::Feminine),
];
let d2010s = [
("Sophia", TemporalGender::Feminine),
("Liam", TemporalGender::Masculine),
("Olivia", TemporalGender::Feminine),
("Noah", TemporalGender::Masculine),
("Ava", TemporalGender::Feminine),
];
let d2020s = [
("Luna", TemporalGender::Feminine),
("Ezra", TemporalGender::Masculine),
("Charlotte", TemporalGender::Feminine),
("Oliver", TemporalGender::Masculine),
("Amelia", TemporalGender::Feminine),
("Mia", TemporalGender::Feminine),
("Liam", TemporalGender::Masculine),
("Harper", TemporalGender::Neutral),
("Mason", TemporalGender::Masculine),
("Evelyn", TemporalGender::Feminine),
];
let classics = [
("James", TemporalGender::Masculine, true),
("Elizabeth", TemporalGender::Feminine, true),
("William", TemporalGender::Masculine, true),
("Mary", TemporalGender::Feminine, true),
("John", TemporalGender::Masculine, true),
("Sarah", TemporalGender::Feminine, true),
("Robert", TemporalGender::Masculine, true),
("Anna", TemporalGender::Feminine, true),
("Michael", TemporalGender::Masculine, true),
("Emily", TemporalGender::Feminine, true),
];
let add_decade = |names: &mut Vec<TemporalNameExample>,
decade_names: &[(&str, TemporalGender)],
decade: Decade,
last_names: &[&str]| {
for (i, (first, gender)) in decade_names.iter().enumerate() {
let last = last_names[i % last_names.len()];
names.push(TemporalNameExample::new(
first, last, decade, *gender, false,
));
}
};
add_decade(&mut names, &pre1900, Decade::Pre1900, &last_names);
add_decade(&mut names, &d1900s, Decade::D1900s, &last_names);
add_decade(&mut names, &d1910s, Decade::D1910s, &last_names);
add_decade(&mut names, &d1920s, Decade::D1920s, &last_names);
add_decade(&mut names, &d1930s, Decade::D1930s, &last_names);
add_decade(&mut names, &d1940s, Decade::D1940s, &last_names);
add_decade(&mut names, &d1950s, Decade::D1950s, &last_names);
add_decade(&mut names, &d1960s, Decade::D1960s, &last_names);
add_decade(&mut names, &d1970s, Decade::D1970s, &last_names);
add_decade(&mut names, &d1980s, Decade::D1980s, &last_names);
add_decade(&mut names, &d1990s, Decade::D1990s, &last_names);
add_decade(&mut names, &d2000s, Decade::D2000s, &last_names);
add_decade(&mut names, &d2010s, Decade::D2010s, &last_names);
add_decade(&mut names, &d2020s, Decade::D2020s, &last_names);
for (i, (first, gender, _is_classic)) in classics.iter().enumerate() {
let last = last_names[i % last_names.len()];
names.push(TemporalNameExample::new(
first,
last,
Decade::D1950s,
*gender,
true,
));
}
names
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_temporal_dataset() {
let names = create_temporal_name_dataset();
let decades: std::collections::HashSet<_> = names
.iter()
.map(|n| format!("{:?}", n.peak_decade))
.collect();
assert!(decades.len() >= 10, "Should cover at least 10 decades");
assert!(
decades.contains("Pre1900"),
"Should have pre-1900 (Victorian) names"
);
assert!(decades.contains("D2020s"), "Should have 2020s names");
}
#[test]
fn test_historical_vs_modern() {
let names = create_temporal_name_dataset();
let historical = names
.iter()
.filter(|n| n.peak_decade.is_historical())
.count();
let modern = names.iter().filter(|n| n.peak_decade.is_modern()).count();
assert!(historical > 0, "Should have historical names");
assert!(modern > 0, "Should have modern names");
}
#[test]
fn test_classic_names_marked() {
let names = create_temporal_name_dataset();
let classics: Vec<_> = names.iter().filter(|n| n.is_classic).collect();
assert!(!classics.is_empty(), "Should have classic names");
assert!(
classics.iter().any(|n| n.first_name == "James"),
"James should be a classic"
);
assert!(
classics.iter().any(|n| n.first_name == "Elizabeth"),
"Elizabeth should be a classic"
);
}
#[test]
fn test_decade_ordering() {
assert!(Decade::Pre1900 < Decade::D1900s);
assert!(Decade::D1900s < Decade::D2020s);
assert!(Decade::D1980s.midpoint_year() == 1985);
}
#[test]
fn test_gender_distribution() {
let names = create_temporal_name_dataset();
let masculine = names
.iter()
.filter(|n| n.gender == TemporalGender::Masculine)
.count();
let feminine = names
.iter()
.filter(|n| n.gender == TemporalGender::Feminine)
.count();
assert!(masculine > 20, "Should have substantial masculine names");
assert!(feminine > 20, "Should have substantial feminine names");
}
}