use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
use linfa::traits::{Fit, FitWith, PredictInplace};
use linfa::{Float, Label};
use ndarray::{Array1, ArrayBase, ArrayView2, CowArray, Data, Ix2};
use std::collections::HashMap;
use std::hash::Hash;
use crate::base_nb::{NaiveBayes, NaiveBayesValidParams};
use crate::error::{NaiveBayesError, Result};
use crate::hyperparams::{BernoulliNbParams, BernoulliNbValidParams};
use crate::{filter, ClassHistogram};
impl<'a, F, L, D, T> NaiveBayesValidParams<'a, F, L, D, T> for BernoulliNbValidParams<F, L>
where
F: Float,
L: Label + 'a,
D: Data<Elem = F>,
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
{
}
impl<F, L, D, T> Fit<ArrayBase<D, Ix2>, T, NaiveBayesError> for BernoulliNbValidParams<F, L>
where
F: Float,
L: Label + Ord,
D: Data<Elem = F>,
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
{
type Object = BernoulliNb<F, L>;
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
NaiveBayesValidParams::fit(self, dataset, None)
}
}
impl<'a, F, L, D, T> FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
for BernoulliNbValidParams<F, L>
where
F: Float,
L: Label + 'a,
D: Data<Elem = F>,
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
{
type ObjectIn = Option<BernoulliNb<F, L>>;
type ObjectOut = BernoulliNb<F, L>;
fn fit_with(
&self,
model_in: Self::ObjectIn,
dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
) -> Result<Self::ObjectOut> {
let x = dataset.records();
let y = dataset.as_single_targets();
let mut model = match model_in {
Some(temp) => temp,
None => BernoulliNb {
class_info: HashMap::new(),
binarize: self.binarize(),
},
};
let xbin = model.binarize(x).to_owned();
let yunique = dataset.labels();
for class in yunique {
let xclass = filter(xbin.view(), y.view(), &class);
model
.class_info
.entry(class)
.or_insert_with(ClassHistogram::default)
.update_with_smoothing(xclass.view(), self.alpha(), true);
}
let class_count_sum = model
.class_info
.values()
.map(|x| x.class_count)
.sum::<usize>();
for info in model.class_info.values_mut() {
info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
}
Ok(model)
}
}
impl<F: Float, L: Label, D> PredictInplace<ArrayBase<D, Ix2>, Array1<L>> for BernoulliNb<F, L>
where
D: Data<Elem = F>,
{
fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
let xbin = self.binarize(x);
NaiveBayes::predict_inplace(self, &xbin, y);
}
fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
Array1::default(x.nrows())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct BernoulliNb<F: PartialEq, L: Eq + Hash> {
class_info: HashMap<L, ClassHistogram<F>>,
binarize: Option<F>,
}
impl<F: Float, L: Label> BernoulliNb<F, L> {
pub fn params() -> BernoulliNbParams<F, L> {
BernoulliNbParams::new()
}
fn binarize<'a, D>(&'a self, x: &'a ArrayBase<D, Ix2>) -> CowArray<'a, F, Ix2>
where
D: Data<Elem = F>,
{
if let Some(thr) = self.binarize {
let xbin = x.map(|v| if v > &thr { F::one() } else { F::zero() });
CowArray::from(xbin)
} else {
CowArray::from(x)
}
}
}
impl<F, L> NaiveBayes<'_, F, L> for BernoulliNb<F, L>
where
F: Float,
L: Label + Ord,
{
fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>> {
let mut joint_log_likelihood = HashMap::new();
for (class, info) in self.class_info.iter() {
let neg_prob = info.feature_log_prob.map(|lp| (F::one() - lp.exp()).ln());
let feature_log_prob = &info.feature_log_prob - &neg_prob;
let jll = x.dot(&feature_log_prob);
joint_log_likelihood.insert(class, jll + info.prior.ln() + neg_prob.sum());
}
joint_log_likelihood
}
}
#[cfg(test)]
mod tests {
use super::{BernoulliNb, NaiveBayes, Result};
use linfa::{
traits::{Fit, Predict},
DatasetView,
};
use crate::{BernoulliNbParams, BernoulliNbValidParams};
use approx::assert_abs_diff_eq;
use ndarray::array;
use std::collections::HashMap;
#[test]
fn autotraits() {
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
has_autotraits::<BernoulliNb<f64, usize>>();
has_autotraits::<BernoulliNbValidParams<f64, usize>>();
has_autotraits::<BernoulliNbParams<f64, usize>>();
}
#[test]
fn test_bernoulli_nb() -> Result<()> {
let x = array![[1., 0.], [0., 0.], [1., 1.], [0., 1.]];
let y = array![1, 1, 2, 2];
let data = DatasetView::new(x.view(), y.view());
let params = BernoulliNb::params().binarize(None);
let fitted_clf = params.fit(&data)?;
assert!(&fitted_clf.binarize.is_none());
let pred = fitted_clf.predict(&x);
assert_abs_diff_eq!(pred, y);
let jll = fitted_clf.joint_log_likelihood(x.view());
let mut expected = HashMap::new();
expected.insert(
&1usize,
(array![0.1875f64, 0.1875, 0.0625, 0.0625]).map(|v| v.ln()),
);
expected.insert(
&2usize,
(array![0.0625f64, 0.0625, 0.1875, 0.1875,]).map(|v| v.ln()),
);
for (key, value) in jll.iter() {
assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
}
Ok(())
}
#[test]
fn test_text_class() -> Result<()> {
let train = array![
[2., 1., 0., 0., 0., 0.0f64],
[2., 0., 1., 0., 0., 0.],
[1., 0., 0., 1., 0., 0.],
[1., 0., 0., 0., 1., 1.],
];
let y = array![1, 1, 1, 2];
let test = array![[3., 0., 0., 0., 1., 1.0f64]];
let data = DatasetView::new(train.view(), y.view());
let fitted_clf = BernoulliNb::params().fit(&data)?;
let pred = fitted_clf.predict(&test);
assert_abs_diff_eq!(pred, array![2]);
let jll = fitted_clf.joint_log_likelihood(fitted_clf.binarize(&test).view());
assert_abs_diff_eq!(jll.get(&1).unwrap()[0].exp(), 0.005, epsilon = 1e-3);
assert_abs_diff_eq!(jll.get(&2).unwrap()[0].exp(), 0.022, epsilon = 1e-3);
Ok(())
}
}