1use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
2use linfa::traits::{Fit, FitWith, PredictInplace};
3use linfa::{Float, Label};
4use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2};
5use std::collections::HashMap;
6use std::hash::Hash;
7
8use crate::base_nb::{filter, NaiveBayes, NaiveBayesValidParams};
9use crate::error::{NaiveBayesError, Result};
10use crate::hyperparams::{MultinomialNbParams, MultinomialNbValidParams};
11
12#[cfg(feature = "serde")]
13use serde_crate::{Deserialize, Serialize};
14
15impl<'a, F, L, D, T> NaiveBayesValidParams<'a, F, L, D, T> for MultinomialNbValidParams<F, L>
16where
17 F: Float,
18 L: Label + 'a,
19 D: Data<Elem = F>,
20 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
21{
22}
23
24impl<F, L, D, T> Fit<ArrayBase<D, Ix2>, T, NaiveBayesError> for MultinomialNbValidParams<F, L>
25where
26 F: Float,
27 L: Label + Ord,
28 D: Data<Elem = F>,
29 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
30{
31 type Object = MultinomialNb<F, L>;
32 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
34 let model = NaiveBayesValidParams::fit(self, dataset, None)?;
35 Ok(model.unwrap())
36 }
37}
38
39impl<'a, F, L, D, T> FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
40 for MultinomialNbValidParams<F, L>
41where
42 F: Float,
43 L: Label + 'a,
44 D: Data<Elem = F>,
45 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
46{
47 type ObjectIn = Option<MultinomialNb<F, L>>;
48 type ObjectOut = Option<MultinomialNb<F, L>>;
49
50 fn fit_with(
51 &self,
52 model_in: Self::ObjectIn,
53 dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
54 ) -> Result<Self::ObjectOut> {
55 let x = dataset.records();
56 let y = dataset.as_single_targets();
57
58 let mut model = match model_in {
59 Some(temp) => temp,
60 None => MultinomialNb {
61 class_info: HashMap::new(),
62 },
63 };
64
65 let yunique = dataset.labels();
66
67 for class in yunique {
68 let xclass = filter(x.view(), y.view(), &class);
70 let nclass = xclass.nrows();
72
73 let class_info = model
75 .class_info
76 .entry(class)
77 .or_insert_with(MultinomialClassInfo::default);
78 let (feature_log_prob, feature_count) =
79 self.update_feature_log_prob(class_info, xclass.view());
80 class_info.feature_log_prob = feature_log_prob;
82 class_info.feature_count = feature_count;
83 class_info.class_count += nclass;
84 }
85
86 let class_count_sum = model
88 .class_info
89 .values()
90 .map(|x| x.class_count)
91 .sum::<usize>();
92 for info in model.class_info.values_mut() {
93 info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
94 }
95 Ok(Some(model))
96 }
97}
98
99impl<F: Float, L: Label, D> PredictInplace<ArrayBase<D, Ix2>, Array1<L>> for MultinomialNb<F, L>
100where
101 D: Data<Elem = F>,
102{
103 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
105 NaiveBayes::predict_inplace(self, x, y);
106 }
107
108 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
109 Array1::default(x.nrows())
110 }
111}
112
113impl<F, L> MultinomialNbValidParams<F, L>
114where
115 F: Float,
116{
117 fn update_feature_log_prob(
119 &self,
120 info_old: &MultinomialClassInfo<F>,
121 x_new: ArrayView2<F>,
122 ) -> (Array1<F>, Array1<F>) {
123 let (count_old, feature_log_prob_old, feature_count_old) = (
125 &info_old.class_count,
126 &info_old.feature_log_prob,
127 &info_old.feature_count,
128 );
129
130 if x_new.nrows() == 0 {
132 return (
133 feature_log_prob_old.to_owned(),
134 feature_count_old.to_owned(),
135 );
136 }
137
138 let feature_count_new = x_new.sum_axis(Axis(0));
139
140 let feature_count = if count_old > &0 {
142 feature_count_old + feature_count_new
143 } else {
144 feature_count_new
145 };
146 let feature_count_smoothed = feature_count.clone() + self.alpha();
148 let count = feature_count_smoothed.sum();
150 let feature_log_prob = feature_count_smoothed.mapv(|x| x.ln() - F::cast(count).ln());
152 (feature_log_prob.to_owned(), feature_count.to_owned())
153 }
154}
155
156#[cfg_attr(
203 feature = "serde",
204 derive(Serialize, Deserialize),
205 serde(crate = "serde_crate")
206)]
207#[derive(Debug, Clone, PartialEq)]
208pub struct MultinomialNb<F: PartialEq, L: Eq + Hash> {
209 class_info: HashMap<L, MultinomialClassInfo<F>>,
210}
211
212#[cfg_attr(
213 feature = "serde",
214 derive(Serialize, Deserialize),
215 serde(crate = "serde_crate")
216)]
217#[derive(Debug, Default, Clone, PartialEq)]
218struct MultinomialClassInfo<F> {
219 class_count: usize,
220 prior: F,
221 feature_count: Array1<F>,
222 feature_log_prob: Array1<F>,
223}
224
225impl<F: Float, L: Label> MultinomialNb<F, L> {
226 pub fn params() -> MultinomialNbParams<F, L> {
228 MultinomialNbParams::new()
229 }
230}
231
232impl<F, L> NaiveBayes<'_, F, L> for MultinomialNb<F, L>
233where
234 F: Float,
235 L: Label + Ord,
236{
237 fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>> {
239 let mut joint_log_likelihood = HashMap::new();
240 for (class, info) in self.class_info.iter() {
241 let jointi = info.prior.ln();
243 let nij = x.dot(&info.feature_log_prob);
244 joint_log_likelihood.insert(class, nij + jointi);
245 }
246
247 joint_log_likelihood
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::{MultinomialNb, NaiveBayes, Result};
254 use linfa::{
255 traits::{Fit, FitWith, Predict},
256 DatasetView,
257 };
258
259 use crate::multinomial_nb::MultinomialClassInfo;
260 use crate::{MultinomialNbParams, MultinomialNbValidParams};
261 use approx::assert_abs_diff_eq;
262 use ndarray::{array, Axis};
263 use std::collections::HashMap;
264
265 #[test]
266 fn autotraits() {
267 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
268 has_autotraits::<MultinomialNb<f64, usize>>();
269 has_autotraits::<MultinomialClassInfo<f64>>();
270 has_autotraits::<MultinomialNbValidParams<f64, usize>>();
271 has_autotraits::<MultinomialNbParams<f64, usize>>();
272 }
273
274 #[test]
275 fn test_multinomial_nb() -> Result<()> {
276 let x = array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]];
277 let y = array![1, 1, 1, 2, 2, 2];
278
279 let data = DatasetView::new(x.view(), y.view());
280 let fitted_clf = MultinomialNb::params().fit(&data)?;
281 let pred = fitted_clf.predict(&x);
282
283 assert_abs_diff_eq!(pred, y);
284
285 let jll = fitted_clf.joint_log_likelihood(x.view());
286 let mut expected = HashMap::new();
287 expected.insert(
289 &1usize,
290 array![
291 -0.82667857,
292 -0.96020997,
293 -1.09374136,
294 -2.77258872,
295 -4.85203026,
296 -6.93147181
297 ],
298 );
299
300 expected.insert(
301 &2usize,
302 array![
303 -2.77258872,
304 -4.85203026,
305 -6.93147181,
306 -0.82667857,
307 -0.96020997,
308 -1.09374136
309 ],
310 );
311
312 for (key, value) in jll.iter() {
313 assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
314 }
315
316 Ok(())
317 }
318
319 #[test]
320 fn test_mnb_fit_with() -> Result<()> {
321 let x = array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]];
322 let y = array![1, 1, 1, 2, 2, 2];
323
324 let clf = MultinomialNb::params();
325
326 let model = x
327 .axis_chunks_iter(Axis(0), 2)
328 .zip(y.axis_chunks_iter(Axis(0), 2))
329 .map(|(a, b)| DatasetView::new(a, b))
330 .fold(None, |current, d| clf.fit_with(current, &d).unwrap())
331 .unwrap();
332
333 let pred = model.predict(&x);
334
335 assert_abs_diff_eq!(pred, y);
336
337 let jll = model.joint_log_likelihood(x.view());
338
339 let mut expected = HashMap::new();
340 expected.insert(
342 &1usize,
343 array![
344 -0.82667857,
345 -0.96020997,
346 -1.09374136,
347 -2.77258872,
348 -4.85203026,
349 -6.93147181
350 ],
351 );
352
353 expected.insert(
354 &2usize,
355 array![
356 -2.77258872,
357 -4.85203026,
358 -6.93147181,
359 -0.82667857,
360 -0.96020997,
361 -1.09374136
362 ],
363 );
364
365 for (key, value) in jll.iter() {
366 assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
367 }
368
369 Ok(())
370 }
371}