use std::path::Path;
use tabicl_model::autodiff::linear3d_backward;
fn read_f32(path: &Path) -> Option<Vec<f32>> {
let bytes = std::fs::read(path).ok()?;
let mut out = Vec::with_capacity(bytes.len() / 4);
for chunk in bytes.chunks_exact(4) {
out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
Some(out)
}
#[test]
fn layer_norm_gradient_matches_pytorch_autograd() {
let base = Path::new("/tmp/tabicl_grad_ln");
if !base.join("x.bin").exists() {
eprintln!("LN grad parity: fixture not at {base:?}, skipping");
return;
}
let d = 6;
let n = 4;
let x = read_f32(&base.join("x.bin")).unwrap();
let gamma = read_f32(&base.join("gamma.bin")).unwrap();
let beta = read_f32(&base.join("beta.bin")).unwrap();
let target = read_f32(&base.join("target.bin")).unwrap();
let py_dx = read_f32(&base.join("dx.bin")).unwrap();
let py_dgamma = read_f32(&base.join("dgamma.bin")).unwrap();
let py_dbeta = read_f32(&base.join("dbeta.bin")).unwrap();
let x_arr = ndarray::Array3::from_shape_vec((1, n, d), x).unwrap();
let target_arr = ndarray::Array3::from_shape_vec((1, n, d), target).unwrap();
let y = tabicl_model::layers::layer_norm_last(x_arr.view(), &gamma, Some(&beta), 1e-5);
let mut dy = ndarray::Array3::<f32>::zeros(y.dim());
for i in 0..n {
for k in 0..d {
dy[(0, i, k)] = y[(0, i, k)] - target_arr[(0, i, k)];
}
}
let (dx, dgamma, dbeta) =
tabicl_model::autodiff::layer_norm_backward(x_arr.view(), &gamma, dy.view(), 1e-5);
let mut max_dx_diff = 0.0_f32;
for i in 0..n {
for k in 0..d {
let r = dx[(0, i, k)];
let p = py_dx[i * d + k];
let dd = (r - p).abs();
if dd > max_dx_diff {
max_dx_diff = dd;
}
}
}
let max_dgamma_diff = dgamma
.iter()
.zip(py_dgamma.iter())
.map(|(r, p)| (r - p).abs())
.fold(0.0_f32, f32::max);
let max_dbeta_diff = dbeta
.iter()
.zip(py_dbeta.iter())
.map(|(r, p)| (r - p).abs())
.fold(0.0_f32, f32::max);
eprintln!(
"LN grad max diffs: dx={max_dx_diff} dgamma={max_dgamma_diff} dbeta={max_dbeta_diff}"
);
assert!(max_dx_diff < 1e-4, "dx mismatch: {max_dx_diff}");
assert!(max_dgamma_diff < 1e-5, "dgamma mismatch: {max_dgamma_diff}");
assert!(max_dbeta_diff < 1e-5, "dbeta mismatch: {max_dbeta_diff}");
}
#[test]
fn linear_gradient_matches_pytorch_autograd() {
let base = Path::new("/tmp/tabicl_grad");
if !base.join("x.bin").exists() {
eprintln!("gradient parity: fixture not at {base:?}, skipping");
return;
}
let n = 8;
let in_dim = 4;
let out_dim = 3;
let x = read_f32(&base.join("x.bin")).unwrap();
let y = read_f32(&base.join("y.bin")).unwrap();
let w = read_f32(&base.join("w.bin")).unwrap();
let b = read_f32(&base.join("b.bin")).unwrap();
let py_dw = read_f32(&base.join("dw.bin")).unwrap();
let py_db = read_f32(&base.join("db.bin")).unwrap();
let x_arr = ndarray::Array3::from_shape_vec((1, n, in_dim), x).unwrap();
let y_arr = ndarray::Array3::from_shape_vec((1, n, out_dim), y).unwrap();
let w_arr = ndarray::Array2::from_shape_vec((out_dim, in_dim), w).unwrap();
let b_vec = b.clone();
let pred = tabicl_model::layers::linear3d(x_arr.view(), w_arr.view(), Some(&b_vec));
let total = (n * out_dim) as f32;
let mut dpred = ndarray::Array3::<f32>::zeros((1, n, out_dim));
for i in 0..n {
for o in 0..out_dim {
dpred[(0, i, o)] = (pred[(0, i, o)] - y_arr[(0, i, o)]) / total;
}
}
let (_dx, dw, db) = linear3d_backward(x_arr.view(), w_arr.view(), dpred.view(), true);
let db = db.unwrap();
let mut max_dw_diff = 0.0_f32;
for o in 0..out_dim {
for k in 0..in_dim {
let r = dw[(o, k)];
let p = py_dw[o * in_dim + k];
let d = (r - p).abs();
if d > max_dw_diff {
max_dw_diff = d;
}
}
}
let mut max_db_diff = 0.0_f32;
for o in 0..out_dim {
let d = (db[o] - py_db[o]).abs();
if d > max_db_diff {
max_db_diff = d;
}
}
eprintln!("Max dW diff: {max_dw_diff}");
eprintln!("Max db diff: {max_db_diff}");
assert!(max_dw_diff < 1e-5, "dW mismatch: {max_dw_diff}");
assert!(max_db_diff < 1e-5, "db mismatch: {max_db_diff}");
}