competitive_programming_rs/data_structure/
segment_tree.rs1pub struct SegmentTree<T, Op> {
3 seg: Vec<Option<T>>,
4 n: usize,
5 op: Op,
6}
7
8impl<T, Op> SegmentTree<T, Op>
9where
10 T: Copy,
11 Op: Fn(T, T) -> T + Copy,
12{
13 pub fn new(size: usize, op: Op) -> SegmentTree<T, Op> {
14 let mut m = size.next_power_of_two();
15 if m == size {
16 m *= 2;
17 }
18 SegmentTree {
19 seg: vec![None; m * 2],
20 n: m,
21 op,
22 }
23 }
24
25 pub fn update(&mut self, k: usize, value: T) {
26 let mut k = k;
27 k += self.n - 1;
28 self.seg[k] = Some(value);
29 while k > 0 {
30 k = (k - 1) >> 1;
31 let left = self.seg[k * 2 + 1];
32 let right = self.seg[k * 2 + 2];
33 self.seg[k] = Self::op(left, right, self.op);
34 }
35 }
36
37 pub fn query<R: std::ops::RangeBounds<usize>>(&self, range: R) -> Option<T> {
39 let start = match range.start_bound() {
40 std::ops::Bound::Included(t) => *t,
41 std::ops::Bound::Excluded(t) => *t+1,
42 std::ops::Bound::Unbounded => 0,
43 };
44
45 let end = match range.end_bound() {
46 std::ops::Bound::Included(t) => *t+1,
47 std::ops::Bound::Excluded(t) => *t,
48 std::ops::Bound::Unbounded => self.n,
49 };
50
51 self.query_range(start..end, 0, 0..self.n)
52 }
53
54 fn query_range(
55 &self,
56 range: std::ops::Range<usize>,
57 k: usize,
58 seg_range: std::ops::Range<usize>,
59 ) -> Option<T> {
60 if seg_range.end <= range.start || range.end <= seg_range.start {
61 None
62 } else if range.start <= seg_range.start && seg_range.end <= range.end {
63 self.seg[k]
64 } else {
65 let mid = (seg_range.start + seg_range.end) >> 1;
66 let x = self.query_range(range.clone(), k * 2 + 1, seg_range.start..mid);
67 let y = self.query_range(range, k * 2 + 2, mid..seg_range.end);
68 Self::op(x, y, self.op)
69 }
70 }
71
72 fn op(a: Option<T>, b: Option<T>, f: Op) -> Option<T> {
73 match (a, b) {
74 (Some(a), Some(b)) => Some(f(a, b)),
75 _ => a.or(b),
76 }
77 }
78}
79
80pub struct SegmentTree2d<T, Op> {
81 n: usize,
82 seg: Vec<SegmentTree<T, Op>>,
83 op: Op,
84}
85
86impl<T, Op> SegmentTree2d<T, Op>
87where
88 T: Copy,
89 Op: Fn(T, T) -> T + Copy,
90{
91 pub fn new(h: usize, w: usize, op: Op) -> Self {
92 let mut n = h.next_power_of_two();
93 if n == h {
94 n *= 2;
95 }
96 let mut seg = Vec::with_capacity(n * 2);
97 for _ in 0..(n * 2) {
98 seg.push(SegmentTree::new(w, op));
99 }
100 Self { seg, n, op }
101 }
102
103 pub fn update(&mut self, i: usize, j: usize, value: T) {
104 let mut k = i;
105 k += self.n - 1;
106 self.seg[k].update(j, value);
107 while k > 0 {
108 k = (k - 1) >> 1;
109 let left = self.seg[k * 2 + 1].query(j..(j + 1));
110 let right = self.seg[k * 2 + 2].query(j..(j + 1));
111 if let Some(value) = Self::op(left, right, self.op) {
112 self.seg[k].update(j, value);
113 }
114 }
115 }
116
117 pub fn query<C, R>(&self, r: R, c: C) -> Option<T>
118 where
119 C: std::ops::RangeBounds<usize>,
120 R: std::ops::RangeBounds<usize>,
121 {
122 let start = |s: std::ops::Bound<&usize>| match s {
123 std::ops::Bound::Included(t) => *t,
124 std::ops::Bound::Excluded(t) => *t+1,
125 std::ops::Bound::Unbounded => 0,
126 };
127
128 let end = |e: std::ops::Bound<&usize>| match e {
129 std::ops::Bound::Included(t) => *t+1,
130 std::ops::Bound::Excluded(t) => *t,
131 std::ops::Bound::Unbounded => self.n,
132 };
133
134 let r_start = start(r.start_bound());
135 let c_start = start(c.start_bound());
136 let r_end = end(r.end_bound());
137 let c_end = end(c.end_bound());
138
139 self.query_range(r_start..r_end, 0, 0..self.n, c_start..c_end)
140 }
141
142 fn query_range(
143 &self,
144 range: std::ops::Range<usize>,
145 k: usize,
146 seg_range: std::ops::Range<usize>,
147 c: std::ops::Range<usize>,
148 ) -> Option<T> {
149 if seg_range.end <= range.start || range.end <= seg_range.start {
150 None
151 } else if range.start <= seg_range.start && seg_range.end <= range.end {
152 self.seg[k].query(c)
153 } else {
154 let mid = (seg_range.start + seg_range.end) >> 1;
155 let x = self.query_range(range.clone(), k * 2 + 1, seg_range.start..mid, c.clone());
156 let y = self.query_range(range, k * 2 + 2, mid..seg_range.end, c);
157 Self::op(x, y, self.op)
158 }
159 }
160 fn op(a: Option<T>, b: Option<T>, f: Op) -> Option<T> {
161 match (a, b) {
162 (Some(a), Some(b)) => Some(f(a, b)),
163 _ => a.or(b),
164 }
165 }
166}
167
168#[cfg(test)]
169mod test {
170 use super::*;
171 use rand::prelude::*;
172
173 const INF: i64 = 1 << 60;
174
175 #[test]
176 fn random_array() {
177 const N: usize = 1000;
178 let mut rng = thread_rng();
179
180 for _ in 0..5 {
181 let mut arr = vec![0; N];
182 for i in 0..N {
183 arr[i] = rng.gen_range(0, INF);
184 }
185
186 let mut seg = SegmentTree::new(N, |a: i64, b: i64| a.min(b));
187 for i in 0..N {
188 let mut minimum = INF;
189 for j in 0..=i {
190 minimum = minimum.min(arr[j]);
191 }
192 seg.update(i, arr[i]);
193 assert_eq!(seg.query(0..N), Some(minimum));
194 assert_eq!(seg.query(0..(i + 1)), Some(minimum));
195 }
196 }
197 }
198
199 #[test]
200 fn random_array_online_update() {
201 const N: usize = 1000;
202 let mut rng = thread_rng();
203
204 for _ in 0..5 {
205 let mut arr = vec![INF; N];
206 let mut seg = SegmentTree::new(N, |a: i64, b: i64| a.min(b));
207
208 for _ in 0..N {
209 let value = rng.gen_range(0, INF);
210 let k = rand::thread_rng().gen_range(0, N);
211 seg.update(k, value);
212
213 arr[k] = value;
214 let mut minimum = INF;
215 for i in 0..N {
216 minimum = minimum.min(arr[i]);
217 }
218 assert_eq!(seg.query(0..N), Some(minimum));
219 assert_eq!(seg.query(0..=(N-1)), Some(minimum));
220 }
221
222 assert_eq!(seg.query(0..N), seg.query(0..=(N-1)));
223 assert_eq!(seg.query(0..N), seg.query(..));
224 }
225 }
226
227 #[test]
228 fn random_array_2d() {
229 const N: usize = 30;
230 let mut rng = thread_rng();
231
232 let mut arr = vec![vec![0; N]; N];
233 let mut seg = SegmentTree2d::new(N, N, |a: i64, b: i64| a.min(b));
234 for i in 0..N {
235 for j in 0..N {
236 arr[i][j] = rng.gen_range(0, INF);
237 seg.update(i, j, arr[i][j]);
238 }
239 }
240
241 for i1 in 0..N {
242 for j1 in 0..N {
243 for i2 in (i1 + 1)..=N {
244 for j2 in (j1 + 1)..=N {
245 let mut minimum = INF;
246
247 for i in i1..i2 {
248 for j in j1..j2 {
249 minimum = minimum.min(arr[i][j]);
250 }
251 }
252
253 assert_eq!(seg.query(i1..i2, j1..j2), Some(minimum));
254 assert_eq!(seg.query(i1..=(i2-1), j1..j2), Some(minimum));
255 }
256 }
257 }
258 }
259
260 assert_eq!(seg.query(0..N, ..), seg.query(.., 0..N));
261 assert_eq!(seg.query(0..N, ..), seg.query(0..=N, ..));
262 }
263}