1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
//! An estimator returning the bound based on the NN classifier.
use ndarray::*;

use crate::Label;
use crate::estimates::{BayesEstimator,KNNEstimator,KNNStrategy,nn_bound};

/// Defines an estimator that returns the NN bound by Cover&Hart.
///
/// This estimate is asymptotically guaranteed to lower bound the
/// true Bayes risk.
pub struct NNBoundEstimator<D>
where D: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64 + Send + Sync + Copy {
    knn: KNNEstimator<D>,
    nlabels: usize,
}

impl<D> NNBoundEstimator<D>
where D: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64 + Send + Sync + Copy {
    /// Create a new NN bound estimator.
    pub fn new(test_x: &ArrayView2<f64>, test_y: &ArrayView1<Label>,
               distance: D, nlabels: usize) -> NNBoundEstimator<D> {
        
        // NOTE: the value of max_n here does not matter, as it is
        // only used for computing max_k, which is fixed to 1
        // for the KNNStrategy:NN.
        let max_n = 1;

        NNBoundEstimator {
            knn: KNNEstimator::new(test_x, test_y, max_n, distance,
                                   KNNStrategy::NN),
            nlabels,
        }
    }
}

/// This implementation maps exactly that of KNNEstimator,
/// except for get_error(), which returns the bound.
impl<D> BayesEstimator for NNBoundEstimator<D>
where D: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64 + Send + Sync + Copy {
    /// Adds a new example.
    fn add_example(&mut self, x: &ArrayView1<f64>, y: Label) -> Result<(), ()> {
        self.knn.add_example(x, y)
    }
    /// Returns the error count.
    fn get_error_count(&self) -> usize {
        self.knn.get_error_count()
    }

    /// Returns the error for the current k.
    fn get_error(&self) -> f64 {
        let error = self.knn.get_error();
        nn_bound(error, self.nlabels)
    }

    /// Returns the current errors for each test point.
    fn get_individual_errors(&self) -> Vec<bool> {
        self.knn.get_individual_errors()
    }
}