use crate::ans::Symbol;
use crate::constants::{Bitlen, Weight};
use crate::errors::{PcoError, PcoResult};
pub struct Spec {
pub size_log: Bitlen,
pub state_symbols: Vec<Symbol>,
pub symbol_weights: Vec<Weight>,
}
fn choose_stride(table_size: Weight) -> Weight {
let mut res = (3 * table_size) / 5;
if res.is_multiple_of(2) {
res += 1;
}
res
}
impl Spec {
fn spread_state_symbols(size_log: Bitlen, symbol_weights: &[Weight]) -> PcoResult<Vec<Symbol>> {
let table_size = symbol_weights.iter().sum::<Weight>();
if table_size != (1 << size_log) {
return Err(PcoError::corruption(format!(
"table size log of {} does not agree with total weight of {}",
size_log, table_size,
)));
}
let mut res = vec![0; table_size as usize];
let mut step = 0;
let stride = choose_stride(table_size);
let mod_table_size = Weight::MAX >> 1 >> (Weight::BITS as Bitlen - 1 - size_log);
for (symbol, &weight) in symbol_weights.iter().enumerate() {
for _ in 0..weight {
let state_idx = (stride * step) & mod_table_size;
res[state_idx as usize] = symbol as Symbol;
step += 1;
}
}
Ok(res)
}
pub fn from_weights(size_log: Bitlen, symbol_weights: Vec<Weight>) -> PcoResult<Self> {
let symbol_weights = if symbol_weights.is_empty() {
vec![1]
} else {
symbol_weights
};
let state_symbols = Self::spread_state_symbols(size_log, &symbol_weights)?;
Ok(Self {
size_log,
state_symbols,
symbol_weights,
})
}
pub fn table_size(&self) -> usize {
1 << self.size_log
}
}
#[cfg(test)]
mod tests {
use crate::ans::spec::{Spec, Symbol};
use crate::constants::Weight;
use crate::errors::PcoResult;
fn assert_state_symbols(weights: Vec<Weight>, expected: Vec<Symbol>) -> PcoResult<()> {
let table_size_log = weights.iter().sum::<Weight>().ilog2();
let spec = Spec::from_weights(table_size_log, weights)?;
assert_eq!(spec.state_symbols, expected);
Ok(())
}
#[test]
fn ans_spec_new() -> PcoResult<()> {
assert_state_symbols(
vec![1, 1, 3, 11],
vec![0, 3, 2, 3, 2, 3, 3, 3, 3, 1, 3, 2, 3, 3, 3, 3],
)
}
#[test]
fn ans_spec_new_trivial() -> PcoResult<()> {
assert_state_symbols(vec![1], vec![0])?;
assert_state_symbols(vec![2], vec![0, 0])
}
}