use ndarray::{Array1, Array2, Array3, ArrayView2, ArrayView3};
#[allow(unused_imports)]
use crate::layers::linear3d;
pub struct LinearLayer {
pub weight: ndarray::Array2<f32>,
pub bias: Vec<f32>,
}
impl LinearLayer {
pub fn new(out_dim: usize, in_dim: usize) -> Self {
Self {
weight: ndarray::Array2::<f32>::zeros((out_dim, in_dim)),
bias: vec![0.0; out_dim],
}
}
pub fn forward(&self, x: ndarray::ArrayView3<f32>) -> ndarray::Array3<f32> {
crate::layers::linear3d(x, self.weight.view(), Some(&self.bias))
}
pub fn sgd_step(
&mut self,
x: ndarray::ArrayView3<f32>,
y: ndarray::ArrayView3<f32>,
lr: f32,
) -> f32 {
let pred = self.forward(x);
let n = (pred.shape()[0] * pred.shape()[1] * pred.shape()[2]) as f32;
let mut loss = 0.0_f32;
let mut dpred = ndarray::Array3::<f32>::zeros(pred.dim());
for ((p, t), d) in pred.iter().zip(y.iter()).zip(dpred.iter_mut()) {
let diff = p - t;
loss += 0.5 * diff * diff;
*d = diff / n;
}
loss /= n;
let (_, dw, db) = linear3d_backward(x, self.weight.view(), dpred.view(), true);
for ((wi, gi), bb) in self
.weight
.iter_mut()
.zip(dw.iter())
.zip(std::iter::repeat(0))
{
*wi -= lr * gi;
let _ = bb;
}
if let Some(db) = db {
for (b, g) in self.bias.iter_mut().zip(db.iter()) {
*b -= lr * g;
}
}
loss
}
}
pub struct MlpBlock {
pub linear1: LinearLayer,
pub linear2: LinearLayer,
pub ln_gamma: Vec<f32>,
pub ln_beta: Vec<f32>,
pub ln_eps: f32,
}
impl MlpBlock {
pub fn new(in_dim: usize, hidden: usize, out_dim: usize) -> Self {
let mut s = 0x12345_u32;
let mut sample = || {
s = s.wrapping_mul(1664525).wrapping_add(1013904223);
(((s >> 8) as f32) / 16_777_216.0 - 0.5) * 2.0
};
let scale1 = (1.0_f32 / in_dim as f32).sqrt();
let scale2 = (1.0_f32 / hidden as f32).sqrt();
let mut l1 = LinearLayer::new(hidden, in_dim);
for v in l1.weight.iter_mut() {
*v = sample() * scale1;
}
let mut l2 = LinearLayer::new(out_dim, hidden);
for v in l2.weight.iter_mut() {
*v = sample() * scale2;
}
Self {
linear1: l1,
linear2: l2,
ln_gamma: vec![1.0; in_dim],
ln_beta: vec![0.0; in_dim],
ln_eps: 1e-5,
}
}
pub fn forward(&self, x: ndarray::ArrayView3<f32>) -> ndarray::Array3<f32> {
let normed =
crate::layers::layer_norm_last(x, &self.ln_gamma, Some(&self.ln_beta), self.ln_eps);
let h = self.linear1.forward(normed.view());
let h_act = gelu_exact(h);
self.linear2.forward(h_act.view())
}
pub fn sgd_step(
&mut self,
x: ndarray::ArrayView3<f32>,
target: ndarray::ArrayView3<f32>,
lr: f32,
) -> f32 {
let normed =
crate::layers::layer_norm_last(x, &self.ln_gamma, Some(&self.ln_beta), self.ln_eps);
let h_pre = self.linear1.forward(normed.view());
let h_act = gelu_exact(h_pre.clone());
let pred = self.linear2.forward(h_act.view());
let n = (pred.shape()[0] * pred.shape()[1] * pred.shape()[2]) as f32;
let mut loss = 0.0_f32;
let mut dpred = ndarray::Array3::<f32>::zeros(pred.dim());
for ((p, t), d) in pred.iter().zip(target.iter()).zip(dpred.iter_mut()) {
let diff = p - t;
loss += 0.5 * diff * diff;
*d = diff / n;
}
loss /= n;
let (dh_act, dw2, db2) =
linear3d_backward(h_act.view(), self.linear2.weight.view(), dpred.view(), true);
let mut dh_pre = dh_act;
gelu_backward_inplace_exact(&mut dh_pre, h_pre.view());
let (dnormed, dw1, db1) = linear3d_backward(
normed.view(),
self.linear1.weight.view(),
dh_pre.view(),
true,
);
let (_dx, dgamma, dbeta) =
layer_norm_backward(x, &self.ln_gamma, dnormed.view(), self.ln_eps);
for (w, g) in self.linear1.weight.iter_mut().zip(dw1.iter()) {
*w -= lr * g;
}
for (b, g) in self
.linear1
.bias
.iter_mut()
.zip(db1.as_ref().unwrap().iter())
{
*b -= lr * g;
}
for (w, g) in self.linear2.weight.iter_mut().zip(dw2.iter()) {
*w -= lr * g;
}
for (b, g) in self
.linear2
.bias
.iter_mut()
.zip(db2.as_ref().unwrap().iter())
{
*b -= lr * g;
}
for (gp, g) in self.ln_gamma.iter_mut().zip(dgamma.iter()) {
*gp -= lr * g;
}
for (bp, g) in self.ln_beta.iter_mut().zip(dbeta.iter()) {
*bp -= lr * g;
}
loss
}
}
fn gelu_exact(mut x: ndarray::Array3<f32>) -> ndarray::Array3<f32> {
for v in x.iter_mut() {
let xv = *v;
*v = 0.5 * xv * (1.0 + erf_f32(xv / std::f32::consts::SQRT_2));
}
x
}
fn gelu_backward_inplace_exact(grad: &mut ndarray::Array3<f32>, x: ndarray::ArrayView3<f32>) {
let two_over_pi_sqrt = (2.0_f32 / std::f32::consts::PI).sqrt();
for (g, &xv) in grad.iter_mut().zip(x.iter()) {
let ex2 = (-0.5 * xv * xv).exp();
let deriv = 0.5 * (1.0 + erf_f32(xv / std::f32::consts::SQRT_2))
+ 0.5 * xv * two_over_pi_sqrt * ex2 / std::f32::consts::SQRT_2;
*g *= deriv;
}
}
fn erf_f32(x: f32) -> f32 {
let sign = x.signum();
let ax = x.abs();
let t = 1.0 / (1.0 + 0.3275911 * ax);
let y = 1.0
- (((((1.061_405_4_f32 * t - 1.453_152_1) * t + 1.421_413_8) * t - 0.284_496_72) * t
+ 0.254_829_6)
* t)
* (-ax * ax).exp();
sign * y
}
pub fn linear3d_backward(
x: ArrayView3<f32>,
weight: ArrayView2<f32>,
dy: ArrayView3<f32>,
has_bias: bool,
) -> (Array3<f32>, Array2<f32>, Option<Array1<f32>>) {
let (b, t, in_f) = (x.shape()[0], x.shape()[1], x.shape()[2]);
let out_f = weight.shape()[0];
assert_eq!(weight.shape()[1], in_f);
assert_eq!(dy.shape(), &[b, t, out_f]);
let mut dx = Array3::<f32>::zeros((b, t, in_f));
for bi in 0..b {
for ti in 0..t {
for j in 0..in_f {
let mut s = 0.0_f32;
for k in 0..out_f {
s += dy[(bi, ti, k)] * weight[(k, j)];
}
dx[(bi, ti, j)] = s;
}
}
}
let mut dw = Array2::<f32>::zeros((out_f, in_f));
for k in 0..out_f {
for j in 0..in_f {
let mut s = 0.0_f32;
for bi in 0..b {
for ti in 0..t {
s += dy[(bi, ti, k)] * x[(bi, ti, j)];
}
}
dw[(k, j)] = s;
}
}
let db = if has_bias {
let mut b_grad = Array1::<f32>::zeros(out_f);
for k in 0..out_f {
let mut s = 0.0_f32;
for bi in 0..b {
for ti in 0..t {
s += dy[(bi, ti, k)];
}
}
b_grad[k] = s;
}
Some(b_grad)
} else {
None
};
(dx, dw, db)
}
pub fn layer_norm_backward(
x: ArrayView3<f32>,
gamma: &[f32],
dy: ArrayView3<f32>,
eps: f32,
) -> (Array3<f32>, Vec<f32>, Vec<f32>) {
let (b, t, d) = (x.shape()[0], x.shape()[1], x.shape()[2]);
assert_eq!(gamma.len(), d);
let mut dx = Array3::<f32>::zeros((b, t, d));
let mut dgamma = vec![0.0_f32; d];
let mut dbeta = vec![0.0_f32; d];
let inv_d = 1.0_f32 / d as f32;
for bi in 0..b {
for ti in 0..t {
let mut mean = 0.0_f32;
for k in 0..d {
mean += x[(bi, ti, k)];
}
mean *= inv_d;
let mut var = 0.0_f32;
for k in 0..d {
let dvv = x[(bi, ti, k)] - mean;
var += dvv * dvv;
}
var *= inv_d;
let inv_std = 1.0_f32 / (var + eps).sqrt();
let x_hat: Vec<f32> = (0..d).map(|k| (x[(bi, ti, k)] - mean) * inv_std).collect();
let dxhat: Vec<f32> = (0..d).map(|k| dy[(bi, ti, k)] * gamma[k]).collect();
let dvar: f32 = (0..d)
.map(|k| dxhat[k] * (x[(bi, ti, k)] - mean))
.sum::<f32>()
* -0.5
* inv_std.powi(3);
let dmean: f32 = (0..d).map(|k| dxhat[k] * -inv_std).sum::<f32>()
+ dvar * -2.0 * (0..d).map(|k| x[(bi, ti, k)] - mean).sum::<f32>() * inv_d;
for k in 0..d {
dx[(bi, ti, k)] = dxhat[k] * inv_std
+ dvar * 2.0 * (x[(bi, ti, k)] - mean) * inv_d
+ dmean * inv_d;
}
for k in 0..d {
dgamma[k] += dy[(bi, ti, k)] * x_hat[k];
dbeta[k] += dy[(bi, ti, k)];
}
}
}
(dx, dgamma, dbeta)
}
pub fn gelu_backward_inplace(grad: &mut Array3<f32>, x: ArrayView3<f32>) {
for (g, &xv) in grad.iter_mut().zip(x.iter()) {
let c = (2.0_f32 / std::f32::consts::PI).sqrt();
let inner = c * (xv + 0.044715 * xv * xv * xv);
let t = inner.tanh();
let dinner = c * (1.0 + 3.0 * 0.044715 * xv * xv);
let deriv = 0.5 * (1.0 + t) + 0.5 * xv * (1.0 - t * t) * dinner;
*g *= deriv;
}
}
pub fn softmax_backward_last(probs: ArrayView3<f32>, dy: ArrayView3<f32>) -> Array3<f32> {
let (b, t, d) = (probs.shape()[0], probs.shape()[1], probs.shape()[2]);
let mut dx = Array3::<f32>::zeros((b, t, d));
for bi in 0..b {
for ti in 0..t {
let mut dot = 0.0_f32;
for k in 0..d {
dot += probs[(bi, ti, k)] * dy[(bi, ti, k)];
}
for k in 0..d {
dx[(bi, ti, k)] = probs[(bi, ti, k)] * (dy[(bi, ti, k)] - dot);
}
}
}
dx
}
pub fn mse_backward(pred: &[f32], target: &[f32]) -> Vec<f32> {
let n = pred.len() as f32;
pred.iter()
.zip(target.iter())
.map(|(p, t)| 2.0 * (p - t) / n)
.collect()
}
pub fn cross_entropy_softmax_backward(probs: ArrayView2<f32>, labels: &[usize]) -> Array2<f32> {
let (n, k) = (probs.shape()[0], probs.shape()[1]);
let mut out = probs.to_owned();
let inv_n = 1.0 / n as f32;
for i in 0..n {
out[(i, labels[i])] -= 1.0;
for c in 0..k {
out[(i, c)] *= inv_n;
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array, array};
fn forward_linear3d(x: ArrayView3<f32>, w: ArrayView2<f32>, b: Option<&[f32]>) -> Array3<f32> {
linear3d(x, w, b)
}
#[test]
fn linear_backward_matches_finite_difference() {
let x = Array::from_shape_vec((1, 2, 3), vec![0.5_f32, -1.0, 2.0, 0.0, 1.5, -0.5]).unwrap();
let w = Array::from_shape_vec((2, 3), vec![1.0_f32, 0.5, -0.5, 0.0, 1.0, 0.5]).unwrap();
let b = vec![0.1_f32, -0.2];
let dy = Array::from_shape_vec((1, 2, 2), vec![1.0_f32, 0.5, -1.0, 0.25]).unwrap();
let (dx, dw, db) = linear3d_backward(x.view(), w.view(), dy.view(), true);
assert_eq!(dx.shape(), x.shape());
assert_eq!(dw.shape(), w.shape());
assert_eq!(db.as_ref().unwrap().len(), 2);
let eps = 1e-3_f32;
let y0 = forward_linear3d(x.view(), w.view(), Some(&b));
let mut w_plus = w.clone();
w_plus[(0, 1)] += eps;
let y_plus = forward_linear3d(x.view(), w_plus.view(), Some(&b));
let mut loss_plus = 0.0_f32;
let mut loss0 = 0.0_f32;
for (a, g) in y0.iter().zip(dy.iter()) {
loss0 += a * g;
}
for (a, g) in y_plus.iter().zip(dy.iter()) {
loss_plus += a * g;
}
let fd = (loss_plus - loss0) / eps;
assert!(
(fd - dw[(0, 1)]).abs() < 1e-2,
"fd {fd} vs analytic {}",
dw[(0, 1)]
);
}
#[test]
fn layer_norm_backward_matches_finite_difference() {
let x = Array::from_shape_vec((1, 1, 4), vec![1.0_f32, 2.0, 3.0, 4.0]).unwrap();
let gamma = vec![1.0_f32; 4];
let dy = array![[[0.5_f32, -0.25, 1.0, 0.75]]];
let (dx, dg, _db) = layer_norm_backward(x.view(), &gamma, dy.view(), 1e-5);
let eps = 1e-3_f32;
let y0 = crate::layers::layer_norm_last(x.view(), &gamma, None, 1e-5);
let mut x_plus = x.clone();
x_plus[(0, 0, 2)] += eps;
let yp = crate::layers::layer_norm_last(x_plus.view(), &gamma, None, 1e-5);
let mut l0 = 0.0_f32;
let mut lp = 0.0_f32;
for (a, g) in y0.iter().zip(dy.iter()) {
l0 += a * g;
}
for (a, g) in yp.iter().zip(dy.iter()) {
lp += a * g;
}
let fd = (lp - l0) / eps;
assert!(
(fd - dx[(0, 0, 2)]).abs() < 1e-2,
"fd {fd} vs analytic {}",
dx[(0, 0, 2)]
);
assert_eq!(dg.len(), 4);
}
#[test]
fn softmax_backward_sums_to_zero_per_row() {
let probs = array![[[0.2_f32, 0.5, 0.3], [0.1, 0.7, 0.2]]];
let dy = array![[[1.0_f32, 1.0, 1.0], [1.0, 1.0, 1.0]]];
let dx = softmax_backward_last(probs.view(), dy.view());
for ti in 0..2 {
let row_sum: f32 = (0..3).map(|k| dx[(0, ti, k)]).sum();
assert!(row_sum.abs() < 1e-5, "row {ti} sum {row_sum}");
}
}
#[test]
fn cross_entropy_softmax_backward_matches_formula() {
let probs = array![[0.7_f32, 0.2, 0.1], [0.1, 0.8, 0.1]];
let labels = vec![0_usize, 1];
let grad = cross_entropy_softmax_backward(probs.view(), &labels);
assert!((grad[(0, 0)] - (-0.15)).abs() < 1e-5);
assert!((grad[(0, 1)] - 0.1).abs() < 1e-5);
assert!((grad[(0, 2)] - 0.05).abs() < 1e-5);
}
#[test]
fn mlp_block_full_backbone_sgd_reduces_loss() {
let in_dim = 4;
let hidden = 8;
let out_dim = 2;
let n = 32;
let x = ndarray::Array::from_shape_fn((1, n, in_dim), |(_, i, j)| {
((i * in_dim + j) as f32) * 0.05 - 0.8
});
let target = ndarray::Array::from_shape_fn((1, n, out_dim), |(_, i, j)| {
if j == 0 {
(x[(0, i, 0)] * 2.0 + x[(0, i, 1)]).tanh()
} else {
(x[(0, i, 2)] - x[(0, i, 3)]).max(0.0)
}
});
let mut block = MlpBlock::new(in_dim, hidden, out_dim);
let l0 = block.sgd_step(x.view(), target.view(), 0.05);
let mut last = l0;
for _ in 0..200 {
last = block.sgd_step(x.view(), target.view(), 0.05);
}
assert!(
last < l0 * 0.5,
"full-backbone MLP SGD did not reduce loss: l0={l0} last={last}"
);
}
#[test]
fn linear_layer_sgd_reduces_mse() {
let in_dim = 3;
let out_dim = 2;
let n = 16;
let target_w = ndarray::Array::from_shape_vec(
(out_dim, in_dim),
vec![1.0_f32, -2.0, 0.5, 0.3, 1.5, -1.0],
)
.unwrap();
let x = ndarray::Array::from_shape_fn((1, n, in_dim), |(_, i, j)| {
(i as f32) * 0.1 + (j as f32) * 0.2 - 1.0
});
let mut y = ndarray::Array3::<f32>::zeros((1, n, out_dim));
for i in 0..n {
for o in 0..out_dim {
let mut s = 0.0_f32;
for j in 0..in_dim {
s += target_w[(o, j)] * x[(0, i, j)];
}
y[(0, i, o)] = s;
}
}
let mut layer = LinearLayer::new(out_dim, in_dim);
let l0 = layer.sgd_step(x.view(), y.view(), 0.1);
let mut last = l0;
for _ in 0..40 {
last = layer.sgd_step(x.view(), y.view(), 0.1);
}
assert!(
last < l0 * 0.1,
"linear SGD did not reduce loss: l0={l0} last={last}"
);
}
#[test]
fn mse_backward_matches_formula() {
let pred = [1.0_f32, 2.0, 3.0];
let target = [0.0_f32, 2.0, 4.0];
let grad = mse_backward(&pred, &target);
assert!((grad[0] - 2.0 / 3.0).abs() < 1e-5);
assert!(grad[1].abs() < 1e-5);
assert!((grad[2] - -2.0 / 3.0).abs() < 1e-5);
}
}