1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
#[derive(Debug, Default)] struct BinaryIndexedTree<T> { n: usize, nodes: Vec<T>, } impl<T> BinaryIndexedTree<T> where T: Default + Clone + std::ops::AddAssign + std::ops::Sub<Output = T>, { pub fn new(n: usize) -> BinaryIndexedTree<T> { BinaryIndexedTree { n, nodes: vec![Default::default(); n + 1], } } fn from(v: Vec<T>) -> BinaryIndexedTree<T> { let n = v.len(); let nodes: Vec<T> = vec![Default::default(); n + 1]; let mut ans = BinaryIndexedTree { n, nodes }; for i in 0..n { ans.update(i, v[i].clone()) } ans } pub fn update(&mut self, idx: usize, delta: T) { let mut i = (idx + 1) as i32; while i <= self.n as i32 { self.nodes[i as usize] += delta.clone(); i += i & (-i); } } pub fn sum_to(&mut self, idx: usize) -> T { let mut i = (idx + 1) as i32; let mut sum = Default::default(); while i > 0 { sum += self.nodes[i as usize].clone(); i -= i & (-i); } sum } pub fn sum_of_range(&mut self, idx: usize, end: usize) -> T { self.sum_to(end) - if idx == 0 { Default::default() } else { self.sum_to(idx - 1) } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_bit_sum() { let mut bit = BinaryIndexedTree::from(vec![1, 2, 3, 4, 5, 6, 7, 8]); println!("{:?}", bit); assert_eq!(bit.sum_to(0), 1); assert_eq!(bit.sum_to(1), 3); assert_eq!(bit.sum_to(2), 6); assert_eq!(bit.sum_to(3), 10); assert_eq!(bit.sum_to(4), 15); assert_eq!(bit.sum_to(5), 21); assert_eq!(bit.sum_to(6), 28); assert_eq!(bit.sum_to(7), 36); } #[test] fn test_bit_sum_of_range() { let mut bit = BinaryIndexedTree::from(vec![1, 2, 3, 4, 5, 6, 7, 8]); println!("{:?}", bit); assert_eq!(bit.sum_of_range(0, 0), 1); assert_eq!(bit.sum_of_range(1, 1), 2); assert_eq!(bit.sum_of_range(0, 1), 3); assert_eq!(bit.sum_of_range(0, 2), 6); assert_eq!(bit.sum_of_range(2, 3), 7); assert_eq!(bit.sum_of_range(0, 7), 36); assert_eq!(bit.sum_of_range(6, 7), 15); } }