algorithms_edu/data_structures/balanced_tree/avl_tree.rs
1//! This mod contains an implementation of an AVL tree. An AVL tree is a special type of binary tree
2//! which self balances itself to keep operations logarithmic.
3//!
4//! # Resources
5//!
6//! - [W. Fiset's video 1](https://www.youtube.com/watch?v=q4fnJZr8ztY)
7//! - [W. Fiset's video 2](https://www.youtube.com/watch?v=1QSYxIKXXP4)
8//! - [W. Fiset's video 3](https://www.youtube.com/watch?v=g4y2h70D6Nk)
9//! - [W. Fiset's video 4](https://www.youtube.com/watch?v=tqFZzXkbbGY)
10//! - [Wikipedia](https://www.wikiwand.com/en/AVL_tree)
11
12use std::cmp::Ordering;
13use std::fmt::Debug;
14use std::mem;
15
16#[derive(Debug, Clone, Eq, PartialEq)]
17struct Node<T: Ord + Debug + PartialEq + Eq + Clone> {
18 value: T,
19 height: i32,
20 balance_factor: i8,
21 left: Option<Box<Node<T>>>,
22 right: Option<Box<Node<T>>>,
23}
24
25impl<T: Ord + Debug + PartialEq + Eq + Clone> Node<T> {
26 fn new(value: T) -> Self {
27 Self {
28 value,
29 height: 0,
30 balance_factor: 0,
31 left: None,
32 right: None,
33 }
34 }
35 /// Updates a node's height and balance factor.
36 fn update(&mut self) {
37 let left_node_height = self.left.as_ref().map_or(-1, |node| node.height);
38 let right_node_height = self.right.as_ref().map_or(-1, |node| node.height);
39 // update this node's height
40 self.height = std::cmp::max(left_node_height, right_node_height) + 1;
41 // update balance factor
42 self.balance_factor = (right_node_height - left_node_height) as i8;
43 }
44}
45
46#[derive(Default, Debug, Clone, Eq, PartialEq)]
47pub struct AvlTree<T: Ord + Debug + PartialEq + Eq + Clone> {
48 root: Option<Box<Node<T>>>,
49 len: usize,
50}
51
52impl<T: Ord + Debug + PartialEq + Eq + Clone> AvlTree<T> {
53 pub fn new() -> Self {
54 Self { root: None, len: 0 }
55 }
56 // the height of a rooted tree is the number of edges between the tree's
57 // root and its furthest leaf. This means that a tree containing a single
58 // node has a height of 0
59 pub fn height(&self) -> Option<i32> {
60 self.root.as_ref().map(|node| node.height)
61 }
62 pub fn len(&self) -> usize {
63 self.len
64 }
65 pub fn is_empty(&self) -> bool {
66 self.len() == 0
67 }
68 pub fn contains(&self, value: &T) -> bool {
69 fn _contains<T: Ord + Debug + Clone>(node: &Option<Box<Node<T>>>, value: &T) -> bool {
70 node.as_ref().map_or(false, |node| {
71 // compare the current value to the value of the node.
72 match value.cmp(&node.value) {
73 // dig into the left subtree
74 Ordering::Less => _contains(&node.left, value),
75 // dig into the right subtree
76 Ordering::Greater => _contains(&node.right, value),
77 Ordering::Equal => true,
78 }
79 })
80 }
81 _contains(&self.root, value)
82 }
83 /// If the value is not found in the AVL tree, insert it and return `true`.
84 /// Otherwise, do not insert and return `false`.
85 pub fn insert(&mut self, value: T) -> bool {
86 fn _insert<T: Ord + Debug + Clone>(node: &mut Option<Box<Node<T>>>, value: T) -> bool {
87 let success = match node.as_mut() {
88 None => {
89 *node = Some(Box::new(Node::new(value)));
90 return true;
91 }
92 Some(node) => {
93 // compare the current value to the value of the node.
94 match value.cmp(&node.value) {
95 // insert into the left subtree
96 Ordering::Less => _insert(&mut node.left, value),
97 // insert into the right subtree
98 Ordering::Greater => _insert(&mut node.right, value),
99 Ordering::Equal => false,
100 }
101 }
102 };
103 let node = node.as_mut().unwrap();
104 node.update();
105 AvlTree::balance(node);
106
107 success
108 }
109 let success = _insert(&mut self.root, value);
110 if success {
111 self.len += 1;
112 }
113 success
114 }
115
116 /// re-balance a node if its balance factor is +2 or -2
117 fn balance(node: &mut Box<Node<T>>) {
118 // left heavy
119 match node.balance_factor {
120 -2 => {
121 // left-left case
122 if node.left.as_ref().unwrap().balance_factor < 0 {
123 Self::rotate_right(node);
124 } else {
125 // left-right case
126 Self::rotate_left(&mut node.left.as_mut().unwrap());
127 Self::rotate_right(node);
128 }
129 }
130 2 => {
131 // right-right case
132 if node.right.as_ref().unwrap().balance_factor > 0 {
133 Self::rotate_left(node);
134 } else {
135 // right-left case
136 Self::rotate_right(&mut node.right.as_mut().unwrap());
137 Self::rotate_left(node);
138 }
139 }
140 _ => {}
141 }
142 }
143
144 fn rotate_left(node: &mut Box<Node<T>>) {
145 let right_left = node.right.as_mut().unwrap().left.take();
146 let new_parent = mem::replace(&mut node.right, right_left).unwrap();
147 let new_left_child = mem::replace(node, new_parent);
148 node.left = Some(new_left_child);
149 node.left.as_mut().unwrap().update();
150 node.update();
151 }
152
153 fn rotate_right(node: &mut Box<Node<T>>) {
154 let left_right = node.left.as_mut().unwrap().right.take();
155 let new_parent = mem::replace(&mut node.left, left_right).unwrap();
156 let new_right_child = mem::replace(node, new_parent);
157 node.right = Some(new_right_child);
158 node.right.as_mut().unwrap().update();
159 node.update();
160 }
161
162 // pub fn remove(&mut self, elem: &T) {
163 // fn _remove<T: Ord + Debug + Clone>(
164 // node: Option<Box<Node<T>>>,
165 // elem: &T,
166 // ) -> Option<Box<Node<T>>> {
167 // match node {
168 // None => None,
169 // Some(mut node) => {
170 // // compare the current value to the value of the node.
171 // match elem.cmp(&node.value) {
172 // // Dig into left subtree, the value we're looking
173 // // for is smaller than the current value.
174 // Ordering::Less => node.left = _remove(node.left, elem),
175 // // Dig into right subtree, the value we're looking
176 // // for is greater than the current value.
177 // Ordering::Greater => node.right = _remove(node.right, elem),
178 // Ordering::Equal => {
179 // // This is the case with only a right subtree or no subtree at all.
180 // // In this situation just swap the node we wish to remove
181 // // with its right child.
182 // if node.left.is_none() {
183 // return node.right;
184 // }
185 // // This is the case with only a left subtree or
186 // // no subtree at all. In this situation just
187 // // swap the node we wish to remove with its left child.
188 // else if node.right.is_none() {
189 // return node.left;
190 // }
191 // // When removing a node from a binary tree with two links the
192 // // successor of the node being removed can either be the largest
193 // // value in the left subtree or the smallest value in the right
194 // // subtree. As a heuristic, I will remove from the subtree with
195 // // the greatest hieght in hopes that this may help with balancing.
196 // else {
197 // let left = node.left.as_ref().unwrap();
198 // let right = node.right.as_ref().unwrap();
199
200 // // Choose to remove from left subtree
201 // if left.height >= right.height {
202 // // Swap the value of the successor into the node.
203 // let successor_value = AvlTree::find_max(&left).clone();
204 // node.value = successor_value.clone();
205
206 // // Find the largest node in the left subtree.
207 // node.left = _remove(node.left, &successor_value);
208 // } else {
209 // // Swap the value of the successor into the node.
210 // let successor_value = AvlTree::find_min(&right).clone();
211 // node.value = successor_value.clone();
212
213 // // Go into the right subtree and remove the leftmost node we
214 // // found and swapped data with. This prevents us from having
215 // // two nodes in our tree with the same value.
216 // node.right = _remove(node.right, &successor_value);
217 // }
218 // }
219 // }
220 // }
221 // node.update();
222 // AvlTree::balance(&mut node);
223 // Some(node)
224 // }
225 // }
226 // }
227 // let root = mem::replace(&mut self.root, None);
228 // self.root = _remove(root, elem);
229 // }
230
231 // fn find_min(mut node: &Node<T>) -> &T {
232 // while let Some(next_node) = node.left.as_ref() {
233 // node = &next_node;
234 // }
235 // &node.value
236 // }
237 // fn find_max(mut node: &Node<T>) -> &T {
238 // while let Some(next_node) = node.right.as_ref() {
239 // node = &next_node;
240 // }
241 // &node.value
242 // }
243 pub fn remove(&mut self, elem: &T) -> bool {
244 fn _remove<T: Ord + Debug + Clone>(
245 _node: &mut Option<Box<Node<T>>>,
246 elem: &T,
247 success: &mut bool,
248 ) {
249 match _node {
250 None => {}
251 Some(node) => {
252 match elem.cmp(&node.value) {
253 Ordering::Less => {
254 _remove(&mut node.left, elem, success);
255 }
256 Ordering::Greater => {
257 _remove(&mut node.right, elem, success);
258 }
259 Ordering::Equal => {
260 *success = true;
261 // if the target is found, replace this node with a successor
262 *_node = match (node.left.take(), node.right.take()) {
263 (None, None) => None,
264 (None, Some(right)) => Some(right),
265 (Some(left), None) => Some(left),
266 (Some(left), Some(right)) => {
267 if left.height >= right.height {
268 let mut x = AvlTree::remove_max(left);
269 x.right = Some(right);
270 Some(x)
271 } else {
272 let mut x = AvlTree::remove_min(right);
273 x.left = Some(left);
274 Some(x)
275 }
276 }
277 };
278 }
279 }
280 let mut node = _node.as_mut().unwrap();
281 node.update();
282 AvlTree::balance(&mut node);
283 }
284 }
285 }
286 let mut success = false;
287 _remove(&mut self.root, elem, &mut success);
288 if success {
289 self.len -= 1;
290 }
291 success
292 }
293
294 fn remove_min(mut node: Box<Node<T>>) -> Box<Node<T>> {
295 fn _remove_min<T: Ord + Debug + PartialEq + Eq + Clone>(
296 node: &mut Node<T>,
297 ) -> Option<Box<Node<T>>> {
298 if let Some(next_node) = node.left.as_mut() {
299 let res = _remove_min(next_node);
300 if res.is_none() {
301 node.left.take()
302 } else {
303 res
304 }
305 } else {
306 None
307 }
308 }
309 _remove_min(&mut node).unwrap_or(node)
310 }
311 fn remove_max(mut node: Box<Node<T>>) -> Box<Node<T>> {
312 fn _remove_max<T: Ord + Debug + PartialEq + Eq + Clone>(
313 node: &mut Node<T>,
314 ) -> Option<Box<Node<T>>> {
315 if let Some(next_node) = node.right.as_mut() {
316 let res = _remove_max(next_node);
317 if res.is_none() {
318 node.right.take()
319 } else {
320 res
321 }
322 } else {
323 None
324 }
325 }
326 _remove_max(&mut node).unwrap_or(node)
327 }
328
329 pub fn iter(&self) -> AvlIter<T> {
330 if let Some(trav) = self.root.as_ref() {
331 AvlIter {
332 stack: Some(vec![trav]),
333 trav: Some(trav),
334 }
335 } else {
336 AvlIter {
337 stack: None,
338 trav: None,
339 }
340 }
341 }
342}
343
344// TODO: better ergonomics?
345pub struct AvlIter<'a, T: 'a + Ord + Debug + PartialEq + Eq + Clone> {
346 stack: Option<Vec<&'a Node<T>>>,
347 trav: Option<&'a Node<T>>,
348}
349
350impl<'a, T: 'a + Ord + Debug + PartialEq + Eq + Clone> Iterator for AvlIter<'a, T> {
351 type Item = &'a T;
352 fn next(&mut self) -> Option<Self::Item> {
353 if let (Some(stack), Some(trav)) = (self.stack.as_mut(), self.trav.as_mut()) {
354 while let Some(left) = trav.left.as_ref() {
355 stack.push(left);
356 *trav = left;
357 }
358
359 stack.pop().map(|curr| {
360 if let Some(right) = curr.right.as_ref() {
361 stack.push(right);
362 *trav = right;
363 }
364 &curr.value
365 })
366 } else {
367 None
368 }
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use lazy_static::lazy_static;
376
377 lazy_static! {
378 static ref AVL: AvlTree<i32> = {
379 // 5
380 // 2 10
381 // 7 15
382 let mut avl = AvlTree::new();
383 assert!(avl.is_empty());
384 avl.insert(2);
385 avl.insert(5);
386 avl.insert(7);
387 avl.insert(10);
388 avl.insert(15);
389 assert_eq!(avl.len(), 5);
390 avl
391 };
392 }
393
394 #[test]
395 fn test_avl() {
396 let mut avl = AVL.clone();
397 assert_eq!(avl.height().unwrap(), 2);
398 assert!(avl.contains(&2));
399 assert!(avl.contains(&5));
400 assert!(avl.contains(&7));
401 assert!(avl.contains(&10));
402 assert!(avl.contains(&15));
403 // 5
404 // 2 10
405 // 7 15
406 let root = avl.root.as_ref().unwrap();
407 assert_eq!(root.value, 5);
408 let n2 = root.left.as_ref().unwrap();
409 let n10 = root.right.as_ref().unwrap();
410 assert_eq!(n2.value, 2);
411 assert_eq!(n10.value, 10);
412 assert_eq!(n10.left.as_ref().unwrap().value, 7);
413 assert_eq!(n10.right.as_ref().unwrap().value, 15);
414 AvlTree::rotate_left(avl.root.as_mut().unwrap());
415 // 10
416 // 5 15
417 // 2 7
418 let root = avl.root.as_ref().unwrap();
419 assert_eq!(root.value, 10);
420 let n5 = root.left.as_ref().unwrap();
421 let n15 = root.right.as_ref().unwrap();
422 assert_eq!(n5.value, 5);
423 assert_eq!(n15.value, 15);
424 assert_eq!(n5.left.as_ref().unwrap().value, 2);
425 assert_eq!(n5.right.as_ref().unwrap().value, 7);
426 // 10
427 // 2 15
428 // 7
429 avl.remove(&5);
430 let root = avl.root.as_ref().unwrap();
431 assert_eq!(root.value, 10);
432 let n2 = root.left.as_ref().unwrap();
433 let n15 = root.right.as_ref().unwrap();
434 assert_eq!(n2.value, 2);
435 assert_eq!(n15.value, 15);
436 assert!(n2.left.as_ref().is_none());
437 assert_eq!(n2.right.as_ref().unwrap().value, 7);
438
439 avl.insert(5);
440 // 10
441 // 5 15
442 // 2 7
443 AvlTree::rotate_right(avl.root.as_mut().unwrap());
444 // 5
445 // 2 10
446 // 7 15
447 assert_eq!(&avl, &*AVL);
448
449 // will not insert an element that's already in the tree
450 assert!(!avl.insert(5));
451 // will not remove an element that's not in the tree
452 assert!(!avl.remove(&100));
453 }
454
455 #[test]
456 fn test_avl_iter() {
457 // 5
458 // 2 10
459 // 7 15
460 let v = AVL.iter().cloned().collect::<Vec<_>>();
461 assert_eq!(&v, &[2, 5, 7, 10, 15]);
462 }
463}