linfa_bayes/
multinomial_nb.rs

1use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
2use linfa::traits::{Fit, FitWith, PredictInplace};
3use linfa::{Float, Label};
4use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2};
5use std::collections::HashMap;
6use std::hash::Hash;
7
8use crate::base_nb::{filter, NaiveBayes, NaiveBayesValidParams};
9use crate::error::{NaiveBayesError, Result};
10use crate::hyperparams::{MultinomialNbParams, MultinomialNbValidParams};
11
12#[cfg(feature = "serde")]
13use serde_crate::{Deserialize, Serialize};
14
15impl<'a, F, L, D, T> NaiveBayesValidParams<'a, F, L, D, T> for MultinomialNbValidParams<F, L>
16where
17    F: Float,
18    L: Label + 'a,
19    D: Data<Elem = F>,
20    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
21{
22}
23
24impl<F, L, D, T> Fit<ArrayBase<D, Ix2>, T, NaiveBayesError> for MultinomialNbValidParams<F, L>
25where
26    F: Float,
27    L: Label + Ord,
28    D: Data<Elem = F>,
29    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
30{
31    type Object = MultinomialNb<F, L>;
32    // Thin wrapper around the corresponding method of NaiveBayesValidParams
33    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
34        let model = NaiveBayesValidParams::fit(self, dataset, None)?;
35        Ok(model.unwrap())
36    }
37}
38
39impl<'a, F, L, D, T> FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
40    for MultinomialNbValidParams<F, L>
41where
42    F: Float,
43    L: Label + 'a,
44    D: Data<Elem = F>,
45    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
46{
47    type ObjectIn = Option<MultinomialNb<F, L>>;
48    type ObjectOut = Option<MultinomialNb<F, L>>;
49
50    fn fit_with(
51        &self,
52        model_in: Self::ObjectIn,
53        dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
54    ) -> Result<Self::ObjectOut> {
55        let x = dataset.records();
56        let y = dataset.as_single_targets();
57
58        let mut model = match model_in {
59            Some(temp) => temp,
60            None => MultinomialNb {
61                class_info: HashMap::new(),
62            },
63        };
64
65        let yunique = dataset.labels();
66
67        for class in yunique {
68            // We filter for records that correspond to the current class
69            let xclass = filter(x.view(), y.view(), &class);
70            // We count the number of occurences of the class
71            let nclass = xclass.nrows();
72
73            // We compute the feature log probabilities and feature counts on the slice corresponding to the current class
74            let class_info = model
75                .class_info
76                .entry(class)
77                .or_insert_with(MultinomialClassInfo::default);
78            let (feature_log_prob, feature_count) =
79                self.update_feature_log_prob(class_info, xclass.view());
80            // We now update the total counts of each feature, feature log probabilities, and class count
81            class_info.feature_log_prob = feature_log_prob;
82            class_info.feature_count = feature_count;
83            class_info.class_count += nclass;
84        }
85
86        // We update the priors
87        let class_count_sum = model
88            .class_info
89            .values()
90            .map(|x| x.class_count)
91            .sum::<usize>();
92        for info in model.class_info.values_mut() {
93            info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
94        }
95        Ok(Some(model))
96    }
97}
98
99impl<F: Float, L: Label, D> PredictInplace<ArrayBase<D, Ix2>, Array1<L>> for MultinomialNb<F, L>
100where
101    D: Data<Elem = F>,
102{
103    // Thin wrapper around the corresponding method of NaiveBayes
104    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
105        NaiveBayes::predict_inplace(self, x, y);
106    }
107
108    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
109        Array1::default(x.nrows())
110    }
111}
112
113impl<F, L> MultinomialNbValidParams<F, L>
114where
115    F: Float,
116{
117    // Update log probabilities of features given class
118    fn update_feature_log_prob(
119        &self,
120        info_old: &MultinomialClassInfo<F>,
121        x_new: ArrayView2<F>,
122    ) -> (Array1<F>, Array1<F>) {
123        // Deconstruct old state
124        let (count_old, feature_log_prob_old, feature_count_old) = (
125            &info_old.class_count,
126            &info_old.feature_log_prob,
127            &info_old.feature_count,
128        );
129
130        // If incoming data is empty no updates required
131        if x_new.nrows() == 0 {
132            return (
133                feature_log_prob_old.to_owned(),
134                feature_count_old.to_owned(),
135            );
136        }
137
138        let feature_count_new = x_new.sum_axis(Axis(0));
139
140        // If previous batch was empty, we send the new feature count calculated
141        let feature_count = if count_old > &0 {
142            feature_count_old + feature_count_new
143        } else {
144            feature_count_new
145        };
146        // Apply smoothing to feature counts
147        let feature_count_smoothed = feature_count.clone() + self.alpha();
148        // Compute total count over all (smoothed) features
149        let count = feature_count_smoothed.sum();
150        // Compute log probabilities of each feature
151        let feature_log_prob = feature_count_smoothed.mapv(|x| x.ln() - F::cast(count).ln());
152        (feature_log_prob.to_owned(), feature_count.to_owned())
153    }
154}
155
156/// Fitted Multinomial Naive Bayes classifier.
157///
158/// See [MultinomialNbParams] for more information on the hyper-parameters.
159///
160/// # Model assumptions
161///
162/// The family of Naive Bayes classifiers assume independence between variables. They do not model
163/// moments between variables and lack therefore in modelling capability. The advantage is a linear
164/// fitting time with maximum-likelihood training in a closed form.
165///
166/// # Model usage example
167///
168/// The example below creates a set of hyperparameters, and then uses it to fit a Multinomial Naive
169/// Bayes classifier on provided data.
170///
171/// ```rust
172/// use linfa_bayes::{MultinomialNbParams, MultinomialNbValidParams, Result};
173/// use linfa::prelude::*;
174/// use ndarray::array;
175///
176/// let x = array![
177///     [-2., -1.],
178///     [-1., -1.],
179///     [-1., -2.],
180///     [1., 1.],
181///     [1., 2.],
182///     [2., 1.]
183/// ];
184/// let y = array![1, 1, 1, 2, 2, 2];
185/// let ds = DatasetView::new(x.view(), y.view());
186///
187/// // create a new parameter set with smoothing parameter equals `1`
188/// let unchecked_params = MultinomialNbParams::new()
189///     .alpha(1.0);
190///
191/// // fit model with unchecked parameter set
192/// let model = unchecked_params.fit(&ds)?;
193///
194/// // transform into a verified parameter set
195/// let checked_params = unchecked_params.check()?;
196///
197/// // update model with the verified parameters, this only returns
198/// // errors originating from the fitting process
199/// let model = checked_params.fit_with(Some(model), &ds)?;
200/// # Result::Ok(())
201/// ```
202#[cfg_attr(
203    feature = "serde",
204    derive(Serialize, Deserialize),
205    serde(crate = "serde_crate")
206)]
207#[derive(Debug, Clone, PartialEq)]
208pub struct MultinomialNb<F: PartialEq, L: Eq + Hash> {
209    class_info: HashMap<L, MultinomialClassInfo<F>>,
210}
211
212#[cfg_attr(
213    feature = "serde",
214    derive(Serialize, Deserialize),
215    serde(crate = "serde_crate")
216)]
217#[derive(Debug, Default, Clone, PartialEq)]
218struct MultinomialClassInfo<F> {
219    class_count: usize,
220    prior: F,
221    feature_count: Array1<F>,
222    feature_log_prob: Array1<F>,
223}
224
225impl<F: Float, L: Label> MultinomialNb<F, L> {
226    /// Construct a new set of hyperparameters
227    pub fn params() -> MultinomialNbParams<F, L> {
228        MultinomialNbParams::new()
229    }
230}
231
232impl<F, L> NaiveBayes<'_, F, L> for MultinomialNb<F, L>
233where
234    F: Float,
235    L: Label + Ord,
236{
237    // Compute unnormalized posterior log probability
238    fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>> {
239        let mut joint_log_likelihood = HashMap::new();
240        for (class, info) in self.class_info.iter() {
241            // Combine feature log probabilities and class priors to get log-likelihood for each class
242            let jointi = info.prior.ln();
243            let nij = x.dot(&info.feature_log_prob);
244            joint_log_likelihood.insert(class, nij + jointi);
245        }
246
247        joint_log_likelihood
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::{MultinomialNb, NaiveBayes, Result};
254    use linfa::{
255        traits::{Fit, FitWith, Predict},
256        DatasetView,
257    };
258
259    use crate::multinomial_nb::MultinomialClassInfo;
260    use crate::{MultinomialNbParams, MultinomialNbValidParams};
261    use approx::assert_abs_diff_eq;
262    use ndarray::{array, Axis};
263    use std::collections::HashMap;
264
265    #[test]
266    fn autotraits() {
267        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
268        has_autotraits::<MultinomialNb<f64, usize>>();
269        has_autotraits::<MultinomialClassInfo<f64>>();
270        has_autotraits::<MultinomialNbValidParams<f64, usize>>();
271        has_autotraits::<MultinomialNbParams<f64, usize>>();
272    }
273
274    #[test]
275    fn test_multinomial_nb() -> Result<()> {
276        let x = array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]];
277        let y = array![1, 1, 1, 2, 2, 2];
278
279        let data = DatasetView::new(x.view(), y.view());
280        let fitted_clf = MultinomialNb::params().fit(&data)?;
281        let pred = fitted_clf.predict(&x);
282
283        assert_abs_diff_eq!(pred, y);
284
285        let jll = fitted_clf.joint_log_likelihood(x.view());
286        let mut expected = HashMap::new();
287        // Computed with sklearn.naive_bayes.MultinomialNB
288        expected.insert(
289            &1usize,
290            array![
291                -0.82667857,
292                -0.96020997,
293                -1.09374136,
294                -2.77258872,
295                -4.85203026,
296                -6.93147181
297            ],
298        );
299
300        expected.insert(
301            &2usize,
302            array![
303                -2.77258872,
304                -4.85203026,
305                -6.93147181,
306                -0.82667857,
307                -0.96020997,
308                -1.09374136
309            ],
310        );
311
312        for (key, value) in jll.iter() {
313            assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
314        }
315
316        Ok(())
317    }
318
319    #[test]
320    fn test_mnb_fit_with() -> Result<()> {
321        let x = array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]];
322        let y = array![1, 1, 1, 2, 2, 2];
323
324        let clf = MultinomialNb::params();
325
326        let model = x
327            .axis_chunks_iter(Axis(0), 2)
328            .zip(y.axis_chunks_iter(Axis(0), 2))
329            .map(|(a, b)| DatasetView::new(a, b))
330            .fold(None, |current, d| clf.fit_with(current, &d).unwrap())
331            .unwrap();
332
333        let pred = model.predict(&x);
334
335        assert_abs_diff_eq!(pred, y);
336
337        let jll = model.joint_log_likelihood(x.view());
338
339        let mut expected = HashMap::new();
340        // Computed with sklearn.naive_bayes.MultinomialNB
341        expected.insert(
342            &1usize,
343            array![
344                -0.82667857,
345                -0.96020997,
346                -1.09374136,
347                -2.77258872,
348                -4.85203026,
349                -6.93147181
350            ],
351        );
352
353        expected.insert(
354            &2usize,
355            array![
356                -2.77258872,
357                -4.85203026,
358                -6.93147181,
359                -0.82667857,
360                -0.96020997,
361                -1.09374136
362            ],
363        );
364
365        for (key, value) in jll.iter() {
366            assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
367        }
368
369        Ok(())
370    }
371}