use crate::dataset::Dataset;
use crate::error::{Result, ScryLearnError};
use crate::weights::{compute_sample_weights, ClassWeight};
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct BernoulliNB {
alpha: f64,
binarize: Option<f64>,
class_weight: ClassWeight,
log_probs: Vec<Vec<f64>>,
log_priors: Vec<f64>,
n_classes: usize,
fitted: bool,
#[cfg_attr(feature = "serde", serde(default))]
_schema_version: u32,
}
impl BernoulliNB {
pub fn new() -> Self {
Self {
alpha: 1.0,
binarize: Some(0.0),
class_weight: ClassWeight::Uniform,
log_probs: Vec::new(),
log_priors: Vec::new(),
n_classes: 0,
fitted: false,
_schema_version: crate::version::SCHEMA_VERSION,
}
}
pub fn alpha(mut self, a: f64) -> Self {
self.alpha = a;
self
}
pub fn binarize(mut self, threshold: Option<f64>) -> Self {
self.binarize = threshold;
self
}
pub fn class_weight(mut self, cw: ClassWeight) -> Self {
self.class_weight = cw;
self
}
pub fn fit(&mut self, data: &Dataset) -> Result<()> {
data.validate_finite()?;
let n = data.n_samples();
let m = data.n_features();
if n == 0 {
return Err(ScryLearnError::EmptyDataset);
}
self.n_classes = data.n_classes();
let sample_weights = compute_sample_weights(&data.target, &self.class_weight);
let mut feature_count = vec![vec![0.0_f64; m]; self.n_classes]; let mut class_weight_sum = vec![0.0_f64; self.n_classes];
for (i, (&sw, &target_val)) in sample_weights.iter().zip(data.target.iter()).enumerate() {
let c = target_val as usize;
if c >= self.n_classes {
continue;
}
class_weight_sum[c] += sw;
for (j, feat_col) in data.features.iter().enumerate() {
let val = feat_col[i];
let binary = self
.binarize
.map_or(val, |thresh| if val > thresh { 1.0 } else { 0.0 });
feature_count[c][j] += sw * binary;
}
}
self.log_probs = vec![vec![0.0; m]; self.n_classes];
for c in 0..self.n_classes {
for (lp, &cnt) in self.log_probs[c].iter_mut().zip(feature_count[c].iter()) {
*lp = (cnt + self.alpha) / (class_weight_sum[c] + 2.0 * self.alpha);
}
}
let total_weight: f64 = class_weight_sum.iter().sum();
self.log_priors = class_weight_sum
.iter()
.map(|&w| (w / total_weight).ln())
.collect();
self.fitted = true;
Ok(())
}
pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
crate::version::check_schema_version(self._schema_version)?;
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
let probas = self.predict_proba(features)?;
Ok(probas
.iter()
.map(|probs| {
probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0.0, |(idx, _)| idx as f64)
})
.collect())
}
pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
Ok(features
.iter()
.map(|row| {
let mut log_probs: Vec<f64> = (0..self.n_classes)
.map(|c| {
let mut lp = self.log_priors[c];
for (j, &x) in row.iter().enumerate() {
if j >= self.log_probs[c].len() {
continue;
}
let binary =
self.binarize
.map_or(x, |thresh| if x > thresh { 1.0 } else { 0.0 });
let p = self.log_probs[c][j];
if binary > 0.5 {
lp += p.ln();
} else {
lp += (1.0 - p).ln();
}
}
lp
})
.collect();
let max_log = log_probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let sum: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
for lp in &mut log_probs {
*lp = ((*lp - max_log).exp()) / sum;
}
log_probs
})
.collect())
}
}
impl Default for BernoulliNB {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bernoulli_nb_binary() {
let features = vec![
vec![1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
];
let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let data = Dataset::new(features, target, vec!["f0".into(), "f1".into()], "class");
let mut nb = BernoulliNB::new().binarize(Some(0.5));
nb.fit(&data).unwrap();
let preds = nb.predict(&[vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap();
assert!((preds[0] - 0.0).abs() < 1e-6, "should predict class 0");
assert!((preds[1] - 1.0).abs() < 1e-6, "should predict class 1");
}
#[test]
fn test_bernoulli_nb_binarize() {
let features = vec![
vec![0.9, 0.8, 0.7, 0.1, 0.2, 0.3],
vec![0.1, 0.2, 0.3, 0.9, 0.8, 0.7],
];
let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let data = Dataset::new(features, target, vec!["f0".into(), "f1".into()], "class");
let mut nb = BernoulliNB::new().binarize(Some(0.5));
nb.fit(&data).unwrap();
let preds = nb.predict(&[vec![0.8, 0.1], vec![0.1, 0.9]]).unwrap();
assert!((preds[0] - 0.0).abs() < 1e-6);
assert!((preds[1] - 1.0).abs() < 1e-6);
}
#[test]
fn test_bernoulli_nb_predict_proba() {
let features = vec![vec![1.0, 1.0, 0.0, 0.0], vec![0.0, 0.0, 1.0, 1.0]];
let target = vec![0.0, 0.0, 1.0, 1.0];
let data = Dataset::new(features, target, vec!["f0".into(), "f1".into()], "class");
let mut nb = BernoulliNB::new().binarize(Some(0.5));
nb.fit(&data).unwrap();
let probas = nb.predict_proba(&[vec![1.0, 0.0]]).unwrap();
assert_eq!(probas[0].len(), 2);
let sum: f64 = probas[0].iter().sum();
assert!(
(sum - 1.0).abs() < 1e-9,
"probabilities must sum to 1.0, got {sum}"
);
}
}