competitive_programming_rs/data_structure/
lazy_segment_tree.rs

1pub mod lazy_segment_tree {
2    type Range = std::ops::Range<usize>;
3
4    pub struct LazySegmentTree<S, Op, E, F, Mapping, Composition, Id> {
5        n: usize,
6        size: usize,
7        log: usize,
8        data: Vec<S>,
9        lazy: Vec<F>,
10        op: Op,
11        e: E,
12        mapping: Mapping,
13        composition: Composition,
14        id: Id,
15    }
16
17    impl<S, Op, E, F, Mapping, Composition, Id> LazySegmentTree<S, Op, E, F, Mapping, Composition, Id>
18    where
19        S: Clone,
20        E: Fn() -> S,
21        F: Clone,
22        Op: Fn(&S, &S) -> S,
23        Mapping: Fn(&F, &S) -> S,
24        Composition: Fn(&F, &F) -> F,
25        Id: Fn() -> F,
26    {
27        pub fn new(
28            n: usize,
29            e: E,
30            op: Op,
31            mapping: Mapping,
32            composition: Composition,
33            id: Id,
34        ) -> Self {
35            let size = n.next_power_of_two() as usize;
36            LazySegmentTree {
37                n,
38                size,
39                log: size.trailing_zeros() as usize,
40                data: vec![e(); 2 * size],
41                lazy: vec![id(); size],
42                e,
43                op,
44                mapping,
45                composition,
46                id,
47            }
48        }
49        pub fn set(&mut self, mut index: usize, value: S) {
50            assert!(index < self.n);
51            index += self.size;
52            for i in (1..=self.log).rev() {
53                self.push(index >> i);
54            }
55            self.data[index] = value;
56            for i in 1..=self.log {
57                self.update(index >> i);
58            }
59        }
60
61        pub fn get(&mut self, mut index: usize) -> S {
62            assert!(index < self.n);
63            index += self.size;
64            for i in (1..=self.log).rev() {
65                self.push(index >> i);
66            }
67            self.data[index].clone()
68        }
69
70        pub fn prod(&mut self, range: Range) -> S {
71            let mut l = range.start;
72            let mut r = range.end;
73            assert!(l < r && r <= self.n);
74
75            l += self.size;
76            r += self.size;
77
78            for i in (1..=self.log).rev() {
79                if ((l >> i) << i) != l {
80                    self.push(l >> i);
81                }
82                if ((r >> i) << i) != r {
83                    self.push(r >> i);
84                }
85            }
86
87            let mut sum_l = (self.e)();
88            let mut sum_r = (self.e)();
89            while l < r {
90                if l & 1 != 0 {
91                    sum_l = (self.op)(&sum_l, &self.data[l]);
92                    l += 1;
93                }
94                if r & 1 != 0 {
95                    r -= 1;
96                    sum_r = (self.op)(&self.data[r], &sum_r);
97                }
98                l >>= 1;
99                r >>= 1;
100            }
101
102            (self.op)(&sum_l, &sum_r)
103        }
104
105        pub fn all_prod(&self) -> S {
106            self.data[1].clone()
107        }
108
109        pub fn apply(&mut self, mut index: usize, f: F) {
110            assert!(index < self.n);
111            index += self.size;
112            for i in (1..=self.log).rev() {
113                self.push(index >> i);
114            }
115            self.data[index] = (self.mapping)(&f, &self.data[index]);
116            for i in 1..=self.log {
117                self.update(index >> i);
118            }
119        }
120        pub fn apply_range(&mut self, range: Range, f: F) {
121            let mut l = range.start;
122            let mut r = range.end;
123            assert!(l <= r && r <= self.n);
124            if l == r {
125                return;
126            }
127
128            l += self.size;
129            r += self.size;
130
131            for i in (1..=self.log).rev() {
132                if ((l >> i) << i) != l {
133                    self.push(l >> i);
134                }
135                if ((r >> i) << i) != r {
136                    self.push((r - 1) >> i);
137                }
138            }
139
140            {
141                let mut l = l;
142                let mut r = r;
143                while l < r {
144                    if l & 1 != 0 {
145                        self.all_apply(l, f.clone());
146                        l += 1;
147                    }
148                    if r & 1 != 0 {
149                        r -= 1;
150                        self.all_apply(r, f.clone());
151                    }
152                    l >>= 1;
153                    r >>= 1;
154                }
155            }
156
157            for i in 1..=self.log {
158                if ((l >> i) << i) != l {
159                    self.update(l >> i);
160                }
161                if ((r >> i) << i) != r {
162                    self.update((r - 1) >> i);
163                }
164            }
165        }
166
167        fn update(&mut self, k: usize) {
168            self.data[k] = (self.op)(&self.data[2 * k], &self.data[2 * k + 1]);
169        }
170        fn all_apply(&mut self, k: usize, f: F) {
171            self.data[k] = (self.mapping)(&f, &self.data[k]);
172            if k < self.size {
173                self.lazy[k] = (self.composition)(&f, &self.lazy[k]);
174            }
175        }
176        fn push(&mut self, k: usize) {
177            self.all_apply(2 * k, self.lazy[k].clone());
178            self.all_apply(2 * k + 1, self.lazy[k].clone());
179            self.lazy[k] = (self.id)();
180        }
181    }
182}
183
184#[cfg(test)]
185mod test {
186    use super::lazy_segment_tree::*;
187    use rand::prelude::*;
188
189    const INF: i64 = 1 << 60;
190
191    #[test]
192    fn edge_case() {
193        let n = 5;
194        let mut seg_min = LazySegmentTree::new(
195            n,
196            || INF,
197            |&s, &t| s.min(t),
198            |&f, &x| f + x,
199            |&f, &g| f + g,
200            || 0,
201        );
202        let mut values = vec![0; n];
203        for i in 0..n {
204            values[i] = i as i64;
205            seg_min.set(i, i as i64);
206        }
207
208        let from = 1;
209        let to = 4;
210        let add = 2;
211        for i in from..to {
212            values[i] += add;
213        }
214        seg_min.apply_range(from..to, add);
215
216        let pos = 2;
217        let value = 1;
218        let cur = seg_min.prod(pos..(pos + 1));
219        seg_min.set(pos, cur - value);
220        values[pos] -= value;
221
222        for l in 0..n {
223            for r in (l + 1)..(n + 1) {
224                let min1 = seg_min.prod(l..r);
225                let &min2 = values[l..r].iter().min().unwrap();
226                assert_eq!(min1, min2);
227            }
228        }
229    }
230
231    #[test]
232    fn random_add() {
233        let mut rng = thread_rng();
234        let n = 32;
235        let mut array = vec![0; n];
236        let mut seg_min = LazySegmentTree::new(
237            n,
238            || INF,
239            |&s, &t| s.min(t),
240            |&f, &x| f + x,
241            |&f, &g| f + g,
242            || 0,
243        );
244        let mut seg_max = LazySegmentTree::new(
245            n,
246            || -INF,
247            |&s, &t| s.max(t),
248            |&f, &x| f + x,
249            |&f, &g| f + g,
250            || 0,
251        );
252        for i in 0..n {
253            let value = rng.gen_range(-1000, 1000);
254            array[i] = value;
255            seg_min.set(i, value);
256            seg_max.set(i, value);
257        }
258
259        for l in 0..n {
260            for r in (l + 1)..n {
261                let value = rng.gen_range(-1000, 1000);
262                seg_min.apply_range(l..r, value);
263                seg_max.apply_range(l..r, value);
264
265                for i in l..r {
266                    array[i] += value;
267                }
268
269                for l in 0..n {
270                    for r in (l + 1)..n {
271                        let mut min = INF;
272                        let mut max = -INF;
273                        for i in l..r {
274                            min = std::cmp::min(min, array[i]);
275                            max = std::cmp::max(max, array[i]);
276                        }
277
278                        assert_eq!(seg_min.prod(l..r), min);
279                        assert_eq!(seg_max.prod(l..r), max);
280                    }
281                }
282            }
283        }
284    }
285
286    #[test]
287    fn random_update() {
288        let mut rng = thread_rng();
289        #[derive(Clone)]
290        struct Num {
291            len: u32,
292            value: i64,
293        }
294        let n = 15;
295        let mut array = vec![0; n];
296        let mut seg_update = LazySegmentTree::new(
297            n,
298            || Num { len: 0, value: 0 },
299            |left: &Num, right: &Num| Num {
300                value: left.value * 10i64.pow(right.len) + right.value,
301                len: left.len + right.len,
302            },
303            |f: &Option<i64>, x: &Num| {
304                if let &Some(f) = f {
305                    let mut value = 0;
306                    for _ in 0..x.len {
307                        value = value * 10 + f;
308                    }
309                    Num { len: x.len, value }
310                } else {
311                    Num {
312                        len: x.len,
313                        value: x.value,
314                    }
315                }
316            },
317            |f: &Option<i64>, g: &Option<i64>| {
318                if f.is_some() {
319                    f.clone()
320                } else {
321                    g.clone()
322                }
323            },
324            || None,
325        );
326        for i in 0..n {
327            array[i] = 1;
328            seg_update.set(i, Num { len: 1, value: 1 });
329        }
330
331        for _ in 0..1000 {
332            let digit = rng.gen_range(0, 10);
333            let left = rng.gen_range(0, n);
334            let right = rng.gen_range(left + 1, n + 1);
335            for i in left..right {
336                array[i] = digit;
337            }
338            seg_update.apply_range(left..right, Some(digit));
339
340            let mut sum = 0;
341            for i in 0..n {
342                sum = sum * 10 + array[i];
343            }
344
345            assert_eq!(sum, seg_update.all_prod().value);
346        }
347    }
348}