datacortex_core/mixer/
apm.rs1const NUM_BINS: usize = 65;
14
15pub struct APMStage {
17 table: Vec<[u32; NUM_BINS]>,
20 num_contexts: usize,
22 blend: u32,
25 last_ctx: usize,
27 last_bin: usize,
29 last_weight: u32,
31}
32
33impl APMStage {
34 pub fn new(num_contexts: usize, blend_pct: u32) -> Self {
39 let mut table = vec![[0u32; NUM_BINS]; num_contexts];
41 for ctx_row in table.iter_mut() {
42 for (i, entry) in ctx_row.iter_mut().enumerate() {
43 *entry = ((i as u64 * 4095 + (NUM_BINS as u64 - 1) / 2) / (NUM_BINS as u64 - 1))
44 .clamp(1, 4095) as u32;
45 }
46 }
47
48 APMStage {
49 table,
50 num_contexts,
51 blend: (blend_pct * 256 / 100).min(256),
52 last_ctx: 0,
53 last_bin: 0,
54 last_weight: 0,
55 }
56 }
57
58 #[inline(always)]
65 pub fn predict(&mut self, prob: u32, context: usize) -> u32 {
66 let ctx = context % self.num_contexts;
67 self.last_ctx = ctx;
68
69 let scaled = prob.min(4095) as u64 * (NUM_BINS as u64 - 1);
72 let bin = (scaled / 4095) as usize;
73 let bin = bin.min(NUM_BINS - 2); let weight = (scaled % 4095) as u32; self.last_bin = bin;
77 self.last_weight = weight;
78
79 let t = &self.table[ctx];
81 let interp = t[bin] as i64 + (t[bin + 1] as i64 - t[bin] as i64) * weight as i64 / 4095;
82 let apm_p = interp.clamp(1, 4095) as u32;
83
84 let blended =
86 (apm_p as u64 * self.blend as u64 + prob as u64 * (256 - self.blend) as u64) / 256;
87 (blended as u32).clamp(1, 4095)
88 }
89
90 #[inline(always)]
93 pub fn update(&mut self, bit: u8) {
94 let target = if bit != 0 { 4095u32 } else { 1u32 };
95 let t = &mut self.table[self.last_ctx];
96
97 let rate = 4; let old = t[self.last_bin];
103 let delta = (target as i32 - old as i32) >> rate;
104 t[self.last_bin] = (old as i32 + delta).clamp(1, 4095) as u32;
105
106 if self.last_bin + 1 < NUM_BINS {
108 let old2 = t[self.last_bin + 1];
109 let delta2 = (target as i32 - old2 as i32) >> (rate + 1);
110 t[self.last_bin + 1] = (old2 as i32 + delta2).clamp(1, 4095) as u32;
111 }
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn initial_passthrough() {
121 let mut apm = APMStage::new(1, 0); let p = apm.predict(2048, 0);
123 assert_eq!(p, 2048);
124 }
125
126 #[test]
127 fn initial_50_blend_near_identity() {
128 let mut apm = APMStage::new(1, 50);
129 let p = apm.predict(2048, 0);
131 assert!(
132 (2000..=2096).contains(&p),
133 "50% blend of identity should be near input: {p}"
134 );
135 }
136
137 #[test]
138 fn prediction_in_range() {
139 let mut apm = APMStage::new(512, 50);
140 for prob in [1u32, 100, 1000, 2048, 3000, 4000, 4095] {
141 for ctx in [0usize, 100, 511] {
142 let p = apm.predict(prob, ctx);
143 assert!(
144 (1..=4095).contains(&p),
145 "out of range: prob={prob}, ctx={ctx}, got {p}"
146 );
147 }
148 }
149 }
150
151 #[test]
152 fn update_adapts() {
153 let mut apm = APMStage::new(1, 100); for _ in 0..100 {
156 apm.predict(2048, 0);
157 apm.update(1);
158 }
159 let p = apm.predict(2048, 0);
161 assert!(p > 2048, "after many 1s, APM should predict higher: {p}");
162 }
163
164 #[test]
165 fn different_contexts_independent() {
166 let mut apm = APMStage::new(2, 100);
167 for _ in 0..50 {
169 apm.predict(2048, 0);
170 apm.update(1);
171 }
172 let p = apm.predict(2048, 1);
174 assert!(
175 (2000..=2096).contains(&p),
176 "untrained context should be near 2048: {p}"
177 );
178 }
179
180 #[test]
181 fn extreme_inputs() {
182 let mut apm = APMStage::new(1, 50);
183 let p_low = apm.predict(1, 0);
184 assert!((1..=100).contains(&p_low), "low input: {p_low}");
185
186 let p_high = apm.predict(4095, 0);
187 assert!((3995..=4095).contains(&p_high), "high input: {p_high}");
188 }
189}