1use crate::internal_bit::ceil_pow2;
2use crate::internal_type_traits::{BoundedAbove, BoundedBelow, One, Zero};
3use std::cmp::{max, min};
4use std::convert::Infallible;
5use std::iter::FromIterator;
6use std::marker::PhantomData;
7use std::ops::{Add, BitAnd, BitOr, BitXor, Bound, Mul, Not, RangeBounds};
8
9pub trait Monoid {
11 type S: Clone;
12 fn identity() -> Self::S;
13 fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S;
14}
15
16pub struct Max<S>(Infallible, PhantomData<fn() -> S>);
17impl<S> Monoid for Max<S>
18where
19 S: Copy + Ord + BoundedBelow,
20{
21 type S = S;
22 fn identity() -> Self::S {
23 S::min_value()
24 }
25 fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
26 max(*a, *b)
27 }
28}
29
30pub struct Min<S>(Infallible, PhantomData<fn() -> S>);
31impl<S> Monoid for Min<S>
32where
33 S: Copy + Ord + BoundedAbove,
34{
35 type S = S;
36 fn identity() -> Self::S {
37 S::max_value()
38 }
39 fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
40 min(*a, *b)
41 }
42}
43
44pub struct Additive<S>(Infallible, PhantomData<fn() -> S>);
45impl<S> Monoid for Additive<S>
46where
47 S: Copy + Add<Output = S> + Zero,
48{
49 type S = S;
50 fn identity() -> Self::S {
51 S::zero()
52 }
53 fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
54 *a + *b
55 }
56}
57
58pub struct Multiplicative<S>(Infallible, PhantomData<fn() -> S>);
59impl<S> Monoid for Multiplicative<S>
60where
61 S: Copy + Mul<Output = S> + One,
62{
63 type S = S;
64 fn identity() -> Self::S {
65 S::one()
66 }
67 fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
68 *a * *b
69 }
70}
71
72pub struct BitwiseOr<S>(Infallible, PhantomData<fn() -> S>);
73impl<S> Monoid for BitwiseOr<S>
74where
75 S: Copy + BitOr<Output = S> + Zero,
76{
77 type S = S;
78 fn identity() -> Self::S {
79 S::zero()
80 }
81 fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
82 *a | *b
83 }
84}
85
86pub struct BitwiseAnd<S>(Infallible, PhantomData<fn() -> S>);
87impl<S> Monoid for BitwiseAnd<S>
88where
89 S: Copy + BitAnd<Output = S> + Not<Output = S> + Zero,
90{
91 type S = S;
92 fn identity() -> Self::S {
93 !S::zero()
94 }
95 fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
96 *a & *b
97 }
98}
99
100pub struct BitwiseXor<S>(Infallible, PhantomData<fn() -> S>);
101impl<S> Monoid for BitwiseXor<S>
102where
103 S: Copy + BitXor<Output = S> + Zero,
104{
105 type S = S;
106 fn identity() -> Self::S {
107 S::zero()
108 }
109 fn binary_operation(a: &Self::S, b: &Self::S) -> Self::S {
110 *a ^ *b
111 }
112}
113
114impl<M: Monoid> Default for Segtree<M> {
115 fn default() -> Self {
116 Segtree::new(0)
117 }
118}
119impl<M: Monoid> Segtree<M> {
120 pub fn new(n: usize) -> Segtree<M> {
121 vec![M::identity(); n].into()
122 }
123}
124impl<M: Monoid> From<Vec<M::S>> for Segtree<M> {
125 fn from(v: Vec<M::S>) -> Self {
126 let n = v.len();
127 let log = ceil_pow2(n as u32) as usize;
128 let size = 1 << log;
129 let mut d = vec![M::identity(); 2 * size];
130 d[size..][..n].clone_from_slice(&v);
131 let mut ret = Segtree { n, size, log, d };
132 for i in (1..size).rev() {
133 ret.update(i);
134 }
135 ret
136 }
137}
138impl<M: Monoid> FromIterator<M::S> for Segtree<M> {
139 fn from_iter<T: IntoIterator<Item = M::S>>(iter: T) -> Self {
140 let v = iter.into_iter().collect::<Vec<_>>();
141 v.into()
142 }
143}
144impl<M: Monoid> Segtree<M> {
145 pub fn set(&mut self, mut p: usize, x: M::S) {
146 assert!(p < self.n);
147 p += self.size;
148 self.d[p] = x;
149 for i in 1..=self.log {
150 self.update(p >> i);
151 }
152 }
153
154 pub fn get(&self, p: usize) -> M::S {
155 assert!(p < self.n);
156 self.d[p + self.size].clone()
157 }
158
159 pub fn get_slice(&self) -> &[M::S] {
160 &self.d[self.size..][..self.n]
161 }
162
163 pub fn prod<R>(&self, range: R) -> M::S
164 where
165 R: RangeBounds<usize>,
166 {
167 if range.start_bound() == Bound::Unbounded && range.end_bound() == Bound::Unbounded {
169 return self.all_prod();
170 }
171
172 let mut r = match range.end_bound() {
173 Bound::Included(r) => r + 1,
174 Bound::Excluded(r) => *r,
175 Bound::Unbounded => self.n,
176 };
177 let mut l = match range.start_bound() {
178 Bound::Included(l) => *l,
179 Bound::Excluded(l) => l + 1,
180 Bound::Unbounded => 0,
182 };
183
184 assert!(l <= r && r <= self.n);
185 let mut sml = M::identity();
186 let mut smr = M::identity();
187 l += self.size;
188 r += self.size;
189
190 while l < r {
191 if l & 1 != 0 {
192 sml = M::binary_operation(&sml, &self.d[l]);
193 l += 1;
194 }
195 if r & 1 != 0 {
196 r -= 1;
197 smr = M::binary_operation(&self.d[r], &smr);
198 }
199 l >>= 1;
200 r >>= 1;
201 }
202
203 M::binary_operation(&sml, &smr)
204 }
205
206 pub fn all_prod(&self) -> M::S {
207 self.d[1].clone()
208 }
209
210 pub fn max_right<F>(&self, mut l: usize, f: F) -> usize
211 where
212 F: Fn(&M::S) -> bool,
213 {
214 assert!(l <= self.n);
215 assert!(f(&M::identity()));
216 if l == self.n {
217 return self.n;
218 }
219 l += self.size;
220 let mut sm = M::identity();
221 while {
222 while l % 2 == 0 {
224 l >>= 1;
225 }
226 if !f(&M::binary_operation(&sm, &self.d[l])) {
227 while l < self.size {
228 l *= 2;
229 let res = M::binary_operation(&sm, &self.d[l]);
230 if f(&res) {
231 sm = res;
232 l += 1;
233 }
234 }
235 return l - self.size;
236 }
237 sm = M::binary_operation(&sm, &self.d[l]);
238 l += 1;
239 {
241 let l = l as isize;
242 (l & -l) != l
243 }
244 } {}
245 self.n
246 }
247
248 pub fn min_left<F>(&self, mut r: usize, f: F) -> usize
249 where
250 F: Fn(&M::S) -> bool,
251 {
252 assert!(r <= self.n);
253 assert!(f(&M::identity()));
254 if r == 0 {
255 return 0;
256 }
257 r += self.size;
258 let mut sm = M::identity();
259 while {
260 r -= 1;
262 while r > 1 && r % 2 == 1 {
263 r >>= 1;
264 }
265 if !f(&M::binary_operation(&self.d[r], &sm)) {
266 while r < self.size {
267 r = 2 * r + 1;
268 let res = M::binary_operation(&self.d[r], &sm);
269 if f(&res) {
270 sm = res;
271 r -= 1;
272 }
273 }
274 return r + 1 - self.size;
275 }
276 sm = M::binary_operation(&self.d[r], &sm);
277 {
279 let r = r as isize;
280 (r & -r) != r
281 }
282 } {}
283 0
284 }
285
286 fn update(&mut self, k: usize) {
287 self.d[k] = M::binary_operation(&self.d[2 * k], &self.d[2 * k + 1]);
288 }
289}
290
291#[derive(Clone)]
302pub struct Segtree<M>
303where
304 M: Monoid,
305{
306 n: usize,
308 size: usize,
309 log: usize,
310 d: Vec<M::S>,
311}
312
313#[cfg(test)]
314mod tests {
315 use crate::segtree::Max;
316 use crate::Segtree;
317 use std::ops::{Bound::*, RangeBounds};
318
319 #[test]
320 fn test_max_segtree() {
321 let base = vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3];
322 let n = base.len();
323 let segtree: Segtree<Max<_>> = base.clone().into();
324 check_segtree(&base, &segtree);
325
326 let mut segtree = Segtree::<Max<_>>::new(n);
327 let mut internal = vec![i32::MIN; n];
328 for i in 0..n {
329 segtree.set(i, base[i]);
330 internal[i] = base[i];
331 check_segtree(&internal, &segtree);
332 }
333
334 segtree.set(6, 5);
335 internal[6] = 5;
336 check_segtree(&internal, &segtree);
337
338 segtree.set(6, 0);
339 internal[6] = 0;
340 check_segtree(&internal, &segtree);
341 }
342
343 #[test]
344 fn test_segtree_fromiter() {
345 let v = [1, 4, 1, 4, 2, 1, 3, 5, 6];
346 let base = v
347 .iter()
348 .copied()
349 .filter(|&x| x % 2 == 0)
350 .collect::<Vec<_>>();
351 let segtree: Segtree<Max<_>> = v.iter().copied().filter(|&x| x % 2 == 0).collect();
352 check_segtree(&base, &segtree);
353 }
354
355 fn check_segtree(base: &[i32], segtree: &Segtree<Max<i32>>) {
357 let n = base.len();
358 #[allow(clippy::needless_range_loop)]
359 for i in 0..n {
360 assert_eq!(segtree.get(i), base[i]);
361 }
362
363 check(base, segtree, ..);
364 for i in 0..=n {
365 check(base, segtree, ..i);
366 check(base, segtree, i..);
367 if i < n {
368 check(base, segtree, ..=i);
369 }
370 for j in i..=n {
371 check(base, segtree, i..j);
372 if j < n {
373 check(base, segtree, i..=j);
374 check(base, segtree, (Excluded(i), Included(j)));
375 }
376 }
377 }
378 assert_eq!(
379 segtree.all_prod(),
380 base.iter().max().copied().unwrap_or(i32::MAX)
381 );
382 for k in 0..=10 {
383 let f = |&x: &i32| x < k;
384 for i in 0..=n {
385 assert_eq!(
386 Some(segtree.max_right(i, f)),
387 (i..=n)
388 .filter(|&j| f(&base[i..j].iter().max().copied().unwrap_or(i32::MIN)))
389 .max()
390 );
391 }
392 for j in 0..=n {
393 assert_eq!(
394 Some(segtree.min_left(j, f)),
395 (0..=j)
396 .filter(|&i| f(&base[i..j].iter().max().copied().unwrap_or(i32::MIN)))
397 .min()
398 );
399 }
400 }
401 }
402
403 fn check(base: &[i32], segtree: &Segtree<Max<i32>>, range: impl RangeBounds<usize>) {
404 let expected = base
405 .iter()
406 .enumerate()
407 .filter_map(|(i, a)| Some(a).filter(|_| range.contains(&i)))
408 .max()
409 .copied()
410 .unwrap_or(i32::MIN);
411 assert_eq!(segtree.prod(range), expected);
412 }
413}