use crate::error::Result;
use crate::primitives::Matrix;
#[derive(Debug, Clone)]
pub struct BernoulliNB {
alpha: f32,
binarize: f32,
class_log_prior: Vec<f32>,
feature_prob: Vec<Vec<f32>>,
n_features: usize,
}
impl Default for BernoulliNB {
fn default() -> Self {
Self::new()
}
}
impl BernoulliNB {
#[must_use]
pub fn new() -> Self {
Self {
alpha: 1.0,
binarize: 0.0,
class_log_prior: Vec::new(),
feature_prob: Vec::new(),
n_features: 0,
}
}
#[must_use]
pub fn with_alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha;
self
}
#[must_use]
pub fn with_binarize(mut self, binarize: f32) -> Self {
self.binarize = binarize;
self
}
pub fn fit(&mut self, x: &Matrix<f32>, y: &[usize]) -> Result<()> {
let (n_samples, n_features) = x.shape();
if n_samples == 0 {
return Err("BernoulliNB: cannot fit with zero samples".into());
}
if y.len() != n_samples {
return Err("BernoulliNB: x/y length mismatch".into());
}
let n_classes = y.iter().max().map_or(0, |&m| m + 1);
let mut class_count = vec![0usize; n_classes];
let mut present = vec![vec![0.0f64; n_features]; n_classes];
for (i, &c) in y.iter().enumerate() {
class_count[c] += 1;
for j in 0..n_features {
if x.get(i, j) > self.binarize {
present[c][j] += 1.0;
}
}
}
let alpha = f64::from(self.alpha);
self.class_log_prior = (0..n_classes)
.map(|c| (class_count[c] as f64 / n_samples as f64).ln() as f32)
.collect();
self.feature_prob = (0..n_classes)
.map(|c| {
let denom = class_count[c] as f64 + 2.0 * alpha;
(0..n_features)
.map(|j| ((present[c][j] + alpha) / denom) as f32)
.collect()
})
.collect();
self.n_features = n_features;
Ok(())
}
#[must_use]
pub fn predict(&self, x: &Matrix<f32>) -> Vec<usize> {
let (n_samples, _) = x.shape();
(0..n_samples)
.map(|i| {
let mut best_c = 0;
let mut best_ll = f32::NEG_INFINITY;
for (c, prior) in self.class_log_prior.iter().enumerate() {
let mut ll = *prior;
for j in 0..self.n_features {
let b = f32::from(x.get(i, j) > self.binarize);
let p = self.feature_prob[c][j].clamp(1e-9, 1.0 - 1e-9);
ll += b * p.ln() + (1.0 - b) * (1.0 - p).ln();
}
if ll > best_ll {
best_ll = ll;
best_c = c;
}
}
best_c
})
.collect()
}
}
impl crate::traits::Estimator for BernoulliNB {
fn fit(&mut self, x: &Matrix<f32>, y: &crate::primitives::Vector<f32>) -> Result<()> {
let labels: Vec<usize> = y.as_slice().iter().map(|&v| v.round() as usize).collect();
BernoulliNB::fit(self, x, &labels)
}
fn predict(&self, x: &Matrix<f32>) -> crate::primitives::Vector<f32> {
let labels = BernoulliNB::predict(self, x);
crate::primitives::Vector::from_vec(labels.into_iter().map(|l| l as f32).collect())
}
fn score(&self, x: &Matrix<f32>, y: &crate::primitives::Vector<f32>) -> f32 {
let preds = BernoulliNB::predict(self, x);
let n = y.len();
if n == 0 {
return 0.0;
}
let correct = preds
.iter()
.zip(y.as_slice())
.filter(|(&p, &t)| p == t.round() as usize)
.count();
correct as f32 / n as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bernoulli_nb_matches_sklearn() {
let x = Matrix::from_vec(
4,
4,
vec![
1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0,
],
)
.expect("valid");
let y = [0usize, 0, 1, 1];
let mut nb = BernoulliNB::new();
nb.fit(&x, &y).expect("fit");
assert_eq!(nb.predict(&x), vec![0, 0, 1, 1]);
let xt =
Matrix::from_vec(2, 4, vec![1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0]).expect("valid");
assert_eq!(nb.predict(&xt), vec![0, 1]);
}
}