1use std::cmp::max;
2
3use crate::ans::spec::Spec;
4use crate::ans::{AnsState, Symbol};
5use crate::constants::{Bitlen, Weight};
6
7#[derive(Clone, Debug)]
8struct SymbolInfo {
9 renorm_bit_cutoff: AnsState,
10 min_renorm_bits: Bitlen,
11 next_states: Vec<AnsState>,
12}
13
14impl SymbolInfo {
15 #[inline]
16 fn next_state_for(&self, x_s: AnsState) -> AnsState {
17 self.next_states[x_s as usize - self.next_states.len()]
18 }
19}
20
21#[derive(Clone, Debug)]
22pub struct Encoder {
23 symbol_infos: Vec<SymbolInfo>,
24 size_log: Bitlen,
25}
26
27impl Encoder {
28 pub fn new(spec: &Spec) -> Self {
29 let table_size = spec.table_size();
30
31 let mut symbol_infos = spec
32 .symbol_weights
33 .iter()
34 .map(|&weight| {
35 let max_x_s = 2 * weight - 1;
41 let min_renorm_bits = spec.size_log - max_x_s.ilog2() as Bitlen;
42 let renorm_bit_cutoff = (2 * weight * (1 << min_renorm_bits)) as AnsState;
43 SymbolInfo {
44 renorm_bit_cutoff,
45 min_renorm_bits,
46 next_states: Vec::with_capacity(weight as usize),
47 }
48 })
49 .collect::<Vec<_>>();
50
51 for (state_idx, &symbol) in spec.state_symbols.iter().enumerate() {
52 symbol_infos[symbol as usize]
53 .next_states
54 .push((table_size + state_idx) as AnsState);
55 }
56
57 Self {
58 symbol_infos,
62 size_log: spec.size_log,
63 }
64 }
65
66 #[inline]
72 pub fn encode(&self, state: AnsState, symbol: Symbol) -> (AnsState, Bitlen) {
73 let symbol_info = &self.symbol_infos[symbol as usize];
74 let renorm_bits = if state >= symbol_info.renorm_bit_cutoff {
75 symbol_info.min_renorm_bits + 1
76 } else {
77 symbol_info.min_renorm_bits
78 };
79 (
80 symbol_info.next_state_for(state >> renorm_bits),
81 renorm_bits,
82 )
83 }
84
85 pub fn size_log(&self) -> Bitlen {
86 self.size_log
87 }
88
89 pub fn default_state(&self) -> AnsState {
90 1 << self.size_log
91 }
92}
93
94fn quantize_weights_to(counts: &[Weight], total_count: usize, size_log: Bitlen) -> Vec<Weight> {
96 if size_log == 0 {
97 return vec![1];
98 }
99
100 let required_weight_sum = 1 << size_log;
101 let multiplier = required_weight_sum as f32 / total_count as f32;
102 let desired_surplus_per_bin = counts
105 .iter()
106 .map(|&count| (count as f32 * multiplier - 1.0).max(0.0))
107 .collect::<Vec<_>>();
108 let desired_surplus = desired_surplus_per_bin.iter().sum::<f32>();
109 let required_surplus = required_weight_sum - counts.len() as Weight;
110
111 let surplus_mult = if desired_surplus == 0.0 {
114 0.0
115 } else {
116 required_surplus as f32 / desired_surplus
117 };
118 let float_weights = desired_surplus_per_bin
119 .iter()
120 .map(|&surplus| 1.0 + surplus * surplus_mult)
121 .collect::<Vec<_>>();
122
123 let mut weights = float_weights
126 .iter()
127 .map(|&weight| weight.round() as Weight)
128 .collect::<Vec<_>>();
129 let mut weight_sum = weights.iter().sum::<Weight>();
130
131 let mut i = 0;
134 while weight_sum > required_weight_sum {
135 if weights[i] > 1 && weights[i] as f32 > float_weights[i] {
136 weights[i] -= 1;
137 weight_sum -= 1;
138 }
139 i += 1;
140 }
141 i = 0;
142 while weight_sum < required_weight_sum {
143 if (weights[i] as f32) < float_weights[i] {
144 weights[i] += 1;
145 weight_sum += 1;
146 }
147 i += 1;
148 }
149
150 weights
151}
152
153pub fn quantize_weights(
157 counts: Vec<Weight>,
158 total_count: usize,
159 max_size_log: Bitlen,
160) -> (Bitlen, Vec<Weight>) {
161 if counts.len() == 1 {
162 return (0, vec![1]);
163 }
164
165 let min_size_log = (usize::BITS - (counts.len() - 1).leading_zeros()) as Bitlen;
166 let mut size_log = max(min_size_log, max_size_log);
167 let mut weights = quantize_weights_to(&counts, total_count, size_log);
168
169 let power_of_2 = weights.iter().map(|&w| w.trailing_zeros()).min().unwrap() as Bitlen;
170 size_log -= power_of_2;
171 for weight in &mut weights {
172 *weight >>= power_of_2;
173 }
174 (size_log, weights)
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn test_quantize_weights_to() {
183 let quantized = quantize_weights_to(&[777], 777, 0);
184 assert_eq!(quantized, vec![1]);
185
186 let quantized = quantize_weights_to(&[777, 1], 778, 1);
187 assert_eq!(quantized, vec![1, 1]);
188
189 let quantized = quantize_weights_to(&[777, 1], 778, 2);
190 assert_eq!(quantized, vec![3, 1]);
191
192 let quantized = quantize_weights_to(&[2, 3, 6, 5, 1], 17, 3);
193 assert_eq!(quantized, vec![1, 1, 3, 2, 1]);
194
195 let quantized = quantize_weights_to(&[1, 1], 2, 1);
196 assert_eq!(quantized, vec![1, 1]);
197 }
198
199 #[test]
200 fn test_quantize_weights() {
201 let quantized = quantize_weights(vec![77, 100], 177, 4);
202 assert_eq!(quantized, (4, vec![7, 9]));
203
204 let quantized = quantize_weights(vec![77, 77], 154, 4);
205 assert_eq!(quantized, (1, vec![1, 1]));
206 }
207}