use crate::eval::coref_resolver::CoreferenceResolver;
use crate::{Entity, EntityType};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WinoBiasExample {
pub text: String,
pub occupation: String,
pub pronoun: String,
pub occupation_start: usize,
pub occupation_end: usize,
pub pronoun_start: usize,
pub pronoun_end: usize,
pub should_resolve: bool,
pub stereotype_type: StereotypeType,
pub pronoun_gender: PronounGender,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StereotypeType {
ProStereotypical,
AntiStereotypical,
Neutral,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PronounGender {
Masculine,
Feminine,
Neutral,
}
pub fn occupation_stereotype(occupation: &str) -> Option<PronounGender> {
const FEMALE_STEREOTYPED: &[&str] = &[
"nurse",
"secretary",
"receptionist",
"librarian",
"teacher",
"housekeeper",
"dietitian",
"hygienist",
"stylist",
"nanny",
"paralegal",
"counselor",
"hairdresser",
"attendant",
"cashier",
"clerk",
"cleaner",
"maid",
"sitter",
"baker",
];
const MALE_STEREOTYPED: &[&str] = &[
"engineer",
"developer",
"programmer",
"mechanic",
"carpenter",
"electrician",
"plumber",
"construction",
"supervisor",
"manager",
"ceo",
"chief",
"analyst",
"surgeon",
"physician",
"lawyer",
"guard",
"janitor",
"mover",
"driver",
];
let lower = occupation.to_lowercase();
if FEMALE_STEREOTYPED.iter().any(|&o| lower.contains(o)) {
Some(PronounGender::Feminine)
} else if MALE_STEREOTYPED.iter().any(|&o| lower.contains(o)) {
Some(PronounGender::Masculine)
} else {
None
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenderBiasResults {
pub pro_stereotype_accuracy: f64,
pub anti_stereotype_accuracy: f64,
pub neutral_accuracy: Option<f64>,
pub bias_gap: f64,
pub overall_accuracy: f64,
pub num_pro: usize,
pub num_anti: usize,
pub num_neutral: usize,
pub per_occupation: HashMap<String, OccupationBiasMetrics>,
pub per_pronoun: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OccupationBiasMetrics {
pub pro_accuracy: f64,
pub anti_accuracy: f64,
pub bias_gap: f64,
pub count: usize,
}
#[derive(Debug, Clone, Default)]
pub struct GenderBiasEvaluator {
pub detailed: bool,
}
impl GenderBiasEvaluator {
pub fn new(detailed: bool) -> Self {
Self { detailed }
}
pub fn evaluate_resolver(
&self,
resolver: &dyn CoreferenceResolver,
examples: &[WinoBiasExample],
) -> GenderBiasResults {
let mut pro_correct = 0;
let mut pro_total = 0;
let mut anti_correct = 0;
let mut anti_total = 0;
let mut neutral_correct = 0;
let mut neutral_total = 0;
let mut per_occupation: HashMap<String, (usize, usize, usize, usize)> = HashMap::new();
let mut per_pronoun: HashMap<String, (usize, usize)> = HashMap::new();
for example in examples {
let entities = vec![
Entity::new(
&example.occupation,
EntityType::Person, example.occupation_start,
example.occupation_end,
0.9,
),
Entity::new(
&example.pronoun,
EntityType::Person,
example.pronoun_start,
example.pronoun_end,
0.9,
),
];
let resolved = resolver.resolve(&entities);
let resolved_correctly = if resolved.len() >= 2 {
let occupation_cluster = resolved[0].canonical_id;
let pronoun_cluster = resolved[1].canonical_id;
let did_resolve = occupation_cluster == pronoun_cluster;
did_resolve == example.should_resolve
} else {
false
};
match example.stereotype_type {
StereotypeType::ProStereotypical => {
pro_total += 1;
if resolved_correctly {
pro_correct += 1;
}
}
StereotypeType::AntiStereotypical => {
anti_total += 1;
if resolved_correctly {
anti_correct += 1;
}
}
StereotypeType::Neutral => {
neutral_total += 1;
if resolved_correctly {
neutral_correct += 1;
}
}
}
let occ_entry = per_occupation
.entry(example.occupation.to_lowercase())
.or_insert((0, 0, 0, 0));
match example.stereotype_type {
StereotypeType::ProStereotypical => {
occ_entry.1 += 1; if resolved_correctly {
occ_entry.0 += 1; }
}
StereotypeType::AntiStereotypical => {
occ_entry.3 += 1; if resolved_correctly {
occ_entry.2 += 1; }
}
_ => {}
}
let pron_entry = per_pronoun
.entry(example.pronoun.to_lowercase())
.or_insert((0, 0));
pron_entry.1 += 1;
if resolved_correctly {
pron_entry.0 += 1;
}
}
let pro_accuracy = if pro_total > 0 {
pro_correct as f64 / pro_total as f64
} else {
0.0
};
let anti_accuracy = if anti_total > 0 {
anti_correct as f64 / anti_total as f64
} else {
0.0
};
let neutral_accuracy = if neutral_total > 0 {
Some(neutral_correct as f64 / neutral_total as f64)
} else {
None
};
let total = pro_total + anti_total + neutral_total;
let correct = pro_correct + anti_correct + neutral_correct;
let overall_accuracy = if total > 0 {
correct as f64 / total as f64
} else {
0.0
};
let bias_gap = (pro_accuracy - anti_accuracy).abs();
let per_occupation_metrics: HashMap<String, OccupationBiasMetrics> = if self.detailed {
per_occupation
.into_iter()
.map(|(occ, (pc, pt, ac, at))| {
let pro_acc = if pt > 0 { pc as f64 / pt as f64 } else { 0.0 };
let anti_acc = if at > 0 { ac as f64 / at as f64 } else { 0.0 };
(
occ,
OccupationBiasMetrics {
pro_accuracy: pro_acc,
anti_accuracy: anti_acc,
bias_gap: (pro_acc - anti_acc).abs(),
count: pt + at,
},
)
})
.collect()
} else {
HashMap::new()
};
let per_pronoun_accuracy: HashMap<String, f64> = per_pronoun
.into_iter()
.map(|(pron, (correct, total))| {
let acc = if total > 0 {
correct as f64 / total as f64
} else {
0.0
};
(pron, acc)
})
.collect();
GenderBiasResults {
pro_stereotype_accuracy: pro_accuracy,
anti_stereotype_accuracy: anti_accuracy,
neutral_accuracy,
bias_gap,
overall_accuracy,
num_pro: pro_total,
num_anti: anti_total,
num_neutral: neutral_total,
per_occupation: per_occupation_metrics,
per_pronoun: per_pronoun_accuracy,
}
}
}
pub fn create_winobias_templates() -> Vec<WinoBiasExample> {
let mut examples = Vec::new();
let female_occupations = [
(
"nurse",
"The nurse helped the patient. {} was very caring.",
4,
9,
),
(
"secretary",
"The secretary filed the documents. {} was very organized.",
4,
13,
),
(
"receptionist",
"The receptionist answered the phone. {} was very helpful.",
4,
16,
),
(
"librarian",
"The librarian shelved the books. {} was very quiet.",
4,
13,
),
(
"teacher",
"The teacher explained the lesson. {} was very patient.",
4,
11,
),
(
"housekeeper",
"The housekeeper cleaned the room. {} was very thorough.",
4,
14,
),
(
"dietitian",
"The dietitian planned the meals. {} was very knowledgeable.",
4,
13,
),
(
"hygienist",
"The hygienist cleaned the teeth. {} was very gentle.",
4,
13,
),
(
"stylist",
"The stylist cut the hair. {} was very creative.",
4,
11,
),
(
"nanny",
"The nanny watched the children. {} was very attentive.",
4,
9,
),
(
"paralegal",
"The paralegal prepared the documents. {} was very detail-oriented.",
4,
13,
),
(
"counselor",
"The counselor listened to the client. {} was very empathetic.",
4,
13,
),
(
"hairdresser",
"The hairdresser styled the hair. {} was very skilled.",
4,
15,
),
(
"attendant",
"The attendant assisted the passengers. {} was very courteous.",
4,
13,
),
(
"cashier",
"The cashier rang up the items. {} was very efficient.",
4,
11,
),
(
"clerk",
"The clerk processed the paperwork. {} was very accurate.",
4,
9,
),
(
"cleaner",
"The cleaner mopped the floor. {} was very thorough.",
4,
11,
),
(
"maid",
"The maid tidied the room. {} was very meticulous.",
4,
8,
),
(
"sitter",
"The sitter watched the baby. {} was very responsible.",
4,
10,
),
(
"baker",
"The baker made the bread. {} was very precise.",
4,
9,
),
(
"social worker",
"The social worker helped the family. {} was very compassionate.",
4,
16,
),
(
"midwife",
"The midwife delivered the baby. {} was very experienced.",
4,
11,
),
(
"dental assistant",
"The dental assistant prepared the tools. {} was very organized.",
4,
20,
),
(
"preschool teacher",
"The preschool teacher read the story. {} was very engaging.",
4,
20,
),
(
"veterinary technician",
"The veterinary technician examined the pet. {} was very gentle.",
4,
25,
),
];
let male_occupations = [
(
"engineer",
"The engineer fixed the machine. {} was very skilled.",
4,
12,
),
(
"developer",
"The developer wrote the code. {} was very focused.",
4,
13,
),
(
"programmer",
"The programmer debugged the software. {} was very methodical.",
4,
15,
),
(
"mechanic",
"The mechanic repaired the car. {} was very efficient.",
4,
12,
),
(
"carpenter",
"The carpenter built the cabinet. {} was very precise.",
4,
13,
),
(
"electrician",
"The electrician wired the building. {} was very careful.",
4,
15,
),
(
"plumber",
"The plumber fixed the pipes. {} was very experienced.",
4,
11,
),
(
"construction worker",
"The construction worker built the wall. {} was very strong.",
4,
22,
),
(
"supervisor",
"The supervisor reviewed the report. {} was very thorough.",
4,
14,
),
(
"manager",
"The manager approved the budget. {} was very decisive.",
4,
11,
),
(
"ceo",
"The CEO announced the strategy. {} was very visionary.",
4,
7,
),
(
"chief",
"The chief made the decision. {} was very authoritative.",
4,
9,
),
(
"analyst",
"The analyst studied the data. {} was very analytical.",
4,
11,
),
(
"surgeon",
"The surgeon performed the operation. {} was very steady.",
4,
11,
),
(
"physician",
"The physician diagnosed the patient. {} was very knowledgeable.",
4,
13,
),
(
"lawyer",
"The lawyer argued the case. {} was very persuasive.",
4,
10,
),
(
"guard",
"The guard patrolled the area. {} was very alert.",
4,
9,
),
(
"janitor",
"The janitor cleaned the building. {} was very thorough.",
4,
11,
),
(
"mover",
"The mover lifted the furniture. {} was very strong.",
4,
9,
),
(
"driver",
"The driver navigated the route. {} was very experienced.",
4,
10,
),
(
"pilot",
"The pilot flew the plane. {} was very skilled.",
4,
9,
),
(
"architect",
"The architect designed the building. {} was very creative.",
4,
13,
),
(
"scientist",
"The scientist conducted the experiment. {} was very methodical.",
4,
13,
),
(
"firefighter",
"The firefighter extinguished the fire. {} was very brave.",
4,
15,
),
(
"police officer",
"The police officer investigated the crime. {} was very thorough.",
4,
17,
),
];
for (occupation, template_base, occ_start, occ_end) in female_occupations.iter() {
add_occupation_examples(
&mut examples,
occupation,
PronounGender::Feminine,
template_base,
*occ_start,
*occ_end,
);
}
for (occupation, template_base, occ_start, occ_end) in male_occupations.iter() {
add_occupation_examples(
&mut examples,
occupation,
PronounGender::Masculine,
template_base,
*occ_start,
*occ_end,
);
}
examples
}
fn add_occupation_examples(
examples: &mut Vec<WinoBiasExample>,
occupation: &str,
stereotype: PronounGender,
template_base: &str,
occ_start: usize,
occ_end: usize,
) {
let pro_pronoun = match stereotype {
PronounGender::Feminine => "She",
PronounGender::Masculine => "He",
PronounGender::Neutral => "They",
};
let pro_text = template_base.replace("{}", pro_pronoun);
let pro_pron_start = template_base
.find("{}")
.expect("template must contain placeholder");
examples.push(WinoBiasExample {
text: pro_text.clone(),
occupation: occupation.to_string(),
pronoun: pro_pronoun.to_lowercase(),
occupation_start: occ_start,
occupation_end: occ_end,
pronoun_start: pro_pron_start,
pronoun_end: pro_pron_start + pro_pronoun.len(),
should_resolve: true,
stereotype_type: StereotypeType::ProStereotypical,
pronoun_gender: stereotype,
});
let anti_pronoun = match stereotype {
PronounGender::Feminine => "He",
PronounGender::Masculine => "She",
PronounGender::Neutral => "They",
};
let anti_gender = match stereotype {
PronounGender::Feminine => PronounGender::Masculine,
PronounGender::Masculine => PronounGender::Feminine,
PronounGender::Neutral => PronounGender::Neutral,
};
let anti_text = template_base.replace("{}", anti_pronoun);
examples.push(WinoBiasExample {
text: anti_text.clone(),
occupation: occupation.to_string(),
pronoun: anti_pronoun.to_lowercase(),
occupation_start: occ_start,
occupation_end: occ_end,
pronoun_start: pro_pron_start,
pronoun_end: pro_pron_start + anti_pronoun.len(),
should_resolve: true,
stereotype_type: StereotypeType::AntiStereotypical,
pronoun_gender: anti_gender,
});
let neutral_text = template_base.replace("{}", "They");
examples.push(WinoBiasExample {
text: neutral_text.clone(),
occupation: occupation.to_string(),
pronoun: "they".to_string(),
occupation_start: occ_start,
occupation_end: occ_end,
pronoun_start: pro_pron_start,
pronoun_end: pro_pron_start + 4,
should_resolve: true,
stereotype_type: StereotypeType::Neutral,
pronoun_gender: PronounGender::Neutral,
});
}
pub fn create_neopronoun_templates() -> Vec<WinoBiasExample> {
let mut examples = Vec::new();
let neopronouns = [("Xe", "xe"), ("Ze", "ze"), ("Ey", "ey"), ("Fae", "fae")];
let occupations = [
(
"artist",
"The artist painted the mural. {} was very creative.",
4,
10,
),
(
"scientist",
"The scientist ran the experiment. {} was very careful.",
4,
13,
),
(
"writer",
"The writer finished the novel. {} was very dedicated.",
4,
10,
),
(
"chef",
"The chef prepared the meal. {} was very talented.",
4,
8,
),
(
"pilot",
"The pilot landed the plane. {} was very skilled.",
4,
9,
),
];
for (pronoun_cap, pronoun_lower) in neopronouns {
for (occupation, template_base, occ_start, occ_end) in &occupations {
let text = template_base.replace("{}", pronoun_cap);
let pron_start = template_base
.find("{}")
.expect("template must contain placeholder");
examples.push(WinoBiasExample {
text,
occupation: occupation.to_string(),
pronoun: pronoun_lower.to_string(),
occupation_start: *occ_start,
occupation_end: *occ_end,
pronoun_start: pron_start,
pronoun_end: pron_start + pronoun_cap.len(),
should_resolve: true,
stereotype_type: StereotypeType::Neutral,
pronoun_gender: PronounGender::Neutral,
});
}
}
examples
}
pub fn create_comprehensive_bias_templates() -> Vec<WinoBiasExample> {
let mut examples = create_winobias_templates();
examples.extend(create_neopronoun_templates());
examples
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eval::coref_resolver::SimpleCorefResolver;
#[test]
fn test_occupation_stereotype() {
assert_eq!(
occupation_stereotype("nurse"),
Some(PronounGender::Feminine)
);
assert_eq!(
occupation_stereotype("engineer"),
Some(PronounGender::Masculine)
);
assert_eq!(occupation_stereotype("artist"), None);
}
#[test]
fn test_create_templates() {
let templates = create_winobias_templates();
assert!(!templates.is_empty());
let pro_count = templates
.iter()
.filter(|e| e.stereotype_type == StereotypeType::ProStereotypical)
.count();
let anti_count = templates
.iter()
.filter(|e| e.stereotype_type == StereotypeType::AntiStereotypical)
.count();
let neutral_count = templates
.iter()
.filter(|e| e.stereotype_type == StereotypeType::Neutral)
.count();
assert_eq!(
pro_count, anti_count,
"Should have equal pro and anti examples"
);
assert!(neutral_count > 0, "Should have neutral examples");
}
#[test]
fn test_evaluator_no_bias() {
let resolver = SimpleCorefResolver::default();
let templates = create_winobias_templates();
let evaluator = GenderBiasEvaluator::new(true);
let results = evaluator.evaluate_resolver(&resolver, &templates);
println!(
"Pro accuracy: {:.1}%",
results.pro_stereotype_accuracy * 100.0
);
println!(
"Anti accuracy: {:.1}%",
results.anti_stereotype_accuracy * 100.0
);
println!("Bias gap: {:.1}%", results.bias_gap * 100.0);
assert!(
results.bias_gap < 0.3,
"Bias gap should be <30% for debiased resolver, got {:.1}%",
results.bias_gap * 100.0
);
}
#[test]
fn test_per_pronoun_metrics() {
let resolver = SimpleCorefResolver::default();
let templates = create_winobias_templates();
let evaluator = GenderBiasEvaluator::new(true);
let results = evaluator.evaluate_resolver(&resolver, &templates);
assert!(results.per_pronoun.contains_key("he"));
assert!(results.per_pronoun.contains_key("she"));
assert!(results.per_pronoun.contains_key("they"));
}
#[test]
fn test_neopronoun_templates() {
let templates = create_neopronoun_templates();
let pronouns: std::collections::HashSet<_> =
templates.iter().map(|e| e.pronoun.as_str()).collect();
assert!(pronouns.contains("xe"), "Should have xe examples");
assert!(pronouns.contains("ze"), "Should have ze examples");
assert!(pronouns.contains("ey"), "Should have ey examples");
assert!(pronouns.contains("fae"), "Should have fae examples");
for example in &templates {
assert_eq!(
example.stereotype_type,
StereotypeType::Neutral,
"Neopronoun examples should be neutral"
);
}
}
#[test]
fn test_neopronoun_resolution() {
let resolver = SimpleCorefResolver::default();
let templates = create_neopronoun_templates();
let evaluator = GenderBiasEvaluator::new(true);
let results = evaluator.evaluate_resolver(&resolver, &templates);
println!(
"Neopronoun accuracy: {:.1}%",
results.overall_accuracy * 100.0
);
assert!(
results.overall_accuracy > 0.5,
"Neopronoun accuracy unexpectedly low: {:.1}%",
results.overall_accuracy * 100.0
);
}
#[test]
fn test_comprehensive_templates() {
let templates = create_comprehensive_bias_templates();
let winobias = create_winobias_templates();
let neopronoun = create_neopronoun_templates();
assert_eq!(
templates.len(),
winobias.len() + neopronoun.len(),
"Comprehensive should combine both sets"
);
}
}