redux/model/
adaptive_tree.rs

1//! Symbol frequency model implemented by Binary Indexed Tree.
2
3use std::boxed::Box;
4use std::vec::Vec;
5use super::Model;
6use super::Parameters;
7use super::super::Result;
8use super::super::Error;
9
10/// Adaptive model that uses a Binary Indexed Tree for storing cumulative frequencies.
11pub struct AdaptiveTreeModel {
12    /// Tree of cumulative frequencies.
13    tree: Vec<u64>,
14    /// Cache of total frequency, as is is needed quite often.
15    /// This could be otherwise calculated as `self.get_frequency_single(self.params.symbol_count)`.
16    count: u64,
17    /// Arithmetic parameters.
18    params: Parameters,
19}
20
21/// Trait for numeric types to extract the rightmost 1 bit in the binary representation.
22/// `10110100` becomes `00000100`.
23trait LastOne<T> {
24    fn last_one(self) -> T;
25}
26
27/// Implementation of rightmost bit extraction for the `usize` type.
28impl LastOne<usize> for usize {
29    fn last_one(self) -> usize {
30        self & self.wrapping_neg()
31    }
32}
33
34impl AdaptiveTreeModel {
35    /// Initializes the model with the given parameters.
36    pub fn new(p: Parameters) -> Box<Model> {
37        let mut m = AdaptiveTreeModel {
38            tree: vec![0; p.symbol_count + 1],
39            count: p.symbol_count as u64,
40            params: p,
41        };
42
43        for i in 0..m.tree.len() {
44            m.tree[i] = i.last_one() as u64;
45        }
46
47        return Box::new(m);
48    }
49
50    /// Returns the cumulated frequency of the symbol.
51    fn get_frequency_single(&self, symbol: usize) -> u64 {
52        let mut i = symbol;
53        let mut sum = self.tree[0];
54        while i > 0 {
55            sum += self.tree[i];
56            i -= i.last_one();
57        }
58        return sum;
59    }
60
61    /// Returns cumulated frequency range of the symbol.
62    /// Uses an optimized algorithm to walk the common part of the tree only once.
63    fn get_frequency_range(&mut self, symbol: usize) -> (u64, u64) {
64        let mut sumh = 0u64;
65        let mut suml = 0u64;
66        let mut h = symbol + 1;
67        let mut l = symbol;
68        while h != l {
69            if h > l {
70                sumh += self.tree[h];
71                h -= h.last_one();
72            } else {
73                suml += self.tree[l];
74                l -= l.last_one();
75            }
76        }
77
78        let sumr = self.get_frequency_single(h);
79        (suml + sumr, sumh + sumr)
80    }
81
82    /// Updates the cumulative frequencies for the given symbol.
83    fn update(&mut self, symbol: usize) {
84        if self.total_frequency() < self.params.freq_max {
85            let mut i = symbol;
86            while i <= self.params.symbol_count {
87                self.tree[i] += 1;
88                i += i.last_one();
89            }
90            self.count += 1;
91        }
92    }
93}
94
95impl Model for AdaptiveTreeModel {
96    fn parameters<'a>(&'a self) -> &'a Parameters {
97        &self.params
98    }
99
100    fn total_frequency(&self) -> u64 {
101        debug_assert!(self.count == self.get_frequency_single(self.params.symbol_count));
102        self.count
103    }
104
105    fn get_frequency(&mut self, symbol: usize) -> Result<(u64, u64)> {
106        if symbol > self.params.symbol_eof {
107            Err(Error::InvalidInput)
108        } else {
109            let result = self.get_frequency_range(symbol);
110            self.update(symbol + 1);
111            Ok(result)
112        }
113    }
114
115    fn get_symbol(&mut self, value: u64) -> Result<(usize, u64, u64)> {
116        let mut m = self.params.symbol_eof;
117        let mut i = 0usize;
118        let mut v = value;
119        while (m > 0) && (i < self.params.symbol_eof) {
120            let ti = i + m;
121            let tv = self.tree[ti];
122            if v >= tv {
123                i = ti;
124                v -= tv;
125            }
126            m >>= 1;
127        }
128
129        let (l, h) = self.get_frequency_range(i);
130        if value >= h {
131            Err(Error::InvalidInput)
132        } else {
133            self.update(i + 1);
134            Ok((i, l, h))
135        }
136    }
137
138    #[cfg(debug_assertions)]
139    fn get_freq_table(&self) -> Vec<(u64, u64)> {
140        let mut res = vec![(0u64, 0u64); self.params.symbol_count];
141        for i in 0..self.params.symbol_count {
142            res[i] = (self.get_frequency_single(i), self.get_frequency_single(i + 1));
143        }
144        res
145    }
146}
147