libpuri/seg_tree.rs
1use super::algebra::Monoid;
2use super::util::IntoIndex;
3use std::iter::FromIterator;
4use std::ptr;
5
6/// A segment tree that supports range query.
7///
8/// A segment tree is used when you want to query on properties of interval. For instance, you have
9/// a list of numbers and you want to get a minimum value of certain interval. If you compute it on
10/// the fly, it would require (m - 1) comparisons for the length m interval. We can use a segment tree
11/// to efficiently compute the minimum element.
12///
13/// Each node of a segment tree represents a union of intervals of their child nodes and each leaf
14/// node means an interval containing only one element. Following is an example segment tree of
15/// elements [1, 42, 16, 3, 5].
16/// <pre>
17/// +-----+
18/// | 1 |
19/// +-----+
20/// [0, 8)
21/// / \
22/// / \
23/// +---+ +---+
24/// | 1 | | 5 |
25/// +---+ +---+
26/// [0, 4) [4, 8)
27/// / \ / \
28/// / \ / \
29/// +---+ +---+ +---+ +----+
30/// | 1 | | 3 | | 5 | | id |
31/// +---+ +---+ +---+ +----+
32/// [0, 2) [2, 4) [4, 6) [6, 8)
33/// / | | \ / | | \
34/// / | | \ / | | \
35/// +---+ +----+ +----+ +---+ +---+ +----+ +----+ +----+
36/// | 1 | | 42 | | 16 | | 3 | | 5 | | id | | id | | id |
37/// +---+ +----+ +----+ +---+ +---+ +----+ +----+ +----+
38/// [0, 1) [1, 2) [2, 3) [3, 4) [4, 5) [5, 6) [6, 7) [7, 8)
39/// </pre>
40///
41/// When you update an element, it propagates from the leaf to the top. If we update 16 to 2, then
42/// the parent node would be updated to 3 -> 2, but the [0, 4) and root node won't be updated as 1
43/// is less than 2.
44///
45/// When querying, it visits non-overlapping nodes within the inteval and computes the minimum among
46/// the nodes. For example if we query minimum element in an interval [1, 4), it first visits [1, 2)
47/// node and then it visits [2, 4). Then it computes min(42, 3) which is 3. Note that it only visits
48/// two nodes at each height of the tree hence the time complexity O(log(n)).
49///
50/// # Use a segment tree when:
51/// - You want to efficiently query on an interval property.
52/// - You only update one element at a time.
53///
54/// # Requirements
55/// Here _the operation_ is how we get an interval property of a parent node from the
56/// child nodes. For instance, the minimum element within interval [0, 4) is minimum element of
57/// mimima from intervals [0, 2) and [2, 4) so the `min` is the operation.
58/// - The interval property has an identity with respect to the operation.
59/// - The operation is associative.
60/// - The interval property of a union of two disjoint intervals is the result of the performing
61/// operation on the interval properties of the two intervals.
62///
63/// [1]: https://en.wikipedia.org/wiki/Associative_property
64///
65/// In case of our example, every elements of `i32` is less than or equal to `i32::MAX` so we have an identity.
66/// And `min(a, min(b, c)) == min(min(a, b), c)` so it is associative.
67/// And if the minima of [a1, a2, ... , an] and [b1, b2, ... , bn] are a and b respectively, then
68/// the minimum of [a1, a2, ..., an, b1, b2, ... , bn] is min(a, b).
69/// Therefore we can use segment tree to efficiently query the minimum element of an interval.
70///
71/// To capture the requirements in the Rust programming language,
72/// a segment tree requires the elements to implement the [`Monoid`] trait.
73///
74/// # Performance
75/// Given n elements, it computes the interval property in O(log(n)) at the expense of O(log(n))
76/// update time.
77///
78/// If we were to store the elements with `Vec`, it would take O(m) for length m interval query
79/// and O(1) to update.
80///
81/// # Examples
82/// ```
83/// use libpuri::{Monoid, SegTree};
84///
85/// // We'll use the segment tree to compute interval sum.
86/// #[derive(Clone, Debug, PartialEq, Eq)]
87/// struct Sum(i64);
88///
89/// impl Monoid for Sum {
90/// const ID: Self = Sum(0);
91/// fn op(&self, rhs: &Self) -> Self { Sum(self.0 + rhs.0) }
92/// }
93///
94/// // Segment tree can be initialized from an iterator of the monoid
95/// let mut seg_tree: SegTree<Sum> = [1, 2, 3, 4, 5].iter().map(|&n| Sum(n)).collect();
96///
97/// // [1, 2, 3, 4]
98/// assert_eq!(seg_tree.get(0..4), Sum(10));
99///
100/// // [1, 2, 42, 4, 5]
101/// seg_tree.set(2, Sum(42));
102///
103/// // [1, 2, 42, 4]
104/// assert_eq!(seg_tree.get(0..4), Sum(49));
105/// ```
106// TODO(yujingaya) remove identity requirement with non-full binary tree?
107// Could be a semigroup
108// reference: https://codeforces.com/blog/entry/18051
109// An identity requirement could be lifted if we implement the tree with complete tree instead
110// of perfect tree, making this trait name a semigroup. But for the sake of simplicity, we
111// leave it this way for now.
112#[derive(Debug)]
113pub struct SegTree<M: Monoid + Clone>(Vec<M>);
114
115impl<M: Monoid + Clone> SegTree<M> {
116 fn size(&self) -> usize {
117 self.0.len() / 2
118 }
119}
120
121impl<M: Monoid + Clone> SegTree<M> {
122 /// Constructs a new segment tree with at least given number of interval propertiess can be stored.
123 ///
124 /// The segment tree will be initialized with the identity elements.
125 ///
126 /// # Complexity
127 /// O(n).
128 ///
129 /// If you know the initial elements in advance, [`from_iter_sized()`](SegTree::from_iter_sized) should be preferred over
130 /// `new()`.
131 ///
132 /// Initializing with the identity elements and updating n elements will tax you O(nlog(n)),
133 /// whereas `from_iter_sized()` is O(n) by computing the interval properties only once.
134 ///
135 /// # Examples
136 /// ```
137 /// use libpuri::{Monoid, SegTree};
138 ///
139 /// #[derive(Clone, Debug, PartialEq, Eq)]
140 /// struct MinMax(i64, i64);
141 ///
142 /// impl Monoid for MinMax {
143 /// const ID: Self = Self(i64::MAX, i64::MIN);
144 /// fn op(&self, rhs: &Self) -> Self { Self(self.0.min(rhs.0), self.1.max(rhs.1)) }
145 /// }
146 ///
147 /// // ⚠️ Use `from_iter_sized` whenever possible.
148 /// // See the Complexity paragraph above for how it differs.
149 /// let mut seg_tree: SegTree<MinMax> = SegTree::new(4);
150 ///
151 /// seg_tree.set(0, MinMax(1, 1));
152 /// assert_eq!(seg_tree.get(0..4), MinMax(1, 1));
153 ///
154 /// seg_tree.set(3, MinMax(4, 4));
155 /// assert_eq!(seg_tree.get(0..2), MinMax(1, 1));
156 /// assert_eq!(seg_tree.get(2..4), MinMax(4, 4));
157 ///
158 /// seg_tree.set(2, MinMax(3, 3));
159 /// assert_eq!(seg_tree.get(0..3), MinMax(1, 3));
160 /// assert_eq!(seg_tree.get(0..4), MinMax(1, 4));
161 /// ```
162 pub fn new(size: usize) -> Self {
163 SegTree(vec![M::ID; size.next_power_of_two() * 2])
164 }
165
166 /// Constructs a new segment tree with given interval properties.
167 ///
168 /// # Complexity
169 /// O(n).
170 ///
171 /// # Examples
172 /// ```
173 /// use libpuri::{Monoid, SegTree};
174 ///
175 /// #[derive(Clone, Debug, PartialEq, Eq)]
176 /// struct Sum(i64);
177 ///
178 /// impl Monoid for Sum {
179 /// const ID: Self = Self(0);
180 /// fn op(&self, rhs: &Self) -> Self { Self(self.0 + rhs.0) }
181 /// }
182 ///
183 /// let mut seg_tree: SegTree<Sum> = SegTree::from_iter_sized(
184 /// [1, 2, 3, 42].iter().map(|&i| Sum(i)),
185 /// 4
186 /// );
187 ///
188 /// assert_eq!(seg_tree.get(0..4), Sum(48));
189 ///
190 /// // [1, 2, 3, 0]
191 /// seg_tree.set(3, Sum(0));
192 /// assert_eq!(seg_tree.get(0..2), Sum(3));
193 /// assert_eq!(seg_tree.get(2..4), Sum(3));
194 ///
195 /// // [1, 2, 2, 0]
196 /// seg_tree.set(2, Sum(2));
197 /// assert_eq!(seg_tree.get(0..3), Sum(5));
198 /// assert_eq!(seg_tree.get(0..4), Sum(5));
199 /// ```
200 pub fn from_iter_sized<I: IntoIterator<Item = M>>(iter: I, size: usize) -> Self {
201 let mut iter = iter.into_iter();
202 let size = size.next_power_of_two();
203 let mut v = Vec::with_capacity(size * 2);
204
205 let v_ptr: *mut M = v.as_mut_ptr();
206
207 unsafe {
208 v.set_len(size * 2);
209
210 for i in 0..size {
211 ptr::write(
212 v_ptr.add(size + i),
213 if let Some(m) = iter.next() { m } else { M::ID },
214 );
215 }
216
217 for i in (1..size).rev() {
218 ptr::write(v_ptr.add(i), v[i * 2].op(&v[i * 2 + 1]));
219 }
220 }
221
222 SegTree(v)
223 }
224
225 /// Queries on the given interval.
226 ///
227 /// Note that any [`RangeBounds`](std::ops::RangeBounds) can be used including
228 /// `..`, `a..`, `..b`, `..=c`, `d..e`, or `f..=g`.
229 /// You can just `seg_tree.get(..)` to get the interval property of the entire elements and
230 /// `seg_tree.get(a)` to get a specific element.
231 /// # Examples
232 /// ```
233 /// # use libpuri::{Monoid, SegTree};
234 /// # #[derive(Clone, Debug, PartialEq, Eq)]
235 /// # struct MinMax(i64, i64);
236 /// # impl Monoid for MinMax {
237 /// # const ID: Self = Self(i64::MAX, i64::MIN);
238 /// # fn op(&self, rhs: &Self) -> Self { Self(self.0.min(rhs.0), self.1.max(rhs.1)) }
239 /// # }
240 /// let mut seg_tree: SegTree<MinMax> = SegTree::new(4);
241 ///
242 /// assert_eq!(seg_tree.get(..), MinMax::ID);
243 ///
244 /// seg_tree.set(0, MinMax(42, 42));
245 ///
246 /// assert_eq!(seg_tree.get(..), MinMax(42, 42));
247 /// assert_eq!(seg_tree.get(1..), MinMax::ID);
248 /// assert_eq!(seg_tree.get(0), MinMax(42, 42));
249 /// ```
250 pub fn get<R>(&self, range: R) -> M
251 where
252 R: IntoIndex,
253 {
254 let (mut start, mut end) = range.into_index(self.size());
255 start += self.size();
256 end += self.size();
257
258 let mut m = M::ID;
259 while start < end {
260 if start % 2 == 1 {
261 m = self.0[start].op(&m);
262 start += 1;
263 }
264
265 if end % 2 == 1 {
266 end -= 1;
267 m = self.0[end].op(&m);
268 }
269
270 start /= 2;
271 end /= 2;
272 }
273
274 m
275 }
276
277 /// Sets an element with index i to a value m. It propagates its update to its parent.
278 ///
279 /// It takes O(log(n)) to propagate the update as the height of the tree is log(n).
280 ///
281 /// # Examples
282 /// ```
283 /// # use libpuri::{Monoid, SegTree};
284 /// # #[derive(Clone, Debug, PartialEq, Eq)]
285 /// # struct MinMax(i64, i64);
286 /// # impl Monoid for MinMax {
287 /// # const ID: Self = Self(i64::MAX, i64::MIN);
288 /// # fn op(&self, rhs: &Self) -> Self { Self(self.0.min(rhs.0), self.1.max(rhs.1)) }
289 /// # }
290 /// let mut seg_tree: SegTree<MinMax> = SegTree::new(4);
291 ///
292 /// seg_tree.set(0, MinMax(4, 4));
293 /// ```
294 pub fn set(&mut self, mut i: usize, m: M) {
295 i += self.size();
296 self.0[i] = m;
297
298 while i > 1 {
299 i /= 2;
300 self.0[i] = self.0[i * 2].op(&self.0[i * 2 + 1]);
301 }
302 }
303}
304
305/// You can `collect` into a segment tree.
306impl<M: Monoid + Clone> FromIterator<M> for SegTree<M> {
307 fn from_iter<I: IntoIterator<Item = M>>(iter: I) -> Self {
308 let v: Vec<M> = iter.into_iter().collect();
309 let len = v.len();
310
311 SegTree::from_iter_sized(v, len)
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn interval_sum() {
321 #[derive(Clone, Debug, PartialEq, Eq)]
322 struct Sum(i64);
323 impl Monoid for Sum {
324 const ID: Self = Sum(0);
325 fn op(&self, rhs: &Self) -> Self {
326 Sum(self.0 + rhs.0)
327 }
328 }
329
330 let mut seg_tree: SegTree<Sum> = (1..=5).map(|n| Sum(n)).collect();
331
332 for &(update, i, j) in [(true, 2, 6), (false, 1, 4), (true, 4, 2), (false, 2, 4)].iter() {
333 if update {
334 seg_tree.set(i, Sum(j as i64));
335 } else {
336 assert_eq!(seg_tree.get(i..=j).0, if i == 1 { 17 } else { 12 });
337 }
338 }
339 }
340}