use ndarray::{Array, Array2, Array4, array};
use tabicl_model::attention::{AttentionConfig, AttentionParams, multi_head_attention_forward};
use tabicl_model::layers::{OneHotAndLinear, SkippableLinear, layer_norm_last, linear3d};
use tabicl_model::rope::{RopeConfig, RopeTables, apply_rotary_emb_ref};
#[test]
fn layer_norm_matches_pytorch_output_bit_for_bit() {
let x = array![[[1.0_f32, 2.0, 3.0, 4.0]]];
let y = layer_norm_last(x.view(), &[1.0, 1.0, 1.0, 1.0], None, 1e-5);
let pytorch_truth = [
-1.3416354656219482_f32,
-0.4472118318080902,
0.4472118318080902,
1.3416354656219482,
];
for k in 0..4 {
let diff = (y[(0, 0, k)] - pytorch_truth[k]).abs();
assert!(
diff < 1e-5,
"LN[{k}] Rust={} PyTorch={} diff={}",
y[(0, 0, k)],
pytorch_truth[k],
diff
);
}
}
#[test]
fn layer_norm_with_affine_matches_python_formula() {
let x = array![[[1.0_f32, 2.0, 3.0]]];
let gamma = [2.0, 0.5, 1.5];
let beta = [10.0, 20.0, -5.0];
let y = layer_norm_last(x.view(), &gamma, Some(&beta), 1e-5);
let var = 2.0_f32 / 3.0;
let inv_std = 1.0 / (var + 1e-5).sqrt();
let expected = [
(1.0 - 2.0) * inv_std * 2.0 + 10.0,
(2.0 - 2.0) * inv_std * 0.5 + 20.0,
(3.0 - 2.0) * inv_std * 1.5 - 5.0,
];
for k in 0..3 {
assert!((y[(0, 0, k)] - expected[k]).abs() < 1e-5);
}
}
#[test]
fn linear3d_matches_python_formula() {
let x = array![[[1.0_f32, 2.0]]];
let w = array![[0.5_f32, 1.0], [-1.0, 2.0]];
let b = [0.1_f32, 0.2];
let y = linear3d(x.view(), w.view(), Some(&b));
assert!((y[(0, 0, 0)] - 2.6).abs() < 1e-6);
assert!((y[(0, 0, 1)] - 3.2).abs() < 1e-6);
}
#[test]
fn one_hot_and_linear_matches_python_formula() {
let weight = array![[1.0_f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
let m = OneHotAndLinear::from_raw_weight(weight, Some(vec![10.0, 20.0]));
let src = array![[0_usize, 2, 1]];
let out = m.forward(src.view());
assert_eq!(out[(0, 0, 0)], 11.0);
assert_eq!(out[(0, 0, 1)], 24.0);
assert_eq!(out[(0, 1, 0)], 13.0);
assert_eq!(out[(0, 2, 1)], 25.0);
}
#[test]
fn skippable_linear_sentinel_matches_python() {
let w = array![[1.0_f32, 0.0], [0.0, 1.0]];
let m = SkippableLinear::new(w, Some(vec![10.0, 20.0]), -100.0);
let src = array![[[1.0_f32, 2.0], [-100.0, -100.0], [3.0, 4.0]]];
let out = m.forward(src.view());
assert_eq!(out[(0, 0, 0)], 11.0);
assert_eq!(out[(0, 0, 1)], 22.0);
assert_eq!(out[(0, 1, 0)], -100.0);
assert_eq!(out[(0, 1, 1)], -100.0);
assert_eq!(out[(0, 2, 0)], 13.0);
assert_eq!(out[(0, 2, 1)], 24.0);
}
#[test]
fn rope_non_interleaved_matches_python_bit_for_bit() {
let cfg = RopeConfig {
head_dim: 2,
base: 10_000.0,
interleaved: false,
};
let tables = RopeTables::new(cfg, 2);
let x = Array4::from_shape_vec((1, 1, 2, 2), vec![1.0_f32, 0.0, 0.5, 0.5]).unwrap();
let y = apply_rotary_emb_ref(&x.view(), &tables);
assert!((y[(0, 0, 0, 0)] - 1.0).abs() < 1e-6);
assert!((y[(0, 0, 0, 1)] - 0.0).abs() < 1e-6);
let python_lo = -0.15058433946987837_f32;
let python_hi = 0.6908866453380181_f32;
let diff_lo = (y[(0, 0, 1, 0)] - python_lo).abs();
let diff_hi = (y[(0, 0, 1, 1)] - python_hi).abs();
assert!(
diff_lo < 1e-6,
"RoPE lo Rust={} Python={} diff={}",
y[(0, 0, 1, 0)],
python_lo,
diff_lo
);
assert!(
diff_hi < 1e-6,
"RoPE hi Rust={} Python={} diff={}",
y[(0, 0, 1, 1)],
python_hi,
diff_hi
);
}
#[test]
fn multi_head_attention_matches_pytorch_bit_for_bit() {
let embed_dim = 4;
let num_heads = 2;
let cfg = AttentionConfig {
embed_dim,
num_heads,
dropout: 0.0,
bias: true,
};
let mut w = Array2::<f32>::zeros((3 * embed_dim, embed_dim));
for i in 0..embed_dim {
w[(i, i)] = 1.0;
w[(embed_dim + i, i)] = 1.0;
}
let qk_input =
Array::from_shape_vec((1, 2, 4), vec![1.0_f32, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
for i in 0..embed_dim {
w[(2 * embed_dim + i, i)] = 1.0;
}
let v_input =
Array::from_shape_vec((1, 2, 4), vec![0.5_f32, 1.0, 1.5, 2.0, 3.0, 2.5, 2.0, 1.5]).unwrap();
let mut out_w = Array2::<f32>::zeros((embed_dim, embed_dim));
for i in 0..embed_dim {
out_w[(i, i)] = 1.0;
}
let params = AttentionParams {
in_proj_weight: w,
in_proj_bias: None,
out_proj_weight: out_w,
out_proj_bias: None,
};
let out = multi_head_attention_forward(
qk_input.view(),
qk_input.view(),
v_input.view(),
¶ms,
&cfg,
None,
None,
);
let pytorch_truth = [
1.3255960941314697_f32,
1.4953577518463135,
1.75,
1.75,
2.1744039058685303,
2.0046422481536865,
1.75,
1.75,
];
let mut idx = 0;
for t in 0..2 {
for e in 0..4 {
let diff = (out[(0, t, e)] - pytorch_truth[idx]).abs();
assert!(
diff < 1e-5,
"attn[{t},{e}] Rust={} PyTorch={} diff={}",
out[(0, t, e)],
pytorch_truth[idx],
diff
);
idx += 1;
}
}
}
#[test]
fn attention_with_known_qkv_matches_softmax_v() {
let embed_dim = 2;
let cfg = AttentionConfig {
embed_dim,
num_heads: 1,
dropout: 0.0,
bias: true,
};
let mut w = Array2::<f32>::zeros((3 * embed_dim, embed_dim));
for i in 0..embed_dim {
w[(i, i)] = 1.0; w[(embed_dim + i, i)] = 1.0; w[(2 * embed_dim + i, i)] = 1.0; }
let mut out_w = Array2::<f32>::zeros((embed_dim, embed_dim));
for i in 0..embed_dim {
out_w[(i, i)] = 1.0;
}
let params = AttentionParams {
in_proj_weight: w,
in_proj_bias: None,
out_proj_weight: out_w,
out_proj_bias: None,
};
let x = Array::from_shape_vec((1, 2, embed_dim), vec![1.0_f32, 0.0, 0.0, 1.0]).unwrap();
let y = multi_head_attention_forward(x.view(), x.view(), x.view(), ¶ms, &cfg, None, None);
let s = 1.0_f32 / 2.0_f32.sqrt();
let e_s = s.exp();
let z0 = e_s + 1.0;
let a00 = e_s / z0;
let a01 = 1.0 / z0;
assert!((y[(0, 0, 0)] - a00).abs() < 1e-5);
assert!((y[(0, 0, 1)] - a01).abs() < 1e-5);
assert!((y[(0, 1, 0)] - a01).abs() < 1e-5);
assert!((y[(0, 1, 1)] - a00).abs() < 1e-5);
}
#[test]
fn gelu_matches_pytorch_bit_for_bit() {
let v = 0.5_f32;
let c = (2.0_f32 / std::f32::consts::PI).sqrt();
let inner = c * (v + 0.044715 * v * v * v);
let g = 0.5 * v * (1.0 + inner.tanh());
let pytorch_truth = 0.3457140028476715_f32;
let diff = (g - pytorch_truth).abs();
assert!(
diff < 1e-6,
"GELU(0.5) Rust={g} PyTorch={pytorch_truth} diff={diff}"
);
}
#[test]
fn softmax_and_cross_entropy_match_pytorch_bit_for_bit() {
let logits = [2.0_f32, 1.0, 0.0];
let mut maxv = f32::NEG_INFINITY;
for v in &logits {
if *v > maxv {
maxv = *v;
}
}
let exps: Vec<f32> = logits.iter().map(|v| (v - maxv).exp()).collect();
let z: f32 = exps.iter().sum();
let probs: Vec<f32> = exps.iter().map(|e| e / z).collect();
let pytorch_softmax = [
0.6652409434318542_f32,
0.2447284758090973,
0.09003057330846786,
];
for k in 0..3 {
let diff = (probs[k] - pytorch_softmax[k]).abs();
assert!(
diff < 1e-6,
"softmax[{k}] Rust={} PyTorch={} diff={diff}",
probs[k],
pytorch_softmax[k]
);
}
let ce = -probs[0].ln();
let pytorch_ce = 0.4076059829985112_f32;
let diff = (ce - pytorch_ce).abs();
assert!(diff < 1e-6, "CE Rust={ce} PyTorch={pytorch_ce} diff={diff}");
}