Skip to main content

bayesian/
lib.rs

1use std::collections::{HashMap, HashSet};
2use std::hash::Hash;
3use std::io::{Read, Write};
4
5use flate2::{Compression, read::GzDecoder, write::GzEncoder};
6use serde::{Deserialize, Serialize};
7
8const DEFAULT_PROB: f64 = 1e-11;
9
10// ---------------------------------------------------------------------------
11// Private data structures
12// ---------------------------------------------------------------------------
13
14#[derive(Serialize, Deserialize)]
15struct ClassData {
16    freqs: HashMap<String, f64>,
17    total: usize,
18    tf_acc: HashMap<String, TfAccumulator>,
19}
20
21impl ClassData {
22    fn new() -> Self {
23        Self {
24            freqs: HashMap::new(),
25            total: 0,
26            tf_acc: HashMap::new(),
27        }
28    }
29
30    fn word_prob(&self, word: &str) -> f64 {
31        let vocab = self.freqs.len();
32        if self.total == 0 || vocab == 0 {
33            return DEFAULT_PROB;
34        }
35        let count = self.freqs.get(word).copied().unwrap_or(0.0);
36        (count + 1.0) / (self.total as f64 + vocab as f64)
37    }
38}
39
40#[derive(Serialize, Deserialize)]
41struct TfAccumulator {
42    count: usize,
43    sum_ln1p_tf: f64,
44}
45
46impl Default for TfAccumulator {
47    fn default() -> Self {
48        Self {
49            count: 0,
50            sum_ln1p_tf: 0.0,
51        }
52    }
53}
54
55#[derive(Serialize, Deserialize)]
56struct TfIdfClassData {
57    weights: HashMap<String, f64>,
58    total: f64,
59}
60
61impl TfIdfClassData {
62    fn word_prob(&self, word: &str) -> f64 {
63        let vocab = self.weights.len();
64        if self.total == 0.0 || vocab == 0 {
65            return DEFAULT_PROB;
66        }
67        let weight = self.weights.get(word).copied().unwrap_or(0.0);
68        (weight + 1.0) / (self.total + vocab as f64)
69    }
70}
71
72// ---------------------------------------------------------------------------
73// Public API
74// ---------------------------------------------------------------------------
75
76/// A Naive Bayes classifier, optionally upgraded with TF-IDF weights.
77///
78/// Plain Naive Bayes is always available via [`classify`], [`log_scores`], and
79/// [`prob_scores`]. After calling [`build_tfidf`] the TF-IDF variants become
80/// available. You may call [`build_tfidf`] again at any time after learning
81/// more documents — it recomputes from scratch over all accumulated data.
82#[derive(Serialize, Deserialize)]
83#[serde(bound(
84    serialize = "C: Serialize",
85    deserialize = "C: Deserialize<'de> + Eq + Hash"
86))]
87pub struct Classifier<C> {
88    classes: Vec<C>,
89    data: HashMap<C, ClassData>,
90    tfidf: Option<HashMap<C, TfIdfClassData>>,
91    learned: usize,
92}
93
94impl<C: Eq + Hash + Clone> Classifier<C> {
95    /// Creates a new classifier.
96    ///
97    /// Panics if fewer than two classes are given, or if they are not unique.
98    pub fn new(classes: Vec<C>) -> Self {
99        assert!(classes.len() >= 2, "provide at least two classes");
100
101        let mut seen = HashSet::new();
102        for c in &classes {
103            assert!(seen.insert(c), "class labels must be unique");
104        }
105
106        let data = classes
107            .iter()
108            .map(|c| (c.clone(), ClassData::new()))
109            .collect();
110
111        Self {
112            classes,
113            data,
114            tfidf: None,
115            learned: 0,
116        }
117    }
118
119    /// Trains the classifier on a document (a slice of words) for the given class.
120    ///
121    /// Accumulates both raw word counts (used by plain Naive Bayes) and
122    /// compact TF accumulators (used by [`build_tfidf`]). The two sets of data
123    /// are kept separate, so calling [`build_tfidf`] never destroys the raw
124    /// counts and plain classification always remains available.
125    pub fn learn<S: AsRef<str>>(&mut self, document: &[S], class: &C) {
126        let entry = self.data.get_mut(class).expect("unknown class");
127
128        // Plain NB: raw word counts.
129        for word in document {
130            *entry.freqs.entry(word.as_ref().to_owned()).or_default() += 1.0;
131            entry.total += 1;
132        }
133
134        // TF-IDF: one TF sample per unique word in this document, but stored
135        // compactly as accumulators instead of vectors.
136        let doc_len = document.len() as f64;
137        if doc_len > 0.0 {
138            let mut counts: HashMap<&str, usize> = HashMap::new();
139            for word in document {
140                *counts.entry(word.as_ref()).or_default() += 1;
141            }
142            for (word, count) in counts {
143                let tf = count as f64 / doc_len;
144                let ln1p = tf.ln_1p(); // ln(1 + tf)
145                let acc = entry.tf_acc.entry(word.to_owned()).or_default();
146                acc.count += 1;
147                acc.sum_ln1p_tf += ln1p;
148            }
149        }
150
151        self.learned += 1;
152    }
153
154    /// Computes TF-IDF weights from all documents learned so far and stores
155    /// them internally, enabling the `_tfidf` family of classification methods.
156    ///
157    /// Safe to call multiple times — each call recomputes from scratch over the
158    /// full accumulated training set, so you can learn more documents and call
159    /// this again to refresh the weights without losing anything.
160    pub fn build_tfidf(&mut self) {
161        let total_docs = self.learned as f64;
162
163        // Global document frequency: how many docs across all classes contain
164        // each word. Each accumulator's `count` represents the number of docs
165        // in that class that contain the word.
166        let mut doc_freq: HashMap<&str, usize> = HashMap::new();
167        for class_data in self.data.values() {
168            for (word, acc) in &class_data.tf_acc {
169                *doc_freq.entry(word.as_str()).or_default() += acc.count;
170            }
171        }
172
173        let mut tfidf: HashMap<C, TfIdfClassData> = HashMap::new();
174        for (class, class_data) in &self.data {
175            let mut weights: HashMap<String, f64> = HashMap::new();
176            let mut total = 0.0_f64;
177            for (word, acc) in &class_data.tf_acc {
178                let df = doc_freq[word.as_str()] as f64;
179                let idf = (1.0 + total_docs / df).ln();
180                // Weight for this class is idf * Σ ln(1 + tf) over docs in class
181                let weight: f64 = acc.sum_ln1p_tf * idf;
182                weights.insert(word.clone(), weight);
183                total += weight;
184            }
185            tfidf.insert(class.clone(), TfIdfClassData { weights, total });
186        }
187
188        self.tfidf = Some(tfidf);
189    }
190
191    /// Returns `true` if [`build_tfidf`] has been called and TF-IDF weights
192    /// are ready for classification.
193    pub fn has_tfidf(&self) -> bool {
194        self.tfidf.is_some()
195    }
196
197    /// Returns the total number of documents the classifier has been trained on.
198    pub fn learned(&self) -> usize {
199        self.learned
200    }
201
202    /// Returns the ordered slice of class labels used to build this classifier.
203    pub fn classes(&self) -> &[C] {
204        &self.classes
205    }
206
207    // --- Shared private helpers ---
208
209    fn priors(&self) -> Vec<f64> {
210        let n = self.classes.len() as f64;
211        let total: f64 = self.data.values().map(|d| d.total as f64).sum();
212        self.classes
213            .iter()
214            .map(|c| (self.data[c].total as f64 + 1.0) / (total + n))
215            .collect()
216    }
217
218    fn find_max(scores: &[f64]) -> usize {
219        scores
220            .iter()
221            .enumerate()
222            .max_by(|(_, a), (_, b)| a.total_cmp(b))
223            .map(|(i, _)| i)
224            .expect("classifier has no classes")
225    }
226
227    fn log_to_probs(log_scores: &[f64]) -> Vec<f64> {
228        // Log-sum-exp trick: shift by max before exponentiating to prevent
229        // overflow/underflow for any document length.
230        let max = log_scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
231        let exps: Vec<f64> = log_scores.iter().map(|&s| (s - max).exp()).collect();
232        let sum: f64 = exps.iter().sum();
233        exps.iter().map(|&e| e / sum).collect()
234    }
235
236    /// Computes a log-score for each class given a document and a per-class
237    /// word probability function. Shared by both NB and TF-IDF scoring paths.
238    fn score_document<S, F>(&self, document: &[S], word_prob: F) -> Vec<f64>
239    where
240        S: AsRef<str>,
241        F: Fn(&C, &str) -> f64,
242    {
243        self.priors()
244            .into_iter()
245            .zip(&self.classes)
246            .map(|(prior, class)| {
247                document.iter().fold(prior.ln(), |score, word| {
248                    score + word_prob(class, word.as_ref()).ln()
249                })
250            })
251            .collect()
252    }
253
254    // --- Plain Naive Bayes ---
255
256    /// Returns the log-likelihood score for each class using raw word counts.
257    /// Index `i` corresponds to `classes()[i]`.
258    pub fn log_scores<S: AsRef<str>>(&self, document: &[S]) -> Vec<f64> {
259        self.score_document(document, |class, word| self.data[class].word_prob(word))
260    }
261
262    /// Returns normalised probability scores for each class using raw word
263    /// counts (sum to 1.0). Index `i` corresponds to `classes()[i]`.
264    pub fn prob_scores<S: AsRef<str>>(&self, document: &[S]) -> Vec<f64> {
265        Self::log_to_probs(&self.log_scores(document))
266    }
267
268    /// Returns the most likely class for the given document using plain Naive Bayes.
269    pub fn classify<S: AsRef<str>>(&self, document: &[S]) -> &C {
270        &self.classes[Self::find_max(&self.log_scores(document))]
271    }
272
273    // --- TF-IDF ---
274
275    /// Returns the log-likelihood score for each class using TF-IDF weights.
276    /// Index `i` corresponds to `classes()[i]`.
277    ///
278    /// Panics if [`build_tfidf`] has not been called yet.
279    pub fn log_scores_tfidf<S: AsRef<str>>(&self, document: &[S]) -> Vec<f64> {
280        let tfidf = self
281            .tfidf
282            .as_ref()
283            .expect("call build_tfidf() before using TF-IDF classification");
284        self.score_document(document, |class, word| tfidf[class].word_prob(word))
285    }
286
287    /// Returns normalised probability scores for each class using TF-IDF
288    /// weights (sum to 1.0). Index `i` corresponds to `classes()[i]`.
289    ///
290    /// Panics if [`build_tfidf`] has not been called yet.
291    pub fn prob_scores_tfidf<S: AsRef<str>>(&self, document: &[S]) -> Vec<f64> {
292        Self::log_to_probs(&self.log_scores_tfidf(document))
293    }
294
295    /// Returns the most likely class for the given document using TF-IDF weights.
296    ///
297    /// Panics if [`build_tfidf`] has not been called yet.
298    pub fn classify_tfidf<S: AsRef<str>>(&self, document: &[S]) -> &C {
299        &self.classes[Self::find_max(&self.log_scores_tfidf(document))]
300    }
301
302    // --- Serialization ---
303
304    /// Serializes the classifier (including any built TF-IDF weights) to a
305    /// compressed binary blob (bincode + gzip).
306    ///
307    /// The result can be stored to disk, sent over the wire, etc. Pass it back
308    /// to [`Classifier::from_data`] to reconstruct an identical classifier.
309    pub fn serialize(&self) -> std::io::Result<Vec<u8>>
310    where
311        C: Serialize,
312    {
313        let bytes = bincode::serialize(self)
314            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
315        let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
316        encoder.write_all(&bytes)?;
317        encoder.finish()
318    }
319
320    /// Reconstructs a [`Classifier`] from bytes previously produced by
321    /// [`Classifier::serialize`].
322    pub fn from_data(data: impl AsRef<[u8]>) -> std::io::Result<Self>
323    where
324        C: for<'de> Deserialize<'de>,
325    {
326        let mut decoder = GzDecoder::new(data.as_ref());
327        let mut bytes = Vec::new();
328        decoder.read_to_end(&mut bytes)?;
329        bincode::deserialize(&bytes)
330            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
331    }
332
333    /// Writes the serialized classifier to a file at the given path.
334    ///
335    /// Creates the file if it does not exist, truncates it if it does.
336    pub fn serialize_to_file(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()>
337    where
338        C: Serialize,
339    {
340        std::fs::write(path, self.serialize()?)
341    }
342
343    /// Reconstructs a [`Classifier`] from a file previously written by
344    /// [`Classifier::serialize_to_file`].
345    pub fn from_file(path: impl AsRef<std::path::Path>) -> std::io::Result<Self>
346    where
347        C: for<'de> Deserialize<'de>,
348    {
349        Self::from_data(std::fs::read(path)?)
350    }
351}
352
353#[cfg(test)]
354mod tests;