1use std::collections::{HashMap, HashSet};
2use std::hash::Hash;
3use std::io::{Read, Write};
4
5use flate2::{Compression, read::GzDecoder, write::GzEncoder};
6use serde::{Deserialize, Serialize};
7
8const DEFAULT_PROB: f64 = 1e-11;
9
10#[derive(Serialize, Deserialize)]
15struct ClassData {
16 freqs: HashMap<String, f64>,
17 total: usize,
18 tf_acc: HashMap<String, TfAccumulator>,
19}
20
21impl ClassData {
22 fn new() -> Self {
23 Self {
24 freqs: HashMap::new(),
25 total: 0,
26 tf_acc: HashMap::new(),
27 }
28 }
29
30 fn word_prob(&self, word: &str) -> f64 {
31 let vocab = self.freqs.len();
32 if self.total == 0 || vocab == 0 {
33 return DEFAULT_PROB;
34 }
35 let count = self.freqs.get(word).copied().unwrap_or(0.0);
36 (count + 1.0) / (self.total as f64 + vocab as f64)
37 }
38}
39
40#[derive(Serialize, Deserialize)]
41struct TfAccumulator {
42 count: usize,
43 sum_ln1p_tf: f64,
44}
45
46impl Default for TfAccumulator {
47 fn default() -> Self {
48 Self {
49 count: 0,
50 sum_ln1p_tf: 0.0,
51 }
52 }
53}
54
55#[derive(Serialize, Deserialize)]
56struct TfIdfClassData {
57 weights: HashMap<String, f64>,
58 total: f64,
59}
60
61impl TfIdfClassData {
62 fn word_prob(&self, word: &str) -> f64 {
63 let vocab = self.weights.len();
64 if self.total == 0.0 || vocab == 0 {
65 return DEFAULT_PROB;
66 }
67 let weight = self.weights.get(word).copied().unwrap_or(0.0);
68 (weight + 1.0) / (self.total + vocab as f64)
69 }
70}
71
72#[derive(Serialize, Deserialize)]
83#[serde(bound(
84 serialize = "C: Serialize",
85 deserialize = "C: Deserialize<'de> + Eq + Hash"
86))]
87pub struct Classifier<C> {
88 classes: Vec<C>,
89 data: HashMap<C, ClassData>,
90 tfidf: Option<HashMap<C, TfIdfClassData>>,
91 learned: usize,
92}
93
94impl<C: Eq + Hash + Clone> Classifier<C> {
95 pub fn new(classes: Vec<C>) -> Self {
99 assert!(classes.len() >= 2, "provide at least two classes");
100
101 let mut seen = HashSet::new();
102 for c in &classes {
103 assert!(seen.insert(c), "class labels must be unique");
104 }
105
106 let data = classes
107 .iter()
108 .map(|c| (c.clone(), ClassData::new()))
109 .collect();
110
111 Self {
112 classes,
113 data,
114 tfidf: None,
115 learned: 0,
116 }
117 }
118
119 pub fn learn<S: AsRef<str>>(&mut self, document: &[S], class: &C) {
126 let entry = self.data.get_mut(class).expect("unknown class");
127
128 for word in document {
130 *entry.freqs.entry(word.as_ref().to_owned()).or_default() += 1.0;
131 entry.total += 1;
132 }
133
134 let doc_len = document.len() as f64;
137 if doc_len > 0.0 {
138 let mut counts: HashMap<&str, usize> = HashMap::new();
139 for word in document {
140 *counts.entry(word.as_ref()).or_default() += 1;
141 }
142 for (word, count) in counts {
143 let tf = count as f64 / doc_len;
144 let ln1p = tf.ln_1p(); let acc = entry.tf_acc.entry(word.to_owned()).or_default();
146 acc.count += 1;
147 acc.sum_ln1p_tf += ln1p;
148 }
149 }
150
151 self.learned += 1;
152 }
153
154 pub fn build_tfidf(&mut self) {
161 let total_docs = self.learned as f64;
162
163 let mut doc_freq: HashMap<&str, usize> = HashMap::new();
167 for class_data in self.data.values() {
168 for (word, acc) in &class_data.tf_acc {
169 *doc_freq.entry(word.as_str()).or_default() += acc.count;
170 }
171 }
172
173 let mut tfidf: HashMap<C, TfIdfClassData> = HashMap::new();
174 for (class, class_data) in &self.data {
175 let mut weights: HashMap<String, f64> = HashMap::new();
176 let mut total = 0.0_f64;
177 for (word, acc) in &class_data.tf_acc {
178 let df = doc_freq[word.as_str()] as f64;
179 let idf = (1.0 + total_docs / df).ln();
180 let weight: f64 = acc.sum_ln1p_tf * idf;
182 weights.insert(word.clone(), weight);
183 total += weight;
184 }
185 tfidf.insert(class.clone(), TfIdfClassData { weights, total });
186 }
187
188 self.tfidf = Some(tfidf);
189 }
190
191 pub fn has_tfidf(&self) -> bool {
194 self.tfidf.is_some()
195 }
196
197 pub fn learned(&self) -> usize {
199 self.learned
200 }
201
202 pub fn classes(&self) -> &[C] {
204 &self.classes
205 }
206
207 fn priors(&self) -> Vec<f64> {
210 let n = self.classes.len() as f64;
211 let total: f64 = self.data.values().map(|d| d.total as f64).sum();
212 self.classes
213 .iter()
214 .map(|c| (self.data[c].total as f64 + 1.0) / (total + n))
215 .collect()
216 }
217
218 fn find_max(scores: &[f64]) -> usize {
219 scores
220 .iter()
221 .enumerate()
222 .max_by(|(_, a), (_, b)| a.total_cmp(b))
223 .map(|(i, _)| i)
224 .expect("classifier has no classes")
225 }
226
227 fn log_to_probs(log_scores: &[f64]) -> Vec<f64> {
228 let max = log_scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
231 let exps: Vec<f64> = log_scores.iter().map(|&s| (s - max).exp()).collect();
232 let sum: f64 = exps.iter().sum();
233 exps.iter().map(|&e| e / sum).collect()
234 }
235
236 fn score_document<S, F>(&self, document: &[S], word_prob: F) -> Vec<f64>
239 where
240 S: AsRef<str>,
241 F: Fn(&C, &str) -> f64,
242 {
243 self.priors()
244 .into_iter()
245 .zip(&self.classes)
246 .map(|(prior, class)| {
247 document.iter().fold(prior.ln(), |score, word| {
248 score + word_prob(class, word.as_ref()).ln()
249 })
250 })
251 .collect()
252 }
253
254 pub fn log_scores<S: AsRef<str>>(&self, document: &[S]) -> Vec<f64> {
259 self.score_document(document, |class, word| self.data[class].word_prob(word))
260 }
261
262 pub fn prob_scores<S: AsRef<str>>(&self, document: &[S]) -> Vec<f64> {
265 Self::log_to_probs(&self.log_scores(document))
266 }
267
268 pub fn classify<S: AsRef<str>>(&self, document: &[S]) -> &C {
270 &self.classes[Self::find_max(&self.log_scores(document))]
271 }
272
273 pub fn log_scores_tfidf<S: AsRef<str>>(&self, document: &[S]) -> Vec<f64> {
280 let tfidf = self
281 .tfidf
282 .as_ref()
283 .expect("call build_tfidf() before using TF-IDF classification");
284 self.score_document(document, |class, word| tfidf[class].word_prob(word))
285 }
286
287 pub fn prob_scores_tfidf<S: AsRef<str>>(&self, document: &[S]) -> Vec<f64> {
292 Self::log_to_probs(&self.log_scores_tfidf(document))
293 }
294
295 pub fn classify_tfidf<S: AsRef<str>>(&self, document: &[S]) -> &C {
299 &self.classes[Self::find_max(&self.log_scores_tfidf(document))]
300 }
301
302 pub fn serialize(&self) -> std::io::Result<Vec<u8>>
310 where
311 C: Serialize,
312 {
313 let bytes = bincode::serialize(self)
314 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
315 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
316 encoder.write_all(&bytes)?;
317 encoder.finish()
318 }
319
320 pub fn from_data(data: impl AsRef<[u8]>) -> std::io::Result<Self>
323 where
324 C: for<'de> Deserialize<'de>,
325 {
326 let mut decoder = GzDecoder::new(data.as_ref());
327 let mut bytes = Vec::new();
328 decoder.read_to_end(&mut bytes)?;
329 bincode::deserialize(&bytes)
330 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
331 }
332
333 pub fn serialize_to_file(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()>
337 where
338 C: Serialize,
339 {
340 std::fs::write(path, self.serialize()?)
341 }
342
343 pub fn from_file(path: impl AsRef<std::path::Path>) -> std::io::Result<Self>
346 where
347 C: for<'de> Deserialize<'de>,
348 {
349 Self::from_data(std::fs::read(path)?)
350 }
351}
352
353#[cfg(test)]
354mod tests;