use attnres::AttnResConfig;
use burn::backend::NdArray;
use burn::prelude::*;
type TestBackend = NdArray;
#[test]
fn test_differential_zero_query_is_mean() {
let device = Default::default();
let config = AttnResConfig::new(4, 4, 2);
let op = config.init_op::<TestBackend>(&device);
let block0 = Tensor::<TestBackend, 3>::from_floats([[[1.0, 2.0, 3.0, 4.0]]], &device);
let block1 = Tensor::<TestBackend, 3>::from_floats([[[5.0, 6.0, 7.0, 8.0]]], &device);
let partial = Tensor::<TestBackend, 3>::from_floats([[[9.0, 10.0, 11.0, 12.0]]], &device);
let output = op.forward(&[block0, block1], &partial);
let expected = Tensor::<TestBackend, 3>::from_floats([[[5.0, 6.0, 7.0, 8.0]]], &device);
let diff: f32 = (output - expected).abs().max().into_scalar();
assert!(
diff < 1e-4,
"Differential test failed: expected mean of sources, diff={diff}"
);
}
#[test]
fn test_differential_rmsnorm_known_input() {
let device = Default::default();
let norm = attnres::RmsNormConfig::new(4).init::<TestBackend>(&device);
let x = Tensor::<TestBackend, 3>::from_floats([[[1.0, 2.0, 3.0, 4.0]]], &device);
let out = norm.forward(x);
let data: Vec<f32> = out.reshape([4]).into_data().to_vec().unwrap();
let rms = (7.5_f64 + 1e-6).sqrt() as f32;
let expected: Vec<f32> = vec![1.0 / rms, 2.0 / rms, 3.0 / rms, 4.0 / rms];
for (i, (got, want)) in data.iter().zip(expected.iter()).enumerate() {
assert!(
(got - want).abs() < 1e-4,
"RMSNorm element {i}: got {got}, want {want}"
);
}
}
#[test]
fn test_differential_softmax_over_depth() {
let device = Default::default();
let config = AttnResConfig::new(4, 4, 2);
let op = config.init_op::<TestBackend>(&device);
let val = Tensor::<TestBackend, 3>::from_floats([[[1.0, 1.0, 1.0, 1.0]]], &device);
let output = op.forward(&[val.clone(), val.clone()], &val);
let diff: f32 = (output - val).abs().max().into_scalar();
assert!(
diff < 1e-5,
"Identical sources should produce identical output, diff={diff}"
);
}