const EMBED_DIM: usize = 32;
const HIDDEN_DIM: usize = 128;
const VOCAB_SIZE: usize = 256;
const BPTT_HORIZON: usize = 10;
const LEARNING_RATE: f32 = 0.01;
const GRAD_CLIP: f32 = 5.0;
pub struct GruModel {
embedding: Vec<f32>,
w_z: Vec<f32>,
u_z: Vec<f32>,
b_z: Vec<f32>,
w_r: Vec<f32>,
u_r: Vec<f32>,
b_r: Vec<f32>,
w_h: Vec<f32>,
u_h: Vec<f32>,
b_h: Vec<f32>,
w_o: Vec<f32>,
b_o: Vec<f32>,
h: Vec<f32>,
last_x: Vec<f32>,
last_h_prev: Vec<f32>,
last_z: Vec<f32>,
last_r: Vec<f32>,
last_h_tilde: Vec<f32>,
byte_probs: Vec<f32>,
probs_valid: bool,
has_context: bool,
hist_x: Vec<f32>,
hist_h_prev: Vec<f32>,
hist_z: Vec<f32>,
hist_r: Vec<f32>,
hist_h_tilde: Vec<f32>,
hist_pos: usize,
hist_count: usize,
grad_w_z: Vec<f32>,
grad_u_z: Vec<f32>,
grad_b_z: Vec<f32>,
grad_w_r: Vec<f32>,
grad_u_r: Vec<f32>,
grad_b_r: Vec<f32>,
grad_w_h: Vec<f32>,
grad_u_h: Vec<f32>,
grad_b_h: Vec<f32>,
}
impl GruModel {
pub fn new() -> Self {
let mut model = GruModel {
embedding: vec![0.0; VOCAB_SIZE * EMBED_DIM],
w_z: vec![0.0; HIDDEN_DIM * EMBED_DIM],
u_z: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
b_z: vec![0.0; HIDDEN_DIM],
w_r: vec![0.0; HIDDEN_DIM * EMBED_DIM],
u_r: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
b_r: vec![0.0; HIDDEN_DIM],
w_h: vec![0.0; HIDDEN_DIM * EMBED_DIM],
u_h: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
b_h: vec![0.0; HIDDEN_DIM],
w_o: vec![0.0; VOCAB_SIZE * HIDDEN_DIM],
b_o: vec![0.0; VOCAB_SIZE],
h: vec![0.0; HIDDEN_DIM],
last_x: vec![0.0; EMBED_DIM],
last_h_prev: vec![0.0; HIDDEN_DIM],
last_z: vec![0.0; HIDDEN_DIM],
last_r: vec![0.0; HIDDEN_DIM],
last_h_tilde: vec![0.0; HIDDEN_DIM],
byte_probs: vec![1.0 / VOCAB_SIZE as f32; VOCAB_SIZE],
probs_valid: false,
has_context: false,
hist_x: vec![0.0; BPTT_HORIZON * EMBED_DIM],
hist_h_prev: vec![0.0; BPTT_HORIZON * HIDDEN_DIM],
hist_z: vec![0.0; BPTT_HORIZON * HIDDEN_DIM],
hist_r: vec![0.0; BPTT_HORIZON * HIDDEN_DIM],
hist_h_tilde: vec![0.0; BPTT_HORIZON * HIDDEN_DIM],
hist_pos: 0,
hist_count: 0,
grad_w_z: vec![0.0; HIDDEN_DIM * EMBED_DIM],
grad_u_z: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
grad_b_z: vec![0.0; HIDDEN_DIM],
grad_w_r: vec![0.0; HIDDEN_DIM * EMBED_DIM],
grad_u_r: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
grad_b_r: vec![0.0; HIDDEN_DIM],
grad_w_h: vec![0.0; HIDDEN_DIM * EMBED_DIM],
grad_u_h: vec![0.0; HIDDEN_DIM * HIDDEN_DIM],
grad_b_h: vec![0.0; HIDDEN_DIM],
};
model.init_weights();
model
}
fn init_weights(&mut self) {
let mut seed: u64 = 0xDEAD_BEEF_CAFE_1234;
let embed_scale = (2.0 / (VOCAB_SIZE + EMBED_DIM) as f32).sqrt();
fill_xavier(&mut self.embedding, embed_scale, &mut seed);
let wx_scale = (2.0 / (EMBED_DIM + HIDDEN_DIM) as f32).sqrt();
fill_xavier(&mut self.w_z, wx_scale, &mut seed);
fill_xavier(&mut self.w_r, wx_scale, &mut seed);
fill_xavier(&mut self.w_h, wx_scale, &mut seed);
let uh_scale = (2.0 / (HIDDEN_DIM + HIDDEN_DIM) as f32).sqrt();
fill_xavier(&mut self.u_z, uh_scale, &mut seed);
fill_xavier(&mut self.u_r, uh_scale, &mut seed);
fill_xavier(&mut self.u_h, uh_scale, &mut seed);
let wo_scale = (2.0 / (HIDDEN_DIM + VOCAB_SIZE) as f32).sqrt();
fill_xavier(&mut self.w_o, wo_scale, &mut seed);
for b in self.b_z.iter_mut() {
*b = 1.0;
}
}
#[inline(never)]
#[allow(
clippy::needless_range_loop,
reason = "matrix ops are clearer with explicit indices"
)]
pub fn forward(&mut self, byte: u8) {
self.last_h_prev.copy_from_slice(&self.h);
let byte_idx = byte as usize;
let embed_start = byte_idx * EMBED_DIM;
self.last_x
.copy_from_slice(&self.embedding[embed_start..embed_start + EMBED_DIM]);
for i in 0..HIDDEN_DIM {
let w_off = i * EMBED_DIM;
let wz_row = &self.w_z[w_off..w_off + EMBED_DIM];
let wr_row = &self.w_r[w_off..w_off + EMBED_DIM];
let wh_row = &self.w_h[w_off..w_off + EMBED_DIM];
let mut val_z = self.b_z[i];
let mut val_r = self.b_r[i];
let mut val_h = self.b_h[i];
for j in 0..EMBED_DIM {
let xj = self.last_x[j];
val_z += wz_row[j] * xj;
val_r += wr_row[j] * xj;
val_h += wh_row[j] * xj;
}
let u_off = i * HIDDEN_DIM;
let uz_row = &self.u_z[u_off..u_off + HIDDEN_DIM];
let ur_row = &self.u_r[u_off..u_off + HIDDEN_DIM];
for j in 0..HIDDEN_DIM {
let hj = self.last_h_prev[j];
val_z += uz_row[j] * hj;
val_r += ur_row[j] * hj;
}
let z_i = sigmoid(val_z);
let r_i = sigmoid(val_r);
self.last_z[i] = z_i;
self.last_r[i] = r_i;
let uh_row = &self.u_h[u_off..u_off + HIDDEN_DIM];
for j in 0..HIDDEN_DIM {
val_h += uh_row[j] * (r_i * self.last_h_prev[j]);
}
let h_tilde_i = tanh_approx(val_h);
self.last_h_tilde[i] = h_tilde_i;
self.h[i] = (1.0 - z_i) * self.last_h_prev[i] + z_i * h_tilde_i;
}
self.compute_output_probs();
self.probs_valid = true;
self.has_context = true;
let x_base = self.hist_pos * EMBED_DIM;
self.hist_x[x_base..x_base + EMBED_DIM].copy_from_slice(&self.last_x);
let h_base = self.hist_pos * HIDDEN_DIM;
self.hist_h_prev[h_base..h_base + HIDDEN_DIM].copy_from_slice(&self.last_h_prev);
self.hist_z[h_base..h_base + HIDDEN_DIM].copy_from_slice(&self.last_z);
self.hist_r[h_base..h_base + HIDDEN_DIM].copy_from_slice(&self.last_r);
self.hist_h_tilde[h_base..h_base + HIDDEN_DIM].copy_from_slice(&self.last_h_tilde);
self.hist_pos = (self.hist_pos + 1) % BPTT_HORIZON;
if self.hist_count < BPTT_HORIZON {
self.hist_count += 1;
}
}
#[inline(never)]
#[allow(
clippy::needless_range_loop,
reason = "matrix ops are clearer with explicit indices"
)]
fn compute_output_probs(&mut self) {
let mut max_logit: f32 = f32::NEG_INFINITY;
for i in 0..VOCAB_SIZE {
let w_row = &self.w_o[i * HIDDEN_DIM..(i + 1) * HIDDEN_DIM];
let mut logit = self.b_o[i];
for j in 0..HIDDEN_DIM {
logit += w_row[j] * self.h[j];
}
self.byte_probs[i] = logit;
if logit > max_logit {
max_logit = logit;
}
}
let mut sum: f32 = 0.0;
for p in self.byte_probs.iter_mut() {
let e = (*p - max_logit).exp();
*p = e;
sum += e;
}
let inv_sum = 1.0 / (sum + 1e-30);
for p in self.byte_probs.iter_mut() {
*p *= inv_sum;
if *p < 1e-8 {
*p = 1e-8;
}
}
}
#[inline]
pub fn predict_bit(&self, bpos: u8, c0: u32) -> u32 {
if !self.has_context {
return 2048; }
let bit_pos = 7 - bpos;
let mask = 1u8 << bit_pos;
let mut sum_one: f64 = 0.0;
let mut sum_zero: f64 = 0.0;
if bpos == 0 {
for b in 0..VOCAB_SIZE {
let p = self.byte_probs[b] as f64;
if (b as u8) & mask != 0 {
sum_one += p;
} else {
sum_zero += p;
}
}
} else {
let partial = (c0 & ((1u32 << bpos) - 1)) as u8;
let shift = 8 - bpos;
let base = (partial as usize) << shift;
let count = 1usize << shift;
for i in 0..count {
let b = base | i;
let p = self.byte_probs[b] as f64;
if (b as u8) & mask != 0 {
sum_one += p;
} else {
sum_zero += p;
}
}
}
let total = sum_one + sum_zero;
if total < 1e-15 {
return 2048;
}
let p = ((sum_one * 4096.0) / total) as u32;
p.clamp(1, 4095)
}
#[inline(never)]
#[allow(
clippy::needless_range_loop,
reason = "matrix ops are clearer with explicit indices"
)]
pub fn train(&mut self, actual_byte: u8) {
if !self.has_context {
return;
}
let target = actual_byte as usize;
let mut d_h = [0.0f32; HIDDEN_DIM];
for i in 0..VOCAB_SIZE {
let dl = clip_grad(self.byte_probs[i] - if i == target { 1.0 } else { 0.0 });
if dl.abs() < 1e-7 {
continue;
}
let w_row = &mut self.w_o[i * HIDDEN_DIM..(i + 1) * HIDDEN_DIM];
let lr_dl = LEARNING_RATE * dl;
for j in 0..HIDDEN_DIM {
d_h[j] += dl * w_row[j];
w_row[j] -= lr_dl * self.h[j];
}
self.b_o[i] -= LEARNING_RATE * dl;
}
self.grad_w_z.fill(0.0);
self.grad_u_z.fill(0.0);
self.grad_b_z.fill(0.0);
self.grad_w_r.fill(0.0);
self.grad_u_r.fill(0.0);
self.grad_b_r.fill(0.0);
self.grad_w_h.fill(0.0);
self.grad_u_h.fill(0.0);
self.grad_b_h.fill(0.0);
let steps = self.hist_count;
let mut d_pre_z_s0 = [0.0f32; HIDDEN_DIM];
let mut d_pre_r_s0 = [0.0f32; HIDDEN_DIM];
let mut d_pre_h_s0 = [0.0f32; HIDDEN_DIM];
for step_back in 0..steps {
let ring_idx = (self.hist_pos + BPTT_HORIZON - 1 - step_back) % BPTT_HORIZON;
let x_base = ring_idx * EMBED_DIM;
let h_base = ring_idx * HIDDEN_DIM;
let mut d_pre_z = [0.0f32; HIDDEN_DIM];
let mut d_pre_r = [0.0f32; HIDDEN_DIM];
let mut d_pre_h = [0.0f32; HIDDEN_DIM];
for i in 0..HIDDEN_DIM {
let dhi = clip_grad(d_h[i]);
let z_i = self.hist_z[h_base + i];
let r_i = self.hist_r[h_base + i];
let h_tilde_i = self.hist_h_tilde[h_base + i];
let h_prev_i = self.hist_h_prev[h_base + i];
let d_h_tilde_i = dhi * z_i;
let dz_i = dhi * (h_tilde_i - h_prev_i);
d_pre_z[i] = clip_grad(dz_i * z_i * (1.0 - z_i));
d_pre_h[i] = clip_grad(d_h_tilde_i * (1.0 - h_tilde_i * h_tilde_i));
self.grad_b_z[i] += d_pre_z[i];
self.grad_b_h[i] += d_pre_h[i];
let w_off = i * EMBED_DIM;
let lr_dpz = d_pre_z[i];
let lr_dph = d_pre_h[i];
for j in 0..EMBED_DIM {
let xj = self.hist_x[x_base + j];
self.grad_w_z[w_off + j] += lr_dpz * xj;
self.grad_w_h[w_off + j] += lr_dph * xj;
}
let u_off = i * HIDDEN_DIM;
for j in 0..HIDDEN_DIM {
let hj = self.hist_h_prev[h_base + j];
self.grad_u_z[u_off + j] += d_pre_z[i] * hj;
self.grad_u_h[u_off + j] += d_pre_h[i] * r_i * hj;
}
}
let mut d_rh = [0.0f32; HIDDEN_DIM];
for i in 0..HIDDEN_DIM {
let u_off = i * HIDDEN_DIM;
for j in 0..HIDDEN_DIM {
d_rh[j] += d_pre_h[i] * self.u_h[u_off + j];
}
}
for j in 0..HIDDEN_DIM {
let dr = clip_grad(d_rh[j] * self.hist_h_prev[h_base + j]);
d_pre_r[j] =
clip_grad(dr * self.hist_r[h_base + j] * (1.0 - self.hist_r[h_base + j]));
self.grad_b_r[j] += d_pre_r[j];
}
for i in 0..HIDDEN_DIM {
let dp = d_pre_r[i];
let w_off = i * EMBED_DIM;
let u_off = i * HIDDEN_DIM;
for j in 0..EMBED_DIM {
self.grad_w_r[w_off + j] += dp * self.hist_x[x_base + j];
}
for j in 0..HIDDEN_DIM {
self.grad_u_r[u_off + j] += dp * self.hist_h_prev[h_base + j];
}
}
if step_back == 0 {
d_pre_z_s0.copy_from_slice(&d_pre_z);
d_pre_r_s0.copy_from_slice(&d_pre_r);
d_pre_h_s0.copy_from_slice(&d_pre_h);
}
let mut d_h_prev = [0.0f32; HIDDEN_DIM];
for j in 0..HIDDEN_DIM {
d_h_prev[j] = clip_grad(d_h[j]) * (1.0 - self.hist_z[h_base + j]);
}
for i in 0..HIDDEN_DIM {
let dpz = d_pre_z[i];
let dpr = d_pre_r[i];
let dph_r = d_pre_h[i] * self.hist_r[h_base + i];
let u_off = i * HIDDEN_DIM;
for j in 0..HIDDEN_DIM {
d_h_prev[j] += dpz * self.u_z[u_off + j];
d_h_prev[j] += dpr * self.u_r[u_off + j];
d_h_prev[j] += dph_r * self.u_h[u_off + j];
}
}
for j in 0..HIDDEN_DIM {
d_h_prev[j] = clip_grad(d_h_prev[j]);
}
d_h.copy_from_slice(&d_h_prev);
}
for i in 0..HIDDEN_DIM {
let w_off = i * EMBED_DIM;
let u_off = i * HIDDEN_DIM;
for j in 0..EMBED_DIM {
self.w_z[w_off + j] -= LEARNING_RATE * clip_grad(self.grad_w_z[w_off + j]);
self.w_r[w_off + j] -= LEARNING_RATE * clip_grad(self.grad_w_r[w_off + j]);
self.w_h[w_off + j] -= LEARNING_RATE * clip_grad(self.grad_w_h[w_off + j]);
}
for j in 0..HIDDEN_DIM {
self.u_z[u_off + j] -= LEARNING_RATE * clip_grad(self.grad_u_z[u_off + j]);
self.u_r[u_off + j] -= LEARNING_RATE * clip_grad(self.grad_u_r[u_off + j]);
self.u_h[u_off + j] -= LEARNING_RATE * clip_grad(self.grad_u_h[u_off + j]);
}
self.b_z[i] -= LEARNING_RATE * clip_grad(self.grad_b_z[i]);
self.b_r[i] -= LEARNING_RATE * clip_grad(self.grad_b_r[i]);
self.b_h[i] -= LEARNING_RATE * clip_grad(self.grad_b_h[i]);
}
let embed_start = target * EMBED_DIM;
for j in 0..EMBED_DIM {
let mut d_xj: f32 = 0.0;
for i in 0..HIDDEN_DIM {
let off = i * EMBED_DIM + j;
d_xj += d_pre_z_s0[i] * self.w_z[off];
d_xj += d_pre_r_s0[i] * self.w_r[off];
d_xj += d_pre_h_s0[i] * self.w_h[off];
}
self.embedding[embed_start + j] -= LEARNING_RATE * clip_grad(d_xj);
}
}
}
impl Default for GruModel {
fn default() -> Self {
Self::new()
}
}
#[inline]
fn sigmoid(x: f32) -> f32 {
let x = x.clamp(-15.0, 15.0);
1.0 / (1.0 + (-x).exp())
}
#[inline]
fn tanh_approx(x: f32) -> f32 {
let x = x.clamp(-7.5, 7.5);
2.0 * sigmoid(2.0 * x) - 1.0
}
#[inline]
fn clip_grad(g: f32) -> f32 {
g.clamp(-GRAD_CLIP, GRAD_CLIP)
}
fn fill_xavier(weights: &mut [f32], scale: f32, seed: &mut u64) {
for w in weights.iter_mut() {
*seed ^= *seed << 13;
*seed ^= *seed >> 7;
*seed ^= *seed << 17;
let r = (*seed as f32 / u64::MAX as f32) * 2.0 - 1.0;
*w = r * scale;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sigmoid_basic() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
assert!(sigmoid(15.0) > 0.999);
assert!(sigmoid(-15.0) < 0.001);
}
#[test]
fn tanh_basic() {
assert!((tanh_approx(0.0)).abs() < 1e-6);
assert!(tanh_approx(7.0) > 0.99);
assert!(tanh_approx(-7.0) < -0.99);
}
#[test]
fn deterministic_init() {
let m1 = GruModel::new();
let m2 = GruModel::new();
assert_eq!(m1.embedding, m2.embedding);
assert_eq!(m1.w_z, m2.w_z);
assert_eq!(m1.w_o, m2.w_o);
}
#[test]
fn initial_predict_bit_uniform() {
let model = GruModel::new();
let p = model.predict_bit(0, 1);
assert_eq!(p, 2048, "before any forward pass, should return 2048");
}
#[test]
fn forward_produces_valid_probs() {
let mut model = GruModel::new();
model.forward(b'A');
let sum: f64 = model.byte_probs.iter().map(|&p| p as f64).sum();
assert!(
(sum - 1.0).abs() < 0.01,
"byte_probs should sum to ~1.0, got {sum}"
);
for &p in &model.byte_probs {
assert!(p >= 0.0, "negative probability: {p}");
}
}
#[test]
fn predict_bit_in_range() {
let mut model = GruModel::new();
model.forward(b'A');
for bpos in 0..8u8 {
let c0 = if bpos == 0 {
1u32
} else {
let mut p = 1u32;
for prev in 0..bpos {
p = (p << 1) | ((b'B' >> (7 - prev)) & 1) as u32;
}
p
};
let p = model.predict_bit(bpos, c0);
assert!(
(1..=4095).contains(&p),
"predict_bit out of range at bpos {bpos}: {p}"
);
}
}
#[test]
fn train_does_not_crash() {
let mut model = GruModel::new();
model.forward(b'A');
model.train(b'B');
model.forward(b'B');
let sum: f64 = model.byte_probs.iter().map(|&p| p as f64).sum();
assert!(
(sum - 1.0).abs() < 0.01,
"probs after training should sum to ~1.0, got {sum}"
);
}
#[test]
fn history_ring_fills_correctly() {
let mut model = GruModel::new();
assert_eq!(model.hist_count, 0);
for i in 0..BPTT_HORIZON + 3 {
model.forward(b'A' + (i % 26) as u8);
let expected = (i + 1).min(BPTT_HORIZON);
assert_eq!(model.hist_count, expected, "hist_count wrong at step {i}");
}
assert_eq!(model.hist_pos, 3);
}
#[test]
fn bptt_does_not_produce_nan() {
let mut model = GruModel::new();
let data = b"Hello, World! This is a BPTT test. Let's check for NaN.";
for &byte in data {
model.forward(byte);
model.train(byte);
for j in 0..HIDDEN_DIM {
assert!(!model.h[j].is_nan(), "hidden state has NaN at j={j}");
}
for &p in &model.byte_probs {
assert!(!p.is_nan(), "byte_probs has NaN");
}
}
}
#[test]
fn encoder_decoder_identical() {
let mut enc = GruModel::new();
let mut dec = GruModel::new();
let data = b"Hello, World! Testing BPTT encoder-decoder parity.";
for &byte in data {
enc.forward(byte);
dec.forward(byte);
for bpos in 0..8u8 {
let c0 = if bpos == 0 {
1u32
} else {
let mut p = 1u32;
for prev in 0..bpos {
p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
}
p
};
let pe = enc.predict_bit(bpos, c0);
let pd = dec.predict_bit(bpos, c0);
assert_eq!(pe, pd, "encoder/decoder diverged at bpos {bpos}");
}
enc.train(byte);
dec.train(byte);
}
assert_eq!(enc.h, dec.h, "hidden states diverged after training");
assert_eq!(enc.hist_count, dec.hist_count, "hist_count diverged");
assert_eq!(enc.hist_pos, dec.hist_pos, "hist_pos diverged");
}
#[test]
fn bptt_improves_over_1step() {
let mut model = GruModel::new();
let pattern: Vec<u8> = b"ab".repeat(200);
for &byte in &pattern {
model.train(byte);
model.forward(byte);
}
model.train(b'a');
model.forward(b'a');
let p_b = model.byte_probs[b'b' as usize];
assert!(
p_b > 0.1,
"after 'a' in ab pattern with BPTT, P('b')={p_b} should be significant"
);
}
#[test]
fn adapts_to_pattern() {
let mut model = GruModel::new();
let pattern: Vec<u8> = b"ab".repeat(500);
for &byte in &pattern {
model.train(byte);
model.forward(byte);
}
model.train(b'a');
model.forward(b'a');
let p_b = model.byte_probs[b'b' as usize];
assert!(
p_b > 0.1,
"after 'a' in ab pattern, P('b')={p_b} should be significant"
);
}
}