use crate::dataset::Dataset;
use crate::error::{Result, ScryLearnError};
use crate::sparse::{CscMatrix, CsrMatrix};
use crate::weights::{compute_sample_weights, ClassWeight};
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct MultinomialNB {
alpha: f64,
class_weight: ClassWeight,
log_probs: Vec<Vec<f64>>,
log_priors: Vec<f64>,
n_classes: usize,
fitted: bool,
#[cfg_attr(feature = "serde", serde(default))]
_schema_version: u32,
}
impl MultinomialNB {
pub fn new() -> Self {
Self {
alpha: 1.0,
class_weight: ClassWeight::Uniform,
log_probs: Vec::new(),
log_priors: Vec::new(),
n_classes: 0,
fitted: false,
_schema_version: crate::version::SCHEMA_VERSION,
}
}
pub fn alpha(mut self, a: f64) -> Self {
self.alpha = a;
self
}
pub fn class_weight(mut self, cw: ClassWeight) -> Self {
self.class_weight = cw;
self
}
pub fn fit(&mut self, data: &Dataset) -> Result<()> {
data.validate_finite()?;
let n = data.n_samples();
let m = data.n_features();
if n == 0 {
return Err(ScryLearnError::EmptyDataset);
}
self.n_classes = data.n_classes();
let sample_weights = compute_sample_weights(&data.target, &self.class_weight);
let mut feature_sum = vec![vec![0.0_f64; m]; self.n_classes];
let mut class_weight_sum = vec![0.0_f64; self.n_classes];
for (i, (&sw, &target_val)) in sample_weights.iter().zip(data.target.iter()).enumerate() {
let c = target_val as usize;
if c >= self.n_classes {
continue;
}
class_weight_sum[c] += sw;
for (j, feat_col) in data.features.iter().enumerate() {
feature_sum[c][j] += sw * feat_col[i];
}
}
self.log_probs = vec![vec![0.0; m]; self.n_classes];
for (c_probs, c_sums) in self.log_probs.iter_mut().zip(feature_sum.iter()) {
let total: f64 = c_sums.iter().sum::<f64>() + self.alpha * m as f64;
for (lp, &fs) in c_probs.iter_mut().zip(c_sums.iter()) {
*lp = ((fs + self.alpha) / total).ln();
}
}
let total_weight: f64 = class_weight_sum.iter().sum();
self.log_priors = class_weight_sum
.iter()
.map(|&w| (w / total_weight).ln())
.collect();
self.fitted = true;
Ok(())
}
#[allow(clippy::needless_range_loop)]
pub fn fit_sparse(&mut self, features: &CscMatrix, target: &[f64]) -> Result<()> {
let n = features.n_rows();
let m = features.n_cols();
if n == 0 {
return Err(ScryLearnError::EmptyDataset);
}
if target.len() != n {
return Err(ScryLearnError::InvalidParameter(format!(
"target length {} != n_rows {}",
target.len(),
n
)));
}
let max_class = target.iter().map(|&t| t as usize).max().unwrap_or(0);
self.n_classes = max_class + 1;
let sample_weights = compute_sample_weights(target, &self.class_weight);
let mut feature_sum = vec![vec![0.0_f64; m]; self.n_classes];
let mut class_weight_sum = vec![0.0_f64; self.n_classes];
for (&sw, &t) in sample_weights.iter().zip(target.iter()) {
let c = t as usize;
if c < self.n_classes {
class_weight_sum[c] += sw;
}
}
for j in 0..m {
for (row_idx, val) in features.col(j).iter() {
let c = target[row_idx] as usize;
if c < self.n_classes {
feature_sum[c][j] += sample_weights[row_idx] * val;
}
}
}
self.log_probs = vec![vec![0.0; m]; self.n_classes];
for (c_probs, c_sums) in self.log_probs.iter_mut().zip(feature_sum.iter()) {
let total: f64 = c_sums.iter().sum::<f64>() + self.alpha * m as f64;
for (lp, &fs) in c_probs.iter_mut().zip(c_sums.iter()) {
*lp = ((fs + self.alpha) / total).ln();
}
}
let total_weight: f64 = class_weight_sum.iter().sum();
self.log_priors = class_weight_sum
.iter()
.map(|&w| (w / total_weight).ln())
.collect();
self.fitted = true;
Ok(())
}
pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
let probas = self.predict_proba_sparse(features)?;
Ok(probas
.iter()
.map(|probs| {
probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0.0, |(idx, _)| idx as f64)
})
.collect())
}
pub fn predict_proba_sparse(&self, features: &CsrMatrix) -> Result<Vec<Vec<f64>>> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
Ok((0..features.n_rows())
.map(|i| {
let row = features.row(i);
let mut log_probs: Vec<f64> = (0..self.n_classes)
.map(|c| {
let mut lp = self.log_priors[c];
for (col, val) in row.iter() {
if col < self.log_probs[c].len() {
lp += val * self.log_probs[c][col];
}
}
lp
})
.collect();
let max_log = log_probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let sum: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
for lp in &mut log_probs {
*lp = ((*lp - max_log).exp()) / sum;
}
log_probs
})
.collect())
}
pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
crate::version::check_schema_version(self._schema_version)?;
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
let probas = self.predict_proba(features)?;
Ok(probas
.iter()
.map(|probs| {
probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0.0, |(idx, _)| idx as f64)
})
.collect())
}
pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
Ok(features
.iter()
.map(|row| {
let mut log_probs: Vec<f64> = (0..self.n_classes)
.map(|c| {
let mut lp = self.log_priors[c];
for (j, &x) in row.iter().enumerate() {
if j >= self.log_probs[c].len() {
continue;
}
lp += x * self.log_probs[c][j];
}
lp
})
.collect();
let max_log = log_probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let sum: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
for lp in &mut log_probs {
*lp = ((*lp - max_log).exp()) / sum;
}
log_probs
})
.collect())
}
}
impl Default for MultinomialNB {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multinomial_nb_counts() {
let features = vec![
vec![5.0, 6.0, 4.0, 0.0, 1.0, 0.0],
vec![0.0, 1.0, 0.0, 5.0, 6.0, 4.0],
];
let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let data = Dataset::new(
features,
target,
vec!["word_a".into(), "word_b".into()],
"class",
);
let mut nb = MultinomialNB::new();
nb.fit(&data).unwrap();
let preds = nb.predict(&[vec![4.0, 0.0], vec![0.0, 5.0]]).unwrap();
assert!((preds[0] - 0.0).abs() < 1e-6, "high word_a → class 0");
assert!((preds[1] - 1.0).abs() < 1e-6, "high word_b → class 1");
}
#[test]
fn test_multinomial_nb_predict_proba() {
let features = vec![vec![5.0, 5.0, 0.0, 0.0], vec![0.0, 0.0, 5.0, 5.0]];
let target = vec![0.0, 0.0, 1.0, 1.0];
let data = Dataset::new(features, target, vec!["f0".into(), "f1".into()], "class");
let mut nb = MultinomialNB::new();
nb.fit(&data).unwrap();
let probas = nb.predict_proba(&[vec![4.0, 0.0]]).unwrap();
assert_eq!(probas[0].len(), 2);
let sum: f64 = probas[0].iter().sum();
assert!(
(sum - 1.0).abs() < 1e-9,
"probabilities must sum to 1.0, got {sum}"
);
}
#[test]
fn test_sparse_multinomial_nb_matches_dense() {
let features = vec![
vec![5.0, 6.0, 4.0, 0.0, 1.0, 0.0],
vec![0.0, 1.0, 0.0, 5.0, 6.0, 4.0],
];
let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let data = Dataset::new(
features.clone(),
target.clone(),
vec!["w_a".into(), "w_b".into()],
"class",
);
let mut nb_dense = MultinomialNB::new();
nb_dense.fit(&data).unwrap();
let csc = CscMatrix::from_dense(&features);
let mut nb_sparse = MultinomialNB::new();
nb_sparse.fit_sparse(&csc, &target).unwrap();
let test = vec![vec![4.0, 0.0], vec![0.0, 5.0]];
let preds_dense = nb_dense.predict(&test).unwrap();
let csr = CsrMatrix::from_dense(&test);
let preds_sparse = nb_sparse.predict_sparse(&csr).unwrap();
for (d, s) in preds_dense.iter().zip(preds_sparse.iter()) {
assert!((d - s).abs() < 1e-6, "Dense={d} vs Sparse={s}");
}
}
}