kdtree_rust/
lib.rs

1use std::borrow::Borrow;
2use std::cell::RefCell;
3use std::fmt::Debug;
4use std::marker::PhantomData;
5use std::ops::{Add, Deref, Mul, Neg, Sub};
6use std::rc::Rc;
7
8#[derive(Debug,Clone,Copy,PartialEq,Eq)]
9pub enum Color {
10    Red,
11    Black
12}
13#[derive(Debug,Clone,Copy,PartialEq,Eq)]
14pub enum LR {
15    L,
16    R
17}
18#[derive(Debug,Clone,Copy,PartialEq,Eq)]
19pub enum Balance {
20    Pre,
21    Fix,
22    None
23}
24impl Neg for Color {
25    type Output = Color;
26    fn neg(self) -> Self::Output {
27        match self {
28            Color::Red => Color::Black,
29            Color::Black => Color::Red
30        }
31    }
32}
33pub trait Square: Mul + Sized + Clone + Copy {
34    fn square(self) -> <Self as Mul>::Output {
35        self * self
36    }
37}
38impl<T> Square for T where T: Mul + Sized + Clone + Copy {}
39pub trait EuclideanDistance<T> {
40    type Output;
41    fn euclidean_distance(self, rhs:T) -> Self::Output;
42}
43impl<'a,const K:usize,P> EuclideanDistance<&'a [P; K]> for &'a [P; K]
44    where P: PartialOrd + Mul<P, Output = P> + Add<P, Output = P> + Sub<P, Output = P> +
45             Clone + Copy + Default + Distance<P, Output = P> + Square + Sized {
46    type Output = P;
47
48    fn euclidean_distance(self, rhs: &'a [P; K]) -> P {
49        self.iter().zip(rhs.iter()).fold(P::default(),|acc,(p1,p2)| {
50            acc + p1.distance(p2).square()
51        })
52    }
53}
54pub trait Distance<T> {
55    type Output;
56
57    fn distance(&self, rhs:&T) -> Self::Output;
58}
59impl<P> Distance<P> for P
60    where P: PartialOrd + Mul<P, Output = P> + Add<P, Output = P> + Sub<P, Output = P> +
61             Clone + Copy + Default {
62    type Output = P;
63
64    fn distance(&self, rhs: &P) -> P {
65        if self.partial_cmp(rhs).unwrap().is_le() {
66            *rhs - *self
67        } else {
68            *self - *rhs
69        }
70    }
71}
72#[derive(Debug)]
73pub struct KDNode<'a,const K:usize,P,T>
74    where P: Debug + PartialOrd + Mul<Output = P> + Add + Sub +
75             Clone + Copy + Default + Distance<P, Output = P>  + Square + Sized + 'a,
76             &'a [P; K]: EuclideanDistance<&'a [P; K], Output = P> + 'a {
77    positions:Rc<[P; K]>,
78    value:Rc<RefCell<T>>,
79    color: Rc<RefCell<Color>>,
80    left:Option<Box<KDNode<'a, K,P,T>>>,
81    right:Option<Box<KDNode<'a, K,P,T>>>,
82    demension:usize,
83    l:PhantomData<&'a ()>
84}
85impl<'a,const K:usize,P,T> KDNode<'a, K,P,T>
86    where P: Debug + PartialOrd + Mul<Output = P> + Add + Sub +
87             Clone + Copy + Default + Distance<P, Output = P> + 'a,
88             &'a [P; K]: EuclideanDistance<&'a [P; K], Output = P> + 'a {
89    pub fn new(positions:Rc<[P; K]>, value:Rc<RefCell<T>>, demension:usize) -> KDNode<'a, K,P,T> {
90        KDNode {
91            positions: positions,
92            value: value,
93            color: Rc::new(RefCell::new(Color::Red)),
94            left: None,
95            right: None,
96            demension: demension,
97            l:PhantomData::<&'a ()>
98        }
99    }
100
101    fn with_color(positions:Rc<[P; K]>, value:Rc<RefCell<T>>, color:Rc<RefCell<Color>>, demension:usize) -> KDNode<'a, K,P,T> {
102        KDNode {
103            positions: positions,
104            value: value,
105            color: color,
106            left: None,
107            right: None,
108            demension: demension,
109            l:PhantomData::<&'a ()>
110        }
111    }
112
113    pub fn right_rotate(t: KDNode<'a, K,P,T>) -> KDNode<'a, K,P,T> {
114        match t.left {
115            Some(left) => {
116                KDNode {
117                    positions: left.positions,
118                    value: left.value,
119                    color: left.color,
120                    left: left.left,
121                    right: Some(Box::new(KDNode {
122                        positions: t.positions,
123                        value: t.value,
124                        color: t.color,
125                        left: left.right,
126                        right: t.right,
127                        demension: t.demension,
128                        l:PhantomData::<&'a ()>
129                    },)),
130                    demension: left.demension,
131                    l:PhantomData::<&'a ()>
132                }
133            },
134            None => t
135        }
136    }
137
138    pub fn left_rotate(t: KDNode<'a, K,P,T>) -> KDNode<'a, K,P,T> {
139        match t.right {
140            Some(right) => {
141                KDNode {
142                    positions: right.positions,
143                    value: right.value,
144                    color: right.color,
145                    right: right.right,
146                    left: Some(Box::new(KDNode {
147                        positions: t.positions,
148                        value: t.value,
149                        color: t.color,
150                        right: right.left,
151                        left: t.left,
152                        demension:  t.demension,
153                        l:PhantomData::<&'a ()>
154                    })),
155                    demension: right.demension,
156                    l:PhantomData::<&'a ()>
157                }
158            },
159            None => t
160        }
161    }
162
163    #[allow(dead_code)]
164    fn left_and_right_rotate(mut t: KDNode<'a, K,P,T>) -> KDNode<'a, K,P,T> {
165        match t.left.take() {
166            None => {
167                t
168            },
169            Some(left) => {
170                t.left = Some(Box::new(Self::left_rotate(*left)));
171                Self::right_rotate(t)
172            }
173        }
174    }
175
176    #[allow(dead_code)]
177    fn right_and_left_rotate(mut t: KDNode<'a, K,P,T>) -> KDNode<'a, K,P,T> {
178        match t.right.take() {
179            None => {
180                t
181            },
182            Some(right) => {
183                t.right = Some(Box::new(Self::right_rotate(*right)));
184                Self::left_rotate(t)
185            }
186        }
187    }
188
189    fn nearest(t: Option<&'a Box<KDNode<'a, K,P,T>>>,
190               positions:&'a [P; K],
191               distance:P,
192               nearest_positions:&'a [P; K],
193               current_value:&Rc<RefCell<T>>,
194               demension:usize) -> Option<(P, &'a [P; K], Rc<RefCell<T>>)> {
195        t.and_then(|t| {
196            if positions[demension].partial_cmp(&t.positions[demension]).unwrap().is_lt() {
197                if let Some(c) = t.left.as_ref() {
198                    if let Some((distance,current_positions,current_value)) = Self::nearest(
199                        Some(&c), positions, distance, nearest_positions, current_value, (demension + 1) % K) {
200                        let (distance,current_positions,current_value) = Self::nearest_center(Some(t),
201                                                                                                    positions,
202                                                                                                    distance,
203                                                                                                    current_positions,
204                                                                                            &current_value,
205                                                                                                     demension).unwrap();
206
207                        if let Some(c) = t.right.as_ref() {
208                            if distance.partial_cmp(&positions[demension].distance(&c.positions[demension]).square()).unwrap().is_lt() {
209                                Some((distance,current_positions,current_value))
210                            } else {
211                                Self::nearest(Some(&c),positions,distance,current_positions,&current_value,(demension + 1) % K)
212                            }
213                        } else {
214                            Some((distance,current_positions,current_value))
215                        }
216                    } else {
217                        unreachable!()
218                    }
219                } else {
220                    let (distance,current_positions,current_value) = Self::nearest_center(Some(t),
221                                                                              positions,
222                                                                              distance,
223                                                                              nearest_positions,
224                                                                              &current_value,
225                                                                              demension).unwrap();
226
227                    if let Some(c) = t.right.as_ref() {
228                        if distance.partial_cmp(&positions[demension].distance(&c.positions[demension]).square()).unwrap().is_lt() {
229                            Some((distance,current_positions,current_value))
230                        } else {
231                            Self::nearest(Some(&c),positions,distance,current_positions,&current_value,(demension + 1) % K)
232                        }
233                    } else {
234                        Some((distance,current_positions,current_value))
235                    }
236                }
237            } else {
238                if let Some(c) = t.right.as_ref() {
239                    if let Some((distance,current_positions,current_value)) = Self::nearest(
240                        Some(&c),positions,distance,nearest_positions,current_value,(demension + 1) % K) {
241                        let (distance,current_positions,current_value) = Self::nearest_center(Some(t),
242                                                                                                                        positions,
243                                                                                                                distance,
244                                                                                                            current_positions,
245                                                                                                                        &current_value,
246                                                                                                                    demension).unwrap();
247
248                        if let Some(c) = t.left.as_ref() {
249                            if distance.partial_cmp(&positions[demension].distance(&c.positions[demension]).square()).unwrap().is_lt() {
250                                Some((distance,current_positions,current_value))
251                            } else {
252                                Self::nearest(Some(&c),positions,distance,current_positions,&current_value,(demension + 1) % K)
253                            }
254                        } else {
255                            Some((distance,current_positions,current_value))
256                        }
257                    } else {
258                        unreachable!()
259                    }
260                } else {
261                    let (distance,current_positions,current_value) = Self::nearest_center(Some(t),
262                                                                              positions,
263                                                                              distance,
264                                                                              nearest_positions,
265                                                                              &current_value,
266                                                                              demension).unwrap();
267
268
269                    if let Some(c) = t.left.as_ref() {
270                        if distance.partial_cmp(&positions[demension].distance(&c.positions[demension]).square()).unwrap().is_lt() {
271                            Some((distance,current_positions,current_value))
272                        } else {
273                            Self::nearest(Some(&c),positions,distance,current_positions,&current_value,(demension + 1) % K)
274                        }
275                    } else {
276                        Some((distance,current_positions,current_value))
277                    }
278                }
279            }
280        })
281    }
282
283    fn nearest_center(t: Option<&'a Box<KDNode<'a, K,P,T>>>,
284               positions:&'a [P; K],
285               mut distance:P,
286               nearest_positions:&'a [P; K],
287               current_value:&Rc<RefCell<T>>,
288               _:usize) -> Option<(P, &'a [P; K], Rc<RefCell<T>>)> {
289        Some(t.and_then(|t| {
290            let d = positions.euclidean_distance(&t.positions);
291
292            let mut current_value = Rc::clone(&current_value);
293
294            let current_positions = if d.partial_cmp(&distance).unwrap().is_le() {
295                distance = d;
296                current_value = Rc::clone(&t.value);
297                t.positions.borrow()
298            } else {
299                nearest_positions
300            };
301
302            Some((distance,current_positions, current_value))
303        }).expect("current node is none."))
304    }
305
306    fn insert(t: Option<Box<KDNode<'a, K,P,T>>>,
307              positions:&Rc<[P; K]>,
308              color:&Rc<RefCell<Color>>,
309              parent_color:Option<Color>,
310              lr:Option<LR>,
311              value:Rc<RefCell<T>>,
312              demension:usize) -> (KDNode<'a, K,P,T>, Balance) {
313        match t {
314            None if demension == K - 1 => {
315                let b = if parent_color.map(|c| c == Color::Red).unwrap_or(false) {
316                    Balance::Pre
317                } else {
318                    Balance::None
319                };
320
321                let color = if let Some(_) = parent_color {
322                    Rc::new(RefCell::new(Color::Black))
323                } else {
324                    Rc::clone(&color)
325                };
326
327                (KDNode::with_color(Rc::clone(positions), Rc::clone(&value),color,demension),b)
328            },
329            None if demension == 0 => {
330                let t = KDNode {
331                    positions: Rc::clone(positions),
332                    value: Rc::clone(&value),
333                    color: Rc::clone(color),
334                    left: None,
335                    right: None,
336                    demension: demension,
337                    l:PhantomData::<&'a ()>
338                };
339
340                (t,Balance::None)
341            },
342            None => {
343                let t = KDNode {
344                    positions: Rc::clone(positions),
345                    value: Rc::clone(&value),
346                    color: Rc::clone(&color),
347                    left: None,
348                    right: None,
349                    demension: demension,
350                    l:PhantomData::<&'a ()>
351                };
352
353                (t,Balance::None)
354            },
355            Some(mut t) if demension == K - 1 => {
356                let parent_color = Some(color.deref().borrow().clone());
357
358                if positions[demension].partial_cmp(&t.positions[demension]).unwrap().is_lt() {
359                    let (n,b) = Self::insert(t.left,
360                                             positions,
361                                             &Rc::clone(&t.color),
362                                             parent_color,
363                                             Some(LR::L),
364                                             value, (demension+1) % K);
365
366                    t.left = Some(Box::new(n));
367
368                    (*t,b)
369                } else {
370                    let (n,b) = Self::insert(t.right,
371                                             positions,
372                                             &Rc::clone(&t.color),
373                                             parent_color,
374                                             Some(LR::R),
375                                             value, (demension+1) % K);
376
377                    t.right = Some(Box::new(n));
378
379                    (*t,b)
380                }
381            },
382            Some(mut t) if demension == 0 => {
383                if positions[demension].partial_cmp(&t.positions[demension]).unwrap().is_lt() {
384                    let (n,b) = Self::insert(t.left,
385                                             positions,
386                                             color,
387                                             parent_color,
388                                             lr,
389                                             value, (demension+1) % K);
390
391                    t.left = Some(Box::new(n));
392
393                    Self::balance(*t, demension, b, lr,Some(LR::L))
394                } else {
395                    let (n,b) = Self::insert(t.right,
396                                             positions,
397                                             color,
398                                             parent_color,
399                                             lr,
400                                             value, (demension+1) % K);
401
402                    t.right = Some(Box::new(n));
403
404                    Self::balance(*t, demension, b, lr,Some(LR::R))
405                }
406            },
407            Some(mut t) => {
408                if positions[demension].partial_cmp(&t.positions[demension]).unwrap().is_lt() {
409                    let (n,b) = Self::insert(t.left,
410                                             positions,
411                                             color,
412                                             parent_color,
413                                             lr,
414                                             value, (demension+1) % K);
415
416                    t.left = Some(Box::new(n));
417
418                    (*t,b)
419                } else {
420                    let (n,b) = Self::insert(t.right,
421                                             positions,
422                                             color,
423                                             parent_color,
424                                             lr,
425                                             value, (demension+1) % K);
426
427                    t.right = Some(Box::new(n));
428
429                    (*t,b)
430                }
431            }
432        }
433    }
434
435    fn balance(mut t: KDNode<'a, K,P,T>, demension:usize, balance:Balance, parent_lr:Option<LR>, lr:Option<LR>) -> (KDNode<'a, K,P,T>, Balance) {
436        if demension > 0 {
437            (t,balance)
438        } else {
439            match balance {
440                Balance::None => (t, balance),
441                Balance::Pre => {
442                    let lr = lr.unwrap();
443                    let parent_lr = parent_lr.unwrap();
444
445                    for _ in 0..K {
446                        t = if parent_lr != lr && lr == LR::L {
447                            Self::right_rotate(t)
448                        } else if parent_lr != lr && lr == LR::R {
449                            Self::left_rotate(t)
450                        } else {
451                            t
452                        };
453                    }
454
455                    (t, Balance::Fix)
456                },
457                Balance::Fix => {
458                    let lr = lr.unwrap();
459
460                    for _ in 0..K {
461                        t = match lr {
462                            LR::L => Self::right_rotate(t),
463                            LR::R => Self::left_rotate(t)
464                        };
465                    }
466
467                    match lr {
468                        LR::L => {
469                            if let Some(c) = t.left.as_ref() {
470                                *c.color.borrow_mut() = Color::Black;
471                            }
472                        },
473                        LR::R => {
474                            if let Some(c) = t.right.as_ref() {
475                                *c.color.borrow_mut() = Color::Black;
476                            }
477                        }
478                    }
479                    (t, Balance::None)
480                }
481            }
482        }
483    }
484}
485#[derive(Debug)]
486pub struct KDTree<'a,const K:usize,P,T>
487    where P: Debug + PartialOrd + Mul<Output = P> + Add + Sub + Clone + Copy + Default + Distance<P, Output = P> + Square + Sized + 'a,
488          &'a [P; K]: EuclideanDistance<&'a [P; K], Output = P> + 'a {
489    root: Option<Box<KDNode<'a, K,P,T>>>,
490    l:PhantomData<&'a ()>
491}
492impl<'a,const K:usize,P,T> KDTree<'a, K,P,T>
493    where P: Debug + PartialOrd + Mul<Output = P> + Add + Sub + Clone + Copy + Default + Distance<P, Output = P> + 'a,
494          &'a [P; K]: EuclideanDistance<&'a [P; K], Output = P> + 'a {
495    pub fn new() -> KDTree<'a, K,P,T> {
496        KDTree {
497            root: None,
498            l:PhantomData::<&'a ()>
499        }
500    }
501
502    pub fn nearest(&'a self, positions:&'a [P; K]) -> Option<(&'a [P; K], Rc<RefCell<T>>)> {
503        self.root.as_ref().and_then(|root| {
504            let distance = positions.euclidean_distance(&root.positions);
505
506            KDNode::nearest(Some(root),positions,distance,&root.positions,&root.value,0).map(|(_,p,v)| {
507                (p,v)
508            })
509        })
510    }
511
512    pub fn nearest_position(&'a self, positions:&'a [P; K]) -> Option<&'a [P; K]> {
513        self.root.as_ref().and_then(|root| {
514            let distance = positions.euclidean_distance(&root.positions);
515
516            KDNode::nearest(Some(root),positions,distance,&root.positions,&root.value,0).map(|(_,p,_)| {
517                p
518            })
519        })
520    }
521
522    pub fn insert(&mut self, positions:[P; K], value:T) {
523        let (n,_) = KDNode::insert(self.root.take(),
524                                   &Rc::new(positions),
525                                   &Rc::new(RefCell::new(Color::Black)),
526                                   None,
527                                   None,
528                                   Rc::new(RefCell::new(value)),
529                                   0);
530        self.root = Some(Box::new(n));
531        self.root.as_ref().map(|root| {
532            *root.color.borrow_mut() = Color::Black;
533        });
534   }
535}
536
537