use crate::mixer::apm::APMStage;
use crate::mixer::dual_mixer::{NUM_MODELS, byte_class};
use crate::mixer::isse::IsseChain;
use crate::mixer::multi_set_mixer::MultiSetMixer;
use crate::model::cm_model::{AssociativeContextModel, ChecksumContextModel, ContextModel};
use crate::model::dmc_model::DmcModel;
use crate::model::indirect_model::IndirectModel;
use crate::model::json_model::JsonModel;
use crate::model::match_model::MatchModel;
use crate::model::order0::Order0Model;
use crate::model::ppm_model::{PpmConfig, PpmModel};
use crate::model::run_model::RunModel;
use crate::model::sparse_model::SparseModel;
use crate::model::word_model::WordModel;
#[derive(Debug, Clone)]
pub struct CMConfig {
pub order1_size: usize,
pub order2_size: usize,
pub order3_size: usize,
pub order4_size: usize,
pub order5_size: usize,
pub order6_size: usize,
pub order7_size: usize,
pub order8_size: usize,
pub order9_size: usize,
pub match_ring_size: usize,
pub match_hash_size: usize,
pub word_size: usize,
pub sparse_size: usize,
pub run_size: usize,
pub json_size: usize,
pub ppm_config: PpmConfig,
}
impl CMConfig {
pub fn balanced() -> Self {
CMConfig {
order1_size: 1 << 25, order2_size: 1 << 24, order3_size: 1 << 25, order4_size: 1 << 25, order5_size: 1 << 25, order6_size: 1 << 24, order7_size: 1 << 25, order8_size: 1 << 25, order9_size: 1 << 24, match_ring_size: 16 << 20, match_hash_size: 8 << 20, word_size: 1 << 24, sparse_size: 1 << 23, run_size: 1 << 22, json_size: 1 << 23, ppm_config: PpmConfig::scaled_4x(),
}
}
pub fn max() -> Self {
CMConfig {
order1_size: 1 << 26, order2_size: 1 << 25, order3_size: 1 << 26, order4_size: 1 << 26, order5_size: 1 << 26, order6_size: 1 << 25, order7_size: 1 << 26, order8_size: 1 << 26, order9_size: 1 << 25, match_ring_size: 32 << 20, match_hash_size: 16 << 20, word_size: 1 << 25, sparse_size: 1 << 24, run_size: 1 << 23, json_size: 1 << 24, ppm_config: PpmConfig::scaled_4x(),
}
}
}
pub struct CMEngine {
order0: Order0Model,
order1: ContextModel,
order2: ContextModel,
order3: ChecksumContextModel,
order4: ChecksumContextModel,
order5: AssociativeContextModel,
order6: AssociativeContextModel,
order7: AssociativeContextModel,
order8: AssociativeContextModel,
order9: AssociativeContextModel,
match_model: MatchModel,
word_model: WordModel,
sparse_model: SparseModel,
run_model: RunModel,
json_model: JsonModel,
indirect_model: IndirectModel,
ppm_model: PpmModel,
dmc_model: DmcModel,
mixer: MultiSetMixer,
apm1: APMStage,
apm2: APMStage,
apm3: APMStage,
apm4: APMStage,
apm5: APMStage,
apm6: APMStage,
apm7: APMStage,
isse_model: IsseChain,
c0: u32,
c1: u8,
c2: u8,
c3: u8,
c4: u8,
c5: u8,
c6: u8,
c7: u8,
c8: u8,
c9: u8,
bpos: u8,
run_len: u8,
line_pos: u16,
column_index: u16,
}
impl CMEngine {
pub fn new() -> Self {
Self::with_config(CMConfig::balanced())
}
pub fn with_config(config: CMConfig) -> Self {
CMEngine {
order0: Order0Model::new(),
order1: ContextModel::new(config.order1_size),
order2: ContextModel::new(config.order2_size),
order3: ChecksumContextModel::new(config.order3_size),
order4: ChecksumContextModel::new(config.order4_size),
order5: AssociativeContextModel::new(config.order5_size),
order6: AssociativeContextModel::new(config.order6_size),
order7: AssociativeContextModel::new(config.order7_size),
order8: AssociativeContextModel::new(config.order8_size),
order9: AssociativeContextModel::new(config.order9_size),
match_model: MatchModel::with_sizes(config.match_ring_size, config.match_hash_size),
word_model: WordModel::with_size(config.word_size),
sparse_model: SparseModel::with_size(config.sparse_size),
run_model: RunModel::with_size(config.run_size),
json_model: JsonModel::with_size(config.json_size),
indirect_model: IndirectModel::new(),
ppm_model: PpmModel::with_config(config.ppm_config),
dmc_model: DmcModel::new_single(),
mixer: MultiSetMixer::new(),
apm1: APMStage::new(2048, 55), apm2: APMStage::new(16384, 30), apm3: APMStage::new(4096, 25), apm4: APMStage::new(4096, 15), apm5: APMStage::new(4096, 15), apm6: APMStage::new(2048, 12), apm7: APMStage::new(4096, 12), isse_model: IsseChain::new(),
c0: 1,
c1: 0,
c2: 0,
c3: 0,
c4: 0,
c5: 0,
c6: 0,
c7: 0,
c8: 0,
c9: 0,
bpos: 0,
run_len: 0,
line_pos: 0,
column_index: 0,
}
}
#[inline(always)]
pub fn predict(&mut self) -> u32 {
let c0 = self.c0;
let c1 = self.c1;
let c2 = self.c2;
let c3 = self.c3;
let c4 = self.c4;
let c5 = self.c5;
let c6 = self.c6;
let c7 = self.c7;
let bpos = self.bpos;
let p0 = self.order0.predict(c0 as usize);
let h1 = order1_hash(c1, c0);
let (p1_s, p1_r) = self.order1.predict_multi(h1);
let h2 = order2_hash(c2, c1, c0);
let (p2_s, p2_r) = self.order2.predict_multi(h2);
let h3 = order3_hash(c3, c2, c1, c0);
let (p3_s, p3_r) = self.order3.predict_multi(h3);
let h4 = order4_hash(c4, c3, c2, c1, c0);
let (p4_s, p4_r) = self.order4.predict_multi(h4);
let h5 = order5_hash(c5, c4, c3, c2, c1, c0);
let (p5_s, p5_r) = self.order5.predict_multi(h5);
let h6 = order6_hash(c6, c5, c4, c3, c2, c1, c0);
let (p6_s, p6_r) = self.order6.predict_multi(h6);
let h7 = order7_hash(c7, c6, c5, c4, c3, c2, c1, c0);
let (p7_s, p7_r) = self.order7.predict_multi(h7);
let c8 = self.c8;
let h8 = order8_hash(c8, c7, c6, c5, c4, c3, c2, c1, c0);
let (p8_s, p8_r) = self.order8.predict_multi(h8);
let c9 = self.c9;
let h9 = order9_hash(c9, c8, c7, c6, c5, c4, c3, c2, c1, c0);
let (p9_s, p9_r) = self.order9.predict_multi(h9);
let p_match = self.match_model.predict(c0, bpos, c1, c2, c3);
let p_word = self.word_model.predict(c0, bpos, c1);
let p_sparse = self.sparse_model.predict(c0, c1, c2, c3);
let p_run = self.run_model.predict(c0, bpos, c1);
let p_json = self.json_model.predict(c0, bpos, c1);
let p_indirect = self.indirect_model.predict(c0, bpos, c1);
let p_ppm = self.ppm_model.predict_bit(bpos, c0);
let p_dmc = self.dmc_model.predict();
let p_isse = self.isse_model.predict(c0, c1, c2, c3, bpos);
let predictions: [u32; NUM_MODELS] = [
p0, p1_s, p1_r, p2_s, p2_r, p3_s, p3_r, p4_s, p4_r, p5_s, p5_r, p6_s, p6_r, p7_s, p7_r,
p8_s, p8_r, p9_s, p9_r, p_match, p_word, p_sparse, p_run, p_json, p_indirect, p_ppm,
p_dmc, p_isse,
];
let bclass = byte_class(c1);
let match_q = self.match_model.match_length_quantized();
let run_q = quantize_run_for_mixer(self.run_len);
let mixed = self
.mixer
.predict(&predictions, c0, c1, c2, bpos, bclass, match_q, run_q, 0);
let apm1_ctx = (((c0 as usize & 0xFF) << 3) | bpos as usize)
.wrapping_mul(5)
.wrapping_add(run_q as usize & 0x3)
& 2047;
let after_apm1 = self.apm1.predict(mixed, apm1_ctx);
let apm2_ctx = (((c1 as usize) << 3 | bpos as usize) * 8 + bclass as usize)
.wrapping_mul(17)
.wrapping_add(c2 as usize >> 4)
& 16383;
let after_apm2 = self.apm2.predict(after_apm1, apm2_ctx);
let apm3_ctx = ((match_q as usize * 512)
+ ((c2 as usize >> 6) << 7)
+ ((c1 as usize >> 4) << 3)
+ bpos as usize)
.wrapping_mul(5)
.wrapping_add(match_q as usize)
& 4095;
let after_apm3 = self.apm3.predict(after_apm2, apm3_ctx);
let bc2 = byte_class(c2);
let apm4_ctx = (bclass as usize * 8 + bc2 as usize)
.wrapping_mul(33)
.wrapping_add(bpos as usize * 4 + run_q as usize)
& 4095;
let after_apm4 = self.apm4.predict(after_apm3, apm4_ctx);
let apm5_ctx = ((c3 as usize >> 4).wrapping_mul(67) + (c2 as usize >> 4))
.wrapping_mul(67)
.wrapping_add((c1 as usize >> 6) * 8 + bpos as usize)
& 4095;
let after_apm5 = self.apm5.predict(after_apm4, apm5_ctx);
let apm6_ctx = (match_q as usize * 64 + bclass as usize * 8 + bpos as usize) & 2047;
let after_apm6 = self.apm6.predict(after_apm5, apm6_ctx);
let line_pos_q = quantize_line_pos(self.line_pos);
let pos_ctx = (line_pos_q as usize) ^ ((self.column_index as usize & 0xF) << 2);
let apm7_ctx = (pos_ctx.wrapping_mul(67) + (c0 as usize & 0xFF))
.wrapping_mul(67)
.wrapping_add(bpos as usize)
& 4095;
let final_p = self.apm7.predict(after_apm6, apm7_ctx);
final_p.clamp(1, 4095)
}
#[inline(always)]
pub fn update(&mut self, bit: u8) {
self.apm7.update(bit);
self.apm6.update(bit);
self.apm5.update(bit);
self.apm4.update(bit);
self.apm3.update(bit);
self.apm2.update(bit);
self.apm1.update(bit);
self.mixer.update(bit);
self.order0.update(self.c0 as usize, bit);
self.order1.update(bit);
self.order2.update(bit);
self.order3.update(bit);
self.order4.update(bit);
self.order5.update(bit);
self.order6.update(bit);
self.order7.update(bit);
self.order8.update(bit);
self.order9.update(bit);
self.match_model
.update(bit, self.bpos, self.c0, self.c1, self.c2);
self.word_model.update(bit);
self.sparse_model.update(bit);
self.run_model.update(bit);
self.json_model.update(bit);
self.indirect_model.update(bit);
self.dmc_model.update(bit);
self.isse_model.update(bit, self.c0, self.bpos);
self.c0 = (self.c0 << 1) | bit as u32;
self.bpos += 1;
if self.bpos >= 8 {
let byte = (self.c0 & 0xFF) as u8;
if byte == self.c1 {
self.run_len = self.run_len.saturating_add(1);
} else {
self.run_len = 1;
}
if byte == b'\n' {
self.line_pos = 0;
} else {
self.line_pos = self.line_pos.saturating_add(1);
}
if byte == 0x00 {
self.column_index = self.column_index.wrapping_add(1);
}
self.ppm_model.update_byte(byte);
self.dmc_model.on_byte_complete(byte);
self.c9 = self.c8;
self.c8 = self.c7;
self.c7 = self.c6;
self.c6 = self.c5;
self.c5 = self.c4;
self.c4 = self.c3;
self.c3 = self.c2;
self.c2 = self.c1;
self.c1 = byte;
self.c0 = 1; self.bpos = 0;
}
}
}
impl Default for CMEngine {
fn default() -> Self {
Self::new()
}
}
const FNV_OFFSET: u32 = 0x811C9DC5;
const FNV_PRIME: u32 = 0x01000193;
#[inline]
fn order1_hash(c1: u8, c0: u32) -> u32 {
let mut h = FNV_OFFSET;
h ^= c1 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c0 & 0xFF;
h = h.wrapping_mul(FNV_PRIME);
h
}
#[inline]
fn order2_hash(c2: u8, c1: u8, c0: u32) -> u32 {
let mut h = FNV_OFFSET;
h ^= c2 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c1 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c0 & 0xFF;
h = h.wrapping_mul(FNV_PRIME);
h
}
#[inline]
fn order3_hash(c3: u8, c2: u8, c1: u8, c0: u32) -> u32 {
let mut h = FNV_OFFSET;
h ^= c3 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c2 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c1 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c0 & 0xFF;
h = h.wrapping_mul(FNV_PRIME);
h
}
#[inline]
fn order4_hash(c4: u8, c3: u8, c2: u8, c1: u8, c0: u32) -> u32 {
let mut h = FNV_OFFSET;
h ^= c4 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c3 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c2 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c1 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c0 & 0xFF;
h = h.wrapping_mul(FNV_PRIME);
h
}
#[inline]
fn order5_hash(c5: u8, c4: u8, c3: u8, c2: u8, c1: u8, c0: u32) -> u32 {
let mut h = FNV_OFFSET;
h ^= c5 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c4 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c3 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c2 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c1 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c0 & 0xFF;
h = h.wrapping_mul(FNV_PRIME);
h
}
#[inline]
fn order6_hash(c6: u8, c5: u8, c4: u8, c3: u8, c2: u8, c1: u8, c0: u32) -> u32 {
let mut h = FNV_OFFSET;
h ^= c6 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c5 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c4 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c3 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c2 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c1 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c0 & 0xFF;
h = h.wrapping_mul(FNV_PRIME);
h
}
#[inline]
#[allow(clippy::too_many_arguments)]
fn order7_hash(c7: u8, c6: u8, c5: u8, c4: u8, c3: u8, c2: u8, c1: u8, c0: u32) -> u32 {
let mut h = FNV_OFFSET;
h ^= c7 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c6 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c5 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c4 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c3 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c2 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c1 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c0 & 0xFF;
h = h.wrapping_mul(FNV_PRIME);
h
}
#[inline]
#[allow(clippy::too_many_arguments)]
fn order8_hash(c8: u8, c7: u8, c6: u8, c5: u8, c4: u8, c3: u8, c2: u8, c1: u8, c0: u32) -> u32 {
let mut h = FNV_OFFSET;
h ^= c8 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c7 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c6 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c5 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c4 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c3 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c2 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c1 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c0 & 0xFF;
h = h.wrapping_mul(FNV_PRIME);
h
}
#[inline]
#[allow(clippy::too_many_arguments)]
fn order9_hash(
c9: u8,
c8: u8,
c7: u8,
c6: u8,
c5: u8,
c4: u8,
c3: u8,
c2: u8,
c1: u8,
c0: u32,
) -> u32 {
let mut h = FNV_OFFSET;
h ^= c9 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c8 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c7 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c6 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c5 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c4 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c3 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c2 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c1 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c0 & 0xFF;
h = h.wrapping_mul(FNV_PRIME);
h
}
#[inline]
fn quantize_run_for_mixer(run_len: u8) -> u8 {
match run_len {
0..=1 => 0,
2..=3 => 1,
4..=8 => 2,
_ => 3,
}
}
#[inline]
fn quantize_line_pos(line_pos: u16) -> u8 {
match line_pos {
0..=3 => 0, 4..=15 => 1, 16..=63 => 2, _ => 3, }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_prediction_is_balanced() {
let mut engine = CMEngine::new();
let p = engine.predict();
assert!(
(1800..=2200).contains(&p),
"initial prediction should be near 2048, got {p}"
);
}
#[test]
fn prediction_always_in_range() {
let mut engine = CMEngine::new();
let data = b"Hello, World! This is a test of the CM engine.";
for &byte in data {
for bpos in 0..8 {
let p = engine.predict();
assert!(
(1..=4095).contains(&p),
"prediction out of range at bpos {bpos}: {p}"
);
let bit = (byte >> (7 - bpos)) & 1;
engine.update(bit);
}
}
}
#[test]
fn context_state_tracks_correctly() {
let mut engine = CMEngine::new();
let byte: u8 = 0x42;
for bpos in 0..8 {
let _p = engine.predict();
let bit = (byte >> (7 - bpos)) & 1;
engine.update(bit);
}
assert_eq!(engine.c1, 0x42);
assert_eq!(engine.c0, 1);
assert_eq!(engine.bpos, 0);
}
#[test]
fn repeated_byte_adapts() {
let mut engine = CMEngine::new();
let byte: u8 = b'A';
let mut total_bits: f64 = 0.0;
let mut first_byte_bits: f64 = 0.0;
for iteration in 0..50 {
let mut byte_bits: f64 = 0.0;
for bpos in 0..8 {
let p = engine.predict();
let bit = (byte >> (7 - bpos)) & 1;
let prob_of_bit = if bit == 1 {
p as f64 / 4096.0
} else {
1.0 - p as f64 / 4096.0
};
byte_bits += -prob_of_bit.max(0.001).log2();
engine.update(bit);
}
if iteration == 0 {
first_byte_bits = byte_bits;
}
total_bits += byte_bits;
}
let avg = total_bits / 50.0;
assert!(
avg < first_byte_bits,
"engine should improve: first={first_byte_bits:.2}, avg={avg:.2}"
);
}
#[test]
fn hash_functions_differ() {
let h1 = order1_hash(65, 1);
let h2 = order2_hash(0, 65, 1);
let h3 = order3_hash(0, 0, 65, 1);
assert_ne!(h1, h2);
assert_ne!(h2, h3);
}
#[test]
fn engine_deterministic() {
let data = b"determinism test";
let mut e1 = CMEngine::new();
let mut e2 = CMEngine::new();
for &byte in data.iter() {
for bpos in 0..8 {
let p1 = e1.predict();
let p2 = e2.predict();
assert_eq!(p1, p2, "engines diverged at bpos {bpos}");
let bit = (byte >> (7 - bpos)) & 1;
e1.update(bit);
e2.update(bit);
}
}
}
}