use super::evaluation::evaluate;
use super::{Case, Rule};
use itertools::Itertools;
use ndarray::{ArrayView, Ix1, Ix2, Zip};
use rustc_hash::FxHasher;
use std::collections::HashMap;
use std::hash::{BuildHasherDefault, Hash};
pub fn discover<A, C>(
attributes: &ArrayView<A, Ix2>,
classes: &ArrayView<C, Ix1>,
) -> Option<(usize, Rule<A, C>)>
where
A: Eq + Hash + Clone + std::fmt::Debug,
C: Eq + Hash + Clone + std::fmt::Debug,
{
let rules: Vec<Rule<A, C>> = generate_hypotheses(attributes, classes);
rules.into_iter().enumerate().max_by(|(_i, a), (_j, b)| {
a.accuracy.partial_cmp(&b.accuracy).unwrap_or(std::cmp::Ordering::Equal)
})
}
#[cfg(test)]
mod test {
use super::*;
use crate::Accuracy;
use ndarray::prelude::*;
#[test]
fn test1() {
let attributes = array![
["good", "small", "yes"],
["good", "big", "no"],
["good", "big", "no"],
["bad", "medium", "no"],
["good", "medium", "only cats"],
["good", "small", "only cats"],
["bad", "medium", "yes"],
["bad", "small", "yes"],
["bad", "medium", "yes"],
["bad", "small", "no"],
];
let classes = array![
"high", "high", "high", "medium", "medium", "medium", "medium", "low", "low", "low",
];
let rule = discover(&attributes.view(), &classes.view());
let expected_rule = Rule {
cases: vec![
Case { attribute_value: "small", predicted_class: "low" },
Case { attribute_value: "big", predicted_class: "high" },
Case { attribute_value: "medium", predicted_class: "medium" },
],
accuracy: Accuracy(0.7),
};
assert_eq!(rule, Some((1, expected_rule)));
}
}
fn generate_hypotheses<A: Eq + Hash + Clone, C: Eq + Hash + Clone>(
attributes: &ArrayView<A, Ix2>,
classes: &ArrayView<C, Ix1>,
) -> Vec<Rule<A, C>> {
let mut hs = Vec::new();
for col in attributes.gencolumns() {
let hypothesis = generate_rule_for_attribute(&col, classes);
hs.push(hypothesis);
}
hs
}
fn generate_rule_for_attribute<A, C>(
attribute_values: &ArrayView<A, Ix1>,
classes: &ArrayView<C, Ix1>,
) -> Rule<A, C>
where
A: Eq + Hash + Clone,
C: Eq + Hash + Clone,
{
let mut cases: Vec<Case<A, C>> = Vec::new();
let unique_values = attribute_values.iter().unique();
for v in unique_values {
let mut class_count = HashMap::with_hasher(BuildHasherDefault::<FxHasher>::default());
Zip::from(attribute_values).and(classes).apply(|attribute_value, class| {
if attribute_value == v {
*class_count.entry(class).or_insert(0) += 1;
}
});
let maybe_most_frequent_class =
class_count.into_iter().max_by_key(|&(_, count)| count).map(|(class, _)| class);
if let Some(class) = maybe_most_frequent_class {
cases.push(Case { attribute_value: v.to_owned(), predicted_class: class.to_owned() });
}
}
let accuracy = evaluate(&cases, attribute_values, classes);
Rule { cases, accuracy }
}