use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
pub(crate) trait NBDistribution<T: RealNumber, M: Matrix<T>> {
fn prior(&self, class_index: usize) -> T;
fn log_likelihood(&self, class_index: usize, j: &M::RowVector) -> T;
fn classes(&self) -> &Vec<T>;
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub(crate) struct BaseNaiveBayes<T: RealNumber, M: Matrix<T>, D: NBDistribution<T, M>> {
distribution: D,
_phantom_t: PhantomData<T>,
_phantom_m: PhantomData<M>,
}
impl<T: RealNumber, M: Matrix<T>, D: NBDistribution<T, M>> BaseNaiveBayes<T, M, D> {
pub fn fit(distribution: D) -> Result<Self, Failed> {
Ok(Self {
distribution,
_phantom_t: PhantomData,
_phantom_m: PhantomData,
})
}
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
let y_classes = self.distribution.classes();
let (rows, _) = x.shape();
let predictions = (0..rows)
.map(|row_index| {
let row = x.get_row(row_index);
let (prediction, _probability) = y_classes
.iter()
.enumerate()
.map(|(class_index, class)| {
(
class,
self.distribution.log_likelihood(class_index, &row)
+ self.distribution.prior(class_index).ln(),
)
})
.max_by(|(_, p1), (_, p2)| p1.partial_cmp(p2).unwrap())
.unwrap();
*prediction
})
.collect::<Vec<T>>();
let y_hat = M::RowVector::from_array(&predictions);
Ok(y_hat)
}
}
pub mod bernoulli;
pub mod categorical;
pub mod gaussian;
pub mod multinomial;