use num::{Integer, One, Unsigned, Zero, rational::Ratio};
use num_traits::{NumAssignOps, NumOps};
use std::{collections::HashMap, hash::Hash, iter::Sum};
fn weights_to_ranges<T: Hash + Eq, U: Integer + Clone + NumOps + NumAssignOps + Sum>(
weights: &[(T, U)],
) -> HashMap<&T, (Ratio<U>, Ratio<U>)> {
let mut ranges = HashMap::with_capacity(weights.len());
let sum = weights.iter().map(|(_, weight)| weight.clone()).sum::<U>();
let mut total_weight: Ratio<U> = Ratio::zero();
for (key, weight) in weights.iter() {
let l_weight = total_weight.clone();
total_weight += Ratio::new(weight.clone(), sum.clone());
ranges.insert(key, (l_weight, total_weight.clone()));
}
ranges
}
pub fn arithmetic_encode<
T: Hash + Eq,
U: Unsigned + Integer + Clone + NumOps + NumAssignOps + Sum,
>(
input: &[T],
weights: &[(T, U)],
) -> Ratio<U> {
let ranges = weights_to_ranges(weights);
let mut l = Ratio::zero();
let mut r = Ratio::one();
for symbol in input {
let (l_weight, r_weight) = ranges.get(symbol).unwrap();
let range = r - l.clone();
r = l.clone() + range.clone() * r_weight;
l = l + range * l_weight;
}
return (r + l) / (U::one() + U::one());
}
pub fn arithmetic_decode<
T: Hash + Eq + Clone,
U: Unsigned + Integer + Clone + NumOps + NumAssignOps + Sum,
>(
input: Ratio<U>,
weights: &[(T, U)],
length: usize,
) -> Vec<T> {
let ranges = weights_to_ranges(weights);
let mut l = Ratio::zero();
let mut r = Ratio::one();
let mut output: Vec<T> = Vec::with_capacity(length);
for _ in 0..length {
let d = r.clone() - l.clone();
let x = (input.clone() - l.clone()) / d.clone();
for (key, (l_weight, r_weight)) in ranges.iter() {
if x >= *l_weight && x < *r_weight {
output.push((*key).clone());
r = l.clone() + d.clone() * r_weight;
l = l + d * l_weight;
break;
}
}
}
return output;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_arithmetic_encode() {
let input = "abcd";
let weights: &[(u8, u32)] = &[(b'a', 1), (b'b', 1), (b'c', 1), (b'd', 1)];
let encoded = arithmetic_encode(input.as_ref(), weights);
assert_eq!(encoded, Ratio::new(55, 512));
}
#[test]
fn test_arithmetic_decode() {
let input = Ratio::new(55, 512);
let weights: &[(u8, u32)] = &[(b'a', 1), (b'b', 1), (b'c', 1), (b'd', 1)];
let length = 4;
let decoded = arithmetic_decode(input, &weights, length);
assert_eq!(decoded, b"abcd");
}
}