use arithmetic_coding::{Decoder, Encoder, Model};
use bitstream_io::{BigEndian, BitRead, BitReader, BitWrite, BitWriter};
use symbolic::Symbol;
const PRECISION: u32 = 12;
mod integer {
use std::ops::Range;
pub struct Model;
#[derive(Debug, thiserror::Error)]
#[error("invalid symbol: {0}")]
pub struct Error(u8);
impl arithmetic_coding::Model for Model {
type B = u32;
type Symbol = u8;
type ValueError = Error;
fn probability(&self, symbol: Option<&Self::Symbol>) -> Result<Range<u32>, Error> {
match symbol {
None => Ok(0..1),
Some(&1) => Ok(1..2),
Some(&2) => Ok(2..3),
Some(&3) => Ok(2..4),
Some(x) => Err(Error(*x)),
}
}
fn symbol(&self, value: u32) -> Option<Self::Symbol> {
match value {
0..1 => None,
1..2 => Some(1),
2..3 => Some(2),
3..4 => Some(3),
_ => unreachable!(),
}
}
fn max_denominator(&self) -> u32 {
4
}
}
}
mod symbolic {
use std::{convert::Infallible, ops::Range};
#[derive(Debug, PartialEq, Eq)]
pub enum Symbol {
A,
B,
C,
}
pub struct Model;
impl arithmetic_coding::Model for Model {
type B = u32;
type Symbol = Symbol;
type ValueError = Infallible;
fn probability(&self, symbol: Option<&Self::Symbol>) -> Result<Range<u32>, Infallible> {
Ok(match symbol {
None => 0..1,
Some(&Symbol::A) => 1..2,
Some(&Symbol::B) => 2..3,
Some(&Symbol::C) => 3..4,
})
}
fn symbol(&self, value: u32) -> Option<Self::Symbol> {
match value {
0..1 => None,
1..2 => Some(Symbol::A),
2..3 => Some(Symbol::B),
3..4 => Some(Symbol::C),
_ => unreachable!(),
}
}
fn max_denominator(&self) -> u32 {
4
}
}
}
#[test]
fn round_trip() {
let input1 = vec![Symbol::A, Symbol::B, Symbol::C];
let input2 = vec![2, 1, 1, 2, 2];
let buffer = encode2(symbolic::Model, &input1, integer::Model, &input2);
let (output1, output2) = decode2(symbolic::Model, integer::Model, &buffer);
assert_eq!(input1, output1);
assert_eq!(input2, output2);
}
fn encode2<M, N>(model1: M, input1: &[M::Symbol], model2: N, input2: &[N::Symbol]) -> Vec<u8>
where
M: Model<B = N::B>,
N: Model,
{
let mut bitwriter = BitWriter::endian(Vec::default(), BigEndian);
let mut encoder1 = Encoder::with_precision(model1, &mut bitwriter, PRECISION);
encode(&mut encoder1, input1);
let mut encoder2 = encoder1.chain(model2);
encode(&mut encoder2, input2);
encoder2.flush().unwrap();
bitwriter.byte_align().unwrap();
bitwriter.into_writer()
}
fn encode<M, W>(encoder: &mut Encoder<M, W>, input: &[M::Symbol])
where
M: Model,
W: BitWrite,
{
for symbol in input {
encoder.encode(Some(symbol)).unwrap();
}
encoder.encode(None).unwrap();
}
fn decode2<M, N>(model1: M, model2: N, buffer: &[u8]) -> (Vec<M::Symbol>, Vec<N::Symbol>)
where
M: Model<B = N::B>,
N: Model,
{
let bitreader = BitReader::endian(buffer, BigEndian);
let mut decoder1 = Decoder::with_precision(model1, bitreader, PRECISION);
let output1 = decode(&mut decoder1);
let mut decoder2 = decoder1.chain(model2);
let output2 = decode(&mut decoder2);
(output1, output2)
}
fn decode<M, R>(decoder: &mut Decoder<M, R>) -> Vec<M::Symbol>
where
M: Model,
R: BitRead,
{
decoder.decode_all().map(Result::unwrap).collect()
}