changeforest/classifier/
knn.rs

1use crate::{Classifier, Control};
2use ndarray::{s, Array1, Array2, ArrayView2, Axis};
3use std::cell::{Ref, RefCell};
4
5#[allow(non_camel_case_types)]
6pub struct kNN<'a, 'b> {
7    X: &'a ArrayView2<'b, f64>,
8    ordering: RefCell<Option<Array2<usize>>>,
9    control: &'a Control,
10}
11
12impl<'a, 'b> kNN<'a, 'b> {
13    pub fn new(X: &'a ArrayView2<'b, f64>, control: &'a Control) -> kNN<'a, 'b> {
14        kNN {
15            X,
16            ordering: RefCell::new(Option::None),
17            control,
18        }
19    }
20
21    fn calculate_ordering(&self) -> Array2<usize> {
22        let n = self.X.nrows();
23        let mut distances = Array2::<f64>::zeros((n, n));
24
25        for i in 0..n {
26            for j in 0..n {
27                if i >= j {
28                    distances[[i, j]] = distances[[j, i]]
29                } else {
30                    for k in 0..self.X.ncols() {
31                        distances[[i, j]] += (self.X[[i, k]] - self.X[[j, k]]).powi(2)
32                    }
33                }
34            }
35        }
36
37        // A rather complex ordering = numpy.argsort(distances, 1)
38        let mut ordering = Array2::<usize>::default((n, n));
39        for (i, mut row) in ordering.axis_iter_mut(Axis(0)).enumerate() {
40            let mut order: Vec<usize> = (0..n).collect();
41            order.sort_unstable_by(|a, b| {
42                distances[[i, *a]].partial_cmp(&distances[[i, *b]]).unwrap()
43            });
44            for (j, val) in row.iter_mut().enumerate() {
45                *val = order[j]
46            }
47        }
48        ordering
49    }
50
51    fn get_ordering(&self) -> Ref<'_, Array2<usize>> {
52        if self.ordering.borrow().is_none() {
53            self.ordering.replace(Some(self.calculate_ordering()));
54        }
55
56        Ref::map(self.ordering.borrow(), |borrow| borrow.as_ref().unwrap())
57    }
58}
59
60impl<'a, 'b> Classifier for kNN<'a, 'b> {
61    fn n(&self) -> usize {
62        self.X.nrows()
63    }
64
65    fn predict(&self, start: usize, stop: usize, split: usize) -> Array1<f64> {
66        let ordering = self.get_ordering();
67        let segment_length = stop - start;
68        let k = (segment_length as f64).sqrt().floor();
69        let k_usize = k as usize;
70        let mut predictions = Array1::<f64>::zeros(segment_length);
71
72        for (i, row) in ordering
73            .slice(s![start..stop, ..])
74            .axis_iter(Axis(0))
75            .enumerate()
76        {
77            predictions[i] = row // order of neighbors by distance
78                .iter()
79                .skip(1) // To get LOOCV-like predictions
80                .filter(|j| (start <= **j) & (**j < stop)) // segment
81                .take(k_usize) // Only look at first k neighbors
82                .filter(|j| **j >= split)
83                .count() as f64
84                / k; // Proportion of neighbors from after split.
85        }
86
87        predictions
88    }
89
90    fn control(&self) -> &Control {
91        self.control
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use crate::gain::{ApproxGain, ClassifierGain, Gain};
99    use crate::optimizer::{Optimizer, TwoStepSearch};
100    use crate::testing;
101    use assert_approx_eq::*;
102    use ndarray::arr1;
103    use rstest::*;
104
105    #[test]
106    fn test_X_ordering() {
107        let X = ndarray::array![[1.], [1.5], [3.], [-0.5]];
108        let X_view = X.view();
109        let control = Control::default();
110
111        let knn = kNN::new(&X_view, &control);
112        let ordering = knn.calculate_ordering();
113        let expected = ndarray::array![[0, 1, 3, 2], [1, 0, 2, 3], [2, 1, 0, 3], [3, 0, 1, 2]];
114        assert_eq!(ordering, expected)
115    }
116
117    #[rstest]
118    #[case(0, 6, 2, arr1(&[0.5, 0.5, 0., 1., 1., 0.5]))]
119    #[case(0, 6, 3, arr1(&[0., 0., 0., 1., 1., 0.5]))]
120    #[case(1, 6, 2, arr1(&[1., 0.5, 1., 1., 0.5]))]
121    #[case(1, 5, 2, arr1(&[1., 0.5, 0.5, 0.5]))]
122    #[case(1, 5, 5, arr1(&[0., 0., 0., 0.]))]
123    #[case(2, 2, 2, arr1(&[]))]
124    fn test_predictions(
125        #[case] start: usize,
126        #[case] stop: usize,
127        #[case] split: usize,
128        #[case] expected: Array1<f64>,
129    ) {
130        let X = ndarray::array![
131            [1., 1.],
132            [1.5, 1.],
133            [0.5, 1.],
134            [3., 3.],
135            [4.5, 3.],
136            [2.5, 2.5]
137        ];
138        let X_view = X.view();
139        let control = Control::default();
140
141        let knn = kNN::new(&X_view, &control);
142        let predictions = knn.predict(start, stop, split);
143
144        assert_eq!(predictions, expected);
145    }
146
147    #[rstest]
148    #[case(0, 6, arr1(&[0.0, 0.0, -3.3325539228390255, 4.796659545476027, -9.55569673879512, 0.0]))]
149    fn test_gain(#[case] start: usize, #[case] stop: usize, #[case] expected: Array1<f64>) {
150        // TODO Find out if this makes any sense.
151        let X = ndarray::array![
152            [1., 1.],
153            [1.5, 1.],
154            [0.5, 1.],
155            [3., 3.],
156            [4.5, 3.],
157            [2.5, 2.5]
158        ];
159        let X_view = X.view();
160        let control = Control::default();
161
162        let knn = kNN::new(&X_view, &control);
163        let knn_gain = ClassifierGain { classifier: knn };
164
165        let split_points: Vec<usize> = (start..stop).collect();
166        for split_point in start..stop {
167            assert_approx_eq!(
168                expected[split_point - start],
169                knn_gain.gain(start, stop, split_point)
170            );
171            assert_approx_eq!(
172                expected[split_point - start],
173                knn_gain
174                    .gain_approx(start, stop, split_point, &split_points)
175                    .gain[split_point - start]
176            )
177        }
178    }
179
180    #[rstest]
181    #[case(0, 100, 40)]
182    fn test_two_step_search(#[case] start: usize, #[case] stop: usize, #[case] expected: usize) {
183        let X = testing::array();
184        let X_view = X.view();
185        let control = Control::default().with_minimal_relative_segment_length(0.01);
186
187        let classifier = kNN::new(&X_view, &control);
188        let gain = ClassifierGain { classifier };
189        let optimizer = TwoStepSearch { gain };
190
191        assert_eq!(
192            expected,
193            optimizer.find_best_split(start, stop).unwrap().best_split
194        );
195    }
196}