#![doc = include_str!("../README.md")]
mod base_nb;
mod bernoulli_nb;
mod error;
mod gaussian_nb;
mod hyperparams;
mod multinomial_nb;
pub use base_nb::NaiveBayes;
pub use bernoulli_nb::BernoulliNb;
pub use error::{NaiveBayesError, Result};
pub use gaussian_nb::GaussianNb;
pub use hyperparams::{BernoulliNbParams, BernoulliNbValidParams};
pub use hyperparams::{GaussianNbParams, GaussianNbValidParams};
pub use hyperparams::{MultinomialNbParams, MultinomialNbValidParams};
pub use multinomial_nb::MultinomialNb;
use linfa::{Float, Label};
use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
#[derive(Debug, Default, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub(crate) struct ClassHistogram<F> {
class_count: usize,
prior: F,
feature_count: Array1<F>,
feature_log_prob: Array1<F>,
}
impl<F: Float> ClassHistogram<F> {
fn update_with_smoothing(&mut self, x_new: ArrayView2<F>, alpha: F, total_count: bool) {
if x_new.nrows() == 0 {
return;
}
let ClassHistogram {
class_count,
feature_count,
feature_log_prob,
..
} = self;
let feature_count_new: Array1<F> = x_new.sum_axis(Axis(0));
if *class_count > 0 {
*feature_count = feature_count_new + feature_count.view();
} else {
*feature_count = feature_count_new;
}
let feature_count_smoothed = feature_count.mapv(|x| x + alpha);
let count = if total_count {
F::cast(x_new.nrows()) + alpha * F::cast(2)
} else {
feature_count_smoothed.sum()
};
*feature_log_prob = feature_count_smoothed.mapv(|x| x.ln() - count.ln());
*class_count += x_new.nrows();
}
}
pub(crate) fn filter<F: Float, L: Label + Ord>(
x: ArrayView2<F>,
y: ArrayView1<L>,
ycondition: &L,
) -> Array2<F> {
let index = y
.into_iter()
.enumerate()
.filter(|&(_, y)| *ycondition == *y)
.map(|(i, _)| i)
.collect::<Vec<_>>();
let mut xsubset = Array2::zeros((index.len(), x.ncols()));
index
.into_iter()
.enumerate()
.for_each(|(i, r)| xsubset.row_mut(i).assign(&x.slice(s![r, ..])));
xsubset
}