libpuri/
lazy_seg_tree.rs

1use super::algebra::{Act, Monoid};
2use super::util::IntoIndex;
3use std::iter::FromIterator;
4use std::ptr;
5// TODO(yujingaya) Rewrite tests/doctests to reuse MinMax, Add structs when cfg(doctest) is stable.
6// reference: https://github.com/rust-lang/rust/issues/67295
7
8/// A segment tree that supports range query and range update.
9///
10/// # Use a lazy segment tree when:
11/// - You want to efficiently query on an interval property.
12/// - You want to efficiently update properties in an interval.
13///
14/// # Requirements
15///
16/// A lazy segment tree requires two monoids and an action of one on other like following:
17///
18/// - `M`: A [`monoid`](crate::Monoid) that represents the interval property.
19/// - `A`: A [`monoid`](crate::Monoid) that reprenents the interval update.
20/// - `L`: An [`action`](crate::Act) that represents the application of an update on a property.
21///
22/// For a lazy segment tree to work, the action should also satisfy the following property.
23///
24/// ```ignore
25/// // An element `f` of `A` should be a homomorphism from `M` to `M`.
26/// f(m * n) == f(m) * f(n)
27/// ```
28/// 
29/// This property cannot be checked by the compiler so the implementer should verify it by themself.
30/// 
31/// # Examples
32/// Following example supports two operations:
33///
34/// - Query minimum and maximum numbers within an interval.
35/// - Add a number to each element within an interval.
36///
37/// ```
38/// use libpuri::{Monoid, Act, LazySegTree};
39///
40/// #[derive(Clone, Debug, PartialEq, Eq)]
41/// struct MinMax(i64, i64);
42///
43/// #[derive(Clone, Debug, PartialEq, Eq)]
44/// struct Add(i64);
45///
46/// impl Monoid for MinMax {
47///     const ID: Self = MinMax(i64::MAX, i64::MIN);
48///     fn op(&self, rhs: &Self) -> Self {
49///         MinMax(self.0.min(rhs.0), self.1.max(rhs.1))
50///     }
51/// }
52///
53/// impl Monoid for Add {
54///     const ID: Self = Add(0);
55///     fn op(&self, rhs: &Self) -> Self {
56///         Add(self.0 + rhs.0)
57///     }
58/// }
59///
60/// impl Act<MinMax> for Add {
61///     fn act(&self, m: &MinMax) -> MinMax {
62///         if m == &MinMax::ID {
63///             MinMax::ID
64///         } else {
65///             MinMax(m.0 + self.0, m.1 + self.0)
66///         }
67///     }
68/// }
69///
70/// // Initialize with [0, 0, 0, 0, 0, 0]
71/// let mut lazy_tree: LazySegTree<MinMax, Add> = (0..6).map(|_| MinMax(0, 0)).collect();
72/// assert_eq!(lazy_tree.get(..), MinMax(0, 0));
73///
74/// // Range update [5, 5, 5, 5, 0, 0]
75/// lazy_tree.act(0..4, Add(5));
76///
77/// // Another range update [5, 5, 47, 47, 42, 42]
78/// lazy_tree.act(2..6, Add(42));
79///
80/// assert_eq!(lazy_tree.get(1..3), MinMax(5,  47));
81/// assert_eq!(lazy_tree.get(3..5), MinMax(42, 47));
82///
83/// // Set index 3 to 0 [5, 5, 47, 0, 42, 42]
84/// lazy_tree.set(3, MinMax(0, 0));
85///
86/// assert_eq!(lazy_tree.get(..), MinMax(0,  47));
87/// assert_eq!(lazy_tree.get(3..5), MinMax(0,  42));
88/// assert_eq!(lazy_tree.get(0), MinMax(5, 5));
89/// ```
90///
91// LazyAct of each node represents an action not yet commited to its children but already commited
92// to the node itself.
93// Conld be refactored into `Vec<(M, Option<A>)>`
94pub struct LazySegTree<M: Monoid + Clone, A: Monoid + Act<M> + Clone>(Vec<(M, A)>);
95
96impl<M: Monoid + Clone, A: Monoid + Act<M> + Clone> LazySegTree<M, A> {
97    fn size(&self) -> usize {
98        self.0.len() / 2
99    }
100
101    fn height(&self) -> u32 {
102        self.0.len().trailing_zeros()
103    }
104
105    fn act_lazy(node: &mut (M, A), a: &A) {
106        *node = (a.act(&node.0), a.op(&node.1));
107    }
108
109    fn propagate_to_children(&mut self, i: usize) {
110        let a = self.0[i].1.clone();
111        Self::act_lazy(&mut self.0[i * 2], &a);
112        Self::act_lazy(&mut self.0[i * 2 + 1], &a);
113
114        self.0[i].1 = A::ID;
115    }
116
117    fn propagate(&mut self, start: usize, end: usize) {
118        for i in (1..self.height()).rev() {
119            if (start >> i) << i != start {
120                self.propagate_to_children(start >> i);
121            }
122            if (end >> i) << i != end {
123                self.propagate_to_children((end - 1) >> i);
124            }
125        }
126    }
127
128    fn update(&mut self, i: usize) {
129        self.0[i].0 = self.0[i * 2].0.op(&self.0[i * 2 + 1].0);
130    }
131}
132
133impl<M: Monoid + Clone, A: Monoid + Act<M> + Clone> LazySegTree<M, A> {
134    /// Constructs a new lazy segment tree with at least given number of intervals can be stored.
135    ///
136    /// The segment tree will be initialized with the identity elements.
137    ///
138    /// # Complexity
139    /// O(n).
140    ///
141    /// If you know the initial elements in advance,
142    /// [`from_iter_sized()`](LazySegTree::from_iter_sized) should be preferred over `new()`.
143    ///
144    /// Initializing with the identity elements and updating n elements will tax you O(nlog(n)),
145    /// whereas `from_iter_sized()` is O(n) by computing the interval properties only once.
146    ///
147    /// # Examples
148    /// ```
149    /// # use libpuri::{LazySegTree, Monoid, Act};
150    /// #
151    /// # #[derive(Clone, Debug, PartialEq, Eq)]
152    /// # struct MinMax(i64, i64);
153    /// # impl Monoid for MinMax {
154    /// #     const ID: Self = MinMax(i64::MAX, i64::MIN);
155    /// #     fn op(&self, rhs: &Self) -> Self {
156    /// #         MinMax(self.0.min(rhs.0), self.1.max(rhs.1))
157    /// #     }
158    /// # }
159    /// #
160    /// # #[derive(Clone, Debug, PartialEq, Eq)]
161    /// # struct Add(i64);
162    /// # impl Monoid for Add {
163    /// #     const ID: Self = Add(0);
164    /// #     fn op(&self, rhs: &Self) -> Self {
165    /// #         Add(self.0.saturating_add(rhs.0))
166    /// #     }
167    /// # }
168    /// #
169    /// # impl Act<MinMax> for Add {
170    /// #     fn act(&self, m: &MinMax) -> MinMax {
171    /// #         if m == &MinMax::ID {
172    /// #             MinMax::ID
173    /// #         } else {
174    /// #             MinMax(m.0 + self.0, m.1 + self.0)
175    /// #         }
176    /// #     }
177    /// # }
178    /// let mut lazy_tree: LazySegTree<MinMax, Add> = LazySegTree::new(5);
179    ///
180    /// // Initialized with [id, id, id, id, id]
181    /// assert_eq!(lazy_tree.get(..), MinMax::ID);
182    /// ```
183    pub fn new(size: usize) -> Self {
184        LazySegTree(vec![(M::ID, A::ID); size.next_power_of_two() * 2])
185    }
186
187    /// Constructs a new lazy segment tree with given intervals properties.
188    ///
189    /// # Complexity
190    /// O(n).
191    ///
192    /// # Examples
193    /// ```
194    /// # use libpuri::{LazySegTree, Monoid, Act};
195    /// #
196    /// # #[derive(Clone, Debug, PartialEq, Eq)]
197    /// # struct MinMax(i64, i64);
198    /// # impl Monoid for MinMax {
199    /// #     const ID: Self = MinMax(i64::MAX, i64::MIN);
200    /// #     fn op(&self, rhs: &Self) -> Self {
201    /// #         MinMax(self.0.min(rhs.0), self.1.max(rhs.1))
202    /// #     }
203    /// # }
204    /// #
205    /// # #[derive(Clone, Debug, PartialEq, Eq)]
206    /// # struct Add(i64);
207    /// # impl Monoid for Add {
208    /// #     const ID: Self = Add(0);
209    /// #     fn op(&self, rhs: &Self) -> Self {
210    /// #         Add(self.0.saturating_add(rhs.0))
211    /// #     }
212    /// # }
213    /// #
214    /// # impl Act<MinMax> for Add {
215    /// #     fn act(&self, m: &MinMax) -> MinMax {
216    /// #         if m == &MinMax::ID {
217    /// #             MinMax::ID
218    /// #         } else {
219    /// #             MinMax(m.0 + self.0, m.1 + self.0)
220    /// #         }
221    /// #     }
222    /// # }
223    /// let v = [0, 42, 17, 6, -11].iter().map(|&i| MinMax(i, i));
224    /// let mut lazy_tree: LazySegTree<MinMax, Add> = LazySegTree::from_iter_sized(v, 5);
225    ///
226    /// // Initialized with [0, 42, 17, 6, -11]
227    /// assert_eq!(lazy_tree.get(..), MinMax(-11, 42));
228    /// ```
229    pub fn from_iter_sized<I: IntoIterator<Item = M>>(iter: I, size: usize) -> Self {
230        let mut iter = iter.into_iter();
231        let size = size.next_power_of_two();
232        let mut v = Vec::with_capacity(size * 2);
233
234        let v_ptr: *mut (M, A) = v.as_mut_ptr();
235
236        unsafe {
237            v.set_len(size * 2);
238
239            for i in 0..size {
240                ptr::write(
241                    v_ptr.add(size + i),
242                    if let Some(m) = iter.next() {
243                        (m, A::ID)
244                    } else {
245                        (M::ID, A::ID)
246                    },
247                );
248            }
249
250            for i in (1..size).rev() {
251                ptr::write(v_ptr.add(i), (v[i * 2].0.op(&v[i * 2 + 1].0), A::ID));
252            }
253        }
254
255        LazySegTree(v)
256    }
257
258    /// Queries on the given interval.
259    ///
260    /// Note that any [`RangeBounds`](std::ops::RangeBounds) can be used including
261    /// `..`, `a..`, `..b`, `..=c`, `d..e`, or `f..=g`.
262    /// You can just `seg_tree.get(..)` to get the interval property of the entire elements and
263    /// `lazy_tree.get(a)` to get a specific element.
264    /// # Examples
265    /// ```
266    /// # use libpuri::{LazySegTree, Monoid, Act};
267    /// #
268    /// # #[derive(Clone, Debug, PartialEq, Eq)]
269    /// # struct MinMax(i64, i64);
270    /// # impl Monoid for MinMax {
271    /// #     const ID: Self = MinMax(i64::MAX, i64::MIN);
272    /// #     fn op(&self, rhs: &Self) -> Self {
273    /// #         MinMax(self.0.min(rhs.0), self.1.max(rhs.1))
274    /// #     }
275    /// # }
276    /// #
277    /// # #[derive(Clone, Debug, PartialEq, Eq)]
278    /// # struct Add(i64);
279    /// # impl Monoid for Add {
280    /// #     const ID: Self = Add(0);
281    /// #     fn op(&self, rhs: &Self) -> Self {
282    /// #         Add(self.0.saturating_add(rhs.0))
283    /// #     }
284    /// # }
285    /// #
286    /// # impl Act<MinMax> for Add {
287    /// #     fn act(&self, m: &MinMax) -> MinMax {
288    /// #         if m == &MinMax::ID {
289    /// #             MinMax::ID
290    /// #         } else {
291    /// #             MinMax(m.0 + self.0, m.1 + self.0)
292    /// #         }
293    /// #     }
294    /// # }
295    /// // [0, 42, 6, 7, 2]
296    /// let mut lazy_tree: LazySegTree<MinMax, Add> = [0, 42, 6, 7, 2].iter()
297    ///     .map(|&n| MinMax(n, n))
298    ///     .collect();
299    ///
300    /// assert_eq!(lazy_tree.get(..), MinMax(0, 42));
301    ///
302    /// // [5, 47, 11, 7, 2]
303    /// lazy_tree.act(0..3, Add(5));
304    ///
305    /// // [5, 47, 4, 0, -5]
306    /// lazy_tree.act(2..5, Add(-7));
307    ///
308    /// assert_eq!(lazy_tree.get(..), MinMax(-5, 47));
309    /// assert_eq!(lazy_tree.get(..4), MinMax(0, 47));
310    /// assert_eq!(lazy_tree.get(2), MinMax(4, 4));
311    /// ```
312    pub fn get<R>(&mut self, range: R) -> M
313    where
314        R: IntoIndex,
315    {
316        let (mut start, mut end) = range.into_index(self.size());
317        start += self.size();
318        end += self.size();
319
320        self.propagate(start, end);
321
322        let mut m = M::ID;
323
324        while start < end {
325            if start % 2 == 1 {
326                m = self.0[start].0.op(&m);
327                start += 1;
328            }
329
330            if end % 2 == 1 {
331                end -= 1;
332                m = self.0[end].0.op(&m);
333            }
334
335            start /= 2;
336            end /= 2;
337        }
338
339        m
340    }
341
342    /// Sets an element with given index to the value. It propagates its update to its ancestors.
343    ///
344    /// It takes O(log(n)) to propagate the update as the height of the tree is log(n).
345    ///
346    /// # Examples
347    /// ```
348    /// # use libpuri::{LazySegTree, Monoid, Act};
349    /// #
350    /// # #[derive(Clone, Debug, PartialEq, Eq)]
351    /// # struct MinMax(i64, i64);
352    /// # impl Monoid for MinMax {
353    /// #     const ID: Self = MinMax(i64::MAX, i64::MIN);
354    /// #     fn op(&self, rhs: &Self) -> Self {
355    /// #         MinMax(self.0.min(rhs.0), self.1.max(rhs.1))
356    /// #     }
357    /// # }
358    /// #
359    /// # #[derive(Clone, Debug, PartialEq, Eq)]
360    /// # struct Add(i64);
361    /// # impl Monoid for Add {
362    /// #     const ID: Self = Add(0);
363    /// #     fn op(&self, rhs: &Self) -> Self {
364    /// #         Add(self.0.saturating_add(rhs.0))
365    /// #     }
366    /// # }
367    /// #
368    /// # impl Act<MinMax> for Add {
369    /// #     fn act(&self, m: &MinMax) -> MinMax {
370    /// #         if m == &MinMax::ID {
371    /// #             MinMax::ID
372    /// #         } else {
373    /// #             MinMax(m.0 + self.0, m.1 + self.0)
374    /// #         }
375    /// #     }
376    /// # }
377    /// // [0, 42, 6, 7, 2]
378    /// let mut lazy_tree: LazySegTree<MinMax, Add> = [0, 42, 6, 7, 2].iter()
379    ///     .map(|&n| MinMax(n, n))
380    ///     .collect();
381    ///
382    /// assert_eq!(lazy_tree.get(..), MinMax(0, 42));
383    ///
384    /// // [0, 1, 6, 7, 2]
385    /// lazy_tree.set(1, MinMax(1, 1));
386    ///
387    /// assert_eq!(lazy_tree.get(1), MinMax(1, 1));
388    /// assert_eq!(lazy_tree.get(..), MinMax(0, 7));
389    /// assert_eq!(lazy_tree.get(2..), MinMax(2, 7));
390    /// ```
391    pub fn set(&mut self, i: usize, m: M) {
392        let i = i + self.size();
393
394        for h in (1..=self.height()).rev() {
395            self.propagate_to_children(i >> h);
396        }
397
398        self.0[i] = (m, A::ID);
399
400        for h in 1..=self.height() {
401            self.update(i >> h);
402        }
403    }
404
405    /// Apply an action to elements within given range.
406    ///
407    /// It takes O(log(n)).
408    ///
409    /// # Examples
410    /// ```
411    /// # use libpuri::{LazySegTree, Monoid, Act};
412    /// #
413    /// # #[derive(Clone, Debug, PartialEq, Eq)]
414    /// # struct MinMax(i64, i64);
415    /// # impl Monoid for MinMax {
416    /// #     const ID: Self = MinMax(i64::MAX, i64::MIN);
417    /// #     fn op(&self, rhs: &Self) -> Self {
418    /// #         MinMax(self.0.min(rhs.0), self.1.max(rhs.1))
419    /// #     }
420    /// # }
421    /// #
422    /// # #[derive(Clone, Debug, PartialEq, Eq)]
423    /// # struct Add(i64);
424    /// # impl Monoid for Add {
425    /// #     const ID: Self = Add(0);
426    /// #     fn op(&self, rhs: &Self) -> Self {
427    /// #         Add(self.0.saturating_add(rhs.0))
428    /// #     }
429    /// # }
430    /// #
431    /// # impl Act<MinMax> for Add {
432    /// #     fn act(&self, m: &MinMax) -> MinMax {
433    /// #         if m == &MinMax::ID {
434    /// #             MinMax::ID
435    /// #         } else {
436    /// #             MinMax(m.0 + self.0, m.1 + self.0)
437    /// #         }
438    /// #     }
439    /// # }
440    /// // [0, 42, 6, 7, 2]
441    /// let mut lazy_tree: LazySegTree<MinMax, Add> = [0, 42, 6, 7, 2].iter()
442    ///     .map(|&n| MinMax(n, n))
443    ///     .collect();
444    ///
445    /// assert_eq!(lazy_tree.get(..), MinMax(0, 42));
446    ///
447    /// // [0, 30, -6, 7, 2]
448    /// lazy_tree.act(1..3, Add(-12));
449    ///
450    /// assert_eq!(lazy_tree.get(1), MinMax(30, 30));
451    /// assert_eq!(lazy_tree.get(..), MinMax(-6, 30));
452    /// assert_eq!(lazy_tree.get(2..), MinMax(-6, 7));
453    /// ```
454    pub fn act<R>(&mut self, range: R, a: A)
455    where
456        R: IntoIndex,
457    {
458        let (mut start, mut end) = range.into_index(self.size());
459        start += self.size();
460        end += self.size();
461
462        self.propagate(start, end);
463
464        {
465            let mut start = start;
466            let mut end = end;
467
468            while start < end {
469                if start % 2 == 1 {
470                    Self::act_lazy(&mut self.0[start], &a);
471                    start += 1;
472                }
473                if end % 2 == 1 {
474                    end -= 1;
475                    Self::act_lazy(&mut self.0[end], &a);
476                }
477
478                start /= 2;
479                end /= 2;
480            }
481        }
482
483        for i in 1..=self.height() {
484            if (start >> i) << i != start {
485                self.update(start >> i);
486            }
487            if (end >> i) << i != end {
488                self.update((end - 1) >> i);
489            }
490        }
491    }
492}
493
494// TODO(yujingaya) New way to print without Eq
495// impl<M, A> Debug for LazySegTree<M, A>
496// where
497//     M: Debug + Monoid,
498//     A: Debug + Act<M>,
499// {
500//     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
501//         let mut tree = "LazySegTree\n".to_owned();
502//         for h in 0..self.height() {
503//             for i in 1 << h..1 << (h + 1) {
504//                 tree.push_str(&if self.0[i].0 == M::ID {
505//                     "(id, ".to_owned()
506//                 } else {
507//                     format!("({:?}, ", self.0[i].0)
508//                 });
509//                 tree.push_str(&if self.0[i].1 == A::ID {
510//                     "id) ".to_owned()
511//                 } else {
512//                     format!("{:?}) ", self.0[i].1)
513//                 });
514//             }
515//             tree.pop();
516//             tree.push('\n');
517//         }
518
519//         f.write_str(&tree)
520//     }
521// }
522
523/// You can `collect` into a lazy segment tree.
524impl<M, A> FromIterator<M> for LazySegTree<M, A>
525where
526    M: Monoid + Clone,
527    A: Monoid + Act<M> + Clone,
528{
529    fn from_iter<I: IntoIterator<Item = M>>(iter: I) -> Self {
530        let v: Vec<M> = iter.into_iter().collect();
531        let len = v.len();
532
533        LazySegTree::from_iter_sized(v, len)
534    }
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540    use crate::Act;
541
542    #[derive(Clone, Debug, PartialEq, Eq)]
543    struct MinMax(i64, i64);
544    impl Monoid for MinMax {
545        const ID: Self = MinMax(i64::MAX, i64::MIN);
546        fn op(&self, rhs: &Self) -> Self {
547            MinMax(self.0.min(rhs.0), self.1.max(rhs.1))
548        }
549    }
550
551    #[derive(Clone, Debug, PartialEq, Eq)]
552    struct Add(i64);
553    impl Monoid for Add {
554        const ID: Self = Add(0);
555        fn op(&self, rhs: &Self) -> Self {
556            Add(self.0.saturating_add(rhs.0))
557        }
558    }
559
560    impl Act<MinMax> for Add {
561        fn act(&self, m: &MinMax) -> MinMax {
562            if m == &MinMax::ID {
563                MinMax::ID
564            } else {
565                MinMax(m.0 + self.0, m.1 + self.0)
566            }
567        }
568    }
569
570    #[test]
571    fn min_max_and_range_add() {
572        // [id, id, id, id, id, id, id, id]
573        let mut t: LazySegTree<MinMax, Add> = LazySegTree::new(8);
574        assert_eq!(t.get(..), MinMax::ID);
575
576        // [0,  0,  0,  0,  0,  0, id, id]
577        for i in 0..6 {
578            t.set(i, MinMax(0, 0));
579        }
580
581        // [5,  5,  5,  5,  0,  0, id, id]
582        t.act(0..=3, Add(5));
583
584        // [5,  5, 47, 47, 42, 42, id, id]
585        t.act(2..=5, Add(42));
586
587        assert_eq!(t.get(0..=1), MinMax(5, 5));
588        assert_eq!(t.get(1..=2), MinMax(5, 47));
589        assert_eq!(t.get(2..=3), MinMax(47, 47));
590        assert_eq!(t.get(3..=4), MinMax(42, 47));
591        assert_eq!(t.get(4..=5), MinMax(42, 42));
592        assert_eq!(t.get(5..=6), MinMax(42, 42));
593        assert_eq!(t.get(0..=5), MinMax(5, 47));
594        assert_eq!(t.get(6..=7), MinMax::ID);
595        assert_eq!(t.get(5), MinMax(42, 42));
596    }
597
598    #[test]
599    fn many_intervals() {
600        let mut t: LazySegTree<MinMax, Add> = LazySegTree::new(88);
601
602        for i in 0..88 {
603            t.set(i, MinMax(0, 0));
604        }
605
606        t.act(0..20, Add(5));
607        t.act(20..40, Add(42));
608        t.act(40..60, Add(-5));
609        t.act(60..88, Add(17));
610        t.act(10..70, Add(1));
611
612        assert_eq!(t.get(..), MinMax(-4, 43));
613        assert_eq!(t.get(0..20), MinMax(5, 6));
614        assert_eq!(t.get(70..88), MinMax(17, 17));
615        assert_eq!(t.get(40), MinMax(-4, -4));
616    }
617}