Skip to main content

datacortex_core/mixer/
isse.rs

1//! ISSE — Indirect Secondary Symbol Estimation model.
2//!
3//! Provides model #19 for the mixer: a 3-level ISSE chain using
4//! CROSS-CONTEXT hashes that combine dimensions the main n-gram
5//! models don't capture.
6//!
7//! The main order-0..9 models use pure n-gram context (c1, c2, ..., c9).
8//! This ISSE model uses:
9//!   Level 0 (ICM): word-position context (position within current word + bpos)
10//!   Level 1 (ISSE): byte-class bigram transition context
11//!   Level 2 (ISSE): sparse skip-2 context (c1, c3 — skipping c2)
12//!
13//! These contexts are orthogonal to the main models, giving the mixer
14//! genuinely new information rather than duplicating existing signals.
15//!
16//! Architecture from ZPAQ (Matt Mahoney):
17//!   ICM: context_hash → bit_history_state → probability (StateMap)
18//!   ISSE: (bit_history_state, p_in) → p_out via learned weights (w0, w1)
19
20use crate::mixer::logistic::{squash, stretch};
21use crate::state::state_map::StateMap;
22use crate::state::state_table::StateTable;
23
24/// Hash table size per level (2^21 = 2M entries, 1 byte each = 2MB).
25/// Total: 3 levels * 2MB = 6MB.
26const HT_SIZE: usize = 1 << 21;
27const HT_MASK: usize = HT_SIZE - 1;
28
29/// Number of bit history states.
30const NUM_STATES: usize = 256;
31
32/// FNV hash prime.
33const FNV_PRIME: u32 = 0x01000193;
34
35/// Weight pair for ISSE.
36#[derive(Clone, Copy)]
37struct WeightPair {
38    w0: i32,
39    w1: i32,
40}
41
42/// ZPAQ scaling constants.
43const W_SHIFT: i32 = 16;
44const W_UNITY: i32 = 1 << W_SHIFT;
45const BIAS_SCALE: i64 = 64;
46const W_CLAMP: i64 = 524287;
47
48/// ICM level — base predictor.
49struct IcmLevel {
50    ht: Vec<u8>,
51    smap: StateMap,
52    last_hash: u32,
53    last_state: u8,
54}
55
56impl IcmLevel {
57    fn new() -> Self {
58        IcmLevel {
59            ht: vec![0u8; HT_SIZE],
60            smap: StateMap::new(),
61            last_hash: 0,
62            last_state: 0,
63        }
64    }
65
66    #[inline]
67    fn predict(&mut self, ctx_hash: u32) -> u32 {
68        self.last_hash = ctx_hash;
69        let state = self.ht[ctx_hash as usize & HT_MASK];
70        self.last_state = state;
71        self.smap.predict(state)
72    }
73
74    #[inline]
75    fn update(&mut self, bit: u8) {
76        self.smap.update(self.last_state, bit);
77        let new_state = StateTable::next(self.last_state, bit);
78        self.ht[self.last_hash as usize & HT_MASK] = new_state;
79    }
80}
81
82/// ISSE level — refines input probability.
83struct IsseLevel {
84    ht: Vec<u8>,
85    weights: [WeightPair; NUM_STATES],
86    last_hash: u32,
87    last_state: u8,
88    last_d_in: i32,
89    last_p_out: i32,
90}
91
92impl IsseLevel {
93    fn new() -> Self {
94        let mut weights = [WeightPair { w0: W_UNITY, w1: 0 }; NUM_STATES];
95
96        // Initialize w1 bias from state table (ZPAQ style).
97        for (s, wt) in weights.iter_mut().enumerate() {
98            let state_p = StateTable::prob(s as u8);
99            let state_d = stretch(state_p as u32);
100            wt.w1 = state_d * 256;
101        }
102
103        IsseLevel {
104            ht: vec![0u8; HT_SIZE],
105            weights,
106            last_hash: 0,
107            last_state: 0,
108            last_d_in: 0,
109            last_p_out: 2048,
110        }
111    }
112
113    #[inline]
114    fn predict(&mut self, p_in: u32, ctx_hash: u32) -> u32 {
115        self.last_hash = ctx_hash;
116        let state = self.ht[ctx_hash as usize & HT_MASK];
117        self.last_state = state;
118
119        let d_in = stretch(p_in);
120        self.last_d_in = d_in;
121
122        let wt = &self.weights[state as usize];
123        let d_out = (wt.w0 as i64 * d_in as i64 + wt.w1 as i64 * BIAS_SCALE) >> W_SHIFT;
124        let p_out = squash(d_out as i32).clamp(1, 4095) as i32;
125        self.last_p_out = p_out;
126        p_out as u32
127    }
128
129    #[inline]
130    fn update(&mut self, bit: u8) {
131        let err = (bit as i32) * 32767 - self.last_p_out * 8;
132        let wt = &mut self.weights[self.last_state as usize];
133
134        let delta_w0 = (err as i64 * self.last_d_in as i64 + (1i64 << 12)) >> 13;
135        wt.w0 = (wt.w0 as i64 + delta_w0).clamp(-W_CLAMP, W_CLAMP) as i32;
136
137        let delta_w1 = (err + 16) >> 5;
138        wt.w1 = (wt.w1 as i64 + delta_w1 as i64).clamp(-W_CLAMP, W_CLAMP) as i32;
139
140        let new_state = StateTable::next(self.last_state, bit);
141        self.ht[self.last_hash as usize & HT_MASK] = new_state;
142    }
143}
144
145/// ISSE model: 3-level chain with cross-context hashes.
146///
147/// Uses contexts orthogonal to the main n-gram models:
148/// - Word-position context (captures intra-word patterns)
149/// - Byte-class transition context (captures character class patterns)
150/// - Sparse skip-2 context (captures periodic patterns)
151///
152/// Memory: 3 * 2MB = 6MB.
153pub struct IsseChain {
154    icm: IcmLevel,
155    isse1: IsseLevel,
156    isse2: IsseLevel,
157    /// Word position: distance since last space/newline/punctuation (0-255).
158    word_pos: u8,
159}
160
161impl IsseChain {
162    pub fn new() -> Self {
163        IsseChain {
164            icm: IcmLevel::new(),
165            isse1: IsseLevel::new(),
166            isse2: IsseLevel::new(),
167            word_pos: 0,
168        }
169    }
170
171    /// Produce a prediction for the mixer.
172    #[inline]
173    #[allow(clippy::too_many_arguments)]
174    pub fn predict(&mut self, c0: u32, c1: u8, c2: u8, c3: u8, bpos: u8) -> u32 {
175        // Level 0 (ICM): word-position context.
176        // Context = (word_pos, c0_partial, bpos).
177        // This captures patterns like "3rd character in a word is usually lowercase".
178        let h0 = word_pos_hash(self.word_pos, c0, bpos);
179        let p0 = self.icm.predict(h0);
180
181        // Level 1 (ISSE): byte-class bigram transition.
182        // Context = (class(c1), class(c2), c0_partial, bpos).
183        // Captures character class transitions (letter->digit, punct->letter, etc.)
184        let h1 = class_transition_hash(c1, c2, c0, bpos);
185        let p1 = self.isse1.predict(p0, h1);
186
187        // Level 2 (ISSE): sparse skip-2 context.
188        // Context = (c1, c3, c0_partial, bpos) — skips c2.
189        // Captures periodic/skip patterns the sequential models miss.
190        let h2 = sparse_skip2_hash(c1, c3, c0, bpos);
191        let p2 = self.isse2.predict(p1, h2);
192
193        p2.clamp(1, 4095)
194    }
195
196    /// Update after observing bit.
197    #[inline]
198    pub fn update(&mut self, bit: u8, c0: u32, bpos: u8) {
199        self.isse2.update(bit);
200        self.isse1.update(bit);
201        self.icm.update(bit);
202
203        // Track word position (update after byte boundary).
204        if bpos == 7 {
205            let byte = ((c0 << 1 | bit as u32) & 0xFF) as u8;
206            if is_word_boundary(byte) {
207                self.word_pos = 0;
208            } else {
209                self.word_pos = self.word_pos.saturating_add(1);
210            }
211        }
212    }
213}
214
215impl Default for IsseChain {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221/// Check if byte is a word boundary.
222#[inline]
223fn is_word_boundary(b: u8) -> bool {
224    matches!(
225        b,
226        b' ' | b'\n'
227            | b'\r'
228            | b'\t'
229            | b'.'
230            | b','
231            | b';'
232            | b':'
233            | b'!'
234            | b'?'
235            | b'('
236            | b')'
237            | b'['
238            | b']'
239            | b'{'
240            | b'}'
241            | b'<'
242            | b'>'
243            | b'"'
244            | b'\''
245            | b'/'
246            | b'='
247    )
248}
249
250/// Byte classifier (0-7).
251#[inline]
252fn classify(b: u8) -> u8 {
253    match b {
254        0..=31 => 0,
255        b' ' => 1,
256        b'0'..=b'9' => 2,
257        b'A'..=b'Z' => 3,
258        b'a'..=b'z' => 4,
259        b'!'..=b'/' | b':'..=b'@' | b'['..=b'`' | b'{'..=b'~' => 5,
260        128..=255 => 6,
261        _ => 7,
262    }
263}
264
265// --- Hash functions with unique seeds ---
266
267const SEED_WP: u32 = 0xA5A5A5A5;
268const SEED_CT: u32 = 0x5A5A5A5A;
269const SEED_SK: u32 = 0x3C3C3C3C;
270
271/// Word-position context hash.
272#[inline]
273fn word_pos_hash(word_pos: u8, c0: u32, bpos: u8) -> u32 {
274    let mut h = SEED_WP;
275    h ^= word_pos as u32;
276    h = h.wrapping_mul(FNV_PRIME);
277    h ^= c0 & 0x1FF;
278    h = h.wrapping_mul(FNV_PRIME);
279    h ^= bpos as u32;
280    h = h.wrapping_mul(FNV_PRIME);
281    h
282}
283
284/// Class-transition context hash.
285#[inline]
286fn class_transition_hash(c1: u8, c2: u8, c0: u32, bpos: u8) -> u32 {
287    let mut h = SEED_CT;
288    h ^= classify(c1) as u32;
289    h = h.wrapping_mul(FNV_PRIME);
290    h ^= classify(c2) as u32;
291    h = h.wrapping_mul(FNV_PRIME);
292    h ^= c0 & 0x1FF;
293    h = h.wrapping_mul(FNV_PRIME);
294    h ^= bpos as u32;
295    h = h.wrapping_mul(FNV_PRIME);
296    h
297}
298
299/// Sparse skip-2 context hash (c1, c3 — skips c2).
300#[inline]
301fn sparse_skip2_hash(c1: u8, c3: u8, c0: u32, bpos: u8) -> u32 {
302    let mut h = SEED_SK;
303    h ^= c1 as u32;
304    h = h.wrapping_mul(FNV_PRIME);
305    h ^= c3 as u32;
306    h = h.wrapping_mul(FNV_PRIME);
307    h ^= c0 & 0x1FF;
308    h = h.wrapping_mul(FNV_PRIME);
309    h ^= bpos as u32;
310    h = h.wrapping_mul(FNV_PRIME);
311    h
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn initial_prediction_in_range() {
320        let mut chain = IsseChain::new();
321        let p = chain.predict(1, 0, 0, 0, 0);
322        assert!(
323            (1..=4095).contains(&p),
324            "initial prediction out of range: {p}"
325        );
326    }
327
328    #[test]
329    fn prediction_always_in_range() {
330        let mut chain = IsseChain::new();
331        for bpos in 0..8u8 {
332            let p = chain.predict(1, 65, 66, 67, bpos);
333            assert!((1..=4095).contains(&p), "out of range: {p}");
334            chain.update(1, 1, bpos);
335        }
336    }
337
338    #[test]
339    fn adapts_to_ones() {
340        let mut chain = IsseChain::new();
341        let mut last_p = 0u32;
342        for i in 0..200 {
343            let p = chain.predict(1, 0, 0, 0, 0);
344            if i > 100 {
345                last_p = p;
346            }
347            chain.update(1, 1, 0);
348        }
349        assert!(last_p > 2200, "should adapt toward 1: got {last_p}");
350    }
351
352    #[test]
353    fn adapts_to_zeros() {
354        let mut chain = IsseChain::new();
355        let mut last_p = 0u32;
356        for i in 0..200 {
357            let p = chain.predict(1, 0, 0, 0, 0);
358            if i > 100 {
359                last_p = p;
360            }
361            chain.update(0, 1, 0);
362        }
363        assert!(last_p < 1800, "should adapt toward 0: got {last_p}");
364    }
365
366    #[test]
367    fn different_contexts_diverge() {
368        let mut chain = IsseChain::new();
369        for _ in 0..100 {
370            chain.predict(1, 65, 0, 0, 0);
371            chain.update(1, 1, 0);
372        }
373        for _ in 0..100 {
374            chain.predict(1, 66, 0, 0, 0);
375            chain.update(0, 1, 0);
376        }
377        let p_a = chain.predict(1, 65, 0, 0, 0);
378        let p_b = chain.predict(1, 66, 0, 0, 0);
379        assert!(
380            p_a > p_b,
381            "trained contexts should diverge: p_a={p_a}, p_b={p_b}"
382        );
383    }
384
385    #[test]
386    fn deterministic() {
387        let mut ch1 = IsseChain::new();
388        let mut ch2 = IsseChain::new();
389        let data = b"ISSE determinism";
390        for &byte in data {
391            for bpos in 0..8u8 {
392                let bit = (byte >> (7 - bpos)) & 1;
393                let c0 = if bpos == 0 {
394                    1u32
395                } else {
396                    let mut p = 1u32;
397                    for prev in 0..bpos {
398                        p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
399                    }
400                    p
401                };
402                let p1 = ch1.predict(c0, byte, 0, 0, bpos);
403                let p2 = ch2.predict(c0, byte, 0, 0, bpos);
404                assert_eq!(p1, p2, "chains diverged at bpos {bpos}");
405                ch1.update(bit, c0, bpos);
406                ch2.update(bit, c0, bpos);
407            }
408        }
409    }
410
411    #[test]
412    fn word_boundary_detection() {
413        assert!(is_word_boundary(b' '));
414        assert!(is_word_boundary(b'\n'));
415        assert!(is_word_boundary(b'.'));
416        assert!(!is_word_boundary(b'a'));
417        assert!(!is_word_boundary(b'5'));
418    }
419}