use std::cmp::max;
use crate::ans::spec::Spec;
use crate::ans::{AnsState, Symbol};
use crate::constants::{Bitlen, Weight};
#[derive(Clone, Debug)]
struct SymbolInfo {
renorm_bit_cutoff: AnsState,
min_renorm_bits: Bitlen,
next_states: Vec<AnsState>,
}
impl SymbolInfo {
#[inline]
fn next_state_for(&self, x_s: AnsState) -> AnsState {
self.next_states[x_s as usize - self.next_states.len()]
}
}
#[derive(Clone, Debug)]
pub struct Encoder {
symbol_infos: Vec<SymbolInfo>,
size_log: Bitlen,
}
impl Encoder {
pub fn new(spec: &Spec) -> Self {
let table_size = spec.table_size();
let mut symbol_infos = spec
.symbol_weights
.iter()
.map(|&weight| {
let max_x_s = 2 * weight - 1;
let min_renorm_bits = spec.size_log - max_x_s.ilog2() as Bitlen;
let renorm_bit_cutoff = (2 * weight * (1 << min_renorm_bits)) as AnsState;
SymbolInfo {
renorm_bit_cutoff,
min_renorm_bits,
next_states: Vec::with_capacity(weight as usize),
}
})
.collect::<Vec<_>>();
for (state_idx, &symbol) in spec.state_symbols.iter().enumerate() {
symbol_infos[symbol as usize]
.next_states
.push((table_size + state_idx) as AnsState);
}
Self {
symbol_infos,
size_log: spec.size_log,
}
}
#[inline]
pub fn encode(&self, state: AnsState, symbol: Symbol) -> (AnsState, Bitlen) {
let symbol_info = &self.symbol_infos[symbol as usize];
let renorm_bits = if state >= symbol_info.renorm_bit_cutoff {
symbol_info.min_renorm_bits + 1
} else {
symbol_info.min_renorm_bits
};
(
symbol_info.next_state_for(state >> renorm_bits),
renorm_bits,
)
}
pub fn size_log(&self) -> Bitlen {
self.size_log
}
pub fn default_state(&self) -> AnsState {
1 << self.size_log
}
}
fn quantize_weights_to(counts: &[Weight], total_count: usize, size_log: Bitlen) -> Vec<Weight> {
if size_log == 0 {
return vec![1];
}
let required_weight_sum = 1 << size_log;
let multiplier = required_weight_sum as f32 / total_count as f32;
let desired_surplus_per_bin = counts
.iter()
.map(|&count| (count as f32 * multiplier - 1.0).max(0.0))
.collect::<Vec<_>>();
let desired_surplus = desired_surplus_per_bin.iter().sum::<f32>();
let required_surplus = required_weight_sum - counts.len() as Weight;
let surplus_mult = if desired_surplus == 0.0 {
0.0
} else {
required_surplus as f32 / desired_surplus
};
let float_weights = desired_surplus_per_bin
.iter()
.map(|&surplus| 1.0 + surplus * surplus_mult)
.collect::<Vec<_>>();
let mut weights = float_weights
.iter()
.map(|&weight| weight.round() as Weight)
.collect::<Vec<_>>();
let mut weight_sum = weights.iter().sum::<Weight>();
let mut i = 0;
while weight_sum > required_weight_sum {
if weights[i] > 1 && weights[i] as f32 > float_weights[i] {
weights[i] -= 1;
weight_sum -= 1;
}
i += 1;
}
i = 0;
while weight_sum < required_weight_sum {
if (weights[i] as f32) < float_weights[i] {
weights[i] += 1;
weight_sum += 1;
}
i += 1;
}
weights
}
pub fn quantize_weights(
counts: Vec<Weight>,
total_count: usize,
max_size_log: Bitlen,
) -> (Bitlen, Vec<Weight>) {
if counts.len() == 1 {
return (0, vec![1]);
}
let min_size_log = (usize::BITS - (counts.len() - 1).leading_zeros()) as Bitlen;
let mut size_log = max(min_size_log, max_size_log);
let mut weights = quantize_weights_to(&counts, total_count, size_log);
let power_of_2 = weights.iter().map(|&w| w.trailing_zeros()).min().unwrap() as Bitlen;
size_log -= power_of_2;
for weight in &mut weights {
*weight >>= power_of_2;
}
(size_log, weights)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantize_weights_to() {
let quantized = quantize_weights_to(&[777], 777, 0);
assert_eq!(quantized, vec![1]);
let quantized = quantize_weights_to(&[777, 1], 778, 1);
assert_eq!(quantized, vec![1, 1]);
let quantized = quantize_weights_to(&[777, 1], 778, 2);
assert_eq!(quantized, vec![3, 1]);
let quantized = quantize_weights_to(&[2, 3, 6, 5, 1], 17, 3);
assert_eq!(quantized, vec![1, 1, 3, 2, 1]);
let quantized = quantize_weights_to(&[1, 1], 2, 1);
assert_eq!(quantized, vec![1, 1]);
}
#[test]
fn test_quantize_weights() {
let quantized = quantize_weights(vec![77, 100], 177, 4);
assert_eq!(quantized, (4, vec![7, 9]));
let quantized = quantize_weights(vec![77, 77], 154, 4);
assert_eq!(quantized, (1, vec![1, 1]));
}
}