use irithyll_core::rng::standard_normal;
pub(crate) struct TTTLayer {
w_k: Vec<f64>, w_v: Vec<f64>, w_q: Vec<f64>,
w_fast: Vec<f64>,
momentum_buf: Vec<f64>,
accumulated_grad: Vec<f64>, n_accumulated: usize,
pub(crate) batch_mode: bool,
d_model: usize, d_state: usize,
eta: f64, alpha: f64, use_momentum: bool, momentum_decay: f64, nesterov: bool, alpha_warmup: usize, step_count: u64, effective_alpha: f64,
mlp_w1: Option<Vec<f64>>, mlp_w2: Option<Vec<f64>>, mlp_v1: Option<Vec<f64>>, mlp_v2: Option<Vec<f64>>, mlp_hidden_dim: usize,
pub prediction_feedback: f64,
initialized: bool,
rng_state: u64,
}
impl TTTLayer {
#[allow(clippy::too_many_arguments)]
pub fn new(
d_state: usize,
eta: f64,
alpha: f64,
use_momentum: bool,
momentum_decay: f64,
nesterov: bool,
alpha_warmup: usize,
mlp_hidden_dim: usize,
seed: u64,
) -> Self {
let (mlp_w1, mlp_w2, mlp_v1, mlp_v2) = if mlp_hidden_dim > 0 {
let mut rng_init = if seed == 0 { 1 } else { seed };
let scale_w1 = (2.0 / (d_state + mlp_hidden_dim) as f64).sqrt();
let scale_w2 = (2.0 / (mlp_hidden_dim + d_state) as f64).sqrt();
let w1 = random_matrix(&mut rng_init, mlp_hidden_dim, d_state, scale_w1);
let w2 = random_matrix(&mut rng_init, d_state, mlp_hidden_dim, scale_w2);
let v1 = vec![0.0; mlp_hidden_dim * d_state];
let v2 = vec![0.0; d_state * mlp_hidden_dim];
(Some(w1), Some(w2), Some(v1), Some(v2))
} else {
(None, None, None, None)
};
Self {
w_k: Vec::new(),
w_v: Vec::new(),
w_q: Vec::new(),
w_fast: if mlp_hidden_dim == 0 {
vec![0.0; d_state * d_state]
} else {
Vec::new() },
momentum_buf: if use_momentum && mlp_hidden_dim == 0 {
vec![0.0; d_state * d_state]
} else {
Vec::new()
},
accumulated_grad: if mlp_hidden_dim == 0 {
vec![0.0; d_state * d_state]
} else {
Vec::new() },
n_accumulated: 0,
batch_mode: false,
d_model: 0,
d_state,
eta,
alpha,
use_momentum,
momentum_decay,
nesterov,
alpha_warmup,
step_count: 0,
effective_alpha: alpha,
mlp_w1,
mlp_w2,
mlp_v1,
mlp_v2,
mlp_hidden_dim,
prediction_feedback: 0.0,
initialized: false,
rng_state: if seed == 0 { 1 } else { seed },
}
}
#[allow(clippy::needless_range_loop)]
pub fn forward(&mut self, features: &[f64]) -> Vec<f64> {
self.ensure_init(features.len());
let d = self.d_state;
self.step_count += 1;
let effective_alpha = if self.alpha_warmup > 0 {
self.alpha * (self.step_count as f64 / self.alpha_warmup as f64).min(1.0)
} else {
self.alpha
};
self.effective_alpha = effective_alpha;
let input_norm: f64 = features.iter().map(|x| x * x).sum::<f64>().sqrt().max(1.0);
let normalized: Vec<f64> = features.iter().map(|x| x / input_norm).collect();
let mut k = mat_vec_mul(&self.w_k, &normalized, d);
let v = mat_vec_mul(&self.w_v, &normalized, d);
let mut q = mat_vec_mul(&self.w_q, &normalized, d);
let k_norm = k.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-8);
for ki in k.iter_mut() {
*ki /= k_norm;
}
let q_norm = q.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-8);
for qi in q.iter_mut() {
*qi /= q_norm;
}
let (z, mlp_h) = if self.mlp_hidden_dim > 0 {
let h_dim = self.mlp_hidden_dim;
let w1 = self.mlp_w1.as_ref().unwrap();
let w2 = self.mlp_w2.as_ref().unwrap();
let h_raw = mat_vec_mul_sq(w1, &k, h_dim, d);
let h: Vec<f64> = h_raw.iter().map(|&x| gelu(x)).collect();
let z_mlp = mat_vec_mul_sq(w2, &h, d, h_dim);
(z_mlp, Some(h))
} else if self.nesterov && self.use_momentum {
let mut z_nesterov = vec![0.0; d];
for i in 0..d {
let mut sum = 0.0;
for j in 0..d {
let idx = i * d + j;
let w_look = (1.0 - effective_alpha) * self.w_fast[idx]
+ self.momentum_decay * self.momentum_buf[idx];
sum += w_look * k[j];
}
z_nesterov[i] = sum;
}
(z_nesterov, None)
} else {
(fast_mat_vec(&self.w_fast, &k, d), None)
};
let mut residual = vec![0.0; d];
if self.prediction_feedback.abs() > 1e-15 {
let pred_err = self.prediction_feedback;
for i in 0..d {
residual[i] = -pred_err * q[i];
}
} else {
for i in 0..d {
residual[i] = z[i] - (v[i] - k[i]);
}
}
if self.mlp_hidden_dim > 0 {
let h = mlp_h.unwrap(); let h_dim = self.mlp_hidden_dim;
let w1 = self.mlp_w1.as_ref().unwrap();
let w2 = self.mlp_w2.as_mut().unwrap();
for i in 0..d {
for j in 0..h_dim {
let idx = i * h_dim + j;
let g = (residual[i] * h[j]).clamp(-1.0, 1.0);
if self.use_momentum {
let v2 = self.mlp_v2.as_mut().unwrap();
v2[idx] = self.momentum_decay * v2[idx] - self.eta * g;
w2[idx] = (1.0 - effective_alpha) * w2[idx] + v2[idx];
} else {
w2[idx] = (1.0 - effective_alpha) * w2[idx] - self.eta * g;
}
}
}
let w2_snap = self.mlp_w2.as_ref().unwrap();
let mut d_h = vec![0.0f64; h_dim];
for j in 0..h_dim {
let mut s = 0.0;
for i in 0..d {
s += w2_snap[i * h_dim + j] * residual[i];
}
d_h[j] = s;
}
let w1_ref = w1; let h_pre_raw = mat_vec_mul_sq(w1_ref, &k, h_dim, d);
let d_h_pre: Vec<f64> = h_pre_raw
.iter()
.zip(d_h.iter())
.map(|(&x, &dh)| dh * gelu_grad(x))
.collect();
let w1_mut = self.mlp_w1.as_mut().unwrap();
for i in 0..h_dim {
for j in 0..d {
let idx = i * d + j;
let g = (d_h_pre[i] * k[j]).clamp(-1.0, 1.0);
if self.use_momentum {
let v1 = self.mlp_v1.as_mut().unwrap();
v1[idx] = self.momentum_decay * v1[idx] - self.eta * g;
w1_mut[idx] = (1.0 - effective_alpha) * w1_mut[idx] + v1[idx];
} else {
w1_mut[idx] = (1.0 - effective_alpha) * w1_mut[idx] - self.eta * g;
}
}
}
} else if self.batch_mode {
for i in 0..d {
for j in 0..d {
let idx = i * d + j;
let grad = residual[i] * k[j];
let clipped_grad = grad.clamp(-1.0, 1.0);
self.accumulated_grad[idx] += clipped_grad;
}
}
self.n_accumulated += 1;
} else if self.use_momentum {
for i in 0..d {
for j in 0..d {
let idx = i * d + j;
let grad = residual[i] * k[j];
let clipped_grad = grad.clamp(-1.0, 1.0);
self.momentum_buf[idx] =
self.momentum_decay * self.momentum_buf[idx] - self.eta * clipped_grad;
self.w_fast[idx] =
(1.0 - effective_alpha) * self.w_fast[idx] + self.momentum_buf[idx];
}
}
} else {
for i in 0..d {
for j in 0..d {
let idx = i * d + j;
let grad = residual[i] * k[j];
let clipped_grad = grad.clamp(-1.0, 1.0);
self.w_fast[idx] =
(1.0 - effective_alpha) * self.w_fast[idx] - self.eta * clipped_grad;
}
}
}
if self.mlp_hidden_dim == 0 {
let w_max = self.w_fast.iter().map(|x| x.abs()).fold(0.0_f64, f64::max);
if w_max > 1e4 || w_max.is_nan() {
let scale = 1e3 / w_max.max(1e-15);
for w in &mut self.w_fast {
*w *= scale;
}
}
}
let wk = if self.mlp_hidden_dim > 0 {
let h_dim = self.mlp_hidden_dim;
let w1 = self.mlp_w1.as_ref().unwrap();
let w2 = self.mlp_w2.as_ref().unwrap();
let h_raw = mat_vec_mul_sq(w1, &k, h_dim, d);
let h: Vec<f64> = h_raw.iter().map(|&x| gelu(x)).collect();
mat_vec_mul_sq(w2, &h, d, h_dim)
} else {
fast_mat_vec(&self.w_fast, &k, d)
};
let mut activated = vec![0.0; d];
for i in 0..d {
activated[i] = gelu(q[i] + wk[i]);
}
let mean = activated.iter().sum::<f64>() / d as f64;
let var = activated
.iter()
.map(|x| (x - mean) * (x - mean))
.sum::<f64>()
/ d as f64;
let std_inv = 1.0 / (var + 1e-8).sqrt();
for a in activated.iter_mut() {
*a = (*a - mean) * std_inv;
}
let mut output = vec![0.0; d];
for i in 0..d {
output[i] = q[i] + activated[i];
}
output
}
pub fn forward_predict(&self, features: &[f64]) -> Vec<f64> {
if !self.initialized {
return vec![0.0; self.d_state];
}
let d = self.d_state;
let input_norm: f64 = features.iter().map(|x| x * x).sum::<f64>().sqrt().max(1.0);
let normalized: Vec<f64> = features.iter().map(|x| x / input_norm).collect();
let mut k = mat_vec_mul(&self.w_k, &normalized, d);
let mut q = mat_vec_mul(&self.w_q, &normalized, d);
let k_norm = k.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-8);
for ki in k.iter_mut() {
*ki /= k_norm;
}
let q_norm = q.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-8);
for qi in q.iter_mut() {
*qi /= q_norm;
}
let wk = if self.mlp_hidden_dim > 0 {
let h_dim = self.mlp_hidden_dim;
let w1 = self.mlp_w1.as_ref().unwrap();
let w2 = self.mlp_w2.as_ref().unwrap();
let h_raw = mat_vec_mul_sq(w1, &k, h_dim, d);
let h: Vec<f64> = h_raw.iter().map(|&x| gelu(x)).collect();
mat_vec_mul_sq(w2, &h, d, h_dim)
} else {
fast_mat_vec(&self.w_fast, &k, d)
};
let mut activated = vec![0.0; d];
for i in 0..d {
activated[i] = gelu(q[i] + wk[i]);
}
let mean = activated.iter().sum::<f64>() / d as f64;
let var = activated
.iter()
.map(|x| (x - mean) * (x - mean))
.sum::<f64>()
/ d as f64;
let std_inv = 1.0 / (var + 1e-8).sqrt();
for a in activated.iter_mut() {
*a = (*a - mean) * std_inv;
}
let mut output = vec![0.0; d];
for i in 0..d {
output[i] = q[i] + activated[i];
}
output
}
pub fn output_dim(&self) -> usize {
self.d_state
}
#[allow(dead_code)]
#[inline]
pub fn fast_weights(&self) -> &[f64] {
&self.w_fast
}
#[inline]
pub fn set_eta(&mut self, eta: f64) {
self.eta = eta;
}
#[inline]
pub fn set_alpha(&mut self, alpha: f64) {
self.alpha = alpha;
}
#[inline]
pub fn effective_alpha(&self) -> f64 {
self.effective_alpha
}
pub fn reset_fast_weights(&mut self) {
if self.mlp_hidden_dim > 0 {
if let Some(w) = &mut self.mlp_w1 {
w.fill(0.0);
}
if let Some(w) = &mut self.mlp_w2 {
w.fill(0.0);
}
if let Some(v) = &mut self.mlp_v1 {
v.fill(0.0);
}
if let Some(v) = &mut self.mlp_v2 {
v.fill(0.0);
}
} else {
self.w_fast.fill(0.0);
if self.use_momentum {
self.momentum_buf.fill(0.0);
}
self.accumulated_grad.fill(0.0);
self.n_accumulated = 0;
}
self.prediction_feedback = 0.0;
self.step_count = 0;
}
pub fn reinitialize_unit(&mut self, j: usize, rng: &mut u64) {
assert!(
j < self.d_state,
"unit index {} out of range (d_state={})",
j,
self.d_state
);
let scale = (2.0 / (self.d_state + self.d_state) as f64).sqrt();
let row_start = j * self.d_state;
for col in 0..self.d_state {
self.w_fast[row_start + col] = standard_normal(rng) * scale;
}
if self.use_momentum {
for col in 0..self.d_state {
self.momentum_buf[row_start + col] = 0.0;
}
}
for col in 0..self.d_state {
self.accumulated_grad[row_start + col] = 0.0;
}
}
pub fn reset_full(&mut self) {
self.reset_fast_weights();
self.w_fast.fill(0.0);
self.w_k.clear();
self.w_v.clear();
self.w_q.clear();
self.d_model = 0;
self.initialized = false;
}
pub(crate) fn flush_batch(&mut self) {
if self.n_accumulated == 0 {
return;
}
let d = self.d_state;
let n = self.n_accumulated as f64;
let effective_alpha = if self.alpha_warmup > 0 {
self.alpha * (self.step_count as f64 / self.alpha_warmup as f64).min(1.0)
} else {
self.alpha
};
if self.use_momentum {
for i in 0..d {
for j in 0..d {
let idx = i * d + j;
let avg_grad = self.accumulated_grad[idx] / n;
self.momentum_buf[idx] =
self.momentum_decay * self.momentum_buf[idx] - self.eta * avg_grad;
self.w_fast[idx] =
(1.0 - effective_alpha) * self.w_fast[idx] + self.momentum_buf[idx];
}
}
} else {
for i in 0..d {
for j in 0..d {
let idx = i * d + j;
let avg_grad = self.accumulated_grad[idx] / n;
self.w_fast[idx] =
(1.0 - effective_alpha) * self.w_fast[idx] - self.eta * avg_grad;
}
}
}
self.accumulated_grad.fill(0.0);
self.n_accumulated = 0;
}
}
impl TTTLayer {
fn ensure_init(&mut self, d_model: usize) {
if self.initialized {
return;
}
self.d_model = d_model;
let d = self.d_state;
if d == d_model {
self.w_k = (0..d * d)
.map(|idx| if idx / d == idx % d { 1.0 } else { 0.0 })
.collect();
self.w_v = self.w_k.clone();
self.w_q = self.w_k.clone();
} else {
let scale = (2.0 / (d_model + d) as f64).sqrt();
self.w_k = random_matrix(&mut self.rng_state, d, d_model, scale);
self.w_v = random_matrix(&mut self.rng_state, d, d_model, scale);
self.w_q = random_matrix(&mut self.rng_state, d, d_model, scale);
}
self.initialized = true;
}
#[allow(clippy::needless_range_loop)]
pub(crate) fn compute_projection_gradients(
&self,
features: &[f64],
pred_error: f64,
readout_weights: &[f64],
) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
let d = self.d_state;
let input_norm: f64 = features.iter().map(|x| x * x).sum::<f64>().sqrt().max(1.0);
let x_norm: Vec<f64> = features.iter().map(|x| x / input_norm).collect();
let n_input = x_norm.len();
let mut k = mat_vec_mul(&self.w_k, &x_norm, d);
let v = mat_vec_mul(&self.w_v, &x_norm, d);
let mut q = mat_vec_mul(&self.w_q, &x_norm, d);
let k_norm = k.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-8);
for ki in k.iter_mut() {
*ki /= k_norm;
}
let q_norm = q.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-8);
for qi in q.iter_mut() {
*qi /= q_norm;
}
let mut d_loss_d_q = vec![0.0; d];
if self.mlp_hidden_dim == 0 {
for i in 0..d {
let mut sum = 0.0;
for k_idx in 0..d {
let w_fast_ki = self.w_fast[k_idx * d + i];
let identity = if k_idx == i { 1.0 } else { 0.0 };
sum +=
readout_weights.get(k_idx).copied().unwrap_or(0.0) * (identity + w_fast_ki);
}
d_loss_d_q[i] = -2.0 * pred_error * sum;
}
} else {
for i in 0..d {
let identity_contrib = readout_weights.get(i).copied().unwrap_or(0.0);
d_loss_d_q[i] = -2.0 * pred_error * identity_contrib;
}
}
let mut grad_wq = vec![0.0; d * n_input];
for i in 0..d {
for j in 0..n_input {
grad_wq[i * n_input + j] = d_loss_d_q[i] * x_norm[j];
}
}
let z = if self.mlp_hidden_dim == 0 {
fast_mat_vec(&self.w_fast, &k, d)
} else {
let h_dim = self.mlp_hidden_dim;
let w1 = self.mlp_w1.as_ref().unwrap();
let w2 = self.mlp_w2.as_ref().unwrap();
let h_raw = mat_vec_mul_sq(w1, &k, h_dim, d);
let h: Vec<f64> = h_raw.iter().map(|&x| gelu(x)).collect();
mat_vec_mul_sq(w2, &h, d, h_dim)
};
let mut residual = vec![0.0; d];
for i in 0..d {
residual[i] = z[i] - v[i] + k[i];
}
let mut d_recon_d_k = vec![0.0; d];
if self.mlp_hidden_dim == 0 {
for i in 0..d {
let mut sum = 0.0;
for j in 0..d {
let w_fast_ji = self.w_fast[j * d + i];
let identity = if j == i { 1.0 } else { 0.0 };
sum += residual[j] * (w_fast_ji + identity);
}
d_recon_d_k[i] = 2.0 * sum;
}
} else {
for i in 0..d {
d_recon_d_k[i] = 2.0 * residual[i];
}
}
let mut grad_wk = vec![0.0; d * n_input];
for i in 0..d {
for j in 0..n_input {
grad_wk[i * n_input + j] = d_recon_d_k[i] * x_norm[j];
}
}
let mut grad_wv = vec![0.0; d * n_input];
for i in 0..d {
for j in 0..n_input {
grad_wv[i * n_input + j] = -2.0 * residual[i] * x_norm[j];
}
}
(grad_wq, grad_wk, grad_wv)
}
pub(crate) fn update_projections(
&mut self,
grad_wq: &[f64],
grad_wk: &[f64],
grad_wv: &[f64],
lr: f64,
) {
for (w, g) in self.w_q.iter_mut().zip(grad_wq.iter()) {
*w -= lr * g;
}
for (w, g) in self.w_k.iter_mut().zip(grad_wk.iter()) {
*w -= lr * g;
}
for (w, g) in self.w_v.iter_mut().zip(grad_wv.iter()) {
*w -= lr * g;
}
}
pub(crate) fn ensure_initialized(&mut self, d_model: usize) {
self.ensure_init(d_model);
}
pub(crate) fn set_projections(&mut self, w_k: Vec<f64>, w_v: Vec<f64>, w_q: Vec<f64>) {
self.w_k = w_k;
self.w_v = w_v;
self.w_q = w_q;
self.initialized = true;
}
}
#[inline]
fn gelu(x: f64) -> f64 {
let inner = (2.0_f64 / std::f64::consts::PI).sqrt() * (x + 0.044715 * x * x * x);
0.5 * x * (1.0 + inner.tanh())
}
fn random_matrix(rng: &mut u64, rows: usize, cols: usize, scale: f64) -> Vec<f64> {
let n = rows * cols;
let mut mat = Vec::with_capacity(n);
for _ in 0..n {
mat.push(standard_normal(rng) * scale);
}
mat
}
#[inline]
fn gelu_grad(x: f64) -> f64 {
let c = (2.0_f64 / std::f64::consts::PI).sqrt();
let a = 0.044715_f64;
let inner = c * (x + a * x * x * x);
let tanh_inner = inner.tanh();
let sech2 = 1.0 - tanh_inner * tanh_inner;
0.5 * (1.0 + tanh_inner) + 0.5 * x * sech2 * c * (1.0 + 3.0 * a * x * x)
}
#[inline]
fn mat_vec_mul_sq(w: &[f64], x: &[f64], rows: usize, _cols: usize) -> Vec<f64> {
mat_vec_mul(w, x, rows)
}
fn mat_vec_mul(w: &[f64], x: &[f64], rows: usize) -> Vec<f64> {
let cols = x.len();
let mut result = vec![0.0; rows];
for (i, out) in result.iter_mut().enumerate() {
let row_start = i * cols;
let mut sum = 0.0;
for j in 0..cols {
sum += w[row_start + j] * x[j];
}
*out = sum;
}
result
}
fn fast_mat_vec(w: &[f64], x: &[f64], d: usize) -> Vec<f64> {
let mut result = vec![0.0; d];
for (i, out) in result.iter_mut().enumerate() {
let row_start = i * d;
let mut sum = 0.0;
for j in 0..d {
sum += w[row_start + j] * x[j];
}
*out = sum;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_creates_uninit() {
let layer = TTTLayer::new(8, 0.01, 0.0, false, 0.0, false, 0, 0, 42);
assert!(!layer.initialized, "should be uninitialized after new()");
assert_eq!(layer.d_state, 8, "d_state should be 8");
assert!(
layer.w_fast.iter().all(|&v| v == 0.0),
"w_fast should be all zeros initially"
);
}
#[test]
fn forward_initializes_projections() {
let mut layer = TTTLayer::new(8, 0.01, 0.0, false, 0.0, false, 0, 0, 42);
assert!(!layer.initialized, "should start uninitialized");
assert!(
layer.w_k.is_empty(),
"w_k should be empty before first forward"
);
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let _ = layer.forward(&input);
assert!(
layer.initialized,
"should be initialized after first forward"
);
assert_eq!(layer.d_model, 5, "d_model should be set to input length");
assert_eq!(
layer.w_k.len(),
8 * 5,
"w_k should be [d_state x d_model] = [8 x 5]"
);
assert_eq!(
layer.w_v.len(),
8 * 5,
"w_v should be [d_state x d_model] = [8 x 5]"
);
assert_eq!(
layer.w_q.len(),
8 * 5,
"w_q should be [d_state x d_model] = [8 x 5]"
);
}
#[test]
fn forward_output_dimension() {
let mut layer = TTTLayer::new(16, 0.01, 0.0, false, 0.0, false, 0, 0, 42);
let input = vec![1.0, 2.0, 3.0];
let output = layer.forward(&input);
assert_eq!(
output.len(),
16,
"output dimension should equal d_state=16, got {}",
output.len()
);
}
#[test]
fn forward_output_finite() {
let mut layer = TTTLayer::new(8, 0.01, 0.0, false, 0.0, false, 0, 0, 42);
let input = vec![0.5, -0.3, 1.2, 0.0, -1.0];
let output = layer.forward(&input);
for (i, &v) in output.iter().enumerate() {
assert!(v.is_finite(), "output[{}] = {} is not finite", i, v);
}
}
#[test]
fn fast_weights_update() {
let mut layer = TTTLayer::new(8, 0.01, 0.0, false, 0.0, false, 0, 0, 42);
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let _ = layer.forward(&input);
let changed = layer.w_fast.iter().any(|&v| v != 0.0);
assert!(
changed,
"w_fast should no longer be all zeros after forward pass"
);
}
#[test]
fn reset_zeros_fast_weights() {
let mut layer = TTTLayer::new(8, 0.01, 0.0, false, 0.0, false, 0, 0, 42);
let input = vec![1.0, 2.0, 3.0];
let _ = layer.forward(&input);
assert!(
layer.w_fast.iter().any(|&v| v != 0.0),
"w_fast should be non-zero after forward"
);
layer.reset_full();
assert!(
layer.w_fast.iter().all(|&v| v == 0.0),
"w_fast should be all zeros after reset"
);
}
#[test]
fn reset_full_clears_projections() {
let mut layer = TTTLayer::new(8, 0.01, 0.0, false, 0.0, false, 0, 0, 42);
let input = vec![1.0, 2.0, 3.0];
let _ = layer.forward(&input);
layer.reset_full();
assert!(
!layer.initialized,
"initialized should be false after reset_full"
);
assert!(
layer.w_k.is_empty(),
"w_k should be cleared after reset_full"
);
let output = layer.forward(&input);
assert_eq!(output.len(), 8, "forward should still work after reset");
for (i, &v) in output.iter().enumerate() {
assert!(
v.is_finite(),
"output[{}] = {} is not finite after reset",
i,
v
);
}
}
#[test]
fn reset_full_clears_everything() {
let mut layer = TTTLayer::new(8, 0.01, 0.0, false, 0.0, false, 0, 0, 42);
let input = vec![1.0, 2.0, 3.0];
let _ = layer.forward(&input);
layer.reset_full();
assert!(
!layer.initialized,
"initialized should be false after reset_full"
);
assert!(
layer.w_k.is_empty(),
"w_k should be cleared after reset_full"
);
assert!(
layer.w_v.is_empty(),
"w_v should be cleared after reset_full"
);
assert!(
layer.w_q.is_empty(),
"w_q should be cleared after reset_full"
);
assert_eq!(layer.d_model, 0, "d_model should be 0 after reset_full");
assert!(
layer.w_fast.iter().all(|&v| v == 0.0),
"w_fast should be all zeros after reset_full"
);
}
#[test]
fn momentum_changes_behavior() {
let input = vec![1.0, -0.5, 0.3, 2.0];
let mut layer_no_mom = TTTLayer::new(8, 0.01, 0.0, false, 0.0, false, 0, 0, 42);
let _ = layer_no_mom.forward(&input);
let _ = layer_no_mom.forward(&input);
let out_no_mom = layer_no_mom.forward(&input);
let mut layer_mom = TTTLayer::new(8, 0.01, 0.0, true, 0.9, false, 0, 0, 42);
let _ = layer_mom.forward(&input);
let _ = layer_mom.forward(&input);
let out_mom = layer_mom.forward(&input);
let diff: f64 = out_no_mom
.iter()
.zip(out_mom.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
diff > 1e-10,
"momentum should produce different output after multiple steps, total diff = {}",
diff
);
}
#[test]
fn deterministic_with_seed() {
let input = vec![0.5, -1.0, 2.0, 0.3];
let mut layer_a = TTTLayer::new(8, 0.01, 0.0, false, 0.0, false, 0, 0, 12345);
let out_a1 = layer_a.forward(&input);
let out_a2 = layer_a.forward(&input);
let mut layer_b = TTTLayer::new(8, 0.01, 0.0, false, 0.0, false, 0, 0, 12345);
let out_b1 = layer_b.forward(&input);
let out_b2 = layer_b.forward(&input);
for i in 0..8 {
assert!(
(out_a1[i] - out_b1[i]).abs() < 1e-15,
"step 1 output[{}] differs: {} vs {}",
i,
out_a1[i],
out_b1[i]
);
assert!(
(out_a2[i] - out_b2[i]).abs() < 1e-15,
"step 2 output[{}] differs: {} vs {}",
i,
out_a2[i],
out_b2[i]
);
}
}
#[test]
fn convergence_on_pattern() {
let mut layer = TTTLayer::new(4, 0.05, 0.0, false, 0.0, false, 0, 0, 42);
let pattern = vec![1.0, 0.0, 0.5, -0.5, 0.3];
let mut errors: Vec<f64> = Vec::new();
for _ in 0..50 {
let _ = layer.forward(&pattern);
let k = mat_vec_mul(&layer.w_k, &pattern, layer.d_state);
let v = mat_vec_mul(&layer.w_v, &pattern, layer.d_state);
let z = fast_mat_vec(&layer.w_fast, &k, layer.d_state);
let err: f64 = (0..layer.d_state)
.map(|i| {
let r = z[i] - (v[i] - k[i]);
r * r
})
.sum();
errors.push(err);
}
let first_half_avg: f64 = errors[..25].iter().sum::<f64>() / 25.0;
let second_half_avg: f64 = errors[25..].iter().sum::<f64>() / 25.0;
assert!(
second_half_avg < first_half_avg,
"reconstruction error should decrease over time: first_half_avg={}, second_half_avg={}",
first_half_avg,
second_half_avg
);
}
}