prust_lib/
avl.rs

1use std::cmp::max;
2
3use crate::RefCounter;
4
5pub enum AVL<K, V = ()> {
6    Empty,
7    Node {
8        key: RefCounter<K>,
9        value: RefCounter<V>,
10        left: RefCounter<AVL<K, V>>,
11        right: RefCounter<AVL<K, V>>,
12    },
13}
14
15pub type OrderedMap<K, V> = AVL<K, V>;
16pub type OrderedSet<K> = AVL<K>;
17
18impl<K, V> Clone for AVL<K, V> {
19    fn clone(&self) -> Self {
20        match self {
21            Self::Empty => Self::Empty,
22            Self::Node {
23                key,
24                value,
25                left,
26                right,
27            } => Self::Node {
28                key: key.clone(),
29                value: value.clone(),
30                left: left.clone(),
31                right: right.clone(),
32            },
33        }
34    }
35}
36
37impl<K: Ord> AVL<K> {
38    pub fn insert(&self, value: K) -> Self {
39        self.put(value, ())
40    }
41    pub fn search(&self, value: &K) -> bool {
42        self.find(value).is_some()
43    }
44}
45
46impl<K: Ord, V> AVL<K, V> {
47    pub fn empty() -> AVL<K, V> {
48        return AVL::Empty;
49    }
50    pub fn is_empty(&self) -> bool {
51        match self {
52            AVL::Empty => true,
53            _ => false,
54        }
55    }
56    fn height(&self) -> i64 {
57        match self {
58            AVL::Empty => 0,
59            AVL::Node {
60                key: _,
61                value: _,
62                left,
63                right,
64            } => 1 + max(&left.height(), &right.height()),
65        }
66    }
67    fn diff(&self) -> i64 {
68        match self {
69            AVL::Empty => 0,
70            AVL::Node {
71                key: _,
72                value: _,
73                left,
74                right,
75            } => &left.height() - &right.height(),
76        }
77    }
78    pub fn find(&self, target_value: &K) -> Option<&V> {
79        match self {
80            AVL::Empty => Option::None,
81            AVL::Node {
82                key,
83                value,
84                left,
85                right,
86            } => match target_value.cmp(key) {
87                std::cmp::Ordering::Less => left.find(target_value),
88                std::cmp::Ordering::Equal => Option::Some(value.as_ref()),
89                std::cmp::Ordering::Greater => right.find(target_value),
90            },
91        }
92    }
93    fn right_rotation(&self) -> AVL<K, V> {
94        if let AVL::Node {
95            key: x,
96            value: vx,
97            left: lt,
98            right: t3,
99        } = self
100        {
101            if let AVL::Node {
102                key: y,
103                value: vy,
104                left: t1,
105                right: t2,
106            } = (*lt).as_ref()
107            {
108                return AVL::Node {
109                    key: y.clone(),
110                    left: t1.clone(),
111                    value: vy.clone(),
112                    right: RefCounter::new(AVL::Node {
113                        key: x.clone(),
114                        value: vx.clone(),
115                        left: t2.clone(),
116                        right: t3.clone(),
117                    }),
118                };
119            }
120        }
121        return self.clone();
122    }
123    fn right_fix(&self) -> AVL<K, V> {
124        if let AVL::Node {
125            key: x,
126            value: vx,
127            left: t1,
128            right: t2,
129        } = self
130        {
131            if t1.diff() == -1 {
132                return AVL::Node {
133                    key: x.clone(),
134                    value: vx.clone(),
135                    left: RefCounter::new(t1.left_rotation()),
136                    right: t2.clone(),
137                }
138                .right_rotation();
139            } else {
140                return self.right_rotation();
141            }
142        }
143        return self.clone();
144    }
145    fn left_rotation(&self) -> AVL<K, V> {
146        if let AVL::Node {
147            key: x,
148            value: vx,
149            left: t1,
150            right: rt,
151        } = self
152        {
153            if let AVL::Node {
154                key: y,
155                value: vy,
156                left: t2,
157                right: t3,
158            } = (*rt).as_ref()
159            {
160                return AVL::Node {
161                    key: y.clone(),
162                    value: vy.clone(),
163                    left: RefCounter::new(AVL::Node {
164                        key: x.clone(),
165                        value: vx.clone(),
166                        left: t1.clone(),
167                        right: t2.clone(),
168                    }),
169                    right: t3.clone(),
170                };
171            }
172        }
173        return self.clone();
174    }
175    fn left_fix(&self) -> AVL<K, V> {
176        if let AVL::Node {
177            key: x,
178            value: vx,
179            left: t1,
180            right: t2,
181        } = self
182        {
183            if t2.diff() == 1 {
184                return AVL::Node {
185                    key: x.clone(),
186                    value: vx.clone(),
187                    left: t1.clone(),
188                    right: RefCounter::new(t2.right_rotation()),
189                }
190                .left_rotation();
191            } else {
192                return self.left_rotation();
193            }
194        }
195        return self.clone();
196    }
197    fn fix(&self) -> AVL<K, V> {
198        match self.diff() {
199            2 => self.right_fix(),
200            -2 => self.left_fix(),
201            _ => self.clone(),
202        }
203    }
204    pub fn put(&self, key: K, value: V) -> AVL<K, V> {
205        self.put_rc(RefCounter::new(key), RefCounter::new(value))
206    }
207    fn put_rc(&self, key_rc: RefCounter<K>, value_rc: RefCounter<V>) -> AVL<K, V> {
208        match self {
209            AVL::Empty => AVL::Node {
210                key: key_rc,
211                value: value_rc,
212                left: RefCounter::new(AVL::Empty),
213                right: RefCounter::new(AVL::Empty),
214            },
215            AVL::Node {
216                key,
217                value,
218                left,
219                right,
220            } => match key_rc.cmp(key) {
221                std::cmp::Ordering::Less => AVL::Node {
222                    key: key.clone(),
223                    value: value.clone(),
224                    left: RefCounter::new(left.put_rc(key_rc, value_rc)),
225                    right: right.clone(),
226                }
227                .fix(),
228                std::cmp::Ordering::Equal => AVL::Node {
229                    key: key_rc,
230                    value: value_rc,
231                    left: left.clone(),
232                    right: right.clone(),
233                },
234                std::cmp::Ordering::Greater => AVL::Node {
235                    key: key.clone(),
236                    value: value.clone(),
237                    left: left.clone(),
238                    right: RefCounter::new(right.put_rc(key_rc, value_rc)),
239                }
240                .fix(),
241            },
242        }
243    }
244    pub fn delete(&self, target_key: &K) -> AVL<K, V> {
245        match self {
246            AVL::Empty => AVL::Empty,
247            AVL::Node {
248                key,
249                value,
250                left,
251                right,
252            } => {
253                match target_key.cmp(key) {
254                    std::cmp::Ordering::Less => {
255                        let left_deleted = left.delete(target_key);
256                        AVL::Node {
257                            key: key.clone(),
258                            value: value.clone(),
259                            left: RefCounter::new(left_deleted),
260                            right: right.clone(),
261                        }
262                        .fix()
263                    }
264                    std::cmp::Ordering::Equal => {
265                        // Node with only one child or no child
266                        if left.is_empty() {
267                            return right.as_ref().clone();
268                        } else if right.is_empty() {
269                            return left.as_ref().clone();
270                        }
271
272                        // Node with two children, get the inorder predecessor (maximum value in the left subtree)
273                        let inorder_predecessor = left.find_max();
274                        if let Some((pred_key, pred_value)) = inorder_predecessor {
275                            let left_deleted = left.delete(&pred_key);
276                            AVL::Node {
277                                key: pred_key.clone(),
278                                value: pred_value.clone(),
279                                left: RefCounter::new(left_deleted),
280                                right: right.clone(),
281                            }
282                            .fix()
283                        } else {
284                            self.clone()
285                        }
286                    }
287                    std::cmp::Ordering::Greater => {
288                        let right_deleted = right.delete(target_key);
289                        AVL::Node {
290                            key: key.clone(),
291                            value: value.clone(),
292                            left: left.clone(),
293                            right: RefCounter::new(right_deleted),
294                        }
295                        .fix()
296                    }
297                }
298            }
299        }
300    }
301
302    fn find_max(&self) -> Option<(RefCounter<K>, RefCounter<V>)> {
303        match self {
304            AVL::Empty => None,
305            AVL::Node {
306                key,
307                value,
308                left: _,
309                right,
310            } => {
311                if right.is_empty() {
312                    Some((key.clone(), value.clone()))
313                } else {
314                    right.find_max()
315                }
316            }
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_avl_set() {
327        let l = AVL::empty().insert(1).insert(2).insert(3).insert(4);
328        let l2 = l.clone().insert(5);
329        for i in 1..=4 {
330            assert!(l.search(&i));
331            assert!(l2.search(&i));
332        }
333        assert!(!l.search(&5));
334        assert!(l2.search(&5));
335    }
336
337    #[test]
338    fn test_avl_map() {
339        let l = AVL::empty().put(1, 999);
340        let l2 = l.clone().put(1, 123).put(2, 3);
341        assert_eq!(l.find(&1), Some(&999));
342        assert_eq!(l2.find(&1), Some(&123));
343        assert!(l.find(&2).is_none());
344        assert!(l2.find(&2).is_some());
345    }
346
347    #[test]
348    fn test_avl_delete() {
349        let l = AVL::empty()
350            .insert(1)
351            .insert(2)
352            .insert(3)
353            .insert(4)
354            .insert(5);
355        let l = l.delete(&3);
356        assert!(!l.search(&3));
357        assert!(l.search(&1));
358        assert!(l.search(&2));
359        assert!(l.search(&4));
360        assert!(l.search(&5));
361    }
362}