ac_library/
fenwicktree.rs

1use std::ops::{Bound, RangeBounds};
2
3// Reference: https://en.wikipedia.org/wiki/Fenwick_tree
4#[derive(Clone)]
5pub struct FenwickTree<T> {
6    n: usize,
7    ary: Vec<T>,
8    e: T,
9}
10
11impl<T: Clone + std::ops::AddAssign<T>> FenwickTree<T> {
12    pub fn new(n: usize, e: T) -> Self {
13        FenwickTree {
14            n,
15            ary: vec![e.clone(); n],
16            e,
17        }
18    }
19    pub fn accum(&self, mut idx: usize) -> T {
20        let mut sum = self.e.clone();
21        while idx > 0 {
22            sum += self.ary[idx - 1].clone();
23            idx &= idx - 1;
24        }
25        sum
26    }
27    /// performs data[idx] += val;
28    pub fn add<U: Clone>(&mut self, mut idx: usize, val: U)
29    where
30        T: std::ops::AddAssign<U>,
31    {
32        let n = self.n;
33        idx += 1;
34        while idx <= n {
35            self.ary[idx - 1] += val.clone();
36            idx += idx & idx.wrapping_neg();
37        }
38    }
39    /// Returns data[l] + ... + data[r - 1].
40    pub fn sum<R>(&self, range: R) -> T
41    where
42        T: std::ops::Sub<Output = T>,
43        R: RangeBounds<usize>,
44    {
45        let r = match range.end_bound() {
46            Bound::Included(r) => r + 1,
47            Bound::Excluded(r) => *r,
48            Bound::Unbounded => self.n,
49        };
50        let l = match range.start_bound() {
51            Bound::Included(l) => *l,
52            Bound::Excluded(l) => l + 1,
53            Bound::Unbounded => return self.accum(r),
54        };
55        self.accum(r) - self.accum(l)
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62    use std::ops::Bound::*;
63
64    #[test]
65    fn fenwick_tree_works() {
66        let mut bit = FenwickTree::new(5, 0i64);
67        // [1, 2, 3, 4, 5]
68        for i in 0..5 {
69            bit.add(i, i as i64 + 1);
70        }
71        assert_eq!(bit.sum(0..5), 15);
72        assert_eq!(bit.sum(0..4), 10);
73        assert_eq!(bit.sum(1..3), 5);
74
75        assert_eq!(bit.sum(..), 15);
76        assert_eq!(bit.sum(..2), 3);
77        assert_eq!(bit.sum(..=2), 6);
78        assert_eq!(bit.sum(1..), 14);
79        assert_eq!(bit.sum(1..=3), 9);
80        assert_eq!(bit.sum((Excluded(0), Included(2))), 5);
81    }
82}