adder_codec_core/codec/compressed/fenwick/
mod.rs

1// From https://github.com/danieleades/arithmetic-coding. Only temporary, for initial testing.
2//! [`Models`](crate::Model) implemented using Fenwick trees
3
4use std::ops::Range;
5
6pub mod context_switching;
7pub mod simple;
8
9/// A wrapper around a vector of fenwick counts, with one additional weight for
10/// EOF.
11#[derive(Debug, Clone)]
12pub struct Weights {
13    fenwick_counts: Vec<u64>,
14    total: u64,
15}
16
17impl Weights {
18    pub fn new(n: usize) -> Self {
19        // we add one extra value here to account for the EOF
20        let mut fenwick_counts = vec![0; n + 1];
21
22        for i in 0..fenwick_counts.len() {
23            fenwick::array::update(&mut fenwick_counts, i, 1);
24        }
25
26        let total = fenwick_counts.len() as u64;
27        Self {
28            fenwick_counts,
29            total,
30        }
31    }
32
33    /// Initialize the weights with the given counts
34    pub fn new_with_counts(n: usize, counts: &[u64]) -> Self {
35        // we add one extra value here to account for the EOF (stored at the FIRST index)
36        let fenwick_counts = vec![0; n + 1];
37
38        let mut weights = Self {
39            fenwick_counts,
40            total: 0,
41        };
42
43        for (i, &count) in counts.iter().enumerate() {
44            weights.update(Some(i), count);
45        }
46        weights.update(None, 1);
47        weights
48    }
49
50    fn update(&mut self, i: Option<usize>, delta: u64) {
51        let index = i.map(|i| i + 1).unwrap_or_default();
52        fenwick::array::update(&mut self.fenwick_counts, index, delta);
53        self.total += delta;
54    }
55
56    fn prefix_sum(&self, i: Option<usize>) -> u64 {
57        let index = i.map(|i| i + 1).unwrap_or_default();
58        fenwick::array::prefix_sum(&self.fenwick_counts, index)
59    }
60
61    /// Returns the probability range for the given symbol
62    pub(crate) fn range(&self, i: Option<usize>) -> Range<u64> {
63        // Increment the symbol index by one to account for the EOF?
64        let index = i.map(|i| i + 1).unwrap_or_default();
65
66        let upper = fenwick::array::prefix_sum(&self.fenwick_counts, index);
67
68        let lower = if index == 0 {
69            0
70        } else {
71            fenwick::array::prefix_sum(&self.fenwick_counts, index - 1)
72        };
73        lower..upper
74    }
75
76    pub fn len(&self) -> usize {
77        self.fenwick_counts.len() - 1
78    }
79
80    /// Used for decoding. Find the symbol index for the given `prefix_sum`
81    fn symbol(&self, prefix_sum: u64) -> Option<usize> {
82        if prefix_sum < self.prefix_sum(None) {
83            return None;
84        }
85
86        // invariant: low <= our answer < high
87        // we seek the lowest number i such that prefix_sum(i) > prefix_sum
88        let mut low = 0;
89        let mut high = self.len();
90        debug_assert!(low < high);
91        debug_assert!(prefix_sum < self.prefix_sum(Some(high - 1)));
92        while low + 1 < high {
93            let i = (low + high - 1) / 2;
94            if self.prefix_sum(Some(i)) > prefix_sum {
95                // i could be our answer, so set high just above it.
96                high = i + 1;
97            } else {
98                // i could not be our answer, so set low just above it.
99                low = i + 1;
100            }
101        }
102        Some(low)
103    }
104
105    const fn total(&self) -> u64 {
106        self.total
107    }
108}
109
110#[derive(Debug, thiserror::Error)]
111#[error("invalid symbol received: {0}")]
112pub struct ValueError(pub usize);
113
114#[cfg(test)]
115mod tests {
116    use super::Weights;
117
118    #[test]
119    fn total() {
120        let weights = Weights::new(3);
121        assert_eq!(weights.total(), 4);
122    }
123
124    #[test]
125    fn range() {
126        let weights = Weights::new(3);
127        assert_eq!(weights.range(None), 0..1);
128        assert_eq!(weights.range(Some(0)), 1..2);
129        assert_eq!(weights.range(Some(1)), 2..3);
130        assert_eq!(weights.range(Some(2)), 3..4);
131    }
132
133    #[test]
134    #[should_panic]
135    fn range_out_of_bounds() {
136        let weights = Weights::new(3);
137        weights.range(Some(3));
138    }
139
140    #[test]
141    fn symbol() {
142        let weights = Weights::new(3);
143        assert_eq!(weights.symbol(0), None);
144        assert_eq!(weights.symbol(1), Some(0));
145        assert_eq!(weights.symbol(2), Some(1));
146        assert_eq!(weights.symbol(3), Some(2));
147    }
148
149    #[test]
150    #[should_panic]
151    fn symbol_out_of_bounds() {
152        let weights = Weights::new(3);
153        weights.symbol(4);
154    }
155}