use std::ops::Range;
use crate::BitStore;
pub trait Model {
type Symbol;
type ValueError: std::error::Error;
type B: BitStore;
fn probability(&self, symbol: &Self::Symbol) -> Result<Range<Self::B>, Self::ValueError>;
fn denominator(&self) -> Self::B {
self.max_denominator()
}
fn max_denominator(&self) -> Self::B;
fn symbol(&self, value: Self::B) -> Self::Symbol;
fn update(&mut self, _symbol: &Self::Symbol) {}
fn length(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct Wrapper<M>
where
M: Model,
{
model: M,
remaining: usize,
}
impl<M> Wrapper<M>
where
M: Model,
{
pub fn new(model: M) -> Self {
let remaining = model.length();
Self { model, remaining }
}
}
impl<M> crate::Model for Wrapper<M>
where
M: Model,
{
type B = M::B;
type Symbol = M::Symbol;
type ValueError = Error<M::ValueError>;
fn probability(
&self,
symbol: Option<&Self::Symbol>,
) -> Result<Range<Self::B>, Self::ValueError> {
if self.remaining > 0 {
symbol.map_or(
Err(Self::ValueError::UnexpectedEof),
|s| self.model.probability(s).map_err(Self::ValueError::Value),
)
} else if symbol.is_some() {
Err(Error::UnexpectedSymbol)
} else {
Ok(Self::B::ZERO..self.denominator())
}
}
fn max_denominator(&self) -> Self::B {
self.model.max_denominator()
}
fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
if self.remaining > 0 {
Some(self.model.symbol(value))
} else {
None
}
}
fn denominator(&self) -> Self::B {
self.model.denominator()
}
fn update(&mut self, symbol: Option<&Self::Symbol>) {
if let Some(s) = symbol {
self.model.update(s);
self.remaining -= 1;
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error<E>
where
E: std::error::Error,
{
#[error("Unexpected EOF")]
UnexpectedEof,
#[error("Unexpected Symbol")]
UnexpectedSymbol,
#[error(transparent)]
Value(E),
}