redux/model/
adaptive_tree.rs1use std::boxed::Box;
4use std::vec::Vec;
5use super::Model;
6use super::Parameters;
7use super::super::Result;
8use super::super::Error;
9
10pub struct AdaptiveTreeModel {
12 tree: Vec<u64>,
14 count: u64,
17 params: Parameters,
19}
20
21trait LastOne<T> {
24 fn last_one(self) -> T;
25}
26
27impl LastOne<usize> for usize {
29 fn last_one(self) -> usize {
30 self & self.wrapping_neg()
31 }
32}
33
34impl AdaptiveTreeModel {
35 pub fn new(p: Parameters) -> Box<Model> {
37 let mut m = AdaptiveTreeModel {
38 tree: vec![0; p.symbol_count + 1],
39 count: p.symbol_count as u64,
40 params: p,
41 };
42
43 for i in 0..m.tree.len() {
44 m.tree[i] = i.last_one() as u64;
45 }
46
47 return Box::new(m);
48 }
49
50 fn get_frequency_single(&self, symbol: usize) -> u64 {
52 let mut i = symbol;
53 let mut sum = self.tree[0];
54 while i > 0 {
55 sum += self.tree[i];
56 i -= i.last_one();
57 }
58 return sum;
59 }
60
61 fn get_frequency_range(&mut self, symbol: usize) -> (u64, u64) {
64 let mut sumh = 0u64;
65 let mut suml = 0u64;
66 let mut h = symbol + 1;
67 let mut l = symbol;
68 while h != l {
69 if h > l {
70 sumh += self.tree[h];
71 h -= h.last_one();
72 } else {
73 suml += self.tree[l];
74 l -= l.last_one();
75 }
76 }
77
78 let sumr = self.get_frequency_single(h);
79 (suml + sumr, sumh + sumr)
80 }
81
82 fn update(&mut self, symbol: usize) {
84 if self.total_frequency() < self.params.freq_max {
85 let mut i = symbol;
86 while i <= self.params.symbol_count {
87 self.tree[i] += 1;
88 i += i.last_one();
89 }
90 self.count += 1;
91 }
92 }
93}
94
95impl Model for AdaptiveTreeModel {
96 fn parameters<'a>(&'a self) -> &'a Parameters {
97 &self.params
98 }
99
100 fn total_frequency(&self) -> u64 {
101 debug_assert!(self.count == self.get_frequency_single(self.params.symbol_count));
102 self.count
103 }
104
105 fn get_frequency(&mut self, symbol: usize) -> Result<(u64, u64)> {
106 if symbol > self.params.symbol_eof {
107 Err(Error::InvalidInput)
108 } else {
109 let result = self.get_frequency_range(symbol);
110 self.update(symbol + 1);
111 Ok(result)
112 }
113 }
114
115 fn get_symbol(&mut self, value: u64) -> Result<(usize, u64, u64)> {
116 let mut m = self.params.symbol_eof;
117 let mut i = 0usize;
118 let mut v = value;
119 while (m > 0) && (i < self.params.symbol_eof) {
120 let ti = i + m;
121 let tv = self.tree[ti];
122 if v >= tv {
123 i = ti;
124 v -= tv;
125 }
126 m >>= 1;
127 }
128
129 let (l, h) = self.get_frequency_range(i);
130 if value >= h {
131 Err(Error::InvalidInput)
132 } else {
133 self.update(i + 1);
134 Ok((i, l, h))
135 }
136 }
137
138 #[cfg(debug_assertions)]
139 fn get_freq_table(&self) -> Vec<(u64, u64)> {
140 let mut res = vec![(0u64, 0u64); self.params.symbol_count];
141 for i in 0..self.params.symbol_count {
142 res[i] = (self.get_frequency_single(i), self.get_frequency_single(i + 1));
143 }
144 res
145 }
146}
147