Skip to main content

datacortex_core/model/
gru_model.rs

1//! GRU (Gated Recurrent Unit) byte-level predictor with truncated BPTT.
2//!
3//! A byte-level neural predictor providing a DIFFERENT signal from the bit-level
4//! CM engine. The GRU captures cross-byte sequential patterns via a recurrent
5//! hidden state trained with backpropagation through time (BPTT).
6//!
7//! Architecture:
8//!   Input: one-hot byte embedding (256 → 32 via embedding matrix)
9//!   GRU: 128 hidden cells, 1 layer
10//!   Output: 128 → 256 linear → softmax → byte probabilities
11//!
12//! Training: truncated BPTT-10. At each byte completion, gradients propagate
13//! back through the last 10 steps of GRU history. This is the same strategy
14//! used by cmix (which uses BPTT-100) and gives the majority of the gain at
15//! 10% of the BPTT-100 cost.
16//!
17//! ~43K parameters (~170KB at f32). History + gradient buffers: ~260KB.
18//!
19//! CRITICAL: Encoder and decoder must maintain IDENTICAL GRU state.
20//! Both must call train(byte) then forward(byte) in the same order on the
21//! same bytes so that history buffers and weight updates are identical.
22
23const EMBED_DIM: usize = 32;
24const HIDDEN_DIM: usize = 128;
25const VOCAB_SIZE: usize = 256;
26
27/// BPTT truncation horizon: backpropagate gradients through this many steps.
28///
29/// 10 steps captures most sequential byte patterns with manageable overhead.
30/// cmix achieves -0.07 bpb using BPTT-100; 10 steps provides ~60-70% of that
31/// gain at 1/10 the training cost (~5× total slowdown vs no-BPTT).
32const BPTT_HORIZON: usize = 10;
33
34/// Online SGD learning rate.
35///
36/// 0.01 is conservative — the GRU sees each byte only once (online learning).
37/// Lower than typical offline training LR to avoid overshooting on rare bytes.
38const LEARNING_RATE: f32 = 0.01;
39
40/// Gradient clip magnitude.
41///
42/// BPTT through 10 steps can accumulate gradients. Clipping at 5.0 prevents
43/// exploding gradients while preserving the direction of improvement.
44const GRAD_CLIP: f32 = 5.0;
45
46/// GRU byte-level predictor with BPTT-10 online training.
47pub struct GruModel {
48    // ─── Parameters ──────────────────────────────────────────────────────────
49    /// Embedding matrix: [VOCAB_SIZE * EMBED_DIM]
50    embedding: Vec<f32>,
51
52    /// Update gate input weights: [HIDDEN_DIM * EMBED_DIM]
53    w_z: Vec<f32>,
54    /// Update gate recurrent weights: [HIDDEN_DIM * HIDDEN_DIM]
55    u_z: Vec<f32>,
56    /// Update gate biases: [HIDDEN_DIM]
57    b_z: Vec<f32>,
58
59    /// Reset gate input weights: [HIDDEN_DIM * EMBED_DIM]
60    w_r: Vec<f32>,
61    /// Reset gate recurrent weights: [HIDDEN_DIM * HIDDEN_DIM]
62    u_r: Vec<f32>,
63    /// Reset gate biases: [HIDDEN_DIM]
64    b_r: Vec<f32>,
65
66    /// Candidate hidden input weights: [HIDDEN_DIM * EMBED_DIM]
67    w_h: Vec<f32>,
68    /// Candidate hidden recurrent weights: [HIDDEN_DIM * HIDDEN_DIM]
69    u_h: Vec<f32>,
70    /// Candidate hidden biases: [HIDDEN_DIM]
71    b_h: Vec<f32>,
72
73    /// Output projection weights: [VOCAB_SIZE * HIDDEN_DIM]
74    w_o: Vec<f32>,
75    /// Output projection biases: [VOCAB_SIZE]
76    b_o: Vec<f32>,
77
78    // ─── Recurrent state ─────────────────────────────────────────────────────
79    /// Hidden state: [HIDDEN_DIM]
80    h: Vec<f32>,
81
82    // ─── Cached forward pass values (for the most recent step) ───────────────
83    /// Input embedding for the most recent forward step: [EMBED_DIM]
84    last_x: Vec<f32>,
85    /// Hidden state before the most recent forward step: [HIDDEN_DIM]
86    last_h_prev: Vec<f32>,
87    /// Update gate output from the most recent step: [HIDDEN_DIM]
88    last_z: Vec<f32>,
89    /// Reset gate output from the most recent step: [HIDDEN_DIM]
90    last_r: Vec<f32>,
91    /// Candidate hidden from the most recent step: [HIDDEN_DIM]
92    last_h_tilde: Vec<f32>,
93
94    /// Cached softmax probabilities over the next byte: [VOCAB_SIZE]
95    byte_probs: Vec<f32>,
96    /// Whether byte_probs has been computed for the current step.
97    probs_valid: bool,
98    /// Whether at least one byte has been processed (have valid hidden state).
99    has_context: bool,
100
101    // ─── BPTT history ring buffer ─────────────────────────────────────────────
102    // Flat layout: entry at ring position `p` occupies [p*DIM .. (p+1)*DIM].
103    // hist_pos is the next WRITE position (circular). After writing,
104    // hist_pos = (hist_pos + 1) % BPTT_HORIZON.
105    //
106    /// Saved input embeddings: [BPTT_HORIZON * EMBED_DIM]
107    hist_x: Vec<f32>,
108    /// Saved h_prev (hidden before each step): [BPTT_HORIZON * HIDDEN_DIM]
109    hist_h_prev: Vec<f32>,
110    /// Saved update gate outputs: [BPTT_HORIZON * HIDDEN_DIM]
111    hist_z: Vec<f32>,
112    /// Saved reset gate outputs: [BPTT_HORIZON * HIDDEN_DIM]
113    hist_r: Vec<f32>,
114    /// Saved candidate hidden values: [BPTT_HORIZON * HIDDEN_DIM]
115    hist_h_tilde: Vec<f32>,
116    /// Next write position in the ring (0..BPTT_HORIZON).
117    hist_pos: usize,
118    /// Number of valid entries in the ring (0..=BPTT_HORIZON).
119    hist_count: usize,
120
121    // ─── Pre-allocated gradient accumulators ─────────────────────────────────
122    // Zeroed at the start of each train() call and accumulated across all BPTT
123    // steps before a single weight update. Stored in the struct to avoid
124    // per-call heap allocation in the hot path.
125    //
126    /// Accumulated gradient for w_z: [HIDDEN_DIM * EMBED_DIM]
127    grad_w_z: Vec<f32>,
128    /// Accumulated gradient for u_z: [HIDDEN_DIM * HIDDEN_DIM]
129    grad_u_z: Vec<f32>,
130    /// Accumulated gradient for b_z: [HIDDEN_DIM]
131    grad_b_z: Vec<f32>,
132    /// Accumulated gradient for w_r: [HIDDEN_DIM * EMBED_DIM]
133    grad_w_r: Vec<f32>,
134    /// Accumulated gradient for u_r: [HIDDEN_DIM * HIDDEN_DIM]
135    grad_u_r: Vec<f32>,
136    /// Accumulated gradient for b_r: [HIDDEN_DIM]
137    grad_b_r: Vec<f32>,
138    /// Accumulated gradient for w_h: [HIDDEN_DIM * EMBED_DIM]
139    grad_w_h: Vec<f32>,
140    /// Accumulated gradient for u_h: [HIDDEN_DIM * HIDDEN_DIM]
141    grad_u_h: Vec<f32>,
142    /// Accumulated gradient for b_h: [HIDDEN_DIM]
143    grad_b_h: Vec<f32>,
144}
145
146impl GruModel {
147    /// Create a new GRU model with Xavier-initialized weights and zeroed buffers.
148    pub fn new() -> Self {
149        let mut model = GruModel {
150            embedding: vec![0.0; VOCAB_SIZE * EMBED_DIM],
151            w_z: vec![0.0; HIDDEN_DIM * EMBED_DIM],
152            u_z: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
153            b_z: vec![0.0; HIDDEN_DIM],
154            w_r: vec![0.0; HIDDEN_DIM * EMBED_DIM],
155            u_r: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
156            b_r: vec![0.0; HIDDEN_DIM],
157            w_h: vec![0.0; HIDDEN_DIM * EMBED_DIM],
158            u_h: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
159            b_h: vec![0.0; HIDDEN_DIM],
160            w_o: vec![0.0; VOCAB_SIZE * HIDDEN_DIM],
161            b_o: vec![0.0; VOCAB_SIZE],
162            h: vec![0.0; HIDDEN_DIM],
163            last_x: vec![0.0; EMBED_DIM],
164            last_h_prev: vec![0.0; HIDDEN_DIM],
165            last_z: vec![0.0; HIDDEN_DIM],
166            last_r: vec![0.0; HIDDEN_DIM],
167            last_h_tilde: vec![0.0; HIDDEN_DIM],
168            byte_probs: vec![1.0 / VOCAB_SIZE as f32; VOCAB_SIZE],
169            probs_valid: false,
170            has_context: false,
171            // History ring buffers — zeroed.
172            hist_x: vec![0.0; BPTT_HORIZON * EMBED_DIM],
173            hist_h_prev: vec![0.0; BPTT_HORIZON * HIDDEN_DIM],
174            hist_z: vec![0.0; BPTT_HORIZON * HIDDEN_DIM],
175            hist_r: vec![0.0; BPTT_HORIZON * HIDDEN_DIM],
176            hist_h_tilde: vec![0.0; BPTT_HORIZON * HIDDEN_DIM],
177            hist_pos: 0,
178            hist_count: 0,
179            // Gradient accumulators — zeroed (will be explicitly zeroed each train() call).
180            grad_w_z: vec![0.0; HIDDEN_DIM * EMBED_DIM],
181            grad_u_z: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
182            grad_b_z: vec![0.0; HIDDEN_DIM],
183            grad_w_r: vec![0.0; HIDDEN_DIM * EMBED_DIM],
184            grad_u_r: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
185            grad_b_r: vec![0.0; HIDDEN_DIM],
186            grad_w_h: vec![0.0; HIDDEN_DIM * EMBED_DIM],
187            grad_u_h: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
188            grad_b_h: vec![0.0; HIDDEN_DIM],
189        };
190        model.init_weights();
191        model
192    }
193
194    /// Initialize weights using a deterministic pseudo-random scheme.
195    /// Xavier/Glorot initialization scaled by fan_in + fan_out.
196    fn init_weights(&mut self) {
197        // Deterministic PRNG for reproducibility (encoder = decoder).
198        let mut seed: u64 = 0xDEAD_BEEF_CAFE_1234;
199
200        // Xavier scale for embedding: sqrt(2 / (256 + 32))
201        let embed_scale = (2.0 / (VOCAB_SIZE + EMBED_DIM) as f32).sqrt();
202        fill_xavier(&mut self.embedding, embed_scale, &mut seed);
203
204        // Xavier scale for input weights: sqrt(2 / (32 + 128))
205        let wx_scale = (2.0 / (EMBED_DIM + HIDDEN_DIM) as f32).sqrt();
206        fill_xavier(&mut self.w_z, wx_scale, &mut seed);
207        fill_xavier(&mut self.w_r, wx_scale, &mut seed);
208        fill_xavier(&mut self.w_h, wx_scale, &mut seed);
209
210        // Xavier scale for recurrent weights: sqrt(2 / (128 + 128))
211        let uh_scale = (2.0 / (HIDDEN_DIM + HIDDEN_DIM) as f32).sqrt();
212        fill_xavier(&mut self.u_z, uh_scale, &mut seed);
213        fill_xavier(&mut self.u_r, uh_scale, &mut seed);
214        fill_xavier(&mut self.u_h, uh_scale, &mut seed);
215
216        // Xavier scale for output weights: sqrt(2 / (128 + 256))
217        let wo_scale = (2.0 / (HIDDEN_DIM + VOCAB_SIZE) as f32).sqrt();
218        fill_xavier(&mut self.w_o, wo_scale, &mut seed);
219
220        // Bias update gate to slightly positive so it starts in "remember" mode
221        // (z → 1 means keep old hidden state). Helps gradient flow early in training.
222        for b in self.b_z.iter_mut() {
223            *b = 1.0;
224        }
225        // Reset gate and candidate biases stay at 0.
226    }
227
228    /// Forward pass: process one byte, update hidden state, compute output probs.
229    ///
230    /// Call this with the byte that was just OBSERVED. The resulting byte_probs
231    /// predict the NEXT byte. After forward(), call predict_bit() to get bit
232    /// probabilities for the next byte.
233    ///
234    /// Also saves the step into the BPTT history ring for train() to use.
235    #[inline(never)]
236    #[allow(
237        clippy::needless_range_loop,
238        reason = "matrix ops are clearer with explicit indices"
239    )]
240    pub fn forward(&mut self, byte: u8) {
241        // Save previous hidden state for backprop.
242        self.last_h_prev.copy_from_slice(&self.h);
243
244        // Get embedding for the input byte (row lookup).
245        let byte_idx = byte as usize;
246        let embed_start = byte_idx * EMBED_DIM;
247        self.last_x
248            .copy_from_slice(&self.embedding[embed_start..embed_start + EMBED_DIM]);
249
250        // Compute update gate z, reset gate r, and candidate h_tilde in a fused
251        // loop for cache locality.
252        for i in 0..HIDDEN_DIM {
253            let w_off = i * EMBED_DIM;
254            let wz_row = &self.w_z[w_off..w_off + EMBED_DIM];
255            let wr_row = &self.w_r[w_off..w_off + EMBED_DIM];
256            let wh_row = &self.w_h[w_off..w_off + EMBED_DIM];
257
258            let mut val_z = self.b_z[i];
259            let mut val_r = self.b_r[i];
260            let mut val_h = self.b_h[i];
261
262            // W @ x for all three gates.
263            for j in 0..EMBED_DIM {
264                let xj = self.last_x[j];
265                val_z += wz_row[j] * xj;
266                val_r += wr_row[j] * xj;
267                val_h += wh_row[j] * xj;
268            }
269
270            // U_z @ h_{t-1} and U_r @ h_{t-1}
271            let u_off = i * HIDDEN_DIM;
272            let uz_row = &self.u_z[u_off..u_off + HIDDEN_DIM];
273            let ur_row = &self.u_r[u_off..u_off + HIDDEN_DIM];
274
275            for j in 0..HIDDEN_DIM {
276                let hj = self.last_h_prev[j];
277                val_z += uz_row[j] * hj;
278                val_r += ur_row[j] * hj;
279            }
280
281            let z_i = sigmoid(val_z);
282            let r_i = sigmoid(val_r);
283            self.last_z[i] = z_i;
284            self.last_r[i] = r_i;
285
286            // U_h @ (r[i] * h_{t-1}) — uses r[i] (output-dimension convention).
287            let uh_row = &self.u_h[u_off..u_off + HIDDEN_DIM];
288            for j in 0..HIDDEN_DIM {
289                val_h += uh_row[j] * (r_i * self.last_h_prev[j]);
290            }
291
292            let h_tilde_i = tanh_approx(val_h);
293            self.last_h_tilde[i] = h_tilde_i;
294
295            // h_t = (1 - z_t) * h_{t-1} + z_t * h_tilde
296            self.h[i] = (1.0 - z_i) * self.last_h_prev[i] + z_i * h_tilde_i;
297        }
298
299        // Compute output probabilities.
300        self.compute_output_probs();
301        self.probs_valid = true;
302        self.has_context = true;
303
304        // ─── Save this step into the BPTT history ring ───────────────────────
305        let x_base = self.hist_pos * EMBED_DIM;
306        self.hist_x[x_base..x_base + EMBED_DIM].copy_from_slice(&self.last_x);
307
308        let h_base = self.hist_pos * HIDDEN_DIM;
309        self.hist_h_prev[h_base..h_base + HIDDEN_DIM].copy_from_slice(&self.last_h_prev);
310        self.hist_z[h_base..h_base + HIDDEN_DIM].copy_from_slice(&self.last_z);
311        self.hist_r[h_base..h_base + HIDDEN_DIM].copy_from_slice(&self.last_r);
312        self.hist_h_tilde[h_base..h_base + HIDDEN_DIM].copy_from_slice(&self.last_h_tilde);
313
314        // Advance circular write head.
315        self.hist_pos = (self.hist_pos + 1) % BPTT_HORIZON;
316        if self.hist_count < BPTT_HORIZON {
317            self.hist_count += 1;
318        }
319    }
320
321    /// Compute softmax output probabilities from current hidden state.
322    #[inline(never)]
323    #[allow(
324        clippy::needless_range_loop,
325        reason = "matrix ops are clearer with explicit indices"
326    )]
327    fn compute_output_probs(&mut self) {
328        let mut max_logit: f32 = f32::NEG_INFINITY;
329        for i in 0..VOCAB_SIZE {
330            let w_row = &self.w_o[i * HIDDEN_DIM..(i + 1) * HIDDEN_DIM];
331            let mut logit = self.b_o[i];
332            for j in 0..HIDDEN_DIM {
333                logit += w_row[j] * self.h[j];
334            }
335            self.byte_probs[i] = logit;
336            if logit > max_logit {
337                max_logit = logit;
338            }
339        }
340
341        // Numerically stable softmax: subtract max before exp.
342        let mut sum: f32 = 0.0;
343        for p in self.byte_probs.iter_mut() {
344            let e = (*p - max_logit).exp();
345            *p = e;
346            sum += e;
347        }
348
349        // Normalize with epsilon guard.
350        let inv_sum = 1.0 / (sum + 1e-30);
351        for p in self.byte_probs.iter_mut() {
352            *p *= inv_sum;
353            // Clamp to avoid log(0) in training.
354            if *p < 1e-8 {
355                *p = 1e-8;
356            }
357        }
358    }
359
360    /// Convert byte probabilities to a bit prediction for the CM MetaMixer.
361    ///
362    /// `bpos`: bit position 0-7 (0 = MSB).
363    /// `c0`: partial byte being built (starts at 1, accumulates bits MSB-first).
364    ///
365    /// Returns: 12-bit probability [1, 4095] of next bit being 1.
366    #[inline]
367    pub fn predict_bit(&self, bpos: u8, c0: u32) -> u32 {
368        if !self.has_context {
369            return 2048; // Uniform before first byte.
370        }
371
372        let bit_pos = 7 - bpos;
373        let mask = 1u8 << bit_pos;
374
375        let mut sum_one: f64 = 0.0;
376        let mut sum_zero: f64 = 0.0;
377
378        if bpos == 0 {
379            // No bits decoded yet — sum over all 256.
380            for b in 0..VOCAB_SIZE {
381                let p = self.byte_probs[b] as f64;
382                if (b as u8) & mask != 0 {
383                    sum_one += p;
384                } else {
385                    sum_zero += p;
386                }
387            }
388        } else {
389            // Some bits decoded. Only consider bytes matching the partial prefix.
390            let partial = (c0 & ((1u32 << bpos) - 1)) as u8;
391            let shift = 8 - bpos;
392            let base = (partial as usize) << shift;
393            let count = 1usize << shift;
394
395            for i in 0..count {
396                let b = base | i;
397                let p = self.byte_probs[b] as f64;
398                if (b as u8) & mask != 0 {
399                    sum_one += p;
400                } else {
401                    sum_zero += p;
402                }
403            }
404        }
405
406        let total = sum_one + sum_zero;
407        if total < 1e-15 {
408            return 2048;
409        }
410
411        let p = ((sum_one * 4096.0) / total) as u32;
412        p.clamp(1, 4095)
413    }
414
415    /// Online training with truncated BPTT.
416    ///
417    /// Computes the output-layer gradient for `actual_byte`, then propagates
418    /// d_h backwards through up to BPTT_HORIZON stored steps. Gradients are
419    /// accumulated into pre-allocated buffers and applied in a single weight
420    /// update at the end. This avoids the weight-corruption from immediate
421    /// per-step updates that would otherwise occur in BPTT.
422    ///
423    /// Call this BEFORE forward(actual_byte) (matches codec.rs flow). Both
424    /// encoder and decoder see the same byte sequence so their weights and
425    /// history buffers evolve identically — parity is preserved.
426    #[inline(never)]
427    #[allow(
428        clippy::needless_range_loop,
429        reason = "matrix ops are clearer with explicit indices"
430    )]
431    pub fn train(&mut self, actual_byte: u8) {
432        if !self.has_context {
433            return;
434        }
435
436        let target = actual_byte as usize;
437
438        // ─── Output layer: update W_o, b_o and compute initial d_h ──────────
439        // Cross-entropy + softmax gradient: d_logits[i] = probs[i] - (i==target).
440        let mut d_h = [0.0f32; HIDDEN_DIM];
441
442        for i in 0..VOCAB_SIZE {
443            let dl = clip_grad(self.byte_probs[i] - if i == target { 1.0 } else { 0.0 });
444            if dl.abs() < 1e-7 {
445                // Skip near-zero gradient — most of the 256 outputs.
446                continue;
447            }
448            let w_row = &mut self.w_o[i * HIDDEN_DIM..(i + 1) * HIDDEN_DIM];
449            let lr_dl = LEARNING_RATE * dl;
450            for j in 0..HIDDEN_DIM {
451                d_h[j] += dl * w_row[j];
452                w_row[j] -= lr_dl * self.h[j];
453            }
454            self.b_o[i] -= LEARNING_RATE * dl;
455        }
456
457        // ─── Zero gradient accumulators ──────────────────────────────────────
458        self.grad_w_z.fill(0.0);
459        self.grad_u_z.fill(0.0);
460        self.grad_b_z.fill(0.0);
461        self.grad_w_r.fill(0.0);
462        self.grad_u_r.fill(0.0);
463        self.grad_b_r.fill(0.0);
464        self.grad_w_h.fill(0.0);
465        self.grad_u_h.fill(0.0);
466        self.grad_b_h.fill(0.0);
467
468        // ─── BPTT: propagate d_h backwards through history ───────────────────
469        // Iterate from most-recent (step_back=0) to oldest (step_back=steps-1).
470        // At each step, accumulate weight gradients and compute d_h_prev to
471        // pass to the step before it.
472        let steps = self.hist_count;
473
474        // Gate gradients at step 0 — needed for the embedding update below.
475        let mut d_pre_z_s0 = [0.0f32; HIDDEN_DIM];
476        let mut d_pre_r_s0 = [0.0f32; HIDDEN_DIM];
477        let mut d_pre_h_s0 = [0.0f32; HIDDEN_DIM];
478
479        for step_back in 0..steps {
480            // Read from ring: most-recent step is at hist_pos - 1 (mod BPTT_HORIZON).
481            let ring_idx = (self.hist_pos + BPTT_HORIZON - 1 - step_back) % BPTT_HORIZON;
482
483            let x_base = ring_idx * EMBED_DIM;
484            let h_base = ring_idx * HIDDEN_DIM;
485
486            // ── GRU cell backward (one step) ─────────────────────────────────
487            // Given d_h (gradient of loss w.r.t. h_t), compute:
488            //   d_h_tilde, d_pre_z, d_pre_h (upstream gradients through h_t)
489            //   d_pre_r (upstream gradient through the reset gate path)
490            //   Accumulate dW, dU, db for all three gates.
491            //   Compute d_h_prev to pass to the previous step.
492
493            let mut d_pre_z = [0.0f32; HIDDEN_DIM];
494            let mut d_pre_r = [0.0f32; HIDDEN_DIM];
495            let mut d_pre_h = [0.0f32; HIDDEN_DIM];
496
497            for i in 0..HIDDEN_DIM {
498                let dhi = clip_grad(d_h[i]);
499                let z_i = self.hist_z[h_base + i];
500                let r_i = self.hist_r[h_base + i];
501                let h_tilde_i = self.hist_h_tilde[h_base + i];
502                let h_prev_i = self.hist_h_prev[h_base + i];
503
504                // h_t = (1-z)*h_prev + z*h_tilde  ⟹  dh_tilde = dh * z
505                let d_h_tilde_i = dhi * z_i;
506                // dz = dh * (h_tilde - h_prev)
507                let dz_i = dhi * (h_tilde_i - h_prev_i);
508
509                // Sigmoid backward: d_pre_z = dz * z * (1-z)
510                d_pre_z[i] = clip_grad(dz_i * z_i * (1.0 - z_i));
511                // Tanh backward: d_pre_h = d_h_tilde * (1 - h_tilde²)
512                d_pre_h[i] = clip_grad(d_h_tilde_i * (1.0 - h_tilde_i * h_tilde_i));
513
514                // Bias gradients.
515                self.grad_b_z[i] += d_pre_z[i];
516                self.grad_b_h[i] += d_pre_h[i];
517
518                // Input weight gradients: dW_z += d_pre_z[i] * x^T
519                let w_off = i * EMBED_DIM;
520                let lr_dpz = d_pre_z[i];
521                let lr_dph = d_pre_h[i];
522                for j in 0..EMBED_DIM {
523                    let xj = self.hist_x[x_base + j];
524                    self.grad_w_z[w_off + j] += lr_dpz * xj;
525                    self.grad_w_h[w_off + j] += lr_dph * xj;
526                }
527
528                // Recurrent weight gradients: dU_z += d_pre_z[i] * h_prev^T
529                // Also accumulate d_rh = (U_h^T @ d_pre_h) for the reset gate.
530                // Note: d_rh is NOT used for d_h_prev here — that uses
531                // sum_i(d_pre_h[i] * r[i] * U_h[i,j]) — see loop below.
532                let u_off = i * HIDDEN_DIM;
533                for j in 0..HIDDEN_DIM {
534                    let hj = self.hist_h_prev[h_base + j];
535                    self.grad_u_z[u_off + j] += d_pre_z[i] * hj;
536                    // dU_h[i,j] = d_pre_h[i] * r[i] * h_prev[j]
537                    self.grad_u_h[u_off + j] += d_pre_h[i] * r_i * hj;
538                }
539            }
540
541            // Reset gate backward.
542            // d_rh[j] = sum_i(d_pre_h[i] * U_h[i,j]) = (U_h^T @ d_pre_h)[j]
543            // dr[j] = d_rh[j] * h_prev[j]   (gradient w.r.t. r[j])
544            // d_pre_r[j] = dr[j] * r[j] * (1-r[j])   (sigmoid backward)
545            let mut d_rh = [0.0f32; HIDDEN_DIM];
546            for i in 0..HIDDEN_DIM {
547                let u_off = i * HIDDEN_DIM;
548                for j in 0..HIDDEN_DIM {
549                    d_rh[j] += d_pre_h[i] * self.u_h[u_off + j];
550                }
551            }
552            for j in 0..HIDDEN_DIM {
553                let dr = clip_grad(d_rh[j] * self.hist_h_prev[h_base + j]);
554                d_pre_r[j] =
555                    clip_grad(dr * self.hist_r[h_base + j] * (1.0 - self.hist_r[h_base + j]));
556                self.grad_b_r[j] += d_pre_r[j];
557            }
558            // Accumulate dW_r and dU_r.
559            for i in 0..HIDDEN_DIM {
560                let dp = d_pre_r[i];
561                let w_off = i * EMBED_DIM;
562                let u_off = i * HIDDEN_DIM;
563                for j in 0..EMBED_DIM {
564                    self.grad_w_r[w_off + j] += dp * self.hist_x[x_base + j];
565                }
566                for j in 0..HIDDEN_DIM {
567                    self.grad_u_r[u_off + j] += dp * self.hist_h_prev[h_base + j];
568                }
569            }
570
571            // Save step-0 gate gradients for the embedding update.
572            if step_back == 0 {
573                d_pre_z_s0.copy_from_slice(&d_pre_z);
574                d_pre_r_s0.copy_from_slice(&d_pre_r);
575                d_pre_h_s0.copy_from_slice(&d_pre_h);
576            }
577
578            // ── Propagate d_h to the previous step ───────────────────────────
579            // d_h_prev[j] = d_h[j] * (1 - z[j])          (direct path)
580            //             + sum_i(d_pre_z[i] * U_z[i,j])  (update gate path)
581            //             + sum_i(d_pre_r[i] * U_r[i,j])  (reset gate path)
582            //             + sum_i(d_pre_h[i] * r[i] * U_h[i,j])  (candidate path)
583            let mut d_h_prev = [0.0f32; HIDDEN_DIM];
584
585            // Direct path.
586            for j in 0..HIDDEN_DIM {
587                d_h_prev[j] = clip_grad(d_h[j]) * (1.0 - self.hist_z[h_base + j]);
588            }
589
590            // Gate recurrent paths: loop over hidden units i, accumulate into j.
591            for i in 0..HIDDEN_DIM {
592                let dpz = d_pre_z[i];
593                let dpr = d_pre_r[i];
594                // d_pre_h[i] * r[i] — used for the U_h candidate path.
595                let dph_r = d_pre_h[i] * self.hist_r[h_base + i];
596                let u_off = i * HIDDEN_DIM;
597                for j in 0..HIDDEN_DIM {
598                    d_h_prev[j] += dpz * self.u_z[u_off + j];
599                    d_h_prev[j] += dpr * self.u_r[u_off + j];
600                    d_h_prev[j] += dph_r * self.u_h[u_off + j];
601                }
602            }
603
604            // Clip for stability before passing to the next step.
605            for j in 0..HIDDEN_DIM {
606                d_h_prev[j] = clip_grad(d_h_prev[j]);
607            }
608            d_h.copy_from_slice(&d_h_prev);
609        }
610
611        // ─── Apply accumulated weight gradients in one shot ───────────────────
612        // Using current weights (not stored snapshots) for the update is standard
613        // "online BPTT" practice — the per-step weight change from LR=0.01 is
614        // small enough that the approximation is negligible.
615        for i in 0..HIDDEN_DIM {
616            let w_off = i * EMBED_DIM;
617            let u_off = i * HIDDEN_DIM;
618            for j in 0..EMBED_DIM {
619                self.w_z[w_off + j] -= LEARNING_RATE * clip_grad(self.grad_w_z[w_off + j]);
620                self.w_r[w_off + j] -= LEARNING_RATE * clip_grad(self.grad_w_r[w_off + j]);
621                self.w_h[w_off + j] -= LEARNING_RATE * clip_grad(self.grad_w_h[w_off + j]);
622            }
623            for j in 0..HIDDEN_DIM {
624                self.u_z[u_off + j] -= LEARNING_RATE * clip_grad(self.grad_u_z[u_off + j]);
625                self.u_r[u_off + j] -= LEARNING_RATE * clip_grad(self.grad_u_r[u_off + j]);
626                self.u_h[u_off + j] -= LEARNING_RATE * clip_grad(self.grad_u_h[u_off + j]);
627            }
628            self.b_z[i] -= LEARNING_RATE * clip_grad(self.grad_b_z[i]);
629            self.b_r[i] -= LEARNING_RATE * clip_grad(self.grad_b_r[i]);
630            self.b_h[i] -= LEARNING_RATE * clip_grad(self.grad_b_h[i]);
631        }
632
633        // ─── Embedding gradient (current step only) ───────────────────────────
634        // Update the embedding for the target byte using the step-0 gate gradients.
635        // Only the current step's input embedding is updated — historical embeddings
636        // are treated as fixed inputs in BPTT (standard practice for online learning).
637        let embed_start = target * EMBED_DIM;
638        for j in 0..EMBED_DIM {
639            let mut d_xj: f32 = 0.0;
640            for i in 0..HIDDEN_DIM {
641                let off = i * EMBED_DIM + j;
642                d_xj += d_pre_z_s0[i] * self.w_z[off];
643                d_xj += d_pre_r_s0[i] * self.w_r[off];
644                d_xj += d_pre_h_s0[i] * self.w_h[off];
645            }
646            self.embedding[embed_start + j] -= LEARNING_RATE * clip_grad(d_xj);
647        }
648    }
649}
650
651impl Default for GruModel {
652    fn default() -> Self {
653        Self::new()
654    }
655}
656
657// ─── Activation functions ────────────────────────────────────────────────────
658// CRITICAL: These must produce IDENTICAL results in encoder and decoder.
659// The same f32 operations guarantee bit-exact results across both paths.
660
661/// Sigmoid activation: 1 / (1 + exp(-x)), clamped to prevent overflow.
662#[inline]
663fn sigmoid(x: f32) -> f32 {
664    let x = x.clamp(-15.0, 15.0);
665    1.0 / (1.0 + (-x).exp())
666}
667
668/// Tanh using the identity tanh(x) = 2*sigmoid(2x) - 1.
669/// Reusing sigmoid ensures tanh and sigmoid use the SAME exp() path.
670#[inline]
671fn tanh_approx(x: f32) -> f32 {
672    let x = x.clamp(-7.5, 7.5);
673    2.0 * sigmoid(2.0 * x) - 1.0
674}
675
676/// Clip gradient magnitude to prevent explosion during BPTT.
677#[inline]
678fn clip_grad(g: f32) -> f32 {
679    g.clamp(-GRAD_CLIP, GRAD_CLIP)
680}
681
682/// Fill a weight slice with deterministic pseudo-random Xavier initialization.
683fn fill_xavier(weights: &mut [f32], scale: f32, seed: &mut u64) {
684    for w in weights.iter_mut() {
685        // xorshift64 PRNG — deterministic and fast.
686        *seed ^= *seed << 13;
687        *seed ^= *seed >> 7;
688        *seed ^= *seed << 17;
689        let r = (*seed as f32 / u64::MAX as f32) * 2.0 - 1.0;
690        *w = r * scale;
691    }
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697
698    #[test]
699    fn sigmoid_basic() {
700        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
701        assert!(sigmoid(15.0) > 0.999);
702        assert!(sigmoid(-15.0) < 0.001);
703    }
704
705    #[test]
706    fn tanh_basic() {
707        assert!((tanh_approx(0.0)).abs() < 1e-6);
708        assert!(tanh_approx(7.0) > 0.99);
709        assert!(tanh_approx(-7.0) < -0.99);
710    }
711
712    #[test]
713    fn deterministic_init() {
714        let m1 = GruModel::new();
715        let m2 = GruModel::new();
716        assert_eq!(m1.embedding, m2.embedding);
717        assert_eq!(m1.w_z, m2.w_z);
718        assert_eq!(m1.w_o, m2.w_o);
719    }
720
721    #[test]
722    fn initial_predict_bit_uniform() {
723        let model = GruModel::new();
724        let p = model.predict_bit(0, 1);
725        assert_eq!(p, 2048, "before any forward pass, should return 2048");
726    }
727
728    #[test]
729    fn forward_produces_valid_probs() {
730        let mut model = GruModel::new();
731        model.forward(b'A');
732        let sum: f64 = model.byte_probs.iter().map(|&p| p as f64).sum();
733        assert!(
734            (sum - 1.0).abs() < 0.01,
735            "byte_probs should sum to ~1.0, got {sum}"
736        );
737        for &p in &model.byte_probs {
738            assert!(p >= 0.0, "negative probability: {p}");
739        }
740    }
741
742    #[test]
743    fn predict_bit_in_range() {
744        let mut model = GruModel::new();
745        model.forward(b'A');
746        for bpos in 0..8u8 {
747            let c0 = if bpos == 0 {
748                1u32
749            } else {
750                let mut p = 1u32;
751                for prev in 0..bpos {
752                    p = (p << 1) | ((b'B' >> (7 - prev)) & 1) as u32;
753                }
754                p
755            };
756            let p = model.predict_bit(bpos, c0);
757            assert!(
758                (1..=4095).contains(&p),
759                "predict_bit out of range at bpos {bpos}: {p}"
760            );
761        }
762    }
763
764    #[test]
765    fn train_does_not_crash() {
766        let mut model = GruModel::new();
767        model.forward(b'A');
768        model.train(b'B');
769        // Should still produce valid output.
770        model.forward(b'B');
771        let sum: f64 = model.byte_probs.iter().map(|&p| p as f64).sum();
772        assert!(
773            (sum - 1.0).abs() < 0.01,
774            "probs after training should sum to ~1.0, got {sum}"
775        );
776    }
777
778    #[test]
779    fn history_ring_fills_correctly() {
780        let mut model = GruModel::new();
781        // Before any forward, history is empty.
782        assert_eq!(model.hist_count, 0);
783
784        // After N forward passes, history fills up to BPTT_HORIZON.
785        for i in 0..BPTT_HORIZON + 3 {
786            model.forward(b'A' + (i % 26) as u8);
787            let expected = (i + 1).min(BPTT_HORIZON);
788            assert_eq!(model.hist_count, expected, "hist_count wrong at step {i}");
789        }
790        // hist_pos should have wrapped.
791        assert_eq!(model.hist_pos, 3);
792    }
793
794    #[test]
795    fn bptt_does_not_produce_nan() {
796        let mut model = GruModel::new();
797        let data = b"Hello, World! This is a BPTT test. Let's check for NaN.";
798        for &byte in data {
799            model.forward(byte);
800            model.train(byte);
801            for j in 0..HIDDEN_DIM {
802                assert!(!model.h[j].is_nan(), "hidden state has NaN at j={j}");
803            }
804            for &p in &model.byte_probs {
805                assert!(!p.is_nan(), "byte_probs has NaN");
806            }
807        }
808    }
809
810    #[test]
811    fn encoder_decoder_identical() {
812        // Encoder and decoder must produce bit-identical predictions throughout.
813        let mut enc = GruModel::new();
814        let mut dec = GruModel::new();
815        let data = b"Hello, World! Testing BPTT encoder-decoder parity.";
816
817        for &byte in data {
818            enc.forward(byte);
819            dec.forward(byte);
820
821            // Predictions must match exactly.
822            for bpos in 0..8u8 {
823                let c0 = if bpos == 0 {
824                    1u32
825                } else {
826                    let mut p = 1u32;
827                    for prev in 0..bpos {
828                        p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
829                    }
830                    p
831                };
832                let pe = enc.predict_bit(bpos, c0);
833                let pd = dec.predict_bit(bpos, c0);
834                assert_eq!(pe, pd, "encoder/decoder diverged at bpos {bpos}");
835            }
836
837            // Train both with the same byte (same as codec flow).
838            enc.train(byte);
839            dec.train(byte);
840        }
841
842        // After training, hidden states must match.
843        assert_eq!(enc.h, dec.h, "hidden states diverged after training");
844        // History ring must match.
845        assert_eq!(enc.hist_count, dec.hist_count, "hist_count diverged");
846        assert_eq!(enc.hist_pos, dec.hist_pos, "hist_pos diverged");
847    }
848
849    #[test]
850    fn bptt_improves_over_1step() {
851        // BPTT-trained model should learn patterns faster than 1-step SGD.
852        // Test: repeated pattern "ab" — after BPTT training, should predict 'b'
853        // after 'a' with high confidence.
854        let mut model = GruModel::new();
855        let pattern: Vec<u8> = b"ab".repeat(200);
856        for &byte in &pattern {
857            model.train(byte);
858            model.forward(byte);
859        }
860        // After 'a' in the pattern, next byte should be 'b'.
861        model.train(b'a');
862        model.forward(b'a');
863        let p_b = model.byte_probs[b'b' as usize];
864        assert!(
865            p_b > 0.1,
866            "after 'a' in ab pattern with BPTT, P('b')={p_b} should be significant"
867        );
868    }
869
870    #[test]
871    fn adapts_to_pattern() {
872        let mut model = GruModel::new();
873        let pattern: Vec<u8> = b"ab".repeat(500);
874        for &byte in &pattern {
875            model.train(byte);
876            model.forward(byte);
877        }
878        model.train(b'a');
879        model.forward(b'a');
880        let p_b = model.byte_probs[b'b' as usize];
881        assert!(
882            p_b > 0.1,
883            "after 'a' in ab pattern, P('b')={p_b} should be significant"
884        );
885    }
886}