1use crate::mixer::logistic::{squash, stretch};
17
18pub const NUM_MODELS: usize = 28;
23
24const FINE_SETS: usize = 65536;
26
27const MEDIUM_SETS: usize = 16384;
29
30const COARSE_SETS: usize = 4096;
32
33const W_SCALE: i32 = 4096;
35
36const INITIAL_WEIGHTS: [i32; NUM_MODELS] = [
40 200, 300, 60, 350, 60, 450, 60, 450, 60, 450, 60, 300, 60, 250, 60, 200, 60, 180, 60, 300, 250, 250, 200, 250, 200, 50, 30, 150, ];
60
61const FINE_LR: i32 = 2;
63
64const MEDIUM_LR: i32 = 3;
66
67const COARSE_LR: i32 = 4;
69
70pub struct DualMixer {
72 fine_weights: Vec<[i32; NUM_MODELS]>,
74 medium_weights: Vec<[i32; NUM_MODELS]>,
76 coarse_weights: Vec<[i32; NUM_MODELS]>,
78 last_d: [i32; NUM_MODELS],
80 last_fine_ctx: usize,
82 last_medium_ctx: usize,
84 last_coarse_ctx: usize,
86 last_p: u32,
88}
89
90impl DualMixer {
91 pub fn new() -> Self {
92 DualMixer {
93 fine_weights: vec![INITIAL_WEIGHTS; FINE_SETS],
94 medium_weights: vec![INITIAL_WEIGHTS; MEDIUM_SETS],
95 coarse_weights: vec![INITIAL_WEIGHTS; COARSE_SETS],
96 last_d: [0; NUM_MODELS],
97 last_fine_ctx: 0,
98 last_medium_ctx: 0,
99 last_coarse_ctx: 0,
100 last_p: 2048,
101 }
102 }
103
104 #[inline(always)]
106 #[allow(clippy::needless_range_loop)]
107 #[allow(clippy::too_many_arguments)]
108 pub fn predict(
109 &mut self,
110 predictions: &[u32; NUM_MODELS],
111 c0: u32,
112 c1: u8,
113 bpos: u8,
114 byte_class: u8,
115 match_len_q: u8,
116 run_q: u8,
117 _xml_state: u8,
118 ) -> u32 {
119 for i in 0..NUM_MODELS {
121 self.last_d[i] = stretch(predictions[i]);
122 }
123
124 self.last_fine_ctx = fine_context(c0, c1, bpos, byte_class, match_len_q, run_q);
126 self.last_medium_ctx = medium_context(c0, c1, bpos, run_q, match_len_q);
128 self.last_coarse_ctx = coarse_context(c0, bpos);
130
131 let fw = &self.fine_weights[self.last_fine_ctx];
133 let mw = &self.medium_weights[self.last_medium_ctx];
134 let cw = &self.coarse_weights[self.last_coarse_ctx];
135
136 let mut fine_sum: i64 = 0;
137 let mut medium_sum: i64 = 0;
138 let mut coarse_sum: i64 = 0;
139 for i in 0..NUM_MODELS {
140 let d = self.last_d[i] as i64;
141 fine_sum += fw[i] as i64 * d;
142 medium_sum += mw[i] as i64 * d;
143 coarse_sum += cw[i] as i64 * d;
144 }
145 let fine_d = (fine_sum / W_SCALE as i64) as i32;
146 let medium_d = (medium_sum / W_SCALE as i64) as i32;
147 let coarse_d = (coarse_sum / W_SCALE as i64) as i32;
148
149 let blended_d = (fine_d as i64 * 5 + medium_d as i64 * 3 + coarse_d as i64 * 2) / 10;
152 let p = squash(blended_d as i32).clamp(1, 4095);
153 self.last_p = p;
154 p
155 }
156
157 #[inline(always)]
159 #[allow(clippy::needless_range_loop)]
160 pub fn update(&mut self, bit: u8) {
161 let error = (bit as i32) * 4096 - self.last_p as i32;
162
163 let fw = &mut self.fine_weights[self.last_fine_ctx];
165 for i in 0..NUM_MODELS {
166 let delta = (FINE_LR as i64 * self.last_d[i] as i64 * error as i64) >> 16;
167 fw[i] = (fw[i] as i64 + delta).clamp(-32768, 32767) as i32;
168 }
169
170 let mw = &mut self.medium_weights[self.last_medium_ctx];
172 for i in 0..NUM_MODELS {
173 let delta = (MEDIUM_LR as i64 * self.last_d[i] as i64 * error as i64) >> 16;
174 mw[i] = (mw[i] as i64 + delta).clamp(-32768, 32767) as i32;
175 }
176
177 let cw = &mut self.coarse_weights[self.last_coarse_ctx];
179 for i in 0..NUM_MODELS {
180 let delta = (COARSE_LR as i64 * self.last_d[i] as i64 * error as i64) >> 16;
181 cw[i] = (cw[i] as i64 + delta).clamp(-32768, 32767) as i32;
182 }
183 }
184}
185
186impl Default for DualMixer {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192#[inline]
200pub fn byte_class(b: u8) -> u8 {
201 match b {
202 0..=31 => 0, b' ' => 1, b'0'..=b'9' => 2, b'A'..=b'Z' => 3, b'a'..=b'z' => 4, b'!'..=b'/' => 5, b':'..=b'@' => 5, b'['..=b'`' => 5, b'{'..=b'~' => 5, 0x80..=0x9F => 6, 0xA0..=0xBF => 7, 0xC0..=0xDF => 8, 0xE0..=0xFE => 9, 0xFF => 10, _ => 11, }
218}
219
220#[inline]
223fn fine_context(c0: u32, c1: u8, bpos: u8, bclass: u8, match_q: u8, run_q: u8) -> usize {
224 let mut h: usize = c0 as usize & 0xFF;
227 h = h.wrapping_mul(97) + (c1 as usize >> 4);
228 h = h.wrapping_mul(97) + bpos as usize;
229 h = h.wrapping_mul(97) + (bclass as usize & 0x7);
230 h = h.wrapping_mul(97) + (match_q as usize & 0x3);
231 h = h.wrapping_mul(97) + (run_q as usize & 0x3);
232 h & (FINE_SETS - 1)
233}
234
235#[inline]
238fn medium_context(c0: u32, c1: u8, bpos: u8, run_q: u8, match_q: u8) -> usize {
239 let bclass = byte_class(c1);
241 let mut h: usize = c0 as usize & 0xFF;
242 h = h.wrapping_mul(67) + (c1 as usize >> 4);
243 h = h.wrapping_mul(67) + bpos as usize;
244 h = h.wrapping_mul(67) + bclass as usize;
245 h = h.wrapping_mul(67) + (run_q as usize & 0x3);
246 h = h.wrapping_mul(67) + (match_q as usize & 0x3);
247 h & (MEDIUM_SETS - 1)
248}
249
250#[inline]
252fn coarse_context(c0: u32, bpos: u8) -> usize {
253 ((c0 as usize & 0xFF) | ((bpos as usize) << 8)) & (COARSE_SETS - 1)
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn initial_prediction_near_balanced() {
262 let mut mixer = DualMixer::new();
263 let preds = [2048u32; NUM_MODELS];
264 let p = mixer.predict(&preds, 1, 0, 0, 0, 0, 0, 0);
265 assert!(
266 (1900..=2100).contains(&p),
267 "initial prediction should be near 2048, got {p}"
268 );
269 }
270
271 #[test]
272 fn prediction_in_range() {
273 let mut mixer = DualMixer::new();
274 let mut preds = [2048u32; NUM_MODELS];
275 preds[0] = 100;
276 preds[1] = 4000;
277 preds[4] = 3000;
278 preds[7] = 500;
279 let p = mixer.predict(&preds, 128, b'a', 3, 4, 1, 0, 0);
280 assert!((1..=4095).contains(&p), "prediction out of range: {p}");
281 }
282
283 #[test]
284 fn update_changes_weights() {
285 let mut mixer = DualMixer::new();
286 let preds = [2048u32; NUM_MODELS];
287 mixer.predict(&preds, 1, 0, 0, 0, 0, 0, 0);
288 let before = mixer.fine_weights[mixer.last_fine_ctx];
289 mixer.update(1);
290 let after = mixer.fine_weights[mixer.last_fine_ctx];
291 let _ = (before, after);
292 }
293
294 #[test]
295 fn mixer_adapts_to_biased_input() {
296 let mut mixer = DualMixer::new();
297 for _ in 0..100 {
298 let mut preds = [2048u32; NUM_MODELS];
299 preds[0] = 3500;
300 let p = mixer.predict(&preds, 1, 0, 0, 0, 0, 0, 0);
301 let _ = p;
302 mixer.update(1);
303 }
304 let mut preds = [2048u32; NUM_MODELS];
305 preds[0] = 3500;
306 let p = mixer.predict(&preds, 1, 0, 0, 0, 0, 0, 0);
307 assert!(p > 2500, "mixer should have learned to trust model 0: {p}");
308 }
309
310 #[test]
311 fn byte_class_categories() {
312 assert_eq!(byte_class(0), 0); assert_eq!(byte_class(b' '), 1); assert_eq!(byte_class(b'5'), 2); assert_eq!(byte_class(b'A'), 3); assert_eq!(byte_class(b'z'), 4); assert_eq!(byte_class(b'.'), 5); assert_eq!(byte_class(0x80), 6); assert_eq!(byte_class(0x90), 6); assert_eq!(byte_class(0xA0), 7); assert_eq!(byte_class(0xC0), 8); assert_eq!(byte_class(0xE0), 9); assert_eq!(byte_class(0xFF), 10); }
325
326 #[test]
327 fn fine_context_in_range() {
328 for c0 in [1u32, 128, 255] {
329 for bpos in 0..8u8 {
330 let ctx = fine_context(c0, 0xFF, bpos, 7, 3, 3);
331 assert!(ctx < FINE_SETS, "fine context out of range: {ctx}");
332 }
333 }
334 }
335
336 #[test]
337 fn medium_context_in_range() {
338 for c0 in [1u32, 128, 255] {
339 for bpos in 0..8u8 {
340 let ctx = medium_context(c0, 0xFF, bpos, 3, 3);
341 assert!(ctx < MEDIUM_SETS, "medium context out of range: {ctx}");
342 }
343 }
344 }
345
346 #[test]
347 fn coarse_context_in_range() {
348 for c0 in [1u32, 128, 255] {
349 for bpos in 0..8u8 {
350 let ctx = coarse_context(c0, bpos);
351 assert!(ctx < COARSE_SETS, "coarse context out of range: {ctx}");
352 }
353 }
354 }
355}