adder_codec_core/codec/compressed/fenwick/
mod.rs1use std::ops::Range;
5
6pub mod context_switching;
7pub mod simple;
8
9#[derive(Debug, Clone)]
12pub struct Weights {
13 fenwick_counts: Vec<u64>,
14 total: u64,
15}
16
17impl Weights {
18 pub fn new(n: usize) -> Self {
19 let mut fenwick_counts = vec![0; n + 1];
21
22 for i in 0..fenwick_counts.len() {
23 fenwick::array::update(&mut fenwick_counts, i, 1);
24 }
25
26 let total = fenwick_counts.len() as u64;
27 Self {
28 fenwick_counts,
29 total,
30 }
31 }
32
33 pub fn new_with_counts(n: usize, counts: &[u64]) -> Self {
35 let fenwick_counts = vec![0; n + 1];
37
38 let mut weights = Self {
39 fenwick_counts,
40 total: 0,
41 };
42
43 for (i, &count) in counts.iter().enumerate() {
44 weights.update(Some(i), count);
45 }
46 weights.update(None, 1);
47 weights
48 }
49
50 fn update(&mut self, i: Option<usize>, delta: u64) {
51 let index = i.map(|i| i + 1).unwrap_or_default();
52 fenwick::array::update(&mut self.fenwick_counts, index, delta);
53 self.total += delta;
54 }
55
56 fn prefix_sum(&self, i: Option<usize>) -> u64 {
57 let index = i.map(|i| i + 1).unwrap_or_default();
58 fenwick::array::prefix_sum(&self.fenwick_counts, index)
59 }
60
61 pub(crate) fn range(&self, i: Option<usize>) -> Range<u64> {
63 let index = i.map(|i| i + 1).unwrap_or_default();
65
66 let upper = fenwick::array::prefix_sum(&self.fenwick_counts, index);
67
68 let lower = if index == 0 {
69 0
70 } else {
71 fenwick::array::prefix_sum(&self.fenwick_counts, index - 1)
72 };
73 lower..upper
74 }
75
76 pub fn len(&self) -> usize {
77 self.fenwick_counts.len() - 1
78 }
79
80 fn symbol(&self, prefix_sum: u64) -> Option<usize> {
82 if prefix_sum < self.prefix_sum(None) {
83 return None;
84 }
85
86 let mut low = 0;
89 let mut high = self.len();
90 debug_assert!(low < high);
91 debug_assert!(prefix_sum < self.prefix_sum(Some(high - 1)));
92 while low + 1 < high {
93 let i = (low + high - 1) / 2;
94 if self.prefix_sum(Some(i)) > prefix_sum {
95 high = i + 1;
97 } else {
98 low = i + 1;
100 }
101 }
102 Some(low)
103 }
104
105 const fn total(&self) -> u64 {
106 self.total
107 }
108}
109
110#[derive(Debug, thiserror::Error)]
111#[error("invalid symbol received: {0}")]
112pub struct ValueError(pub usize);
113
114#[cfg(test)]
115mod tests {
116 use super::Weights;
117
118 #[test]
119 fn total() {
120 let weights = Weights::new(3);
121 assert_eq!(weights.total(), 4);
122 }
123
124 #[test]
125 fn range() {
126 let weights = Weights::new(3);
127 assert_eq!(weights.range(None), 0..1);
128 assert_eq!(weights.range(Some(0)), 1..2);
129 assert_eq!(weights.range(Some(1)), 2..3);
130 assert_eq!(weights.range(Some(2)), 3..4);
131 }
132
133 #[test]
134 #[should_panic]
135 fn range_out_of_bounds() {
136 let weights = Weights::new(3);
137 weights.range(Some(3));
138 }
139
140 #[test]
141 fn symbol() {
142 let weights = Weights::new(3);
143 assert_eq!(weights.symbol(0), None);
144 assert_eq!(weights.symbol(1), Some(0));
145 assert_eq!(weights.symbol(2), Some(1));
146 assert_eq!(weights.symbol(3), Some(2));
147 }
148
149 #[test]
150 #[should_panic]
151 fn symbol_out_of_bounds() {
152 let weights = Weights::new(3);
153 weights.symbol(4);
154 }
155}