rml/
knn.rs

1// Copyright 2021 Jonathan Manly.
2
3// This file is part of rml.
4
5// rml is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Lesser General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9
10// rml is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13// GNU Lesser General Public License for more details.
14
15// You should have received a copy of the GNU Lesser General Public License
16// along with rml.  If not, see <https://www.gnu.org/licenses/>.
17
18//! Implementation for K-Nearest Neighbors.
19
20/*!
21Allows for predicting data based on a KNN search.
22A full, working example is contained in the `examples/knn` directory.
23
24# Example
25```rust
26// Collect and parse data to a format consistent with:
27// type CSVOutput = (Vec<Vec<f64>>, Vec<i32>);
28
29let training_data: CSVOutput = (Vec::new, Vec::new);
30let testing_data: CSVOutput = (Vec::new, Vec::new);
31
32// Create a new KNN struct.
33let knn = knn::KNN::new(
345, // 5-nearest
35training_data.0, // x
36training_data.1, // y
37None, // Default distance(euclidean)
38Some(math::norm::Norm::L2), // L2 Normalization
39);
40
41// Get a prediction for each point of the testing data.
42let pred: Vec<i32> = testing_data.0.iter().map(|x| knn.predict(x)).collect();
43
44// Count the number that were predicted correctly.
45let num_correct = pred
46    .iter()
47    .cloned()
48    .zip(&testing_data.1)
49    .filter(|(a, b)| *a == **b)
50    .count();
51
52println!(
53    "Accuracy: {}",
54    (num_correct as f64) / (pred.len() as f64)
55);
56
57```
58!*/
59
60use crate::math::distance;
61use crate::math::norm;
62use rayon::prelude::*;
63use std::cmp::Ordering;
64use std::collections::HashSet;
65
66/// KNN struct handles the computation and data for the K-Nearest Neighbors algorithm.
67/// It is *highly recommended* to not change values inside of this struct manually. Always
68/// create a new one using ::new.
69#[derive(Debug)]
70pub struct KNN {
71    /// K-Nearest to analyze
72    pub k: i32,
73    /// Features
74    pub x: Vec<Vec<f64>>,
75    /// Class labels for each feature.
76    pub y: Vec<i32>,
77    /// Number of labels.
78    pub num_labels: usize,
79    /// Type of distance to use.
80    pub distance: Option<distance::Distance>,
81    /// The type of normalization, or None.
82    pub normalize: Option<norm::Norm>,
83}
84
85/// A data point.
86#[derive(PartialEq, Debug)]
87pub struct Point {
88    /// The class label for the point.
89    pub class: i32,
90    /// The distance from the test point.
91    pub distance: f64,
92}
93
94impl PartialOrd for Point {
95    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
96        self.distance.partial_cmp(&other.distance)
97    }
98}
99
100impl Eq for Point {}
101
102impl KNN {
103    /// Create a new KNN with optional normalization.
104    pub fn new(
105        k: i32,
106        x: Vec<Vec<f64>>,
107        y: Vec<i32>,
108        distance: Option<distance::Distance>,
109        normalize: Option<norm::Norm>,
110    ) -> KNN {
111        let num_labels = KNN::get_num_labels(&y);
112        let mut knn = KNN {
113            k,
114            x,
115            y,
116            num_labels,
117            distance,
118            normalize,
119        };
120        knn.normalize_data();
121        knn
122    }
123
124    /// Gets the number of unique labels.
125    /// This function is called when ::new is called. You can access the value using
126    /// the value contained in the KNN struct.
127    pub fn get_num_labels(y: &[i32]) -> usize {
128        let set: HashSet<i32> = y.iter().cloned().collect::<HashSet<_>>();
129        set.len()
130    }
131
132    /// Normalize the data contain in `self` given by the KNN's configured normalization setting.
133    pub fn normalize_data(&mut self) {
134        if let Some(n) = &self.normalize {
135            self.x
136                .iter_mut()
137                .for_each(|xi| norm::normalize_vector(xi, n));
138        }
139    }
140
141    /// Borrow immutable reference to the data.
142    pub fn data(&self) -> (&Vec<Vec<f64>>, &Vec<i32>) {
143        (&self.x, &self.y)
144    }
145
146    /// Calculate the distance from `new_point` to all other points in the set.
147    /// Note: new_point must be the same dimensions as the data passed into ::new.norm
148    pub fn calculate_distances(&self, new_point: &[f64]) -> Vec<Point> {
149        let distance_fn = match self.distance {
150            Some(distance::Distance::Manhattan) => distance::manhattan_distance,
151            _ => distance::euclidean_distance,
152        };
153
154        self.x
155            .par_iter()
156            .zip(self.y.par_iter())
157            .map(|(x, y)| Point {
158                class: *y,
159                distance: distance_fn(new_point, x),
160            })
161            .collect()
162    }
163
164    /// Predict the class of a point `x`.
165    pub fn predict(&self, x: &[f64]) -> i32 {
166        let mut norm_x: Vec<f64> = x.to_owned();
167        if let Some(n) = &self.normalize {
168            norm::normalize_vector(&mut norm_x, n);
169        }
170        let mut points = self.calculate_distances(x);
171        // points.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
172        points.par_sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
173
174        let mut predictions = vec![0; self.num_labels];
175
176        for i in &points[0..self.k as usize] {
177            predictions[i.class as usize] += 1;
178        }
179        KNN::get_max_value(&predictions)
180    }
181
182    /// Get the class of the highest index.
183    fn get_max_value(predictions: &[i32]) -> i32 {
184        predictions
185            .iter()
186            .enumerate() // add index to the iterated items [a, b, c] -> [(0, a), (1, b), (2, c)]
187            .max_by_key(|(_, pred)| **pred) // take maximum by the actual item, not the index,
188            // `pred` has type `&&i32`, because of all the combinators, so we have to dereference twice
189            .map(|(idx, _)| idx) // Option::map - take tuple (idx, value) and transform it to just idx
190            .unwrap() as i32
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn normalize_data_test() {
200        let p: Vec<Vec<f64>> = vec![vec![2.0, 2.0, 2.0]];
201        let mut knn = KNN::new(5, p, vec![1], None, Some(norm::Norm::L2));
202        knn.normalize_data();
203        assert_eq!(
204            knn.data().0.clone(),
205            vec![vec![2.0 / f64::from(12).sqrt(); 3]]
206        );
207    }
208
209    #[test]
210    fn calculate_distances_test() {
211        let p: Vec<Vec<f64>> = vec![vec![2.0, 2.0]];
212        let knn = KNN::new(5, p, vec![1], None, None);
213
214        let q = knn.calculate_distances(&(vec![0.0, 0.0] as Vec<f64>));
215        assert_eq!(q[0].distance, f64::from(8).sqrt());
216    }
217}