pub(crate) use super::*;
#[test]
fn test_conv1d_shape() {
let conv = Conv1d::new(16, 32, 3);
let x = Tensor::ones(&[4, 16, 100]);
let y = conv.forward(&x);
assert_eq!(y.shape(), &[4, 32, 98]);
}
#[test]
fn test_conv1d_with_padding() {
let conv = Conv1d::with_padding(16, 32, 3, 1);
let x = Tensor::ones(&[4, 16, 100]);
let y = conv.forward(&x);
assert_eq!(y.shape(), &[4, 32, 100]);
}
#[test]
fn test_conv1d_with_stride() {
let conv = Conv1d::with_stride(16, 32, 3, 2);
let x = Tensor::ones(&[4, 16, 100]);
let y = conv.forward(&x);
assert_eq!(y.shape(), &[4, 32, 49]);
}
#[test]
fn test_conv1d_parameters() {
let conv = Conv1d::new(16, 32, 3);
let params = conv.parameters();
assert_eq!(params.len(), 2); assert_eq!(params[0].shape(), &[32, 16, 3]); assert_eq!(params[1].shape(), &[32]); }
#[test]
fn test_conv2d_shape() {
let conv = Conv2d::new(3, 64, 3);
let x = Tensor::ones(&[4, 3, 32, 32]);
let y = conv.forward(&x);
assert_eq!(y.shape(), &[4, 64, 30, 30]);
}
#[test]
fn test_conv2d_with_padding() {
let conv = Conv2d::with_padding(3, 64, 3, 1);
let x = Tensor::ones(&[4, 3, 32, 32]);
let y = conv.forward(&x);
assert_eq!(y.shape(), &[4, 64, 32, 32]);
}
#[test]
fn test_conv2d_with_stride() {
let conv = Conv2d::with_stride(3, 64, 3, 2);
let x = Tensor::ones(&[4, 3, 32, 32]);
let y = conv.forward(&x);
assert_eq!(y.shape(), &[4, 64, 15, 15]);
}
#[test]
fn test_conv2d_parameters() {
let conv = Conv2d::new(3, 64, 3);
let params = conv.parameters();
assert_eq!(params.len(), 2);
assert_eq!(params[0].shape(), &[64, 3, 3, 3]); assert_eq!(params[1].shape(), &[64]); }
#[test]
fn test_maxpool1d_shape() {
let pool = MaxPool1d::new(2);
let x = Tensor::ones(&[4, 16, 100]);
let y = pool.forward(&x);
assert_eq!(y.shape(), &[4, 16, 50]);
}
#[test]
fn test_maxpool2d_shape() {
let pool = MaxPool2d::new(2);
let x = Tensor::ones(&[4, 64, 32, 32]);
let y = pool.forward(&x);
assert_eq!(y.shape(), &[4, 64, 16, 16]);
}
#[test]
fn test_maxpool2d_values() {
let pool = MaxPool2d::new(2);
let x = Tensor::new(
&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
],
&[1, 1, 4, 4],
);
let y = pool.forward(&x);
assert_eq!(y.shape(), &[1, 1, 2, 2]);
assert_eq!(y.data(), &[6.0, 8.0, 14.0, 16.0]);
}
#[test]
fn test_avgpool2d_shape() {
let pool = AvgPool2d::new(2);
let x = Tensor::ones(&[4, 64, 32, 32]);
let y = pool.forward(&x);
assert_eq!(y.shape(), &[4, 64, 16, 16]);
}
#[test]
fn test_avgpool2d_values() {
let pool = AvgPool2d::new(2);
let x = Tensor::new(
&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
],
&[1, 1, 4, 4],
);
let y = pool.forward(&x);
assert_eq!(y.data(), &[3.5, 5.5, 11.5, 13.5]);
}
#[test]
fn test_global_avg_pool2d() {
let pool = GlobalAvgPool2d::new();
let x = Tensor::ones(&[2, 64, 7, 7]);
let y = pool.forward(&x);
assert_eq!(y.shape(), &[2, 64]);
assert!(y.data().iter().all(|&v| (v - 1.0).abs() < 1e-5));
}
#[test]
fn test_flatten() {
let flatten = Flatten::new();
let x = Tensor::ones(&[4, 64, 7, 7]);
let y = flatten.forward(&x);
assert_eq!(y.shape(), &[4, 64 * 7 * 7]);
}
#[test]
fn test_flatten_from_dim() {
let flatten = Flatten::from_dim(2);
let x = Tensor::ones(&[4, 64, 7, 7]);
let y = flatten.forward(&x);
assert_eq!(y.shape(), &[4, 64, 49]);
}
#[test]
fn test_conv1d_no_bias() {
let conv = Conv1d::with_options(16, 32, 3, 1, 0, false);
let params = conv.parameters();
assert_eq!(params.len(), 1); assert_eq!(params[0].shape(), &[32, 16, 3]);
}
#[test]
fn test_conv1d_getters() {
let conv = Conv1d::with_options(16, 32, 5, 2, 3, true);
assert_eq!(conv.kernel_size(), 5);
assert_eq!(conv.stride(), 2);
assert_eq!(conv.padding(), 3);
}
#[test]
fn test_conv1d_parameters_mut() {
let mut conv = Conv1d::new(8, 16, 3);
let params = conv.parameters_mut();
assert_eq!(params.len(), 2);
assert_eq!(params[0].shape(), &[16, 8, 3]);
}
#[test]
fn test_conv1d_parameters_mut_no_bias() {
let mut conv = Conv1d::with_options(8, 16, 3, 1, 0, false);
let params = conv.parameters_mut();
assert_eq!(params.len(), 1);
}
#[test]
fn test_conv1d_debug() {
let conv = Conv1d::new(16, 32, 3);
let debug_str = format!("{:?}", conv);
assert!(debug_str.contains("Conv1d"));
assert!(debug_str.contains("in_channels"));
assert!(debug_str.contains("out_channels"));
}
#[test]
fn test_conv1d_forward_no_bias() {
let conv = Conv1d::with_options(4, 8, 3, 1, 0, false);
let x = Tensor::ones(&[2, 4, 10]);
let y = conv.forward(&x);
assert_eq!(y.shape(), &[2, 8, 8]);
}
#[test]
fn test_conv1d_padding_zero_values() {
let conv = Conv1d::with_padding(1, 1, 3, 1);
let x = Tensor::new(&[1.0, 2.0, 3.0], &[1, 1, 3]);
let y = conv.forward(&x);
assert_eq!(y.shape(), &[1, 1, 3]);
}
#[test]
fn test_conv2d_no_bias() {
let conv = Conv2d::with_options(3, 64, (3, 3), (1, 1), (0, 0), false);
let params = conv.parameters();
assert_eq!(params.len(), 1); assert_eq!(params[0].shape(), &[64, 3, 3, 3]);
}
#[test]
fn test_conv2d_getters() {
let conv = Conv2d::with_options(3, 64, (5, 7), (2, 3), (1, 2), true);
assert_eq!(conv.kernel_size(), (5, 7));
assert_eq!(conv.stride(), (2, 3));
assert_eq!(conv.padding(), (1, 2));
}
#[test]
fn test_conv2d_parameters_mut() {
let mut conv = Conv2d::new(3, 32, 3);
let params = conv.parameters_mut();
assert_eq!(params.len(), 2);
assert_eq!(params[0].shape(), &[32, 3, 3, 3]);
}
#[test]
fn test_conv2d_parameters_mut_no_bias() {
let mut conv = Conv2d::with_options(3, 32, (3, 3), (1, 1), (0, 0), false);
let params = conv.parameters_mut();
assert_eq!(params.len(), 1);
}
#[test]
fn test_conv2d_debug() {
let conv = Conv2d::new(3, 64, 3);
let debug_str = format!("{:?}", conv);
assert!(debug_str.contains("Conv2d"));
assert!(debug_str.contains("in_channels"));
assert!(debug_str.contains("kernel_size"));
}
#[test]
fn test_conv2d_forward_no_bias() {
let conv = Conv2d::with_options(2, 4, (3, 3), (1, 1), (0, 0), false);
let x = Tensor::ones(&[2, 2, 8, 8]);
let y = conv.forward(&x);
assert_eq!(y.shape(), &[2, 4, 6, 6]);
}
#[test]
fn test_conv2d_padding_zero_values() {
let conv = Conv2d::with_padding(1, 1, 3, 1);
let x = Tensor::ones(&[1, 1, 4, 4]);
let y = conv.forward(&x);
assert_eq!(y.shape(), &[1, 1, 4, 4]);
}
#[test]
fn test_conv2d_non_square_kernel() {
let conv = Conv2d::with_options(2, 4, (3, 5), (1, 1), (0, 0), true);
let x = Tensor::ones(&[1, 2, 10, 12]);
let y = conv.forward(&x);
assert_eq!(y.shape(), &[1, 4, 8, 8]);
}
#[test]
fn test_maxpool1d_with_stride() {
let pool = MaxPool1d::with_stride(3, 2);
let x = Tensor::ones(&[2, 4, 10]);
let y = pool.forward(&x);
assert_eq!(y.shape(), &[2, 4, 4]);
}
#[test]
fn test_maxpool1d_values() {
let pool = MaxPool1d::new(2);
let x = Tensor::new(&[1.0, 4.0, 2.0, 3.0], &[1, 1, 4]);
let y = pool.forward(&x);
assert_eq!(y.data(), &[4.0, 3.0]);
}
#[test]
fn test_maxpool1d_debug() {
let pool = MaxPool1d::new(2);
let debug_str = format!("{:?}", pool);
assert!(debug_str.contains("MaxPool1d"));
}
#[test]
fn test_maxpool2d_with_stride() {
let pool = MaxPool2d::with_stride(3, 2);
let x = Tensor::ones(&[2, 4, 10, 10]);
let y = pool.forward(&x);
assert_eq!(y.shape(), &[2, 4, 4, 4]);
}
#[test]
fn test_maxpool2d_with_options() {
let pool = MaxPool2d::with_options((3, 5), (2, 3));
let x = Tensor::ones(&[1, 2, 10, 14]);
let y = pool.forward(&x);
assert_eq!(y.shape(), &[1, 2, 4, 4]);
}
#[test]
fn test_maxpool2d_debug() {
let pool = MaxPool2d::new(2);
let debug_str = format!("{:?}", pool);
assert!(debug_str.contains("MaxPool2d"));
}
#[test]
fn test_avgpool2d_with_stride() {
let pool = AvgPool2d::with_stride(3, 2);
let x = Tensor::ones(&[2, 4, 10, 10]);
let y = pool.forward(&x);
assert_eq!(y.shape(), &[2, 4, 4, 4]);
}
#[test]
fn test_avgpool2d_debug() {
let pool = AvgPool2d::new(2);
let debug_str = format!("{:?}", pool);
assert!(debug_str.contains("AvgPool2d"));
}
#[test]
fn test_global_avgpool2d_default() {
let pool = GlobalAvgPool2d::default();
let x = Tensor::ones(&[2, 8, 4, 4]);
let y = pool.forward(&x);
assert_eq!(y.shape(), &[2, 8]);
}
#[test]
fn test_global_avgpool2d_debug() {
let pool = GlobalAvgPool2d::new();
let debug_str = format!("{:?}", pool);
assert!(debug_str.contains("GlobalAvgPool2d"));
}
#[test]
fn test_global_avgpool2d_varied_values() {
let pool = GlobalAvgPool2d::new();
let x = Tensor::new(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
let y = pool.forward(&x);
assert!((y.data()[0] - 2.5).abs() < 1e-5);
}
#[test]
fn test_flatten_default() {
let flatten = Flatten::default();
let x = Tensor::ones(&[4, 8, 8, 8]);
let y = flatten.forward(&x);
assert_eq!(y.shape(), &[2048]);
}
#[test]
fn test_flatten_debug() {
let flatten = Flatten::new();
let debug_str = format!("{:?}", flatten);
assert!(debug_str.contains("Flatten"));
}
#[test]
fn test_flatten_no_op_for_2d() {
let flatten = Flatten::new();
let x = Tensor::ones(&[4, 64]);
let y = flatten.forward(&x);
assert_eq!(y.shape(), &[4, 64]);
}
#[path = "tests_conv1d_contract.rs"]
mod tests_conv1d_contract;
#[path = "tests_flatten.rs"]
mod tests_flatten;