use std;
use ndarray::*;
use ndarray_parallel::rayon::prelude::*;
use std::collections::HashMap;
use ndarray_parallel::prelude::*;
use ordered_float::OrderedFloat;
use std::cmp::Ordering;
use Label;
#[derive(Debug)]
struct Neighbor {
distance: f64,
label: Label,
}
impl Neighbor {
fn new(distance: f64, label: Label) -> Neighbor {
Neighbor {
distance: distance,
label: label,
}
}
}
impl Ord for Neighbor {
fn cmp(&self, other: &Neighbor) -> Ordering {
let self_d = OrderedFloat::from(self.distance);
let other_d = OrderedFloat::from(other.distance);
self_d.cmp(&other_d)
}
}
impl PartialOrd for Neighbor {
fn partial_cmp(&self, other: &Neighbor) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for Neighbor {
fn eq(&self, other: &Neighbor) -> bool {
self.distance == other.distance
}
}
impl Eq for Neighbor {}
#[derive(Debug)]
struct NearestNeighbors<D>
where D: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64 {
x: Array1<f64>,
neighbors: Vec<Neighbor>,
extra_ties: HashMap<Label, usize>,
extra_ties_dist: Option<f64>,
updated_k: usize,
max_k: usize,
distance: D,
}
impl<D> NearestNeighbors<D>
where D: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64 + Copy {
fn new(x: &ArrayView1<f64>, max_k: usize, distance: D) -> NearestNeighbors<D> {
NearestNeighbors {
x: x.to_owned(),
neighbors: Vec::with_capacity(max_k + 1),
extra_ties: HashMap::new(),
extra_ties_dist: None,
updated_k: 0,
max_k: max_k,
distance: distance,
}
}
fn from_data(x: &ArrayView1<f64>, train_x: &ArrayView2<f64>,
train_y: &ArrayView1<Label>, max_k: usize, distance: D) -> NearestNeighbors<D> {
assert!(max_k > 0);
let mut knn = NearestNeighbors::new(x, max_k, distance);
for (xi, yi) in train_x.outer_iter().zip(train_y) {
assert!(*yi != std::usize::MAX,
"label {} is too large and currently not supported", *yi);
knn.add_example(&xi, *yi);
}
knn
}
fn predict_no_ties(&self, k: usize) -> Result<Label, ()> {
if k > self.neighbors.len() {
return Err(());
}
let mut label_count = HashMap::new();
let mut y_pred = 0;
let mut y_count = 0;
for neigh in self.neighbors.iter().take(k) {
let count = label_count.entry(neigh.label).or_insert(0);
*count += 1;
if *count > y_count {
y_pred = neigh.label;
y_count = *count;
}
}
Ok(y_pred)
}
fn predict(&self, k: usize) -> Result<Label, ()> {
if k > self.neighbors.len() {
return Err(());
}
let ties_d = self.neighbors[k-1].distance;
let no_ties = (self.neighbors.len() <= k) || (self.neighbors[k].distance > ties_d);
let extra_ties_matter = (self.extra_ties.len() > 0) && (k == self.neighbors.len());
if no_ties && !extra_ties_matter {
return self.predict_no_ties(k);
}
let mut label_count = HashMap::new();
let mut ties_label_count = HashMap::new();
let mut y_pred = 0;
let mut y_count = 0;
let mut ties_y_pred = 0;
let mut ties_y_count = 0;
const TIES_LABEL: usize = std::usize::MAX;
for (i, neigh) in self.neighbors.iter().enumerate() {
if neigh.distance != ties_d {
if i >= k {
break;
}
let count = label_count.entry(neigh.label).or_insert(0);
*count += 1;
if *count > y_count {
y_pred = neigh.label;
y_count = *count;
}
}
else {
let count = ties_label_count.entry(neigh.label).or_insert(0);
*count += 1;
if *count > ties_y_count {
ties_y_pred = neigh.label;
ties_y_count = *count;
}
let count = label_count.entry(TIES_LABEL).or_insert(0);
if i >= k {
continue;
}
*count += 1;
if *count > y_count {
y_pred = TIES_LABEL;
y_count = *count;
}
}
}
if Some(ties_d) == self.extra_ties_dist {
for (y, extra_count) in &self.extra_ties {
let count = ties_label_count.entry(*y).or_insert(0);
*count += extra_count;
if *count > ties_y_count {
ties_y_pred = *y;
ties_y_count = *count;
}
}
}
if y_pred == TIES_LABEL {
y_pred = ties_y_pred;
}
else {
let mut count = *label_count.get(&TIES_LABEL)
.expect("[!] unexpected error in splitting ties");
if let Some(c) = label_count.get(&ties_y_pred) {
count += c;
}
if count > y_count {
y_pred = ties_y_pred;
}
}
Ok(y_pred)
}
fn first_of_ties(&self, mut i: usize) -> usize {
if i == 0 || self.neighbors.len() == 0 {
return 0;
}
let d = self.neighbors.get(i)
.expect("first_of_ties() called on wrong index")
.distance;
while let Some(neigh) = self.neighbors.get(i-1) {
if neigh.distance != d {
break;
}
i -= 1;
if i == 0 {
break;
}
}
i
}
fn add_example(&mut self, x: &ArrayView1<f64>, y: Label) -> bool {
let d = (self.distance)(x, &self.x.view());
if self.neighbors.len() < self.max_k {
let new = Neighbor::new(d, y);
let pos = self.neighbors.binary_search(&new).unwrap_or_else(|e| e);
self.neighbors.insert(pos, new);
self.updated_k = self.first_of_ties(pos);
}
else if self.neighbors.last().unwrap().distance < d {
return false;
}
else if self.neighbors.last().unwrap().distance == d {
if self.extra_ties.len() == 0 {
self.extra_ties_dist = Some(d);
}
{
let count = self.extra_ties.entry(y).or_insert(0);
*count += 1;
}
self.updated_k = self.first_of_ties(self.neighbors.len()-1);
}
else {
let new = Neighbor::new(d, y);
let pos = self.neighbors.binary_search(&new).unwrap_or_else(|e| e);
self.neighbors.insert(pos, new);
self.updated_k = self.first_of_ties(pos);
if let Some(removed) = self.neighbors.pop() {
let last_neigh = self.neighbors.last().unwrap();
if last_neigh.distance == removed.distance {
if self.extra_ties.len() == 0 {
self.extra_ties_dist = Some(removed.distance);
}
else {
assert_eq!(Some(removed.distance), self.extra_ties_dist);
}
let count = self.extra_ties.entry(removed.label).or_insert(0);
*count += 1;
}
else {
self.extra_ties.clear();
self.extra_ties_dist = None;
}
}
}
true
}
}
#[derive(Debug)]
pub struct KNNEstimator<D>
where D: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64 {
neighbors: Vec<NearestNeighbors<D>>,
pub errors: Vec<f64>,
pub predictions: Vec<Label>,
labels: Vec<Label>,
current_k: usize,
pub k_error_count: f64,
n: usize,
}
impl<D> KNNEstimator<D>
where D: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64 + Send + Sync + Copy {
pub fn new(test_x: &ArrayView2<f64>, test_y: &ArrayView1<Label>,
k: usize, max_k: usize, distance: D) -> KNNEstimator<D> {
assert_eq!(test_x.rows(), test_y.len());
assert!(test_y.len() > 0);
let neighbors = test_x.outer_iter()
.into_par_iter()
.map(|x| NearestNeighbors::new(&x, max_k, distance))
.collect::<Vec<_>>();
let errors = test_y.iter()
.map(|y| if *y != 0 { 1. } else { 0. })
.collect::<Vec<_>>();
let error_count = errors.iter().sum();
KNNEstimator {
neighbors: neighbors,
errors: errors,
predictions: vec![0; test_y.len()],
labels: test_y.to_vec(),
current_k: k,
k_error_count: error_count,
n: 0,
}
}
pub fn from_data(train_x: &ArrayView2<f64>, train_y: &ArrayView1<Label>,
test_x: &ArrayView2<f64>, test_y: &ArrayView1<Label>,
k: usize, max_k: usize, distance: D) -> KNNEstimator<D> {
assert_eq!(train_x.cols(), test_x.cols());
assert_eq!(train_x.rows(), train_y.len());
assert_eq!(test_x.rows(), test_y.len());
assert!(train_x.len() > 0);
assert!(test_x.len() > 0);
let neighbors = test_x.outer_iter()
.into_par_iter()
.map(|x| NearestNeighbors::from_data(&x,
&train_x.view(),
&train_y.view(),
max_k,
distance))
.collect::<Vec<_>>();
let mut knn_error = 0.;
let mut errors = Vec::with_capacity(test_y.len());
let mut predictions = Vec::with_capacity(test_y.len());
for (neigh, y) in neighbors.iter().zip(test_y) {
let pred = neigh.predict(k)
.expect("unexpected error");
let error = if pred != *y { 1. } else { 0. };
predictions.push(pred);
errors.push(error);
knn_error += error;
}
KNNEstimator {
neighbors: neighbors,
errors: errors,
predictions: predictions,
labels: test_y.to_vec(),
current_k: k,
k_error_count: knn_error,
n: train_y.len(),
}
}
fn update_all(&mut self) -> Result<(), ()> {
for (neigh, y, p, e) in izip!(&self.neighbors, &self.labels,
&mut self.predictions, &mut self.errors) {
let pred = neigh.predict(self.current_k)?;
if pred == *p {
continue;
}
let error = if pred != *y { 1. } else { 0. };
self.k_error_count += error - *e;
*p = pred;
*e = error;
}
Ok(())
}
pub fn add_example(&mut self, x: &ArrayView1<f64>, y: Label) -> Result<(), ()> {
let current_k = self.current_k;
self.n += 1;
let error_updates: Result<Vec<_>, ()> = self.neighbors.par_iter_mut()
.zip(&self.labels)
.zip(&mut self.predictions)
.zip(&mut self.errors)
.filter_map(|(((neigh, true_y), old_pred), old_error)| {
if neigh.add_example(x, y) {
if neigh.updated_k > current_k {
return None;
}
let pred = match neigh.predict(current_k) {
Ok(pred) => pred,
_ => return Some(Err(())),
};
if pred == *old_pred {
return None;
}
let error = if pred != *true_y { 1. } else { 0. };
let update = error - *old_error;
*old_error = error;
*old_pred = pred;
return Some(Ok(update));
}
None
})
.collect();
self.k_error_count += error_updates?.iter().sum::<f64>();
Ok(())
}
pub fn get_error(&self) -> f64 {
self.k_error_count / self.labels.len() as f64
}
pub fn get_error_count(&self) -> usize {
self.k_error_count as usize
}
pub fn set_k(&mut self, k: usize) -> Result<(), ()> {
if k != self.current_k {
self.current_k = k;
self.update_all()?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use estimates::*;
#[test]
fn knn_init() {
let train_x = array![[8.],
[3.],
[1.],
[4.],
[5.],
[7.],
[2.],
[6.]];
let train_y = array![0, 0, 0, 1, 0, 1, 1, 2];
let x = array![0.];
let mut max_k = 8;
let knn = NearestNeighbors::from_data(&x.view(), &train_x.view(),
&train_y.view(), max_k,
euclidean_distance);
let distances_from_x: Vec<_> = knn.neighbors.iter()
.map(|neigh| neigh.distance)
.collect();
assert_eq!(distances_from_x, vec![1., 2., 3., 4., 5., 6., 7., 8.]);
max_k = 5;
let knn = NearestNeighbors::from_data(&x.view(), &train_x.view(),
&train_y.view(), max_k,
euclidean_distance);
let distances_from_x: Vec<_> = knn.neighbors.iter()
.map(|neigh| neigh.distance)
.collect();
assert_eq!(distances_from_x, vec![1., 2., 3., 4., 5.]);
}
#[test]
fn knn_predictions_ties() {
let train_x = array![[0.], [1.], [1.], [1.], [1.], [1.], [1.], [2.], [2.]];
let train_y = array![0, 1, 1, 1, 0, 1, 0, 0, 0];
let x = array![0.];
let max_k = 10;
let knn = NearestNeighbors::from_data(&x.view(), &train_x.view(),
&train_y.view(), max_k,
euclidean_distance);
assert_eq!(knn.predict(1), Ok(0));
assert_eq!(knn.predict(3), Ok(1));
assert_eq!(knn.predict(5), Ok(1));
let train_x = array![[1.], [1.], [1.], [1.], [1.], [1.], [2.], [2.], [0.]];
let train_y = array![0, 1, 1, 1, 0, 1, 0, 0, 0];
let knn = NearestNeighbors::from_data(&x.view(), &train_x.view(),
&train_y.view(), max_k,
euclidean_distance);
assert_eq!(knn.predict(1), Ok(0));
assert_eq!(knn.predict(3), Ok(1));
assert_eq!(knn.predict(5), Ok(1));
}
#[test]
fn knn_predictions_multivariate() {
let train_x = array![[1., 3.],
[1., 2.],
[2., 3.],
[2., 2.],
[3., 2.],
[2., 2.],
[2., 2.]];
let train_y = array![0, 0, 0, 1, 1, 2, 2];
let x1 = array![2., 2.];
let x2 = array![2., 2.];
let max_k = 10;
let distance = euclidean_distance;
let mut knn1 = NearestNeighbors::from_data(&x1.view(), &train_x.view(),
&train_y.view(), max_k,
distance);
let knn2 = NearestNeighbors::from_data(&x2.view(), &train_x.view(),
&train_y.view(), max_k,
distance);
assert_eq!(knn1.predict(1), Ok(2));
knn1.add_example(&array![2., 1.].view(), 2);
assert_eq!(knn1.predict(1), Ok(2));
assert_eq!(knn1.predict(3), Ok(2));
assert!(knn1.predict(5) == Ok(2) || knn1.predict(5) == Ok(0));
assert_eq!(knn2.predict(1), Ok(2));
assert_eq!(knn2.predict(3), Ok(2));
assert!(knn2.predict(5) == Ok(2) || knn2.predict(5) == Ok(1));
}
#[test]
fn knn_forward_predictions() {
let train_x = array![[8.],
[7.],
[6.],
[5.],
[4.],
[3.],
[2.],
[1.]];
let train_y = array![0, 0, 0, 1, 0, 1, 1, 2];
let x = array![0.];
let max_k = 8;
let distance = euclidean_distance;
let mut knn = NearestNeighbors::new(&x.view(), max_k, distance);
let expected_preds_1 = vec![Ok(0), Ok(0), Ok(0), Ok(1), Ok(0), Ok(1),
Ok(1), Ok(2)];
let expected_preds_3 = vec![Err(()), Err(()), Ok(0), Ok(0), Ok(0),
Ok(1), Ok(1), Ok(1)];
let expected_preds_5 = vec![Err(()), Err(()), Err(()), Err(()), Ok(0),
Ok(0), Ok(1), Ok(1)];
let expected_preds_7 = vec![Err(()), Err(()), Err(()), Err(()),
Err(()), Err(()), Ok(0), Ok(1)];
for (i, (x, y)) in train_x.outer_iter().zip(train_y.iter()).enumerate() {
knn.add_example(&x, *y);
assert_eq!(knn.predict(1), expected_preds_1[i]);
assert_eq!(knn.predict(3), expected_preds_3[i]);
assert_eq!(knn.predict(5), expected_preds_5[i]);
assert_eq!(knn.predict(7), expected_preds_7[i]);
}
}
#[test]
fn knn_forward_errors() {
let train_x = array![[8.],
[7.],
[6.],
[5.],
[4.],
[3.],
[2.],
[1.]];
let train_y = array![0, 0, 0, 1, 0, 1, 1, 2];
let test_x = array![[3.],
[0.],
[6.],
[1.],
[6.],
[4.],
[5.]];
let test_y = array![0, 0, 2, 1, 0, 1, 0];
let max_k = 8;
let distance = euclidean_distance;
let k = 1;
let mut knn = KNNEstimator::new(&test_x.view(), &test_y.view(), k,
max_k, distance);
let expected_preds = vec![[1, 2, 0, 2, 0, 0, 1],
[1, 1, 0, 1, 0, 0, 1],
[1, 1, 0, 1, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 1],
[1, 1, 0, 1, 0, 1, 1],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]];
let expected_error = vec![0.8571428571428571, 0.7142857142857143,
0.7142857142857143, 0.5714285714285714,
0.5714285714285714, 0.42857142857142855,
0.42857142857142855, 0.42857142857142855];
for (i, (x, y)) in train_x.outer_iter().zip(train_y.iter()).enumerate() {
knn.add_example(&x, *y).unwrap();
assert_eq!(knn.predictions, expected_preds[expected_preds.len()-1-i]);
assert_eq!(knn.get_error(), expected_error[expected_error.len()-1-i]);
}
let k = 1;
let max_k = train_y.len();
let mut knn = KNNEstimator::new(&test_x.view(), &test_y.view(), k,
max_k, distance);
let expected_error = vec![0.42857142857142855, 0.42857142857142855,
0.42857142857142855, 0.5714285714285714,
0.42857142857142855, 0.42857142857142855,
0.42857142857142855, 0.42857142857142855];
let expected_preds = vec![vec![0; 7], vec![0; 7], vec![0; 7],
vec![1, 1, 0, 1, 0, 1, 1],
vec![0; 7],
vec![1, 1, 0, 1, 0, 1, 0],
vec![1, 1, 0, 1, 0, 1, 0],
vec![1, 1, 0, 1, 0, 1, 0]];
let ks = vec![1, 1, 1, 1, 3, 3, 5, 5];
for (i, (x, y)) in train_x.outer_iter().zip(train_y.iter()).enumerate() {
knn.set_k(ks[i]).unwrap();
knn.add_example(&x, *y).unwrap();
assert_eq!(knn.get_error(), expected_error[i]);
assert_eq!(knn.predictions, expected_preds[i]);
}
}
#[test]
fn ties_after_max_k() {
let max_k = 5;
let distance = euclidean_distance;
let train_x = array![[0.], [0.], [0.], [1.], [1.], [1.], [1.], [1.], [0.]];
let train_y = array![0, 1, 1, 1, 0, 0, 0, 0, 1];
let x = array![0.];
let mut nn = NearestNeighbors::from_data(&x.view(), &train_x.view(),
&train_y.view(), max_k,
distance);
let distances_from_x: Vec<_> = nn.neighbors.iter()
.map(|neigh| neigh.distance)
.collect();
assert_eq!(distances_from_x, vec![0., 0., 0., 0., 1.]);
assert_eq!(nn.extra_ties_dist, Some(1.));
println!("{:?}", nn.extra_ties);
assert_eq!(nn.extra_ties.get(&0), Some(&3));
assert_eq!(nn.extra_ties.get(&1), Some(&1));
assert_eq!(nn.predict(3).unwrap(), 1);
assert_eq!(nn.predict(5).unwrap(), 1);
assert!(nn.predict(6).is_err());
nn.add_example(&array![1.].view(), 0);
nn.add_example(&array![1.].view(), 1);
nn.add_example(&array![1.].view(), 2);
nn.add_example(&array![1.].view(), 2);
nn.add_example(&array![1.].view(), 3);
assert_eq!(nn.extra_ties.get(&0), Some(&4));
assert_eq!(nn.extra_ties.get(&1), Some(&2));
assert_eq!(nn.extra_ties.get(&2), Some(&2));
assert_eq!(nn.extra_ties.get(&3), Some(&1));
let train_y = array![0, 0, 1, 1, 0, 0, 0, 0, 1];
let nn = NearestNeighbors::from_data(&x.view(), &train_x.view(),
&train_y.view(), max_k,
euclidean_distance);
assert_eq!(nn.predict(5).unwrap(), 0);
}
#[test]
fn test_updated_k() {
let test_x = array![[0.]];
let test_y = array![1];
let k = 1;
let max_k = 5;
let mut knn = KNNEstimator::new(&test_x.view(), &test_y.view(), k,
max_k, euclidean_distance);
println!("asdf");
for _ in 0..5 {
knn.add_example(&array![2.].view(), 0).unwrap();
println!("asdf");
}
assert_eq!(knn.predictions, vec![0]);
for _ in 0..6 {
knn.add_example(&array![2.].view(), 1).unwrap();
println!("asdf");
}
assert_eq!(knn.predictions, vec![1]);
}
}