Skip to main content

names/
names.rs

1use bayesian::Classifier;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash)]
4enum Sex {
5    Male,
6    Female,
7}
8
9impl std::str::FromStr for Sex {
10    type Err = String;
11
12    fn from_str(s: &str) -> Result<Self, Self::Err> {
13        match s {
14            "male" => Ok(Self::Male),
15            "female" => Ok(Self::Female),
16            _ => Err(format!("unknown sex: {s}")),
17        }
18    }
19}
20
21/// Breaks a name into character-level features that tend to signal gender:
22///   - the last 1, 2, and 3 characters (suffix features)
23///   - the first character (prefix feature)
24fn featurize(name: &str) -> Vec<String> {
25    let lower = name.to_lowercase();
26    let chars: Vec<char> = lower.chars().collect();
27    let len = chars.len();
28    let mut features = Vec::new();
29
30    for n in 1..=3 {
31        if len >= n {
32            let suffix: String = chars[len - n..].iter().collect();
33            features.push(format!("suffix:{suffix}"));
34        }
35    }
36
37    if let Some(&first) = chars.first() {
38        features.push(format!("prefix:{first}"));
39    }
40
41    features
42}
43
44/// Parses the embedded CSV and returns `(name, sex)` pairs, skipping the header
45/// and any malformed rows.
46fn training_data() -> Vec<(String, Sex)> {
47    const CSV: &str = include_str!("names.csv");
48    CSV.lines()
49        .skip(1) // header
50        .filter_map(|line| {
51            let (name, sex) = line.split_once(',')?;
52            Some((name.to_owned(), sex.parse().ok()?))
53        })
54        .collect()
55}
56
57fn main() {
58    use Sex::*;
59
60    let mut classifier = Classifier::new(vec![Male, Female]);
61
62    for (name, sex) in training_data() {
63        classifier.learn(&featurize(&name), &sex);
64    }
65
66    let test_cases: &[(&str, Sex)] = &[
67        // Unambiguous male
68        ("james", Male),
69        ("michael", Male),
70        ("bob", Male),
71        ("joe", Male),
72        ("corey", Male),
73        ("tyler", Male),
74        ("brad", Male),
75        ("chad", Male),
76        ("derek", Male),
77        ("lance", Male),
78        // Unambiguous female
79        ("sara", Female),
80        ("whitney", Female),
81        ("jessica", Female),
82        ("emily", Female),
83        ("tiffany", Female),
84        ("brittany", Female),
85        ("amanda", Female),
86        ("natalie", Female),
87        // Genuinely ambiguous / androgynous
88        ("sam", Male),
89        ("alex", Male),
90        ("morgan", Female),
91        ("taylor", Female),
92        ("jordan", Male),
93        ("casey", Male),
94        ("riley", Female),
95        ("charlie", Male),
96        ("robin", Female),
97        ("drew", Male),
98    ];
99
100    let male_idx = classifier
101        .classes()
102        .iter()
103        .position(|c| c == &Male)
104        .unwrap();
105
106    let female_idx = classifier
107        .classes()
108        .iter()
109        .position(|c| c == &Female)
110        .unwrap();
111
112    println!("Trained on {} names\n", classifier.learned());
113    println!("{}", "-".repeat(56));
114
115    let mut correct = 0;
116
117    for (name, expected) in test_cases {
118        let features = featurize(name);
119        let predicted = classifier.classify(&features);
120        let probs = classifier.prob_scores(&features);
121
122        if predicted == expected {
123            correct += 1;
124        }
125
126        println!(
127            "[{}] {name:<12} predicted={predicted:?}  expected={expected:?}  male={:.1}%  female={:.1}%",
128            if predicted == expected { "✓" } else { "✗" },
129            probs[male_idx] * 100.0,
130            probs[female_idx] * 100.0,
131        );
132    }
133
134    println!("{}", "-".repeat(56));
135    println!(
136        "Accuracy: {correct}/{}  ({:.0}%)",
137        test_cases.len(),
138        correct as f64 / test_cases.len() as f64 * 100.0
139    );
140}