ac_library/
segtree.rs

1use crate::internal_bit::ceil_pow2;
2use crate::internal_type_traits::{BoundedAbove, BoundedBelow, One, Zero};
3use std::cmp::{max, min};
4use std::convert::Infallible;
5use std::iter::FromIterator;
6use std::marker::PhantomData;
7use std::ops::{Add, BitAnd, BitOr, BitXor, Bound, Mul, Not, RangeBounds};
8
9// TODO Should I split monoid-related traits to another module?
10pub trait Monoid {
11    type S: Clone;
12    fn identity() -> Self::S;
13    fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S;
14}
15
16pub struct Max<S>(Infallible, PhantomData<fn() -> S>);
17impl<S> Monoid for Max<S>
18where
19    S: Copy + Ord + BoundedBelow,
20{
21    type S = S;
22    fn identity() -> Self::S {
23        S::min_value()
24    }
25    fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
26        max(*a, *b)
27    }
28}
29
30pub struct Min<S>(Infallible, PhantomData<fn() -> S>);
31impl<S> Monoid for Min<S>
32where
33    S: Copy + Ord + BoundedAbove,
34{
35    type S = S;
36    fn identity() -> Self::S {
37        S::max_value()
38    }
39    fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
40        min(*a, *b)
41    }
42}
43
44pub struct Additive<S>(Infallible, PhantomData<fn() -> S>);
45impl<S> Monoid for Additive<S>
46where
47    S: Copy + Add<Output = S> + Zero,
48{
49    type S = S;
50    fn identity() -> Self::S {
51        S::zero()
52    }
53    fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
54        *a + *b
55    }
56}
57
58pub struct Multiplicative<S>(Infallible, PhantomData<fn() -> S>);
59impl<S> Monoid for Multiplicative<S>
60where
61    S: Copy + Mul<Output = S> + One,
62{
63    type S = S;
64    fn identity() -> Self::S {
65        S::one()
66    }
67    fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
68        *a * *b
69    }
70}
71
72pub struct BitwiseOr<S>(Infallible, PhantomData<fn() -> S>);
73impl<S> Monoid for BitwiseOr<S>
74where
75    S: Copy + BitOr<Output = S> + Zero,
76{
77    type S = S;
78    fn identity() -> Self::S {
79        S::zero()
80    }
81    fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
82        *a | *b
83    }
84}
85
86pub struct BitwiseAnd<S>(Infallible, PhantomData<fn() -> S>);
87impl<S> Monoid for BitwiseAnd<S>
88where
89    S: Copy + BitAnd<Output = S> + Not<Output = S> + Zero,
90{
91    type S = S;
92    fn identity() -> Self::S {
93        !S::zero()
94    }
95    fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
96        *a & *b
97    }
98}
99
100pub struct BitwiseXor<S>(Infallible, PhantomData<fn() -> S>);
101impl<S> Monoid for BitwiseXor<S>
102where
103    S: Copy + BitXor<Output = S> + Zero,
104{
105    type S = S;
106    fn identity() -> Self::S {
107        S::zero()
108    }
109    fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
110        *a ^ *b
111    }
112}
113
114impl<M: Monoid> Default for Segtree<M> {
115    fn default() -> Self {
116        Segtree::new(0)
117    }
118}
119impl<M: Monoid> Segtree<M> {
120    pub fn new(n: usize) -> Segtree<M> {
121        vec![M::identity(); n].into()
122    }
123}
124impl<M: Monoid> From<Vec<M::S>> for Segtree<M> {
125    fn from(v: Vec<M::S>) -> Self {
126        let n = v.len();
127        let log = ceil_pow2(n as u32) as usize;
128        let size = 1 << log;
129        let mut d = vec![M::identity(); 2 * size];
130        d[size..][..n].clone_from_slice(&v);
131        let mut ret = Segtree { n, size, log, d };
132        for i in (1..size).rev() {
133            ret.update(i);
134        }
135        ret
136    }
137}
138impl<M: Monoid> FromIterator<M::S> for Segtree<M> {
139    fn from_iter<T: IntoIterator<Item = M::S>>(iter: T) -> Self {
140        let v = iter.into_iter().collect::<Vec<_>>();
141        v.into()
142    }
143}
144impl<M: Monoid> Segtree<M> {
145    pub fn set(&mut self, mut p: usize, x: M::S) {
146        assert!(p < self.n);
147        p += self.size;
148        self.d[p] = x;
149        for i in 1..=self.log {
150            self.update(p >> i);
151        }
152    }
153
154    pub fn get(&self, p: usize) -> M::S {
155        assert!(p < self.n);
156        self.d[p + self.size].clone()
157    }
158
159    pub fn get_slice(&self) -> &[M::S] {
160        &self.d[self.size..][..self.n]
161    }
162
163    pub fn prod<R>(&self, range: R) -> M::S
164    where
165        R: RangeBounds<usize>,
166    {
167        // Trivial optimization
168        if range.start_bound() == Bound::Unbounded && range.end_bound() == Bound::Unbounded {
169            return self.all_prod();
170        }
171
172        let mut r = match range.end_bound() {
173            Bound::Included(r) => r + 1,
174            Bound::Excluded(r) => *r,
175            Bound::Unbounded => self.n,
176        };
177        let mut l = match range.start_bound() {
178            Bound::Included(l) => *l,
179            Bound::Excluded(l) => l + 1,
180            // TODO: There are another way of optimizing [0..r)
181            Bound::Unbounded => 0,
182        };
183
184        assert!(l <= r && r <= self.n);
185        let mut sml = M::identity();
186        let mut smr = M::identity();
187        l += self.size;
188        r += self.size;
189
190        while l < r {
191            if l & 1 != 0 {
192                sml = M::binary_operation(&sml, &self.d[l]);
193                l += 1;
194            }
195            if r & 1 != 0 {
196                r -= 1;
197                smr = M::binary_operation(&self.d[r], &smr);
198            }
199            l >>= 1;
200            r >>= 1;
201        }
202
203        M::binary_operation(&sml, &smr)
204    }
205
206    pub fn all_prod(&self) -> M::S {
207        self.d[1].clone()
208    }
209
210    pub fn max_right<F>(&self, mut l: usize, f: F) -> usize
211    where
212        F: Fn(&M::S) -> bool,
213    {
214        assert!(l <= self.n);
215        assert!(f(&M::identity()));
216        if l == self.n {
217            return self.n;
218        }
219        l += self.size;
220        let mut sm = M::identity();
221        while {
222            // do
223            while l % 2 == 0 {
224                l >>= 1;
225            }
226            if !f(&M::binary_operation(&sm, &self.d[l])) {
227                while l < self.size {
228                    l *= 2;
229                    let res = M::binary_operation(&sm, &self.d[l]);
230                    if f(&res) {
231                        sm = res;
232                        l += 1;
233                    }
234                }
235                return l - self.size;
236            }
237            sm = M::binary_operation(&sm, &self.d[l]);
238            l += 1;
239            // while
240            {
241                let l = l as isize;
242                (l & -l) != l
243            }
244        } {}
245        self.n
246    }
247
248    pub fn min_left<F>(&self, mut r: usize, f: F) -> usize
249    where
250        F: Fn(&M::S) -> bool,
251    {
252        assert!(r <= self.n);
253        assert!(f(&M::identity()));
254        if r == 0 {
255            return 0;
256        }
257        r += self.size;
258        let mut sm = M::identity();
259        while {
260            // do
261            r -= 1;
262            while r > 1 && r % 2 == 1 {
263                r >>= 1;
264            }
265            if !f(&M::binary_operation(&self.d[r], &sm)) {
266                while r < self.size {
267                    r = 2 * r + 1;
268                    let res = M::binary_operation(&self.d[r], &sm);
269                    if f(&res) {
270                        sm = res;
271                        r -= 1;
272                    }
273                }
274                return r + 1 - self.size;
275            }
276            sm = M::binary_operation(&self.d[r], &sm);
277            // while
278            {
279                let r = r as isize;
280                (r & -r) != r
281            }
282        } {}
283        0
284    }
285
286    fn update(&mut self, k: usize) {
287        self.d[k] = M::binary_operation(&self.d[2 * k], &self.d[2 * k + 1]);
288    }
289}
290
291// Maybe we can use this someday
292// ```
293// for i in 0..=self.log {
294//     for j in 0..1 << i {
295//         print!("{}\t", self.d[(1 << i) + j]);
296//     }
297//     println!();
298// }
299// ```
300
301#[derive(Clone)]
302pub struct Segtree<M>
303where
304    M: Monoid,
305{
306    // variable name is _n in original library
307    n: usize,
308    size: usize,
309    log: usize,
310    d: Vec<M::S>,
311}
312
313#[cfg(test)]
314mod tests {
315    use crate::segtree::Max;
316    use crate::Segtree;
317    use std::ops::{Bound::*, RangeBounds};
318
319    #[test]
320    fn test_max_segtree() {
321        let base = vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3];
322        let n = base.len();
323        let segtree: Segtree<Max<_>> = base.clone().into();
324        check_segtree(&base, &segtree);
325
326        let mut segtree = Segtree::<Max<_>>::new(n);
327        let mut internal = vec![i32::MIN; n];
328        for i in 0..n {
329            segtree.set(i, base[i]);
330            internal[i] = base[i];
331            check_segtree(&internal, &segtree);
332        }
333
334        segtree.set(6, 5);
335        internal[6] = 5;
336        check_segtree(&internal, &segtree);
337
338        segtree.set(6, 0);
339        internal[6] = 0;
340        check_segtree(&internal, &segtree);
341    }
342
343    #[test]
344    fn test_segtree_fromiter() {
345        let v = [1, 4, 1, 4, 2, 1, 3, 5, 6];
346        let base = v
347            .iter()
348            .copied()
349            .filter(|&x| x % 2 == 0)
350            .collect::<Vec<_>>();
351        let segtree: Segtree<Max<_>> = v.iter().copied().filter(|&x| x % 2 == 0).collect();
352        check_segtree(&base, &segtree);
353    }
354
355    //noinspection DuplicatedCode
356    fn check_segtree(base: &[i32], segtree: &Segtree<Max<i32>>) {
357        let n = base.len();
358        #[allow(clippy::needless_range_loop)]
359        for i in 0..n {
360            assert_eq!(segtree.get(i), base[i]);
361        }
362
363        check(base, segtree, ..);
364        for i in 0..=n {
365            check(base, segtree, ..i);
366            check(base, segtree, i..);
367            if i < n {
368                check(base, segtree, ..=i);
369            }
370            for j in i..=n {
371                check(base, segtree, i..j);
372                if j < n {
373                    check(base, segtree, i..=j);
374                    check(base, segtree, (Excluded(i), Included(j)));
375                }
376            }
377        }
378        assert_eq!(
379            segtree.all_prod(),
380            base.iter().max().copied().unwrap_or(i32::MAX)
381        );
382        for k in 0..=10 {
383            let f = |&x: &i32| x < k;
384            for i in 0..=n {
385                assert_eq!(
386                    Some(segtree.max_right(i, f)),
387                    (i..=n)
388                        .filter(|&j| f(&base[i..j].iter().max().copied().unwrap_or(i32::MIN)))
389                        .max()
390                );
391            }
392            for j in 0..=n {
393                assert_eq!(
394                    Some(segtree.min_left(j, f)),
395                    (0..=j)
396                        .filter(|&i| f(&base[i..j].iter().max().copied().unwrap_or(i32::MIN)))
397                        .min()
398                );
399            }
400        }
401    }
402
403    fn check(base: &[i32], segtree: &Segtree<Max<i32>>, range: impl RangeBounds<usize>) {
404        let expected = base
405            .iter()
406            .enumerate()
407            .filter_map(|(i, a)| Some(a).filter(|_| range.contains(&i)))
408            .max()
409            .copied()
410            .unwrap_or(i32::MIN);
411        assert_eq!(segtree.prod(range), expected);
412    }
413}