use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
use linfa::traits::{Fit, FitWith, PredictInplace};
use linfa::{Float, Label};
use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2};
use std::collections::HashMap;
use std::hash::Hash;
use crate::base_nb::{filter, NaiveBayes, NaiveBayesValidParams};
use crate::error::{NaiveBayesError, Result};
use crate::hyperparams::{MultinomialNbParams, MultinomialNbValidParams};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
impl<'a, F, L, D, T> NaiveBayesValidParams<'a, F, L, D, T> for MultinomialNbValidParams<F, L>
where
F: Float,
L: Label + 'a,
D: Data<Elem = F>,
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
{
}
impl<F, L, D, T> Fit<ArrayBase<D, Ix2>, T, NaiveBayesError> for MultinomialNbValidParams<F, L>
where
F: Float,
L: Label + Ord,
D: Data<Elem = F>,
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
{
type Object = MultinomialNb<F, L>;
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
let model = NaiveBayesValidParams::fit(self, dataset, None)?;
Ok(model.unwrap())
}
}
impl<'a, F, L, D, T> FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
for MultinomialNbValidParams<F, L>
where
F: Float,
L: Label + 'a,
D: Data<Elem = F>,
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
{
type ObjectIn = Option<MultinomialNb<F, L>>;
type ObjectOut = Option<MultinomialNb<F, L>>;
fn fit_with(
&self,
model_in: Self::ObjectIn,
dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
) -> Result<Self::ObjectOut> {
let x = dataset.records();
let y = dataset.as_single_targets();
let mut model = match model_in {
Some(temp) => temp,
None => MultinomialNb {
class_info: HashMap::new(),
},
};
let yunique = dataset.labels();
for class in yunique {
let xclass = filter(x.view(), y.view(), &class);
let nclass = xclass.nrows();
let mut class_info = model
.class_info
.entry(class)
.or_insert_with(MultinomialClassInfo::default);
let (feature_log_prob, feature_count) =
self.update_feature_log_prob(class_info, xclass.view());
class_info.feature_log_prob = feature_log_prob;
class_info.feature_count = feature_count;
class_info.class_count += nclass;
}
let class_count_sum = model
.class_info
.values()
.map(|x| x.class_count)
.sum::<usize>();
for info in model.class_info.values_mut() {
info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
}
Ok(Some(model))
}
}
impl<F: Float, L: Label, D> PredictInplace<ArrayBase<D, Ix2>, Array1<L>> for MultinomialNb<F, L>
where
D: Data<Elem = F>,
{
fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
NaiveBayes::predict_inplace(self, x, y);
}
fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
Array1::default(x.nrows())
}
}
impl<F, L> MultinomialNbValidParams<F, L>
where
F: Float,
{
fn update_feature_log_prob(
&self,
info_old: &MultinomialClassInfo<F>,
x_new: ArrayView2<F>,
) -> (Array1<F>, Array1<F>) {
let (count_old, feature_log_prob_old, feature_count_old) = (
&info_old.class_count,
&info_old.feature_log_prob,
&info_old.feature_count,
);
if x_new.nrows() == 0 {
return (
feature_log_prob_old.to_owned(),
feature_count_old.to_owned(),
);
}
let feature_count_new = x_new.sum_axis(Axis(0));
let feature_count = if count_old > &0 {
feature_count_old + feature_count_new
} else {
feature_count_new
};
let feature_count_smoothed = feature_count.clone() + self.alpha();
let count = feature_count_smoothed.sum();
let feature_log_prob = feature_count_smoothed.mapv(|x| x.ln() - F::cast(count).ln());
(feature_log_prob.to_owned(), feature_count.to_owned())
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone, PartialEq)]
pub struct MultinomialNb<F: PartialEq, L: Eq + Hash> {
class_info: HashMap<L, MultinomialClassInfo<F>>,
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Default, Clone, PartialEq)]
struct MultinomialClassInfo<F> {
class_count: usize,
prior: F,
feature_count: Array1<F>,
feature_log_prob: Array1<F>,
}
impl<F: Float, L: Label> MultinomialNb<F, L> {
pub fn params() -> MultinomialNbParams<F, L> {
MultinomialNbParams::new()
}
}
impl<'a, F, L> NaiveBayes<'a, F, L> for MultinomialNb<F, L>
where
F: Float,
L: Label + Ord,
{
fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>> {
let mut joint_log_likelihood = HashMap::new();
for (class, info) in self.class_info.iter() {
let jointi = info.prior.ln();
let nij = x.dot(&info.feature_log_prob);
joint_log_likelihood.insert(class, nij + jointi);
}
joint_log_likelihood
}
}
#[cfg(test)]
mod tests {
use super::{MultinomialNb, NaiveBayes, Result};
use linfa::{
traits::{Fit, FitWith, Predict},
DatasetView,
};
use crate::multinomial_nb::MultinomialClassInfo;
use crate::{MultinomialNbParams, MultinomialNbValidParams};
use approx::assert_abs_diff_eq;
use ndarray::{array, Axis};
use std::collections::HashMap;
#[test]
fn autotraits() {
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
has_autotraits::<MultinomialNb<f64, usize>>();
has_autotraits::<MultinomialClassInfo<f64>>();
has_autotraits::<MultinomialNbValidParams<f64, usize>>();
has_autotraits::<MultinomialNbParams<f64, usize>>();
}
#[test]
fn test_multinomial_nb() -> Result<()> {
let x = array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]];
let y = array![1, 1, 1, 2, 2, 2];
let data = DatasetView::new(x.view(), y.view());
let fitted_clf = MultinomialNb::params().fit(&data)?;
let pred = fitted_clf.predict(&x);
assert_abs_diff_eq!(pred, y);
let jll = fitted_clf.joint_log_likelihood(x.view());
let mut expected = HashMap::new();
expected.insert(
&1usize,
array![
-0.82667857,
-0.96020997,
-1.09374136,
-2.77258872,
-4.85203026,
-6.93147181
],
);
expected.insert(
&2usize,
array![
-2.77258872,
-4.85203026,
-6.93147181,
-0.82667857,
-0.96020997,
-1.09374136
],
);
for (key, value) in jll.iter() {
assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
}
Ok(())
}
#[test]
fn test_mnb_fit_with() -> Result<()> {
let x = array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]];
let y = array![1, 1, 1, 2, 2, 2];
let clf = MultinomialNb::params();
let model = x
.axis_chunks_iter(Axis(0), 2)
.zip(y.axis_chunks_iter(Axis(0), 2))
.map(|(a, b)| DatasetView::new(a, b))
.fold(None, |current, d| clf.fit_with(current, &d).unwrap())
.unwrap();
let pred = model.predict(&x);
assert_abs_diff_eq!(pred, y);
let jll = model.joint_log_likelihood(x.view());
let mut expected = HashMap::new();
expected.insert(
&1usize,
array![
-0.82667857,
-0.96020997,
-1.09374136,
-2.77258872,
-4.85203026,
-6.93147181
],
);
expected.insert(
&2usize,
array![
-2.77258872,
-4.85203026,
-6.93147181,
-0.82667857,
-0.96020997,
-1.09374136
],
);
for (key, value) in jll.iter() {
assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
}
Ok(())
}
}