1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#[derive(Debug, Default)]
struct BinaryIndexedTree<T> {
    n: usize,
    nodes: Vec<T>,
}

impl<T> BinaryIndexedTree<T>
where
    T: Default + Clone + std::ops::AddAssign + std::ops::Sub<Output = T>,
{
    pub fn new(n: usize) -> BinaryIndexedTree<T> {
        BinaryIndexedTree {
            n,
            nodes: vec![Default::default(); n + 1],
        }
    }

    fn from(v: Vec<T>) -> BinaryIndexedTree<T> {
        let n = v.len();
        let nodes: Vec<T> = vec![Default::default(); n + 1];
        let mut ans = BinaryIndexedTree { n, nodes };
        for i in 0..n {
            ans.update(i, v[i].clone())
        }
        ans
    }
    pub fn update(&mut self, idx: usize, delta: T) {
        let mut i = (idx + 1) as i32;
        while i <= self.n as i32 {
            self.nodes[i as usize] += delta.clone();
            i += i & (-i);
        }
    }

    pub fn sum_to(&mut self, idx: usize) -> T {
        let mut i = (idx + 1) as i32;
        let mut sum = Default::default();
        while i > 0 {
            sum += self.nodes[i as usize].clone();
            i -= i & (-i);
        }
        sum
    }

    pub fn sum_of_range(&mut self, idx: usize, end: usize) -> T {
        self.sum_to(end)
            - if idx == 0 {
                Default::default()
            } else {
                self.sum_to(idx - 1)
            }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_bit_sum() {
        let mut bit = BinaryIndexedTree::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
        println!("{:?}", bit);
        assert_eq!(bit.sum_to(0), 1);
        assert_eq!(bit.sum_to(1), 3);
        assert_eq!(bit.sum_to(2), 6);
        assert_eq!(bit.sum_to(3), 10);
        assert_eq!(bit.sum_to(4), 15);
        assert_eq!(bit.sum_to(5), 21);
        assert_eq!(bit.sum_to(6), 28);
        assert_eq!(bit.sum_to(7), 36);
    }

    #[test]
    fn test_bit_sum_of_range() {
        let mut bit = BinaryIndexedTree::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
        println!("{:?}", bit);
        assert_eq!(bit.sum_of_range(0, 0), 1);
        assert_eq!(bit.sum_of_range(1, 1), 2);
        assert_eq!(bit.sum_of_range(0, 1), 3);
        assert_eq!(bit.sum_of_range(0, 2), 6);
        assert_eq!(bit.sum_of_range(2, 3), 7);
        assert_eq!(bit.sum_of_range(0, 7), 36);
        assert_eq!(bit.sum_of_range(6, 7), 15);
    }
}