1use bayesian::Classifier;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash)]
4enum Category {
5 Spam,
6 Ham,
7}
8
9fn tokenize(text: &str) -> Vec<String> {
10 text.split_whitespace()
11 .map(|w| {
12 w.to_lowercase()
13 .trim_matches(|c: char| !c.is_alphanumeric())
14 .to_owned()
15 })
16 .filter(|w| !w.is_empty())
17 .collect()
18}
19
20fn main() {
21 use Category::*;
22
23 let mut classifier = Classifier::new(vec![Spam, Ham]);
24
25 let training: &[(&str, Category)] = &[
26 ("Buy cheap pills online now limited offer", Spam),
27 ("You have won a free prize click here to claim", Spam),
28 ("Make money fast work from home guaranteed", Spam),
29 ("Exclusive deal buy now huge discount sale", Spam),
30 ("Free Viagra cheap meds no prescription needed", Spam),
31 ("Congratulations you are our lucky winner click", Spam),
32 ("Earn cash fast no experience required apply now", Spam),
33 ("Special offer buy one get one free today only", Spam),
34 ("Hey are you coming to the meeting tomorrow", Ham),
35 ("Can we schedule a call this afternoon to catch up", Ham),
36 ("The project report is due on Friday please review", Ham),
37 ("Let me know if you need anything from the store", Ham),
38 ("Dinner tonight at the usual place around seven", Ham),
39 ("Just wanted to check in hope everything is well", Ham),
40 (
41 "The presentation went really well thanks for your help",
42 Ham,
43 ),
44 ("Could you send me the notes from this morning", Ham),
45 ];
46
47 for (text, category) in training {
48 classifier.learn(&tokenize(text), category);
49 }
50
51 let test_cases: &[(&str, Category)] = &[
52 ("Free money click now buy cheap online offer", Spam),
53 ("Hey can we meet tomorrow morning for coffee", Ham),
54 ("You won a prize claim your free gift now", Spam),
55 ("The report looks good nice work on the project", Ham),
56 ("Limited time offer buy now get discount pills", Spam),
57 ("Just checking in let me know how you are doing", Ham),
58 ];
59
60 let spam_idx = classifier
61 .classes()
62 .iter()
63 .position(|c| c == &Spam)
64 .unwrap();
65 let ham_idx = classifier.classes().iter().position(|c| c == &Ham).unwrap();
66
67 println!("Trained on {} documents\n", classifier.learned());
68 println!("{}", "-".repeat(60));
69
70 let mut correct = 0;
71
72 for (text, expected) in test_cases {
73 let tokens = tokenize(text);
74 let predicted = classifier.classify(&tokens);
75 let probs = classifier.prob_scores(&tokens);
76
77 if predicted == expected {
78 correct += 1;
79 }
80
81 println!(
82 "[{}] \"{text}\"",
83 if predicted == expected { "✓" } else { "✗" }
84 );
85 println!(
86 " predicted={predicted:?} expected={expected:?} spam={:.1}% ham={:.1}%\n",
87 probs[spam_idx] * 100.0,
88 probs[ham_idx] * 100.0,
89 );
90 }
91
92 println!("{}", "-".repeat(60));
93 println!("Accuracy: {correct}/{}", test_cases.len());
94}