datacortex_core/model/gru_model.rs
1//! GRU (Gated Recurrent Unit) byte-level predictor with truncated BPTT.
2//!
3//! A byte-level neural predictor providing a DIFFERENT signal from the bit-level
4//! CM engine. The GRU captures cross-byte sequential patterns via a recurrent
5//! hidden state trained with backpropagation through time (BPTT).
6//!
7//! Architecture:
8//! Input: one-hot byte embedding (256 → 32 via embedding matrix)
9//! GRU: 128 hidden cells, 1 layer
10//! Output: 128 → 256 linear → softmax → byte probabilities
11//!
12//! Training: truncated BPTT-10. At each byte completion, gradients propagate
13//! back through the last 10 steps of GRU history. This is the same strategy
14//! used by cmix (which uses BPTT-100) and gives the majority of the gain at
15//! 10% of the BPTT-100 cost.
16//!
17//! ~43K parameters (~170KB at f32). History + gradient buffers: ~260KB.
18//!
19//! CRITICAL: Encoder and decoder must maintain IDENTICAL GRU state.
20//! Both must call train(byte) then forward(byte) in the same order on the
21//! same bytes so that history buffers and weight updates are identical.
22
23const EMBED_DIM: usize = 32;
24const HIDDEN_DIM: usize = 128;
25const VOCAB_SIZE: usize = 256;
26
27/// BPTT truncation horizon: backpropagate gradients through this many steps.
28///
29/// 10 steps captures most sequential byte patterns with manageable overhead.
30/// cmix achieves -0.07 bpb using BPTT-100; 10 steps provides ~60-70% of that
31/// gain at 1/10 the training cost (~5× total slowdown vs no-BPTT).
32const BPTT_HORIZON: usize = 10;
33
34/// Online SGD learning rate.
35///
36/// 0.01 is conservative — the GRU sees each byte only once (online learning).
37/// Lower than typical offline training LR to avoid overshooting on rare bytes.
38const LEARNING_RATE: f32 = 0.01;
39
40/// Gradient clip magnitude.
41///
42/// BPTT through 10 steps can accumulate gradients. Clipping at 5.0 prevents
43/// exploding gradients while preserving the direction of improvement.
44const GRAD_CLIP: f32 = 5.0;
45
46/// GRU byte-level predictor with BPTT-10 online training.
47pub struct GruModel {
48 // ─── Parameters ──────────────────────────────────────────────────────────
49 /// Embedding matrix: [VOCAB_SIZE * EMBED_DIM]
50 embedding: Vec<f32>,
51
52 /// Update gate input weights: [HIDDEN_DIM * EMBED_DIM]
53 w_z: Vec<f32>,
54 /// Update gate recurrent weights: [HIDDEN_DIM * HIDDEN_DIM]
55 u_z: Vec<f32>,
56 /// Update gate biases: [HIDDEN_DIM]
57 b_z: Vec<f32>,
58
59 /// Reset gate input weights: [HIDDEN_DIM * EMBED_DIM]
60 w_r: Vec<f32>,
61 /// Reset gate recurrent weights: [HIDDEN_DIM * HIDDEN_DIM]
62 u_r: Vec<f32>,
63 /// Reset gate biases: [HIDDEN_DIM]
64 b_r: Vec<f32>,
65
66 /// Candidate hidden input weights: [HIDDEN_DIM * EMBED_DIM]
67 w_h: Vec<f32>,
68 /// Candidate hidden recurrent weights: [HIDDEN_DIM * HIDDEN_DIM]
69 u_h: Vec<f32>,
70 /// Candidate hidden biases: [HIDDEN_DIM]
71 b_h: Vec<f32>,
72
73 /// Output projection weights: [VOCAB_SIZE * HIDDEN_DIM]
74 w_o: Vec<f32>,
75 /// Output projection biases: [VOCAB_SIZE]
76 b_o: Vec<f32>,
77
78 // ─── Recurrent state ─────────────────────────────────────────────────────
79 /// Hidden state: [HIDDEN_DIM]
80 h: Vec<f32>,
81
82 // ─── Cached forward pass values (for the most recent step) ───────────────
83 /// Input embedding for the most recent forward step: [EMBED_DIM]
84 last_x: Vec<f32>,
85 /// Hidden state before the most recent forward step: [HIDDEN_DIM]
86 last_h_prev: Vec<f32>,
87 /// Update gate output from the most recent step: [HIDDEN_DIM]
88 last_z: Vec<f32>,
89 /// Reset gate output from the most recent step: [HIDDEN_DIM]
90 last_r: Vec<f32>,
91 /// Candidate hidden from the most recent step: [HIDDEN_DIM]
92 last_h_tilde: Vec<f32>,
93
94 /// Cached softmax probabilities over the next byte: [VOCAB_SIZE]
95 byte_probs: Vec<f32>,
96 /// Whether byte_probs has been computed for the current step.
97 probs_valid: bool,
98 /// Whether at least one byte has been processed (have valid hidden state).
99 has_context: bool,
100
101 // ─── BPTT history ring buffer ─────────────────────────────────────────────
102 // Flat layout: entry at ring position `p` occupies [p*DIM .. (p+1)*DIM].
103 // hist_pos is the next WRITE position (circular). After writing,
104 // hist_pos = (hist_pos + 1) % BPTT_HORIZON.
105 //
106 /// Saved input embeddings: [BPTT_HORIZON * EMBED_DIM]
107 hist_x: Vec<f32>,
108 /// Saved h_prev (hidden before each step): [BPTT_HORIZON * HIDDEN_DIM]
109 hist_h_prev: Vec<f32>,
110 /// Saved update gate outputs: [BPTT_HORIZON * HIDDEN_DIM]
111 hist_z: Vec<f32>,
112 /// Saved reset gate outputs: [BPTT_HORIZON * HIDDEN_DIM]
113 hist_r: Vec<f32>,
114 /// Saved candidate hidden values: [BPTT_HORIZON * HIDDEN_DIM]
115 hist_h_tilde: Vec<f32>,
116 /// Next write position in the ring (0..BPTT_HORIZON).
117 hist_pos: usize,
118 /// Number of valid entries in the ring (0..=BPTT_HORIZON).
119 hist_count: usize,
120
121 // ─── Pre-allocated gradient accumulators ─────────────────────────────────
122 // Zeroed at the start of each train() call and accumulated across all BPTT
123 // steps before a single weight update. Stored in the struct to avoid
124 // per-call heap allocation in the hot path.
125 //
126 /// Accumulated gradient for w_z: [HIDDEN_DIM * EMBED_DIM]
127 grad_w_z: Vec<f32>,
128 /// Accumulated gradient for u_z: [HIDDEN_DIM * HIDDEN_DIM]
129 grad_u_z: Vec<f32>,
130 /// Accumulated gradient for b_z: [HIDDEN_DIM]
131 grad_b_z: Vec<f32>,
132 /// Accumulated gradient for w_r: [HIDDEN_DIM * EMBED_DIM]
133 grad_w_r: Vec<f32>,
134 /// Accumulated gradient for u_r: [HIDDEN_DIM * HIDDEN_DIM]
135 grad_u_r: Vec<f32>,
136 /// Accumulated gradient for b_r: [HIDDEN_DIM]
137 grad_b_r: Vec<f32>,
138 /// Accumulated gradient for w_h: [HIDDEN_DIM * EMBED_DIM]
139 grad_w_h: Vec<f32>,
140 /// Accumulated gradient for u_h: [HIDDEN_DIM * HIDDEN_DIM]
141 grad_u_h: Vec<f32>,
142 /// Accumulated gradient for b_h: [HIDDEN_DIM]
143 grad_b_h: Vec<f32>,
144}
145
146impl GruModel {
147 /// Create a new GRU model with Xavier-initialized weights and zeroed buffers.
148 pub fn new() -> Self {
149 let mut model = GruModel {
150 embedding: vec![0.0; VOCAB_SIZE * EMBED_DIM],
151 w_z: vec![0.0; HIDDEN_DIM * EMBED_DIM],
152 u_z: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
153 b_z: vec![0.0; HIDDEN_DIM],
154 w_r: vec![0.0; HIDDEN_DIM * EMBED_DIM],
155 u_r: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
156 b_r: vec![0.0; HIDDEN_DIM],
157 w_h: vec![0.0; HIDDEN_DIM * EMBED_DIM],
158 u_h: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
159 b_h: vec![0.0; HIDDEN_DIM],
160 w_o: vec![0.0; VOCAB_SIZE * HIDDEN_DIM],
161 b_o: vec![0.0; VOCAB_SIZE],
162 h: vec![0.0; HIDDEN_DIM],
163 last_x: vec![0.0; EMBED_DIM],
164 last_h_prev: vec![0.0; HIDDEN_DIM],
165 last_z: vec![0.0; HIDDEN_DIM],
166 last_r: vec![0.0; HIDDEN_DIM],
167 last_h_tilde: vec![0.0; HIDDEN_DIM],
168 byte_probs: vec![1.0 / VOCAB_SIZE as f32; VOCAB_SIZE],
169 probs_valid: false,
170 has_context: false,
171 // History ring buffers — zeroed.
172 hist_x: vec![0.0; BPTT_HORIZON * EMBED_DIM],
173 hist_h_prev: vec![0.0; BPTT_HORIZON * HIDDEN_DIM],
174 hist_z: vec![0.0; BPTT_HORIZON * HIDDEN_DIM],
175 hist_r: vec![0.0; BPTT_HORIZON * HIDDEN_DIM],
176 hist_h_tilde: vec![0.0; BPTT_HORIZON * HIDDEN_DIM],
177 hist_pos: 0,
178 hist_count: 0,
179 // Gradient accumulators — zeroed (will be explicitly zeroed each train() call).
180 grad_w_z: vec![0.0; HIDDEN_DIM * EMBED_DIM],
181 grad_u_z: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
182 grad_b_z: vec![0.0; HIDDEN_DIM],
183 grad_w_r: vec![0.0; HIDDEN_DIM * EMBED_DIM],
184 grad_u_r: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
185 grad_b_r: vec![0.0; HIDDEN_DIM],
186 grad_w_h: vec![0.0; HIDDEN_DIM * EMBED_DIM],
187 grad_u_h: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
188 grad_b_h: vec![0.0; HIDDEN_DIM],
189 };
190 model.init_weights();
191 model
192 }
193
194 /// Initialize weights using a deterministic pseudo-random scheme.
195 /// Xavier/Glorot initialization scaled by fan_in + fan_out.
196 fn init_weights(&mut self) {
197 // Deterministic PRNG for reproducibility (encoder = decoder).
198 let mut seed: u64 = 0xDEAD_BEEF_CAFE_1234;
199
200 // Xavier scale for embedding: sqrt(2 / (256 + 32))
201 let embed_scale = (2.0 / (VOCAB_SIZE + EMBED_DIM) as f32).sqrt();
202 fill_xavier(&mut self.embedding, embed_scale, &mut seed);
203
204 // Xavier scale for input weights: sqrt(2 / (32 + 128))
205 let wx_scale = (2.0 / (EMBED_DIM + HIDDEN_DIM) as f32).sqrt();
206 fill_xavier(&mut self.w_z, wx_scale, &mut seed);
207 fill_xavier(&mut self.w_r, wx_scale, &mut seed);
208 fill_xavier(&mut self.w_h, wx_scale, &mut seed);
209
210 // Xavier scale for recurrent weights: sqrt(2 / (128 + 128))
211 let uh_scale = (2.0 / (HIDDEN_DIM + HIDDEN_DIM) as f32).sqrt();
212 fill_xavier(&mut self.u_z, uh_scale, &mut seed);
213 fill_xavier(&mut self.u_r, uh_scale, &mut seed);
214 fill_xavier(&mut self.u_h, uh_scale, &mut seed);
215
216 // Xavier scale for output weights: sqrt(2 / (128 + 256))
217 let wo_scale = (2.0 / (HIDDEN_DIM + VOCAB_SIZE) as f32).sqrt();
218 fill_xavier(&mut self.w_o, wo_scale, &mut seed);
219
220 // Bias update gate to slightly positive so it starts in "remember" mode
221 // (z → 1 means keep old hidden state). Helps gradient flow early in training.
222 for b in self.b_z.iter_mut() {
223 *b = 1.0;
224 }
225 // Reset gate and candidate biases stay at 0.
226 }
227
228 /// Forward pass: process one byte, update hidden state, compute output probs.
229 ///
230 /// Call this with the byte that was just OBSERVED. The resulting byte_probs
231 /// predict the NEXT byte. After forward(), call predict_bit() to get bit
232 /// probabilities for the next byte.
233 ///
234 /// Also saves the step into the BPTT history ring for train() to use.
235 #[inline(never)]
236 #[allow(
237 clippy::needless_range_loop,
238 reason = "matrix ops are clearer with explicit indices"
239 )]
240 pub fn forward(&mut self, byte: u8) {
241 // Save previous hidden state for backprop.
242 self.last_h_prev.copy_from_slice(&self.h);
243
244 // Get embedding for the input byte (row lookup).
245 let byte_idx = byte as usize;
246 let embed_start = byte_idx * EMBED_DIM;
247 self.last_x
248 .copy_from_slice(&self.embedding[embed_start..embed_start + EMBED_DIM]);
249
250 // Compute update gate z, reset gate r, and candidate h_tilde in a fused
251 // loop for cache locality.
252 for i in 0..HIDDEN_DIM {
253 let w_off = i * EMBED_DIM;
254 let wz_row = &self.w_z[w_off..w_off + EMBED_DIM];
255 let wr_row = &self.w_r[w_off..w_off + EMBED_DIM];
256 let wh_row = &self.w_h[w_off..w_off + EMBED_DIM];
257
258 let mut val_z = self.b_z[i];
259 let mut val_r = self.b_r[i];
260 let mut val_h = self.b_h[i];
261
262 // W @ x for all three gates.
263 for j in 0..EMBED_DIM {
264 let xj = self.last_x[j];
265 val_z += wz_row[j] * xj;
266 val_r += wr_row[j] * xj;
267 val_h += wh_row[j] * xj;
268 }
269
270 // U_z @ h_{t-1} and U_r @ h_{t-1}
271 let u_off = i * HIDDEN_DIM;
272 let uz_row = &self.u_z[u_off..u_off + HIDDEN_DIM];
273 let ur_row = &self.u_r[u_off..u_off + HIDDEN_DIM];
274
275 for j in 0..HIDDEN_DIM {
276 let hj = self.last_h_prev[j];
277 val_z += uz_row[j] * hj;
278 val_r += ur_row[j] * hj;
279 }
280
281 let z_i = sigmoid(val_z);
282 let r_i = sigmoid(val_r);
283 self.last_z[i] = z_i;
284 self.last_r[i] = r_i;
285
286 // U_h @ (r[i] * h_{t-1}) — uses r[i] (output-dimension convention).
287 let uh_row = &self.u_h[u_off..u_off + HIDDEN_DIM];
288 for j in 0..HIDDEN_DIM {
289 val_h += uh_row[j] * (r_i * self.last_h_prev[j]);
290 }
291
292 let h_tilde_i = tanh_approx(val_h);
293 self.last_h_tilde[i] = h_tilde_i;
294
295 // h_t = (1 - z_t) * h_{t-1} + z_t * h_tilde
296 self.h[i] = (1.0 - z_i) * self.last_h_prev[i] + z_i * h_tilde_i;
297 }
298
299 // Compute output probabilities.
300 self.compute_output_probs();
301 self.probs_valid = true;
302 self.has_context = true;
303
304 // ─── Save this step into the BPTT history ring ───────────────────────
305 let x_base = self.hist_pos * EMBED_DIM;
306 self.hist_x[x_base..x_base + EMBED_DIM].copy_from_slice(&self.last_x);
307
308 let h_base = self.hist_pos * HIDDEN_DIM;
309 self.hist_h_prev[h_base..h_base + HIDDEN_DIM].copy_from_slice(&self.last_h_prev);
310 self.hist_z[h_base..h_base + HIDDEN_DIM].copy_from_slice(&self.last_z);
311 self.hist_r[h_base..h_base + HIDDEN_DIM].copy_from_slice(&self.last_r);
312 self.hist_h_tilde[h_base..h_base + HIDDEN_DIM].copy_from_slice(&self.last_h_tilde);
313
314 // Advance circular write head.
315 self.hist_pos = (self.hist_pos + 1) % BPTT_HORIZON;
316 if self.hist_count < BPTT_HORIZON {
317 self.hist_count += 1;
318 }
319 }
320
321 /// Compute softmax output probabilities from current hidden state.
322 #[inline(never)]
323 #[allow(
324 clippy::needless_range_loop,
325 reason = "matrix ops are clearer with explicit indices"
326 )]
327 fn compute_output_probs(&mut self) {
328 let mut max_logit: f32 = f32::NEG_INFINITY;
329 for i in 0..VOCAB_SIZE {
330 let w_row = &self.w_o[i * HIDDEN_DIM..(i + 1) * HIDDEN_DIM];
331 let mut logit = self.b_o[i];
332 for j in 0..HIDDEN_DIM {
333 logit += w_row[j] * self.h[j];
334 }
335 self.byte_probs[i] = logit;
336 if logit > max_logit {
337 max_logit = logit;
338 }
339 }
340
341 // Numerically stable softmax: subtract max before exp.
342 let mut sum: f32 = 0.0;
343 for p in self.byte_probs.iter_mut() {
344 let e = (*p - max_logit).exp();
345 *p = e;
346 sum += e;
347 }
348
349 // Normalize with epsilon guard.
350 let inv_sum = 1.0 / (sum + 1e-30);
351 for p in self.byte_probs.iter_mut() {
352 *p *= inv_sum;
353 // Clamp to avoid log(0) in training.
354 if *p < 1e-8 {
355 *p = 1e-8;
356 }
357 }
358 }
359
360 /// Convert byte probabilities to a bit prediction for the CM MetaMixer.
361 ///
362 /// `bpos`: bit position 0-7 (0 = MSB).
363 /// `c0`: partial byte being built (starts at 1, accumulates bits MSB-first).
364 ///
365 /// Returns: 12-bit probability [1, 4095] of next bit being 1.
366 #[inline]
367 pub fn predict_bit(&self, bpos: u8, c0: u32) -> u32 {
368 if !self.has_context {
369 return 2048; // Uniform before first byte.
370 }
371
372 let bit_pos = 7 - bpos;
373 let mask = 1u8 << bit_pos;
374
375 let mut sum_one: f64 = 0.0;
376 let mut sum_zero: f64 = 0.0;
377
378 if bpos == 0 {
379 // No bits decoded yet — sum over all 256.
380 for b in 0..VOCAB_SIZE {
381 let p = self.byte_probs[b] as f64;
382 if (b as u8) & mask != 0 {
383 sum_one += p;
384 } else {
385 sum_zero += p;
386 }
387 }
388 } else {
389 // Some bits decoded. Only consider bytes matching the partial prefix.
390 let partial = (c0 & ((1u32 << bpos) - 1)) as u8;
391 let shift = 8 - bpos;
392 let base = (partial as usize) << shift;
393 let count = 1usize << shift;
394
395 for i in 0..count {
396 let b = base | i;
397 let p = self.byte_probs[b] as f64;
398 if (b as u8) & mask != 0 {
399 sum_one += p;
400 } else {
401 sum_zero += p;
402 }
403 }
404 }
405
406 let total = sum_one + sum_zero;
407 if total < 1e-15 {
408 return 2048;
409 }
410
411 let p = ((sum_one * 4096.0) / total) as u32;
412 p.clamp(1, 4095)
413 }
414
415 /// Online training with truncated BPTT.
416 ///
417 /// Computes the output-layer gradient for `actual_byte`, then propagates
418 /// d_h backwards through up to BPTT_HORIZON stored steps. Gradients are
419 /// accumulated into pre-allocated buffers and applied in a single weight
420 /// update at the end. This avoids the weight-corruption from immediate
421 /// per-step updates that would otherwise occur in BPTT.
422 ///
423 /// Call this BEFORE forward(actual_byte) (matches codec.rs flow). Both
424 /// encoder and decoder see the same byte sequence so their weights and
425 /// history buffers evolve identically — parity is preserved.
426 #[inline(never)]
427 #[allow(
428 clippy::needless_range_loop,
429 reason = "matrix ops are clearer with explicit indices"
430 )]
431 pub fn train(&mut self, actual_byte: u8) {
432 if !self.has_context {
433 return;
434 }
435
436 let target = actual_byte as usize;
437
438 // ─── Output layer: update W_o, b_o and compute initial d_h ──────────
439 // Cross-entropy + softmax gradient: d_logits[i] = probs[i] - (i==target).
440 let mut d_h = [0.0f32; HIDDEN_DIM];
441
442 for i in 0..VOCAB_SIZE {
443 let dl = clip_grad(self.byte_probs[i] - if i == target { 1.0 } else { 0.0 });
444 if dl.abs() < 1e-7 {
445 // Skip near-zero gradient — most of the 256 outputs.
446 continue;
447 }
448 let w_row = &mut self.w_o[i * HIDDEN_DIM..(i + 1) * HIDDEN_DIM];
449 let lr_dl = LEARNING_RATE * dl;
450 for j in 0..HIDDEN_DIM {
451 d_h[j] += dl * w_row[j];
452 w_row[j] -= lr_dl * self.h[j];
453 }
454 self.b_o[i] -= LEARNING_RATE * dl;
455 }
456
457 // ─── Zero gradient accumulators ──────────────────────────────────────
458 self.grad_w_z.fill(0.0);
459 self.grad_u_z.fill(0.0);
460 self.grad_b_z.fill(0.0);
461 self.grad_w_r.fill(0.0);
462 self.grad_u_r.fill(0.0);
463 self.grad_b_r.fill(0.0);
464 self.grad_w_h.fill(0.0);
465 self.grad_u_h.fill(0.0);
466 self.grad_b_h.fill(0.0);
467
468 // ─── BPTT: propagate d_h backwards through history ───────────────────
469 // Iterate from most-recent (step_back=0) to oldest (step_back=steps-1).
470 // At each step, accumulate weight gradients and compute d_h_prev to
471 // pass to the step before it.
472 let steps = self.hist_count;
473
474 // Gate gradients at step 0 — needed for the embedding update below.
475 let mut d_pre_z_s0 = [0.0f32; HIDDEN_DIM];
476 let mut d_pre_r_s0 = [0.0f32; HIDDEN_DIM];
477 let mut d_pre_h_s0 = [0.0f32; HIDDEN_DIM];
478
479 for step_back in 0..steps {
480 // Read from ring: most-recent step is at hist_pos - 1 (mod BPTT_HORIZON).
481 let ring_idx = (self.hist_pos + BPTT_HORIZON - 1 - step_back) % BPTT_HORIZON;
482
483 let x_base = ring_idx * EMBED_DIM;
484 let h_base = ring_idx * HIDDEN_DIM;
485
486 // ── GRU cell backward (one step) ─────────────────────────────────
487 // Given d_h (gradient of loss w.r.t. h_t), compute:
488 // d_h_tilde, d_pre_z, d_pre_h (upstream gradients through h_t)
489 // d_pre_r (upstream gradient through the reset gate path)
490 // Accumulate dW, dU, db for all three gates.
491 // Compute d_h_prev to pass to the previous step.
492
493 let mut d_pre_z = [0.0f32; HIDDEN_DIM];
494 let mut d_pre_r = [0.0f32; HIDDEN_DIM];
495 let mut d_pre_h = [0.0f32; HIDDEN_DIM];
496
497 for i in 0..HIDDEN_DIM {
498 let dhi = clip_grad(d_h[i]);
499 let z_i = self.hist_z[h_base + i];
500 let r_i = self.hist_r[h_base + i];
501 let h_tilde_i = self.hist_h_tilde[h_base + i];
502 let h_prev_i = self.hist_h_prev[h_base + i];
503
504 // h_t = (1-z)*h_prev + z*h_tilde ⟹ dh_tilde = dh * z
505 let d_h_tilde_i = dhi * z_i;
506 // dz = dh * (h_tilde - h_prev)
507 let dz_i = dhi * (h_tilde_i - h_prev_i);
508
509 // Sigmoid backward: d_pre_z = dz * z * (1-z)
510 d_pre_z[i] = clip_grad(dz_i * z_i * (1.0 - z_i));
511 // Tanh backward: d_pre_h = d_h_tilde * (1 - h_tilde²)
512 d_pre_h[i] = clip_grad(d_h_tilde_i * (1.0 - h_tilde_i * h_tilde_i));
513
514 // Bias gradients.
515 self.grad_b_z[i] += d_pre_z[i];
516 self.grad_b_h[i] += d_pre_h[i];
517
518 // Input weight gradients: dW_z += d_pre_z[i] * x^T
519 let w_off = i * EMBED_DIM;
520 let lr_dpz = d_pre_z[i];
521 let lr_dph = d_pre_h[i];
522 for j in 0..EMBED_DIM {
523 let xj = self.hist_x[x_base + j];
524 self.grad_w_z[w_off + j] += lr_dpz * xj;
525 self.grad_w_h[w_off + j] += lr_dph * xj;
526 }
527
528 // Recurrent weight gradients: dU_z += d_pre_z[i] * h_prev^T
529 // Also accumulate d_rh = (U_h^T @ d_pre_h) for the reset gate.
530 // Note: d_rh is NOT used for d_h_prev here — that uses
531 // sum_i(d_pre_h[i] * r[i] * U_h[i,j]) — see loop below.
532 let u_off = i * HIDDEN_DIM;
533 for j in 0..HIDDEN_DIM {
534 let hj = self.hist_h_prev[h_base + j];
535 self.grad_u_z[u_off + j] += d_pre_z[i] * hj;
536 // dU_h[i,j] = d_pre_h[i] * r[i] * h_prev[j]
537 self.grad_u_h[u_off + j] += d_pre_h[i] * r_i * hj;
538 }
539 }
540
541 // Reset gate backward.
542 // d_rh[j] = sum_i(d_pre_h[i] * U_h[i,j]) = (U_h^T @ d_pre_h)[j]
543 // dr[j] = d_rh[j] * h_prev[j] (gradient w.r.t. r[j])
544 // d_pre_r[j] = dr[j] * r[j] * (1-r[j]) (sigmoid backward)
545 let mut d_rh = [0.0f32; HIDDEN_DIM];
546 for i in 0..HIDDEN_DIM {
547 let u_off = i * HIDDEN_DIM;
548 for j in 0..HIDDEN_DIM {
549 d_rh[j] += d_pre_h[i] * self.u_h[u_off + j];
550 }
551 }
552 for j in 0..HIDDEN_DIM {
553 let dr = clip_grad(d_rh[j] * self.hist_h_prev[h_base + j]);
554 d_pre_r[j] =
555 clip_grad(dr * self.hist_r[h_base + j] * (1.0 - self.hist_r[h_base + j]));
556 self.grad_b_r[j] += d_pre_r[j];
557 }
558 // Accumulate dW_r and dU_r.
559 for i in 0..HIDDEN_DIM {
560 let dp = d_pre_r[i];
561 let w_off = i * EMBED_DIM;
562 let u_off = i * HIDDEN_DIM;
563 for j in 0..EMBED_DIM {
564 self.grad_w_r[w_off + j] += dp * self.hist_x[x_base + j];
565 }
566 for j in 0..HIDDEN_DIM {
567 self.grad_u_r[u_off + j] += dp * self.hist_h_prev[h_base + j];
568 }
569 }
570
571 // Save step-0 gate gradients for the embedding update.
572 if step_back == 0 {
573 d_pre_z_s0.copy_from_slice(&d_pre_z);
574 d_pre_r_s0.copy_from_slice(&d_pre_r);
575 d_pre_h_s0.copy_from_slice(&d_pre_h);
576 }
577
578 // ── Propagate d_h to the previous step ───────────────────────────
579 // d_h_prev[j] = d_h[j] * (1 - z[j]) (direct path)
580 // + sum_i(d_pre_z[i] * U_z[i,j]) (update gate path)
581 // + sum_i(d_pre_r[i] * U_r[i,j]) (reset gate path)
582 // + sum_i(d_pre_h[i] * r[i] * U_h[i,j]) (candidate path)
583 let mut d_h_prev = [0.0f32; HIDDEN_DIM];
584
585 // Direct path.
586 for j in 0..HIDDEN_DIM {
587 d_h_prev[j] = clip_grad(d_h[j]) * (1.0 - self.hist_z[h_base + j]);
588 }
589
590 // Gate recurrent paths: loop over hidden units i, accumulate into j.
591 for i in 0..HIDDEN_DIM {
592 let dpz = d_pre_z[i];
593 let dpr = d_pre_r[i];
594 // d_pre_h[i] * r[i] — used for the U_h candidate path.
595 let dph_r = d_pre_h[i] * self.hist_r[h_base + i];
596 let u_off = i * HIDDEN_DIM;
597 for j in 0..HIDDEN_DIM {
598 d_h_prev[j] += dpz * self.u_z[u_off + j];
599 d_h_prev[j] += dpr * self.u_r[u_off + j];
600 d_h_prev[j] += dph_r * self.u_h[u_off + j];
601 }
602 }
603
604 // Clip for stability before passing to the next step.
605 for j in 0..HIDDEN_DIM {
606 d_h_prev[j] = clip_grad(d_h_prev[j]);
607 }
608 d_h.copy_from_slice(&d_h_prev);
609 }
610
611 // ─── Apply accumulated weight gradients in one shot ───────────────────
612 // Using current weights (not stored snapshots) for the update is standard
613 // "online BPTT" practice — the per-step weight change from LR=0.01 is
614 // small enough that the approximation is negligible.
615 for i in 0..HIDDEN_DIM {
616 let w_off = i * EMBED_DIM;
617 let u_off = i * HIDDEN_DIM;
618 for j in 0..EMBED_DIM {
619 self.w_z[w_off + j] -= LEARNING_RATE * clip_grad(self.grad_w_z[w_off + j]);
620 self.w_r[w_off + j] -= LEARNING_RATE * clip_grad(self.grad_w_r[w_off + j]);
621 self.w_h[w_off + j] -= LEARNING_RATE * clip_grad(self.grad_w_h[w_off + j]);
622 }
623 for j in 0..HIDDEN_DIM {
624 self.u_z[u_off + j] -= LEARNING_RATE * clip_grad(self.grad_u_z[u_off + j]);
625 self.u_r[u_off + j] -= LEARNING_RATE * clip_grad(self.grad_u_r[u_off + j]);
626 self.u_h[u_off + j] -= LEARNING_RATE * clip_grad(self.grad_u_h[u_off + j]);
627 }
628 self.b_z[i] -= LEARNING_RATE * clip_grad(self.grad_b_z[i]);
629 self.b_r[i] -= LEARNING_RATE * clip_grad(self.grad_b_r[i]);
630 self.b_h[i] -= LEARNING_RATE * clip_grad(self.grad_b_h[i]);
631 }
632
633 // ─── Embedding gradient (current step only) ───────────────────────────
634 // Update the embedding for the target byte using the step-0 gate gradients.
635 // Only the current step's input embedding is updated — historical embeddings
636 // are treated as fixed inputs in BPTT (standard practice for online learning).
637 let embed_start = target * EMBED_DIM;
638 for j in 0..EMBED_DIM {
639 let mut d_xj: f32 = 0.0;
640 for i in 0..HIDDEN_DIM {
641 let off = i * EMBED_DIM + j;
642 d_xj += d_pre_z_s0[i] * self.w_z[off];
643 d_xj += d_pre_r_s0[i] * self.w_r[off];
644 d_xj += d_pre_h_s0[i] * self.w_h[off];
645 }
646 self.embedding[embed_start + j] -= LEARNING_RATE * clip_grad(d_xj);
647 }
648 }
649}
650
651impl Default for GruModel {
652 fn default() -> Self {
653 Self::new()
654 }
655}
656
657// ─── Activation functions ────────────────────────────────────────────────────
658// CRITICAL: These must produce IDENTICAL results in encoder and decoder.
659// The same f32 operations guarantee bit-exact results across both paths.
660
661/// Sigmoid activation: 1 / (1 + exp(-x)), clamped to prevent overflow.
662#[inline]
663fn sigmoid(x: f32) -> f32 {
664 let x = x.clamp(-15.0, 15.0);
665 1.0 / (1.0 + (-x).exp())
666}
667
668/// Tanh using the identity tanh(x) = 2*sigmoid(2x) - 1.
669/// Reusing sigmoid ensures tanh and sigmoid use the SAME exp() path.
670#[inline]
671fn tanh_approx(x: f32) -> f32 {
672 let x = x.clamp(-7.5, 7.5);
673 2.0 * sigmoid(2.0 * x) - 1.0
674}
675
676/// Clip gradient magnitude to prevent explosion during BPTT.
677#[inline]
678fn clip_grad(g: f32) -> f32 {
679 g.clamp(-GRAD_CLIP, GRAD_CLIP)
680}
681
682/// Fill a weight slice with deterministic pseudo-random Xavier initialization.
683fn fill_xavier(weights: &mut [f32], scale: f32, seed: &mut u64) {
684 for w in weights.iter_mut() {
685 // xorshift64 PRNG — deterministic and fast.
686 *seed ^= *seed << 13;
687 *seed ^= *seed >> 7;
688 *seed ^= *seed << 17;
689 let r = (*seed as f32 / u64::MAX as f32) * 2.0 - 1.0;
690 *w = r * scale;
691 }
692}
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697
698 #[test]
699 fn sigmoid_basic() {
700 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
701 assert!(sigmoid(15.0) > 0.999);
702 assert!(sigmoid(-15.0) < 0.001);
703 }
704
705 #[test]
706 fn tanh_basic() {
707 assert!((tanh_approx(0.0)).abs() < 1e-6);
708 assert!(tanh_approx(7.0) > 0.99);
709 assert!(tanh_approx(-7.0) < -0.99);
710 }
711
712 #[test]
713 fn deterministic_init() {
714 let m1 = GruModel::new();
715 let m2 = GruModel::new();
716 assert_eq!(m1.embedding, m2.embedding);
717 assert_eq!(m1.w_z, m2.w_z);
718 assert_eq!(m1.w_o, m2.w_o);
719 }
720
721 #[test]
722 fn initial_predict_bit_uniform() {
723 let model = GruModel::new();
724 let p = model.predict_bit(0, 1);
725 assert_eq!(p, 2048, "before any forward pass, should return 2048");
726 }
727
728 #[test]
729 fn forward_produces_valid_probs() {
730 let mut model = GruModel::new();
731 model.forward(b'A');
732 let sum: f64 = model.byte_probs.iter().map(|&p| p as f64).sum();
733 assert!(
734 (sum - 1.0).abs() < 0.01,
735 "byte_probs should sum to ~1.0, got {sum}"
736 );
737 for &p in &model.byte_probs {
738 assert!(p >= 0.0, "negative probability: {p}");
739 }
740 }
741
742 #[test]
743 fn predict_bit_in_range() {
744 let mut model = GruModel::new();
745 model.forward(b'A');
746 for bpos in 0..8u8 {
747 let c0 = if bpos == 0 {
748 1u32
749 } else {
750 let mut p = 1u32;
751 for prev in 0..bpos {
752 p = (p << 1) | ((b'B' >> (7 - prev)) & 1) as u32;
753 }
754 p
755 };
756 let p = model.predict_bit(bpos, c0);
757 assert!(
758 (1..=4095).contains(&p),
759 "predict_bit out of range at bpos {bpos}: {p}"
760 );
761 }
762 }
763
764 #[test]
765 fn train_does_not_crash() {
766 let mut model = GruModel::new();
767 model.forward(b'A');
768 model.train(b'B');
769 // Should still produce valid output.
770 model.forward(b'B');
771 let sum: f64 = model.byte_probs.iter().map(|&p| p as f64).sum();
772 assert!(
773 (sum - 1.0).abs() < 0.01,
774 "probs after training should sum to ~1.0, got {sum}"
775 );
776 }
777
778 #[test]
779 fn history_ring_fills_correctly() {
780 let mut model = GruModel::new();
781 // Before any forward, history is empty.
782 assert_eq!(model.hist_count, 0);
783
784 // After N forward passes, history fills up to BPTT_HORIZON.
785 for i in 0..BPTT_HORIZON + 3 {
786 model.forward(b'A' + (i % 26) as u8);
787 let expected = (i + 1).min(BPTT_HORIZON);
788 assert_eq!(model.hist_count, expected, "hist_count wrong at step {i}");
789 }
790 // hist_pos should have wrapped.
791 assert_eq!(model.hist_pos, 3);
792 }
793
794 #[test]
795 fn bptt_does_not_produce_nan() {
796 let mut model = GruModel::new();
797 let data = b"Hello, World! This is a BPTT test. Let's check for NaN.";
798 for &byte in data {
799 model.forward(byte);
800 model.train(byte);
801 for j in 0..HIDDEN_DIM {
802 assert!(!model.h[j].is_nan(), "hidden state has NaN at j={j}");
803 }
804 for &p in &model.byte_probs {
805 assert!(!p.is_nan(), "byte_probs has NaN");
806 }
807 }
808 }
809
810 #[test]
811 fn encoder_decoder_identical() {
812 // Encoder and decoder must produce bit-identical predictions throughout.
813 let mut enc = GruModel::new();
814 let mut dec = GruModel::new();
815 let data = b"Hello, World! Testing BPTT encoder-decoder parity.";
816
817 for &byte in data {
818 enc.forward(byte);
819 dec.forward(byte);
820
821 // Predictions must match exactly.
822 for bpos in 0..8u8 {
823 let c0 = if bpos == 0 {
824 1u32
825 } else {
826 let mut p = 1u32;
827 for prev in 0..bpos {
828 p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
829 }
830 p
831 };
832 let pe = enc.predict_bit(bpos, c0);
833 let pd = dec.predict_bit(bpos, c0);
834 assert_eq!(pe, pd, "encoder/decoder diverged at bpos {bpos}");
835 }
836
837 // Train both with the same byte (same as codec flow).
838 enc.train(byte);
839 dec.train(byte);
840 }
841
842 // After training, hidden states must match.
843 assert_eq!(enc.h, dec.h, "hidden states diverged after training");
844 // History ring must match.
845 assert_eq!(enc.hist_count, dec.hist_count, "hist_count diverged");
846 assert_eq!(enc.hist_pos, dec.hist_pos, "hist_pos diverged");
847 }
848
849 #[test]
850 fn bptt_improves_over_1step() {
851 // BPTT-trained model should learn patterns faster than 1-step SGD.
852 // Test: repeated pattern "ab" — after BPTT training, should predict 'b'
853 // after 'a' with high confidence.
854 let mut model = GruModel::new();
855 let pattern: Vec<u8> = b"ab".repeat(200);
856 for &byte in &pattern {
857 model.train(byte);
858 model.forward(byte);
859 }
860 // After 'a' in the pattern, next byte should be 'b'.
861 model.train(b'a');
862 model.forward(b'a');
863 let p_b = model.byte_probs[b'b' as usize];
864 assert!(
865 p_b > 0.1,
866 "after 'a' in ab pattern with BPTT, P('b')={p_b} should be significant"
867 );
868 }
869
870 #[test]
871 fn adapts_to_pattern() {
872 let mut model = GruModel::new();
873 let pattern: Vec<u8> = b"ab".repeat(500);
874 for &byte in &pattern {
875 model.train(byte);
876 model.forward(byte);
877 }
878 model.train(b'a');
879 model.forward(b'a');
880 let p_b = model.byte_probs[b'b' as usize];
881 assert!(
882 p_b > 0.1,
883 "after 'a' in ab pattern, P('b')={p_b} should be significant"
884 );
885 }
886}