1use crate::{Classifier, Control};
2use ndarray::{s, Array1, Array2, ArrayView2, Axis};
3use std::cell::{Ref, RefCell};
4
5#[allow(non_camel_case_types)]
6pub struct kNN<'a, 'b> {
7 X: &'a ArrayView2<'b, f64>,
8 ordering: RefCell<Option<Array2<usize>>>,
9 control: &'a Control,
10}
11
12impl<'a, 'b> kNN<'a, 'b> {
13 pub fn new(X: &'a ArrayView2<'b, f64>, control: &'a Control) -> kNN<'a, 'b> {
14 kNN {
15 X,
16 ordering: RefCell::new(Option::None),
17 control,
18 }
19 }
20
21 fn calculate_ordering(&self) -> Array2<usize> {
22 let n = self.X.nrows();
23 let mut distances = Array2::<f64>::zeros((n, n));
24
25 for i in 0..n {
26 for j in 0..n {
27 if i >= j {
28 distances[[i, j]] = distances[[j, i]]
29 } else {
30 for k in 0..self.X.ncols() {
31 distances[[i, j]] += (self.X[[i, k]] - self.X[[j, k]]).powi(2)
32 }
33 }
34 }
35 }
36
37 let mut ordering = Array2::<usize>::default((n, n));
39 for (i, mut row) in ordering.axis_iter_mut(Axis(0)).enumerate() {
40 let mut order: Vec<usize> = (0..n).collect();
41 order.sort_unstable_by(|a, b| {
42 distances[[i, *a]].partial_cmp(&distances[[i, *b]]).unwrap()
43 });
44 for (j, val) in row.iter_mut().enumerate() {
45 *val = order[j]
46 }
47 }
48 ordering
49 }
50
51 fn get_ordering(&self) -> Ref<'_, Array2<usize>> {
52 if self.ordering.borrow().is_none() {
53 self.ordering.replace(Some(self.calculate_ordering()));
54 }
55
56 Ref::map(self.ordering.borrow(), |borrow| borrow.as_ref().unwrap())
57 }
58}
59
60impl<'a, 'b> Classifier for kNN<'a, 'b> {
61 fn n(&self) -> usize {
62 self.X.nrows()
63 }
64
65 fn predict(&self, start: usize, stop: usize, split: usize) -> Array1<f64> {
66 let ordering = self.get_ordering();
67 let segment_length = stop - start;
68 let k = (segment_length as f64).sqrt().floor();
69 let k_usize = k as usize;
70 let mut predictions = Array1::<f64>::zeros(segment_length);
71
72 for (i, row) in ordering
73 .slice(s![start..stop, ..])
74 .axis_iter(Axis(0))
75 .enumerate()
76 {
77 predictions[i] = row .iter()
79 .skip(1) .filter(|j| (start <= **j) & (**j < stop)) .take(k_usize) .filter(|j| **j >= split)
83 .count() as f64
84 / k; }
86
87 predictions
88 }
89
90 fn control(&self) -> &Control {
91 self.control
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::gain::{ApproxGain, ClassifierGain, Gain};
99 use crate::optimizer::{Optimizer, TwoStepSearch};
100 use crate::testing;
101 use assert_approx_eq::*;
102 use ndarray::arr1;
103 use rstest::*;
104
105 #[test]
106 fn test_X_ordering() {
107 let X = ndarray::array![[1.], [1.5], [3.], [-0.5]];
108 let X_view = X.view();
109 let control = Control::default();
110
111 let knn = kNN::new(&X_view, &control);
112 let ordering = knn.calculate_ordering();
113 let expected = ndarray::array![[0, 1, 3, 2], [1, 0, 2, 3], [2, 1, 0, 3], [3, 0, 1, 2]];
114 assert_eq!(ordering, expected)
115 }
116
117 #[rstest]
118 #[case(0, 6, 2, arr1(&[0.5, 0.5, 0., 1., 1., 0.5]))]
119 #[case(0, 6, 3, arr1(&[0., 0., 0., 1., 1., 0.5]))]
120 #[case(1, 6, 2, arr1(&[1., 0.5, 1., 1., 0.5]))]
121 #[case(1, 5, 2, arr1(&[1., 0.5, 0.5, 0.5]))]
122 #[case(1, 5, 5, arr1(&[0., 0., 0., 0.]))]
123 #[case(2, 2, 2, arr1(&[]))]
124 fn test_predictions(
125 #[case] start: usize,
126 #[case] stop: usize,
127 #[case] split: usize,
128 #[case] expected: Array1<f64>,
129 ) {
130 let X = ndarray::array![
131 [1., 1.],
132 [1.5, 1.],
133 [0.5, 1.],
134 [3., 3.],
135 [4.5, 3.],
136 [2.5, 2.5]
137 ];
138 let X_view = X.view();
139 let control = Control::default();
140
141 let knn = kNN::new(&X_view, &control);
142 let predictions = knn.predict(start, stop, split);
143
144 assert_eq!(predictions, expected);
145 }
146
147 #[rstest]
148 #[case(0, 6, arr1(&[0.0, 0.0, -3.3325539228390255, 4.796659545476027, -9.55569673879512, 0.0]))]
149 fn test_gain(#[case] start: usize, #[case] stop: usize, #[case] expected: Array1<f64>) {
150 let X = ndarray::array![
152 [1., 1.],
153 [1.5, 1.],
154 [0.5, 1.],
155 [3., 3.],
156 [4.5, 3.],
157 [2.5, 2.5]
158 ];
159 let X_view = X.view();
160 let control = Control::default();
161
162 let knn = kNN::new(&X_view, &control);
163 let knn_gain = ClassifierGain { classifier: knn };
164
165 let split_points: Vec<usize> = (start..stop).collect();
166 for split_point in start..stop {
167 assert_approx_eq!(
168 expected[split_point - start],
169 knn_gain.gain(start, stop, split_point)
170 );
171 assert_approx_eq!(
172 expected[split_point - start],
173 knn_gain
174 .gain_approx(start, stop, split_point, &split_points)
175 .gain[split_point - start]
176 )
177 }
178 }
179
180 #[rstest]
181 #[case(0, 100, 40)]
182 fn test_two_step_search(#[case] start: usize, #[case] stop: usize, #[case] expected: usize) {
183 let X = testing::array();
184 let X_view = X.view();
185 let control = Control::default().with_minimal_relative_segment_length(0.01);
186
187 let classifier = kNN::new(&X_view, &control);
188 let gain = ClassifierGain { classifier };
189 let optimizer = TwoStepSearch { gain };
190
191 assert_eq!(
192 expected,
193 optimizer.find_best_split(start, stop).unwrap().best_split
194 );
195 }
196}