dsalgo/
avl_tree_ordered_multiset_merge_split_based.rs1use crate::avl_tree_node_with_box_recurse_merge_split_based::Node;
2
3#[derive(Debug)]
4
5pub struct AVLMultiset<T>(Option<Box<Node<T>>>);
6
7impl<T: Ord> AVLMultiset<T> {
8 pub fn new() -> Self {
9 Self(None)
10 }
11
12 pub fn size(&self) -> usize {
13 Node::size(self.0.as_ref())
14 }
15
16 pub fn lower_bound(
17 &self,
18 value: &T,
19 ) -> usize {
20 Node::binary_search(|v| v >= &value, self.0.as_ref())
21 }
22
23 pub fn upper_bound(
24 &self,
25 value: &T,
26 ) -> usize {
27 Node::binary_search(|v| v > &value, self.0.as_ref())
28 }
29
30 pub fn count(
31 &self,
32 value: &T,
33 ) -> usize {
34 self.upper_bound(value) - self.lower_bound(value)
35 }
36
37 pub fn contains(
38 &self,
39 value: &T,
40 ) -> bool {
41 self.count(value) > 0
42 }
43
44 pub fn insert(
45 &mut self,
46 value: T,
47 ) {
48 let i = self.lower_bound(&value);
49
50 self.0 = Node::insert(self.0.take(), i, Some(Node::new(value)));
51 }
52
53 pub fn remove(
54 &mut self,
55 value: &T,
56 ) {
57 if !self.contains(value) {
58 return;
59 }
60
61 let i = self.lower_bound(value);
62
63 self.0 = Node::remove(self.0.take(), i);
64 }
65
66 pub fn remove_all(
67 &mut self,
68 value: &T,
69 ) {
70 let l = self.lower_bound(value);
71
72 let r = self.upper_bound(value);
73
74 self.0 = Node::remove_range(self.0.take(), l, r);
75 }
76
77 pub fn iter<'a>(&'a self) -> std::vec::IntoIter<&'a T> {
78 self.0.as_ref().unwrap().iter()
79 }
80}
81
82use std::ops::*;
83
84impl<T> Index<usize> for AVLMultiset<T> {
85 type Output = T;
86
87 fn index(
88 &self,
89 i: usize,
90 ) -> &Self::Output {
91 &Node::kth_node(self.0.as_ref().unwrap(), i).value
92 }
93}
94
95#[cfg(test)]
96
97mod tests {
98
99 use super::*;
100
101 #[test]
102
103 fn test() {
104 let mut s = AVLMultiset::new();
105
106 s.insert("b");
107
108 s.insert("a");
109
110 s.insert("b");
111
112 println!("{:?}", s);
113
114 println!("{:?}", s[0]);
115
116 println!("{:?}", s[1]);
117
118 println!("{:?}", s[2]);
119
120 assert_eq!(s.count(&"b"), 2);
121
122 s.remove_all(&"b");
123
124 for &v in s.iter() {
125 println!("{:?}", v);
126 }
127 }
128}