use ruvector_cnn::layers::{
Activation, ActivationType, BatchNorm, Conv2d, DepthwiseSeparableConv,
GlobalAvgPool, HardSwish, Layer, MaxPool2d, AvgPool2d, ReLU, ReLU6,
Sigmoid, Swish, TensorShape,
};
use ruvector_cnn::{simd, Tensor};
#[test]
fn test_conv2d_output_shape_no_padding() {
let conv = Conv2d::new(3, 16, 3, 1, 0);
let input = Tensor::ones(&[1, 8, 8, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 6, 6, 16]);
}
#[test]
fn test_conv2d_output_shape_with_padding() {
let conv = Conv2d::new(3, 16, 3, 1, 1);
let input = Tensor::ones(&[1, 8, 8, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 8, 8, 16]);
}
#[test]
fn test_conv2d_output_shape_with_stride() {
let conv = Conv2d::new(3, 16, 3, 2, 1);
let input = Tensor::ones(&[1, 8, 8, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 4, 4, 16]);
}
#[test]
fn test_conv2d_batch_processing() {
let conv = Conv2d::new(3, 8, 3, 1, 1);
let input = Tensor::ones(&[4, 16, 16, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[4, 16, 16, 8]);
}
#[test]
fn test_conv2d_1x1_pointwise() {
let conv = Conv2d::new(64, 128, 1, 1, 0);
let input = Tensor::ones(&[1, 7, 7, 64]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 7, 7, 128]);
}
#[test]
fn test_conv2d_output_shape_method() {
let conv = Conv2d::new(3, 64, 3, 1, 1);
let shape = conv.output_shape(&[1, 224, 224, 3]).unwrap();
assert_eq!(shape, vec![1, 224, 224, 64]);
}
#[test]
fn test_conv2d_output_shape_stride2() {
let conv = Conv2d::new(3, 64, 3, 2, 1);
let shape = conv.output_shape(&[1, 224, 224, 3]).unwrap();
assert_eq!(shape, vec![1, 112, 112, 64]);
}
#[test]
fn test_depthwise_separable_conv_shape() {
let dw_conv = DepthwiseSeparableConv::new(32, 64, 3, 1, 1);
let input = Tensor::ones(&[1, 14, 14, 32]);
let output = dw_conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 14, 14, 64]);
}
#[test]
fn test_depthwise_separable_conv_params() {
let conv = DepthwiseSeparableConv::new(16, 32, 3, 1, 1);
assert_eq!(conv.num_params(), 144 + 512);
}
#[test]
fn test_batchnorm_output_shape() {
let bn = BatchNorm::new(64);
let input = Tensor::ones(&[2, 8, 8, 64]);
let output = bn.forward(&input).unwrap();
assert_eq!(output.shape(), input.shape());
}
#[test]
fn test_batchnorm_creation() {
let bn = BatchNorm::new(64);
assert_eq!(bn.num_features(), 64);
assert_eq!(bn.gamma().len(), 64);
assert_eq!(bn.beta().len(), 64);
assert_eq!(bn.num_params(), 128);
}
#[test]
fn test_batchnorm_default_params() {
let bn = BatchNorm::new(4);
for i in 0..4 {
assert!((bn.gamma()[i] - 1.0).abs() < 1e-6);
assert!((bn.beta()[i]).abs() < 1e-6);
}
}
#[test]
fn test_batchnorm_with_running_stats() {
let mut bn = BatchNorm::new(2);
bn.set_running_stats(vec![1.0, 2.0], vec![1.0, 4.0]).unwrap();
let input = Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0], &[1, 2, 1, 2]).unwrap();
let output = bn.forward(&input).unwrap();
assert!(output.data()[0].abs() < 0.01); assert!(output.data()[1].abs() < 0.01); assert!((output.data()[2] - 2.0).abs() < 0.01); assert!((output.data()[3] - 1.0).abs() < 0.01); }
#[test]
fn test_batchnorm_numerical_stability() {
let mut bn = BatchNorm::new(4);
bn.set_running_stats(vec![0.0; 4], vec![1e-10; 4]).unwrap();
let input = Tensor::ones(&[1, 2, 2, 4]);
let output = bn.forward(&input).unwrap();
for &val in output.data() {
assert!(val.is_finite(), "Output should be finite, got {}", val);
}
}
#[test]
fn test_batchnorm_invalid_channels() {
let bn = BatchNorm::new(4);
let input = Tensor::ones(&[1, 8, 8, 8]);
let result = bn.forward(&input);
assert!(result.is_err());
}
#[test]
fn test_relu_positive_unchanged() {
let relu = ReLU::new();
let input = Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
let output = relu.forward(&input).unwrap();
assert_eq!(output.data(), &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_relu_negative_zeroed() {
let relu = ReLU::new();
let input = Tensor::from_data(vec![-1.0, -2.0, -3.0, -4.0], &[4]).unwrap();
let output = relu.forward(&input).unwrap();
assert_eq!(output.data(), &[0.0, 0.0, 0.0, 0.0]);
}
#[test]
fn test_relu_mixed() {
let relu = ReLU::new();
let input = Tensor::from_data(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5]).unwrap();
let output = relu.forward(&input).unwrap();
assert_eq!(output.data(), &[0.0, 0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn test_relu6_clamps_at_6() {
let relu6 = ReLU6::new();
let input = Tensor::from_data(vec![-1.0, 0.0, 3.0, 6.0, 10.0], &[5]).unwrap();
let output = relu6.forward(&input).unwrap();
assert_eq!(output.data(), &[0.0, 0.0, 3.0, 6.0, 6.0]);
}
#[test]
fn test_swish_properties() {
let swish = Swish::new();
let input = Tensor::from_data(vec![0.0, 1.0, -1.0], &[3]).unwrap();
let output = swish.forward(&input).unwrap();
assert!(output.data()[0].abs() < 0.001);
assert!((output.data()[1] - 0.731).abs() < 0.01);
}
#[test]
fn test_hard_swish() {
let hs = HardSwish::new();
let input = Tensor::from_data(vec![-4.0, -3.0, 0.0, 3.0, 4.0], &[5]).unwrap();
let output = hs.forward(&input).unwrap();
assert!(output.data()[0].abs() < 0.001);
assert!(output.data()[1].abs() < 0.001);
assert!(output.data()[2].abs() < 0.001);
assert!((output.data()[3] - 3.0).abs() < 0.001);
}
#[test]
fn test_sigmoid_at_zero() {
let sigmoid = Sigmoid::new();
let input = Tensor::from_data(vec![0.0], &[1]).unwrap();
let output = sigmoid.forward(&input).unwrap();
assert!((output.data()[0] - 0.5).abs() < 0.001);
}
#[test]
fn test_activation_generic() {
let activation = Activation::new(ActivationType::ReLU);
let mut data = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
activation.apply_inplace(&mut data);
assert_eq!(data, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn test_activation_identity() {
let activation = Activation::new(ActivationType::Identity);
let mut data = vec![-2.0, 0.0, 2.0];
activation.apply_inplace(&mut data);
assert_eq!(data, vec![-2.0, 0.0, 2.0]);
}
#[test]
fn test_global_avg_pool_output_shape() {
let pool = GlobalAvgPool::new();
let input = Tensor::ones(&[1, 7, 7, 512]);
let output = pool.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 1, 512]);
}
#[test]
fn test_global_avg_pool_computes_average() {
let pool = GlobalAvgPool::new();
let mut data = vec![0.0; 2 * 2 * 2];
for i in 0..4 {
data[i * 2] = 1.0; data[i * 2 + 1] = 2.0; }
let input = Tensor::from_data(data, &[1, 2, 2, 2]).unwrap();
let output = pool.forward(&input).unwrap();
assert!((output.data()[0] - 1.0).abs() < 0.001);
assert!((output.data()[1] - 2.0).abs() < 0.001);
}
#[test]
fn test_global_avg_pool_batch() {
let pool = GlobalAvgPool::new();
let input = Tensor::ones(&[3, 4, 4, 8]);
let output = pool.forward(&input).unwrap();
assert_eq!(output.shape(), &[3, 1, 1, 8]);
for &val in output.data() {
assert!((val - 1.0).abs() < 0.001);
}
}
#[test]
fn test_max_pool_2d_output_shape() {
let pool = MaxPool2d::new(2, 2, 0);
let input = Tensor::ones(&[1, 8, 8, 32]);
let output = pool.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 4, 4, 32]);
}
#[test]
fn test_max_pool_2d_finds_maximum() {
let pool = MaxPool2d::new(2, 2, 0);
let data = vec![1.0, 2.0, 3.0, 4.0];
let input = Tensor::from_data(data, &[1, 2, 2, 1]).unwrap();
let output = pool.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 1, 1]);
assert_eq!(output.data()[0], 4.0);
}
#[test]
fn test_max_pool_2d_output_shape_method() {
let pool = MaxPool2d::new(2, 2, 0);
let shape = pool.output_shape(&[1, 224, 224, 64]).unwrap();
assert_eq!(shape, vec![1, 112, 112, 64]);
}
#[test]
fn test_avg_pool_2d_output_shape() {
let pool = AvgPool2d::new(2, 2, 0);
let input = Tensor::ones(&[1, 8, 8, 4]);
let output = pool.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 4, 4, 4]);
}
#[test]
fn test_avg_pool_2d_computes_average() {
let pool = AvgPool2d::new(2, 2, 0);
let data = vec![1.0, 2.0, 3.0, 4.0];
let input = Tensor::from_data(data, &[1, 2, 2, 1]).unwrap();
let output = pool.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 1, 1]);
assert!((output.data()[0] - 2.5).abs() < 0.001); }
#[test]
fn test_max_pool_with_stride1() {
let pool = MaxPool2d::new(2, 1, 0);
let shape = pool.output_shape(&[1, 4, 4, 1]).unwrap();
assert_eq!(shape, vec![1, 3, 3, 1]);
}
#[test]
fn test_tensor_shape() {
let shape = TensorShape::new(2, 64, 7, 7);
assert_eq!(shape.n, 2);
assert_eq!(shape.c, 64);
assert_eq!(shape.h, 7);
assert_eq!(shape.w, 7);
assert_eq!(shape.numel(), 2 * 64 * 7 * 7);
}
#[test]
fn test_simd_relu() {
let input = vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0, -7.0, 8.0];
let mut output = vec![0.0; 8];
simd::relu_simd(&input, &mut output);
assert_eq!(output, vec![0.0, 2.0, 0.0, 4.0, 0.0, 6.0, 0.0, 8.0]);
}
#[test]
fn test_simd_relu6() {
let input = vec![-1.0, 2.0, 7.0, 4.0, -5.0, 10.0, 3.0, 8.0];
let mut output = vec![0.0; 8];
simd::relu6_simd(&input, &mut output);
assert_eq!(output, vec![0.0, 2.0, 6.0, 4.0, 0.0, 6.0, 3.0, 6.0]);
}
#[test]
fn test_simd_dot_product() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let result = simd::dot_product_simd(&a, &b);
let expected = simd::scalar::dot_product_scalar(&a, &b);
assert!((result - expected).abs() < 0.001);
}
#[test]
fn test_conv_bn_relu_pipeline() {
let conv = Conv2d::new(3, 32, 3, 1, 1);
let bn = BatchNorm::new(32);
let relu = ReLU::new();
let input = Tensor::full(&[1, 32, 32, 3], 0.5);
let conv_out = conv.forward(&input).unwrap();
assert_eq!(conv_out.shape(), &[1, 32, 32, 32]);
let bn_out = bn.forward(&conv_out).unwrap();
assert_eq!(bn_out.shape(), conv_out.shape());
let relu_out = relu.forward(&bn_out).unwrap();
assert_eq!(relu_out.shape(), bn_out.shape());
for &val in relu_out.data() {
assert!(val >= 0.0);
}
}