use crate::code::{Alphabet, Letter};
use crate::model::Model;
use crate::tokens::{Token, TokenPacker, Tokenizer};
use anyhow::{anyhow, Result};
use log::{debug, log_enabled, Level};
use std::collections::HashMap;
use std::convert::TryFrom;
use std::io::Cursor;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
pub mod balanced_tree;
pub mod fano;
pub mod huffman;
pub mod shannon;
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum EncodingScheme {
BalancedTree,
Fano,
Shannon,
Huffman,
}
pub fn new_encoder<T: Token>(
encoding_scheme: &EncodingScheme,
model: Model<T>,
) -> Result<Encoding<T>> {
let constructor = match encoding_scheme {
EncodingScheme::BalancedTree => balanced_tree::new::<T>,
EncodingScheme::Fano => fano::new::<T>,
EncodingScheme::Huffman => huffman::new::<T>,
EncodingScheme::Shannon => shannon::new::<T>,
};
constructor(model)
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct Encoding<T: Token> {
map: HashMap<T, Letter>,
alphabet: Alphabet,
}
impl<T: Token> Encoding<T> {
fn new(map: HashMap<T, Letter>) -> Result<Self> {
let mut letters: Vec<&Letter> = map.values().collect();
letters.sort();
let alphabet = Alphabet::new(letters.into_iter().cloned().collect())?;
log_encoder_ring(&map);
Ok(Self { map, alphabet })
}
pub fn alphabet(&self) -> &Alphabet {
&self.alphabet
}
pub fn map(&self) -> &HashMap<T, Letter> {
&self.map
}
pub fn reverse_map(&self) -> HashMap<&Letter, &T> {
let mut m = HashMap::new();
for (t, l) in &self.map {
m.insert(l, t);
}
m
}
pub fn pack<W: std::io::Write>(&self, mut w: W) -> Result<()> {
let tokens = self.tokens();
let size = tokens.iter().fold(0, |sum, t| sum + t.bit_count()) / 8;
w.write_all(&pack_u64(size as u64))?;
T::Packer::pack(tokens.into_iter(), &mut w)?;
self.alphabet().clone().pack(w)?;
Ok(())
}
pub fn unpack<R: std::io::Read>(mut r: R) -> Result<Self> {
let size = unpack_u64(&mut r)?;
let safe_size = usize::try_from(size)?;
let mut buf = vec![0u8; safe_size];
r.read_exact(&mut buf)?;
let tokens: Result<Vec<T>> = T::Tokenizer::tokenize(Cursor::new(buf)).unwrap().collect();
let tokens = tokens?;
let alphabet = crate::code::Alphabet::unpack(r)?;
let letters = alphabet.letters().into_iter().cloned();
if letters.len() != tokens.len() {
return Err(anyhow!(
"Extracted letter count {} does not match token count {}",
letters.len(),
tokens.len(),
));
}
let map: HashMap<T, Letter> = tokens.iter().cloned().zip(letters.into_iter()).collect();
log_encoder_ring(&map);
Ok(Self { map, alphabet })
}
fn tokens(&self) -> Vec<T> {
let m = self.reverse_map();
let mut letters: Vec<&Letter> = self.map.values().collect();
letters.sort();
letters.into_iter().map(|l| m[l].clone()).collect()
}
}
#[allow(dead_code)]
fn from_pairs<T: Token>(data: &[(T, Letter)]) -> Result<Encoding<T>> {
Encoding::new(data.iter().cloned().collect())
}
fn log_encoder_ring<T: Token>(m: &HashMap<T, Letter>) {
if !log_enabled!(Level::Debug) {
return;
}
debug!("Encoder ring:");
for (k, l) in m.iter() {
debug!(" |{:?}|: |{:?}|", k, l);
}
}
fn pack_u64(s: u64) -> Vec<u8> {
s.to_be_bytes().to_vec()
}
fn unpack_u64<R: std::io::Read>(mut r: R) -> Result<u64> {
let mut buf: [u8; 8] = [0; 8];
r.read_exact(&mut buf)?;
Ok(u64::from_be_bytes(buf))
}
#[cfg(test)]
mod roundtrip_with_len_tests {
use super::*;
use crate::tokens::{bytes::Byte, graphemes::Grapheme};
use std::io::{Cursor, Read};
#[test]
fn empty() {
let encoding: Encoding<Byte> = Encoding::new(HashMap::new()).unwrap();
let mut buf = Vec::<u8>::new();
assert!(encoding.pack(&mut buf).is_ok());
let got: Encoding<Byte> = Encoding::unpack(&mut Cursor::new(&mut buf)).unwrap();
assert_eq!(got.map(), encoding.map());
}
#[test]
fn non_empty() {
let map = (vec![
(Byte::from(0), Letter::from_bytes(&vec![0u8, 1u8])),
(Byte::from(1), Letter::from_bytes(&vec![0u8, 0u8, 1u8])),
(Byte::from(2), Letter::from_bytes(&vec![0u8, 0u8, 0u8, 1u8])),
(
Byte::from(3),
Letter::from_bytes(&vec![0u8, 0u8, 0u8, 0u8, 1u8]),
),
(Byte::from(0), Letter::from_bytes(&vec![1u8, 1u8])),
(Byte::from(4), Letter::from_bytes(&vec![1u8, 0u8, 1u8])),
(Byte::from(5), Letter::from_bytes(&vec![1u8, 0u8, 0u8, 1u8])),
])
.into_iter()
.collect();
let encoding: Encoding<Byte> = Encoding::new(map).unwrap();
let mut buf = Vec::<u8>::new();
assert!(encoding.pack(&mut buf).is_ok());
let got: Encoding<Byte> = Encoding::unpack(&mut Cursor::new(&mut buf)).unwrap();
assert_eq!(got.map(), encoding.map());
}
#[test]
fn trailing_data_byte() {
let map = (vec![
(Byte::from(0), Letter::from_bytes(&vec![0u8, 1u8])),
(Byte::from(1), Letter::from_bytes(&vec![0u8, 0u8, 1u8])),
(Byte::from(2), Letter::from_bytes(&vec![0u8, 0u8, 0u8, 1u8])),
])
.into_iter()
.collect();
let encoding: Encoding<Byte> = Encoding::new(map).unwrap();
let mut buf = Vec::<u8>::new();
assert!(encoding.pack(&mut buf).is_ok());
buf.push(0b1111_1111);
let mut r = Cursor::new(buf);
let got: Encoding<Byte> = Encoding::unpack(&mut r).unwrap();
assert_eq!(got.map(), encoding.map());
let mut buf = Vec::<u8>::new();
assert_eq!(r.read_to_end(&mut buf).unwrap(), 1);
assert_eq!(buf, vec![0b1111_1111u8]);
}
#[test]
fn trailing_data_grapheme() {
let map = (vec![
(
Grapheme::from("a".to_owned()),
Letter::from_bytes(&vec![0u8, 1u8]),
),
(
Grapheme::from("b".to_owned()),
Letter::from_bytes(&vec![0u8, 0u8, 1u8]),
),
])
.into_iter()
.collect();
let encoding: Encoding<Grapheme> = Encoding::new(map).unwrap();
let mut buf = Vec::<u8>::new();
assert!(encoding.pack(&mut buf).is_ok());
buf.push(0b1111_1111);
let mut r = Cursor::new(buf);
let got: Encoding<Grapheme> = Encoding::unpack(&mut r).unwrap();
assert_eq!(got.map(), encoding.map());
let mut buf = Vec::<u8>::new();
assert_eq!(r.read_to_end(&mut buf).unwrap(), 1);
assert_eq!(buf, vec![0b1111_1111u8]);
}
}