use std::collections::{HashMap, HashSet};
use crate::metrics::accuracy;
pub struct NaiveBayesClassifier {
pub priors: HashMap<i32, f64>,
pub feature_means: HashMap<i32, Vec<f64>>,
pub feature_vars: HashMap<i32, Vec<f64>>,
pub class_counts: HashMap<i32, i32>,
pub n_features: i32,
}
impl NaiveBayesClassifier {
pub fn new() -> NaiveBayesClassifier {
NaiveBayesClassifier {
priors: HashMap::new(),
feature_means: HashMap::new(),
feature_vars: HashMap::new(),
class_counts: HashMap::new(),
n_features: 0,
}
}
pub fn fit(&mut self, x_train: &Vec<Vec<f64>>, y_train: &Vec<i32>) {
let n_samples = x_train.len();
let n_features = x_train[0].len();
let classes = y_train.iter().collect::<HashSet<_>>();
self.n_features = n_features as i32;
let mut class_counts = HashMap::new();
for &class in classes {
let mut count = 0;
for &y in y_train.iter() {
if y == class {
count += 1;
}
}
class_counts.insert(class, count);
}
self.class_counts = class_counts.clone();
for (class, count) in class_counts.iter() {
self.priors.insert(*class, *count as f64 / n_samples as f64);
}
for (class, count) in class_counts.iter() {
let mut feature_sum = vec![0.0; n_features];
let mut feature_sq_sum = vec![0.0; n_features];
for i in 0..n_samples {
if y_train[i] == *class {
for j in 0..n_features {
feature_sum[j] += x_train[i][j];
feature_sq_sum[j] += x_train[i][j].powi(2);
}
}
}
let class_mean = feature_sum
.iter()
.map(|x| x / *count as f64)
.collect::<Vec<f64>>();
let mut class_var = vec![0.0; n_features];
for i in 0..n_features {
class_var[i] = (feature_sq_sum[i] / *count as f64) - class_mean[i].powi(2);
}
self.feature_means.insert(*class, class_mean);
self.feature_vars.insert(*class, class_var);
}
}
pub fn predict_proba(&self, x_test: &Vec<Vec<f64>>) -> Vec<Vec<f64>> {
let n_features = x_test[0].len();
if n_features != self.feature_means[&0].len() {
panic!("Number of features in x_test is not equal to number of features in x_train");
}
let mut probs = Vec::new();
for x in x_test.iter() {
let mut class_probs = vec![0.0; self.priors.len()];
for (&class, &prior) in self.priors.iter() {
let mut log_prob = prior.ln();
for i in 0..n_features {
let mean = self.feature_means[&class][i];
let var = self.feature_vars[&class][i];
let exponent = (-1.0 * (x[i] - mean).powi(2)) / (2.0 * var);
log_prob += exponent - 0.5 * (2.0 * std::f64::consts::PI * var).ln();
}
class_probs[class as usize] = log_prob;
}
probs.push(class_probs);
}
probs = probs
.iter()
.map(|class_probs| class_probs.iter().map(|x| x.exp()).collect::<Vec<f64>>())
.collect::<Vec<Vec<f64>>>();
probs
}
pub fn predict(&self, x_test: &Vec<Vec<f64>>) -> Vec<i32> {
let probs = self.predict_proba(x_test);
let mut preds = Vec::new();
for prob in probs.iter() {
let mut max_prob = 0.0;
let mut max_class = 0;
for (i, &p) in prob.iter().enumerate() {
if p > max_prob {
max_prob = p;
max_class = i;
}
}
preds.push(max_class as i32);
}
preds
}
pub fn predict_log_proba(&self, x_test: &Vec<Vec<f64>>) -> Vec<Vec<f64>> {
let probs = self.predict_proba(x_test);
probs
.iter()
.map(|class_probs| class_probs.iter().map(|x| x.ln()).collect::<Vec<f64>>())
.collect::<Vec<Vec<f64>>>()
}
pub fn score(&self, x_test: &Vec<Vec<f64>>, y_test: &Vec<i32>) -> HashMap<String, f64> {
let preds = self.predict(&x_test);
accuracy(&preds, &y_test)
}
pub fn get_params(&self) -> HashMap<String, String> {
let params = [
("priors".to_string(), format!("{:?}", self.priors)),
("feature_means".to_string(), format!("{:?}", self.feature_means)),
("feature_vars".to_string(), format!("{:?}", self.feature_vars)),
("class_counts".to_string(), format!("{:?}", self.class_counts)),
("n_features".to_string(), format!("{:?}", self.n_features)),
];
params.iter().cloned().collect()
}
}
#[cfg(test)]
mod tests {
use std::vec;
use super::*;
#[test]
fn test_naive_bayes() {
let x_train = vec![
vec![1.0, 2.0],
vec![2.0, 3.0],
vec![3.0, 4.0],
vec![4.0, 5.0],
vec![5.0, 6.0],
vec![6.0, 7.0],
];
let y_train = vec![0, 0, 0, 1, 1, 1];
let x_test = vec![vec![-1.0, 0.0], vec![4.0, 5.0], vec![7.0, 7.0], vec![1.0, 9.0]];
let mut model = NaiveBayesClassifier::new();
model.fit(&x_train, &y_train);
let preds = model.predict(&x_test);
assert_eq!(preds, vec![0, 1, 1, 0]);
}
}