1use std::cmp::max;
2
3use crate::RefCounter;
4
5pub enum AVL<K, V = ()> {
6 Empty,
7 Node {
8 key: RefCounter<K>,
9 value: RefCounter<V>,
10 left: RefCounter<AVL<K, V>>,
11 right: RefCounter<AVL<K, V>>,
12 },
13}
14
15pub type OrderedMap<K, V> = AVL<K, V>;
16pub type OrderedSet<K> = AVL<K>;
17
18impl<K, V> Clone for AVL<K, V> {
19 fn clone(&self) -> Self {
20 match self {
21 Self::Empty => Self::Empty,
22 Self::Node {
23 key,
24 value,
25 left,
26 right,
27 } => Self::Node {
28 key: key.clone(),
29 value: value.clone(),
30 left: left.clone(),
31 right: right.clone(),
32 },
33 }
34 }
35}
36
37impl<K: Ord> AVL<K> {
38 pub fn insert(&self, value: K) -> Self {
39 self.put(value, ())
40 }
41 pub fn search(&self, value: &K) -> bool {
42 self.find(value).is_some()
43 }
44}
45
46impl<K: Ord, V> AVL<K, V> {
47 pub fn empty() -> AVL<K, V> {
48 return AVL::Empty;
49 }
50 pub fn is_empty(&self) -> bool {
51 match self {
52 AVL::Empty => true,
53 _ => false,
54 }
55 }
56 fn height(&self) -> i64 {
57 match self {
58 AVL::Empty => 0,
59 AVL::Node {
60 key: _,
61 value: _,
62 left,
63 right,
64 } => 1 + max(&left.height(), &right.height()),
65 }
66 }
67 fn diff(&self) -> i64 {
68 match self {
69 AVL::Empty => 0,
70 AVL::Node {
71 key: _,
72 value: _,
73 left,
74 right,
75 } => &left.height() - &right.height(),
76 }
77 }
78 pub fn find(&self, target_value: &K) -> Option<&V> {
79 match self {
80 AVL::Empty => Option::None,
81 AVL::Node {
82 key,
83 value,
84 left,
85 right,
86 } => match target_value.cmp(key) {
87 std::cmp::Ordering::Less => left.find(target_value),
88 std::cmp::Ordering::Equal => Option::Some(value.as_ref()),
89 std::cmp::Ordering::Greater => right.find(target_value),
90 },
91 }
92 }
93 fn right_rotation(&self) -> AVL<K, V> {
94 if let AVL::Node {
95 key: x,
96 value: vx,
97 left: lt,
98 right: t3,
99 } = self
100 {
101 if let AVL::Node {
102 key: y,
103 value: vy,
104 left: t1,
105 right: t2,
106 } = (*lt).as_ref()
107 {
108 return AVL::Node {
109 key: y.clone(),
110 left: t1.clone(),
111 value: vy.clone(),
112 right: RefCounter::new(AVL::Node {
113 key: x.clone(),
114 value: vx.clone(),
115 left: t2.clone(),
116 right: t3.clone(),
117 }),
118 };
119 }
120 }
121 return self.clone();
122 }
123 fn right_fix(&self) -> AVL<K, V> {
124 if let AVL::Node {
125 key: x,
126 value: vx,
127 left: t1,
128 right: t2,
129 } = self
130 {
131 if t1.diff() == -1 {
132 return AVL::Node {
133 key: x.clone(),
134 value: vx.clone(),
135 left: RefCounter::new(t1.left_rotation()),
136 right: t2.clone(),
137 }
138 .right_rotation();
139 } else {
140 return self.right_rotation();
141 }
142 }
143 return self.clone();
144 }
145 fn left_rotation(&self) -> AVL<K, V> {
146 if let AVL::Node {
147 key: x,
148 value: vx,
149 left: t1,
150 right: rt,
151 } = self
152 {
153 if let AVL::Node {
154 key: y,
155 value: vy,
156 left: t2,
157 right: t3,
158 } = (*rt).as_ref()
159 {
160 return AVL::Node {
161 key: y.clone(),
162 value: vy.clone(),
163 left: RefCounter::new(AVL::Node {
164 key: x.clone(),
165 value: vx.clone(),
166 left: t1.clone(),
167 right: t2.clone(),
168 }),
169 right: t3.clone(),
170 };
171 }
172 }
173 return self.clone();
174 }
175 fn left_fix(&self) -> AVL<K, V> {
176 if let AVL::Node {
177 key: x,
178 value: vx,
179 left: t1,
180 right: t2,
181 } = self
182 {
183 if t2.diff() == 1 {
184 return AVL::Node {
185 key: x.clone(),
186 value: vx.clone(),
187 left: t1.clone(),
188 right: RefCounter::new(t2.right_rotation()),
189 }
190 .left_rotation();
191 } else {
192 return self.left_rotation();
193 }
194 }
195 return self.clone();
196 }
197 fn fix(&self) -> AVL<K, V> {
198 match self.diff() {
199 2 => self.right_fix(),
200 -2 => self.left_fix(),
201 _ => self.clone(),
202 }
203 }
204 pub fn put(&self, key: K, value: V) -> AVL<K, V> {
205 self.put_rc(RefCounter::new(key), RefCounter::new(value))
206 }
207 fn put_rc(&self, key_rc: RefCounter<K>, value_rc: RefCounter<V>) -> AVL<K, V> {
208 match self {
209 AVL::Empty => AVL::Node {
210 key: key_rc,
211 value: value_rc,
212 left: RefCounter::new(AVL::Empty),
213 right: RefCounter::new(AVL::Empty),
214 },
215 AVL::Node {
216 key,
217 value,
218 left,
219 right,
220 } => match key_rc.cmp(key) {
221 std::cmp::Ordering::Less => AVL::Node {
222 key: key.clone(),
223 value: value.clone(),
224 left: RefCounter::new(left.put_rc(key_rc, value_rc)),
225 right: right.clone(),
226 }
227 .fix(),
228 std::cmp::Ordering::Equal => AVL::Node {
229 key: key_rc,
230 value: value_rc,
231 left: left.clone(),
232 right: right.clone(),
233 },
234 std::cmp::Ordering::Greater => AVL::Node {
235 key: key.clone(),
236 value: value.clone(),
237 left: left.clone(),
238 right: RefCounter::new(right.put_rc(key_rc, value_rc)),
239 }
240 .fix(),
241 },
242 }
243 }
244 pub fn delete(&self, target_key: &K) -> AVL<K, V> {
245 match self {
246 AVL::Empty => AVL::Empty,
247 AVL::Node {
248 key,
249 value,
250 left,
251 right,
252 } => {
253 match target_key.cmp(key) {
254 std::cmp::Ordering::Less => {
255 let left_deleted = left.delete(target_key);
256 AVL::Node {
257 key: key.clone(),
258 value: value.clone(),
259 left: RefCounter::new(left_deleted),
260 right: right.clone(),
261 }
262 .fix()
263 }
264 std::cmp::Ordering::Equal => {
265 if left.is_empty() {
267 return right.as_ref().clone();
268 } else if right.is_empty() {
269 return left.as_ref().clone();
270 }
271
272 let inorder_predecessor = left.find_max();
274 if let Some((pred_key, pred_value)) = inorder_predecessor {
275 let left_deleted = left.delete(&pred_key);
276 AVL::Node {
277 key: pred_key.clone(),
278 value: pred_value.clone(),
279 left: RefCounter::new(left_deleted),
280 right: right.clone(),
281 }
282 .fix()
283 } else {
284 self.clone()
285 }
286 }
287 std::cmp::Ordering::Greater => {
288 let right_deleted = right.delete(target_key);
289 AVL::Node {
290 key: key.clone(),
291 value: value.clone(),
292 left: left.clone(),
293 right: RefCounter::new(right_deleted),
294 }
295 .fix()
296 }
297 }
298 }
299 }
300 }
301
302 fn find_max(&self) -> Option<(RefCounter<K>, RefCounter<V>)> {
303 match self {
304 AVL::Empty => None,
305 AVL::Node {
306 key,
307 value,
308 left: _,
309 right,
310 } => {
311 if right.is_empty() {
312 Some((key.clone(), value.clone()))
313 } else {
314 right.find_max()
315 }
316 }
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_avl_set() {
327 let l = AVL::empty().insert(1).insert(2).insert(3).insert(4);
328 let l2 = l.clone().insert(5);
329 for i in 1..=4 {
330 assert!(l.search(&i));
331 assert!(l2.search(&i));
332 }
333 assert!(!l.search(&5));
334 assert!(l2.search(&5));
335 }
336
337 #[test]
338 fn test_avl_map() {
339 let l = AVL::empty().put(1, 999);
340 let l2 = l.clone().put(1, 123).put(2, 3);
341 assert_eq!(l.find(&1), Some(&999));
342 assert_eq!(l2.find(&1), Some(&123));
343 assert!(l.find(&2).is_none());
344 assert!(l2.find(&2).is_some());
345 }
346
347 #[test]
348 fn test_avl_delete() {
349 let l = AVL::empty()
350 .insert(1)
351 .insert(2)
352 .insert(3)
353 .insert(4)
354 .insert(5);
355 let l = l.delete(&3);
356 assert!(!l.search(&3));
357 assert!(l.search(&1));
358 assert!(l.search(&2));
359 assert!(l.search(&4));
360 assert!(l.search(&5));
361 }
362}