use yscv_tensor::Tensor;
use crate::Graph;
#[test]
fn backward_conv_transpose2d_nhwc_computes_weight_and_input_grads() {
let mut graph = Graph::new();
let input =
graph.variable(Tensor::from_vec(vec![1, 2, 2, 1], vec![1.0, 2.0, 3.0, 4.0]).unwrap());
let weight =
graph.variable(Tensor::from_vec(vec![2, 2, 1, 1], vec![1.0, 0.5, 0.5, 0.25]).unwrap());
let out = graph
.conv_transpose2d_nhwc(input, weight, None, 1, 1)
.unwrap();
assert_eq!(graph.value(out).unwrap().shape(), &[1, 3, 3, 1]);
let loss = graph.sum(out).unwrap();
graph.backward(loss).unwrap();
let w_grad = graph.grad(weight).unwrap().unwrap();
assert_eq!(w_grad.shape(), &[2, 2, 1, 1]);
let i_grad = graph.grad(input).unwrap().unwrap();
assert_eq!(i_grad.shape(), &[1, 2, 2, 1]);
for &g in i_grad.data() {
assert!(
(g - 2.25).abs() < 1e-4,
"input grad mismatch: got {g}, expected 2.25"
);
}
for &g in w_grad.data() {
assert!(
(g - 10.0).abs() < 1e-4,
"weight grad mismatch: got {g}, expected 10.0"
);
}
}
#[test]
fn backward_conv_transpose2d_nhwc_with_bias() {
let mut graph = Graph::new();
let input = graph.variable(Tensor::filled(vec![1, 2, 2, 2], 1.0).unwrap());
let weight = graph.variable(Tensor::filled(vec![2, 2, 1, 2], 0.5).unwrap());
let bias = graph.variable(Tensor::from_vec(vec![1], vec![0.1]).unwrap());
let out = graph
.conv_transpose2d_nhwc(input, weight, Some(bias), 1, 1)
.unwrap();
assert_eq!(graph.value(out).unwrap().shape(), &[1, 3, 3, 1]);
let loss = graph.sum(out).unwrap();
graph.backward(loss).unwrap();
let b_grad = graph.grad(bias).unwrap().unwrap();
assert_eq!(b_grad.shape(), &[1]);
assert!((b_grad.data()[0] - 9.0).abs() < 1e-4);
}
#[test]
fn backward_conv_transpose2d_nhwc_stride2() {
let mut graph = Graph::new();
let input =
graph.variable(Tensor::from_vec(vec![1, 2, 2, 1], vec![1.0, 2.0, 3.0, 4.0]).unwrap());
let weight = graph.variable(Tensor::filled(vec![2, 2, 1, 1], 1.0).unwrap());
let out = graph
.conv_transpose2d_nhwc(input, weight, None, 2, 2)
.unwrap();
assert_eq!(graph.value(out).unwrap().shape(), &[1, 4, 4, 1]);
let loss = graph.sum(out).unwrap();
graph.backward(loss).unwrap();
let i_grad = graph.grad(input).unwrap().unwrap();
assert_eq!(i_grad.shape(), &[1, 2, 2, 1]);
for &g in i_grad.data() {
assert!((g - 4.0).abs() < 1e-4, "stride2 input grad: got {g}");
}
}
#[test]
fn backward_adaptive_avg_pool2d_nhwc_distributes_uniformly() {
let mut graph = Graph::new();
let data: Vec<f32> = (1..=16).map(|v| v as f32).collect();
let input = graph.variable(Tensor::from_vec(vec![1, 4, 4, 1], data).unwrap());
let out = graph.adaptive_avg_pool2d_nhwc(input, 2, 2).unwrap();
assert_eq!(graph.value(out).unwrap().shape(), &[1, 2, 2, 1]);
let loss = graph.sum(out).unwrap();
graph.backward(loss).unwrap();
let i_grad = graph.grad(input).unwrap().unwrap();
assert_eq!(i_grad.shape(), &[1, 4, 4, 1]);
for &g in i_grad.data() {
assert!(
(g - 0.25).abs() < 1e-6,
"adaptive avg pool grad: got {g}, expected 0.25"
);
}
}
#[test]
fn backward_adaptive_avg_pool2d_to_1x1() {
let mut graph = Graph::new();
let input = graph.variable(Tensor::filled(vec![1, 3, 3, 2], 1.0).unwrap());
let out = graph.adaptive_avg_pool2d_nhwc(input, 1, 1).unwrap();
assert_eq!(graph.value(out).unwrap().shape(), &[1, 1, 1, 2]);
let loss = graph.sum(out).unwrap();
graph.backward(loss).unwrap();
let i_grad = graph.grad(input).unwrap().unwrap();
for &g in i_grad.data() {
assert!(
(g - 1.0 / 9.0).abs() < 1e-6,
"global avg pool grad: got {g}"
);
}
}
#[test]
fn backward_adaptive_max_pool2d_nhwc_scatters_to_argmax() {
let mut graph = Graph::new();
let data: Vec<f32> = (1..=16).map(|v| v as f32).collect();
let input = graph.variable(Tensor::from_vec(vec![1, 4, 4, 1], data).unwrap());
let out = graph.adaptive_max_pool2d_nhwc(input, 2, 2).unwrap();
assert_eq!(graph.value(out).unwrap().shape(), &[1, 2, 2, 1]);
let loss = graph.sum(out).unwrap();
graph.backward(loss).unwrap();
let i_grad = graph.grad(input).unwrap().unwrap();
assert_eq!(i_grad.shape(), &[1, 4, 4, 1]);
let nonzero_count = i_grad.data().iter().filter(|&&g| g > 0.5).count();
assert_eq!(
nonzero_count, 4,
"expected 4 max elements to receive gradient"
);
}
#[test]
fn backward_instance_norm_nhwc_computes_gamma_beta_input_grads() {
let mut graph = Graph::new();
let input = graph.variable(
Tensor::from_vec(
vec![1, 2, 2, 2],
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
)
.unwrap(),
);
let gamma = graph.variable(Tensor::from_vec(vec![2], vec![1.0, 1.0]).unwrap());
let beta = graph.variable(Tensor::from_vec(vec![2], vec![0.0, 0.0]).unwrap());
let out = graph.instance_norm_nhwc(input, gamma, beta, 1e-5).unwrap();
assert_eq!(graph.value(out).unwrap().shape(), &[1, 2, 2, 2]);
let loss = graph.sum(out).unwrap();
graph.backward(loss).unwrap();
let g_grad = graph.grad(gamma).unwrap().unwrap();
assert_eq!(g_grad.shape(), &[2]);
let b_grad = graph.grad(beta).unwrap().unwrap();
assert_eq!(b_grad.shape(), &[2]);
assert!((b_grad.data()[0] - 4.0).abs() < 1e-4);
assert!((b_grad.data()[1] - 4.0).abs() < 1e-4);
let i_grad = graph.grad(input).unwrap().unwrap();
assert_eq!(i_grad.shape(), &[1, 2, 2, 2]);
for &g in i_grad.data() {
assert!(
g.abs() < 1e-3,
"instance norm input grad should be near 0, got {g}"
);
}
}
#[test]
fn backward_instance_norm_nhwc_multi_batch() {
let mut graph = Graph::new();
let input = graph.variable(
Tensor::from_vec(
vec![2, 2, 2, 1],
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
)
.unwrap(),
);
let gamma = graph.variable(Tensor::from_vec(vec![1], vec![2.0]).unwrap());
let beta = graph.variable(Tensor::from_vec(vec![1], vec![0.0]).unwrap());
let out = graph.instance_norm_nhwc(input, gamma, beta, 1e-5).unwrap();
assert_eq!(graph.value(out).unwrap().shape(), &[2, 2, 2, 1]);
let loss = graph.sum(out).unwrap();
graph.backward(loss).unwrap();
let g_grad = graph.grad(gamma).unwrap().unwrap();
assert_eq!(g_grad.shape(), &[1]);
let b_grad = graph.grad(beta).unwrap().unwrap();
assert_eq!(b_grad.shape(), &[1]);
assert!((b_grad.data()[0] - 8.0).abs() < 1e-4);
}
#[test]
fn backward_prelu_scalar_alpha() {
let mut graph = Graph::new();
let input = graph.variable(Tensor::from_vec(vec![1, 4], vec![1.0, -2.0, 3.0, -4.0]).unwrap());
let alpha = graph.variable(Tensor::from_vec(vec![1], vec![0.1]).unwrap());
let out = graph.prelu(input, alpha).unwrap();
let loss = graph.sum(out).unwrap();
graph.backward(loss).unwrap();
let i_grad = graph.grad(input).unwrap().unwrap();
assert!((i_grad.data()[0] - 1.0).abs() < 1e-6);
assert!((i_grad.data()[1] - 0.1).abs() < 1e-6);
assert!((i_grad.data()[2] - 1.0).abs() < 1e-6);
assert!((i_grad.data()[3] - 0.1).abs() < 1e-6);
let a_grad = graph.grad(alpha).unwrap().unwrap();
assert!((a_grad.data()[0] - (-6.0)).abs() < 1e-6);
}
#[test]
fn backward_prelu_per_channel_alpha() {
let mut graph = Graph::new();
let input = graph.variable(Tensor::from_vec(vec![1, 4], vec![-1.0, 2.0, -3.0, 4.0]).unwrap());
let alpha = graph.variable(Tensor::from_vec(vec![4], vec![0.1, 0.2, 0.3, 0.4]).unwrap());
let out = graph.prelu(input, alpha).unwrap();
let loss = graph.sum(out).unwrap();
graph.backward(loss).unwrap();
let i_grad = graph.grad(input).unwrap().unwrap();
assert!((i_grad.data()[0] - 0.1).abs() < 1e-6); assert!((i_grad.data()[1] - 1.0).abs() < 1e-6); assert!((i_grad.data()[2] - 0.3).abs() < 1e-6); assert!((i_grad.data()[3] - 1.0).abs() < 1e-6);
let a_grad = graph.grad(alpha).unwrap().unwrap();
assert!((a_grad.data()[0] - (-1.0)).abs() < 1e-6);
assert!((a_grad.data()[1] - 0.0).abs() < 1e-6);
assert!((a_grad.data()[2] - (-3.0)).abs() < 1e-6);
assert!((a_grad.data()[3] - 0.0).abs() < 1e-6);
}
#[test]
fn backward_separable_conv2d_produces_grads_through_composition() {
let mut graph = Graph::new();
let input = graph.variable(Tensor::filled(vec![1, 3, 3, 2], 1.0).unwrap());
let dw_weight = graph.variable(Tensor::filled(vec![2, 2, 2, 1], 0.5).unwrap());
let pw_weight = graph.variable(Tensor::filled(vec![1, 1, 2, 1], 1.0).unwrap());
let dw_out = graph
.depthwise_conv2d_nhwc(input, dw_weight, None, 1, 1)
.unwrap();
assert_eq!(graph.value(dw_out).unwrap().shape(), &[1, 2, 2, 2]);
let out = graph.conv2d_nhwc(dw_out, pw_weight, None, 1, 1).unwrap();
assert_eq!(graph.value(out).unwrap().shape(), &[1, 2, 2, 1]);
let loss = graph.sum(out).unwrap();
graph.backward(loss).unwrap();
assert!(graph.grad(input).unwrap().is_some());
assert!(graph.grad(dw_weight).unwrap().is_some());
assert!(graph.grad(pw_weight).unwrap().is_some());
}
#[test]
fn backward_conv_transpose2d_numerical_gradient_check() {
let eps = 1e-3;
let input_data = vec![1.0, 2.0, 3.0, 4.0];
let weight_data = vec![1.0, 0.5, 0.5, 0.25];
let mut graph = Graph::new();
let input = graph.variable(Tensor::from_vec(vec![1, 2, 2, 1], input_data.clone()).unwrap());
let weight = graph.variable(Tensor::from_vec(vec![2, 2, 1, 1], weight_data.clone()).unwrap());
let out = graph
.conv_transpose2d_nhwc(input, weight, None, 1, 1)
.unwrap();
let loss = graph.sum(out).unwrap();
graph.backward(loss).unwrap();
let analytic_grad = graph.grad(weight).unwrap().unwrap().data().to_vec();
for w_idx in 0..4 {
let mut wp = weight_data.clone();
wp[w_idx] += eps;
let mut graph_p = Graph::new();
let inp = graph_p.variable(Tensor::from_vec(vec![1, 2, 2, 1], input_data.clone()).unwrap());
let wt = graph_p.variable(Tensor::from_vec(vec![2, 2, 1, 1], wp).unwrap());
let o = graph_p.conv_transpose2d_nhwc(inp, wt, None, 1, 1).unwrap();
let loss_p = graph_p.value(o).unwrap().sum();
let mut wm = weight_data.clone();
wm[w_idx] -= eps;
let mut graph_m = Graph::new();
let inp = graph_m.variable(Tensor::from_vec(vec![1, 2, 2, 1], input_data.clone()).unwrap());
let wt = graph_m.variable(Tensor::from_vec(vec![2, 2, 1, 1], wm).unwrap());
let o = graph_m.conv_transpose2d_nhwc(inp, wt, None, 1, 1).unwrap();
let loss_m = graph_m.value(o).unwrap().sum();
let numerical = (loss_p - loss_m) / (2.0 * eps);
assert!(
(analytic_grad[w_idx] - numerical).abs() < 1e-2,
"conv_transpose2d weight grad mismatch at {w_idx}: analytic={}, numerical={}",
analytic_grad[w_idx],
numerical
);
}
}
#[test]
fn backward_prelu_numerical_gradient_check() {
let eps = 1e-3;
let input_data = vec![1.0, -2.0, 3.0, -4.0];
let alpha_val = 0.1f32;
let mut graph = Graph::new();
let input = graph.variable(Tensor::from_vec(vec![1, 4], input_data.clone()).unwrap());
let alpha = graph.variable(Tensor::from_vec(vec![1], vec![alpha_val]).unwrap());
let out = graph.prelu(input, alpha).unwrap();
let loss = graph.sum(out).unwrap();
graph.backward(loss).unwrap();
let analytic_alpha_grad = graph.grad(alpha).unwrap().unwrap().data()[0];
let mut graph_p = Graph::new();
let inp = graph_p.variable(Tensor::from_vec(vec![1, 4], input_data.clone()).unwrap());
let ap = graph_p.variable(Tensor::from_vec(vec![1], vec![alpha_val + eps]).unwrap());
let o = graph_p.prelu(inp, ap).unwrap();
let loss_p = graph_p.value(o).unwrap().sum();
let mut graph_m = Graph::new();
let inp = graph_m.variable(Tensor::from_vec(vec![1, 4], input_data.clone()).unwrap());
let am = graph_m.variable(Tensor::from_vec(vec![1], vec![alpha_val - eps]).unwrap());
let o = graph_m.prelu(inp, am).unwrap();
let loss_m = graph_m.value(o).unwrap().sum();
let numerical = (loss_p - loss_m) / (2.0 * eps);
assert!(
(analytic_alpha_grad - numerical).abs() < 1e-2,
"prelu alpha grad: analytic={analytic_alpha_grad}, numerical={numerical}"
);
}