use ndarray::{Array4, array};
use crate::Tensor;
use crate::nn::{Conv1d, LSTMCell, Layer, Reduction, ResizeMode};
use crate::test::helpers::RealizeTestExt;
fn get_shape(tensor: &Tensor) -> Vec<usize> {
tensor.uop().shape().unwrap().unwrap().iter().map(|s| s.as_const().unwrap()).collect()
}
#[test]
fn test_pool_2d_basic() {
let x = Tensor::from_ndarray(&Array4::<f32>::zeros((1, 1, 4, 4)));
let pooled = x.pool(&[2, 2], &[1, 1], &[1, 1]).unwrap();
let shape = pooled.shape().unwrap();
let dims: Vec<usize> = shape.iter().map(|s| s.as_const().unwrap()).collect();
assert_eq!(dims, vec![1, 1, 3, 3, 2, 2]);
}
#[test]
fn test_pool_2d_stride() {
let x = Tensor::from_ndarray(&Array4::<f32>::zeros((1, 1, 6, 6)));
let pooled = x.pool(&[3, 3], &[2, 2], &[1, 1]).unwrap();
let shape = pooled.shape().unwrap();
let dims: Vec<usize> = shape.iter().map(|s| s.as_const().unwrap()).collect();
assert_eq!(dims, vec![1, 1, 2, 2, 3, 3]);
}
#[test]
fn test_pool_2d_dilation() {
let x = Tensor::from_ndarray(&Array4::<f32>::zeros((1, 1, 7, 7)));
let pooled = x.pool(&[3, 3], &[1, 1], &[2, 2]).unwrap();
let shape = pooled.shape().unwrap();
let dims: Vec<usize> = shape.iter().map(|s| s.as_const().unwrap()).collect();
assert_eq!(dims, vec![1, 1, 3, 3, 3, 3]);
}
#[test]
fn test_avg_pool2d_ceil_mode_shape() {
let x = Tensor::from_ndarray(&Array4::<f32>::zeros((1, 1, 7, 7)));
let result = x.avg_pool2d().kernel_size(&[2, 2]).stride(&[3, 3]).ceil_mode(true).call().unwrap();
let shape = result.shape().unwrap();
let dims: Vec<usize> = shape.iter().map(|s| s.as_const().unwrap()).collect();
assert_eq!(dims, vec![1, 1, 3, 3]);
}
#[test]
fn test_max_pool2d_ceil_mode_shape() {
let x = Tensor::from_ndarray(&Array4::<f32>::zeros((1, 1, 7, 7)));
let result = x.max_pool2d().kernel_size(&[2, 2]).stride(&[3, 3]).ceil_mode(true).call().unwrap();
let shape = result.shape().unwrap();
let dims: Vec<usize> = shape.iter().map(|s| s.as_const().unwrap()).collect();
assert_eq!(dims, vec![1, 1, 3, 3]);
}
#[test]
fn test_avg_pool2d_ceil_mode_large_stride() {
let x = Tensor::from_ndarray(&array![[[[1.0f32, 2.0, 3.0]]]]);
let result = x.avg_pool2d().kernel_size(&[1, 2]).stride(&[1, 3]).ceil_mode(true).call().unwrap();
let shape = result.shape().unwrap();
let dims: Vec<usize> = shape.iter().map(|s| s.as_const().unwrap()).collect();
assert_eq!(dims[3], 1);
}
#[test]
fn test_linspace_zero() {
let t = Tensor::linspace(0.0, 1.0, 0, svod_dtype::DType::Float32).unwrap();
assert_eq!(get_shape(&t), vec![0]);
}
fn expect_err_msg<T>(result: crate::Result<T>, substr: &str) {
let msg = result.err().expect("expected an error").to_string();
assert!(msg.contains(substr), "error should contain '{substr}', got: {msg}");
}
#[test]
fn test_depth_to_space_rejects_3d() {
let x = Tensor::from_slice([0.0f32; 24]).try_reshape([2, 3, 4]).unwrap();
expect_err_msg(x.depth_to_space().blocksize(2).call(), "exactly 4D");
}
#[test]
fn test_depth_to_space_rejects_indivisible_channels() {
let x = Tensor::from_ndarray(&Array4::<f32>::zeros((1, 3, 2, 2)));
expect_err_msg(x.depth_to_space().blocksize(2).call(), "divisible");
}
#[test]
fn test_space_to_depth_rejects_indivisible_spatial() {
let x = Tensor::from_ndarray(&Array4::<f32>::zeros((1, 1, 3, 3)));
expect_err_msg(x.space_to_depth(2), "divisible");
}
#[test]
fn test_dropout_rejects_invalid_p() {
let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
expect_err_msg(x.dropout().p(1.5).call(), "p");
expect_err_msg(x.dropout().p(-0.1).call(), "p");
}
#[test]
fn test_lp_pool_rejects_p_zero() {
let x = Tensor::from_ndarray(&Array4::<f32>::zeros((1, 1, 4, 4)));
expect_err_msg(x.lp_pool().kernel_shape(&[2, 2]).p(0).call(), "p");
}
#[test]
fn test_group_norm_rejects_1d() {
let x = Tensor::from_slice([1.0f32, 2.0, 3.0, 4.0]);
let scale = Tensor::from_slice([1.0f32]);
let bias = Tensor::from_slice([0.0f32]);
expect_err_msg(x.group_norm().scale(&scale).bias(&bias).num_groups(1).call(), "at least 2D");
}
#[test]
fn test_lrn_rejects_3d() {
let x = Tensor::from_slice([0.0f32; 24]).try_reshape([2, 3, 4]).unwrap();
expect_err_msg(x.lrn().size(5).call(), "exactly 4D");
}
crate::codegen_tests! {
fn test_pad_value_neg_inf(config) {
let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
let mut padded = x.try_pad_value(&[(1, 1)], f64::NEG_INFINITY).unwrap();
let result = padded.realize_with_and(&config).as_vec::<f32>().unwrap();
assert_eq!(result.len(), 5);
assert!(result[0].is_infinite() && result[0] < 0.0);
assert_eq!(result[1], 1.0);
assert_eq!(result[2], 2.0);
assert_eq!(result[3], 3.0);
assert!(result[4].is_infinite() && result[4] < 0.0);
}
fn test_pad_value_zero_delegates(config) {
let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
let mut padded = x.try_pad_value(&[(1, 1)], 0.0).unwrap();
let result = padded.realize_with_and(&config).as_vec::<f32>().unwrap();
assert_eq!(result.len(), 5);
assert_eq!(result[0], 0.0);
assert_eq!(result[1], 1.0);
assert_eq!(result[3], 3.0);
assert_eq!(result[4], 0.0);
}
fn test_conv2d_1x1(config) {
let x_data: Vec<f32> = (1..=9).map(|v| v as f32).collect();
let x = Tensor::from_ndarray(&Array4::from_shape_vec((1, 1, 3, 3), x_data).unwrap());
let w = Tensor::from_ndarray(&array![[[[2.0f32]]]]);
let result = x.conv2d().weight(&w).call().unwrap();
let mut result = result.contiguous();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
let expected: Vec<f32> = (1..=9).map(|v| v as f32 * 2.0).collect();
assert_eq!(view.shape(), &[1, 1, 3, 3]);
for (got, exp) in view.iter().zip(expected.iter()) {
assert!((got - exp).abs() < 1e-5, "got {got}, expected {exp}");
}
}
fn test_conv2d_3x3(config) {
let x_data: Vec<f32> = (0..16).map(|v| v as f32).collect();
let x = Tensor::from_ndarray(&Array4::from_shape_vec((1, 1, 4, 4), x_data).unwrap());
let w = Tensor::from_ndarray(&Array4::<f32>::ones((1, 1, 3, 3)));
let result = x.conv2d().weight(&w).call().unwrap();
let mut result = result.contiguous();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[1, 1, 2, 2]);
assert!((view[[0, 0, 0, 0]] - 45.0).abs() < 1e-4);
assert!((view[[0, 0, 0, 1]] - 54.0).abs() < 1e-4);
assert!((view[[0, 0, 1, 0]] - 81.0).abs() < 1e-4);
assert!((view[[0, 0, 1, 1]] - 90.0).abs() < 1e-4);
}
fn test_conv2d_stride(config) {
let x_data: Vec<f32> = (0..16).map(|v| v as f32).collect();
let x = Tensor::from_ndarray(&Array4::from_shape_vec((1, 1, 4, 4), x_data).unwrap());
let w = Tensor::from_ndarray(&Array4::<f32>::ones((1, 1, 2, 2)));
let result = x.conv2d().weight(&w).stride(&[2, 2]).call().unwrap();
let mut result = result.contiguous();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[1, 1, 2, 2]);
assert!((view[[0, 0, 0, 0]] - 10.0).abs() < 1e-4);
assert!((view[[0, 0, 0, 1]] - 18.0).abs() < 1e-4);
}
#[ignore = "blocked by CONTIGUOUS realization range-leak bug in rangeify pipeline"]
fn test_conv2d_groups(config) {
let x = Tensor::from_ndarray(&Array4::<f32>::ones((1, 2, 3, 3)));
let w = Tensor::from_ndarray(&array![[[[2.0f32]]], [[[3.0f32]]]]);
let result = x.conv2d().weight(&w).groups(2).call().unwrap();
let mut result = result.contiguous();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[1, 2, 3, 3]);
assert!((view[[0, 0, 0, 0]] - 2.0).abs() < 1e-4);
assert!((view[[0, 1, 0, 0]] - 3.0).abs() < 1e-4);
}
fn test_conv2d_bias(config) {
let x = Tensor::from_ndarray(&Array4::<f32>::ones((1, 1, 2, 2)));
let w = Tensor::from_ndarray(&array![[[[1.0f32]]]]);
let b = Tensor::from_slice([10.0f32]);
let result = x.conv2d().weight(&w).bias(&b).call().unwrap();
let mut result = result.contiguous();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[1, 1, 2, 2]);
assert!((view[[0, 0, 0, 0]] - 11.0).abs() < 1e-4);
}
fn test_conv2d_padding(config) {
let x = Tensor::from_ndarray(&Array4::<f32>::ones((1, 1, 3, 3)));
let w = Tensor::from_ndarray(&Array4::<f32>::ones((1, 1, 3, 3)));
let result = x.conv2d().weight(&w).padding(&[(1, 1), (1, 1)]).call().unwrap();
let shape = result.shape().unwrap();
let dims: Vec<usize> = shape.iter().map(|s| s.as_const().unwrap()).collect();
assert_eq!(dims, vec![1, 1, 3, 3]);
let mut result = result.contiguous();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert!((view[[0, 0, 1, 1]] - 9.0).abs() < 1e-4);
assert!((view[[0, 0, 0, 0]] - 4.0).abs() < 1e-4);
}
fn test_avg_pool2d(config) {
let x_data: Vec<f32> = (0..16).map(|v| v as f32).collect();
let x = Tensor::from_ndarray(&Array4::from_shape_vec((1, 1, 4, 4), x_data).unwrap());
let result = x.avg_pool2d().kernel_size(&[2, 2]).stride(&[2, 2]).call().unwrap();
let mut result = result.contiguous();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[1, 1, 2, 2]);
assert!((view[[0, 0, 0, 0]] - 2.5).abs() < 1e-4);
assert!((view[[0, 0, 0, 1]] - 4.5).abs() < 1e-4);
assert!((view[[0, 0, 1, 0]] - 10.5).abs() < 1e-4);
assert!((view[[0, 0, 1, 1]] - 12.5).abs() < 1e-4);
}
fn test_max_pool2d(config) {
let x_data: Vec<f32> =
vec![-1.0, 2.0, 3.0, -4.0, 5.0, -6.0, 7.0, 8.0, 9.0, 10.0, -11.0, 12.0, 13.0, -14.0, 15.0, 16.0];
let x = Tensor::from_ndarray(&Array4::from_shape_vec((1, 1, 4, 4), x_data).unwrap());
let result = x.max_pool2d().kernel_size(&[2, 2]).stride(&[2, 2]).call().unwrap();
let mut result = result.contiguous();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[1, 1, 2, 2]);
assert!((view[[0, 0, 0, 0]] - 5.0).abs() < 1e-4);
assert!((view[[0, 0, 0, 1]] - 8.0).abs() < 1e-4);
assert!((view[[0, 0, 1, 0]] - 13.0).abs() < 1e-4);
assert!((view[[0, 0, 1, 1]] - 16.0).abs() < 1e-4);
}
fn test_max_pool2d_pad(config) {
let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), -5.0f32));
let mut result = x.max_pool2d().kernel_size(&[3, 3]).stride(&[1, 1]).padding(&[(1, 1), (1, 1)]).call().unwrap();
result.realize_with(&config).unwrap();
let result = result.as_vec::<f32>().unwrap();
for val in result.iter() {
assert!((*val - (-5.0)).abs() < 1e-4, "max_pool2d with padding should use -inf fill, got {val}");
}
}
fn test_max_pool2d_with_indices_basic(config) {
let x_data: Vec<f32> =
vec![-1.0, 2.0, 3.0, -4.0, 5.0, -6.0, 7.0, 8.0, 9.0, 10.0, -11.0, 12.0, 13.0, -14.0, 15.0, 16.0];
let x = Tensor::from_ndarray(&Array4::from_shape_vec((1, 1, 4, 4), x_data).unwrap());
let (values, indices) = x.max_pool2d_with_indices().kernel_size(&[2, 2]).stride(&[2, 2]).call().unwrap();
let mut values = values.contiguous();
values.realize_with(&config).unwrap();
let vals = values.array_view::<f32>().unwrap();
assert_eq!(vals.shape(), &[1, 1, 2, 2]);
assert!((vals[[0, 0, 0, 0]] - 5.0).abs() < 1e-4);
assert!((vals[[0, 0, 0, 1]] - 8.0).abs() < 1e-4);
let mut indices = indices.contiguous();
indices.realize_with(&config).unwrap();
let idx = indices.array_view::<i32>().unwrap();
assert_eq!(idx.shape(), &[1, 1, 2, 2]);
assert_eq!(idx[[0, 0, 0, 0]], 4);
assert_eq!(idx[[0, 0, 0, 1]], 7);
}
fn test_layernorm(config) {
let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]);
let result = x.layernorm(-1, 1e-5).unwrap();
let mut result = result.contiguous();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[2, 4]);
for row in 0..2 {
let row_data: Vec<f32> = (0..4).map(|c| view[[row, c]]).collect();
let mean: f32 = row_data.iter().sum::<f32>() / 4.0;
let var: f32 = row_data.iter().map(|x| (x - mean) * (x - mean)).sum::<f32>() / 4.0;
assert!(mean.abs() < 1e-4, "mean should be ~0, got {mean}");
assert!((var - 1.0).abs() < 0.1, "var should be ~1, got {var}");
}
}
fn test_layernorm_2d(config) {
let x_data: Vec<f32> = (0..24).map(|v| v as f32).collect();
let x = Tensor::from_ndarray(&ndarray::Array3::from_shape_vec((2, 3, 4), x_data).unwrap());
let result = x.layernorm(-2, 1e-5).unwrap();
let mut result = result.contiguous();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[2, 3, 4]);
for b in 0..2 {
let mut sum = 0.0f32;
for h in 0..3 {
for w in 0..4 {
sum += view[[b, h, w]];
}
}
let mean = sum / 12.0;
assert!(mean.abs() < 1e-3, "mean should be ~0, got {mean}");
}
}
fn test_resize_nearest_upsample(config) {
let t = Tensor::from_ndarray(&array![[[[1.0f32, 2.0], [3.0, 4.0]]]]);
let mut result = t.resize().scales(&[1.0, 1.0, 2.0, 2.0]).mode(ResizeMode::Nearest).call().unwrap();
result.realize_with(&config).unwrap();
assert_eq!(get_shape(&result), vec![1, 1, 4, 4]);
}
fn test_resize_linear_upsample(config) {
let t = Tensor::from_ndarray(&array![[[[1.0f32, 2.0], [3.0, 4.0]]]]);
let mut result = t.resize().scales(&[1.0, 1.0, 2.0, 2.0]).mode(ResizeMode::Linear).call().unwrap();
result.realize_with(&config).unwrap();
assert_eq!(get_shape(&result), vec![1, 1, 4, 4]);
}
fn test_resize_nearest_downsample(config) {
let x_data: Vec<f32> = (1..=9).map(|v| v as f32).collect();
let t = Tensor::from_ndarray(&Array4::from_shape_vec((1, 1, 3, 3), x_data).unwrap());
let mut result = t.resize().sizes(&[1, 1, 2, 2]).mode(ResizeMode::Nearest).call().unwrap();
result.realize_with(&config).unwrap();
assert_eq!(get_shape(&result), vec![1, 1, 2, 2]);
}
fn test_linspace_basic(config) {
let mut t = Tensor::linspace(-1.0, 1.0, 5, svod_dtype::DType::Float32).unwrap();
assert_eq!(get_shape(&t), vec![5]);
t.realize_with(&config).unwrap();
let result = t.as_vec::<f32>().unwrap();
let expected = [-1.0f32, -0.5, 0.0, 0.5, 1.0];
for (got, exp) in result.iter().zip(expected.iter()) {
assert!((got - exp).abs() < 1e-5, "got {got}, expected {exp}");
}
}
fn test_linspace_single(config) {
let mut t = Tensor::linspace(3.0, 7.0, 1, svod_dtype::DType::Float32).unwrap();
assert_eq!(get_shape(&t), vec![1], "steps=1 must produce 1-D shape [1]");
t.realize_with(&config).unwrap();
let vals = t.as_vec::<f32>().unwrap();
assert_eq!(vals.len(), 1);
assert!((vals[0] - 3.0).abs() < 1e-5);
}
fn test_nll_loss_basic(config) {
let log_probs = Tensor::from_ndarray(&array![
[-0.5f32, -1.0, -2.0], [-0.3, -1.5, -0.8], ]);
let target = Tensor::from_slice([0i64, 2]); let mut loss = log_probs.nll_loss().target(&target).call().unwrap();
let val = loss.realize_with_and(&config).as_vec::<f32>().unwrap()[0];
assert!((val - 0.65).abs() < 1e-4, "got {val}");
}
fn test_nll_loss_none_reduction(config) {
let log_probs = Tensor::from_ndarray(&array![
[-0.5f32, -1.0, -2.0], [-0.3, -1.5, -0.8], ]);
let target = Tensor::from_slice([0i64, 2]);
let mut loss = log_probs.nll_loss().target(&target).reduction(Reduction::None).call().unwrap();
let vals = loss.realize_with_and(&config).as_vec::<f32>().unwrap();
assert_eq!(vals.len(), 2);
assert!((vals[0] - 0.5).abs() < 1e-4);
assert!((vals[1] - 0.8).abs() < 1e-4);
}
fn test_nll_loss_weighted(config) {
let log_probs = Tensor::from_ndarray(&array![
[-0.5f32, -1.0, -2.0], [-0.3, -1.5, -0.8], ]);
let target = Tensor::from_slice([0i64, 2]);
let weight = Tensor::from_slice([2.0f32, 1.0, 3.0]); let mut loss = log_probs.nll_loss().target(&target).weight(&weight).call().unwrap();
let val = loss.realize_with_and(&config).as_vec::<f32>().unwrap()[0];
assert!((val - 0.68).abs() < 1e-4, "got {val}");
}
fn test_nll_loss_ignore_index(config) {
let log_probs = Tensor::from_ndarray(&array![
[-0.5f32, -1.0, -2.0], [-0.3, -1.5, -0.8], ]);
let target = Tensor::from_slice([0i64, 2]);
let mut loss = log_probs.nll_loss().target(&target).ignore_index(2).call().unwrap();
let val = loss.realize_with_and(&config).as_vec::<f32>().unwrap()[0];
assert!((val - 0.5).abs() < 1e-4, "got {val}");
}
fn test_dropout_inference(config) {
let x = Tensor::from_slice([1.0f32, 2.0, 3.0, 4.0]);
let (mut output, mut mask) = x.dropout().p(0.5).call().unwrap();
output.realize_with(&config).unwrap();
assert_eq!(output.as_vec::<f32>().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
mask.realize_with(&config).unwrap();
assert!(mask.as_vec::<bool>().unwrap().iter().all(|&v| v));
}
fn test_conv1d_module_matches_explicit_conv2d(config) {
let x_data: Vec<f32> = (0..8).map(|v| v as f32 * 0.1).collect();
let x = Tensor::from_slice(&x_data).try_reshape([1isize, 2, 4]).unwrap();
let w_data: Vec<f32> = (0..18).map(|v| (v as f32 * 0.05).sin()).collect();
let w = Tensor::from_slice(&w_data).try_reshape([3isize, 2, 3]).unwrap();
let b = Tensor::from_slice([0.1f32, 0.2, 0.3]);
let conv = Conv1d::new(w.clone(), Some(b.clone())).with_stride(2).with_padding((1, 1));
let mut got = conv.forward(&x).unwrap().contiguous();
got.realize_with(&config).unwrap();
let mut expected =
x.conv2d().weight(&w).bias(&b).stride(&[2]).padding(&[(1, 1)]).call().unwrap().contiguous();
expected.realize_with(&config).unwrap();
assert_eq!(got.as_vec::<f32>().unwrap(), expected.as_vec::<f32>().unwrap());
}
fn test_conv1d_no_bias(config) {
let x = Tensor::from_slice([1.0f32, 2.0, 3.0, 4.0]).try_reshape([1isize, 1, 4]).unwrap();
let w = Tensor::from_slice([0.5f32, 0.25]).try_reshape([1isize, 1, 2]).unwrap();
let conv = Conv1d::new(w.clone(), None);
let mut got = conv.forward(&x).unwrap().contiguous();
got.realize_with(&config).unwrap();
let mut expected = x.conv2d().weight(&w).call().unwrap().contiguous();
expected.realize_with(&config).unwrap();
assert_eq!(got.as_vec::<f32>().unwrap(), expected.as_vec::<f32>().unwrap());
}
fn test_lstm_cell_step_matches_explicit(config) {
let input = 3usize;
let hidden = 2usize;
let four_hidden = 4 * hidden;
let w_ih_data: Vec<f32> = (0..four_hidden * input).map(|i| (i as f32 * 0.1).sin()).collect();
let w_ih = Tensor::from_slice(&w_ih_data).try_reshape([four_hidden as isize, input as isize]).unwrap();
let w_hh_data: Vec<f32> = (0..four_hidden * hidden).map(|i| (i as f32 * 0.07).cos()).collect();
let w_hh = Tensor::from_slice(&w_hh_data).try_reshape([four_hidden as isize, hidden as isize]).unwrap();
let b_ih = Tensor::from_slice([0.01f32, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08]);
let b_hh = Tensor::from_slice([-0.01f32, -0.02, -0.03, -0.04, -0.05, -0.06, -0.07, -0.08]);
let x = Tensor::from_slice([0.5f32, -0.25, 0.125]).try_reshape([1isize, input as isize]).unwrap();
let h0 = Tensor::from_slice([0.1f32, -0.2]).try_reshape([1isize, hidden as isize]).unwrap();
let c0 = Tensor::from_slice([0.3f32, 0.4]).try_reshape([1isize, hidden as isize]).unwrap();
let cell = LSTMCell::new(w_ih.clone(), w_hh.clone(), b_ih.clone(), b_hh.clone());
assert_eq!(cell.hidden_size(), hidden);
let (new_h, new_c) = cell.step(&x, &h0, &c0).unwrap();
let mut new_h = new_h.contiguous();
let mut new_c = new_c.contiguous();
new_h.realize_with(&config).unwrap();
new_c.realize_with(&config).unwrap();
let gates_x = x.linear().weight(&w_ih).bias(&b_ih).call().unwrap();
let gates_h = h0.linear().weight(&w_hh).bias(&b_hh).call().unwrap();
let gates = gates_x.try_add(&gates_h).unwrap();
let parts = gates.split(&[hidden, hidden, hidden, hidden], 1).unwrap();
let i = parts[0].sigmoid().unwrap();
let f = parts[1].sigmoid().unwrap();
let g = parts[2].tanh().unwrap();
let o = parts[3].sigmoid().unwrap();
let exp_c = f.try_mul(&c0).unwrap().try_add(&i.try_mul(&g).unwrap()).unwrap();
let exp_h = o.try_mul(&exp_c.tanh().unwrap()).unwrap();
let mut exp_h = exp_h.contiguous();
let mut exp_c = exp_c.contiguous();
exp_h.realize_with(&config).unwrap();
exp_c.realize_with(&config).unwrap();
assert_eq!(new_h.as_vec::<f32>().unwrap(), exp_h.as_vec::<f32>().unwrap());
assert_eq!(new_c.as_vec::<f32>().unwrap(), exp_c.as_vec::<f32>().unwrap());
}
}
#[test]
fn test_densenet_two_layer_kernel_count() {
use ndarray::Array4;
let mk_bn_params = |ch: usize| {
let mean = Tensor::from_slice(vec![0.0f32; ch]);
let var = Tensor::from_slice(vec![1.0f32; ch]);
let gamma = Tensor::from_slice(vec![1.0f32; ch]);
let beta = Tensor::from_slice(vec![0.0f32; ch]);
let invstd =
(&var + Tensor::const_(1e-5f64, svod_dtype::DType::Float32)).try_sqrt().unwrap().reciprocal().unwrap();
(mean, invstd, gamma, beta)
};
let x0 = Tensor::from_ndarray(&Array4::<f32>::ones((1, 64, 14, 14)));
let (m, inv, g, b) = mk_bn_params(64);
let bn1 = x0.batchnorm().mean(&m).invstd(&inv).scale(&g).bias(&b).call().unwrap().relu().unwrap();
let w1x1 = Tensor::from_ndarray(&Array4::<f32>::ones((128, 64, 1, 1)));
let conv1x1 = bn1.conv2d().weight(&w1x1).call().unwrap();
let (m, inv, g, b) = mk_bn_params(128);
let bn2 = conv1x1.batchnorm().mean(&m).invstd(&inv).scale(&g).bias(&b).call().unwrap().relu().unwrap();
let w3x3 = Tensor::from_ndarray(&Array4::<f32>::ones((32, 128, 3, 3)));
let conv3x3 = bn2.conv2d().weight(&w3x3).padding(&[(1, 1), (1, 1)]).call().unwrap();
let cat1 = Tensor::cat(&[&x0, &conv3x3], 1).unwrap();
let (m, inv, g, b) = mk_bn_params(96);
let bn3 = cat1.batchnorm().mean(&m).invstd(&inv).scale(&g).bias(&b).call().unwrap().relu().unwrap();
let w1x1_2 = Tensor::from_ndarray(&Array4::<f32>::ones((128, 96, 1, 1)));
let conv1x1_2 = bn3.conv2d().weight(&w1x1_2).call().unwrap();
let (m, inv, g, b) = mk_bn_params(128);
let bn4 = conv1x1_2.batchnorm().mean(&m).invstd(&inv).scale(&g).bias(&b).call().unwrap().relu().unwrap();
let w3x3_2 = Tensor::from_ndarray(&Array4::<f32>::ones((32, 128, 3, 3)));
let conv3x3_2 = bn4.conv2d().weight(&w3x3_2).padding(&[(1, 1), (1, 1)]).call().unwrap();
let result = Tensor::cat(&[&cat1, &conv3x3_2], 1).unwrap();
let uop = result.uop();
let sink = svod_ir::UOp::sink(vec![uop.clone()]);
let normalization = crate::realize::normalize_for_schedule_cache(&sink).expect("normalize schedule cache");
let (rangeified, _ctx) = svod_schedule::rangeify::rangeify(normalization.normalized, None).unwrap();
let (kernels_root, _kctx) = svod_schedule::rangeify::try_get_kernel_graph(rangeified)
.expect("kernel split pipeline should succeed for dense layer kernel count");
let kernels: Vec<_> =
kernels_root.toposort().into_iter().filter(|n| matches!(n.op(), svod_ir::Op::Call { .. })).collect();
assert_eq!(kernels.len(), 6, "Expected 6 kernels for 2 dense layers, got {}", kernels.len());
}