Skip to main content

datacortex_core/mixer/
dual_mixer.rs

1//! TripleMixer — three-level logistic mixer (fine + medium + coarse) in log-odds space.
2//!
3//! Phase 3+: PAQ8-style logistic mixer with multi-output ContextMap support.
4//!
5//! 37 model inputs:
6//! - Order-0: 1 (state only)
7//! - Order-1 through Order-9: 3 each (state + run + byte_hist) = 27
8//! - Match, Word, Sparse, Run, JSON, Indirect, PPM, DMC, ISSE: 1 each = 9
9//!
10//! Fine mixer: 64K weight sets, learning rate eta=2.
11//! Medium mixer: 16K weight sets, learning rate eta=3.
12//! Coarse mixer: 4K weight sets, learning rate eta=4.
13//!
14//! Blend: fine * 0.5 + medium * 0.3 + coarse * 0.2 in log-odds space.
15
16use crate::mixer::logistic::{squash, stretch};
17
18/// Number of models feeding the mixer.
19/// Layout: [O0, O1_s, O1_r, O2_s, O2_r, ..., O9_s, O9_r,
20///          match, word, sparse, run, json, indirect, ppm, dmc, isse]
21/// = 1 + 9*2 + 9 = 28
22pub const NUM_MODELS: usize = 28;
23
24/// Fine mixer: 64K weight sets.
25const FINE_SETS: usize = 65536;
26
27/// Medium mixer: 16K weight sets.
28const MEDIUM_SETS: usize = 16384;
29
30/// Coarse mixer: 4K weight sets.
31const COARSE_SETS: usize = 4096;
32
33/// Weight scale factor (2^12 = 4096).
34const W_SCALE: i32 = 4096;
35
36/// Initial weights per model.
37/// Order models: state predictions get higher weight, run starts low.
38/// Layout: O0, O1(s,r), O2(s,r), ..., O9(s,r), match, word, sparse, run, json, indirect, ppm, dmc, isse
39const INITIAL_WEIGHTS: [i32; NUM_MODELS] = [
40    200, // O0
41    300, 60, // O1 (state, run)
42    350, 60, // O2
43    450, 60, // O3
44    450, 60, // O4
45    450, 60, // O5
46    300, 60, // O6
47    250, 60, // O7
48    200, 60, // O8
49    180, 60,  // O9
50    300, // match
51    250, // word
52    250, // sparse
53    200, // run
54    250, // json
55    200, // indirect
56    50,  // ppm
57    30,  // dmc
58    150, // isse
59];
60
61/// Fine mixer learning rate.
62const FINE_LR: i32 = 2;
63
64/// Medium mixer learning rate.
65const MEDIUM_LR: i32 = 3;
66
67/// Coarse mixer learning rate.
68const COARSE_LR: i32 = 4;
69
70/// Triple logistic mixer (fine + medium + coarse).
71pub struct DualMixer {
72    /// Fine mixer weights: [FINE_SETS][NUM_MODELS].
73    fine_weights: Vec<[i32; NUM_MODELS]>,
74    /// Medium mixer weights: [MEDIUM_SETS][NUM_MODELS].
75    medium_weights: Vec<[i32; NUM_MODELS]>,
76    /// Coarse mixer weights: [COARSE_SETS][NUM_MODELS].
77    coarse_weights: Vec<[i32; NUM_MODELS]>,
78    /// Cached stretched predictions from last predict() call.
79    last_d: [i32; NUM_MODELS],
80    /// Cached fine context index.
81    last_fine_ctx: usize,
82    /// Cached medium context index.
83    last_medium_ctx: usize,
84    /// Cached coarse context index.
85    last_coarse_ctx: usize,
86    /// Cached blended output probability.
87    last_p: u32,
88}
89
90impl DualMixer {
91    pub fn new() -> Self {
92        DualMixer {
93            fine_weights: vec![INITIAL_WEIGHTS; FINE_SETS],
94            medium_weights: vec![INITIAL_WEIGHTS; MEDIUM_SETS],
95            coarse_weights: vec![INITIAL_WEIGHTS; COARSE_SETS],
96            last_d: [0; NUM_MODELS],
97            last_fine_ctx: 0,
98            last_medium_ctx: 0,
99            last_coarse_ctx: 0,
100            last_p: 2048,
101        }
102    }
103
104    /// Mix model predictions to produce a final 12-bit probability.
105    #[inline(always)]
106    #[allow(clippy::needless_range_loop)]
107    #[allow(clippy::too_many_arguments)]
108    pub fn predict(
109        &mut self,
110        predictions: &[u32; NUM_MODELS],
111        c0: u32,
112        c1: u8,
113        bpos: u8,
114        byte_class: u8,
115        match_len_q: u8,
116        run_q: u8,
117        _xml_state: u8,
118    ) -> u32 {
119        // Stretch all predictions to log-odds.
120        for i in 0..NUM_MODELS {
121            self.last_d[i] = stretch(predictions[i]);
122        }
123
124        // Fine mixer context.
125        self.last_fine_ctx = fine_context(c0, c1, bpos, byte_class, match_len_q, run_q);
126        // Medium mixer context: (c0, c1_top4, bpos, bclass, run_q, match_q).
127        self.last_medium_ctx = medium_context(c0, c1, bpos, run_q, match_len_q);
128        // Coarse mixer context: (c0, bpos).
129        self.last_coarse_ctx = coarse_context(c0, bpos);
130
131        // Compute weighted sums in log-odds space.
132        let fw = &self.fine_weights[self.last_fine_ctx];
133        let mw = &self.medium_weights[self.last_medium_ctx];
134        let cw = &self.coarse_weights[self.last_coarse_ctx];
135
136        let mut fine_sum: i64 = 0;
137        let mut medium_sum: i64 = 0;
138        let mut coarse_sum: i64 = 0;
139        for i in 0..NUM_MODELS {
140            let d = self.last_d[i] as i64;
141            fine_sum += fw[i] as i64 * d;
142            medium_sum += mw[i] as i64 * d;
143            coarse_sum += cw[i] as i64 * d;
144        }
145        let fine_d = (fine_sum / W_SCALE as i64) as i32;
146        let medium_d = (medium_sum / W_SCALE as i64) as i32;
147        let coarse_d = (coarse_sum / W_SCALE as i64) as i32;
148
149        // Blend: fine * 0.5 + medium * 0.3 + coarse * 0.2 in log-odds space.
150        // Use integer: (fine * 5 + medium * 3 + coarse * 2) / 10
151        let blended_d = (fine_d as i64 * 5 + medium_d as i64 * 3 + coarse_d as i64 * 2) / 10;
152        let p = squash(blended_d as i32).clamp(1, 4095);
153        self.last_p = p;
154        p
155    }
156
157    /// Update weights after observing `bit`.
158    #[inline(always)]
159    #[allow(clippy::needless_range_loop)]
160    pub fn update(&mut self, bit: u8) {
161        let error = (bit as i32) * 4096 - self.last_p as i32;
162
163        // Fine mixer update.
164        let fw = &mut self.fine_weights[self.last_fine_ctx];
165        for i in 0..NUM_MODELS {
166            let delta = (FINE_LR as i64 * self.last_d[i] as i64 * error as i64) >> 16;
167            fw[i] = (fw[i] as i64 + delta).clamp(-32768, 32767) as i32;
168        }
169
170        // Medium mixer update.
171        let mw = &mut self.medium_weights[self.last_medium_ctx];
172        for i in 0..NUM_MODELS {
173            let delta = (MEDIUM_LR as i64 * self.last_d[i] as i64 * error as i64) >> 16;
174            mw[i] = (mw[i] as i64 + delta).clamp(-32768, 32767) as i32;
175        }
176
177        // Coarse mixer update.
178        let cw = &mut self.coarse_weights[self.last_coarse_ctx];
179        for i in 0..NUM_MODELS {
180            let delta = (COARSE_LR as i64 * self.last_d[i] as i64 * error as i64) >> 16;
181            cw[i] = (cw[i] as i64 + delta).clamp(-32768, 32767) as i32;
182        }
183    }
184}
185
186impl Default for DualMixer {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192/// Classify a byte into categories for mixer context.
193/// Returns 0-9 based on byte value ranges.
194///
195/// The high byte range (128-255) is split into 4 groups to properly
196/// discriminate WRT word codes. Without this, all WRT codes (0x80-0xFE)
197/// get the SAME mixer context, preventing the mixer from learning
198/// different weights for different word codes.
199#[inline]
200pub fn byte_class(b: u8) -> u8 {
201    match b {
202        0..=31 => 0,      // control chars
203        b' ' => 1,        // space
204        b'0'..=b'9' => 2, // digits
205        b'A'..=b'Z' => 3, // uppercase
206        b'a'..=b'z' => 4, // lowercase
207        b'!'..=b'/' => 5, // punctuation low
208        b':'..=b'@' => 5, // punctuation mid
209        b'['..=b'`' => 5, // punctuation high
210        b'{'..=b'~' => 5, // punctuation high2
211        0x80..=0x9F => 6, // WRT word codes group 1 (common words: the, of, and...)
212        0xA0..=0xBF => 7, // WRT word codes group 2
213        0xC0..=0xDF => 8, // WRT word codes group 3
214        0xE0..=0xFE => 9, // WRT word codes group 4 / high bytes
215        0xFF => 10,       // WRT escape byte
216        _ => 11,          // other (unreachable with full match)
217    }
218}
219
220/// Compute fine mixer context index (0..65535).
221/// Uses c0 partial byte, c1 top 4 bits, bpos, byte class, match info, run length.
222#[inline]
223fn fine_context(c0: u32, c1: u8, bpos: u8, bclass: u8, match_q: u8, run_q: u8) -> usize {
224    // Hash together: c0(8b) + c1_top4(4b) + bpos(3b) + bclass(3b) + match_q(2b) + run_q(2b)
225    // = 22 bits, fold to 16 bits for 64K sets
226    let mut h: usize = c0 as usize & 0xFF;
227    h = h.wrapping_mul(97) + (c1 as usize >> 4);
228    h = h.wrapping_mul(97) + bpos as usize;
229    h = h.wrapping_mul(97) + (bclass as usize & 0x7);
230    h = h.wrapping_mul(97) + (match_q as usize & 0x3);
231    h = h.wrapping_mul(97) + (run_q as usize & 0x3);
232    h & (FINE_SETS - 1)
233}
234
235/// Compute medium mixer context index (0..16383).
236/// Uses c0, c1 top nibble, bpos, byte class, run quantized, match length quantized.
237#[inline]
238fn medium_context(c0: u32, c1: u8, bpos: u8, run_q: u8, match_q: u8) -> usize {
239    // c0 (8 bits) + c1_top4 (4 bits) + bpos (3 bits) + bclass (3 bits) + run_q (2 bits) + match_q (2 bits) = 22 bits -> hash to 14 bits
240    let bclass = byte_class(c1);
241    let mut h: usize = c0 as usize & 0xFF;
242    h = h.wrapping_mul(67) + (c1 as usize >> 4);
243    h = h.wrapping_mul(67) + bpos as usize;
244    h = h.wrapping_mul(67) + bclass as usize;
245    h = h.wrapping_mul(67) + (run_q as usize & 0x3);
246    h = h.wrapping_mul(67) + (match_q as usize & 0x3);
247    h & (MEDIUM_SETS - 1)
248}
249
250/// Compute coarse mixer context index (0..4095).
251#[inline]
252fn coarse_context(c0: u32, bpos: u8) -> usize {
253    ((c0 as usize & 0xFF) | ((bpos as usize) << 8)) & (COARSE_SETS - 1)
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn initial_prediction_near_balanced() {
262        let mut mixer = DualMixer::new();
263        let preds = [2048u32; NUM_MODELS];
264        let p = mixer.predict(&preds, 1, 0, 0, 0, 0, 0, 0);
265        assert!(
266            (1900..=2100).contains(&p),
267            "initial prediction should be near 2048, got {p}"
268        );
269    }
270
271    #[test]
272    fn prediction_in_range() {
273        let mut mixer = DualMixer::new();
274        let mut preds = [2048u32; NUM_MODELS];
275        preds[0] = 100;
276        preds[1] = 4000;
277        preds[4] = 3000;
278        preds[7] = 500;
279        let p = mixer.predict(&preds, 128, b'a', 3, 4, 1, 0, 0);
280        assert!((1..=4095).contains(&p), "prediction out of range: {p}");
281    }
282
283    #[test]
284    fn update_changes_weights() {
285        let mut mixer = DualMixer::new();
286        let preds = [2048u32; NUM_MODELS];
287        mixer.predict(&preds, 1, 0, 0, 0, 0, 0, 0);
288        let before = mixer.fine_weights[mixer.last_fine_ctx];
289        mixer.update(1);
290        let after = mixer.fine_weights[mixer.last_fine_ctx];
291        let _ = (before, after);
292    }
293
294    #[test]
295    fn mixer_adapts_to_biased_input() {
296        let mut mixer = DualMixer::new();
297        for _ in 0..100 {
298            let mut preds = [2048u32; NUM_MODELS];
299            preds[0] = 3500;
300            let p = mixer.predict(&preds, 1, 0, 0, 0, 0, 0, 0);
301            let _ = p;
302            mixer.update(1);
303        }
304        let mut preds = [2048u32; NUM_MODELS];
305        preds[0] = 3500;
306        let p = mixer.predict(&preds, 1, 0, 0, 0, 0, 0, 0);
307        assert!(p > 2500, "mixer should have learned to trust model 0: {p}");
308    }
309
310    #[test]
311    fn byte_class_categories() {
312        assert_eq!(byte_class(0), 0); // control
313        assert_eq!(byte_class(b' '), 1); // space
314        assert_eq!(byte_class(b'5'), 2); // digit
315        assert_eq!(byte_class(b'A'), 3); // uppercase
316        assert_eq!(byte_class(b'z'), 4); // lowercase
317        assert_eq!(byte_class(b'.'), 5); // punctuation
318        assert_eq!(byte_class(0x80), 6); // WRT group 1
319        assert_eq!(byte_class(0x90), 6); // WRT group 1
320        assert_eq!(byte_class(0xA0), 7); // WRT group 2
321        assert_eq!(byte_class(0xC0), 8); // WRT group 3
322        assert_eq!(byte_class(0xE0), 9); // WRT group 4
323        assert_eq!(byte_class(0xFF), 10); // WRT escape
324    }
325
326    #[test]
327    fn fine_context_in_range() {
328        for c0 in [1u32, 128, 255] {
329            for bpos in 0..8u8 {
330                let ctx = fine_context(c0, 0xFF, bpos, 7, 3, 3);
331                assert!(ctx < FINE_SETS, "fine context out of range: {ctx}");
332            }
333        }
334    }
335
336    #[test]
337    fn medium_context_in_range() {
338        for c0 in [1u32, 128, 255] {
339            for bpos in 0..8u8 {
340                let ctx = medium_context(c0, 0xFF, bpos, 3, 3);
341                assert!(ctx < MEDIUM_SETS, "medium context out of range: {ctx}");
342            }
343        }
344    }
345
346    #[test]
347    fn coarse_context_in_range() {
348        for c0 in [1u32, 128, 255] {
349            for bpos in 0..8u8 {
350                let ctx = coarse_context(c0, bpos);
351                assert!(ctx < COARSE_SETS, "coarse context out of range: {ctx}");
352            }
353        }
354    }
355}