rml/knn.rs
1// Copyright 2021 Jonathan Manly.
2
3// This file is part of rml.
4
5// rml is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Lesser General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9
10// rml is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Lesser General Public License for more details.
14
15// You should have received a copy of the GNU Lesser General Public License
16// along with rml. If not, see <https://www.gnu.org/licenses/>.
17
18//! Implementation for K-Nearest Neighbors.
19
20/*!
21Allows for predicting data based on a KNN search.
22A full, working example is contained in the `examples/knn` directory.
23
24# Example
25```rust
26// Collect and parse data to a format consistent with:
27// type CSVOutput = (Vec<Vec<f64>>, Vec<i32>);
28
29let training_data: CSVOutput = (Vec::new, Vec::new);
30let testing_data: CSVOutput = (Vec::new, Vec::new);
31
32// Create a new KNN struct.
33let knn = knn::KNN::new(
345, // 5-nearest
35training_data.0, // x
36training_data.1, // y
37None, // Default distance(euclidean)
38Some(math::norm::Norm::L2), // L2 Normalization
39);
40
41// Get a prediction for each point of the testing data.
42let pred: Vec<i32> = testing_data.0.iter().map(|x| knn.predict(x)).collect();
43
44// Count the number that were predicted correctly.
45let num_correct = pred
46 .iter()
47 .cloned()
48 .zip(&testing_data.1)
49 .filter(|(a, b)| *a == **b)
50 .count();
51
52println!(
53 "Accuracy: {}",
54 (num_correct as f64) / (pred.len() as f64)
55);
56
57```
58!*/
59
60use crate::math::distance;
61use crate::math::norm;
62use rayon::prelude::*;
63use std::cmp::Ordering;
64use std::collections::HashSet;
65
66/// KNN struct handles the computation and data for the K-Nearest Neighbors algorithm.
67/// It is *highly recommended* to not change values inside of this struct manually. Always
68/// create a new one using ::new.
69#[derive(Debug)]
70pub struct KNN {
71 /// K-Nearest to analyze
72 pub k: i32,
73 /// Features
74 pub x: Vec<Vec<f64>>,
75 /// Class labels for each feature.
76 pub y: Vec<i32>,
77 /// Number of labels.
78 pub num_labels: usize,
79 /// Type of distance to use.
80 pub distance: Option<distance::Distance>,
81 /// The type of normalization, or None.
82 pub normalize: Option<norm::Norm>,
83}
84
85/// A data point.
86#[derive(PartialEq, Debug)]
87pub struct Point {
88 /// The class label for the point.
89 pub class: i32,
90 /// The distance from the test point.
91 pub distance: f64,
92}
93
94impl PartialOrd for Point {
95 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
96 self.distance.partial_cmp(&other.distance)
97 }
98}
99
100impl Eq for Point {}
101
102impl KNN {
103 /// Create a new KNN with optional normalization.
104 pub fn new(
105 k: i32,
106 x: Vec<Vec<f64>>,
107 y: Vec<i32>,
108 distance: Option<distance::Distance>,
109 normalize: Option<norm::Norm>,
110 ) -> KNN {
111 let num_labels = KNN::get_num_labels(&y);
112 let mut knn = KNN {
113 k,
114 x,
115 y,
116 num_labels,
117 distance,
118 normalize,
119 };
120 knn.normalize_data();
121 knn
122 }
123
124 /// Gets the number of unique labels.
125 /// This function is called when ::new is called. You can access the value using
126 /// the value contained in the KNN struct.
127 pub fn get_num_labels(y: &[i32]) -> usize {
128 let set: HashSet<i32> = y.iter().cloned().collect::<HashSet<_>>();
129 set.len()
130 }
131
132 /// Normalize the data contain in `self` given by the KNN's configured normalization setting.
133 pub fn normalize_data(&mut self) {
134 if let Some(n) = &self.normalize {
135 self.x
136 .iter_mut()
137 .for_each(|xi| norm::normalize_vector(xi, n));
138 }
139 }
140
141 /// Borrow immutable reference to the data.
142 pub fn data(&self) -> (&Vec<Vec<f64>>, &Vec<i32>) {
143 (&self.x, &self.y)
144 }
145
146 /// Calculate the distance from `new_point` to all other points in the set.
147 /// Note: new_point must be the same dimensions as the data passed into ::new.norm
148 pub fn calculate_distances(&self, new_point: &[f64]) -> Vec<Point> {
149 let distance_fn = match self.distance {
150 Some(distance::Distance::Manhattan) => distance::manhattan_distance,
151 _ => distance::euclidean_distance,
152 };
153
154 self.x
155 .par_iter()
156 .zip(self.y.par_iter())
157 .map(|(x, y)| Point {
158 class: *y,
159 distance: distance_fn(new_point, x),
160 })
161 .collect()
162 }
163
164 /// Predict the class of a point `x`.
165 pub fn predict(&self, x: &[f64]) -> i32 {
166 let mut norm_x: Vec<f64> = x.to_owned();
167 if let Some(n) = &self.normalize {
168 norm::normalize_vector(&mut norm_x, n);
169 }
170 let mut points = self.calculate_distances(x);
171 // points.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
172 points.par_sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
173
174 let mut predictions = vec![0; self.num_labels];
175
176 for i in &points[0..self.k as usize] {
177 predictions[i.class as usize] += 1;
178 }
179 KNN::get_max_value(&predictions)
180 }
181
182 /// Get the class of the highest index.
183 fn get_max_value(predictions: &[i32]) -> i32 {
184 predictions
185 .iter()
186 .enumerate() // add index to the iterated items [a, b, c] -> [(0, a), (1, b), (2, c)]
187 .max_by_key(|(_, pred)| **pred) // take maximum by the actual item, not the index,
188 // `pred` has type `&&i32`, because of all the combinators, so we have to dereference twice
189 .map(|(idx, _)| idx) // Option::map - take tuple (idx, value) and transform it to just idx
190 .unwrap() as i32
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn normalize_data_test() {
200 let p: Vec<Vec<f64>> = vec![vec![2.0, 2.0, 2.0]];
201 let mut knn = KNN::new(5, p, vec![1], None, Some(norm::Norm::L2));
202 knn.normalize_data();
203 assert_eq!(
204 knn.data().0.clone(),
205 vec![vec![2.0 / f64::from(12).sqrt(); 3]]
206 );
207 }
208
209 #[test]
210 fn calculate_distances_test() {
211 let p: Vec<Vec<f64>> = vec![vec![2.0, 2.0]];
212 let knn = KNN::new(5, p, vec![1], None, None);
213
214 let q = knn.calculate_distances(&(vec![0.0, 0.0] as Vec<f64>));
215 assert_eq!(q[0].distance, f64::from(8).sqrt());
216 }
217}