pco/ans/
encoding.rs

1use std::cmp::max;
2
3use crate::ans::spec::Spec;
4use crate::ans::{AnsState, Symbol};
5use crate::constants::{Bitlen, Weight};
6
7#[derive(Clone, Debug)]
8struct SymbolInfo {
9  renorm_bit_cutoff: AnsState,
10  min_renorm_bits: Bitlen,
11  next_states: Vec<AnsState>,
12}
13
14impl SymbolInfo {
15  #[inline]
16  fn next_state_for(&self, x_s: AnsState) -> AnsState {
17    self.next_states[x_s as usize - self.next_states.len()]
18  }
19}
20
21#[derive(Clone, Debug)]
22pub struct Encoder {
23  symbol_infos: Vec<SymbolInfo>,
24  size_log: Bitlen,
25}
26
27impl Encoder {
28  pub fn new(spec: &Spec) -> Self {
29    let table_size = spec.table_size();
30
31    let mut symbol_infos = spec
32      .symbol_weights
33      .iter()
34      .map(|&weight| {
35        // e.g. If the symbol count is 3 and table size is 16, so the x_s values
36        // are in [3, 6).
37        // We find the power of 2 in this range (4), then compare its log to 16
38        // to find the min renormalization bits (4 - 2 = 2).
39        // Finally we choose the cutoff as 2 * 3 * 2 ^ renorm_bits = 24.
40        let max_x_s = 2 * weight - 1;
41        let min_renorm_bits = spec.size_log - max_x_s.ilog2() as Bitlen;
42        let renorm_bit_cutoff = (2 * weight * (1 << min_renorm_bits)) as AnsState;
43        SymbolInfo {
44          renorm_bit_cutoff,
45          min_renorm_bits,
46          next_states: Vec::with_capacity(weight as usize),
47        }
48      })
49      .collect::<Vec<_>>();
50
51    for (state_idx, &symbol) in spec.state_symbols.iter().enumerate() {
52      symbol_infos[symbol as usize]
53        .next_states
54        .push((table_size + state_idx) as AnsState);
55    }
56
57    Self {
58      // We choose the initial state from [table_size, 2 * table_size)
59      // to be the minimum as this tends to require fewer bits to encode
60      // the first symbol.
61      symbol_infos,
62      size_log: spec.size_log,
63    }
64  }
65
66  // Returns the new state, and how many bits of the existing state to write.
67  // The value of those bits may contain larger significant bits that must be
68  // ignored.
69  // We don't write to a BitWriter directly because ANS operates in a LIFO
70  // manner. We need to write these in reverse order.
71  #[inline]
72  pub fn encode(&self, state: AnsState, symbol: Symbol) -> (AnsState, Bitlen) {
73    let symbol_info = &self.symbol_infos[symbol as usize];
74    let renorm_bits = if state >= symbol_info.renorm_bit_cutoff {
75      symbol_info.min_renorm_bits + 1
76    } else {
77      symbol_info.min_renorm_bits
78    };
79    (
80      symbol_info.next_state_for(state >> renorm_bits),
81      renorm_bits,
82    )
83  }
84
85  pub fn size_log(&self) -> Bitlen {
86    self.size_log
87  }
88
89  pub fn default_state(&self) -> AnsState {
90    1 << self.size_log
91  }
92}
93
94// given size_log, quantize the counts
95fn quantize_weights_to(counts: &[Weight], total_count: usize, size_log: Bitlen) -> Vec<Weight> {
96  if size_log == 0 {
97    return vec![1];
98  }
99
100  let required_weight_sum = 1 << size_log;
101  let multiplier = required_weight_sum as f32 / total_count as f32;
102  // We need to give each bin a weight of at least 1, so we first calculate
103  // how much surplus weight each bin wants above 1.
104  let desired_surplus_per_bin = counts
105    .iter()
106    .map(|&count| (count as f32 * multiplier - 1.0).max(0.0))
107    .collect::<Vec<_>>();
108  let desired_surplus = desired_surplus_per_bin.iter().sum::<f32>();
109  let required_surplus = required_weight_sum - counts.len() as Weight;
110
111  // Divide the available surplus among the bins, proportional to their desired
112  // surplus.
113  let surplus_mult = if desired_surplus == 0.0 {
114    0.0
115  } else {
116    required_surplus as f32 / desired_surplus
117  };
118  let float_weights = desired_surplus_per_bin
119    .iter()
120    .map(|&surplus| 1.0 + surplus * surplus_mult)
121    .collect::<Vec<_>>();
122
123  // Round the float weights to integers. This doesn't give us the exact right
124  // sum, so we further adjust afterward.
125  let mut weights = float_weights
126    .iter()
127    .map(|&weight| weight.round() as Weight)
128    .collect::<Vec<_>>();
129  let mut weight_sum = weights.iter().sum::<Weight>();
130
131  // Take weight away from bins that got rounded up or give it out to bins that
132  // got rounded down until we have the exact right weight sum.
133  let mut i = 0;
134  while weight_sum > required_weight_sum {
135    if weights[i] > 1 && weights[i] as f32 > float_weights[i] {
136      weights[i] -= 1;
137      weight_sum -= 1;
138    }
139    i += 1;
140  }
141  i = 0;
142  while weight_sum < required_weight_sum {
143    if (weights[i] as f32) < float_weights[i] {
144      weights[i] += 1;
145      weight_sum += 1;
146    }
147    i += 1;
148  }
149
150  weights
151}
152
153// choose both size_log and weights
154// increase size_log if it's insufficient to encode all bins;
155// decrease it if all the weights are divisible by 2^k
156pub fn quantize_weights(
157  counts: Vec<Weight>,
158  total_count: usize,
159  max_size_log: Bitlen,
160) -> (Bitlen, Vec<Weight>) {
161  if counts.len() == 1 {
162    return (0, vec![1]);
163  }
164
165  let min_size_log = (usize::BITS - (counts.len() - 1).leading_zeros()) as Bitlen;
166  let mut size_log = max(min_size_log, max_size_log);
167  let mut weights = quantize_weights_to(&counts, total_count, size_log);
168
169  let power_of_2 = weights.iter().map(|&w| w.trailing_zeros()).min().unwrap() as Bitlen;
170  size_log -= power_of_2;
171  for weight in &mut weights {
172    *weight >>= power_of_2;
173  }
174  (size_log, weights)
175}
176
177#[cfg(test)]
178mod tests {
179  use super::*;
180
181  #[test]
182  fn test_quantize_weights_to() {
183    let quantized = quantize_weights_to(&[777], 777, 0);
184    assert_eq!(quantized, vec![1]);
185
186    let quantized = quantize_weights_to(&[777, 1], 778, 1);
187    assert_eq!(quantized, vec![1, 1]);
188
189    let quantized = quantize_weights_to(&[777, 1], 778, 2);
190    assert_eq!(quantized, vec![3, 1]);
191
192    let quantized = quantize_weights_to(&[2, 3, 6, 5, 1], 17, 3);
193    assert_eq!(quantized, vec![1, 1, 3, 2, 1]);
194
195    let quantized = quantize_weights_to(&[1, 1], 2, 1);
196    assert_eq!(quantized, vec![1, 1]);
197  }
198
199  #[test]
200  fn test_quantize_weights() {
201    let quantized = quantize_weights(vec![77, 100], 177, 4);
202    assert_eq!(quantized, (4, vec![7, 9]));
203
204    let quantized = quantize_weights(vec![77, 77], 154, 4);
205    assert_eq!(quantized, (1, vec![1, 1]));
206  }
207}