mlinrust/model/
knn.rs

1use std::collections::{BinaryHeap, HashMap};
2
3use crate::{dataset::{Dataset, TaskLabelType}, ndarray::{NdArray, utils::softmax}};
4
5use super::{Model, utils::minkowski_distance};
6
7
8#[derive(Debug, Clone, Copy)]
9pub enum KNNAlg {
10    BruteForce,
11    KdTree,
12}
13
14/// Uniform means marjorty voting for classification task
15/// Distance means weighting ensemble based on distance
16#[derive(Debug, Clone, Copy)]
17pub enum KNNWeighting {
18    Uniform,
19    Distance,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub struct QueryRecord<'a, T: TaskLabelType + Copy + std::cmp::PartialEq> {
24    feature: &'a Vec<f32>,
25    label: T,
26    distance: f32,
27}
28
29impl<'a, T: TaskLabelType + Copy + std::cmp::PartialEq> Eq for QueryRecord<'a, T> {
30    
31}
32
33impl<'a, T: TaskLabelType + Copy + std::cmp::PartialEq> Ord for QueryRecord<'a, T> {
34    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
35        self.distance.partial_cmp(&other.distance).unwrap()
36    }
37}
38
39impl<'a, T: TaskLabelType + Copy + std::cmp::PartialEq> PartialOrd for QueryRecord<'a, T> {
40    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
41        self.distance.partial_cmp(&other.distance)
42    }
43}
44
45struct KdNode<T: TaskLabelType + Copy> {
46    feature_idx: usize,
47    sample: Vec<f32>,
48    label: T,
49    left: Option<Box<KdNode<T>>>,
50    right: Option<Box<KdNode<T>>>,
51}
52
53struct KDTree<T: TaskLabelType + Copy> {
54    root: Option<Box<KdNode<T>>>,
55    minkowski_distance_p: f32,
56    k: usize,
57    weighting: KNNWeighting,
58}
59
60pub trait KNNInterface<T: TaskLabelType + Copy + std::cmp::PartialEq> {
61    /// find the nearest k nodes around the query
62    /// * return: ordering Vector<(node_sample_feature, node_label, distance)>, size = k
63    fn nearest<'a>(&'a self, query: &Vec<f32>) -> Vec<QueryRecord<'a, T>>;
64
65    fn get_weighting(&self) -> KNNWeighting;
66
67}
68
69impl<T: TaskLabelType + Copy + std::cmp::PartialEq + 'static> KNNInterface<T> for KDTree<T> {
70
71    fn nearest<'a>(&'a self, query: &Vec<f32>) -> Vec<QueryRecord<'a, T>> {
72        // the initial best records is trivial, so borrow query
73        assert!(self.root.is_some());
74
75        let records_heap = BinaryHeap::new();
76        let mut records_heap = self.recursive_nearest(&self.root, query, records_heap);
77        let mut nearest: Vec<QueryRecord<'a, T>>  = vec![];
78        while let Some(item) = records_heap.pop() {
79            nearest.push(item);
80        }
81        nearest.reverse();
82        nearest
83    }
84
85    fn get_weighting(&self) -> KNNWeighting {
86        self.weighting
87    }
88}
89
90impl<T: TaskLabelType + Copy + std::cmp::PartialEq + 'static>  KDTree<T> {
91    
92    /// * k: k nearest neighbours
93    /// * weighting: weighting the neibours, default is these neighbours are equal
94    /// * features: \[batch, feature\]
95    /// * labels: \[batch\]
96    /// * total_dim: total dim of the feature
97    /// * p: the parameter p of [minkowski distance](https://en.wikipedia.org/wiki/Minkowski_distance)
98    ///     * default is p = 2
99    fn new(k: usize, weighting: Option<KNNWeighting>, features: Vec<Vec<f32>>, labels: Vec<T>, total_dim: usize, p: Option<usize>) -> Box<dyn KNNInterface<T>> {
100        assert!(features.len() > 0 && features.len() == labels.len());
101        assert!(k > 0);
102        let feature_label_zip: Vec<(Vec<f32>, T)> = features.into_iter().zip(labels.into_iter()).map(|(f,l)| (f, l)).collect();
103
104        // Box::new(Self { root: None, minkowski_distance_p: p.unwrap_or(2) as f32, k: k, weighting: weighting.unwrap_or(KNNWeighting::Uniform)})
105
106        Box::new(Self { root:  Self::build(feature_label_zip, total_dim, 0), minkowski_distance_p: p.unwrap_or(2) as f32, k: k, weighting: weighting.unwrap_or(KNNWeighting::Uniform)})
107    }
108
109    /// features: [batch, (feature, label)]
110    fn build(mut feature_label_zip: Vec<(Vec<f32>, T)>, total_dim: usize, depth: usize) -> Option<Box<KdNode<T>>> {
111        if feature_label_zip.len() == 0 {
112            None
113        } else if feature_label_zip.len() == 1 {
114            let axis = depth % total_dim;
115            let (feature, label) = feature_label_zip.pop().unwrap();
116            Some(Box::new(KdNode {feature_idx: axis, label: label, sample: feature, left: None, right: None}))
117        } else {
118            let axis = depth % total_dim;
119            feature_label_zip.sort_by(|a, b| {
120                a.0[axis].partial_cmp(&b.0[axis]).unwrap()
121            });
122
123
124            let median = feature_label_zip.len() / 2;
125
126            let right_feature_label_zip = feature_label_zip.split_off(median + 1);
127            let (median_f, median_l) = feature_label_zip.pop().unwrap();
128
129            
130            let left = Self::build(feature_label_zip, total_dim, depth + 1);
131            let right = Self::build(right_feature_label_zip, total_dim, depth + 1);
132            
133            Some(Box::new(KdNode {feature_idx: axis, label: median_l, sample: median_f, left: left, right: right}))
134        }
135    }
136    
137    /// * return: MaxHeap<queryrecord>
138    fn recursive_nearest<'a>(&'a self, node: &'a Option<Box<KdNode<T>>>, query: &Vec<f32>, mut records_heap: BinaryHeap<QueryRecord<'a, T>>) -> BinaryHeap<QueryRecord<'a, T>> {
139        if node.is_none() {
140            records_heap
141        } else {
142            // calculate distance from query and current node
143            let d = minkowski_distance(query, &node.as_ref().unwrap().sample, self.minkowski_distance_p);
144
145            let node = node.as_ref().unwrap();
146
147            // update best records
148            if records_heap.len() == self.k {
149                let worst_record = records_heap.peek().unwrap();
150                if worst_record.distance > d {
151                    records_heap.pop();
152                    records_heap.push(QueryRecord { feature: &node.sample, label: node.label, distance: d });
153                }
154            } else {
155                records_heap.push(QueryRecord { feature: &node.sample, label: node.label, distance: d });
156            }
157            
158
159            // find the best from subtrees
160            // good is the one that follows the median value (less goes left, more goes right)
161            // then, bad is the opposite choice
162            let (good, bad) = if query[node.feature_idx] < node.sample[node.feature_idx] {
163                (&node.left, &node.right)
164            } else {
165                (&node.right, &node.left)
166            };
167
168            // explore the good side
169            records_heap = self.recursive_nearest(good, query, records_heap);
170
171            // explore the bad side
172            // only if it has probability for less than the best distance, i.e., other features except feature[axis] are equal to query (has that probability)
173            // otherwise, take pruning
174            let worst_record = records_heap.peek().unwrap();
175            if records_heap.len() < self.k ||
176            (query[node.feature_idx] - node.sample[node.feature_idx]).abs() < worst_record.distance {
177                records_heap = self.recursive_nearest(bad, query, records_heap);
178            }
179            
180            records_heap
181        }
182    }
183}
184
185
186struct BruteForceSearch<T: TaskLabelType + Copy> {
187    k: usize,
188    minkowski_distance_p: f32,
189    weighting: KNNWeighting,   
190    features: Vec<Vec<f32>>,
191    labels: Vec<T>,
192}
193
194impl<T: TaskLabelType + Copy + PartialEq> KNNInterface<T> for BruteForceSearch<T> {
195    fn nearest<'a>(&'a self, query: &Vec<f32>) -> Vec<QueryRecord<'a, T>> {
196        let mut records_heap: BinaryHeap<QueryRecord<'a, T>> = BinaryHeap::new();
197        for (feature, label) in self.features.iter().zip(self.labels.iter()) {
198            let d = minkowski_distance(query, feature, self.minkowski_distance_p);
199            if records_heap.len() == self.k {
200                let worst_record = records_heap.peek().unwrap();
201                if d < worst_record.distance {
202                    records_heap.pop();
203                    records_heap.push(
204                        QueryRecord { feature: feature, label: *label, distance: d }
205                    );
206                }
207            } else {
208                records_heap.push(
209                    QueryRecord { feature: feature, label: *label, distance: d }
210                );
211            }
212        }
213        let mut res = vec![];
214        while let Some(item) = records_heap.pop() {
215            res.push(item);
216        }
217        res.reverse();
218        res
219    }
220
221    fn get_weighting(&self) -> KNNWeighting {
222        self.weighting
223    }
224}
225
226impl<T: TaskLabelType + Copy + std::cmp::PartialEq + 'static> BruteForceSearch<T> {
227    /// * k: k nearest neighbours
228    /// * weighting: weighting the neibours, default is these neighbours are equal
229    /// * features: \[batch, feature\]
230    /// * labels: \[batch\]
231    /// * p: the parameter p of [minkowski distance](https://en.wikipedia.org/wiki/Minkowski_distance)
232    ///     * default is p = 2
233    fn new(k: usize, weighting: Option<KNNWeighting>, features: Vec<Vec<f32>>, labels: Vec<T>, p: Option<usize>) -> Box<dyn KNNInterface<T>> {
234        assert!(features.len() > 0 && features.len() == labels.len());
235        assert!(k > 0);
236        Box::new(Self { k: k, minkowski_distance_p: p.unwrap_or(2) as f32, weighting: weighting.unwrap_or(KNNWeighting::Uniform), features: features, labels: labels })
237    }
238}
239
240
241impl Model<usize> for dyn KNNInterface<usize> {
242    fn predict(&self, feature: &Vec<f32>) -> usize {
243        let res = self.nearest(feature);
244        let mut predicts: HashMap<usize, f32> = HashMap::new();
245        for item in res {
246            *predicts.entry(item.label).or_insert(0.0) += match self.get_weighting() {
247                KNNWeighting::Distance => 1.0 / f32::max(item.distance, 1e-6),
248                KNNWeighting::Uniform => 1.0,
249            }
250        }
251        predicts.iter().fold((0, f32::MAX), |s, i| {
252            if *i.1 > s.1 {
253                (*i.0, *i.1)
254            } else {
255                s
256            }
257        }).0
258    }
259}
260
261impl Model<f32> for dyn KNNInterface<f32> {
262    fn predict(&self, feature: &Vec<f32>) -> f32 {
263        let res = self.nearest(feature);
264        let weights = match self.get_weighting() {
265            KNNWeighting::Distance => {
266                let mut a = NdArray::new(res.iter().map(|i| i.distance).collect::<Vec<f32>>());
267                softmax(&mut a, 0);
268                a.destroy().1
269            },
270            KNNWeighting::Uniform => {
271                vec![1.0 / res.len() as f32; res.len()]
272            }
273        };
274        res.iter().zip(weights.iter()).fold(0.0, |s, (i, w)| {
275            s + i.label * w
276        })
277    }
278}
279
280
281
282/// KNN implemented by [KDTree](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm) or BruteForceSearch
283pub struct KNNModel<T: TaskLabelType + Copy + PartialEq> {
284    pub alg: KNNAlg,
285    interface: Box<dyn KNNInterface<T>>,
286}
287
288
289impl<T: TaskLabelType + Copy + PartialEq + 'static> KNNModel<T> {
290    /// initialization is training of KdTree; for BruteForce, it is lazy training (i.e., no training)
291    /// * alg: algorithm of knn
292    /// * k: k nearest neighbours
293    /// * weighting: for the ensemble results
294    ///     * default is uniform
295    /// * p: parameter of minkowski distance
296    ///     * default is 2, i.e., Euclidean distance
297    pub fn new(alg: KNNAlg, k: usize, weighting: Option<KNNWeighting>, dataset: Dataset<T>, p: Option<usize>) -> Self {
298        let interface= match alg {
299            KNNAlg::BruteForce => {
300                BruteForceSearch::new(k, weighting, dataset.features, dataset.labels, p)
301            },
302            KNNAlg::KdTree => {
303                let total_dim = dataset.feature_len();
304                KDTree::new(k, weighting, dataset.features, dataset.labels, total_dim, p)
305            }
306        };
307        Self { alg: alg, interface: interface }
308    }
309
310    pub fn nearest(&self, query: &Vec<f32>) -> Vec<QueryRecord<T>> {
311        self.interface.nearest(query)
312    }
313}
314
315impl Model<usize> for KNNModel<usize> {
316    fn predict(&self, feature: &Vec<f32>) -> usize {
317        let res = self.interface.nearest(feature);
318        let mut predicts: HashMap<usize, f32> = HashMap::new();
319        for item in res {
320            *predicts.entry(item.label).or_insert(0.0) += match self.interface.get_weighting() {
321                KNNWeighting::Distance => 1.0 / f32::max(item.distance, 1e-6),
322                KNNWeighting::Uniform => 1.0,
323            }
324        }
325        predicts.iter().fold((0, f32::MIN), |s, i| {
326            if *i.1 > s.1 {
327                (*i.0, *i.1)
328            } else {
329                s
330            }
331        }).0
332    }
333}
334
335impl Model<f32> for KNNModel<f32> {
336    fn predict(&self, feature: &Vec<f32>) -> f32 {
337        let res = self.interface.nearest(feature);
338        let weights = match self.interface.get_weighting() {
339            KNNWeighting::Distance => {
340                let mut a = NdArray::new(res.iter().map(|i| i.distance).collect::<Vec<f32>>());
341                softmax(&mut a, 0);
342                a.destroy().1
343            },
344            KNNWeighting::Uniform => {
345                vec![1.0 / res.len() as f32; res.len()]
346            }
347        };
348        res.iter().zip(weights.iter()).fold(0.0, |s, (i, w)| {
349            s + i.label * w
350        })
351    }
352}
353
354#[cfg(test)]
355mod test {
356    use crate::dataset::{Dataset};
357    use crate::model::Model;
358    use crate::model::knn::{KNNWeighting, BruteForceSearch};
359
360    use super::{KNNModel, KNNAlg};
361    use super::{KDTree};
362
363    #[test]
364    fn test_kdtree() {
365        let features = vec![
366            vec![2.0, 3.0],
367            vec![5.0, 4.0],
368            vec![9.0, 6.0],
369            vec![4.0, 7.0],
370            vec![8.0, 1.0],
371            vec![7.0, 2.0],
372        ];
373        let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
374        let tree = KDTree::new(20, Some(KNNWeighting::Distance), features, labels, 2, Some(2));
375        let query = vec![6.0, 7.0];
376        let results =  tree.nearest(&query);
377        println!("size {} predict {}\nnearest {results:?}", results.len(), tree.predict(&query));
378    }
379
380    #[test]
381    fn test_brute_force_search() {
382        let features = vec![
383            vec![2.0, 3.0],
384            vec![5.0, 4.0],
385            vec![9.0, 6.0],
386            vec![4.0, 7.0],
387            vec![8.0, 1.0],
388            vec![7.0, 2.0],
389        ];
390        let labels = vec![0, 0, 0, 1, 1, 1];
391        let tree = BruteForceSearch::new(20, Some(KNNWeighting::Distance), features, labels, Some(2));
392        let query = vec![6.0, 7.0];
393        let results =  tree.nearest(&query);
394        println!("size {} predict {}\nnearest {results:?}", results.len(), tree.predict(&query));
395    }
396
397    #[test]
398    fn test_knn() {
399        let features = vec![
400            vec![2.0, 3.0],
401            vec![5.0, 4.0],
402            vec![9.0, 6.0],
403            vec![4.0, 7.0],
404            vec![8.0, 1.0],
405            vec![7.0, 2.0],
406        ];
407        let labels = vec![0, 0, 0, 1, 1, 1];
408        let dataset = Dataset::new(features, labels, None);
409        let knn = KNNModel::new(KNNAlg::KdTree, 1, None, dataset, None);
410        let query = vec![7.0, 1.9];
411        println!("nearest {:?}\npredict = {}", knn.nearest(&query), knn.predict(&query));
412    }
413}