#![cfg(feature = "neural_network")]
use ndarray::{Array2, Array3, Array4, Array5};
use rustyml::neural_network::layer::activation_layer::relu::ReLU;
use rustyml::neural_network::layer::activation_layer::sigmoid::Sigmoid;
use rustyml::neural_network::layer::activation_layer::softmax::Softmax;
use rustyml::neural_network::layer::convolution_layer::PaddingType;
use rustyml::neural_network::layer::convolution_layer::conv_2d::Conv2D;
use rustyml::neural_network::layer::dense::Dense;
use rustyml::neural_network::layer::flatten::Flatten;
use rustyml::neural_network::layer::pooling_layer::max_pooling_2d::MaxPooling2D;
use rustyml::neural_network::loss_function::categorical_cross_entropy::CategoricalCrossEntropy;
use rustyml::neural_network::loss_function::mean_squared_error::MeanSquaredError;
use rustyml::neural_network::neural_network_trait::Layer;
use rustyml::neural_network::optimizer::adam::Adam;
use rustyml::neural_network::optimizer::sgd::SGD;
use rustyml::neural_network::sequential::Sequential;
#[test]
fn test_flatten_in_sequential() {
let x = Array4::ones((2, 3, 4, 4)).into_dyn();
let y = Array2::ones((2, 1)).into_dyn();
let mut model = Sequential::new();
model
.add(
Conv2D::new(
6, (3, 3), vec![2, 3, 4, 4], (1, 1), PaddingType::Valid, ReLU::new(), )
.unwrap(),
)
.add(
MaxPooling2D::new(
(2, 2), vec![2, 6, 2, 2], None, )
.unwrap(),
)
.add(Flatten::new(vec![2, 6, 1, 1]).unwrap()) .add(Dense::new(6, 1, Sigmoid::new()).unwrap()) .compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
model.summary();
model.fit(&x, &y, 3).unwrap();
let prediction = model.predict(&x).unwrap();
println!("CNN+Flatten prediction results: {:?}", prediction);
assert_eq!(prediction.shape(), &[2, 1]);
}
#[test]
fn test_flatten_only_model() {
let x = Array4::from_shape_fn((2, 3, 4, 4), |(b, c, h, w)| {
(b * 100 + c * 10 + h + w) as f32 / 10.0
})
.into_dyn();
let flattened_size = 3 * 4 * 4;
let mut model = Sequential::new();
model
.add(Flatten::new(vec![2, 3, 4, 4]).unwrap())
.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
model.summary();
let output = model.predict(&x).unwrap();
assert_eq!(output.shape(), &[2, flattened_size]);
}
#[test]
fn test_multiple_layers_with_flatten() {
let x = Array4::ones((2, 1, 8, 8)).into_dyn();
let y = Array2::ones((2, 10)).into_dyn();
let mut model = Sequential::new();
model
.add(
Conv2D::new(
8, (3, 3), vec![2, 1, 8, 8], (1, 1), PaddingType::Same, ReLU::new(), )
.unwrap(),
)
.add(
Conv2D::new(
16, (3, 3), vec![2, 8, 8, 8], (1, 1), PaddingType::Valid, ReLU::new(), )
.unwrap(),
)
.add(
MaxPooling2D::new(
(2, 2), vec![2, 16, 6, 6], Some((2, 2)), )
.unwrap(),
)
.add(Flatten::new(vec![2, 16, 3, 3]).unwrap()) .add(Dense::new(16 * 3 * 3, 20, ReLU::new()).unwrap()) .add(Dense::new(20, 10, Softmax::new()).unwrap()) .compile(
Adam::new(0.001, 0.9, 0.999, 1e-8).unwrap(),
CategoricalCrossEntropy::new(),
);
model.summary();
model.fit(&x, &y, 3).unwrap();
let prediction = model.predict(&x).unwrap();
println!(
"Complex CNN with Flatten prediction results: {:?}",
prediction
);
assert_eq!(prediction.shape(), &[2, 10]);
}
#[test]
fn test_flatten_3d() {
let input = Array3::ones((2, 10, 5)).into_dyn(); let mut flatten = Flatten::new(vec![2, 10, 5]).unwrap();
let output = flatten.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 50]);
let grad_output = Array2::ones((2, 50)).into_dyn();
let grad_input = flatten.backward(&grad_output).unwrap();
assert_eq!(grad_input.shape(), input.shape());
}
#[test]
fn test_flatten_4d() {
let input = Array4::ones((2, 3, 4, 4)).into_dyn(); let mut flatten = Flatten::new(vec![2, 3, 4, 4]).unwrap();
let output = flatten.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 48]);
let grad_output = Array2::ones((2, 48)).into_dyn();
let grad_input = flatten.backward(&grad_output).unwrap();
assert_eq!(grad_input.shape(), input.shape());
}
#[test]
fn test_flatten_5d() {
let input = Array5::ones((2, 3, 4, 8, 8)).into_dyn(); let mut flatten = Flatten::new(vec![2, 3, 4, 8, 8]).unwrap();
let output = flatten.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 768]);
let grad_output = Array2::ones((2, 768)).into_dyn();
let grad_input = flatten.backward(&grad_output).unwrap();
assert_eq!(grad_input.shape(), input.shape());
}