competitive_programming_rs/data_structure/
treap.rs

1pub mod treap {
2    type BNode<T> = Box<Node<T>>;
3    pub struct Treap<T> {
4        rng: XorShift,
5        root: Option<BNode<T>>,
6    }
7    impl<T> Treap<T> {
8        pub fn new(seed: u32) -> Self {
9            Self {
10                rng: XorShift { state: seed },
11                root: None,
12            }
13        }
14        pub fn len(&self) -> usize {
15            size(&self.root)
16        }
17        pub fn nth(&self, k: usize) -> &T {
18            let root = self
19                .root
20                .as_ref()
21                .expect("Cannot fetch the k-th element of an empty set.");
22            root.nth(k)
23        }
24    }
25
26    impl<T: PartialOrd> Treap<T> {
27        pub fn insert(&mut self, value: T) -> bool {
28            let priority = self.rng.next();
29            if let Some(root) = self.root.take() {
30                let (contains, k) = root.find(&value);
31                if !contains {
32                    self.root = Some(insert(Some(root), k, value, priority));
33                    true
34                } else {
35                    self.root = Some(root);
36                    false
37                }
38            } else {
39                self.root = Some(Node::new(value, priority));
40                true
41            }
42        }
43        pub fn contains(&self, value: &T) -> bool {
44            if let Some(root) = self.root.as_ref() {
45                root.find(value).0
46            } else {
47                false
48            }
49        }
50
51        pub fn erase(&mut self, value: &T) -> Option<T> {
52            if let Some(root) = self.root.take() {
53                let (contains, k) = root.find(&value);
54                if !contains {
55                    self.root = Some(root);
56                    None
57                } else {
58                    let (root, removed) = erase(Some(root), k);
59                    self.root = root;
60                    removed.map(|b| b.key)
61                }
62            } else {
63                None
64            }
65        }
66        pub fn binary_search(&self, value: &T) -> Result<usize, usize> {
67            match self.root.as_ref() {
68                Some(root) => {
69                    let (contains, k) = root.find(value);
70                    if contains {
71                        Ok(k)
72                    } else {
73                        Err(k)
74                    }
75                }
76                None => Err(0),
77            }
78        }
79    }
80
81    #[derive(Debug)]
82    struct Node<T> {
83        left: Option<BNode<T>>,
84        right: Option<BNode<T>>,
85        key: T,
86        priority: u32,
87        count: usize,
88    }
89
90    impl<T> Node<T> {
91        fn new(key: T, priority: u32) -> BNode<T> {
92            Box::new(Node {
93                left: None,
94                right: None,
95                key,
96                priority,
97                count: 1,
98            })
99        }
100        fn update_count(&mut self) {
101            self.count = size(&self.left) + size(&self.right) + 1;
102        }
103        fn nth(&self, k: usize) -> &T {
104            let left_size = size(&self.left);
105            if left_size > k {
106                let left = self.left.as_ref().expect("");
107                left.nth(k)
108            } else if left_size == k {
109                &self.key
110            } else {
111                let right = self.right.as_ref().expect("");
112                right.nth(k - left_size - 1)
113            }
114        }
115    }
116
117    impl<T: PartialOrd> Node<T> {
118        fn find(&self, value: &T) -> (bool, usize) {
119            let left_size = size(&self.left);
120            if &self.key == value {
121                (true, left_size)
122            } else if &self.key > value {
123                if let Some(left) = self.left.as_ref() {
124                    left.find(value)
125                } else {
126                    (false, 0)
127                }
128            } else {
129                if let Some(right) = self.right.as_ref() {
130                    let (contained, size) = right.find(value);
131                    (contained, size + left_size + 1)
132                } else {
133                    (false, left_size + 1)
134                }
135            }
136        }
137    }
138
139    fn insert<T>(t: Option<BNode<T>>, k: usize, value: T, priority: u32) -> BNode<T> {
140        let (first, second) = split(t, k);
141        let node = merge(first, Some(Node::new(value, priority)));
142        let mut node =
143            merge(node, second).expect("It shouldn't be a none, since one node is added at least.");
144        node.update_count();
145        node
146    }
147    fn erase<T>(node: Option<BNode<T>>, k: usize) -> (Option<BNode<T>>, Option<BNode<T>>) {
148        let (first, second) = split(node, k + 1);
149        let (first, removed) = split(first, k);
150        match merge(first, second) {
151            Some(mut node) => {
152                node.update_count();
153                (Some(node), removed)
154            }
155            None => (None, removed),
156        }
157    }
158
159    fn merge<T>(s: Option<BNode<T>>, t: Option<BNode<T>>) -> Option<BNode<T>> {
160        match (s, t) {
161            (Some(mut s), Some(mut t)) => {
162                if s.priority > t.priority {
163                    s.right = merge(s.right, Some(t));
164                    s.update_count();
165                    Some(s)
166                } else {
167                    t.left = merge(Some(s), t.left);
168                    t.update_count();
169                    Some(t)
170                }
171            }
172            (Some(s), None) => Some(s),
173            (None, Some(t)) => Some(t),
174            (None, None) => None,
175        }
176    }
177
178    fn split<T>(node: Option<BNode<T>>, k: usize) -> (Option<BNode<T>>, Option<BNode<T>>) {
179        if let Some(mut node) = node {
180            let left_size = size(&node.left);
181            if k <= left_size {
182                let left = node.left.take();
183                let (first, second) = split(left, k);
184                node.left = second;
185                node.update_count();
186                (first, Some(node))
187            } else {
188                let right = node.right.take();
189                let (first, second) = split(right, k - left_size - 1);
190                node.right = first;
191                node.update_count();
192                (Some(node), second)
193            }
194        } else {
195            (None, None)
196        }
197    }
198
199    fn size<T>(node: &Option<BNode<T>>) -> usize {
200        node.as_ref().map(|node| node.count).unwrap_or(0)
201    }
202
203    #[derive(Debug)]
204    struct XorShift {
205        state: u32,
206    }
207
208    impl XorShift {
209        fn next(&mut self) -> u32 {
210            self.state = xor_shift(self.state);
211            self.state
212        }
213    }
214
215    fn xor_shift(state: u32) -> u32 {
216        let mut x = state;
217        x ^= x << 13;
218        x ^= x >> 17;
219        x ^= x << 5;
220        x
221    }
222}
223
224#[cfg(test)]
225mod test {
226    use super::treap::*;
227    use rand::distributions::Uniform;
228    use rand::prelude::*;
229    use rand::{thread_rng, Rng};
230    use std::collections::BTreeSet;
231
232    #[test]
233    fn test_treap_insert_erase() {
234        let mut treap = Treap::new(71);
235        let mut rng = StdRng::seed_from_u64(141);
236        let max = 1000000;
237
238        let mut v = (0..max).collect::<Vec<_>>();
239        v.shuffle(&mut rng);
240        for &i in v.iter() {
241            assert!(!treap.contains(&i));
242            assert!(treap.insert(i));
243            assert!(!treap.insert(i));
244            assert!(treap.contains(&i));
245        }
246
247        v.shuffle(&mut rng);
248        for &i in v.iter() {
249            assert!(treap.contains(&i));
250            assert_eq!(treap.erase(&i), Some(i));
251            assert_eq!(treap.erase(&i), None);
252            assert!(!treap.contains(&i));
253        }
254    }
255
256    #[test]
257    fn test_treap_nth() {
258        let mut rng = StdRng::seed_from_u64(141);
259
260        for _ in 0..10 {
261            let mut treap = Treap::new(71);
262            let max = 100000;
263            let mut v = (0..max)
264                .map(|_| rng.gen_range(0, 1_000_000_000))
265                .collect::<Vec<_>>();
266            v.sort();
267            v.dedup();
268            v.shuffle(&mut rng);
269            for &i in v.iter() {
270                assert!(treap.insert(i));
271                assert!(!treap.insert(i));
272            }
273            v.sort();
274
275            for (i, v) in v.into_iter().enumerate() {
276                assert_eq!(treap.nth(i), &v);
277            }
278        }
279    }
280
281    #[test]
282    fn test_random_insertion() {
283        let mut rng = thread_rng();
284        let mut set = BTreeSet::new();
285        let mut treap = Treap::new(42);
286        for _ in 0..2000 {
287            let x = rng.gen::<i64>();
288
289            if rng.sample(Uniform::from(0..10)) == 0 {
290                // remove
291                if set.contains(&x) {
292                    assert!(treap.contains(&x));
293                    set.remove(&x);
294                    assert_eq!(treap.erase(&x), Some(x));
295                    assert_eq!(treap.erase(&x), None);
296                    assert!(!treap.contains(&x));
297                } else {
298                    assert!(!treap.contains(&x));
299                }
300            } else {
301                // insert
302                if set.contains(&x) {
303                    assert!(treap.contains(&x));
304                } else {
305                    assert!(!treap.contains(&x));
306                    assert!(treap.insert(x));
307                    assert!(!treap.insert(x));
308                    set.insert(x);
309                    assert!(treap.contains(&x));
310                }
311            }
312
313            assert_eq!(treap.len(), set.len());
314            for (i, x) in set.iter().enumerate() {
315                assert_eq!(treap.nth(i), x);
316                assert_eq!(treap.binary_search(x), Ok(i));
317            }
318        }
319    }
320}