use std::{io, ops::Range};
use bitstream_io::BitWrite;
use crate::{common, BitStore, Error, Model};
#[derive(Debug)]
pub struct Encoder<'a, M, W>
where
M: Model,
W: BitWrite,
{
model: M,
state: State<'a, M::B, W>,
}
impl<'a, M, W> Encoder<'a, M, W>
where
M: Model,
W: BitWrite,
{
pub fn new(model: M, bitwriter: &'a mut W) -> Self {
let frequency_bits = model.max_denominator().log2() + 1;
let precision = M::B::BITS - frequency_bits;
Self::with_precision(model, bitwriter, precision)
}
pub fn with_precision(model: M, bitwriter: &'a mut W, precision: u32) -> Self {
let frequency_bits = model.max_denominator().log2() + 1;
debug_assert!(
(precision >= (frequency_bits + 2)),
"not enough bits of precision to prevent overflow/underflow",
);
debug_assert!(
(frequency_bits + precision) <= M::B::BITS,
"not enough bits in BitStore to support the required precision",
);
Self {
model,
state: State::new(precision, bitwriter),
}
}
pub const fn with_state(state: State<'a, M::B, W>, model: M) -> Self {
Self { model, state }
}
pub fn encode_all(
&mut self,
symbols: impl IntoIterator<Item = M::Symbol>,
) -> Result<(), Error<M::ValueError>> {
for symbol in symbols {
self.encode(Some(&symbol))?;
}
self.encode(None)?;
self.flush()?;
Ok(())
}
pub fn encode(&mut self, symbol: Option<&M::Symbol>) -> Result<(), Error<M::ValueError>> {
let p = self.model.probability(symbol).map_err(Error::ValueError)?;
let denominator = self.model.denominator();
debug_assert!(
denominator <= self.model.max_denominator(),
"denominator is greater than maximum!"
);
self.state.scale(p, denominator)?;
self.model.update(symbol);
Ok(())
}
pub fn flush(&mut self) -> io::Result<()> {
self.state.flush()
}
pub fn into_inner(self) -> (M, State<'a, M::B, W>) {
(self.model, self.state)
}
pub fn chain<X>(self, model: X) -> Encoder<'a, X, W>
where
X: Model<B = M::B>,
{
Encoder {
model,
state: self.state,
}
}
}
#[derive(Debug)]
pub struct State<'a, B, W>
where
B: BitStore,
W: BitWrite,
{
state: common::State<B>,
pending: u32,
output: &'a mut W,
}
impl<'a, B, W> State<'a, B, W>
where
B: BitStore,
W: BitWrite,
{
pub fn new(precision: u32, output: &'a mut W) -> Self {
let state = common::State::new(precision);
let pending = 0;
Self {
state,
pending,
output,
}
}
fn scale(&mut self, p: Range<B>, denominator: B) -> io::Result<()> {
self.state.scale(p, denominator);
self.normalise()
}
fn normalise(&mut self) -> io::Result<()> {
while self.state.high < self.state.half() || self.state.low >= self.state.half() {
if self.state.high < self.state.half() {
self.emit(false)?;
self.state.high <<= 1;
self.state.low <<= 1;
} else {
self.emit(true)?;
self.state.low = (self.state.low - self.state.half()) << 1;
self.state.high = (self.state.high - self.state.half()) << 1;
}
}
while self.state.low >= self.state.quarter()
&& self.state.high < (self.state.three_quarter())
{
self.pending += 1;
self.state.low = (self.state.low - self.state.quarter()) << 1;
self.state.high = (self.state.high - self.state.quarter()) << 1;
}
Ok(())
}
fn emit(&mut self, bit: bool) -> io::Result<()> {
self.output.write_bit(bit)?;
for _ in 0..self.pending {
self.output.write_bit(!bit)?;
}
self.pending = 0;
Ok(())
}
pub fn flush(&mut self) -> io::Result<()> {
self.pending += 1;
if self.state.low <= self.state.quarter() {
self.emit(false)?;
} else {
self.emit(true)?;
}
Ok(())
}
}