competitive_programming_rs/data_structure/
lazy_segment_tree.rs1pub mod lazy_segment_tree {
2 type Range = std::ops::Range<usize>;
3
4 pub struct LazySegmentTree<S, Op, E, F, Mapping, Composition, Id> {
5 n: usize,
6 size: usize,
7 log: usize,
8 data: Vec<S>,
9 lazy: Vec<F>,
10 op: Op,
11 e: E,
12 mapping: Mapping,
13 composition: Composition,
14 id: Id,
15 }
16
17 impl<S, Op, E, F, Mapping, Composition, Id> LazySegmentTree<S, Op, E, F, Mapping, Composition, Id>
18 where
19 S: Clone,
20 E: Fn() -> S,
21 F: Clone,
22 Op: Fn(&S, &S) -> S,
23 Mapping: Fn(&F, &S) -> S,
24 Composition: Fn(&F, &F) -> F,
25 Id: Fn() -> F,
26 {
27 pub fn new(
28 n: usize,
29 e: E,
30 op: Op,
31 mapping: Mapping,
32 composition: Composition,
33 id: Id,
34 ) -> Self {
35 let size = n.next_power_of_two() as usize;
36 LazySegmentTree {
37 n,
38 size,
39 log: size.trailing_zeros() as usize,
40 data: vec![e(); 2 * size],
41 lazy: vec![id(); size],
42 e,
43 op,
44 mapping,
45 composition,
46 id,
47 }
48 }
49 pub fn set(&mut self, mut index: usize, value: S) {
50 assert!(index < self.n);
51 index += self.size;
52 for i in (1..=self.log).rev() {
53 self.push(index >> i);
54 }
55 self.data[index] = value;
56 for i in 1..=self.log {
57 self.update(index >> i);
58 }
59 }
60
61 pub fn get(&mut self, mut index: usize) -> S {
62 assert!(index < self.n);
63 index += self.size;
64 for i in (1..=self.log).rev() {
65 self.push(index >> i);
66 }
67 self.data[index].clone()
68 }
69
70 pub fn prod(&mut self, range: Range) -> S {
71 let mut l = range.start;
72 let mut r = range.end;
73 assert!(l < r && r <= self.n);
74
75 l += self.size;
76 r += self.size;
77
78 for i in (1..=self.log).rev() {
79 if ((l >> i) << i) != l {
80 self.push(l >> i);
81 }
82 if ((r >> i) << i) != r {
83 self.push(r >> i);
84 }
85 }
86
87 let mut sum_l = (self.e)();
88 let mut sum_r = (self.e)();
89 while l < r {
90 if l & 1 != 0 {
91 sum_l = (self.op)(&sum_l, &self.data[l]);
92 l += 1;
93 }
94 if r & 1 != 0 {
95 r -= 1;
96 sum_r = (self.op)(&self.data[r], &sum_r);
97 }
98 l >>= 1;
99 r >>= 1;
100 }
101
102 (self.op)(&sum_l, &sum_r)
103 }
104
105 pub fn all_prod(&self) -> S {
106 self.data[1].clone()
107 }
108
109 pub fn apply(&mut self, mut index: usize, f: F) {
110 assert!(index < self.n);
111 index += self.size;
112 for i in (1..=self.log).rev() {
113 self.push(index >> i);
114 }
115 self.data[index] = (self.mapping)(&f, &self.data[index]);
116 for i in 1..=self.log {
117 self.update(index >> i);
118 }
119 }
120 pub fn apply_range(&mut self, range: Range, f: F) {
121 let mut l = range.start;
122 let mut r = range.end;
123 assert!(l <= r && r <= self.n);
124 if l == r {
125 return;
126 }
127
128 l += self.size;
129 r += self.size;
130
131 for i in (1..=self.log).rev() {
132 if ((l >> i) << i) != l {
133 self.push(l >> i);
134 }
135 if ((r >> i) << i) != r {
136 self.push((r - 1) >> i);
137 }
138 }
139
140 {
141 let mut l = l;
142 let mut r = r;
143 while l < r {
144 if l & 1 != 0 {
145 self.all_apply(l, f.clone());
146 l += 1;
147 }
148 if r & 1 != 0 {
149 r -= 1;
150 self.all_apply(r, f.clone());
151 }
152 l >>= 1;
153 r >>= 1;
154 }
155 }
156
157 for i in 1..=self.log {
158 if ((l >> i) << i) != l {
159 self.update(l >> i);
160 }
161 if ((r >> i) << i) != r {
162 self.update((r - 1) >> i);
163 }
164 }
165 }
166
167 fn update(&mut self, k: usize) {
168 self.data[k] = (self.op)(&self.data[2 * k], &self.data[2 * k + 1]);
169 }
170 fn all_apply(&mut self, k: usize, f: F) {
171 self.data[k] = (self.mapping)(&f, &self.data[k]);
172 if k < self.size {
173 self.lazy[k] = (self.composition)(&f, &self.lazy[k]);
174 }
175 }
176 fn push(&mut self, k: usize) {
177 self.all_apply(2 * k, self.lazy[k].clone());
178 self.all_apply(2 * k + 1, self.lazy[k].clone());
179 self.lazy[k] = (self.id)();
180 }
181 }
182}
183
184#[cfg(test)]
185mod test {
186 use super::lazy_segment_tree::*;
187 use rand::prelude::*;
188
189 const INF: i64 = 1 << 60;
190
191 #[test]
192 fn edge_case() {
193 let n = 5;
194 let mut seg_min = LazySegmentTree::new(
195 n,
196 || INF,
197 |&s, &t| s.min(t),
198 |&f, &x| f + x,
199 |&f, &g| f + g,
200 || 0,
201 );
202 let mut values = vec![0; n];
203 for i in 0..n {
204 values[i] = i as i64;
205 seg_min.set(i, i as i64);
206 }
207
208 let from = 1;
209 let to = 4;
210 let add = 2;
211 for i in from..to {
212 values[i] += add;
213 }
214 seg_min.apply_range(from..to, add);
215
216 let pos = 2;
217 let value = 1;
218 let cur = seg_min.prod(pos..(pos + 1));
219 seg_min.set(pos, cur - value);
220 values[pos] -= value;
221
222 for l in 0..n {
223 for r in (l + 1)..(n + 1) {
224 let min1 = seg_min.prod(l..r);
225 let &min2 = values[l..r].iter().min().unwrap();
226 assert_eq!(min1, min2);
227 }
228 }
229 }
230
231 #[test]
232 fn random_add() {
233 let mut rng = thread_rng();
234 let n = 32;
235 let mut array = vec![0; n];
236 let mut seg_min = LazySegmentTree::new(
237 n,
238 || INF,
239 |&s, &t| s.min(t),
240 |&f, &x| f + x,
241 |&f, &g| f + g,
242 || 0,
243 );
244 let mut seg_max = LazySegmentTree::new(
245 n,
246 || -INF,
247 |&s, &t| s.max(t),
248 |&f, &x| f + x,
249 |&f, &g| f + g,
250 || 0,
251 );
252 for i in 0..n {
253 let value = rng.gen_range(-1000, 1000);
254 array[i] = value;
255 seg_min.set(i, value);
256 seg_max.set(i, value);
257 }
258
259 for l in 0..n {
260 for r in (l + 1)..n {
261 let value = rng.gen_range(-1000, 1000);
262 seg_min.apply_range(l..r, value);
263 seg_max.apply_range(l..r, value);
264
265 for i in l..r {
266 array[i] += value;
267 }
268
269 for l in 0..n {
270 for r in (l + 1)..n {
271 let mut min = INF;
272 let mut max = -INF;
273 for i in l..r {
274 min = std::cmp::min(min, array[i]);
275 max = std::cmp::max(max, array[i]);
276 }
277
278 assert_eq!(seg_min.prod(l..r), min);
279 assert_eq!(seg_max.prod(l..r), max);
280 }
281 }
282 }
283 }
284 }
285
286 #[test]
287 fn random_update() {
288 let mut rng = thread_rng();
289 #[derive(Clone)]
290 struct Num {
291 len: u32,
292 value: i64,
293 }
294 let n = 15;
295 let mut array = vec![0; n];
296 let mut seg_update = LazySegmentTree::new(
297 n,
298 || Num { len: 0, value: 0 },
299 |left: &Num, right: &Num| Num {
300 value: left.value * 10i64.pow(right.len) + right.value,
301 len: left.len + right.len,
302 },
303 |f: &Option<i64>, x: &Num| {
304 if let &Some(f) = f {
305 let mut value = 0;
306 for _ in 0..x.len {
307 value = value * 10 + f;
308 }
309 Num { len: x.len, value }
310 } else {
311 Num {
312 len: x.len,
313 value: x.value,
314 }
315 }
316 },
317 |f: &Option<i64>, g: &Option<i64>| {
318 if f.is_some() {
319 f.clone()
320 } else {
321 g.clone()
322 }
323 },
324 || None,
325 );
326 for i in 0..n {
327 array[i] = 1;
328 seg_update.set(i, Num { len: 1, value: 1 });
329 }
330
331 for _ in 0..1000 {
332 let digit = rng.gen_range(0, 10);
333 let left = rng.gen_range(0, n);
334 let right = rng.gen_range(left + 1, n + 1);
335 for i in left..right {
336 array[i] = digit;
337 }
338 seg_update.apply_range(left..right, Some(digit));
339
340 let mut sum = 0;
341 for i in 0..n {
342 sum = sum * 10 + array[i];
343 }
344
345 assert_eq!(sum, seg_update.all_prod().value);
346 }
347 }
348}