use crate::error::PramanaError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct BayesianUpdate {
pub prior: f64,
pub likelihood: f64,
pub evidence: f64,
}
impl BayesianUpdate {
pub fn new(prior: f64, likelihood: f64, evidence: f64) -> Result<Self, PramanaError> {
if evidence == 0.0 {
return Err(PramanaError::InvalidParameter(
"evidence P(B) must be non-zero".into(),
));
}
if !(0.0..=1.0).contains(&prior) {
return Err(PramanaError::InvalidParameter(
"prior must be in [0, 1]".into(),
));
}
if !(0.0..=1.0).contains(&likelihood) {
return Err(PramanaError::InvalidParameter(
"likelihood must be in [0, 1]".into(),
));
}
if !(0.0..=1.0).contains(&evidence) {
return Err(PramanaError::InvalidParameter(
"evidence must be in [0, 1]".into(),
));
}
Ok(Self {
prior,
likelihood,
evidence,
})
}
#[must_use]
#[inline]
pub fn posterior(&self) -> f64 {
self.likelihood * self.prior / self.evidence
}
}
#[must_use = "returns the posterior probability"]
pub fn bayes_theorem(prior: f64, likelihood: f64, evidence: f64) -> Result<f64, PramanaError> {
if evidence == 0.0 {
return Err(PramanaError::InvalidParameter(
"evidence P(B) must be non-zero".into(),
));
}
Ok(likelihood * prior / evidence)
}
#[must_use = "returns the predicted class index"]
pub fn naive_bayes_classify(
features: &[f64],
class_priors: &[f64],
class_likelihoods: &[Vec<f64>],
) -> Result<usize, PramanaError> {
if class_priors.is_empty() {
return Err(PramanaError::InvalidSample(
"need at least one class".into(),
));
}
if class_priors.len() != class_likelihoods.len() {
return Err(PramanaError::DimensionMismatch(
"class_priors and class_likelihoods must have the same length".into(),
));
}
for (i, likelihoods) in class_likelihoods.iter().enumerate() {
if likelihoods.len() != features.len() {
return Err(PramanaError::DimensionMismatch(format!(
"class {} likelihoods length {} != features length {}",
i,
likelihoods.len(),
features.len()
)));
}
}
let mut best_class = 0;
let mut best_log_posterior = f64::NEG_INFINITY;
for (c, prior) in class_priors.iter().enumerate() {
if *prior <= 0.0 {
continue;
}
let mut log_posterior = prior.ln();
for &likelihood in &class_likelihoods[c] {
if likelihood <= 0.0 {
log_posterior += 1e-300_f64.ln();
} else {
log_posterior += likelihood.ln();
}
}
if log_posterior > best_log_posterior {
best_log_posterior = log_posterior;
best_class = c;
}
}
Ok(best_class)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bayes_theorem() {
let result = bayes_theorem(0.01, 0.9, 0.1).unwrap();
assert!((result - 0.09).abs() < 1e-10);
}
#[test]
fn test_bayesian_update() {
let update = BayesianUpdate::new(0.5, 0.8, 0.6).unwrap();
let posterior = update.posterior();
let expected = 0.8 * 0.5 / 0.6;
assert!((posterior - expected).abs() < 1e-10);
}
#[test]
fn test_bayes_zero_evidence() {
assert!(bayes_theorem(0.5, 0.5, 0.0).is_err());
}
#[test]
fn test_naive_bayes() {
let features = [1.0, 0.5];
let priors = [0.5, 0.5];
let likelihoods = vec![vec![0.9, 0.8], vec![0.1, 0.2]];
let class = naive_bayes_classify(&features, &priors, &likelihoods).unwrap();
assert_eq!(class, 0);
}
#[test]
fn serde_roundtrip() {
let update = BayesianUpdate::new(0.3, 0.7, 0.4).unwrap();
let json = serde_json::to_string(&update).unwrap();
let update2: BayesianUpdate = serde_json::from_str(&json).unwrap();
assert_eq!(update.prior, update2.prior);
assert_eq!(update.likelihood, update2.likelihood);
assert_eq!(update.evidence, update2.evidence);
}
}