1use crate::mixer::logistic::{squash, stretch};
21use crate::state::state_map::StateMap;
22use crate::state::state_table::StateTable;
23
24const HT_SIZE: usize = 1 << 21;
27const HT_MASK: usize = HT_SIZE - 1;
28
29const NUM_STATES: usize = 256;
31
32const FNV_PRIME: u32 = 0x01000193;
34
35#[derive(Clone, Copy)]
37struct WeightPair {
38 w0: i32,
39 w1: i32,
40}
41
42const W_SHIFT: i32 = 16;
44const W_UNITY: i32 = 1 << W_SHIFT;
45const BIAS_SCALE: i64 = 64;
46const W_CLAMP: i64 = 524287;
47
48struct IcmLevel {
50 ht: Vec<u8>,
51 smap: StateMap,
52 last_hash: u32,
53 last_state: u8,
54}
55
56impl IcmLevel {
57 fn new() -> Self {
58 IcmLevel {
59 ht: vec![0u8; HT_SIZE],
60 smap: StateMap::new(),
61 last_hash: 0,
62 last_state: 0,
63 }
64 }
65
66 #[inline]
67 fn predict(&mut self, ctx_hash: u32) -> u32 {
68 self.last_hash = ctx_hash;
69 let state = self.ht[ctx_hash as usize & HT_MASK];
70 self.last_state = state;
71 self.smap.predict(state)
72 }
73
74 #[inline]
75 fn update(&mut self, bit: u8) {
76 self.smap.update(self.last_state, bit);
77 let new_state = StateTable::next(self.last_state, bit);
78 self.ht[self.last_hash as usize & HT_MASK] = new_state;
79 }
80}
81
82struct IsseLevel {
84 ht: Vec<u8>,
85 weights: [WeightPair; NUM_STATES],
86 last_hash: u32,
87 last_state: u8,
88 last_d_in: i32,
89 last_p_out: i32,
90}
91
92impl IsseLevel {
93 fn new() -> Self {
94 let mut weights = [WeightPair { w0: W_UNITY, w1: 0 }; NUM_STATES];
95
96 for (s, wt) in weights.iter_mut().enumerate() {
98 let state_p = StateTable::prob(s as u8);
99 let state_d = stretch(state_p as u32);
100 wt.w1 = state_d * 256;
101 }
102
103 IsseLevel {
104 ht: vec![0u8; HT_SIZE],
105 weights,
106 last_hash: 0,
107 last_state: 0,
108 last_d_in: 0,
109 last_p_out: 2048,
110 }
111 }
112
113 #[inline]
114 fn predict(&mut self, p_in: u32, ctx_hash: u32) -> u32 {
115 self.last_hash = ctx_hash;
116 let state = self.ht[ctx_hash as usize & HT_MASK];
117 self.last_state = state;
118
119 let d_in = stretch(p_in);
120 self.last_d_in = d_in;
121
122 let wt = &self.weights[state as usize];
123 let d_out = (wt.w0 as i64 * d_in as i64 + wt.w1 as i64 * BIAS_SCALE) >> W_SHIFT;
124 let p_out = squash(d_out as i32).clamp(1, 4095) as i32;
125 self.last_p_out = p_out;
126 p_out as u32
127 }
128
129 #[inline]
130 fn update(&mut self, bit: u8) {
131 let err = (bit as i32) * 32767 - self.last_p_out * 8;
132 let wt = &mut self.weights[self.last_state as usize];
133
134 let delta_w0 = (err as i64 * self.last_d_in as i64 + (1i64 << 12)) >> 13;
135 wt.w0 = (wt.w0 as i64 + delta_w0).clamp(-W_CLAMP, W_CLAMP) as i32;
136
137 let delta_w1 = (err + 16) >> 5;
138 wt.w1 = (wt.w1 as i64 + delta_w1 as i64).clamp(-W_CLAMP, W_CLAMP) as i32;
139
140 let new_state = StateTable::next(self.last_state, bit);
141 self.ht[self.last_hash as usize & HT_MASK] = new_state;
142 }
143}
144
145pub struct IsseChain {
154 icm: IcmLevel,
155 isse1: IsseLevel,
156 isse2: IsseLevel,
157 word_pos: u8,
159}
160
161impl IsseChain {
162 pub fn new() -> Self {
163 IsseChain {
164 icm: IcmLevel::new(),
165 isse1: IsseLevel::new(),
166 isse2: IsseLevel::new(),
167 word_pos: 0,
168 }
169 }
170
171 #[inline]
173 #[allow(clippy::too_many_arguments)]
174 pub fn predict(&mut self, c0: u32, c1: u8, c2: u8, c3: u8, bpos: u8) -> u32 {
175 let h0 = word_pos_hash(self.word_pos, c0, bpos);
179 let p0 = self.icm.predict(h0);
180
181 let h1 = class_transition_hash(c1, c2, c0, bpos);
185 let p1 = self.isse1.predict(p0, h1);
186
187 let h2 = sparse_skip2_hash(c1, c3, c0, bpos);
191 let p2 = self.isse2.predict(p1, h2);
192
193 p2.clamp(1, 4095)
194 }
195
196 #[inline]
198 pub fn update(&mut self, bit: u8, c0: u32, bpos: u8) {
199 self.isse2.update(bit);
200 self.isse1.update(bit);
201 self.icm.update(bit);
202
203 if bpos == 7 {
205 let byte = ((c0 << 1 | bit as u32) & 0xFF) as u8;
206 if is_word_boundary(byte) {
207 self.word_pos = 0;
208 } else {
209 self.word_pos = self.word_pos.saturating_add(1);
210 }
211 }
212 }
213}
214
215impl Default for IsseChain {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221#[inline]
223fn is_word_boundary(b: u8) -> bool {
224 matches!(
225 b,
226 b' ' | b'\n'
227 | b'\r'
228 | b'\t'
229 | b'.'
230 | b','
231 | b';'
232 | b':'
233 | b'!'
234 | b'?'
235 | b'('
236 | b')'
237 | b'['
238 | b']'
239 | b'{'
240 | b'}'
241 | b'<'
242 | b'>'
243 | b'"'
244 | b'\''
245 | b'/'
246 | b'='
247 )
248}
249
250#[inline]
252fn classify(b: u8) -> u8 {
253 match b {
254 0..=31 => 0,
255 b' ' => 1,
256 b'0'..=b'9' => 2,
257 b'A'..=b'Z' => 3,
258 b'a'..=b'z' => 4,
259 b'!'..=b'/' | b':'..=b'@' | b'['..=b'`' | b'{'..=b'~' => 5,
260 128..=255 => 6,
261 _ => 7,
262 }
263}
264
265const SEED_WP: u32 = 0xA5A5A5A5;
268const SEED_CT: u32 = 0x5A5A5A5A;
269const SEED_SK: u32 = 0x3C3C3C3C;
270
271#[inline]
273fn word_pos_hash(word_pos: u8, c0: u32, bpos: u8) -> u32 {
274 let mut h = SEED_WP;
275 h ^= word_pos as u32;
276 h = h.wrapping_mul(FNV_PRIME);
277 h ^= c0 & 0x1FF;
278 h = h.wrapping_mul(FNV_PRIME);
279 h ^= bpos as u32;
280 h = h.wrapping_mul(FNV_PRIME);
281 h
282}
283
284#[inline]
286fn class_transition_hash(c1: u8, c2: u8, c0: u32, bpos: u8) -> u32 {
287 let mut h = SEED_CT;
288 h ^= classify(c1) as u32;
289 h = h.wrapping_mul(FNV_PRIME);
290 h ^= classify(c2) as u32;
291 h = h.wrapping_mul(FNV_PRIME);
292 h ^= c0 & 0x1FF;
293 h = h.wrapping_mul(FNV_PRIME);
294 h ^= bpos as u32;
295 h = h.wrapping_mul(FNV_PRIME);
296 h
297}
298
299#[inline]
301fn sparse_skip2_hash(c1: u8, c3: u8, c0: u32, bpos: u8) -> u32 {
302 let mut h = SEED_SK;
303 h ^= c1 as u32;
304 h = h.wrapping_mul(FNV_PRIME);
305 h ^= c3 as u32;
306 h = h.wrapping_mul(FNV_PRIME);
307 h ^= c0 & 0x1FF;
308 h = h.wrapping_mul(FNV_PRIME);
309 h ^= bpos as u32;
310 h = h.wrapping_mul(FNV_PRIME);
311 h
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn initial_prediction_in_range() {
320 let mut chain = IsseChain::new();
321 let p = chain.predict(1, 0, 0, 0, 0);
322 assert!(
323 (1..=4095).contains(&p),
324 "initial prediction out of range: {p}"
325 );
326 }
327
328 #[test]
329 fn prediction_always_in_range() {
330 let mut chain = IsseChain::new();
331 for bpos in 0..8u8 {
332 let p = chain.predict(1, 65, 66, 67, bpos);
333 assert!((1..=4095).contains(&p), "out of range: {p}");
334 chain.update(1, 1, bpos);
335 }
336 }
337
338 #[test]
339 fn adapts_to_ones() {
340 let mut chain = IsseChain::new();
341 let mut last_p = 0u32;
342 for i in 0..200 {
343 let p = chain.predict(1, 0, 0, 0, 0);
344 if i > 100 {
345 last_p = p;
346 }
347 chain.update(1, 1, 0);
348 }
349 assert!(last_p > 2200, "should adapt toward 1: got {last_p}");
350 }
351
352 #[test]
353 fn adapts_to_zeros() {
354 let mut chain = IsseChain::new();
355 let mut last_p = 0u32;
356 for i in 0..200 {
357 let p = chain.predict(1, 0, 0, 0, 0);
358 if i > 100 {
359 last_p = p;
360 }
361 chain.update(0, 1, 0);
362 }
363 assert!(last_p < 1800, "should adapt toward 0: got {last_p}");
364 }
365
366 #[test]
367 fn different_contexts_diverge() {
368 let mut chain = IsseChain::new();
369 for _ in 0..100 {
370 chain.predict(1, 65, 0, 0, 0);
371 chain.update(1, 1, 0);
372 }
373 for _ in 0..100 {
374 chain.predict(1, 66, 0, 0, 0);
375 chain.update(0, 1, 0);
376 }
377 let p_a = chain.predict(1, 65, 0, 0, 0);
378 let p_b = chain.predict(1, 66, 0, 0, 0);
379 assert!(
380 p_a > p_b,
381 "trained contexts should diverge: p_a={p_a}, p_b={p_b}"
382 );
383 }
384
385 #[test]
386 fn deterministic() {
387 let mut ch1 = IsseChain::new();
388 let mut ch2 = IsseChain::new();
389 let data = b"ISSE determinism";
390 for &byte in data {
391 for bpos in 0..8u8 {
392 let bit = (byte >> (7 - bpos)) & 1;
393 let c0 = if bpos == 0 {
394 1u32
395 } else {
396 let mut p = 1u32;
397 for prev in 0..bpos {
398 p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
399 }
400 p
401 };
402 let p1 = ch1.predict(c0, byte, 0, 0, bpos);
403 let p2 = ch2.predict(c0, byte, 0, 0, bpos);
404 assert_eq!(p1, p2, "chains diverged at bpos {bpos}");
405 ch1.update(bit, c0, bpos);
406 ch2.update(bit, c0, bpos);
407 }
408 }
409 }
410
411 #[test]
412 fn word_boundary_detection() {
413 assert!(is_word_boundary(b' '));
414 assert!(is_word_boundary(b'\n'));
415 assert!(is_word_boundary(b'.'));
416 assert!(!is_word_boundary(b'a'));
417 assert!(!is_word_boundary(b'5'));
418 }
419}