competitive_programming_rs/data_structure/
segment_tree.rs

1/// Segment Tree for range queries
2pub struct SegmentTree<T, Op> {
3    seg: Vec<Option<T>>,
4    n: usize,
5    op: Op,
6}
7
8impl<T, Op> SegmentTree<T, Op>
9where
10    T: Copy,
11    Op: Fn(T, T) -> T + Copy,
12{
13    pub fn new(size: usize, op: Op) -> SegmentTree<T, Op> {
14        let mut m = size.next_power_of_two();
15        if m == size {
16            m *= 2;
17        }
18        SegmentTree {
19            seg: vec![None; m * 2],
20            n: m,
21            op,
22        }
23    }
24
25    pub fn update(&mut self, k: usize, value: T) {
26        let mut k = k;
27        k += self.n - 1;
28        self.seg[k] = Some(value);
29        while k > 0 {
30            k = (k - 1) >> 1;
31            let left = self.seg[k * 2 + 1];
32            let right = self.seg[k * 2 + 2];
33            self.seg[k] = Self::op(left, right, self.op);
34        }
35    }
36
37    /// Get the result in the array of the range
38    pub fn query<R: std::ops::RangeBounds<usize>>(&self, range: R) -> Option<T> {
39        let start = match range.start_bound() {
40            std::ops::Bound::Included(t) => *t,
41            std::ops::Bound::Excluded(t) => *t+1,
42            std::ops::Bound::Unbounded => 0,
43        };
44
45        let end = match range.end_bound() {
46            std::ops::Bound::Included(t) => *t+1,
47            std::ops::Bound::Excluded(t) => *t,
48            std::ops::Bound::Unbounded => self.n,
49        };
50
51        self.query_range(start..end, 0, 0..self.n)
52    }
53
54    fn query_range(
55        &self,
56        range: std::ops::Range<usize>,
57        k: usize,
58        seg_range: std::ops::Range<usize>,
59    ) -> Option<T> {
60        if seg_range.end <= range.start || range.end <= seg_range.start {
61            None
62        } else if range.start <= seg_range.start && seg_range.end <= range.end {
63            self.seg[k]
64        } else {
65            let mid = (seg_range.start + seg_range.end) >> 1;
66            let x = self.query_range(range.clone(), k * 2 + 1, seg_range.start..mid);
67            let y = self.query_range(range, k * 2 + 2, mid..seg_range.end);
68            Self::op(x, y, self.op)
69        }
70    }
71
72    fn op(a: Option<T>, b: Option<T>, f: Op) -> Option<T> {
73        match (a, b) {
74            (Some(a), Some(b)) => Some(f(a, b)),
75            _ => a.or(b),
76        }
77    }
78}
79
80pub struct SegmentTree2d<T, Op> {
81    n: usize,
82    seg: Vec<SegmentTree<T, Op>>,
83    op: Op,
84}
85
86impl<T, Op> SegmentTree2d<T, Op>
87where
88    T: Copy,
89    Op: Fn(T, T) -> T + Copy,
90{
91    pub fn new(h: usize, w: usize, op: Op) -> Self {
92        let mut n = h.next_power_of_two();
93        if n == h {
94            n *= 2;
95        }
96        let mut seg = Vec::with_capacity(n * 2);
97        for _ in 0..(n * 2) {
98            seg.push(SegmentTree::new(w, op));
99        }
100        Self { seg, n, op }
101    }
102
103    pub fn update(&mut self, i: usize, j: usize, value: T) {
104        let mut k = i;
105        k += self.n - 1;
106        self.seg[k].update(j, value);
107        while k > 0 {
108            k = (k - 1) >> 1;
109            let left = self.seg[k * 2 + 1].query(j..(j + 1));
110            let right = self.seg[k * 2 + 2].query(j..(j + 1));
111            if let Some(value) = Self::op(left, right, self.op) {
112                self.seg[k].update(j, value);
113            }
114        }
115    }
116
117    pub fn query<C, R>(&self, r: R, c: C) -> Option<T>
118    where
119        C: std::ops::RangeBounds<usize>,
120        R: std::ops::RangeBounds<usize>,
121    {
122        let start = |s: std::ops::Bound<&usize>| match s {
123            std::ops::Bound::Included(t) => *t,
124            std::ops::Bound::Excluded(t) => *t+1,
125            std::ops::Bound::Unbounded => 0,
126        };
127
128        let end = |e: std::ops::Bound<&usize>| match e {
129            std::ops::Bound::Included(t) => *t+1,
130            std::ops::Bound::Excluded(t) => *t,
131            std::ops::Bound::Unbounded => self.n,
132        };
133
134        let r_start = start(r.start_bound());
135        let c_start = start(c.start_bound());
136        let r_end = end(r.end_bound());
137        let c_end = end(c.end_bound());
138
139        self.query_range(r_start..r_end, 0, 0..self.n, c_start..c_end)
140    }
141
142    fn query_range(
143        &self,
144        range: std::ops::Range<usize>,
145        k: usize,
146        seg_range: std::ops::Range<usize>,
147        c: std::ops::Range<usize>,
148    ) -> Option<T> {
149        if seg_range.end <= range.start || range.end <= seg_range.start {
150            None
151        } else if range.start <= seg_range.start && seg_range.end <= range.end {
152            self.seg[k].query(c)
153        } else {
154            let mid = (seg_range.start + seg_range.end) >> 1;
155            let x = self.query_range(range.clone(), k * 2 + 1, seg_range.start..mid, c.clone());
156            let y = self.query_range(range, k * 2 + 2, mid..seg_range.end, c);
157            Self::op(x, y, self.op)
158        }
159    }
160    fn op(a: Option<T>, b: Option<T>, f: Op) -> Option<T> {
161        match (a, b) {
162            (Some(a), Some(b)) => Some(f(a, b)),
163            _ => a.or(b),
164        }
165    }
166}
167
168#[cfg(test)]
169mod test {
170    use super::*;
171    use rand::prelude::*;
172
173    const INF: i64 = 1 << 60;
174
175    #[test]
176    fn random_array() {
177        const N: usize = 1000;
178        let mut rng = thread_rng();
179
180        for _ in 0..5 {
181            let mut arr = vec![0; N];
182            for i in 0..N {
183                arr[i] = rng.gen_range(0, INF);
184            }
185
186            let mut seg = SegmentTree::new(N, |a: i64, b: i64| a.min(b));
187            for i in 0..N {
188                let mut minimum = INF;
189                for j in 0..=i {
190                    minimum = minimum.min(arr[j]);
191                }
192                seg.update(i, arr[i]);
193                assert_eq!(seg.query(0..N), Some(minimum));
194                assert_eq!(seg.query(0..(i + 1)), Some(minimum));
195            }
196        }
197    }
198
199    #[test]
200    fn random_array_online_update() {
201        const N: usize = 1000;
202        let mut rng = thread_rng();
203
204        for _ in 0..5 {
205            let mut arr = vec![INF; N];
206            let mut seg = SegmentTree::new(N, |a: i64, b: i64| a.min(b));
207
208            for _ in 0..N {
209                let value = rng.gen_range(0, INF);
210                let k = rand::thread_rng().gen_range(0, N);
211                seg.update(k, value);
212
213                arr[k] = value;
214                let mut minimum = INF;
215                for i in 0..N {
216                    minimum = minimum.min(arr[i]);
217                }
218                assert_eq!(seg.query(0..N), Some(minimum));
219                assert_eq!(seg.query(0..=(N-1)), Some(minimum));
220            }
221
222            assert_eq!(seg.query(0..N), seg.query(0..=(N-1)));
223            assert_eq!(seg.query(0..N), seg.query(..));
224        }
225    }
226
227    #[test]
228    fn random_array_2d() {
229        const N: usize = 30;
230        let mut rng = thread_rng();
231
232        let mut arr = vec![vec![0; N]; N];
233        let mut seg = SegmentTree2d::new(N, N, |a: i64, b: i64| a.min(b));
234        for i in 0..N {
235            for j in 0..N {
236                arr[i][j] = rng.gen_range(0, INF);
237                seg.update(i, j, arr[i][j]);
238            }
239        }
240
241        for i1 in 0..N {
242            for j1 in 0..N {
243                for i2 in (i1 + 1)..=N {
244                    for j2 in (j1 + 1)..=N {
245                        let mut minimum = INF;
246
247                        for i in i1..i2 {
248                            for j in j1..j2 {
249                                minimum = minimum.min(arr[i][j]);
250                            }
251                        }
252
253                        assert_eq!(seg.query(i1..i2, j1..j2), Some(minimum));
254                        assert_eq!(seg.query(i1..=(i2-1), j1..j2), Some(minimum));
255                    }
256                }
257            }
258        }
259
260        assert_eq!(seg.query(0..N, ..), seg.query(.., 0..N));
261        assert_eq!(seg.query(0..N, ..), seg.query(0..=N, ..));
262    }
263}