algorithms_edu/data_structures/balanced_tree/
avl_tree.rs

1//! This mod contains an implementation of an AVL tree. An AVL tree is a special type of binary tree
2//! which self balances itself to keep operations logarithmic.
3//!
4//! # Resources
5//!
6//! - [W. Fiset's video 1](https://www.youtube.com/watch?v=q4fnJZr8ztY)
7//! - [W. Fiset's video 2](https://www.youtube.com/watch?v=1QSYxIKXXP4)
8//! - [W. Fiset's video 3](https://www.youtube.com/watch?v=g4y2h70D6Nk)
9//! - [W. Fiset's video 4](https://www.youtube.com/watch?v=tqFZzXkbbGY)
10//! - [Wikipedia](https://www.wikiwand.com/en/AVL_tree)
11
12use std::cmp::Ordering;
13use std::fmt::Debug;
14use std::mem;
15
16#[derive(Debug, Clone, Eq, PartialEq)]
17struct Node<T: Ord + Debug + PartialEq + Eq + Clone> {
18    value: T,
19    height: i32,
20    balance_factor: i8,
21    left: Option<Box<Node<T>>>,
22    right: Option<Box<Node<T>>>,
23}
24
25impl<T: Ord + Debug + PartialEq + Eq + Clone> Node<T> {
26    fn new(value: T) -> Self {
27        Self {
28            value,
29            height: 0,
30            balance_factor: 0,
31            left: None,
32            right: None,
33        }
34    }
35    /// Updates a node's height and balance factor.
36    fn update(&mut self) {
37        let left_node_height = self.left.as_ref().map_or(-1, |node| node.height);
38        let right_node_height = self.right.as_ref().map_or(-1, |node| node.height);
39        // update this node's height
40        self.height = std::cmp::max(left_node_height, right_node_height) + 1;
41        // update balance factor
42        self.balance_factor = (right_node_height - left_node_height) as i8;
43    }
44}
45
46#[derive(Default, Debug, Clone, Eq, PartialEq)]
47pub struct AvlTree<T: Ord + Debug + PartialEq + Eq + Clone> {
48    root: Option<Box<Node<T>>>,
49    len: usize,
50}
51
52impl<T: Ord + Debug + PartialEq + Eq + Clone> AvlTree<T> {
53    pub fn new() -> Self {
54        Self { root: None, len: 0 }
55    }
56    // the height of a rooted tree is the number of edges between the tree's
57    // root and its furthest leaf. This means that a tree containing a single
58    // node has a height of 0
59    pub fn height(&self) -> Option<i32> {
60        self.root.as_ref().map(|node| node.height)
61    }
62    pub fn len(&self) -> usize {
63        self.len
64    }
65    pub fn is_empty(&self) -> bool {
66        self.len() == 0
67    }
68    pub fn contains(&self, value: &T) -> bool {
69        fn _contains<T: Ord + Debug + Clone>(node: &Option<Box<Node<T>>>, value: &T) -> bool {
70            node.as_ref().map_or(false, |node| {
71                // compare the current value to the value of the node.
72                match value.cmp(&node.value) {
73                    // dig into the left subtree
74                    Ordering::Less => _contains(&node.left, value),
75                    // dig into the right subtree
76                    Ordering::Greater => _contains(&node.right, value),
77                    Ordering::Equal => true,
78                }
79            })
80        }
81        _contains(&self.root, value)
82    }
83    /// If the value is not found in the AVL tree, insert it and return `true`.
84    /// Otherwise, do not insert and return `false`.
85    pub fn insert(&mut self, value: T) -> bool {
86        fn _insert<T: Ord + Debug + Clone>(node: &mut Option<Box<Node<T>>>, value: T) -> bool {
87            let success = match node.as_mut() {
88                None => {
89                    *node = Some(Box::new(Node::new(value)));
90                    return true;
91                }
92                Some(node) => {
93                    // compare the current value to the value of the node.
94                    match value.cmp(&node.value) {
95                        // insert into the left subtree
96                        Ordering::Less => _insert(&mut node.left, value),
97                        // insert into the right subtree
98                        Ordering::Greater => _insert(&mut node.right, value),
99                        Ordering::Equal => false,
100                    }
101                }
102            };
103            let node = node.as_mut().unwrap();
104            node.update();
105            AvlTree::balance(node);
106
107            success
108        }
109        let success = _insert(&mut self.root, value);
110        if success {
111            self.len += 1;
112        }
113        success
114    }
115
116    /// re-balance a node if its balance factor is +2 or -2
117    fn balance(node: &mut Box<Node<T>>) {
118        // left heavy
119        match node.balance_factor {
120            -2 => {
121                // left-left case
122                if node.left.as_ref().unwrap().balance_factor < 0 {
123                    Self::rotate_right(node);
124                } else {
125                    // left-right case
126                    Self::rotate_left(&mut node.left.as_mut().unwrap());
127                    Self::rotate_right(node);
128                }
129            }
130            2 => {
131                // right-right case
132                if node.right.as_ref().unwrap().balance_factor > 0 {
133                    Self::rotate_left(node);
134                } else {
135                    // right-left case
136                    Self::rotate_right(&mut node.right.as_mut().unwrap());
137                    Self::rotate_left(node);
138                }
139            }
140            _ => {}
141        }
142    }
143
144    fn rotate_left(node: &mut Box<Node<T>>) {
145        let right_left = node.right.as_mut().unwrap().left.take();
146        let new_parent = mem::replace(&mut node.right, right_left).unwrap();
147        let new_left_child = mem::replace(node, new_parent);
148        node.left = Some(new_left_child);
149        node.left.as_mut().unwrap().update();
150        node.update();
151    }
152
153    fn rotate_right(node: &mut Box<Node<T>>) {
154        let left_right = node.left.as_mut().unwrap().right.take();
155        let new_parent = mem::replace(&mut node.left, left_right).unwrap();
156        let new_right_child = mem::replace(node, new_parent);
157        node.right = Some(new_right_child);
158        node.right.as_mut().unwrap().update();
159        node.update();
160    }
161
162    // pub fn remove(&mut self, elem: &T) {
163    //     fn _remove<T: Ord + Debug + Clone>(
164    //         node: Option<Box<Node<T>>>,
165    //         elem: &T,
166    //     ) -> Option<Box<Node<T>>> {
167    //         match node {
168    //             None => None,
169    //             Some(mut node) => {
170    //                 // compare the current value to the value of the node.
171    //                 match elem.cmp(&node.value) {
172    //                     // Dig into left subtree, the value we're looking
173    //                     // for is smaller than the current value.
174    //                     Ordering::Less => node.left = _remove(node.left, elem),
175    //                     // Dig into right subtree, the value we're looking
176    //                     // for is greater than the current value.
177    //                     Ordering::Greater => node.right = _remove(node.right, elem),
178    //                     Ordering::Equal => {
179    //                         // This is the case with only a right subtree or no subtree at all.
180    //                         // In this situation just swap the node we wish to remove
181    //                         // with its right child.
182    //                         if node.left.is_none() {
183    //                             return node.right;
184    //                         }
185    //                         // This is the case with only a left subtree or
186    //                         // no subtree at all. In this situation just
187    //                         // swap the node we wish to remove with its left child.
188    //                         else if node.right.is_none() {
189    //                             return node.left;
190    //                         }
191    //                         // When removing a node from a binary tree with two links the
192    //                         // successor of the node being removed can either be the largest
193    //                         // value in the left subtree or the smallest value in the right
194    //                         // subtree. As a heuristic, I will remove from the subtree with
195    //                         // the greatest hieght in hopes that this may help with balancing.
196    //                         else {
197    //                             let left = node.left.as_ref().unwrap();
198    //                             let right = node.right.as_ref().unwrap();
199
200    //                             // Choose to remove from left subtree
201    //                             if left.height >= right.height {
202    //                                 // Swap the value of the successor into the node.
203    //                                 let successor_value = AvlTree::find_max(&left).clone();
204    //                                 node.value = successor_value.clone();
205
206    //                                 // Find the largest node in the left subtree.
207    //                                 node.left = _remove(node.left, &successor_value);
208    //                             } else {
209    //                                 // Swap the value of the successor into the node.
210    //                                 let successor_value = AvlTree::find_min(&right).clone();
211    //                                 node.value = successor_value.clone();
212
213    //                                 // Go into the right subtree and remove the leftmost node we
214    //                                 // found and swapped data with. This prevents us from having
215    //                                 // two nodes in our tree with the same value.
216    //                                 node.right = _remove(node.right, &successor_value);
217    //                             }
218    //                         }
219    //                     }
220    //                 }
221    //                 node.update();
222    //                 AvlTree::balance(&mut node);
223    //                 Some(node)
224    //             }
225    //         }
226    //     }
227    //     let root = mem::replace(&mut self.root, None);
228    //     self.root = _remove(root, elem);
229    // }
230
231    // fn find_min(mut node: &Node<T>) -> &T {
232    //     while let Some(next_node) = node.left.as_ref() {
233    //         node = &next_node;
234    //     }
235    //     &node.value
236    // }
237    // fn find_max(mut node: &Node<T>) -> &T {
238    //     while let Some(next_node) = node.right.as_ref() {
239    //         node = &next_node;
240    //     }
241    //     &node.value
242    // }
243    pub fn remove(&mut self, elem: &T) -> bool {
244        fn _remove<T: Ord + Debug + Clone>(
245            _node: &mut Option<Box<Node<T>>>,
246            elem: &T,
247            success: &mut bool,
248        ) {
249            match _node {
250                None => {}
251                Some(node) => {
252                    match elem.cmp(&node.value) {
253                        Ordering::Less => {
254                            _remove(&mut node.left, elem, success);
255                        }
256                        Ordering::Greater => {
257                            _remove(&mut node.right, elem, success);
258                        }
259                        Ordering::Equal => {
260                            *success = true;
261                            // if the target is found, replace this node with a successor
262                            *_node = match (node.left.take(), node.right.take()) {
263                                (None, None) => None,
264                                (None, Some(right)) => Some(right),
265                                (Some(left), None) => Some(left),
266                                (Some(left), Some(right)) => {
267                                    if left.height >= right.height {
268                                        let mut x = AvlTree::remove_max(left);
269                                        x.right = Some(right);
270                                        Some(x)
271                                    } else {
272                                        let mut x = AvlTree::remove_min(right);
273                                        x.left = Some(left);
274                                        Some(x)
275                                    }
276                                }
277                            };
278                        }
279                    }
280                    let mut node = _node.as_mut().unwrap();
281                    node.update();
282                    AvlTree::balance(&mut node);
283                }
284            }
285        }
286        let mut success = false;
287        _remove(&mut self.root, elem, &mut success);
288        if success {
289            self.len -= 1;
290        }
291        success
292    }
293
294    fn remove_min(mut node: Box<Node<T>>) -> Box<Node<T>> {
295        fn _remove_min<T: Ord + Debug + PartialEq + Eq + Clone>(
296            node: &mut Node<T>,
297        ) -> Option<Box<Node<T>>> {
298            if let Some(next_node) = node.left.as_mut() {
299                let res = _remove_min(next_node);
300                if res.is_none() {
301                    node.left.take()
302                } else {
303                    res
304                }
305            } else {
306                None
307            }
308        }
309        _remove_min(&mut node).unwrap_or(node)
310    }
311    fn remove_max(mut node: Box<Node<T>>) -> Box<Node<T>> {
312        fn _remove_max<T: Ord + Debug + PartialEq + Eq + Clone>(
313            node: &mut Node<T>,
314        ) -> Option<Box<Node<T>>> {
315            if let Some(next_node) = node.right.as_mut() {
316                let res = _remove_max(next_node);
317                if res.is_none() {
318                    node.right.take()
319                } else {
320                    res
321                }
322            } else {
323                None
324            }
325        }
326        _remove_max(&mut node).unwrap_or(node)
327    }
328
329    pub fn iter(&self) -> AvlIter<T> {
330        if let Some(trav) = self.root.as_ref() {
331            AvlIter {
332                stack: Some(vec![trav]),
333                trav: Some(trav),
334            }
335        } else {
336            AvlIter {
337                stack: None,
338                trav: None,
339            }
340        }
341    }
342}
343
344// TODO: better ergonomics?
345pub struct AvlIter<'a, T: 'a + Ord + Debug + PartialEq + Eq + Clone> {
346    stack: Option<Vec<&'a Node<T>>>,
347    trav: Option<&'a Node<T>>,
348}
349
350impl<'a, T: 'a + Ord + Debug + PartialEq + Eq + Clone> Iterator for AvlIter<'a, T> {
351    type Item = &'a T;
352    fn next(&mut self) -> Option<Self::Item> {
353        if let (Some(stack), Some(trav)) = (self.stack.as_mut(), self.trav.as_mut()) {
354            while let Some(left) = trav.left.as_ref() {
355                stack.push(left);
356                *trav = left;
357            }
358
359            stack.pop().map(|curr| {
360                if let Some(right) = curr.right.as_ref() {
361                    stack.push(right);
362                    *trav = right;
363                }
364                &curr.value
365            })
366        } else {
367            None
368        }
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use lazy_static::lazy_static;
376
377    lazy_static! {
378        static ref AVL: AvlTree<i32> = {
379            //     5
380            //   2   10
381            //      7  15
382            let mut avl = AvlTree::new();
383            assert!(avl.is_empty());
384            avl.insert(2);
385            avl.insert(5);
386            avl.insert(7);
387            avl.insert(10);
388            avl.insert(15);
389            assert_eq!(avl.len(), 5);
390            avl
391        };
392    }
393
394    #[test]
395    fn test_avl() {
396        let mut avl = AVL.clone();
397        assert_eq!(avl.height().unwrap(), 2);
398        assert!(avl.contains(&2));
399        assert!(avl.contains(&5));
400        assert!(avl.contains(&7));
401        assert!(avl.contains(&10));
402        assert!(avl.contains(&15));
403        //     5
404        //   2   10
405        //      7  15
406        let root = avl.root.as_ref().unwrap();
407        assert_eq!(root.value, 5);
408        let n2 = root.left.as_ref().unwrap();
409        let n10 = root.right.as_ref().unwrap();
410        assert_eq!(n2.value, 2);
411        assert_eq!(n10.value, 10);
412        assert_eq!(n10.left.as_ref().unwrap().value, 7);
413        assert_eq!(n10.right.as_ref().unwrap().value, 15);
414        AvlTree::rotate_left(avl.root.as_mut().unwrap());
415        //     10
416        //   5    15
417        // 2   7
418        let root = avl.root.as_ref().unwrap();
419        assert_eq!(root.value, 10);
420        let n5 = root.left.as_ref().unwrap();
421        let n15 = root.right.as_ref().unwrap();
422        assert_eq!(n5.value, 5);
423        assert_eq!(n15.value, 15);
424        assert_eq!(n5.left.as_ref().unwrap().value, 2);
425        assert_eq!(n5.right.as_ref().unwrap().value, 7);
426        //     10
427        //   2    15
428        //     7
429        avl.remove(&5);
430        let root = avl.root.as_ref().unwrap();
431        assert_eq!(root.value, 10);
432        let n2 = root.left.as_ref().unwrap();
433        let n15 = root.right.as_ref().unwrap();
434        assert_eq!(n2.value, 2);
435        assert_eq!(n15.value, 15);
436        assert!(n2.left.as_ref().is_none());
437        assert_eq!(n2.right.as_ref().unwrap().value, 7);
438
439        avl.insert(5);
440        //     10
441        //   5    15
442        // 2   7
443        AvlTree::rotate_right(avl.root.as_mut().unwrap());
444        //     5
445        //   2   10
446        //      7  15
447        assert_eq!(&avl, &*AVL);
448
449        // will not insert an element that's already in the tree
450        assert!(!avl.insert(5));
451        // will not remove an element that's not in the tree
452        assert!(!avl.remove(&100));
453    }
454
455    #[test]
456    fn test_avl_iter() {
457        //     5
458        //   2   10
459        //      7  15
460        let v = AVL.iter().cloned().collect::<Vec<_>>();
461        assert_eq!(&v, &[2, 5, 7, 10, 15]);
462    }
463}