use crate::dataset::Dataset;
use crate::error::{Result, ScryLearnError};
use crate::partial_fit::PartialFit;
use crate::sparse::{CscMatrix, CsrMatrix};
use crate::weights::{compute_sample_weights, ClassWeight};
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct GaussianNb {
means: Vec<Vec<f64>>,
variances: Vec<Vec<f64>>,
priors: Vec<f64>,
class_weight: ClassWeight,
var_smoothing: f64,
n_classes: usize,
fitted: bool,
partial_count: Vec<f64>,
partial_sum: Vec<Vec<f64>>,
partial_sum_sq: Vec<Vec<f64>>,
n_features_partial: usize,
#[cfg_attr(feature = "serde", serde(default))]
_schema_version: u32,
}
impl GaussianNb {
pub fn new() -> Self {
Self {
means: Vec::new(),
variances: Vec::new(),
priors: Vec::new(),
class_weight: ClassWeight::Uniform,
var_smoothing: 1e-9,
n_classes: 0,
fitted: false,
partial_count: Vec::new(),
partial_sum: Vec::new(),
partial_sum_sq: Vec::new(),
n_features_partial: 0,
_schema_version: crate::version::SCHEMA_VERSION,
}
}
pub fn class_weight(mut self, cw: ClassWeight) -> Self {
self.class_weight = cw;
self
}
pub fn var_smoothing(mut self, vs: f64) -> Self {
self.var_smoothing = vs;
self
}
pub fn fit(&mut self, data: &Dataset) -> Result<()> {
data.validate_finite()?;
if let Some(csc) = data.sparse_csc() {
return self.fit_sparse(csc, &data.target);
}
let n = data.n_samples();
let m = data.n_features();
if n == 0 {
return Err(ScryLearnError::EmptyDataset);
}
self.n_classes = data.n_classes();
self.means = vec![vec![0.0; m]; self.n_classes];
self.variances = vec![vec![0.0; m]; self.n_classes];
let mat = data.matrix();
let sample_weights = compute_sample_weights(&data.target, &self.class_weight);
let mut w_sums = vec![0.0_f64; self.n_classes];
for (i, (&sw, &target_val)) in sample_weights.iter().zip(data.target.iter()).enumerate() {
let c = target_val as usize;
if c >= self.n_classes {
continue;
}
w_sums[c] += sw;
for j in 0..m {
self.means[c][j] += sw * mat.get(i, j);
}
}
for (c_means, &ws) in self.means.iter_mut().zip(w_sums.iter()) {
if ws > 0.0 {
for mj in c_means.iter_mut() {
*mj /= ws;
}
}
}
for (i, (&sw, &target_val)) in sample_weights.iter().zip(data.target.iter()).enumerate() {
let c = target_val as usize;
if c >= self.n_classes {
continue;
}
for j in 0..m {
let diff = mat.get(i, j) - self.means[c][j];
self.variances[c][j] += sw * diff * diff;
}
}
for (c_var, &ws) in self.variances.iter_mut().zip(w_sums.iter()) {
if ws > 0.0 {
for vj in c_var.iter_mut() {
*vj /= ws;
}
}
}
let max_var = self
.variances
.iter()
.flat_map(|cv| cv.iter())
.copied()
.fold(0.0_f64, f64::max);
let epsilon = self.var_smoothing * max_var.max(1e-12);
for c_var in &mut self.variances {
for vj in c_var.iter_mut() {
*vj += epsilon;
}
}
let w_total: f64 = w_sums.iter().sum();
self.priors = w_sums.iter().map(|&ws| ws / w_total).collect();
self.fitted = true;
Ok(())
}
pub fn class_means(&self) -> &[Vec<f64>] {
&self.means
}
pub fn class_variances(&self) -> &[Vec<f64>] {
&self.variances
}
pub fn class_priors(&self) -> &[f64] {
&self.priors
}
#[allow(clippy::needless_range_loop)]
pub fn fit_sparse(&mut self, features: &CscMatrix, target: &[f64]) -> Result<()> {
let n = features.n_rows();
let m = features.n_cols();
if n == 0 {
return Err(ScryLearnError::EmptyDataset);
}
if target.len() != n {
return Err(ScryLearnError::InvalidParameter(format!(
"target length {} != n_rows {}",
target.len(),
n
)));
}
let max_class = target.iter().map(|&t| t as usize).max().unwrap_or(0);
self.n_classes = max_class + 1;
self.means = vec![vec![0.0; m]; self.n_classes];
self.variances = vec![vec![0.0; m]; self.n_classes];
let sample_weights = compute_sample_weights(target, &self.class_weight);
let mut w_sums = vec![0.0_f64; self.n_classes];
for (i, &t) in target.iter().enumerate() {
let c = t as usize;
if c < self.n_classes {
w_sums[c] += sample_weights[i];
}
}
for j in 0..m {
for (row_idx, val) in features.col(j).iter() {
let c = target[row_idx] as usize;
if c < self.n_classes {
self.means[c][j] += sample_weights[row_idx] * val;
}
}
}
for (c_means, &ws) in self.means.iter_mut().zip(w_sums.iter()) {
if ws > 0.0 {
for mj in c_means.iter_mut() {
*mj /= ws;
}
}
}
let mut nnz_weight_per_class_feat = vec![vec![0.0; m]; self.n_classes];
for j in 0..m {
for (row_idx, val) in features.col(j).iter() {
let c = target[row_idx] as usize;
if c < self.n_classes {
let sw = sample_weights[row_idx];
let diff = val - self.means[c][j];
self.variances[c][j] += sw * diff * diff;
nnz_weight_per_class_feat[c][j] += sw;
}
}
}
for c in 0..self.n_classes {
for j in 0..m {
let zero_weight = w_sums[c] - nnz_weight_per_class_feat[c][j];
if zero_weight > 0.0 {
self.variances[c][j] += zero_weight * self.means[c][j] * self.means[c][j];
}
}
}
for (c_var, &ws) in self.variances.iter_mut().zip(w_sums.iter()) {
if ws > 0.0 {
for vj in c_var.iter_mut() {
*vj /= ws;
}
}
}
let max_var = self
.variances
.iter()
.flat_map(|cv| cv.iter())
.copied()
.fold(0.0_f64, f64::max);
let epsilon = self.var_smoothing * max_var.max(1e-12);
for c_var in &mut self.variances {
for vj in c_var.iter_mut() {
*vj += epsilon;
}
}
let w_total: f64 = w_sums.iter().sum();
self.priors = w_sums.iter().map(|&ws| ws / w_total).collect();
self.fitted = true;
Ok(())
}
pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
let probas = self.predict_proba_sparse(features)?;
Ok(probas
.iter()
.map(|probs| {
probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0.0, |(idx, _)| idx as f64)
})
.collect())
}
#[allow(clippy::needless_range_loop)]
pub fn predict_proba_sparse(&self, features: &CsrMatrix) -> Result<Vec<Vec<f64>>> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
let n_feat = self.means[0].len();
Ok((0..features.n_rows())
.map(|i| {
let row = features.row(i);
let mut log_probs: Vec<f64> = (0..self.n_classes)
.map(|c| {
let mut log_p = self.priors[c].ln();
for j in 0..n_feat {
let mean = self.means[c][j];
let var = self.variances[c][j];
log_p +=
-0.5 * (mean * mean / var + var.ln() + std::f64::consts::TAU.ln());
}
log_p
})
.collect();
for (col, val) in row.iter() {
if col >= n_feat {
continue;
}
for c in 0..self.n_classes {
let mean = self.means[c][col];
let var = self.variances[c][col];
log_probs[c] -=
-0.5 * (mean * mean / var + var.ln() + std::f64::consts::TAU.ln());
log_probs[c] += -0.5
* ((val - mean).powi(2) / var + var.ln() + std::f64::consts::TAU.ln());
}
}
let max_log = log_probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let sum: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
for lp in &mut log_probs {
*lp = ((*lp - max_log).exp()) / sum;
}
log_probs
})
.collect())
}
pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
crate::version::check_schema_version(self._schema_version)?;
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
let probas = self.predict_proba(features)?;
Ok(probas
.iter()
.map(|probs| {
probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0.0, |(idx, _)| idx as f64)
})
.collect())
}
pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
Ok(features
.iter()
.map(|row| {
let mut log_probs: Vec<f64> = (0..self.n_classes)
.map(|c| {
let mut log_p = self.priors[c].ln();
for (j, &x) in row.iter().enumerate() {
if j < self.means[c].len() {
let mean = self.means[c][j];
let var = self.variances[c][j];
log_p += -0.5
* ((x - mean).powi(2) / var
+ var.ln()
+ std::f64::consts::TAU.ln());
}
}
log_p
})
.collect();
let max_log = log_probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let sum: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
for lp in &mut log_probs {
*lp = ((*lp - max_log).exp()) / sum;
}
log_probs
})
.collect())
}
}
impl Default for GaussianNb {
fn default() -> Self {
Self::new()
}
}
impl PartialFit for GaussianNb {
fn partial_fit(&mut self, data: &Dataset) -> Result<()> {
let n = data.n_samples();
let m = data.n_features();
if n == 0 {
if self.is_initialized() {
return Ok(());
}
return Err(ScryLearnError::EmptyDataset);
}
if !self.is_initialized() {
self.n_classes = data.n_classes();
self.n_features_partial = m;
self.partial_count = vec![0.0; self.n_classes];
self.partial_sum = vec![vec![0.0; m]; self.n_classes];
self.partial_sum_sq = vec![vec![0.0; m]; self.n_classes];
} else if m != self.n_features_partial {
return Err(ScryLearnError::ShapeMismatch {
expected: self.n_features_partial,
got: m,
});
}
let sample_weights = compute_sample_weights(&data.target, &self.class_weight);
let mat = data.matrix();
for (i, (&sw, &target_val)) in sample_weights.iter().zip(data.target.iter()).enumerate() {
let c = target_val as usize;
if c >= self.n_classes {
let new_n = c + 1;
self.partial_count.resize(new_n, 0.0);
self.partial_sum.resize(new_n, vec![0.0; m]);
self.partial_sum_sq.resize(new_n, vec![0.0; m]);
self.n_classes = new_n;
}
self.partial_count[c] += sw;
for j in 0..m {
let x = mat.get(i, j);
self.partial_sum[c][j] += sw * x;
self.partial_sum_sq[c][j] += sw * x * x;
}
}
self.means = vec![vec![0.0; m]; self.n_classes];
self.variances = vec![vec![0.0; m]; self.n_classes];
for c in 0..self.n_classes {
if self.partial_count[c] > 0.0 {
let cnt = self.partial_count[c];
for j in 0..m {
let mean = self.partial_sum[c][j] / cnt;
let var = (self.partial_sum_sq[c][j] / cnt - mean * mean).max(0.0);
self.means[c][j] = mean;
self.variances[c][j] = var;
}
}
}
let max_var = self
.variances
.iter()
.flat_map(|cv| cv.iter())
.copied()
.fold(0.0_f64, f64::max);
let epsilon = self.var_smoothing * max_var.max(1e-12);
for c_var in &mut self.variances {
for vj in c_var.iter_mut() {
*vj += epsilon;
}
}
let w_total: f64 = self.partial_count.iter().sum();
self.priors = self.partial_count.iter().map(|&c| c / w_total).collect();
self.fitted = true;
Ok(())
}
fn is_initialized(&self) -> bool {
!self.partial_count.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gaussian_nb_simple() {
let features = vec![
vec![1.0, 1.5, 2.0, 8.0, 8.5, 9.0],
vec![1.0, 1.5, 2.0, 8.0, 8.5, 9.0],
];
let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
let mut nb = GaussianNb::new();
nb.fit(&data).unwrap();
let preds = nb.predict(&[vec![1.0, 1.0], vec![9.0, 9.0]]).unwrap();
assert!((preds[0] - 0.0).abs() < 1e-6);
assert!((preds[1] - 1.0).abs() < 1e-6);
}
#[test]
fn test_sparse_gaussian_nb_matches_dense() {
let features = vec![
vec![1.0, 1.5, 2.0, 8.0, 8.5, 9.0],
vec![1.0, 1.5, 2.0, 8.0, 8.5, 9.0],
];
let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let data = Dataset::new(
features.clone(),
target.clone(),
vec!["x".into(), "y".into()],
"class",
);
let mut nb_dense = GaussianNb::new();
nb_dense.fit(&data).unwrap();
let csc = CscMatrix::from_dense(&features);
let mut nb_sparse = GaussianNb::new();
nb_sparse.fit_sparse(&csc, &target).unwrap();
for c in 0..2 {
for j in 0..2 {
assert!(
(nb_dense.class_means()[c][j] - nb_sparse.class_means()[c][j]).abs() < 1e-6,
"Mean mismatch: class={c} feat={j}"
);
}
}
let test = vec![vec![1.0, 1.0], vec![9.0, 9.0]];
let preds_dense = nb_dense.predict(&test).unwrap();
let csr = CsrMatrix::from_dense(&test);
let preds_sparse = nb_sparse.predict_sparse(&csr).unwrap();
for (d, s) in preds_dense.iter().zip(preds_sparse.iter()) {
assert!((d - s).abs() < 1e-6, "Dense={d} vs Sparse={s}");
}
}
#[test]
fn test_partial_fit_is_initialized() {
let mut nb = GaussianNb::new();
assert!(!nb.is_initialized());
let features = vec![vec![1.0, 2.0], vec![1.0, 2.0]];
let target = vec![0.0, 1.0];
let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
nb.partial_fit(&data).unwrap();
assert!(nb.is_initialized());
}
#[test]
fn test_partial_fit_3_batches_matches_fit() {
let features = vec![
vec![1.0, 1.5, 2.0, 8.0, 8.5, 9.0],
vec![1.0, 1.5, 2.0, 8.0, 8.5, 9.0],
];
let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let all_data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
let mut nb_full = GaussianNb::new();
nb_full.fit(&all_data).unwrap();
let mut nb_partial = GaussianNb::new();
for i in 0..3 {
let s = i * 2;
let batch = Dataset::new(
vec![
vec![all_data.features[0][s], all_data.features[0][s + 1]],
vec![all_data.features[1][s], all_data.features[1][s + 1]],
],
vec![all_data.target[s], all_data.target[s + 1]],
vec!["x".into(), "y".into()],
"class",
);
nb_partial.partial_fit(&batch).unwrap();
}
for c in 0..2 {
for j in 0..2 {
assert!(
(nb_full.class_means()[c][j] - nb_partial.class_means()[c][j]).abs() < 1e-10,
"Mean mismatch: class={c}, feat={j}: {} vs {}",
nb_full.class_means()[c][j],
nb_partial.class_means()[c][j],
);
}
}
let test = vec![vec![1.0, 1.0], vec![9.0, 9.0]];
let preds_full = nb_full.predict(&test).unwrap();
let preds_partial = nb_partial.predict(&test).unwrap();
assert_eq!(preds_full, preds_partial);
}
#[test]
fn test_partial_fit_classifies_correctly() {
let mut nb = GaussianNb::new();
let b1 = Dataset::new(
vec![vec![1.0, 1.5, 2.0], vec![1.0, 1.5, 2.0]],
vec![0.0, 0.0, 0.0],
vec!["x".into(), "y".into()],
"class",
);
let b2 = Dataset::new(
vec![vec![8.0, 8.5, 9.0], vec![8.0, 8.5, 9.0]],
vec![1.0, 1.0, 1.0],
vec!["x".into(), "y".into()],
"class",
);
nb.partial_fit(&b1).unwrap();
nb.partial_fit(&b2).unwrap();
let preds = nb.predict(&[vec![1.0, 1.0], vec![9.0, 9.0]]).unwrap();
assert!((preds[0] - 0.0).abs() < 1e-6, "x=1 should be class 0");
assert!((preds[1] - 1.0).abs() < 1e-6, "x=9 should be class 1");
}
#[test]
fn test_zero_variance_features() {
let features = vec![
vec![5.0, 5.0, 5.0, 5.0, 5.0, 5.0],
vec![5.0, 5.0, 5.0, 5.0, 5.0, 5.0],
];
let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
let mut nb = GaussianNb::new();
nb.fit(&data).unwrap();
for c_var in nb.class_variances() {
for &v in c_var {
assert!(v > 1e-300, "variance {v} is subnormal — smoothing failed");
assert!(v.is_normal(), "variance must be a normal float");
}
}
let preds = nb.predict(&[vec![5.0, 5.0]]).unwrap();
assert!(preds[0].is_finite(), "prediction should be finite");
let preds2 = nb.predict(&[vec![10.0, 10.0]]).unwrap();
assert!(
preds2[0].is_finite(),
"off-center prediction should be finite"
);
}
#[test]
fn test_auto_dispatch_sparse_fit() {
use crate::sparse::CscMatrix;
let features = vec![vec![1.0, 2.0, 8.0, 9.0], vec![1.0, 2.0, 8.0, 9.0]];
let target = vec![0.0, 0.0, 1.0, 1.0];
let csc = CscMatrix::from_dense(&features);
let data = crate::dataset::Dataset::from_sparse(
csc,
target,
vec!["x".into(), "y".into()],
"class",
);
let mut nb = GaussianNb::new();
nb.fit(&data).unwrap();
let preds = nb.predict(&[vec![1.5, 1.5], vec![8.5, 8.5]]).unwrap();
assert!(
(preds[0] - 0.0).abs() < 1e-6,
"Auto-dispatch sparse: x=1.5 should be class 0"
);
assert!(
(preds[1] - 1.0).abs() < 1e-6,
"Auto-dispatch sparse: x=8.5 should be class 1"
);
}
}