use std::collections::HashMap;
use std::collections::HashSet;
use std::collections::hash_map::Keys;
use std::iter::FromIterator;
use std::vec::Vec;
struct Attributes {
attributes: HashMap<String, HashMap<String, i64>>,
}
impl Attributes {
pub fn new() -> Attributes {
Attributes {
attributes: HashMap::new(),
}
}
fn add(&mut self, attribute: &String, label: &String) {
let labels = self.attributes
.entry(attribute.to_string())
.or_insert(HashMap::new());
let value = labels.entry((*label).to_string()).or_insert(0);
*value += 1;
}
fn get_frequency(&mut self, attribute: &String, label: &String) -> (Option<&i64>, bool) {
match self.attributes.get(attribute) {
Some(labels) => match labels.get(label) {
Some(value) => return (Some(value), true),
None => return (None, true),
},
None => return (None, false),
}
}
}
struct Labels {
counts: HashMap<String, i64>,
}
impl Labels {
pub fn new() -> Labels {
Labels {
counts: HashMap::new(),
}
}
fn add(&mut self, label: &String) {
let value = self.counts.entry(label.to_string()).or_insert(0);
*value += 1;
}
fn get_count(&mut self, label: &String) -> Option<&i64> {
return self.counts.get(label);
}
fn get_labels(&mut self) -> Keys<String, i64> {
return self.counts.keys();
}
fn get_total(&mut self) -> i64 {
return self.counts.values().fold(0, |acc, x| acc + x);
}
}
struct Model {
labels: Labels,
attributes: Attributes,
}
impl Model {
pub fn new() -> Model {
Model {
labels: Labels::new(),
attributes: Attributes::new(),
}
}
fn train(&mut self, data: &Vec<String>, label: &String) {
self.labels.add(label);
for attribute in data {
self.attributes.add(attribute, label);
}
}
}
pub struct NaiveBayes {
model: Model,
minimum_probability: f64,
minimum_log_probability: f64
}
impl NaiveBayes {
pub fn new() -> NaiveBayes {
NaiveBayes {
model: Model::new(),
minimum_probability: 1e-9,
minimum_log_probability: -100.0,
}
}
fn prior(&mut self, label: &String) -> Option<f64> {
let total = *(&self.model.labels.get_total()) as f64;
let label = &self.model.labels.get_count(label);
if label.is_some() && total > 0.0 {
return Some(*label.unwrap() as f64 / total);
} else {
return None;
}
}
fn log_prior(&mut self, label: &String) -> Option<f64> {
let total = *(&self.model.labels.get_total()) as f64;
let label = &self.model.labels.get_count(label);
if label.is_some() && total > 0.0 {
return Some((*label.unwrap() as f64).ln() - total.ln());
} else {
return None;
}
}
fn calculate_attr_prob(&mut self, attribute: &String, label: &String) -> Option<f64> {
match self.model.attributes.get_frequency(attribute, label) {
(Some(frequency), true) => match self.model.labels.get_count(label) {
Some(count) => return Some((*frequency as f64) / (*count as f64)),
None => return None,
},
(None, true) => return Some(self.minimum_probability),
(None, false) => return None,
(Some(_), false) => None,
}
}
fn calculate_attr_log_prob(&mut self, attribute: &String, label: &String) -> Option<f64> {
match self.model.attributes.get_frequency(attribute, label) {
(Some(frequency), true) => match self.model.labels.get_count(label) {
Some(count) => return Some((*frequency as f64).ln() - (*count as f64).ln()),
None => return None,
},
(None, true) => return Some(self.minimum_log_probability),
(None, false) => return None,
(Some(_), false) => None,
}
}
fn label_prob(&mut self, label: &String, attrs: &HashSet<String>) -> Vec<f64> {
let mut probs: Vec<f64> = Vec::new();
for attr in attrs {
match self.calculate_attr_prob(attr, label) {
Some(p) => {
probs.push(p);
}
None => {}
}
}
return probs;
}
fn label_log_prob(&mut self, label: &String, attrs: &HashSet<String>) -> Vec<f64> {
let mut probs: Vec<f64> = Vec::new();
for attr in attrs {
match self.calculate_attr_log_prob(attr, label) {
Some(p) => {
probs.push(p);
}
None => {}
}
}
return probs;
}
pub fn train(&mut self, data: &Vec<String>, label: &String) {
self.model.train(data, label);
}
pub fn classify(&mut self, data: &Vec<String>) -> HashMap<String, f64> {
let attribute_set: HashSet<String> = HashSet::from_iter(data.iter().cloned());
let mut result: HashMap<String, f64> = HashMap::new();
let labels: HashSet<String> =
HashSet::from_iter(self.model.labels.get_labels().into_iter().cloned());
for label in labels {
let p = self.label_prob(&label, &attribute_set);
let p_iter = p.into_iter().fold(1.0, |acc, x| acc * x);
let _value = result
.entry(label.to_string())
.or_insert(p_iter * self.prior(&label).unwrap());
}
return result;
}
pub fn log_classify(&mut self, data: &Vec<String>) -> HashMap<String, f64> {
let attribute_set: HashSet<String> = HashSet::from_iter(data.iter().cloned());
let mut result: HashMap<String, f64> = HashMap::new();
let labels: HashSet<String> =
HashSet::from_iter(self.model.labels.get_labels().into_iter().cloned());
for label in labels {
let p = self.label_log_prob(&label, &attribute_set);
let max = p.iter().cloned().fold(-1./0. , f64::max);
let p_iter = p.into_iter().fold(0.0, |acc, x| acc + (x - max).exp());
let _value = result
.entry(label.to_string())
.or_insert(max + p_iter.ln() + self.log_prior(&label).unwrap());
}
return result;
}
}
#[cfg(test)]
mod test_attributes {
use super::*;
#[test]
fn attribute_add() {
let mut model = Attributes::new();
model.add(&"rust".to_string(), &"naive".to_string());
assert_eq!(
*model
.get_frequency(&"rust".to_string(), &"naive".to_string())
.0
.unwrap(),
1
);
}
#[test]
fn get_non_existing() {
let mut model = Attributes::new();
assert_eq!(
model
.get_frequency(&"rust".to_string(), &"naive".to_string())
.0,
None
);
}
}
#[cfg(test)]
mod test_labels {
use super::*;
#[test]
fn label_add() {
let mut labels = Labels::new();
labels.add(&"rust".to_string());
assert_eq!(*labels.get_count(&"rust".to_string()).unwrap(), 1);
}
#[test]
fn label_get_nonexistent() {
let mut labels = Labels::new();
assert_eq!(labels.get_count(&"rust".to_string()), None);
}
#[test]
fn get_labels() {
let mut labels = Labels::new();
labels.add(&"rust".to_string());
assert_eq!(labels.get_labels().len(), 1);
assert_eq!(labels.get_labels().last().unwrap(), "rust");
}
#[test]
fn get_counts() {
let mut labels = Labels::new();
labels.add(&"rust".to_string());
labels.add(&"rust".to_string());
assert_eq!(labels.get_labels().len(), 1);
assert_eq!(*labels.get_count(&"rust".to_string()).unwrap(), 2);
}
#[test]
fn get_nonexistent_counts() {
let mut labels = Labels::new();
assert_eq!(labels.get_labels().len(), 0);
assert_eq!(labels.get_count(&"rust".to_string()), None);
}
#[test]
fn get_nonexistent_total() {
let mut labels = Labels::new();
assert_eq!(labels.get_total(), 0);
}
#[test]
fn get_total() {
let mut labels = Labels::new();
labels.add(&"rust".to_string());
labels.add(&"rust".to_string());
labels.add(&"naive".to_string());
labels.add(&"bayes".to_string());
assert_eq!(labels.get_total(), 4);
}
}
#[cfg(test)]
mod test_naive_bayes {
use super::*;
use std::f64::consts::LN_2;
#[test]
fn test_prior() {
let mut nb = NaiveBayes::new();
let mut data: Vec<String> = Vec::new();
data.push("rust".to_string());
data.push("naive".to_string());
data.push("bayes".to_string());
nb.model.train(&data, &"👍".to_string());
let prior = nb.prior(&"👍".to_string());
assert_eq!(prior, Some(1.0));
}
#[test]
fn test_log_prior() {
let mut nb = NaiveBayes::new();
let mut data: Vec<String> = Vec::new();
data.push("rust".to_string());
data.push("naive".to_string());
data.push("bayes".to_string());
nb.model.train(&data, &"👍".to_string());
let prior = nb.log_prior(&"👍".to_string());
assert_eq!(prior, Some(0.0));
}
#[test]
fn test_prior_nonexistent() {
let mut nb = NaiveBayes::new();
let mut data: Vec<String> = Vec::new();
data.push("rust".to_string());
data.push("naive".to_string());
data.push("bayes".to_string());
nb.model.train(&data, &"👍".to_string());
let prior = nb.prior(&"👎".to_string());
assert_eq!(prior, None);
}
#[test]
fn test_classification() {
let mut nb = NaiveBayes::new();
let mut data: Vec<String> = Vec::new();
data.push("rust".to_string());
data.push("naive".to_string());
data.push("bayes".to_string());
nb.model.train(&data, &"👍".to_string());
let mut data2: Vec<String> = Vec::new();
data2.push("golang".to_string());
data2.push("java".to_string());
data2.push("javascript".to_string());
nb.model.train(&data2, &"👎".to_string());
let classes = nb.classify(
&(vec![
"rust".to_string(),
"scala".to_string(),
"c++".to_string(),
]),
);
assert_eq!(classes.get(&"👍".to_string()).unwrap(), &0.5);
assert_eq!(classes.get(&"👎".to_string()).unwrap(), &0.0000000005);
print!("{:?}", classes);
}
#[test]
fn test_log_classification() {
let mut nb = NaiveBayes::new();
let mut data: Vec<String> = Vec::new();
data.push("rust".to_string());
data.push("naive".to_string());
data.push("bayes".to_string());
nb.model.train(&data, &"👍".to_string());
let mut data2: Vec<String> = Vec::new();
data2.push("golang".to_string());
data2.push("java".to_string());
data2.push("javascript".to_string());
nb.model.train(&data2, &"👎".to_string());
let classes = nb.log_classify(
&(vec![
"rust".to_string(),
"scala".to_string(),
"c++".to_string(),
]),
);
assert_eq!(classes.get(&"👍".to_string()).unwrap(), &-LN_2);
assert_eq!(classes.get(&"👎".to_string()).unwrap(), &-100.69314718055995);
print!("{:?}", classes);
}
}