use std::f64::consts::PI;
use std::fmt;
use crate::learner::StreamingLearner;
struct ClassStats {
count: f64,
mean: Vec<f64>,
m2: Vec<f64>,
}
impl ClassStats {
fn new(n_features: usize) -> Self {
Self {
count: 0.0,
mean: vec![0.0; n_features],
m2: vec![0.0; n_features],
}
}
#[inline]
fn update(&mut self, features: &[f64], weight: f64) {
self.count += weight;
for (j, &fj) in features.iter().enumerate() {
let delta = fj - self.mean[j];
self.mean[j] += delta * weight / self.count;
let delta2 = fj - self.mean[j];
self.m2[j] += delta * delta2 * weight;
}
}
#[inline]
fn variance(&self, j: usize) -> f64 {
if self.count > 0.0 {
self.m2[j] / self.count
} else {
0.0
}
}
}
impl Clone for ClassStats {
fn clone(&self) -> Self {
Self {
count: self.count,
mean: self.mean.clone(),
m2: self.m2.clone(),
}
}
}
impl fmt::Debug for ClassStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ClassStats")
.field("count", &self.count)
.field("n_features", &self.mean.len())
.finish()
}
}
pub struct GaussianNB {
class_stats: Vec<ClassStats>,
samples_seen: u64,
n_features: Option<usize>,
var_smoothing: f64,
}
impl GaussianNB {
#[inline]
pub fn new() -> Self {
Self {
class_stats: Vec::new(),
samples_seen: 0,
n_features: None,
var_smoothing: 1e-9,
}
}
#[inline]
pub fn with_var_smoothing(var_smoothing: f64) -> Self {
Self {
class_stats: Vec::new(),
samples_seen: 0,
n_features: None,
var_smoothing,
}
}
}
impl GaussianNB {
#[inline]
pub fn n_classes(&self) -> usize {
self.class_stats.len()
}
pub fn class_prior(&self) -> Vec<f64> {
let total: f64 = self.class_stats.iter().map(|s| s.count).sum();
if total == 0.0 {
return vec![0.0; self.class_stats.len()];
}
self.class_stats.iter().map(|s| s.count / total).collect()
}
pub fn predict_log_proba(&self, features: &[f64]) -> Vec<f64> {
if self.class_stats.is_empty() {
return Vec::new();
}
let total_count: f64 = self.class_stats.iter().map(|s| s.count).sum();
if total_count == 0.0 {
return vec![0.0; self.class_stats.len()];
}
let smoothing = self.compute_smoothing();
self.class_stats
.iter()
.map(|stats| {
if stats.count == 0.0 {
return f64::NEG_INFINITY;
}
let log_prior = (stats.count / total_count).ln();
let log_likelihood: f64 = (0..features.len())
.map(|j| {
let var = stats.variance(j) + smoothing;
gaussian_log_likelihood(features[j], stats.mean[j], var)
})
.sum();
log_prior + log_likelihood
})
.collect()
}
pub fn predict_proba(&self, features: &[f64]) -> Vec<f64> {
let log_probs = self.predict_log_proba(features);
if log_probs.is_empty() {
return Vec::new();
}
log_sum_exp_normalize(&log_probs)
}
}
impl GaussianNB {
fn compute_smoothing(&self) -> f64 {
let mut max_var: f64 = 0.0;
for stats in &self.class_stats {
if stats.count == 0.0 {
continue;
}
for j in 0..stats.mean.len() {
let v = stats.variance(j);
if v > max_var {
max_var = v;
}
}
}
self.var_smoothing * max_var.max(1e-12)
}
fn ensure_class(&mut self, label: usize, n_features: usize) {
while self.class_stats.len() <= label {
self.class_stats.push(ClassStats::new(n_features));
}
}
}
impl StreamingLearner for GaussianNB {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
let n_feat = features.len();
if self.n_features.is_none() {
self.n_features = Some(n_feat);
}
let label = target.round() as usize;
self.ensure_class(label, n_feat);
self.class_stats[label].update(features, weight);
self.samples_seen += 1;
}
fn predict(&self, features: &[f64]) -> f64 {
if self.class_stats.is_empty() {
return 0.0;
}
let log_probs = self.predict_log_proba(features);
let mut best_class = 0;
let mut best_score = f64::NEG_INFINITY;
for (k, &lp) in log_probs.iter().enumerate() {
if lp > best_score {
best_score = lp;
best_class = k;
}
}
best_class as f64
}
#[inline]
fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
fn reset(&mut self) {
self.class_stats.clear();
self.samples_seen = 0;
self.n_features = None;
}
}
impl Clone for GaussianNB {
fn clone(&self) -> Self {
Self {
class_stats: self.class_stats.clone(),
samples_seen: self.samples_seen,
n_features: self.n_features,
var_smoothing: self.var_smoothing,
}
}
}
impl fmt::Debug for GaussianNB {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GaussianNB")
.field("n_classes", &self.class_stats.len())
.field("samples_seen", &self.samples_seen)
.field("n_features", &self.n_features)
.field("var_smoothing", &self.var_smoothing)
.finish()
}
}
impl Default for GaussianNB {
#[inline]
fn default() -> Self {
Self::new()
}
}
#[inline]
fn gaussian_log_likelihood(x: f64, mean: f64, var: f64) -> f64 {
-0.5 * ((2.0 * PI * var).ln() + (x - mean).powi(2) / var)
}
fn log_sum_exp_normalize(log_probs: &[f64]) -> Vec<f64> {
let max_lp = log_probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
if max_lp == f64::NEG_INFINITY {
let n = log_probs.len();
return vec![1.0 / n as f64; n];
}
let exps: Vec<f64> = log_probs.iter().map(|&lp| (lp - max_lp).exp()).collect();
let sum: f64 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
impl crate::automl::DiagnosticSource for GaussianNB {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_creation() {
let nb = GaussianNB::new();
assert_eq!(nb.n_samples_seen(), 0);
assert_eq!(nb.n_classes(), 0);
assert_eq!(nb.n_features, None);
assert!((nb.var_smoothing - 1e-9).abs() < 1e-15);
let nb2 = GaussianNB::with_var_smoothing(1e-6);
assert!((nb2.var_smoothing - 1e-6).abs() < 1e-15);
let nb3 = GaussianNB::default();
assert_eq!(nb3.n_samples_seen(), 0);
}
#[test]
fn test_two_class() {
let mut nb = GaussianNB::new();
for _ in 0..50 {
nb.train(&[1.0, 1.0], 0.0);
}
for _ in 0..50 {
nb.train(&[10.0, 10.0], 1.0);
}
assert_eq!(nb.n_samples_seen(), 100);
assert_eq!(nb.n_classes(), 2);
assert_eq!(nb.predict(&[1.5, 1.5]), 0.0);
assert_eq!(nb.predict(&[9.5, 9.5]), 1.0);
}
#[test]
fn test_three_class() {
let mut nb = GaussianNB::new();
for _ in 0..40 {
nb.train(&[0.0, 0.0], 0.0);
nb.train(&[0.1, -0.1], 0.0);
}
for _ in 0..40 {
nb.train(&[10.0, 0.0], 1.0);
nb.train(&[10.1, -0.1], 1.0);
}
for _ in 0..40 {
nb.train(&[5.0, 10.0], 2.0);
nb.train(&[5.1, 9.9], 2.0);
}
assert_eq!(nb.n_classes(), 3);
assert_eq!(nb.predict(&[0.0, 0.0]), 0.0);
assert_eq!(nb.predict(&[10.0, 0.0]), 1.0);
assert_eq!(nb.predict(&[5.0, 10.0]), 2.0);
}
#[test]
fn test_predict_proba() {
let mut nb = GaussianNB::new();
for _ in 0..50 {
nb.train(&[1.0, 1.0], 0.0);
}
for _ in 0..50 {
nb.train(&[10.0, 10.0], 1.0);
}
let proba = nb.predict_proba(&[1.0, 1.0]);
assert_eq!(proba.len(), 2);
let sum: f64 = proba.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"probabilities should sum to 1.0, got {}",
sum,
);
assert!(
proba[0] > proba[1],
"P(class=0) should be greater for point near class 0: {:?}",
proba,
);
let proba2 = nb.predict_proba(&[10.0, 10.0]);
let sum2: f64 = proba2.iter().sum();
assert!((sum2 - 1.0).abs() < 1e-10);
assert!(proba2[1] > proba2[0]);
}
#[test]
fn test_class_prior() {
let mut nb = GaussianNB::new();
for _ in 0..30 {
nb.train(&[1.0], 0.0);
}
for _ in 0..70 {
nb.train(&[5.0], 1.0);
}
let priors = nb.class_prior();
assert_eq!(priors.len(), 2);
assert!(
(priors[0] - 0.3).abs() < 1e-10,
"prior[0] should be 0.3, got {}",
priors[0],
);
assert!(
(priors[1] - 0.7).abs() < 1e-10,
"prior[1] should be 0.7, got {}",
priors[1],
);
}
#[test]
fn test_incremental_update() {
let mut nb = GaussianNB::new();
nb.train(&[0.0, 0.0], 0.0);
nb.train(&[10.0, 10.0], 1.0);
let proba_early = nb.predict_proba(&[2.0, 2.0]);
let p0_early = proba_early[0];
for _ in 0..50 {
nb.train(&[0.5, 0.5], 0.0);
nb.train(&[-0.5, -0.5], 0.0);
}
let proba_later = nb.predict_proba(&[2.0, 2.0]);
let p0_later = proba_later[0];
assert!(
p0_later > p0_early || (p0_early - p0_later).abs() < 0.05,
"incremental updates should improve or maintain confidence: early={}, later={}",
p0_early,
p0_later,
);
}
#[test]
fn test_reset() {
let mut nb = GaussianNB::new();
for _ in 0..50 {
nb.train(&[1.0, 2.0], 0.0);
}
for _ in 0..50 {
nb.train(&[5.0, 6.0], 1.0);
}
assert_eq!(nb.n_samples_seen(), 100);
assert_eq!(nb.n_classes(), 2);
assert!(nb.n_features.is_some());
nb.reset();
assert_eq!(nb.n_samples_seen(), 0);
assert_eq!(nb.n_classes(), 0);
assert_eq!(nb.n_features, None);
assert!(nb.class_prior().is_empty());
assert!(nb.predict_proba(&[1.0, 2.0]).is_empty());
assert_eq!(nb.predict(&[1.0, 2.0]), 0.0);
}
#[test]
fn test_predict_batch() {
let mut nb = GaussianNB::new();
for _ in 0..50 {
nb.train(&[1.0, 1.0], 0.0);
}
for _ in 0..50 {
nb.train(&[10.0, 10.0], 1.0);
}
let rows: Vec<&[f64]> = vec![&[1.0, 1.0], &[10.0, 10.0], &[5.0, 5.0]];
let batch = nb.predict_batch(&rows);
assert_eq!(batch.len(), rows.len());
for (i, row) in rows.iter().enumerate() {
let individual = nb.predict(row);
assert!(
(batch[i] - individual).abs() < 1e-12,
"batch[{}]={} != individual={}",
i,
batch[i],
individual,
);
}
}
#[test]
fn test_trait_object() {
let mut nb = GaussianNB::new();
nb.train(&[1.0], 0.0);
nb.train(&[10.0], 1.0);
let mut boxed: Box<dyn StreamingLearner> = Box::new(nb);
boxed.train(&[1.0], 0.0);
assert_eq!(boxed.n_samples_seen(), 3);
let pred = boxed.predict(&[1.0]);
assert!(
pred == 0.0 || pred == 1.0,
"prediction should be a class label"
);
boxed.reset();
assert_eq!(boxed.n_samples_seen(), 0);
}
#[test]
fn test_weighted_samples() {
let mut nb_uniform = GaussianNB::new();
let mut nb_weighted = GaussianNB::new();
for i in 0..50 {
let x = 1.0 + (i as f64) * 0.1;
nb_uniform.train(&[x], 0.0);
}
for i in 0..50 {
let x = 10.0 + (i as f64) * 0.1;
nb_uniform.train(&[x], 1.0);
}
for i in 0..50 {
let x = 1.0 + (i as f64) * 0.1;
nb_weighted.train_one(&[x], 0.0, 10.0);
}
for i in 0..50 {
let x = 10.0 + (i as f64) * 0.1;
nb_weighted.train_one(&[x], 1.0, 1.0);
}
let proba_uniform = nb_uniform.predict_proba(&[5.5]);
let proba_weighted = nb_weighted.predict_proba(&[5.5]);
assert!(
proba_weighted[0] > proba_uniform[0],
"weighted class 0 should have higher posterior: weighted={}, uniform={}",
proba_weighted[0],
proba_uniform[0],
);
}
#[test]
fn test_clone_independence() {
let mut nb = GaussianNB::new();
for _ in 0..30 {
nb.train(&[1.0, 2.0], 0.0);
}
let mut cloned = nb.clone();
assert_eq!(cloned.n_samples_seen(), nb.n_samples_seen());
for _ in 0..30 {
cloned.train(&[10.0, 20.0], 1.0);
}
assert_eq!(nb.n_samples_seen(), 30);
assert_eq!(cloned.n_samples_seen(), 60);
assert_eq!(nb.n_classes(), 1);
assert_eq!(cloned.n_classes(), 2);
}
#[test]
fn test_debug_format() {
let nb = GaussianNB::new();
let debug_str = format!("{:?}", nb);
assert!(debug_str.contains("GaussianNB"));
assert!(debug_str.contains("samples_seen"));
assert!(debug_str.contains("var_smoothing"));
}
}