#![cfg(feature = "neural_network")]
use ndarray::Array4;
use rustyml::neural_network::layer::activation_layer::linear::Linear;
use rustyml::neural_network::layer::activation_layer::relu::ReLU;
use rustyml::neural_network::layer::activation_layer::sigmoid::Sigmoid;
use rustyml::neural_network::layer::activation_layer::tanh::Tanh;
use rustyml::neural_network::layer::convolution_layer::PaddingType;
use rustyml::neural_network::layer::convolution_layer::separable_conv_2d::SeparableConv2D;
use rustyml::neural_network::layer::layer_weight::LayerWeight;
use rustyml::neural_network::loss_function::mean_squared_error::MeanSquaredError;
use rustyml::neural_network::neural_network_trait::Layer;
use rustyml::neural_network::optimizer::adam::Adam;
use rustyml::neural_network::optimizer::rms_prop::RMSprop;
use rustyml::neural_network::optimizer::sgd::SGD;
use rustyml::neural_network::sequential::Sequential;
#[test]
fn test_separable_conv2d_basic() {
let x = Array4::ones((1, 3, 16, 16)).into_dyn();
let mut model = Sequential::new();
model
.add(
SeparableConv2D::new(
32, (3, 3), vec![1, 3, 16, 16], (1, 1), PaddingType::Same, 1, ReLU::new(), )
.unwrap(),
)
.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
model.summary();
let output = model.predict(&x).unwrap();
assert_eq!(output.shape(), &[1, 32, 16, 16]);
for &val in output.iter() {
assert!(val >= 0.0);
}
}
#[test]
fn test_separable_conv2d_different_depth_multiplier() {
let x = Array4::ones((2, 4, 8, 8)).into_dyn();
let mut model = Sequential::new();
model
.add(
SeparableConv2D::new(
16, (3, 3), vec![2, 4, 8, 8], (1, 1), PaddingType::Same, 2, Linear::new(), )
.unwrap(),
)
.compile(
RMSprop::new(0.001, 0.9, 1e-8).unwrap(),
MeanSquaredError::new(),
);
let output = model.predict(&x).unwrap();
assert_eq!(output.shape(), &[2, 16, 8, 8]);
}
#[test]
fn test_separable_conv2d_valid_padding() {
let x = Array4::ones((1, 2, 10, 10)).into_dyn();
let mut model = Sequential::new();
model
.add(
SeparableConv2D::new(
8, (3, 3), vec![1, 2, 10, 10], (1, 1), PaddingType::Valid, 1, Sigmoid::new(), )
.unwrap(),
)
.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
let output = model.predict(&x).unwrap();
assert_eq!(output.shape(), &[1, 8, 8, 8]);
for &val in output.iter() {
assert!(val >= 0.0 && val <= 1.0);
}
}
#[test]
fn test_separable_conv2d_with_strides() {
let x = Array4::ones((1, 3, 32, 32)).into_dyn();
let mut model = Sequential::new();
model
.add(
SeparableConv2D::new(
64, (3, 3), vec![1, 3, 32, 32], (2, 2), PaddingType::Same, 1, Tanh::new(), )
.unwrap(),
)
.compile(
Adam::new(0.001, 0.9, 0.999, 1e-8).unwrap(),
MeanSquaredError::new(),
);
let output = model.predict(&x).unwrap();
assert_eq!(output.shape(), &[1, 64, 16, 16]);
for &val in output.iter() {
assert!(val >= -1.0 && val <= 1.0);
}
}
#[test]
fn test_separable_conv2d_training() {
let x = Array4::from_elem((2, 3, 8, 8), 0.5).into_dyn();
let y = Array4::from_elem((2, 16, 8, 8), 1.0).into_dyn();
let mut model = Sequential::new();
model
.add(
SeparableConv2D::new(
16, (3, 3), vec![2, 3, 8, 8], (1, 1), PaddingType::Same, 1, ReLU::new(), )
.unwrap(),
)
.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
let result = model.fit(&x, &y, 3);
assert!(result.is_ok());
let output = model.predict(&x).unwrap();
assert_eq!(output.shape(), &[2, 16, 8, 8]);
}
#[test]
fn test_separable_conv2d_multiple_training_calls() {
let x = Array4::ones((1, 2, 4, 4)).into_dyn();
let y = Array4::ones((1, 8, 4, 4)).into_dyn();
let mut model = Sequential::new();
model
.add(
SeparableConv2D::new(
8, (3, 3), vec![1, 2, 4, 4], (1, 1), PaddingType::Same, 1, Linear::new(), )
.unwrap(),
)
.compile(
RMSprop::new(0.001, 0.9, 1e-8).unwrap(),
MeanSquaredError::new(),
);
let result1 = model.fit(&x, &y, 1);
assert!(result1.is_ok());
let result2 = model.fit(&x, &y, 1);
assert!(result2.is_ok());
}
#[test]
fn test_separable_conv2d_different_optimizers() {
let x = Array4::ones((1, 3, 16, 16)).into_dyn();
let y = Array4::ones((1, 8, 16, 16)).into_dyn();
let mut model_sgd = Sequential::new();
model_sgd
.add(
SeparableConv2D::new(
8,
(3, 3),
vec![1, 3, 16, 16],
(1, 1),
PaddingType::Same,
1,
Linear::new(),
)
.unwrap(),
)
.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
let result_sgd = model_sgd.fit(&x, &y, 1);
assert!(result_sgd.is_ok());
let mut model_adam = Sequential::new();
model_adam
.add(
SeparableConv2D::new(
8,
(3, 3),
vec![1, 3, 16, 16],
(1, 1),
PaddingType::Same,
1,
Linear::new(),
)
.unwrap(),
)
.compile(
Adam::new(0.001, 0.9, 0.999, 1e-8).unwrap(),
MeanSquaredError::new(),
);
let result_adam = model_adam.fit(&x, &y, 1);
assert!(result_adam.is_ok());
let mut model_rmsprop = Sequential::new();
model_rmsprop
.add(
SeparableConv2D::new(
8,
(3, 3),
vec![1, 3, 16, 16],
(1, 1),
PaddingType::Same,
1,
Linear::new(),
)
.unwrap(),
)
.compile(
RMSprop::new(0.001, 0.9, 1e-8).unwrap(),
MeanSquaredError::new(),
);
let result_rmsprop = model_rmsprop.fit(&x, &y, 1);
assert!(result_rmsprop.is_ok());
}
#[test]
fn test_separable_conv2d_batch_processing() {
let batch_sizes = vec![1, 2, 4, 8];
for &batch_size in &batch_sizes {
let x = Array4::ones((batch_size, 3, 8, 8)).into_dyn();
let mut model = Sequential::new();
model
.add(
SeparableConv2D::new(
16,
(3, 3),
vec![batch_size, 3, 8, 8],
(1, 1),
PaddingType::Same,
1,
Linear::new(),
)
.unwrap(),
)
.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
let output = model.predict(&x).unwrap();
assert_eq!(output.shape(), &[batch_size, 16, 8, 8]);
}
}
#[test]
fn test_separable_conv2d_large_kernel() {
let x = Array4::ones((1, 4, 32, 32)).into_dyn();
let mut model = Sequential::new();
model
.add(
SeparableConv2D::new(
64, (5, 5), vec![1, 4, 32, 32], (1, 1), PaddingType::Same, 1, ReLU::new(), )
.unwrap(),
)
.compile(
Adam::new(0.001, 0.9, 0.999, 1e-8).unwrap(),
MeanSquaredError::new(),
);
let output = model.predict(&x).unwrap();
assert_eq!(output.shape(), &[1, 64, 32, 32]);
}
#[test]
fn test_separable_conv2d_get_weights() {
let mut model = Sequential::new();
model.add(
SeparableConv2D::new(
8,
(3, 3),
vec![1, 3, 16, 16],
(1, 1),
PaddingType::Same,
1,
Linear::new(),
)
.unwrap(),
);
let weights = model.get_weights();
assert_eq!(weights.len(), 1);
match &weights[0] {
LayerWeight::SeparableConv2DLayer(_) => {
}
_ => panic!("Expected Conv2D weight type"),
}
}
#[test]
fn test_separable_conv2d_output_shape_calculation() {
let layer = SeparableConv2D::new(
32,
(3, 3),
vec![2, 16, 64, 64],
(2, 2),
PaddingType::Valid,
1,
Linear::new(),
)
.unwrap();
let expected_shape = "(2, 32, 31, 31)";
assert_eq!(layer.output_shape(), expected_shape);
}