Skip to main content

datacortex_core/mixer/
apm.rs

1//! APM — Adaptive Probability Map for post-mixer refinement.
2//!
3//! Phase 3: Two-stage APM cascade following V2 proven parameters.
4//!
5//! APM maps (context, input_probability) → refined_probability using
6//! an interpolation table. The table adapts over time.
7//!
8//! Stage 1: 512 contexts (bpos * byte_class), 50% blend with input.
9//! Stage 2: 16K contexts (c1 * c0_top4), 25% blend with input.
10
11/// Number of probability bins in the interpolation table.
12/// 65 bins provides good resolution while keeping memory footprint small.
13const NUM_BINS: usize = 65;
14
15/// An adaptive probability map stage.
16pub struct APMStage {
17    /// Table: [num_contexts][NUM_BINS] of 12-bit probabilities.
18    /// Each entry is the mapped output probability for (context, bin).
19    table: Vec<[u32; NUM_BINS]>,
20    /// Number of contexts.
21    num_contexts: usize,
22    /// Blend factor: how much of APM output vs input to use (0-256).
23    /// blend=128 means 50% input + 50% APM output.
24    blend: u32,
25    /// Last context used (for update).
26    last_ctx: usize,
27    /// Last bin index (for update).
28    last_bin: usize,
29    /// Last weight for interpolation (for update).
30    last_weight: u32,
31}
32
33impl APMStage {
34    /// Create a new APM stage with the given number of contexts and blend factor.
35    ///
36    /// `num_contexts`: number of different contexts.
37    /// `blend_pct`: blend percentage (0-100). 50 means 50% APM output + 50% input.
38    pub fn new(num_contexts: usize, blend_pct: u32) -> Self {
39        // Initialize table with linear mapping: bin i → i * 4096 / (NUM_BINS - 1).
40        let mut table = vec![[0u32; NUM_BINS]; num_contexts];
41        for ctx_row in table.iter_mut() {
42            for (i, entry) in ctx_row.iter_mut().enumerate() {
43                *entry = ((i as u64 * 4095 + (NUM_BINS as u64 - 1) / 2) / (NUM_BINS as u64 - 1))
44                    .clamp(1, 4095) as u32;
45            }
46        }
47
48        APMStage {
49            table,
50            num_contexts,
51            blend: (blend_pct * 256 / 100).min(256),
52            last_ctx: 0,
53            last_bin: 0,
54            last_weight: 0,
55        }
56    }
57
58    /// Map an input probability through the APM.
59    ///
60    /// `prob`: input 12-bit probability [1, 4095].
61    /// `context`: context index (0..num_contexts-1).
62    ///
63    /// Returns: refined 12-bit probability [1, 4095].
64    #[inline(always)]
65    pub fn predict(&mut self, prob: u32, context: usize) -> u32 {
66        let ctx = context % self.num_contexts;
67        self.last_ctx = ctx;
68
69        // Map probability to bin index with interpolation weight.
70        // prob in [0, 4096] → bin in [0, NUM_BINS-1]
71        let scaled = prob.min(4095) as u64 * (NUM_BINS as u64 - 1);
72        let bin = (scaled / 4095) as usize;
73        let bin = bin.min(NUM_BINS - 2); // clamp so bin+1 is valid
74        let weight = (scaled % 4095) as u32; // interpolation weight (0-4094)
75
76        self.last_bin = bin;
77        self.last_weight = weight;
78
79        // Linear interpolation between table[ctx][bin] and table[ctx][bin+1].
80        let t = &self.table[ctx];
81        let interp = t[bin] as i64 + (t[bin + 1] as i64 - t[bin] as i64) * weight as i64 / 4095;
82        let apm_p = interp.clamp(1, 4095) as u32;
83
84        // Blend APM output with input probability.
85        let blended =
86            (apm_p as u64 * self.blend as u64 + prob as u64 * (256 - self.blend) as u64) / 256;
87        (blended as u32).clamp(1, 4095)
88    }
89
90    /// Update the APM after observing `bit`.
91    /// Must be called after predict().
92    #[inline(always)]
93    pub fn update(&mut self, bit: u8) {
94        let target = if bit != 0 { 4095u32 } else { 1u32 };
95        let t = &mut self.table[self.last_ctx];
96
97        // Update both bins involved in interpolation.
98        // Learning rate: move 1/16th toward target.
99        let rate = 4; // shift amount: 1/16
100
101        // Primary bin.
102        let old = t[self.last_bin];
103        let delta = (target as i32 - old as i32) >> rate;
104        t[self.last_bin] = (old as i32 + delta).clamp(1, 4095) as u32;
105
106        // Adjacent bin (with reduced learning).
107        if self.last_bin + 1 < NUM_BINS {
108            let old2 = t[self.last_bin + 1];
109            let delta2 = (target as i32 - old2 as i32) >> (rate + 1);
110            t[self.last_bin + 1] = (old2 as i32 + delta2).clamp(1, 4095) as u32;
111        }
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn initial_passthrough() {
121        let mut apm = APMStage::new(1, 0); // 0% blend = pure input
122        let p = apm.predict(2048, 0);
123        assert_eq!(p, 2048);
124    }
125
126    #[test]
127    fn initial_50_blend_near_identity() {
128        let mut apm = APMStage::new(1, 50);
129        // With identity initialization, 50% blend should give ~same as input.
130        let p = apm.predict(2048, 0);
131        assert!(
132            (2000..=2096).contains(&p),
133            "50% blend of identity should be near input: {p}"
134        );
135    }
136
137    #[test]
138    fn prediction_in_range() {
139        let mut apm = APMStage::new(512, 50);
140        for prob in [1u32, 100, 1000, 2048, 3000, 4000, 4095] {
141            for ctx in [0usize, 100, 511] {
142                let p = apm.predict(prob, ctx);
143                assert!(
144                    (1..=4095).contains(&p),
145                    "out of range: prob={prob}, ctx={ctx}, got {p}"
146                );
147            }
148        }
149    }
150
151    #[test]
152    fn update_adapts() {
153        let mut apm = APMStage::new(1, 100); // 100% APM
154        // Predict at 2048, then update with bit=1 many times.
155        for _ in 0..100 {
156            apm.predict(2048, 0);
157            apm.update(1);
158        }
159        // After many 1s, prediction at 2048 should shift higher.
160        let p = apm.predict(2048, 0);
161        assert!(p > 2048, "after many 1s, APM should predict higher: {p}");
162    }
163
164    #[test]
165    fn different_contexts_independent() {
166        let mut apm = APMStage::new(2, 100);
167        // Train context 0 with all 1s.
168        for _ in 0..50 {
169            apm.predict(2048, 0);
170            apm.update(1);
171        }
172        // Context 1 should still be near 2048.
173        let p = apm.predict(2048, 1);
174        assert!(
175            (2000..=2096).contains(&p),
176            "untrained context should be near 2048: {p}"
177        );
178    }
179
180    #[test]
181    fn extreme_inputs() {
182        let mut apm = APMStage::new(1, 50);
183        let p_low = apm.predict(1, 0);
184        assert!((1..=100).contains(&p_low), "low input: {p_low}");
185
186        let p_high = apm.predict(4095, 0);
187        assert!((3995..=4095).contains(&p_high), "high input: {p_high}");
188    }
189}