competitive_programming_rs/data_structure/
fenwick_tree.rs

1pub mod fenwick_tree {
2    /// `FenwickTree` is a data structure that can efficiently update elements
3    /// and calculate prefix sums in a table of numbers.
4    /// [https://en.wikipedia.org/wiki/Fenwick_tree](https://en.wikipedia.org/wiki/Fenwick_tree)
5    pub struct FenwickTree<T, F> {
6        n: usize,
7        data: Vec<T>,
8        initialize: F,
9    }
10
11    impl<T, F> FenwickTree<T, F>
12    where
13        T: Copy + std::ops::AddAssign + std::ops::Sub<Output = T>,
14        F: Fn() -> T,
15    {
16        /// Constructs a new `FenwickTree`. The size of `FenwickTree` should be specified by `size`.
17        pub fn new(size: usize, initialize: F) -> FenwickTree<T, F> {
18            FenwickTree {
19                n: size + 1,
20                data: vec![initialize(); size + 1],
21                initialize,
22            }
23        }
24
25        pub fn add(&mut self, k: usize, value: T) {
26            let mut x = k;
27            while x < self.n {
28                self.data[x] += value;
29                x |= x + 1;
30            }
31        }
32
33        /// Returns a sum of range `[l, r)`
34        pub fn sum(&self, l: usize, r: usize) -> T {
35            self.sum_one(r) - self.sum_one(l)
36        }
37
38        /// Returns a sum of range `[0, k)`
39        pub fn sum_one(&self, k: usize) -> T {
40            assert!(k < self.n, "Cannot calculate for range [{}, {})", k, self.n);
41            let mut result = (self.initialize)();
42            let mut x = k as i32 - 1;
43            while x >= 0 {
44                result += self.data[x as usize];
45                x = (x & (x + 1)) - 1;
46            }
47
48            result
49        }
50    }
51}
52
53#[cfg(test)]
54mod test {
55    use super::fenwick_tree::FenwickTree;
56    use rand::{thread_rng, Rng};
57
58    #[test]
59    fn random_array() {
60        let n = 1000;
61        let mut bit = FenwickTree::new(n, || 0);
62        let mut v = vec![0; n];
63
64        for _ in 0..10000 {
65            let value = thread_rng().gen_range(0, 1000);
66            let k = thread_rng().gen_range(0, n);
67            v[k] += value;
68            bit.add(k, value);
69
70            let mut sum = 0;
71            for i in 0..n {
72                sum += v[i];
73                assert_eq!(sum, bit.sum(0, i + 1));
74            }
75        }
76    }
77}