Skip to main content

Classifier

Struct Classifier 

Source
pub struct Classifier<C> { /* private fields */ }
Expand description

A Naive Bayes classifier, optionally upgraded with TF-IDF weights.

Plain Naive Bayes is always available via [classify], [log_scores], and [prob_scores]. After calling [build_tfidf] the TF-IDF variants become available. You may call [build_tfidf] again at any time after learning more documents — it recomputes from scratch over all accumulated data.

Implementations§

Source§

impl<C: Eq + Hash + Clone> Classifier<C>

Source

pub fn new(classes: Vec<C>) -> Self

Creates a new classifier.

Panics if fewer than two classes are given, or if they are not unique.

Examples found in repository?
examples/names.rs (line 60)
57fn main() {
58    use Sex::*;
59
60    let mut classifier = Classifier::new(vec![Male, Female]);
61
62    for (name, sex) in training_data() {
63        classifier.learn(&featurize(&name), &sex);
64    }
65
66    let test_cases: &[(&str, Sex)] = &[
67        // Unambiguous male
68        ("james", Male),
69        ("michael", Male),
70        ("bob", Male),
71        ("joe", Male),
72        ("corey", Male),
73        ("tyler", Male),
74        ("brad", Male),
75        ("chad", Male),
76        ("derek", Male),
77        ("lance", Male),
78        // Unambiguous female
79        ("sara", Female),
80        ("whitney", Female),
81        ("jessica", Female),
82        ("emily", Female),
83        ("tiffany", Female),
84        ("brittany", Female),
85        ("amanda", Female),
86        ("natalie", Female),
87        // Genuinely ambiguous / androgynous
88        ("sam", Male),
89        ("alex", Male),
90        ("morgan", Female),
91        ("taylor", Female),
92        ("jordan", Male),
93        ("casey", Male),
94        ("riley", Female),
95        ("charlie", Male),
96        ("robin", Female),
97        ("drew", Male),
98    ];
99
100    let male_idx = classifier
101        .classes()
102        .iter()
103        .position(|c| c == &Male)
104        .unwrap();
105
106    let female_idx = classifier
107        .classes()
108        .iter()
109        .position(|c| c == &Female)
110        .unwrap();
111
112    println!("Trained on {} names\n", classifier.learned());
113    println!("{}", "-".repeat(56));
114
115    let mut correct = 0;
116
117    for (name, expected) in test_cases {
118        let features = featurize(name);
119        let predicted = classifier.classify(&features);
120        let probs = classifier.prob_scores(&features);
121
122        if predicted == expected {
123            correct += 1;
124        }
125
126        println!(
127            "[{}] {name:<12} predicted={predicted:?}  expected={expected:?}  male={:.1}%  female={:.1}%",
128            if predicted == expected { "✓" } else { "✗" },
129            probs[male_idx] * 100.0,
130            probs[female_idx] * 100.0,
131        );
132    }
133
134    println!("{}", "-".repeat(56));
135    println!(
136        "Accuracy: {correct}/{}  ({:.0}%)",
137        test_cases.len(),
138        correct as f64 / test_cases.len() as f64 * 100.0
139    );
140}
More examples
Hide additional examples
examples/spam.rs (line 23)
20fn main() {
21    use Category::*;
22
23    let mut classifier = Classifier::new(vec![Spam, Ham]);
24
25    let training: &[(&str, Category)] = &[
26        ("Buy cheap pills online now limited offer", Spam),
27        ("You have won a free prize click here to claim", Spam),
28        ("Make money fast work from home guaranteed", Spam),
29        ("Exclusive deal buy now huge discount sale", Spam),
30        ("Free Viagra cheap meds no prescription needed", Spam),
31        ("Congratulations you are our lucky winner click", Spam),
32        ("Earn cash fast no experience required apply now", Spam),
33        ("Special offer buy one get one free today only", Spam),
34        ("Hey are you coming to the meeting tomorrow", Ham),
35        ("Can we schedule a call this afternoon to catch up", Ham),
36        ("The project report is due on Friday please review", Ham),
37        ("Let me know if you need anything from the store", Ham),
38        ("Dinner tonight at the usual place around seven", Ham),
39        ("Just wanted to check in hope everything is well", Ham),
40        (
41            "The presentation went really well thanks for your help",
42            Ham,
43        ),
44        ("Could you send me the notes from this morning", Ham),
45    ];
46
47    for (text, category) in training {
48        classifier.learn(&tokenize(text), category);
49    }
50
51    let test_cases: &[(&str, Category)] = &[
52        ("Free money click now buy cheap online offer", Spam),
53        ("Hey can we meet tomorrow morning for coffee", Ham),
54        ("You won a prize claim your free gift now", Spam),
55        ("The report looks good nice work on the project", Ham),
56        ("Limited time offer buy now get discount pills", Spam),
57        ("Just checking in let me know how you are doing", Ham),
58    ];
59
60    let spam_idx = classifier
61        .classes()
62        .iter()
63        .position(|c| c == &Spam)
64        .unwrap();
65    let ham_idx = classifier.classes().iter().position(|c| c == &Ham).unwrap();
66
67    println!("Trained on {} documents\n", classifier.learned());
68    println!("{}", "-".repeat(60));
69
70    let mut correct = 0;
71
72    for (text, expected) in test_cases {
73        let tokens = tokenize(text);
74        let predicted = classifier.classify(&tokens);
75        let probs = classifier.prob_scores(&tokens);
76
77        if predicted == expected {
78            correct += 1;
79        }
80
81        println!(
82            "[{}] \"{text}\"",
83            if predicted == expected { "✓" } else { "✗" }
84        );
85        println!(
86            "    predicted={predicted:?}  expected={expected:?}  spam={:.1}%  ham={:.1}%\n",
87            probs[spam_idx] * 100.0,
88            probs[ham_idx] * 100.0,
89        );
90    }
91
92    println!("{}", "-".repeat(60));
93    println!("Accuracy: {correct}/{}", test_cases.len());
94}
Source

pub fn learn<S: AsRef<str>>(&mut self, document: &[S], class: &C)

Trains the classifier on a document (a slice of words) for the given class.

Accumulates both raw word counts (used by plain Naive Bayes) and per-document TF samples (used by [build_tfidf]). The two sets of data are kept separate, so calling [build_tfidf] never destroys the raw counts and plain classification always remains available.

Examples found in repository?
examples/names.rs (line 63)
57fn main() {
58    use Sex::*;
59
60    let mut classifier = Classifier::new(vec![Male, Female]);
61
62    for (name, sex) in training_data() {
63        classifier.learn(&featurize(&name), &sex);
64    }
65
66    let test_cases: &[(&str, Sex)] = &[
67        // Unambiguous male
68        ("james", Male),
69        ("michael", Male),
70        ("bob", Male),
71        ("joe", Male),
72        ("corey", Male),
73        ("tyler", Male),
74        ("brad", Male),
75        ("chad", Male),
76        ("derek", Male),
77        ("lance", Male),
78        // Unambiguous female
79        ("sara", Female),
80        ("whitney", Female),
81        ("jessica", Female),
82        ("emily", Female),
83        ("tiffany", Female),
84        ("brittany", Female),
85        ("amanda", Female),
86        ("natalie", Female),
87        // Genuinely ambiguous / androgynous
88        ("sam", Male),
89        ("alex", Male),
90        ("morgan", Female),
91        ("taylor", Female),
92        ("jordan", Male),
93        ("casey", Male),
94        ("riley", Female),
95        ("charlie", Male),
96        ("robin", Female),
97        ("drew", Male),
98    ];
99
100    let male_idx = classifier
101        .classes()
102        .iter()
103        .position(|c| c == &Male)
104        .unwrap();
105
106    let female_idx = classifier
107        .classes()
108        .iter()
109        .position(|c| c == &Female)
110        .unwrap();
111
112    println!("Trained on {} names\n", classifier.learned());
113    println!("{}", "-".repeat(56));
114
115    let mut correct = 0;
116
117    for (name, expected) in test_cases {
118        let features = featurize(name);
119        let predicted = classifier.classify(&features);
120        let probs = classifier.prob_scores(&features);
121
122        if predicted == expected {
123            correct += 1;
124        }
125
126        println!(
127            "[{}] {name:<12} predicted={predicted:?}  expected={expected:?}  male={:.1}%  female={:.1}%",
128            if predicted == expected { "✓" } else { "✗" },
129            probs[male_idx] * 100.0,
130            probs[female_idx] * 100.0,
131        );
132    }
133
134    println!("{}", "-".repeat(56));
135    println!(
136        "Accuracy: {correct}/{}  ({:.0}%)",
137        test_cases.len(),
138        correct as f64 / test_cases.len() as f64 * 100.0
139    );
140}
More examples
Hide additional examples
examples/spam.rs (line 48)
20fn main() {
21    use Category::*;
22
23    let mut classifier = Classifier::new(vec![Spam, Ham]);
24
25    let training: &[(&str, Category)] = &[
26        ("Buy cheap pills online now limited offer", Spam),
27        ("You have won a free prize click here to claim", Spam),
28        ("Make money fast work from home guaranteed", Spam),
29        ("Exclusive deal buy now huge discount sale", Spam),
30        ("Free Viagra cheap meds no prescription needed", Spam),
31        ("Congratulations you are our lucky winner click", Spam),
32        ("Earn cash fast no experience required apply now", Spam),
33        ("Special offer buy one get one free today only", Spam),
34        ("Hey are you coming to the meeting tomorrow", Ham),
35        ("Can we schedule a call this afternoon to catch up", Ham),
36        ("The project report is due on Friday please review", Ham),
37        ("Let me know if you need anything from the store", Ham),
38        ("Dinner tonight at the usual place around seven", Ham),
39        ("Just wanted to check in hope everything is well", Ham),
40        (
41            "The presentation went really well thanks for your help",
42            Ham,
43        ),
44        ("Could you send me the notes from this morning", Ham),
45    ];
46
47    for (text, category) in training {
48        classifier.learn(&tokenize(text), category);
49    }
50
51    let test_cases: &[(&str, Category)] = &[
52        ("Free money click now buy cheap online offer", Spam),
53        ("Hey can we meet tomorrow morning for coffee", Ham),
54        ("You won a prize claim your free gift now", Spam),
55        ("The report looks good nice work on the project", Ham),
56        ("Limited time offer buy now get discount pills", Spam),
57        ("Just checking in let me know how you are doing", Ham),
58    ];
59
60    let spam_idx = classifier
61        .classes()
62        .iter()
63        .position(|c| c == &Spam)
64        .unwrap();
65    let ham_idx = classifier.classes().iter().position(|c| c == &Ham).unwrap();
66
67    println!("Trained on {} documents\n", classifier.learned());
68    println!("{}", "-".repeat(60));
69
70    let mut correct = 0;
71
72    for (text, expected) in test_cases {
73        let tokens = tokenize(text);
74        let predicted = classifier.classify(&tokens);
75        let probs = classifier.prob_scores(&tokens);
76
77        if predicted == expected {
78            correct += 1;
79        }
80
81        println!(
82            "[{}] \"{text}\"",
83            if predicted == expected { "✓" } else { "✗" }
84        );
85        println!(
86            "    predicted={predicted:?}  expected={expected:?}  spam={:.1}%  ham={:.1}%\n",
87            probs[spam_idx] * 100.0,
88            probs[ham_idx] * 100.0,
89        );
90    }
91
92    println!("{}", "-".repeat(60));
93    println!("Accuracy: {correct}/{}", test_cases.len());
94}
Source

pub fn build_tfidf(&mut self)

Computes TF-IDF weights from all documents learned so far and stores them internally, enabling the _tfidf family of classification methods.

Safe to call multiple times — each call recomputes from scratch over the full accumulated training set, so you can learn more documents and call this again to refresh the weights without losing anything.

IDF formula: ln(1 + N / df) where N is the total number of documents learned and df is the number of documents (across all classes) that contain the word.

Per-word weight: Σ ln(1 + tf) * idf summed over every training document in the class that contains the word.

Source

pub fn has_tfidf(&self) -> bool

Returns true if [build_tfidf] has been called and TF-IDF weights are ready for classification.

Source

pub fn learned(&self) -> usize

Returns the total number of documents the classifier has been trained on.

Examples found in repository?
examples/names.rs (line 112)
57fn main() {
58    use Sex::*;
59
60    let mut classifier = Classifier::new(vec![Male, Female]);
61
62    for (name, sex) in training_data() {
63        classifier.learn(&featurize(&name), &sex);
64    }
65
66    let test_cases: &[(&str, Sex)] = &[
67        // Unambiguous male
68        ("james", Male),
69        ("michael", Male),
70        ("bob", Male),
71        ("joe", Male),
72        ("corey", Male),
73        ("tyler", Male),
74        ("brad", Male),
75        ("chad", Male),
76        ("derek", Male),
77        ("lance", Male),
78        // Unambiguous female
79        ("sara", Female),
80        ("whitney", Female),
81        ("jessica", Female),
82        ("emily", Female),
83        ("tiffany", Female),
84        ("brittany", Female),
85        ("amanda", Female),
86        ("natalie", Female),
87        // Genuinely ambiguous / androgynous
88        ("sam", Male),
89        ("alex", Male),
90        ("morgan", Female),
91        ("taylor", Female),
92        ("jordan", Male),
93        ("casey", Male),
94        ("riley", Female),
95        ("charlie", Male),
96        ("robin", Female),
97        ("drew", Male),
98    ];
99
100    let male_idx = classifier
101        .classes()
102        .iter()
103        .position(|c| c == &Male)
104        .unwrap();
105
106    let female_idx = classifier
107        .classes()
108        .iter()
109        .position(|c| c == &Female)
110        .unwrap();
111
112    println!("Trained on {} names\n", classifier.learned());
113    println!("{}", "-".repeat(56));
114
115    let mut correct = 0;
116
117    for (name, expected) in test_cases {
118        let features = featurize(name);
119        let predicted = classifier.classify(&features);
120        let probs = classifier.prob_scores(&features);
121
122        if predicted == expected {
123            correct += 1;
124        }
125
126        println!(
127            "[{}] {name:<12} predicted={predicted:?}  expected={expected:?}  male={:.1}%  female={:.1}%",
128            if predicted == expected { "✓" } else { "✗" },
129            probs[male_idx] * 100.0,
130            probs[female_idx] * 100.0,
131        );
132    }
133
134    println!("{}", "-".repeat(56));
135    println!(
136        "Accuracy: {correct}/{}  ({:.0}%)",
137        test_cases.len(),
138        correct as f64 / test_cases.len() as f64 * 100.0
139    );
140}
More examples
Hide additional examples
examples/spam.rs (line 67)
20fn main() {
21    use Category::*;
22
23    let mut classifier = Classifier::new(vec![Spam, Ham]);
24
25    let training: &[(&str, Category)] = &[
26        ("Buy cheap pills online now limited offer", Spam),
27        ("You have won a free prize click here to claim", Spam),
28        ("Make money fast work from home guaranteed", Spam),
29        ("Exclusive deal buy now huge discount sale", Spam),
30        ("Free Viagra cheap meds no prescription needed", Spam),
31        ("Congratulations you are our lucky winner click", Spam),
32        ("Earn cash fast no experience required apply now", Spam),
33        ("Special offer buy one get one free today only", Spam),
34        ("Hey are you coming to the meeting tomorrow", Ham),
35        ("Can we schedule a call this afternoon to catch up", Ham),
36        ("The project report is due on Friday please review", Ham),
37        ("Let me know if you need anything from the store", Ham),
38        ("Dinner tonight at the usual place around seven", Ham),
39        ("Just wanted to check in hope everything is well", Ham),
40        (
41            "The presentation went really well thanks for your help",
42            Ham,
43        ),
44        ("Could you send me the notes from this morning", Ham),
45    ];
46
47    for (text, category) in training {
48        classifier.learn(&tokenize(text), category);
49    }
50
51    let test_cases: &[(&str, Category)] = &[
52        ("Free money click now buy cheap online offer", Spam),
53        ("Hey can we meet tomorrow morning for coffee", Ham),
54        ("You won a prize claim your free gift now", Spam),
55        ("The report looks good nice work on the project", Ham),
56        ("Limited time offer buy now get discount pills", Spam),
57        ("Just checking in let me know how you are doing", Ham),
58    ];
59
60    let spam_idx = classifier
61        .classes()
62        .iter()
63        .position(|c| c == &Spam)
64        .unwrap();
65    let ham_idx = classifier.classes().iter().position(|c| c == &Ham).unwrap();
66
67    println!("Trained on {} documents\n", classifier.learned());
68    println!("{}", "-".repeat(60));
69
70    let mut correct = 0;
71
72    for (text, expected) in test_cases {
73        let tokens = tokenize(text);
74        let predicted = classifier.classify(&tokens);
75        let probs = classifier.prob_scores(&tokens);
76
77        if predicted == expected {
78            correct += 1;
79        }
80
81        println!(
82            "[{}] \"{text}\"",
83            if predicted == expected { "✓" } else { "✗" }
84        );
85        println!(
86            "    predicted={predicted:?}  expected={expected:?}  spam={:.1}%  ham={:.1}%\n",
87            probs[spam_idx] * 100.0,
88            probs[ham_idx] * 100.0,
89        );
90    }
91
92    println!("{}", "-".repeat(60));
93    println!("Accuracy: {correct}/{}", test_cases.len());
94}
Source

pub fn classes(&self) -> &[C]

Returns the ordered slice of class labels used to build this classifier.

Examples found in repository?
examples/names.rs (line 101)
57fn main() {
58    use Sex::*;
59
60    let mut classifier = Classifier::new(vec![Male, Female]);
61
62    for (name, sex) in training_data() {
63        classifier.learn(&featurize(&name), &sex);
64    }
65
66    let test_cases: &[(&str, Sex)] = &[
67        // Unambiguous male
68        ("james", Male),
69        ("michael", Male),
70        ("bob", Male),
71        ("joe", Male),
72        ("corey", Male),
73        ("tyler", Male),
74        ("brad", Male),
75        ("chad", Male),
76        ("derek", Male),
77        ("lance", Male),
78        // Unambiguous female
79        ("sara", Female),
80        ("whitney", Female),
81        ("jessica", Female),
82        ("emily", Female),
83        ("tiffany", Female),
84        ("brittany", Female),
85        ("amanda", Female),
86        ("natalie", Female),
87        // Genuinely ambiguous / androgynous
88        ("sam", Male),
89        ("alex", Male),
90        ("morgan", Female),
91        ("taylor", Female),
92        ("jordan", Male),
93        ("casey", Male),
94        ("riley", Female),
95        ("charlie", Male),
96        ("robin", Female),
97        ("drew", Male),
98    ];
99
100    let male_idx = classifier
101        .classes()
102        .iter()
103        .position(|c| c == &Male)
104        .unwrap();
105
106    let female_idx = classifier
107        .classes()
108        .iter()
109        .position(|c| c == &Female)
110        .unwrap();
111
112    println!("Trained on {} names\n", classifier.learned());
113    println!("{}", "-".repeat(56));
114
115    let mut correct = 0;
116
117    for (name, expected) in test_cases {
118        let features = featurize(name);
119        let predicted = classifier.classify(&features);
120        let probs = classifier.prob_scores(&features);
121
122        if predicted == expected {
123            correct += 1;
124        }
125
126        println!(
127            "[{}] {name:<12} predicted={predicted:?}  expected={expected:?}  male={:.1}%  female={:.1}%",
128            if predicted == expected { "✓" } else { "✗" },
129            probs[male_idx] * 100.0,
130            probs[female_idx] * 100.0,
131        );
132    }
133
134    println!("{}", "-".repeat(56));
135    println!(
136        "Accuracy: {correct}/{}  ({:.0}%)",
137        test_cases.len(),
138        correct as f64 / test_cases.len() as f64 * 100.0
139    );
140}
More examples
Hide additional examples
examples/spam.rs (line 61)
20fn main() {
21    use Category::*;
22
23    let mut classifier = Classifier::new(vec![Spam, Ham]);
24
25    let training: &[(&str, Category)] = &[
26        ("Buy cheap pills online now limited offer", Spam),
27        ("You have won a free prize click here to claim", Spam),
28        ("Make money fast work from home guaranteed", Spam),
29        ("Exclusive deal buy now huge discount sale", Spam),
30        ("Free Viagra cheap meds no prescription needed", Spam),
31        ("Congratulations you are our lucky winner click", Spam),
32        ("Earn cash fast no experience required apply now", Spam),
33        ("Special offer buy one get one free today only", Spam),
34        ("Hey are you coming to the meeting tomorrow", Ham),
35        ("Can we schedule a call this afternoon to catch up", Ham),
36        ("The project report is due on Friday please review", Ham),
37        ("Let me know if you need anything from the store", Ham),
38        ("Dinner tonight at the usual place around seven", Ham),
39        ("Just wanted to check in hope everything is well", Ham),
40        (
41            "The presentation went really well thanks for your help",
42            Ham,
43        ),
44        ("Could you send me the notes from this morning", Ham),
45    ];
46
47    for (text, category) in training {
48        classifier.learn(&tokenize(text), category);
49    }
50
51    let test_cases: &[(&str, Category)] = &[
52        ("Free money click now buy cheap online offer", Spam),
53        ("Hey can we meet tomorrow morning for coffee", Ham),
54        ("You won a prize claim your free gift now", Spam),
55        ("The report looks good nice work on the project", Ham),
56        ("Limited time offer buy now get discount pills", Spam),
57        ("Just checking in let me know how you are doing", Ham),
58    ];
59
60    let spam_idx = classifier
61        .classes()
62        .iter()
63        .position(|c| c == &Spam)
64        .unwrap();
65    let ham_idx = classifier.classes().iter().position(|c| c == &Ham).unwrap();
66
67    println!("Trained on {} documents\n", classifier.learned());
68    println!("{}", "-".repeat(60));
69
70    let mut correct = 0;
71
72    for (text, expected) in test_cases {
73        let tokens = tokenize(text);
74        let predicted = classifier.classify(&tokens);
75        let probs = classifier.prob_scores(&tokens);
76
77        if predicted == expected {
78            correct += 1;
79        }
80
81        println!(
82            "[{}] \"{text}\"",
83            if predicted == expected { "✓" } else { "✗" }
84        );
85        println!(
86            "    predicted={predicted:?}  expected={expected:?}  spam={:.1}%  ham={:.1}%\n",
87            probs[spam_idx] * 100.0,
88            probs[ham_idx] * 100.0,
89        );
90    }
91
92    println!("{}", "-".repeat(60));
93    println!("Accuracy: {correct}/{}", test_cases.len());
94}
Source

pub fn log_scores<S: AsRef<str>>(&self, document: &[S]) -> Vec<f64>

Returns the log-likelihood score for each class using raw word counts. Index i corresponds to classes()[i].

Source

pub fn prob_scores<S: AsRef<str>>(&self, document: &[S]) -> Vec<f64>

Returns normalised probability scores for each class using raw word counts (sum to 1.0). Index i corresponds to classes()[i].

Examples found in repository?
examples/names.rs (line 120)
57fn main() {
58    use Sex::*;
59
60    let mut classifier = Classifier::new(vec![Male, Female]);
61
62    for (name, sex) in training_data() {
63        classifier.learn(&featurize(&name), &sex);
64    }
65
66    let test_cases: &[(&str, Sex)] = &[
67        // Unambiguous male
68        ("james", Male),
69        ("michael", Male),
70        ("bob", Male),
71        ("joe", Male),
72        ("corey", Male),
73        ("tyler", Male),
74        ("brad", Male),
75        ("chad", Male),
76        ("derek", Male),
77        ("lance", Male),
78        // Unambiguous female
79        ("sara", Female),
80        ("whitney", Female),
81        ("jessica", Female),
82        ("emily", Female),
83        ("tiffany", Female),
84        ("brittany", Female),
85        ("amanda", Female),
86        ("natalie", Female),
87        // Genuinely ambiguous / androgynous
88        ("sam", Male),
89        ("alex", Male),
90        ("morgan", Female),
91        ("taylor", Female),
92        ("jordan", Male),
93        ("casey", Male),
94        ("riley", Female),
95        ("charlie", Male),
96        ("robin", Female),
97        ("drew", Male),
98    ];
99
100    let male_idx = classifier
101        .classes()
102        .iter()
103        .position(|c| c == &Male)
104        .unwrap();
105
106    let female_idx = classifier
107        .classes()
108        .iter()
109        .position(|c| c == &Female)
110        .unwrap();
111
112    println!("Trained on {} names\n", classifier.learned());
113    println!("{}", "-".repeat(56));
114
115    let mut correct = 0;
116
117    for (name, expected) in test_cases {
118        let features = featurize(name);
119        let predicted = classifier.classify(&features);
120        let probs = classifier.prob_scores(&features);
121
122        if predicted == expected {
123            correct += 1;
124        }
125
126        println!(
127            "[{}] {name:<12} predicted={predicted:?}  expected={expected:?}  male={:.1}%  female={:.1}%",
128            if predicted == expected { "✓" } else { "✗" },
129            probs[male_idx] * 100.0,
130            probs[female_idx] * 100.0,
131        );
132    }
133
134    println!("{}", "-".repeat(56));
135    println!(
136        "Accuracy: {correct}/{}  ({:.0}%)",
137        test_cases.len(),
138        correct as f64 / test_cases.len() as f64 * 100.0
139    );
140}
More examples
Hide additional examples
examples/spam.rs (line 75)
20fn main() {
21    use Category::*;
22
23    let mut classifier = Classifier::new(vec![Spam, Ham]);
24
25    let training: &[(&str, Category)] = &[
26        ("Buy cheap pills online now limited offer", Spam),
27        ("You have won a free prize click here to claim", Spam),
28        ("Make money fast work from home guaranteed", Spam),
29        ("Exclusive deal buy now huge discount sale", Spam),
30        ("Free Viagra cheap meds no prescription needed", Spam),
31        ("Congratulations you are our lucky winner click", Spam),
32        ("Earn cash fast no experience required apply now", Spam),
33        ("Special offer buy one get one free today only", Spam),
34        ("Hey are you coming to the meeting tomorrow", Ham),
35        ("Can we schedule a call this afternoon to catch up", Ham),
36        ("The project report is due on Friday please review", Ham),
37        ("Let me know if you need anything from the store", Ham),
38        ("Dinner tonight at the usual place around seven", Ham),
39        ("Just wanted to check in hope everything is well", Ham),
40        (
41            "The presentation went really well thanks for your help",
42            Ham,
43        ),
44        ("Could you send me the notes from this morning", Ham),
45    ];
46
47    for (text, category) in training {
48        classifier.learn(&tokenize(text), category);
49    }
50
51    let test_cases: &[(&str, Category)] = &[
52        ("Free money click now buy cheap online offer", Spam),
53        ("Hey can we meet tomorrow morning for coffee", Ham),
54        ("You won a prize claim your free gift now", Spam),
55        ("The report looks good nice work on the project", Ham),
56        ("Limited time offer buy now get discount pills", Spam),
57        ("Just checking in let me know how you are doing", Ham),
58    ];
59
60    let spam_idx = classifier
61        .classes()
62        .iter()
63        .position(|c| c == &Spam)
64        .unwrap();
65    let ham_idx = classifier.classes().iter().position(|c| c == &Ham).unwrap();
66
67    println!("Trained on {} documents\n", classifier.learned());
68    println!("{}", "-".repeat(60));
69
70    let mut correct = 0;
71
72    for (text, expected) in test_cases {
73        let tokens = tokenize(text);
74        let predicted = classifier.classify(&tokens);
75        let probs = classifier.prob_scores(&tokens);
76
77        if predicted == expected {
78            correct += 1;
79        }
80
81        println!(
82            "[{}] \"{text}\"",
83            if predicted == expected { "✓" } else { "✗" }
84        );
85        println!(
86            "    predicted={predicted:?}  expected={expected:?}  spam={:.1}%  ham={:.1}%\n",
87            probs[spam_idx] * 100.0,
88            probs[ham_idx] * 100.0,
89        );
90    }
91
92    println!("{}", "-".repeat(60));
93    println!("Accuracy: {correct}/{}", test_cases.len());
94}
Source

pub fn classify<S: AsRef<str>>(&self, document: &[S]) -> &C

Returns the most likely class for the given document using plain Naive Bayes.

Examples found in repository?
examples/names.rs (line 119)
57fn main() {
58    use Sex::*;
59
60    let mut classifier = Classifier::new(vec![Male, Female]);
61
62    for (name, sex) in training_data() {
63        classifier.learn(&featurize(&name), &sex);
64    }
65
66    let test_cases: &[(&str, Sex)] = &[
67        // Unambiguous male
68        ("james", Male),
69        ("michael", Male),
70        ("bob", Male),
71        ("joe", Male),
72        ("corey", Male),
73        ("tyler", Male),
74        ("brad", Male),
75        ("chad", Male),
76        ("derek", Male),
77        ("lance", Male),
78        // Unambiguous female
79        ("sara", Female),
80        ("whitney", Female),
81        ("jessica", Female),
82        ("emily", Female),
83        ("tiffany", Female),
84        ("brittany", Female),
85        ("amanda", Female),
86        ("natalie", Female),
87        // Genuinely ambiguous / androgynous
88        ("sam", Male),
89        ("alex", Male),
90        ("morgan", Female),
91        ("taylor", Female),
92        ("jordan", Male),
93        ("casey", Male),
94        ("riley", Female),
95        ("charlie", Male),
96        ("robin", Female),
97        ("drew", Male),
98    ];
99
100    let male_idx = classifier
101        .classes()
102        .iter()
103        .position(|c| c == &Male)
104        .unwrap();
105
106    let female_idx = classifier
107        .classes()
108        .iter()
109        .position(|c| c == &Female)
110        .unwrap();
111
112    println!("Trained on {} names\n", classifier.learned());
113    println!("{}", "-".repeat(56));
114
115    let mut correct = 0;
116
117    for (name, expected) in test_cases {
118        let features = featurize(name);
119        let predicted = classifier.classify(&features);
120        let probs = classifier.prob_scores(&features);
121
122        if predicted == expected {
123            correct += 1;
124        }
125
126        println!(
127            "[{}] {name:<12} predicted={predicted:?}  expected={expected:?}  male={:.1}%  female={:.1}%",
128            if predicted == expected { "✓" } else { "✗" },
129            probs[male_idx] * 100.0,
130            probs[female_idx] * 100.0,
131        );
132    }
133
134    println!("{}", "-".repeat(56));
135    println!(
136        "Accuracy: {correct}/{}  ({:.0}%)",
137        test_cases.len(),
138        correct as f64 / test_cases.len() as f64 * 100.0
139    );
140}
More examples
Hide additional examples
examples/spam.rs (line 74)
20fn main() {
21    use Category::*;
22
23    let mut classifier = Classifier::new(vec![Spam, Ham]);
24
25    let training: &[(&str, Category)] = &[
26        ("Buy cheap pills online now limited offer", Spam),
27        ("You have won a free prize click here to claim", Spam),
28        ("Make money fast work from home guaranteed", Spam),
29        ("Exclusive deal buy now huge discount sale", Spam),
30        ("Free Viagra cheap meds no prescription needed", Spam),
31        ("Congratulations you are our lucky winner click", Spam),
32        ("Earn cash fast no experience required apply now", Spam),
33        ("Special offer buy one get one free today only", Spam),
34        ("Hey are you coming to the meeting tomorrow", Ham),
35        ("Can we schedule a call this afternoon to catch up", Ham),
36        ("The project report is due on Friday please review", Ham),
37        ("Let me know if you need anything from the store", Ham),
38        ("Dinner tonight at the usual place around seven", Ham),
39        ("Just wanted to check in hope everything is well", Ham),
40        (
41            "The presentation went really well thanks for your help",
42            Ham,
43        ),
44        ("Could you send me the notes from this morning", Ham),
45    ];
46
47    for (text, category) in training {
48        classifier.learn(&tokenize(text), category);
49    }
50
51    let test_cases: &[(&str, Category)] = &[
52        ("Free money click now buy cheap online offer", Spam),
53        ("Hey can we meet tomorrow morning for coffee", Ham),
54        ("You won a prize claim your free gift now", Spam),
55        ("The report looks good nice work on the project", Ham),
56        ("Limited time offer buy now get discount pills", Spam),
57        ("Just checking in let me know how you are doing", Ham),
58    ];
59
60    let spam_idx = classifier
61        .classes()
62        .iter()
63        .position(|c| c == &Spam)
64        .unwrap();
65    let ham_idx = classifier.classes().iter().position(|c| c == &Ham).unwrap();
66
67    println!("Trained on {} documents\n", classifier.learned());
68    println!("{}", "-".repeat(60));
69
70    let mut correct = 0;
71
72    for (text, expected) in test_cases {
73        let tokens = tokenize(text);
74        let predicted = classifier.classify(&tokens);
75        let probs = classifier.prob_scores(&tokens);
76
77        if predicted == expected {
78            correct += 1;
79        }
80
81        println!(
82            "[{}] \"{text}\"",
83            if predicted == expected { "✓" } else { "✗" }
84        );
85        println!(
86            "    predicted={predicted:?}  expected={expected:?}  spam={:.1}%  ham={:.1}%\n",
87            probs[spam_idx] * 100.0,
88            probs[ham_idx] * 100.0,
89        );
90    }
91
92    println!("{}", "-".repeat(60));
93    println!("Accuracy: {correct}/{}", test_cases.len());
94}
Source

pub fn log_scores_tfidf<S: AsRef<str>>(&self, document: &[S]) -> Vec<f64>

Returns the log-likelihood score for each class using TF-IDF weights. Index i corresponds to classes()[i].

Panics if [build_tfidf] has not been called yet.

Source

pub fn prob_scores_tfidf<S: AsRef<str>>(&self, document: &[S]) -> Vec<f64>

Returns normalised probability scores for each class using TF-IDF weights (sum to 1.0). Index i corresponds to classes()[i].

Panics if [build_tfidf] has not been called yet.

Source

pub fn classify_tfidf<S: AsRef<str>>(&self, document: &[S]) -> &C

Returns the most likely class for the given document using TF-IDF weights.

Panics if [build_tfidf] has not been called yet.

Source

pub fn serialize(&self) -> Result<Vec<u8>>
where C: Serialize,

Serializes the classifier (including any built TF-IDF weights) to a compressed binary blob (bincode + gzip).

The result can be stored to disk, sent over the wire, etc. Pass it back to Classifier::from_data to reconstruct an identical classifier.

Source

pub fn from_data(data: impl AsRef<[u8]>) -> Result<Self>
where C: for<'de> Deserialize<'de>,

Reconstructs a Classifier from bytes previously produced by Classifier::serialize.

Source

pub fn serialize_to_file(&self, path: impl AsRef<Path>) -> Result<()>
where C: Serialize,

Writes the serialized classifier to a file at the given path.

Creates the file if it does not exist, truncates it if it does.

Source

pub fn from_file(path: impl AsRef<Path>) -> Result<Self>
where C: for<'de> Deserialize<'de>,

Reconstructs a Classifier from a file previously written by Classifier::serialize_to_file.

Trait Implementations§

Source§

impl<'de, C> Deserialize<'de> for Classifier<C>
where C: Deserialize<'de> + Eq + Hash,

Source§

fn deserialize<__D>(__deserializer: __D) -> Result<Self, __D::Error>
where __D: Deserializer<'de>,

Deserialize this value from the given Serde deserializer. Read more
Source§

impl<C> Serialize for Classifier<C>
where C: Serialize,

Source§

fn serialize<__S>(&self, __serializer: __S) -> Result<__S::Ok, __S::Error>
where __S: Serializer,

Serialize this value into the given Serde serializer. Read more

Auto Trait Implementations§

§

impl<C> Freeze for Classifier<C>

§

impl<C> RefUnwindSafe for Classifier<C>
where C: RefUnwindSafe,

§

impl<C> Send for Classifier<C>
where C: Send,

§

impl<C> Sync for Classifier<C>
where C: Sync,

§

impl<C> Unpin for Classifier<C>
where C: Unpin,

§

impl<C> UnsafeUnpin for Classifier<C>

§

impl<C> UnwindSafe for Classifier<C>
where C: UnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<T> DeserializeOwned for T
where T: for<'de> Deserialize<'de>,