dsalgo/
pivot_tree_ordered_set_usize_with_size_with_vec.rs

1pub struct PivotTreeSet {
2    data: Vec<Option<usize>>,
3    size: Vec<usize>,
4    max_height: usize,
5}
6
7fn left(pivot: usize) -> usize {
8    pivot - (1 << pivot.trailing_zeros() - 1)
9}
10
11fn right(pivot: usize) -> usize {
12    pivot + (1 << pivot.trailing_zeros() - 1)
13}
14
15impl PivotTreeSet {
16    pub fn new(max_height: usize) -> Self {
17        assert!(max_height > 0);
18
19        let n = 1 << max_height;
20
21        Self { data: vec![None; n], size: vec![0; n], max_height }
22    }
23
24    fn root_pivot(&self) -> usize {
25        1 << self.max_height - 1
26    }
27
28    fn left_size(
29        &self,
30        p: usize,
31    ) -> usize {
32        if p & 1 == 1 {
33            0
34        } else {
35            self.size[left(p)]
36        }
37    }
38
39    fn right_size(
40        &self,
41        p: usize,
42    ) -> usize {
43        if p & 1 == 1 {
44            0
45        } else {
46            self.size[right(p)]
47        }
48    }
49
50    pub fn size(&self) -> usize {
51        self.size[self.root_pivot()]
52    }
53
54    fn update(
55        &mut self,
56        p: usize,
57    ) {
58        if self.data[p].is_none() {
59            self.size[p] = 0;
60
61            return;
62        }
63
64        self.size[p] = self.left_size(p) + self.right_size(p) + 1;
65    }
66
67    pub fn _insert(
68        &mut self,
69        p: usize,
70        mut v: usize,
71    ) {
72        use std::mem::swap;
73
74        let value = self.data[p];
75
76        if value.is_none() {
77            debug_assert!(self.size[p] == 0);
78
79            self.data[p] = Some(v);
80
81            self.size[p] = 1;
82
83            return;
84        }
85
86        let mut value = value.unwrap();
87
88        if v == value {
89            return;
90        }
91
92        let d = 1 << p.trailing_zeros();
93
94        assert!(p - d < v && v < p + d);
95
96        if value.min(v) < p {
97            if value < v {
98                swap(&mut value, &mut v);
99            }
100
101            self.data[p] = Some(value);
102
103            self._insert(left(p), v);
104        } else {
105            if value > v {
106                swap(&mut value, &mut v);
107            }
108
109            self.data[p] = Some(value);
110
111            self._insert(right(p), v);
112        }
113
114        self.update(p);
115    }
116
117    fn _remove(
118        &mut self,
119        p: usize,
120        i: usize,
121    ) {
122        assert!(i < self.size[p]);
123
124        let lsize = self.left_size(p);
125
126        if i < lsize {
127            self._remove(left(p), i);
128        } else if i > lsize {
129            self._remove(right(p), i - lsize - 1);
130        } else {
131            if self.right_size(p) > 0 {
132                let rp = right(p);
133
134                self.data[p] = Some(self.kth_value(rp, 0));
135
136                self._remove(rp, 0);
137            } else if lsize > 0 {
138                let lp = left(p);
139
140                self.data[p] = Some(self.kth_value(lp, lsize - 1));
141
142                self._remove(lp, lsize - 1);
143            } else {
144                self.data[p] = None;
145            }
146        }
147
148        self.update(p);
149    }
150
151    fn kth_value(
152        &self,
153        p: usize,
154        k: usize,
155    ) -> usize {
156        assert!(k < self.size[p]);
157
158        let lsize = self.left_size(p);
159
160        if k < lsize {
161            self.kth_value(left(p), k)
162        } else if k == lsize {
163            self.data[p].unwrap()
164        } else {
165            self.kth_value(right(p), k - lsize - 1)
166        }
167    }
168
169    fn binary_search<F>(
170        &self,
171        f: F,
172        p: usize,
173    ) -> usize
174    where
175        F: Fn(usize) -> bool,
176    {
177        let v = self.data[p];
178
179        if v.is_none() {
180            return 0;
181        }
182
183        let v = v.unwrap();
184
185        if f(v) {
186            if p & 1 == 1 {
187                0
188            } else {
189                self.binary_search(f, left(p))
190            }
191        } else {
192            let i = self.left_size(p) + 1;
193
194            i + if p & 1 == 1 { 0 } else { self.binary_search(f, right(p)) }
195        }
196    }
197
198    pub fn get(
199        &self,
200        i: usize,
201    ) -> usize {
202        self.kth_value(self.root_pivot(), i) - 1
203    }
204
205    pub fn lower_bound(
206        &self,
207        x: usize,
208    ) -> usize {
209        self.binary_search(|v| v >= x + 1, self.root_pivot())
210    }
211
212    pub fn upper_bound(
213        &self,
214        x: usize,
215    ) -> usize {
216        self.binary_search(|v| v > x + 1, self.root_pivot())
217    }
218
219    pub fn count(
220        &self,
221        x: usize,
222    ) -> usize {
223        self.upper_bound(x) - self.lower_bound(x)
224    }
225
226    pub fn contains(
227        &self,
228        x: usize,
229    ) -> bool {
230        self.count(x) > 0
231    }
232
233    pub fn insert(
234        &mut self,
235        mut x: usize,
236    ) {
237        assert!(x < (1 << self.max_height) - 1);
238
239        if self.contains(x) {
240            return;
241        }
242
243        x += 1;
244
245        self._insert(self.root_pivot(), x);
246    }
247
248    pub fn remove(
249        &mut self,
250        x: usize,
251    ) {
252        if !self.contains(x) {
253            return;
254        }
255
256        let i = self.lower_bound(x);
257
258        self._remove(self.root_pivot(), i);
259    }
260}
261
262#[cfg(test)]
263
264mod tests {
265
266    use super::*;
267
268    #[test]
269
270    fn test() {
271        let h = 20;
272
273        let mut s = PivotTreeSet::new(h);
274
275        s.insert(1);
276
277        assert_eq!(s.size(), 1);
278
279        s.insert(0);
280
281        assert_eq!(s.size(), 2);
282
283        s.insert(1 << (h - 1));
284
285        assert_eq!(s.size(), 3);
286
287        assert_eq!(s.get(2), 1 << (h - 1));
288
289        assert_eq!(s.get(1), 1);
290
291        assert_eq!(s.get(0), 0);
292
293        assert!(s.contains(0));
294
295        s.remove(0);
296
297        assert!(!s.contains(0));
298    }
299
300    #[test]
301
302    fn test_abc217() {
303        let cases = vec![
304            (5, vec![((2, 2), 5), ((1, 3), 0), ((2, 2), 3)]),
305            (5, vec![((1, 2), 0), ((1, 4), 0), ((2, 3), 2)]),
306            (
307                100,
308                vec![
309                    ((1, 31), 0),
310                    ((2, 41), 69),
311                    ((1, 59), 0),
312                    ((2, 26), 31),
313                    ((1, 53), 0),
314                    ((2, 58), 6),
315                    ((1, 97), 0),
316                    ((2, 93), 38),
317                    ((1, 23), 0),
318                    ((2, 84), 38),
319                ],
320            ),
321        ];
322
323        for (l, q) in cases {
324            let mut s = PivotTreeSet::new(20);
325
326            s.insert(0);
327
328            s.insert(l);
329
330            for ((t, x), ans) in q {
331                if t == 1 {
332                    s.insert(x);
333                } else {
334                    let i = s.lower_bound(x);
335
336                    dbg!(i);
337
338                    assert_eq!(s.get(i) - s.get(i - 1), ans);
339                }
340            }
341        }
342    }
343
344    #[test]
345
346    fn test_arc033_3() {
347        let cases = vec![
348            vec![
349                ((1, 11), 0),
350                ((1, 29), 0),
351                ((1, 89), 0),
352                ((2, 2), 29),
353                ((2, 2), 89),
354            ],
355            vec![
356                ((1, 8932), 0),
357                ((1, 183450), 0),
358                ((1, 34323), 0),
359                ((1, 81486), 0),
360                ((1, 127874), 0),
361                ((1, 114850), 0),
362                ((1, 55277), 0),
363                ((1, 112706), 0),
364                ((2, 3), 55277),
365                ((1, 39456), 0),
366                ((1, 52403), 0),
367                ((2, 4), 52403),
368            ],
369        ];
370
371        for q in cases {
372            let mut s = PivotTreeSet::new(18);
373
374            for ((t, x), ans) in q {
375                if t == 1 {
376                    s.insert(x);
377                } else {
378                    let v = s.get(x - 1);
379
380                    assert_eq!(v, ans);
381
382                    s.remove(v);
383                }
384            }
385        }
386    }
387}