use std::{io, ops::Range};
use bitstream_io::BitRead;
use crate::{common, BitStore, Model};
#[derive(Debug)]
pub struct Decoder<M, R>
where
M: Model,
R: BitRead,
{
model: M,
state: State<M::B, R>,
}
trait BitReadExt {
fn next_bit(&mut self) -> io::Result<Option<bool>>;
}
impl<R: BitRead> BitReadExt for R {
fn next_bit(&mut self) -> io::Result<Option<bool>> {
match self.read_bit() {
Ok(bit) => Ok(Some(bit)),
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(None),
Err(e) => Err(e),
}
}
}
impl<M, R> Decoder<M, R>
where
M: Model,
R: BitRead,
{
pub fn new(model: M, input: R) -> Self {
let frequency_bits = model.max_denominator().log2() + 1;
let precision = M::B::BITS - frequency_bits;
Self::with_precision(model, input, precision)
}
pub fn with_precision(model: M, input: R, precision: u32) -> Self {
let frequency_bits = model.max_denominator().log2() + 1;
debug_assert!(
(precision >= (frequency_bits + 2)),
"not enough bits of precision to prevent overflow/underflow",
);
debug_assert!(
(frequency_bits + precision) <= M::B::BITS,
"not enough bits in BitStore to support the required precision",
);
let state = State::new(precision, input);
Self { model, state }
}
pub const fn with_state(state: State<M::B, R>, model: M) -> Self {
Self { model, state }
}
pub fn decode_all(&mut self) -> DecodeIter<M, R> {
DecodeIter { decoder: self }
}
#[allow(clippy::missing_panics_doc)]
pub fn decode(&mut self) -> io::Result<Option<M::Symbol>> {
self.state.initialise()?;
let denominator = self.model.denominator();
debug_assert!(
denominator <= self.model.max_denominator(),
"denominator is greater than maximum!"
);
let value = self.state.value(denominator);
let symbol = self.model.symbol(value);
let p = self
.model
.probability(symbol.as_ref())
.expect("this should not be able to fail. Check the implementation of the model.");
self.state.scale(p, denominator)?;
self.model.update(symbol.as_ref());
Ok(symbol)
}
pub fn chain<X>(self, model: X) -> Decoder<X, R>
where
X: Model<B = M::B>,
{
Decoder {
model,
state: self.state,
}
}
pub fn into_inner(self) -> (M, State<M::B, R>) {
(self.model, self.state)
}
}
#[allow(missing_debug_implementations)]
pub struct DecodeIter<'a, M, R>
where
M: Model,
R: BitRead,
{
decoder: &'a mut Decoder<M, R>,
}
impl<M, R> Iterator for DecodeIter<'_, M, R>
where
M: Model,
R: BitRead,
{
type Item = io::Result<M::Symbol>;
fn next(&mut self) -> Option<Self::Item> {
self.decoder.decode().transpose()
}
}
#[derive(Debug)]
pub struct State<B, R>
where
B: BitStore,
R: BitRead,
{
state: common::State<B>,
input: R,
x: B,
uninitialised: bool,
}
impl<B, R> State<B, R>
where
B: BitStore,
R: BitRead,
{
pub fn new(precision: u32, input: R) -> Self {
let state = common::State::new(precision);
let x = B::ZERO;
Self {
state,
input,
x,
uninitialised: true,
}
}
fn normalise(&mut self) -> io::Result<()> {
while self.state.high < self.state.half() || self.state.low >= self.state.half() {
if self.state.high < self.state.half() {
self.state.high <<= 1;
self.state.low <<= 1;
self.x <<= 1;
} else {
self.state.low = (self.state.low - self.state.half()) << 1;
self.state.high = (self.state.high - self.state.half()) << 1;
self.x = (self.x - self.state.half()) << 1;
}
if self.input.next_bit()? == Some(true) {
self.x += B::ONE;
}
}
while self.state.low >= self.state.quarter()
&& self.state.high < (self.state.three_quarter())
{
self.state.low = (self.state.low - self.state.quarter()) << 1;
self.state.high = (self.state.high - self.state.quarter()) << 1;
self.x = (self.x - self.state.quarter()) << 1;
if self.input.next_bit()? == Some(true) {
self.x += B::ONE;
}
}
Ok(())
}
fn scale(&mut self, p: Range<B>, denominator: B) -> io::Result<()> {
self.state.scale(p, denominator);
self.normalise()
}
fn value(&self, denominator: B) -> B {
let range = self.state.high - self.state.low + B::ONE;
((self.x - self.state.low + B::ONE) * denominator - B::ONE) / range
}
fn fill(&mut self) -> io::Result<()> {
for _ in 0..self.state.precision {
self.x <<= 1;
if self.input.next_bit()? == Some(true) {
self.x += B::ONE;
}
}
Ok(())
}
fn initialise(&mut self) -> io::Result<()> {
if self.uninitialised {
self.fill()?;
self.uninitialised = false;
}
Ok(())
}
}