1use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
2use linfa::traits::{Fit, FitWith, PredictInplace};
3use linfa::{Float, Label};
4use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2};
5use ndarray_stats::QuantileExt;
6use std::collections::HashMap;
7use std::hash::Hash;
8
9use crate::base_nb::{filter, NaiveBayes, NaiveBayesValidParams};
10use crate::error::{NaiveBayesError, Result};
11use crate::hyperparams::{GaussianNbParams, GaussianNbValidParams};
12
13#[cfg(feature = "serde")]
14use serde_crate::{Deserialize, Serialize};
15
16impl<'a, F, L, D, T> NaiveBayesValidParams<'a, F, L, D, T> for GaussianNbValidParams<F, L>
17where
18 F: Float,
19 L: Label + 'a,
20 D: Data<Elem = F>,
21 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
22{
23}
24
25impl<F, L, D, T> Fit<ArrayBase<D, Ix2>, T, NaiveBayesError> for GaussianNbValidParams<F, L>
26where
27 F: Float,
28 L: Label + Ord,
29 D: Data<Elem = F>,
30 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
31{
32 type Object = GaussianNb<F, L>;
33
34 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
36 let model = NaiveBayesValidParams::fit(self, dataset, None)?;
37 Ok(model.unwrap())
38 }
39}
40
41impl<'a, F, L, D, T> FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
42 for GaussianNbValidParams<F, L>
43where
44 F: Float,
45 L: Label + 'a,
46 D: Data<Elem = F>,
47 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
48{
49 type ObjectIn = Option<GaussianNb<F, L>>;
50 type ObjectOut = Option<GaussianNb<F, L>>;
51
52 fn fit_with(
53 &self,
54 model_in: Self::ObjectIn,
55 dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
56 ) -> Result<Self::ObjectOut> {
57 let x = dataset.records();
58 let y = dataset.as_single_targets();
59
60 let epsilon = self.var_smoothing() * *x.var_axis(Axis(0), F::zero()).max()?;
64
65 let mut model = match model_in {
66 Some(mut temp) => {
67 temp.class_info
68 .values_mut()
69 .for_each(|x| x.sigma -= epsilon);
70 temp
71 }
72 None => GaussianNb {
73 class_info: HashMap::new(),
74 },
75 };
76
77 let yunique = dataset.labels();
78
79 for class in yunique {
80 let xclass = filter(x.view(), y.view(), &class);
82
83 let nclass = xclass.nrows();
85
86 let class_info = model
88 .class_info
89 .entry(class)
90 .or_insert_with(GaussianClassInfo::default);
91
92 let (theta_new, sigma_new) = Self::update_mean_variance(class_info, xclass.view());
93
94 class_info.theta = theta_new;
96 class_info.sigma = sigma_new;
97 class_info.class_count += nclass;
98 }
99
100 model
103 .class_info
104 .values_mut()
105 .for_each(|x| x.sigma += epsilon);
106
107 let class_count_sum = model
109 .class_info
110 .values()
111 .map(|x| x.class_count)
112 .sum::<usize>();
113
114 for info in model.class_info.values_mut() {
115 info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
116 }
117
118 Ok(Some(model))
119 }
120}
121
122impl<F: Float, L: Label, D> PredictInplace<ArrayBase<D, Ix2>, Array1<L>> for GaussianNb<F, L>
123where
124 D: Data<Elem = F>,
125{
126 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
128 NaiveBayes::predict_inplace(self, x, y);
129 }
130
131 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
132 Array1::default(x.nrows())
133 }
134}
135
136impl<F, L> GaussianNbValidParams<F, L>
137where
138 F: Float,
139{
140 fn update_mean_variance(
142 info_old: &GaussianClassInfo<F>,
143 x_new: ArrayView2<F>,
144 ) -> (Array1<F>, Array1<F>) {
145 let (count_old, mu_old, var_old) = (info_old.class_count, &info_old.theta, &info_old.sigma);
147
148 if x_new.nrows() == 0 {
150 return (mu_old.to_owned(), var_old.to_owned());
151 }
152
153 let count_new = x_new.nrows();
154
155 let mu_new = x_new.mean_axis(Axis(0)).unwrap();
158 let var_new = x_new.var_axis(Axis(0), F::zero());
159
160 if count_old == 0 {
162 return (mu_new, var_new);
163 }
164
165 let count_total = count_old + count_new;
166
167 let mu_new_weighted = &mu_new * F::cast(count_new);
170 let mu_old_weighted = mu_old * F::cast(count_old);
171 let mu_weighted = (mu_new_weighted + mu_old_weighted).mapv(|x| x / F::cast(count_total));
172
173 let ssd_old = var_old * F::cast(count_old);
177 let ssd_new = var_new * F::cast(count_new);
178 let weight = F::cast(count_new * count_old) / F::cast(count_total);
179 let ssd_weighted = ssd_old + ssd_new + (mu_old - mu_new).mapv(|x| weight * x.powi(2));
180 let var_weighted = ssd_weighted.mapv(|x| x / F::cast(count_total));
181
182 (mu_weighted, var_weighted)
183 }
184}
185
186#[cfg_attr(
233 feature = "serde",
234 derive(Serialize, Deserialize),
235 serde(crate = "serde_crate")
236)]
237#[derive(Debug, Clone, PartialEq)]
238pub struct GaussianNb<F: PartialEq, L: Eq + Hash> {
239 class_info: HashMap<L, GaussianClassInfo<F>>,
240}
241
242#[cfg_attr(
243 feature = "serde",
244 derive(Serialize, Deserialize),
245 serde(crate = "serde_crate")
246)]
247#[derive(Debug, Default, Clone, PartialEq)]
248struct GaussianClassInfo<F> {
249 class_count: usize,
250 prior: F,
251 theta: Array1<F>,
252 sigma: Array1<F>,
253}
254
255impl<F: Float, L: Label> GaussianNb<F, L> {
256 pub fn params() -> GaussianNbParams<F, L> {
258 GaussianNbParams::new()
259 }
260}
261
262impl<F, L> NaiveBayes<'_, F, L> for GaussianNb<F, L>
263where
264 F: Float,
265 L: Label + Ord,
266{
267 fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>> {
269 let mut joint_log_likelihood = HashMap::new();
270
271 for (class, info) in self.class_info.iter() {
272 let jointi = info.prior.ln();
273
274 let mut nij = info
275 .sigma
276 .mapv(|x| F::cast(2. * std::f64::consts::PI) * x)
277 .mapv(|x| x.ln())
278 .sum();
279 nij = F::cast(-0.5) * nij;
280
281 let nij = ((x.to_owned() - &info.theta).mapv(|x| x.powi(2)) / &info.sigma)
282 .sum_axis(Axis(1))
283 .mapv(|x| x * F::cast(0.5))
284 .mapv(|x| nij - x);
285
286 joint_log_likelihood.insert(class, nij + jointi);
287 }
288
289 joint_log_likelihood
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::{GaussianNb, NaiveBayes, Result};
296 use linfa::{
297 traits::{Fit, FitWith, Predict},
298 DatasetView,
299 };
300
301 use crate::gaussian_nb::GaussianClassInfo;
302 use crate::{GaussianNbParams, GaussianNbValidParams, NaiveBayesError};
303 use approx::assert_abs_diff_eq;
304 use ndarray::{array, Axis};
305 use std::collections::HashMap;
306
307 #[test]
308 fn autotraits() {
309 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
310 has_autotraits::<GaussianNb<f64, usize>>();
311 has_autotraits::<GaussianClassInfo<f64>>();
312 has_autotraits::<GaussianNbParams<f64, usize>>();
313 has_autotraits::<GaussianNbValidParams<f64, usize>>();
314 has_autotraits::<NaiveBayesError>();
315 }
316
317 #[test]
318 fn test_gaussian_nb() -> Result<()> {
319 let x = array![
320 [-2., -1.],
321 [-1., -1.],
322 [-1., -2.],
323 [1., 1.],
324 [1., 2.],
325 [2., 1.]
326 ];
327 let y = array![1, 1, 1, 2, 2, 2];
328
329 let data = DatasetView::new(x.view(), y.view());
330 let fitted_clf = GaussianNb::params().fit(&data)?;
331 let pred = fitted_clf.predict(&x);
332
333 assert_abs_diff_eq!(pred, y);
334
335 let jll = fitted_clf.joint_log_likelihood(x.view());
336
337 let mut expected = HashMap::new();
338 expected.insert(
339 &1usize,
340 array![
341 -2.276946847943017,
342 -1.5269468546930165,
343 -2.276946847943017,
344 -25.52694663869301,
345 -38.27694652394301,
346 -38.27694652394301
347 ],
348 );
349 expected.insert(
350 &2usize,
351 array![
352 -38.27694652394301,
353 -25.52694663869301,
354 -38.27694652394301,
355 -1.5269468546930165,
356 -2.276946847943017,
357 -2.276946847943017
358 ],
359 );
360
361 assert_eq!(jll, expected);
362
363 Ok(())
364 }
365
366 #[test]
367 fn test_gnb_fit_with() -> Result<()> {
368 let x = array![
369 [-2., -1.],
370 [-1., -1.],
371 [-1., -2.],
372 [1., 1.],
373 [1., 2.],
374 [2., 1.]
375 ];
376 let y = array![1, 1, 1, 2, 2, 2];
377
378 let clf = GaussianNb::params();
379
380 let model = x
381 .axis_chunks_iter(Axis(0), 2)
382 .zip(y.axis_chunks_iter(Axis(0), 2))
383 .map(|(a, b)| DatasetView::new(a, b))
384 .fold(None, |current, d| clf.fit_with(current, &d).unwrap())
385 .unwrap();
386
387 let pred = model.predict(&x);
388
389 assert_abs_diff_eq!(pred, y);
390
391 let jll = model.joint_log_likelihood(x.view());
392
393 let mut expected = HashMap::new();
394 expected.insert(
395 &1usize,
396 array![
397 -2.276946847943017,
398 -1.5269468546930165,
399 -2.276946847943017,
400 -25.52694663869301,
401 -38.27694652394301,
402 -38.27694652394301
403 ],
404 );
405 expected.insert(
406 &2usize,
407 array![
408 -38.27694652394301,
409 -25.52694663869301,
410 -38.27694652394301,
411 -1.5269468546930165,
412 -2.276946847943017,
413 -2.276946847943017
414 ],
415 );
416
417 for (key, value) in jll.iter() {
418 assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
419 }
420
421 Ok(())
422 }
423}