1use std::collections::{BinaryHeap, HashMap};
2
3use crate::{dataset::{Dataset, TaskLabelType}, ndarray::{NdArray, utils::softmax}};
4
5use super::{Model, utils::minkowski_distance};
6
7
8#[derive(Debug, Clone, Copy)]
9pub enum KNNAlg {
10 BruteForce,
11 KdTree,
12}
13
14#[derive(Debug, Clone, Copy)]
17pub enum KNNWeighting {
18 Uniform,
19 Distance,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub struct QueryRecord<'a, T: TaskLabelType + Copy + std::cmp::PartialEq> {
24 feature: &'a Vec<f32>,
25 label: T,
26 distance: f32,
27}
28
29impl<'a, T: TaskLabelType + Copy + std::cmp::PartialEq> Eq for QueryRecord<'a, T> {
30
31}
32
33impl<'a, T: TaskLabelType + Copy + std::cmp::PartialEq> Ord for QueryRecord<'a, T> {
34 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
35 self.distance.partial_cmp(&other.distance).unwrap()
36 }
37}
38
39impl<'a, T: TaskLabelType + Copy + std::cmp::PartialEq> PartialOrd for QueryRecord<'a, T> {
40 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
41 self.distance.partial_cmp(&other.distance)
42 }
43}
44
45struct KdNode<T: TaskLabelType + Copy> {
46 feature_idx: usize,
47 sample: Vec<f32>,
48 label: T,
49 left: Option<Box<KdNode<T>>>,
50 right: Option<Box<KdNode<T>>>,
51}
52
53struct KDTree<T: TaskLabelType + Copy> {
54 root: Option<Box<KdNode<T>>>,
55 minkowski_distance_p: f32,
56 k: usize,
57 weighting: KNNWeighting,
58}
59
60pub trait KNNInterface<T: TaskLabelType + Copy + std::cmp::PartialEq> {
61 fn nearest<'a>(&'a self, query: &Vec<f32>) -> Vec<QueryRecord<'a, T>>;
64
65 fn get_weighting(&self) -> KNNWeighting;
66
67}
68
69impl<T: TaskLabelType + Copy + std::cmp::PartialEq + 'static> KNNInterface<T> for KDTree<T> {
70
71 fn nearest<'a>(&'a self, query: &Vec<f32>) -> Vec<QueryRecord<'a, T>> {
72 assert!(self.root.is_some());
74
75 let records_heap = BinaryHeap::new();
76 let mut records_heap = self.recursive_nearest(&self.root, query, records_heap);
77 let mut nearest: Vec<QueryRecord<'a, T>> = vec![];
78 while let Some(item) = records_heap.pop() {
79 nearest.push(item);
80 }
81 nearest.reverse();
82 nearest
83 }
84
85 fn get_weighting(&self) -> KNNWeighting {
86 self.weighting
87 }
88}
89
90impl<T: TaskLabelType + Copy + std::cmp::PartialEq + 'static> KDTree<T> {
91
92 fn new(k: usize, weighting: Option<KNNWeighting>, features: Vec<Vec<f32>>, labels: Vec<T>, total_dim: usize, p: Option<usize>) -> Box<dyn KNNInterface<T>> {
100 assert!(features.len() > 0 && features.len() == labels.len());
101 assert!(k > 0);
102 let feature_label_zip: Vec<(Vec<f32>, T)> = features.into_iter().zip(labels.into_iter()).map(|(f,l)| (f, l)).collect();
103
104 Box::new(Self { root: Self::build(feature_label_zip, total_dim, 0), minkowski_distance_p: p.unwrap_or(2) as f32, k: k, weighting: weighting.unwrap_or(KNNWeighting::Uniform)})
107 }
108
109 fn build(mut feature_label_zip: Vec<(Vec<f32>, T)>, total_dim: usize, depth: usize) -> Option<Box<KdNode<T>>> {
111 if feature_label_zip.len() == 0 {
112 None
113 } else if feature_label_zip.len() == 1 {
114 let axis = depth % total_dim;
115 let (feature, label) = feature_label_zip.pop().unwrap();
116 Some(Box::new(KdNode {feature_idx: axis, label: label, sample: feature, left: None, right: None}))
117 } else {
118 let axis = depth % total_dim;
119 feature_label_zip.sort_by(|a, b| {
120 a.0[axis].partial_cmp(&b.0[axis]).unwrap()
121 });
122
123
124 let median = feature_label_zip.len() / 2;
125
126 let right_feature_label_zip = feature_label_zip.split_off(median + 1);
127 let (median_f, median_l) = feature_label_zip.pop().unwrap();
128
129
130 let left = Self::build(feature_label_zip, total_dim, depth + 1);
131 let right = Self::build(right_feature_label_zip, total_dim, depth + 1);
132
133 Some(Box::new(KdNode {feature_idx: axis, label: median_l, sample: median_f, left: left, right: right}))
134 }
135 }
136
137 fn recursive_nearest<'a>(&'a self, node: &'a Option<Box<KdNode<T>>>, query: &Vec<f32>, mut records_heap: BinaryHeap<QueryRecord<'a, T>>) -> BinaryHeap<QueryRecord<'a, T>> {
139 if node.is_none() {
140 records_heap
141 } else {
142 let d = minkowski_distance(query, &node.as_ref().unwrap().sample, self.minkowski_distance_p);
144
145 let node = node.as_ref().unwrap();
146
147 if records_heap.len() == self.k {
149 let worst_record = records_heap.peek().unwrap();
150 if worst_record.distance > d {
151 records_heap.pop();
152 records_heap.push(QueryRecord { feature: &node.sample, label: node.label, distance: d });
153 }
154 } else {
155 records_heap.push(QueryRecord { feature: &node.sample, label: node.label, distance: d });
156 }
157
158
159 let (good, bad) = if query[node.feature_idx] < node.sample[node.feature_idx] {
163 (&node.left, &node.right)
164 } else {
165 (&node.right, &node.left)
166 };
167
168 records_heap = self.recursive_nearest(good, query, records_heap);
170
171 let worst_record = records_heap.peek().unwrap();
175 if records_heap.len() < self.k ||
176 (query[node.feature_idx] - node.sample[node.feature_idx]).abs() < worst_record.distance {
177 records_heap = self.recursive_nearest(bad, query, records_heap);
178 }
179
180 records_heap
181 }
182 }
183}
184
185
186struct BruteForceSearch<T: TaskLabelType + Copy> {
187 k: usize,
188 minkowski_distance_p: f32,
189 weighting: KNNWeighting,
190 features: Vec<Vec<f32>>,
191 labels: Vec<T>,
192}
193
194impl<T: TaskLabelType + Copy + PartialEq> KNNInterface<T> for BruteForceSearch<T> {
195 fn nearest<'a>(&'a self, query: &Vec<f32>) -> Vec<QueryRecord<'a, T>> {
196 let mut records_heap: BinaryHeap<QueryRecord<'a, T>> = BinaryHeap::new();
197 for (feature, label) in self.features.iter().zip(self.labels.iter()) {
198 let d = minkowski_distance(query, feature, self.minkowski_distance_p);
199 if records_heap.len() == self.k {
200 let worst_record = records_heap.peek().unwrap();
201 if d < worst_record.distance {
202 records_heap.pop();
203 records_heap.push(
204 QueryRecord { feature: feature, label: *label, distance: d }
205 );
206 }
207 } else {
208 records_heap.push(
209 QueryRecord { feature: feature, label: *label, distance: d }
210 );
211 }
212 }
213 let mut res = vec![];
214 while let Some(item) = records_heap.pop() {
215 res.push(item);
216 }
217 res.reverse();
218 res
219 }
220
221 fn get_weighting(&self) -> KNNWeighting {
222 self.weighting
223 }
224}
225
226impl<T: TaskLabelType + Copy + std::cmp::PartialEq + 'static> BruteForceSearch<T> {
227 fn new(k: usize, weighting: Option<KNNWeighting>, features: Vec<Vec<f32>>, labels: Vec<T>, p: Option<usize>) -> Box<dyn KNNInterface<T>> {
234 assert!(features.len() > 0 && features.len() == labels.len());
235 assert!(k > 0);
236 Box::new(Self { k: k, minkowski_distance_p: p.unwrap_or(2) as f32, weighting: weighting.unwrap_or(KNNWeighting::Uniform), features: features, labels: labels })
237 }
238}
239
240
241impl Model<usize> for dyn KNNInterface<usize> {
242 fn predict(&self, feature: &Vec<f32>) -> usize {
243 let res = self.nearest(feature);
244 let mut predicts: HashMap<usize, f32> = HashMap::new();
245 for item in res {
246 *predicts.entry(item.label).or_insert(0.0) += match self.get_weighting() {
247 KNNWeighting::Distance => 1.0 / f32::max(item.distance, 1e-6),
248 KNNWeighting::Uniform => 1.0,
249 }
250 }
251 predicts.iter().fold((0, f32::MAX), |s, i| {
252 if *i.1 > s.1 {
253 (*i.0, *i.1)
254 } else {
255 s
256 }
257 }).0
258 }
259}
260
261impl Model<f32> for dyn KNNInterface<f32> {
262 fn predict(&self, feature: &Vec<f32>) -> f32 {
263 let res = self.nearest(feature);
264 let weights = match self.get_weighting() {
265 KNNWeighting::Distance => {
266 let mut a = NdArray::new(res.iter().map(|i| i.distance).collect::<Vec<f32>>());
267 softmax(&mut a, 0);
268 a.destroy().1
269 },
270 KNNWeighting::Uniform => {
271 vec![1.0 / res.len() as f32; res.len()]
272 }
273 };
274 res.iter().zip(weights.iter()).fold(0.0, |s, (i, w)| {
275 s + i.label * w
276 })
277 }
278}
279
280
281
282pub struct KNNModel<T: TaskLabelType + Copy + PartialEq> {
284 pub alg: KNNAlg,
285 interface: Box<dyn KNNInterface<T>>,
286}
287
288
289impl<T: TaskLabelType + Copy + PartialEq + 'static> KNNModel<T> {
290 pub fn new(alg: KNNAlg, k: usize, weighting: Option<KNNWeighting>, dataset: Dataset<T>, p: Option<usize>) -> Self {
298 let interface= match alg {
299 KNNAlg::BruteForce => {
300 BruteForceSearch::new(k, weighting, dataset.features, dataset.labels, p)
301 },
302 KNNAlg::KdTree => {
303 let total_dim = dataset.feature_len();
304 KDTree::new(k, weighting, dataset.features, dataset.labels, total_dim, p)
305 }
306 };
307 Self { alg: alg, interface: interface }
308 }
309
310 pub fn nearest(&self, query: &Vec<f32>) -> Vec<QueryRecord<T>> {
311 self.interface.nearest(query)
312 }
313}
314
315impl Model<usize> for KNNModel<usize> {
316 fn predict(&self, feature: &Vec<f32>) -> usize {
317 let res = self.interface.nearest(feature);
318 let mut predicts: HashMap<usize, f32> = HashMap::new();
319 for item in res {
320 *predicts.entry(item.label).or_insert(0.0) += match self.interface.get_weighting() {
321 KNNWeighting::Distance => 1.0 / f32::max(item.distance, 1e-6),
322 KNNWeighting::Uniform => 1.0,
323 }
324 }
325 predicts.iter().fold((0, f32::MIN), |s, i| {
326 if *i.1 > s.1 {
327 (*i.0, *i.1)
328 } else {
329 s
330 }
331 }).0
332 }
333}
334
335impl Model<f32> for KNNModel<f32> {
336 fn predict(&self, feature: &Vec<f32>) -> f32 {
337 let res = self.interface.nearest(feature);
338 let weights = match self.interface.get_weighting() {
339 KNNWeighting::Distance => {
340 let mut a = NdArray::new(res.iter().map(|i| i.distance).collect::<Vec<f32>>());
341 softmax(&mut a, 0);
342 a.destroy().1
343 },
344 KNNWeighting::Uniform => {
345 vec![1.0 / res.len() as f32; res.len()]
346 }
347 };
348 res.iter().zip(weights.iter()).fold(0.0, |s, (i, w)| {
349 s + i.label * w
350 })
351 }
352}
353
354#[cfg(test)]
355mod test {
356 use crate::dataset::{Dataset};
357 use crate::model::Model;
358 use crate::model::knn::{KNNWeighting, BruteForceSearch};
359
360 use super::{KNNModel, KNNAlg};
361 use super::{KDTree};
362
363 #[test]
364 fn test_kdtree() {
365 let features = vec![
366 vec![2.0, 3.0],
367 vec![5.0, 4.0],
368 vec![9.0, 6.0],
369 vec![4.0, 7.0],
370 vec![8.0, 1.0],
371 vec![7.0, 2.0],
372 ];
373 let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
374 let tree = KDTree::new(20, Some(KNNWeighting::Distance), features, labels, 2, Some(2));
375 let query = vec![6.0, 7.0];
376 let results = tree.nearest(&query);
377 println!("size {} predict {}\nnearest {results:?}", results.len(), tree.predict(&query));
378 }
379
380 #[test]
381 fn test_brute_force_search() {
382 let features = vec![
383 vec![2.0, 3.0],
384 vec![5.0, 4.0],
385 vec![9.0, 6.0],
386 vec![4.0, 7.0],
387 vec![8.0, 1.0],
388 vec![7.0, 2.0],
389 ];
390 let labels = vec![0, 0, 0, 1, 1, 1];
391 let tree = BruteForceSearch::new(20, Some(KNNWeighting::Distance), features, labels, Some(2));
392 let query = vec![6.0, 7.0];
393 let results = tree.nearest(&query);
394 println!("size {} predict {}\nnearest {results:?}", results.len(), tree.predict(&query));
395 }
396
397 #[test]
398 fn test_knn() {
399 let features = vec![
400 vec![2.0, 3.0],
401 vec![5.0, 4.0],
402 vec![9.0, 6.0],
403 vec![4.0, 7.0],
404 vec![8.0, 1.0],
405 vec![7.0, 2.0],
406 ];
407 let labels = vec![0, 0, 0, 1, 1, 1];
408 let dataset = Dataset::new(features, labels, None);
409 let knn = KNNModel::new(KNNAlg::KdTree, 1, None, dataset, None);
410 let query = vec![7.0, 1.9];
411 println!("nearest {:?}\npredict = {}", knn.nearest(&query), knn.predict(&query));
412 }
413}