libpuri/
seg_tree.rs

1use super::algebra::Monoid;
2use super::util::IntoIndex;
3use std::iter::FromIterator;
4use std::ptr;
5
6/// A segment tree that supports range query.
7///
8/// A segment tree is used when you want to query on properties of interval. For instance, you have
9/// a list of numbers and you want to get a minimum value of certain interval. If you compute it on
10/// the fly, it would require (m - 1) comparisons for the length m interval. We can use a segment tree
11/// to efficiently compute the minimum element.
12///
13/// Each node of a segment tree represents a union of intervals of their child nodes and each leaf
14/// node means an interval containing only one element. Following is an example segment tree of
15/// elements [1, 42, 16, 3, 5].
16/// <pre>
17///                          +-----+
18///                          |  1  |
19///                          +-----+
20///                          [0, 8)
21///                     /               \
22///                    /                 \
23///            +---+                         +---+
24///            | 1 |                         | 5 |
25///            +---+                         +---+  
26///            [0, 4)                        [4, 8)  
27///           /     \                       /     \
28///          /       \                     /       \
29///      +---+       +---+             +---+       +----+  
30///      | 1 |       | 3 |             | 5 |       | id |   
31///      +---+       +---+             +---+       +----+   
32///      [0, 2)      [2, 4)            [4, 6)      [6, 8)  
33///     /    |       |    \           /    |       |    \
34///    /     |       |     \         /     |       |     \
35/// +---+  +----+  +----+  +---+  +---+  +----+  +----+  +----+
36/// | 1 |  | 42 |  | 16 |  | 3 |  | 5 |  | id |  | id |  | id |
37/// +---+  +----+  +----+  +---+  +---+  +----+  +----+  +----+
38/// [0, 1) [1, 2)  [2, 3)  [3, 4) [4, 5) [5, 6)  [6, 7)  [7, 8)
39/// </pre>
40///
41/// When you update an element, it propagates from the leaf to the top. If we update 16 to 2, then
42/// the parent node would be updated to 3 -> 2, but the [0, 4) and root node won't be updated as 1
43/// is less than 2.
44///
45/// When querying, it visits non-overlapping nodes within the inteval and computes the minimum among
46/// the nodes. For example if we query minimum element in an interval [1, 4), it first visits [1, 2)
47/// node and then it visits [2, 4). Then it computes min(42, 3) which is 3. Note that it only visits
48/// two nodes at each height of the tree hence the time complexity O(log(n)).
49///
50/// # Use a segment tree when:
51/// - You want to efficiently query on an interval property.
52/// - You only update one element at a time.
53/// 
54/// # Requirements
55/// Here _the operation_ is how we get an interval property of a parent node from the
56/// child nodes. For instance, the minimum element within interval [0, 4) is minimum element of
57/// mimima from intervals [0, 2) and [2, 4) so the `min` is the operation.
58/// - The interval property has an identity with respect to the operation.
59/// - The operation is associative.
60/// - The interval property of a union of two disjoint intervals is the result of the performing
61/// operation on the interval properties of the two intervals.
62///
63/// [1]: https://en.wikipedia.org/wiki/Associative_property
64///
65/// In case of our example, every elements of `i32` is less than or equal to `i32::MAX` so we have an identity.
66/// And `min(a, min(b, c)) == min(min(a, b), c)` so it is associative.
67/// And if the minima of [a1, a2, ... , an] and [b1, b2, ... , bn] are a and b respectively, then
68/// the minimum of [a1, a2, ..., an, b1, b2, ... , bn] is min(a, b).
69/// Therefore we can use segment tree to efficiently query the minimum element of an interval.
70/// 
71/// To capture the requirements in the Rust programming language,
72/// a segment tree requires the elements to implement the [`Monoid`] trait.
73///
74/// # Performance
75/// Given n elements, it computes the interval property in O(log(n)) at the expense of O(log(n))
76/// update time.
77///
78/// If we were to store the elements with `Vec`, it would take O(m) for length m interval query
79/// and O(1) to update.
80///
81/// # Examples
82/// ```
83/// use libpuri::{Monoid, SegTree};
84///
85/// // We'll use the segment tree to compute interval sum.
86/// #[derive(Clone, Debug, PartialEq, Eq)]
87/// struct Sum(i64);
88///
89/// impl Monoid for Sum {
90///     const ID: Self = Sum(0);
91///     fn op(&self, rhs: &Self) -> Self { Sum(self.0 + rhs.0) }
92/// }
93///
94/// // Segment tree can be initialized from an iterator of the monoid
95/// let mut seg_tree: SegTree<Sum> = [1, 2, 3, 4, 5].iter().map(|&n| Sum(n)).collect();
96///
97/// // [1, 2, 3, 4]
98/// assert_eq!(seg_tree.get(0..4), Sum(10));
99///
100/// // [1, 2, 42, 4, 5]
101/// seg_tree.set(2, Sum(42));
102/// 
103/// // [1, 2, 42, 4]
104/// assert_eq!(seg_tree.get(0..4), Sum(49));
105/// ```
106// TODO(yujingaya) remove identity requirement with non-full binary tree?
107// Could be a semigroup
108// reference: https://codeforces.com/blog/entry/18051
109// An identity requirement could be lifted if we implement the tree with complete tree instead
110// of perfect tree, making this trait name a semigroup. But for the sake of simplicity, we
111// leave it this way for now.
112#[derive(Debug)]
113pub struct SegTree<M: Monoid + Clone>(Vec<M>);
114
115impl<M: Monoid + Clone> SegTree<M> {
116    fn size(&self) -> usize {
117        self.0.len() / 2
118    }
119}
120
121impl<M: Monoid + Clone> SegTree<M> {
122    /// Constructs a new segment tree with at least given number of interval propertiess can be stored.
123    ///
124    /// The segment tree will be initialized with the identity elements.
125    /// 
126    /// # Complexity
127    /// O(n).
128    /// 
129    /// If you know the initial elements in advance, [`from_iter_sized()`](SegTree::from_iter_sized) should be preferred over
130    /// `new()`.
131    ///
132    /// Initializing with the identity elements and updating n elements will tax you O(nlog(n)),
133    /// whereas `from_iter_sized()` is O(n) by computing the interval properties only once.
134    ///
135    /// # Examples
136    /// ```
137    /// use libpuri::{Monoid, SegTree};
138    ///
139    /// #[derive(Clone, Debug, PartialEq, Eq)]
140    /// struct MinMax(i64, i64);
141    ///
142    /// impl Monoid for MinMax {
143    ///     const ID: Self = Self(i64::MAX, i64::MIN);
144    ///     fn op(&self, rhs: &Self) -> Self { Self(self.0.min(rhs.0), self.1.max(rhs.1)) }
145    /// }
146    ///
147    /// // ⚠️ Use `from_iter_sized` whenever possible.
148    /// // See the Complexity paragraph above for how it differs.
149    /// let mut seg_tree: SegTree<MinMax> = SegTree::new(4);
150    ///
151    /// seg_tree.set(0, MinMax(1, 1));
152    /// assert_eq!(seg_tree.get(0..4), MinMax(1, 1));
153    ///
154    /// seg_tree.set(3, MinMax(4, 4));
155    /// assert_eq!(seg_tree.get(0..2), MinMax(1, 1));
156    /// assert_eq!(seg_tree.get(2..4), MinMax(4, 4));
157    ///
158    /// seg_tree.set(2, MinMax(3, 3));
159    /// assert_eq!(seg_tree.get(0..3), MinMax(1, 3));
160    /// assert_eq!(seg_tree.get(0..4), MinMax(1, 4));
161    /// ```
162    pub fn new(size: usize) -> Self {
163        SegTree(vec![M::ID; size.next_power_of_two() * 2])
164    }
165
166    /// Constructs a new segment tree with given interval properties.
167    ///
168    /// # Complexity
169    /// O(n).
170    ///
171    /// # Examples
172    /// ```
173    /// use libpuri::{Monoid, SegTree};
174    ///
175    /// #[derive(Clone, Debug, PartialEq, Eq)]
176    /// struct Sum(i64);
177    ///
178    /// impl Monoid for Sum {
179    ///     const ID: Self = Self(0);
180    ///     fn op(&self, rhs: &Self) -> Self { Self(self.0 + rhs.0) }
181    /// }
182    ///
183    /// let mut seg_tree: SegTree<Sum> = SegTree::from_iter_sized(
184    ///     [1, 2, 3, 42].iter().map(|&i| Sum(i)),
185    ///     4
186    /// );
187    ///
188    /// assert_eq!(seg_tree.get(0..4), Sum(48));
189    ///
190    /// // [1, 2, 3, 0]
191    /// seg_tree.set(3, Sum(0));
192    /// assert_eq!(seg_tree.get(0..2), Sum(3));
193    /// assert_eq!(seg_tree.get(2..4), Sum(3));
194    ///
195    /// // [1, 2, 2, 0]
196    /// seg_tree.set(2, Sum(2));
197    /// assert_eq!(seg_tree.get(0..3), Sum(5));
198    /// assert_eq!(seg_tree.get(0..4), Sum(5));
199    /// ```
200    pub fn from_iter_sized<I: IntoIterator<Item = M>>(iter: I, size: usize) -> Self {
201        let mut iter = iter.into_iter();
202        let size = size.next_power_of_two();
203        let mut v = Vec::with_capacity(size * 2);
204
205        let v_ptr: *mut M = v.as_mut_ptr();
206
207        unsafe {
208            v.set_len(size * 2);
209
210            for i in 0..size {
211                ptr::write(
212                    v_ptr.add(size + i),
213                    if let Some(m) = iter.next() { m } else { M::ID },
214                );
215            }
216
217            for i in (1..size).rev() {
218                ptr::write(v_ptr.add(i), v[i * 2].op(&v[i * 2 + 1]));
219            }
220        }
221
222        SegTree(v)
223    }
224
225    /// Queries on the given interval.
226    ///
227    /// Note that any [`RangeBounds`](std::ops::RangeBounds) can be used including
228    /// `..`, `a..`, `..b`, `..=c`, `d..e`, or `f..=g`.
229    /// You can just `seg_tree.get(..)` to get the interval property of the entire elements and
230    /// `seg_tree.get(a)` to get a specific element.
231    /// # Examples
232    /// ```
233    /// # use libpuri::{Monoid, SegTree};
234    /// # #[derive(Clone, Debug, PartialEq, Eq)]
235    /// # struct MinMax(i64, i64);
236    /// # impl Monoid for MinMax {
237    /// #     const ID: Self = Self(i64::MAX, i64::MIN);
238    /// #     fn op(&self, rhs: &Self) -> Self { Self(self.0.min(rhs.0), self.1.max(rhs.1)) }
239    /// # }
240    /// let mut seg_tree: SegTree<MinMax> = SegTree::new(4);
241    ///
242    /// assert_eq!(seg_tree.get(..), MinMax::ID);
243    ///
244    /// seg_tree.set(0, MinMax(42, 42));
245    ///
246    /// assert_eq!(seg_tree.get(..), MinMax(42, 42));
247    /// assert_eq!(seg_tree.get(1..), MinMax::ID);
248    /// assert_eq!(seg_tree.get(0), MinMax(42, 42));
249    /// ```
250    pub fn get<R>(&self, range: R) -> M
251    where
252        R: IntoIndex,
253    {
254        let (mut start, mut end) = range.into_index(self.size());
255        start += self.size();
256        end += self.size();
257
258        let mut m = M::ID;
259        while start < end {
260            if start % 2 == 1 {
261                m = self.0[start].op(&m);
262                start += 1;
263            }
264
265            if end % 2 == 1 {
266                end -= 1;
267                m = self.0[end].op(&m);
268            }
269
270            start /= 2;
271            end /= 2;
272        }
273
274        m
275    }
276
277    /// Sets an element with index i to a value m. It propagates its update to its parent.
278    ///
279    /// It takes O(log(n)) to propagate the update as the height of the tree is log(n).
280    ///
281    /// # Examples
282    /// ```
283    /// # use libpuri::{Monoid, SegTree};
284    /// # #[derive(Clone, Debug, PartialEq, Eq)]
285    /// # struct MinMax(i64, i64);
286    /// # impl Monoid for MinMax {
287    /// #     const ID: Self = Self(i64::MAX, i64::MIN);
288    /// #     fn op(&self, rhs: &Self) -> Self { Self(self.0.min(rhs.0), self.1.max(rhs.1)) }
289    /// # }
290    /// let mut seg_tree: SegTree<MinMax> = SegTree::new(4);
291    ///
292    /// seg_tree.set(0, MinMax(4, 4));
293    /// ```
294    pub fn set(&mut self, mut i: usize, m: M) {
295        i += self.size();
296        self.0[i] = m;
297
298        while i > 1 {
299            i /= 2;
300            self.0[i] = self.0[i * 2].op(&self.0[i * 2 + 1]);
301        }
302    }
303}
304
305/// You can `collect` into a segment tree.
306impl<M: Monoid + Clone> FromIterator<M> for SegTree<M> {
307    fn from_iter<I: IntoIterator<Item = M>>(iter: I) -> Self {
308        let v: Vec<M> = iter.into_iter().collect();
309        let len = v.len();
310
311        SegTree::from_iter_sized(v, len)
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn interval_sum() {
321        #[derive(Clone, Debug, PartialEq, Eq)]
322        struct Sum(i64);
323        impl Monoid for Sum {
324            const ID: Self = Sum(0);
325            fn op(&self, rhs: &Self) -> Self {
326                Sum(self.0 + rhs.0)
327            }
328        }
329
330        let mut seg_tree: SegTree<Sum> = (1..=5).map(|n| Sum(n)).collect();
331
332        for &(update, i, j) in [(true, 2, 6), (false, 1, 4), (true, 4, 2), (false, 2, 4)].iter() {
333            if update {
334                seg_tree.set(i, Sum(j as i64));
335            } else {
336                assert_eq!(seg_tree.get(i..=j).0, if i == 1 { 17 } else { 12 });
337            }
338        }
339    }
340}