use ndarray::*;
use itertools::Itertools;
use std::collections::HashMap;
use ordered_float::OrderedFloat;
use crate::Label;
use crate::estimates::{BayesEstimator,some_or_error};
type ObjectValue = usize;
#[derive(Debug)]
struct FrequencyCount {
count: Vec<usize>,
prediction: Option<Label>,
}
impl FrequencyCount {
fn new(n_labels: usize) -> FrequencyCount {
FrequencyCount {
count: vec![0; n_labels],
prediction: None,
}
}
fn predict(&self) -> Option<Label> {
self.prediction
}
fn add_example(&mut self, y: Label) -> bool {
self.count[y] += 1;
let mut updated = false;
if let Some(pred) = self.prediction {
if y != pred {
if self.count[y] > self.count[pred] {
self.prediction = Some(y);
updated = true;
}
}
}
else {
self.prediction = Some(y);
updated = true;
}
updated
}
fn remove_example(&mut self, y: Label) -> bool {
self.count[y] -= 1;
if Some(y) == self.prediction {
let mut new_pred = y;
for (yi, &c) in self.count.iter().enumerate() {
if c > self.count[y] {
new_pred = yi;
}
}
if self.count[new_pred] == 0 {
self.prediction = None;
return true;
}
if self.prediction != Some(new_pred) {
self.prediction = Some(new_pred);
return true;
}
}
false
}
}
pub struct FrequentistEstimator {
joint_count: HashMap<ObjectValue, FrequencyCount>,
priors_count: FrequencyCount,
error_count: usize,
train_x: Vec<ObjectValue>,
train_y: Vec<Label>,
test_x: Vec<ObjectValue>,
test_y: Vec<Label>,
array_to_index: ArrayToIndex,
}
impl FrequentistEstimator {
pub fn new(n_labels: usize, test_x: &ArrayView2<f64>,
test_y: &ArrayView1<Label>)
-> FrequentistEstimator {
let priors_count = FrequencyCount::new(n_labels);
let mut joint_count: HashMap<ObjectValue, FrequencyCount> = HashMap::new();
let mut array_to_index = ArrayToIndex::new();
let test_x = test_x.outer_iter()
.map(|x| array_to_index.map(x))
.collect::<Vec<_>>();
for &x in test_x.iter().unique() {
joint_count.entry(x)
.or_insert_with(|| FrequencyCount::new(n_labels));
}
FrequentistEstimator {
joint_count,
priors_count,
error_count: 0,
train_x: vec![],
train_y: vec![],
test_x: test_x.to_vec(),
test_y: test_y.to_vec(),
array_to_index,
}
}
pub fn from_data(n_labels: usize, train_x: &ArrayView1<ObjectValue>,
train_y: &ArrayView1<Label>, test_x: &ArrayView1<ObjectValue>,
test_y: &ArrayView1<Label>)
-> FrequentistEstimator {
let mut joint_count: HashMap<ObjectValue, FrequencyCount> = HashMap::new();
let mut priors_count = FrequencyCount::new(n_labels);
for &x in test_x.iter().unique() {
joint_count.entry(x)
.or_insert_with(|| FrequencyCount::new(n_labels));
}
for (x, &y) in train_x.iter().zip(train_y) {
assert!(y < n_labels,
"labels' values must be < number of labels");
priors_count.add_example(y);
if let Some(jx) = joint_count.get_mut(x) {
jx.add_example(y);
}
}
let mut error_count = 0;
for (x, &y) in test_x.iter().zip(test_y) {
let jx = joint_count.get(x)
.expect("shouldn't happen");
let pred = match jx.predict() {
Some(pred) => pred,
None => priors_count.predict().expect("not enough info for priors"),
};
if y != pred {
error_count += 1;
}
}
FrequentistEstimator {
joint_count,
priors_count,
error_count,
train_x: train_x.to_vec(),
train_y: train_y.to_vec(),
test_x: test_x.to_vec(),
test_y: test_y.to_vec(),
array_to_index: ArrayToIndex::new(),
}
}
fn add_first_example(&mut self, x: ObjectValue, y: Label) {
self.error_count = 0;
self.priors_count.add_example(y);
let pred = y;
if let Some(jx) = self.joint_count.get_mut(&x) {
jx.add_example(y);
}
for yi in &self.test_y {
let error = if *yi != pred { 1 } else { 0 };
self.error_count += error;
}
}
pub fn remove_one(&mut self) -> Result<(), ()> {
let x = some_or_error(self.train_x.pop())?;
let y = some_or_error(self.train_y.pop())?;
let old_priors_pred = some_or_error(self.priors_count.predict())?;
let priors_changed = self.priors_count.remove_example(y);
if priors_changed {
let new_pred = some_or_error(self.priors_count.predict())?;
for (xi, &yi) in self.test_x.iter().zip(&self.test_y) {
let joint = self.joint_count.get(xi).expect("shouldn't happen");
if joint.predict().is_none() {
let old_error = if yi != old_priors_pred { 1 } else { 0 };
let new_error = if yi != new_pred { 1 } else { 0 };
self.error_count = self.error_count + new_error - old_error;
}
}
}
if let Some(joint) = self.joint_count.get_mut(&x) {
let old_joint_pred = joint.predict()
.expect("shouldn't fail here");
let joint_changed = joint.remove_example(y);
if joint_changed {
let new_pred = match self.priors_count.predict() {
Some(pred) => pred,
None => some_or_error(self.priors_count.predict())?,
};
for (&xi, &yi) in self.test_x.iter().zip(&self.test_y) {
if xi == x {
let old_error = if yi != old_joint_pred { 1 } else { 0 };
let new_error = if yi != new_pred { 1 } else { 0 };
self.error_count = self.error_count + new_error - old_error;
}
}
}
}
Ok(())
}
}
impl BayesEstimator for FrequentistEstimator {
fn add_example(&mut self, x: &ArrayView1<f64>, y: Label) -> Result<(), ()> {
let x = self.array_to_index.map(*x);
self.train_x.push(x);
self.train_y.push(y);
let mut old_priors_pred = match self.priors_count.predict() {
Some(pred) => pred,
None => { self.add_first_example(x, y);
return Ok(())
},
};
let priors_changed = self.priors_count.add_example(y);
if priors_changed {
let new_pred = self.priors_count.predict().unwrap();
for (xi, &yi) in self.test_x.iter().zip(&self.test_y) {
let joint = self.joint_count.get(xi).expect("shouldn't happen");
if joint.predict().is_none() {
let old_error = if yi != old_priors_pred { 1 } else { 0 };
let new_error = if yi != new_pred { 1 } else { 0 };
self.error_count = self.error_count + new_error - old_error;
}
}
old_priors_pred = new_pred;
}
if let Some(joint) = self.joint_count.get_mut(&x) {
let old_pred = match joint.predict() {
Some(pred) => pred,
None => old_priors_pred,
};
let joint_changed = joint.add_example(y);
if joint_changed {
let new_pred = joint.predict().unwrap();
for (&xi, &yi) in self.test_x.iter().zip(&self.test_y) {
if xi == x {
let old_error = if yi != old_pred { 1 } else { 0 };
let new_error = if yi != new_pred { 1 } else { 0 };
self.error_count = self.error_count + new_error - old_error;
}
}
}
}
Ok(())
}
fn get_error_count(&self) -> usize {
self.error_count
}
fn get_error(&self) -> f64 {
(self.error_count as f64) / (self.test_y.len() as f64)
}
fn get_individual_errors(&self) -> Vec<bool> {
let mut errors = Vec::with_capacity(self.test_x.len());
for (xi, &yi) in self.test_x.iter().zip(&self.test_y) {
let pred = if let Some(joint) = self.joint_count.get(&xi) {
joint.predict().unwrap()
}
else {
match self.priors_count.predict() {
Some(pred) => pred,
None => panic!("Call get_individual_errors() after training"),
}
};
errors.push(pred == yi);
}
errors
}
}
struct ArrayToIndex {
mapping: HashMap<Vec<OrderedFloat<f64>>, usize>,
next_id: usize,
}
impl ArrayToIndex {
pub fn new() -> ArrayToIndex {
ArrayToIndex {
mapping: HashMap::new(),
next_id: 0,
}
}
pub fn map(&mut self, x: ArrayView1<f64>) -> ObjectValue {
let x = x.iter()
.map(|&x| OrderedFloat::from(x))
.collect::<Vec<_>>();
let mapping = &mut self.mapping;
let next_id = &mut self.next_id;
let id = mapping.entry(x).or_insert_with(|| { *next_id += 1;
*next_id - 1});
*id
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn frequentist_init() {
let n_labels = 3;
let train_x = array![0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 6];
let train_y = array![0, 1, 1, 2, 2, 1, 2, 2, 0, 1, 1, 1];
let test_x = array![0, 0, 1, 2, 2, 8, 8];
let test_y = array![1, 1, 1, 1, 2, 1, 0];
let freq = FrequentistEstimator::from_data(n_labels,
&train_x.view(),
&train_y.view(),
&test_x.view(),
&test_y.view());
assert_eq!(freq.joint_count.len(), 4);
assert_eq!(freq.priors_count.count, vec![2, 6, 4]);
assert_eq!(freq.joint_count.get(&0).unwrap().count, vec![1, 2, 0]);
assert_eq!(freq.joint_count.get(&1).unwrap().count, vec![0, 0, 2]);
assert_eq!(freq.joint_count.get(&2).unwrap().count, vec![1, 3, 2]);
assert_eq!(freq.joint_count.get(&8).unwrap().count, vec![0; 3]);
assert_eq!(freq.joint_count.get(&0).unwrap().predict().unwrap(), 1);
assert_eq!(freq.joint_count.get(&1).unwrap().predict().unwrap(), 2);
assert_eq!(freq.joint_count.get(&2).unwrap().predict().unwrap(), 1);
assert!(freq.joint_count.get(&8).unwrap().predict().is_none());
assert_eq!(freq.error_count, 3);
assert_eq!(freq.get_error(), 3./7.);
}
#[test]
fn frequentist_estimate_backward() {
let n_labels = 3;
let train_x = array![0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 6];
let train_y = array![0, 1, 1, 2, 2, 1, 2, 2, 0, 1, 1, 1];
let test_x = array![0, 0, 1, 2, 2, 8, 8];
let test_y = array![1, 1, 1, 1, 2, 1, 0];
let mut freq = FrequentistEstimator::from_data(n_labels,
&train_x.view(),
&train_y.view(),
&test_x.view(),
&test_y.view());
assert_eq!(freq.error_count, 3);
assert_eq!(freq.priors_count.count, vec![2, 6, 4]);
freq.remove_one().unwrap();
assert_eq!(freq.priors_count.count, vec![2, 5, 4]);
assert_eq!(freq.error_count, 3);
freq.remove_one().unwrap();
assert_eq!(freq.joint_count.get(&2).unwrap().count, vec![1, 2, 2]);
let pred = freq.joint_count.get(&2).unwrap().predict().unwrap();
assert_eq!(pred, 1);
assert_eq!(freq.priors_count.count, vec![2, 4, 4]);
assert_eq!(freq.error_count, 3);
freq.remove_one().unwrap();
assert_eq!(freq.joint_count.get(&2).unwrap().count, vec![1, 1, 2]);
assert_eq!(freq.joint_count.get(&2).unwrap().predict().unwrap(), 2);
assert_eq!(freq.priors_count.count, vec![2, 3, 4]);
assert_eq!(freq.priors_count.predict().unwrap(), 2);
assert_eq!(freq.error_count, 4);
freq.remove_one().unwrap();
assert_eq!(freq.joint_count.get(&2).unwrap().count, vec![0, 1, 2]);
assert_eq!(freq.joint_count.get(&2).unwrap().predict().unwrap(), 2);
assert_eq!(freq.priors_count.count, vec![1, 3, 4]);
assert_eq!(freq.error_count, 4);
freq.remove_one().unwrap();
assert_eq!(freq.joint_count.get(&2).unwrap().count, vec![0, 1, 1]);
assert_eq!(freq.joint_count.get(&2).unwrap().predict().unwrap(), 2);
assert_eq!(freq.priors_count.count, vec![1, 3, 3]);
assert_eq!(freq.error_count, 4);
freq.remove_one().unwrap();
assert_eq!(freq.joint_count.get(&2).unwrap().count, vec![0, 1, 0]);
assert_eq!(freq.joint_count.get(&2).unwrap().predict().unwrap(), 1);
assert_eq!(freq.priors_count.count, vec![1, 3, 2]);
assert_eq!(freq.error_count, 3);
freq.remove_one().unwrap();
assert_eq!(freq.joint_count.get(&2).unwrap().count, vec![0, 0, 0]);
assert!(freq.joint_count.get(&2).unwrap().predict().is_none());
assert_eq!(freq.priors_count.count, vec![1, 2, 2]);
assert_eq!(freq.error_count, 3);
freq.remove_one().unwrap();
assert_eq!(freq.joint_count.get(&1).unwrap().count, vec![0, 0, 1]);
assert_eq!(freq.joint_count.get(&1).unwrap().predict().unwrap(), 2);
assert_eq!(freq.priors_count.count, vec![1, 2, 1]);
assert_eq!(freq.error_count, 3);
freq.remove_one().unwrap();
assert_eq!(freq.joint_count.get(&1).unwrap().count, vec![0, 0, 0]);
assert!(freq.joint_count.get(&1).unwrap().predict().is_none());
assert_eq!(freq.priors_count.count, vec![1, 2, 0]);
assert_eq!(freq.error_count, 2);
freq.remove_one().unwrap();
assert_eq!(freq.joint_count.get(&0).unwrap().count, vec![1, 1, 0]);
assert_eq!(freq.priors_count.count, vec![1, 1, 0]);
assert_eq!(freq.error_count, 2);
freq.remove_one().unwrap();
assert_eq!(freq.joint_count.get(&0).unwrap().count, vec![1, 0, 0]);
assert_eq!(freq.joint_count.get(&0).unwrap().predict().unwrap(), 0);
assert_eq!(freq.priors_count.count, vec![1, 0, 0]);
assert_eq!(freq.error_count, 6);
assert!(freq.remove_one().is_err());
}
#[test]
fn frequentist_estimate_forward() {
let n_labels = 3;
let train_x = array![[0.], [0.], [0.], [1.], [1.], [2.], [2.], [2.],
[2.], [2.], [2.], [6.]];
let train_y = array![0, 1, 1, 2, 2, 1, 2, 2, 0, 1, 1, 1];
let test_x = array![[0.], [0.], [1.], [2.], [2.], [8.], [8.]];
let test_y = array![1, 1, 1, 1, 2, 1, 0];
let mut freq = FrequentistEstimator::new(n_labels,
&test_x.view(),
&test_y.view());
freq.add_example(&train_x.row(0), train_y[0]).unwrap();
assert_eq!(freq.joint_count.get(&0).unwrap().count, vec![1, 0, 0]);
assert_eq!(freq.joint_count.get(&0).unwrap().predict().unwrap(), 0);
assert_eq!(freq.priors_count.count, vec![1, 0, 0]);
assert_eq!(freq.error_count, 6);
freq.add_example(&train_x.row(1), train_y[1]).unwrap();
assert_eq!(freq.joint_count.get(&0).unwrap().count, vec![1, 1, 0]);
assert_eq!(freq.priors_count.count, vec![1, 1, 0]);
assert_eq!(freq.error_count, 6);
freq.add_example(&train_x.row(2), train_y[2]).unwrap();
assert_eq!(freq.joint_count.get(&1).unwrap().count, vec![0, 0, 0]);
assert!(freq.joint_count.get(&1).unwrap().predict().is_none());
assert_eq!(freq.priors_count.count, vec![1, 2, 0]);
assert_eq!(freq.error_count, 2);
freq.add_example(&train_x.row(3), train_y[3]).unwrap();
assert_eq!(freq.joint_count.get(&1).unwrap().count, vec![0, 0, 1]);
assert_eq!(freq.joint_count.get(&1).unwrap().predict().unwrap(), 2);
assert_eq!(freq.priors_count.count, vec![1, 2, 1]);
assert_eq!(freq.error_count, 3);
freq.add_example(&train_x.row(4), train_y[4]).unwrap();
assert_eq!(freq.joint_count.get(&2).unwrap().count, vec![0, 0, 0]);
assert!(freq.joint_count.get(&2).unwrap().predict().is_none());
assert_eq!(freq.priors_count.count, vec![1, 2, 2]);
assert_eq!(freq.error_count, 3);
freq.add_example(&train_x.row(5), train_y[5]).unwrap();
assert_eq!(freq.joint_count.get(&2).unwrap().count, vec![0, 1, 0]);
assert_eq!(freq.joint_count.get(&2).unwrap().predict().unwrap(), 1);
assert_eq!(freq.priors_count.count, vec![1, 3, 2]);
assert_eq!(freq.error_count, 3);
freq.add_example(&train_x.row(6), train_y[6]).unwrap();
assert_eq!(freq.joint_count.get(&2).unwrap().count, vec![0, 1, 1]);
let pred = freq.joint_count.get(&2).unwrap().predict().unwrap();
assert_eq!(pred, 1);
assert_eq!(freq.priors_count.count, vec![1, 3, 3]);
assert!(freq.error_count == 3);
freq.add_example(&train_x.row(7), train_y[7]).unwrap();
assert_eq!(freq.joint_count.get(&2).unwrap().count, vec![0, 1, 2]);
assert_eq!(freq.joint_count.get(&2).unwrap().predict().unwrap(), 2);
assert_eq!(freq.priors_count.count, vec![1, 3, 4]);
assert_eq!(freq.error_count, 4);
freq.add_example(&train_x.row(8), train_y[8]).unwrap();
assert_eq!(freq.joint_count.get(&2).unwrap().count, vec![1, 1, 2]);
assert_eq!(freq.joint_count.get(&2).unwrap().predict().unwrap(), 2);
assert_eq!(freq.priors_count.count, vec![2, 3, 4]);
assert_eq!(freq.priors_count.predict().unwrap(), 2);
assert_eq!(freq.error_count, 4);
freq.add_example(&train_x.row(9), train_y[9]).unwrap();
assert_eq!(freq.joint_count.get(&2).unwrap().count, vec![1, 2, 2]);
let pred = freq.joint_count.get(&2).unwrap().predict().unwrap();
assert_eq!(pred, 2);
assert_eq!(freq.priors_count.count, vec![2, 4, 4]);
assert_eq!(freq.error_count, 4);
freq.add_example(&train_x.row(10), train_y[10]).unwrap();
assert_eq!(freq.priors_count.count, vec![2, 5, 4]);
assert_eq!(freq.error_count, 3);
freq.add_example(&train_x.row(11), train_y[11]).unwrap();
assert_eq!(freq.error_count, 3);
assert_eq!(freq.priors_count.count, vec![2, 6, 4]);
}
}