#![cfg(feature = "neural_network")]
use ndarray::Array5;
use rustyml::neural_network::layer::ReLU;
use rustyml::neural_network::layer::TrainingParameters;
use rustyml::neural_network::layer::activation_layer::sigmoid::Sigmoid;
use rustyml::neural_network::layer::activation_layer::tanh::Tanh;
use rustyml::neural_network::layer::convolution_layer::PaddingType;
use rustyml::neural_network::layer::convolution_layer::conv_3d::Conv3D;
use rustyml::neural_network::loss_function::mean_squared_error::MeanSquaredError;
use rustyml::neural_network::neural_network_trait::Layer;
use rustyml::neural_network::optimizer::adam::Adam;
use rustyml::neural_network::optimizer::rms_prop::RMSprop;
use rustyml::neural_network::optimizer::sgd::SGD;
use rustyml::neural_network::sequential::Sequential;
#[test]
fn test_conv3d_sequential_with_sgd() {
let x = Array5::ones((2, 1, 8, 8, 8)).into_dyn();
let y = Array5::ones((2, 3, 6, 6, 6)).into_dyn();
let mut model = Sequential::new();
model
.add(
Conv3D::new(
3, (3, 3, 3), vec![2, 1, 8, 8, 8], (1, 1, 1), PaddingType::Valid, ReLU::new(), )
.unwrap(),
)
.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
model.summary();
let result = model.fit(&x, &y, 3);
assert!(result.is_ok());
let prediction = model.predict(&x).unwrap();
assert_eq!(prediction.shape(), &[2, 3, 6, 6, 6]);
for value in prediction.iter() {
assert!(*value >= 0.0);
}
}
#[test]
fn test_conv3d_sequential_with_rmsprop() {
let x = Array5::from_shape_fn((2, 2, 6, 6, 6), |(b, c, d, h, w)| {
((b * 2 + c * 3 + d + h + w) as f32).sin() * 0.5
})
.into_dyn();
let y = Array5::zeros((2, 2, 4, 4, 4)).into_dyn();
let mut model = Sequential::new();
model
.add(
Conv3D::new(
2, (3, 3, 3), vec![2, 2, 6, 6, 6], (1, 1, 1), PaddingType::Valid, Tanh::new(), )
.unwrap(),
)
.compile(
RMSprop::new(0.001, 0.9, 1e-8).unwrap(),
MeanSquaredError::new(),
);
model.summary();
let result = model.fit(&x, &y, 4);
assert!(result.is_ok());
let prediction = model.predict(&x).unwrap();
assert_eq!(prediction.shape(), &[2, 2, 4, 4, 4]);
for value in prediction.iter() {
assert!(*value >= -1.0 && *value <= 1.0);
}
}
#[test]
fn test_conv3d_different_strides() {
let x = Array5::ones((1, 1, 10, 10, 10)).into_dyn();
let stride_2_conv = Conv3D::new(
1,
(3, 3, 3),
vec![1, 1, 10, 10, 10],
(2, 2, 2), PaddingType::Valid,
ReLU::new(),
)
.unwrap();
let mut model = Sequential::new();
model
.add(stride_2_conv)
.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
let prediction = model.predict(&x).unwrap();
assert_eq!(prediction.shape(), &[1, 1, 4, 4, 4]);
}
#[test]
fn test_conv3d_multiple_channels() {
let x = Array5::from_shape_fn((2, 3, 6, 6, 6), |(b, c, d, h, w)| {
(b + c + d + h + w) as f32 * 0.1
})
.into_dyn();
let y = Array5::ones((2, 5, 4, 4, 4)).into_dyn();
let mut model = Sequential::new();
model
.add(
Conv3D::new(
5, (3, 3, 3), vec![2, 3, 6, 6, 6], (1, 1, 1), PaddingType::Valid, ReLU::new(), )
.unwrap(),
)
.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
model.summary();
let result = model.fit(&x, &y, 2);
assert!(result.is_ok());
let prediction = model.predict(&x).unwrap();
assert_eq!(prediction.shape(), &[2, 5, 4, 4, 4]);
}
#[test]
fn test_conv3d_activation_functions() {
let x = Array5::from_shape_fn((1, 1, 4, 4, 4), |(_, _, d, _, _)| {
d as f32 - 2.0 })
.into_dyn();
let mut relu_model = Sequential::new();
relu_model
.add(
Conv3D::new(
1,
(2, 2, 2),
vec![1, 1, 4, 4, 4],
(1, 1, 1),
PaddingType::Valid,
ReLU::new(),
)
.unwrap(),
)
.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
let relu_output = relu_model.predict(&x).unwrap();
for value in relu_output.iter() {
assert!(*value >= 0.0);
}
let mut sigmoid_model = Sequential::new();
sigmoid_model
.add(
Conv3D::new(
1,
(2, 2, 2),
vec![1, 1, 4, 4, 4],
(1, 1, 1),
PaddingType::Valid,
Sigmoid::new(),
)
.unwrap(),
)
.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
let sigmoid_output = sigmoid_model.predict(&x).unwrap();
for value in sigmoid_output.iter() {
assert!(*value >= 0.0 && *value <= 1.0);
}
}
#[test]
fn test_conv3d_parameter_count() {
let conv3d = Conv3D::new(
4, (3, 3, 3), vec![2, 2, 5, 5, 5], (1, 1, 1),
PaddingType::Valid,
ReLU::new(),
)
.unwrap();
assert_eq!(conv3d.param_count(), TrainingParameters::Trainable(220));
}
#[test]
fn test_conv3d_same_padding() {
let x = Array5::ones((1, 1, 8, 8, 8)).into_dyn();
let mut model = Sequential::new();
model
.add(
Conv3D::new(
2,
(3, 3, 3),
vec![1, 1, 8, 8, 8],
(1, 1, 1),
PaddingType::Same,
ReLU::new(),
)
.unwrap(),
)
.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
let prediction = model.predict(&x).unwrap();
assert_eq!(prediction.shape(), &[1, 2, 8, 8, 8]);
}
#[test]
fn test_conv3d_with_adam() {
let x = Array5::from_shape_fn((2, 1, 6, 6, 6), |(b, _, d, h, w)| {
((b + d + h + w) as f32) * 0.1
})
.into_dyn();
let y = Array5::ones((2, 2, 4, 4, 4)).into_dyn();
let mut model = Sequential::new();
model
.add(
Conv3D::new(
2,
(3, 3, 3),
vec![2, 1, 6, 6, 6],
(1, 1, 1),
PaddingType::Valid,
ReLU::new(),
)
.unwrap(),
)
.compile(
Adam::new(0.001, 0.9, 0.999, 1e-8).unwrap(),
MeanSquaredError::new(),
);
model.summary();
let result = model.fit(&x, &y, 3);
assert!(result.is_ok());
let prediction = model.predict(&x).unwrap();
assert_eq!(prediction.shape(), &[2, 2, 4, 4, 4]);
}
#[test]
fn test_conv3d_asymmetric_stride() {
let x = Array5::ones((1, 1, 12, 8, 10)).into_dyn();
let mut model = Sequential::new();
model
.add(
Conv3D::new(
1,
(3, 3, 3),
vec![1, 1, 12, 8, 10],
(2, 1, 2), PaddingType::Valid,
ReLU::new(),
)
.unwrap(),
)
.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
let prediction = model.predict(&x).unwrap();
assert_eq!(prediction.shape(), &[1, 1, 5, 6, 4]);
}