ac_library/
fenwicktree.rs

1use std::ops::{Bound, RangeBounds};
2
3// Reference: https://en.wikipedia.org/wiki/Fenwick_tree
4pub struct FenwickTree<T> {
5    n: usize,
6    ary: Vec<T>,
7    e: T,
8}
9
10impl<T: Clone + std::ops::AddAssign<T>> FenwickTree<T> {
11    pub fn new(n: usize, e: T) -> Self {
12        FenwickTree {
13            n,
14            ary: vec![e.clone(); n],
15            e,
16        }
17    }
18    pub fn accum(&self, mut idx: usize) -> T {
19        let mut sum = self.e.clone();
20        while idx > 0 {
21            sum += self.ary[idx - 1].clone();
22            idx &= idx - 1;
23        }
24        sum
25    }
26    /// performs data[idx] += val;
27    pub fn add<U: Clone>(&mut self, mut idx: usize, val: U)
28    where
29        T: std::ops::AddAssign<U>,
30    {
31        let n = self.n;
32        idx += 1;
33        while idx <= n {
34            self.ary[idx - 1] += val.clone();
35            idx += idx & idx.wrapping_neg();
36        }
37    }
38    /// Returns data[l] + ... + data[r - 1].
39    pub fn sum<R>(&self, range: R) -> T
40    where
41        T: std::ops::Sub<Output = T>,
42        R: RangeBounds<usize>,
43    {
44        let r = match range.end_bound() {
45            Bound::Included(r) => r + 1,
46            Bound::Excluded(r) => *r,
47            Bound::Unbounded => self.n,
48        };
49        let l = match range.start_bound() {
50            Bound::Included(l) => *l,
51            Bound::Excluded(l) => l + 1,
52            Bound::Unbounded => return self.accum(r),
53        };
54        self.accum(r) - self.accum(l)
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61    use std::ops::Bound::*;
62
63    #[test]
64    fn fenwick_tree_works() {
65        let mut bit = FenwickTree::new(5, 0i64);
66        // [1, 2, 3, 4, 5]
67        for i in 0..5 {
68            bit.add(i, i as i64 + 1);
69        }
70        assert_eq!(bit.sum(0..5), 15);
71        assert_eq!(bit.sum(0..4), 10);
72        assert_eq!(bit.sum(1..3), 5);
73
74        assert_eq!(bit.sum(..), 15);
75        assert_eq!(bit.sum(..2), 3);
76        assert_eq!(bit.sum(..=2), 6);
77        assert_eq!(bit.sum(1..), 14);
78        assert_eq!(bit.sum(1..=3), 9);
79        assert_eq!(bit.sum((Excluded(0), Included(2))), 5);
80    }
81}