use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, Ix2, Zip};
use ndarray_stats::QuantileExt;
use std::collections::HashMap;
use crate::error::{NaiveBayesError, Result};
use linfa::dataset::{AsTargets, DatasetBase, Labels};
use linfa::traits::FitWith;
use linfa::{Float, Label};
pub trait NaiveBayes<'a, F, L>
where
F: Float,
L: Label + Ord,
{
fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>>;
#[doc(hidden)]
fn predict_inplace<D: Data<Elem = F>>(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
assert_eq!(
x.nrows(),
y.len(),
"The number of data points must match the number of output targets."
);
let joint_log_likelihood = self.joint_log_likelihood(x.view());
let nclasses = joint_log_likelihood.keys().len();
let n = x.nrows();
let mut classes = Vec::with_capacity(nclasses);
let mut likelihood = Array2::zeros((nclasses, n));
joint_log_likelihood
.iter()
.enumerate()
.for_each(|(i, (&key, value))| {
classes.push(key.clone());
likelihood.row_mut(i).assign(value);
});
*y = likelihood.map_axis(Axis(0), |x| {
let i = x.argmax().unwrap();
classes[i].clone()
});
}
fn predict_log_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<&L>) {
let log_likelihood = self.joint_log_likelihood(x);
let mut classes = log_likelihood.keys().cloned().collect::<Vec<_>>();
classes.sort();
let n_samples = x.nrows();
let n_classes = log_likelihood.len();
let mut log_prob_mat = Array2::<F>::zeros((n_samples, n_classes));
Zip::from(log_prob_mat.columns_mut())
.and(&classes)
.for_each(|mut jll, &class| jll.assign(log_likelihood.get(class).unwrap()));
let log_prob_x = log_prob_mat
.mapv(|x| x.exp())
.sum_axis(Axis(1))
.mapv(|x| x.ln())
.into_shape_with_order((n_samples, 1))
.unwrap();
(log_prob_mat - log_prob_x, classes)
}
fn predict_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<&L>) {
let (log_prob_mat, classes) = self.predict_log_proba(x);
(log_prob_mat.mapv(|v| v.exp()), classes)
}
}
pub(crate) trait NaiveBayesValidParams<'a, F, L, D, T>:
FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
where
F: Float,
L: Label + Ord,
D: Data<Elem = F>,
T: AsTargets<Elem = L> + Labels<Elem = L>,
{
fn fit(
&self,
dataset: &'a DatasetBase<ArrayBase<D, Ix2>, T>,
model_none: Self::ObjectIn,
) -> Result<Self::ObjectOut> {
let mut unique_classes = dataset.targets.labels();
unique_classes.sort_unstable();
self.fit_with(model_none, dataset)
}
}