#![cfg(feature = "neural_network")]
use approx::assert_abs_diff_eq;
use ndarray::{Array, Array2, Array3};
use rustyml::neural_network::Tensor;
use rustyml::neural_network::layer::activation_layer::relu::ReLU;
use rustyml::neural_network::layer::activation_layer::sigmoid::Sigmoid;
use rustyml::neural_network::layer::activation_layer::softmax::Softmax;
use rustyml::neural_network::layer::activation_layer::tanh::Tanh;
use rustyml::neural_network::layer::dense::Dense;
use rustyml::neural_network::layer::recurrent_layer::simple_rnn::SimpleRNN;
use rustyml::neural_network::loss_function::mean_squared_error::MeanSquaredError;
use rustyml::neural_network::optimizer::adam::Adam;
use rustyml::neural_network::optimizer::rms_prop::RMSprop;
use rustyml::neural_network::sequential::Sequential;
#[test]
fn test_simple_rnn_layer() {
let x = Array::ones((2, 5, 4)).into_dyn();
let y = Array::ones((2, 3)).into_dyn();
let mut model = Sequential::new();
model
.add(SimpleRNN::new(4, 3, Tanh::new()).unwrap())
.compile(
RMSprop::new(0.001, 0.9, 1e-8).unwrap(),
MeanSquaredError::new(),
);
model.summary();
model.fit(&x, &y, 1).unwrap();
let pred = model.predict(&x);
println!("SimpleRnn prediction:\n{:#?}\n", pred);
}
#[test]
fn test_simple_rnn_layer_basic() {
let x = Array::ones((2, 5, 4)).into_dyn();
let y = Array::ones((2, 3)).into_dyn();
let mut model = Sequential::new();
model
.add(SimpleRNN::new(4, 3, Tanh::new()).unwrap())
.compile(
RMSprop::new(0.001, 0.9, 1e-8).unwrap(),
MeanSquaredError::new(),
);
model.summary();
model.fit(&x, &y, 1).unwrap();
let pred = model.predict(&x).unwrap();
assert_eq!(pred.shape(), &[2, 3]);
}
#[test]
fn test_simple_rnn_different_activations() {
let x = Array::ones((3, 4, 2)).into_dyn();
let y = Array::ones((3, 6)).into_dyn();
let mut model_relu = Sequential::new();
model_relu
.add(SimpleRNN::new(2, 6, ReLU::new()).unwrap())
.compile(
RMSprop::new(0.001, 0.9, 1e-8).unwrap(),
MeanSquaredError::new(),
);
model_relu.fit(&x, &y, 3).unwrap();
let pred_relu = model_relu.predict(&x).unwrap();
let mut model_sigmoid = Sequential::new();
model_sigmoid
.add(SimpleRNN::new(2, 6, Sigmoid::new()).unwrap())
.compile(
RMSprop::new(0.001, 0.9, 1e-8).unwrap(),
MeanSquaredError::new(),
);
model_sigmoid.fit(&x, &y, 3).unwrap();
let pred_sigmoid = model_sigmoid.predict(&x).unwrap();
assert_eq!(pred_relu.shape(), &[3, 6]);
assert_eq!(pred_sigmoid.shape(), &[3, 6]);
for v in pred_relu.iter() {
assert!(*v >= 0.0);
}
for v in pred_sigmoid.iter() {
assert!(*v >= 0.0 && *v <= 1.0);
}
}
#[test]
fn test_simple_rnn_sequential_composition() {
let x = Array::ones((2, 5, 3)).into_dyn();
let y = Array::ones((2, 4)).into_dyn();
let mut model = Sequential::new();
model
.add(SimpleRNN::new(3, 6, Tanh::new()).unwrap())
.add(Dense::new(6, 4, Sigmoid::new()).unwrap())
.compile(
RMSprop::new(0.001, 0.9, 1e-8).unwrap(),
MeanSquaredError::new(),
);
model.summary();
model.fit(&x, &y, 5).unwrap();
let pred = model.predict(&x).unwrap();
assert_eq!(pred.shape(), &[2, 4]);
for v in pred.iter() {
assert!(*v >= 0.0 && *v <= 1.0);
}
}
#[test]
fn test_simple_rnn_overfitting() {
let x = Array::ones((2, 4, 3)).into_dyn();
let y = Array::ones((2, 7)).into_dyn();
let mut model = Sequential::new();
model
.add(SimpleRNN::new(3, 7, Tanh::new()).unwrap())
.compile(
RMSprop::new(0.01, 0.9, 1e-8).unwrap(),
MeanSquaredError::new(),
);
model.fit(&x, &y, 200).unwrap();
let pred = model.predict(&x).unwrap();
for (pred_val, target_val) in pred.iter().zip(y.iter()) {
assert_abs_diff_eq!(*pred_val, *target_val, epsilon = 0.3);
}
}
#[test]
fn test_simple_rnn_sequence_memory() {
let batch_size = 8;
let seq_len = 6;
let input_dim = 3;
let mut x = Array3::<f32>::zeros((batch_size, seq_len, input_dim));
let mut y = Array2::<f32>::zeros((batch_size, 2));
for b in 0..batch_size {
if b < 4 {
x[[b, 0, 0]] = 1.0; x[[b, 0, 1]] = 0.0;
x[[b, 0, 2]] = 0.0;
y[[b, 0]] = 1.0;
y[[b, 1]] = 0.0;
} else {
x[[b, 0, 0]] = 0.0; x[[b, 0, 1]] = 1.0;
x[[b, 0, 2]] = 0.0;
y[[b, 0]] = 0.0;
y[[b, 1]] = 1.0;
}
for t in 1..seq_len {
x[[b, t, 0]] = 0.1 * ((b * t) as f32).sin();
x[[b, t, 1]] = 0.1 * ((b * t) as f32).cos();
x[[b, t, 2]] = 0.05 * (t as f32);
}
}
let x = x.into_dyn();
let y = y.into_dyn();
let mut model = Sequential::new();
model
.add(SimpleRNN::new(input_dim, 12, Tanh::new()).unwrap())
.add(Dense::new(12, 2, Softmax::new()).unwrap())
.compile(
Adam::new(0.005, 0.9, 0.999, 1e-8).unwrap(),
MeanSquaredError::new(),
);
model.fit(&x, &y, 100).unwrap();
let pred = model.predict(&x).unwrap();
let mut correct_predictions = 0;
for b in 0..batch_size {
let pred_class = if pred[[b, 0]] > pred[[b, 1]] { 0 } else { 1 };
let true_class = if y[[b, 0]] > y[[b, 1]] { 0 } else { 1 };
if pred_class == true_class {
correct_predictions += 1;
}
println!(
"Sample {}: True class {}, Pred class {}, Confidence [{:.3}, {:.3}]",
b,
true_class,
pred_class,
pred[[b, 0]],
pred[[b, 1]]
);
}
let accuracy = correct_predictions as f32 / batch_size as f32;
assert!(
accuracy >= 0.6,
"Memory task accuracy too low: {:.2}",
accuracy
);
println!("SimpleRNN Memory Task Accuracy: {:.1}%", accuracy * 100.0);
}
#[test]
fn test_simple_rnn_vanishing_gradient_susceptibility() {
let batch_size = 3;
let short_seq_len = 5;
let long_seq_len = 15;
let input_dim = 2;
let units = 4;
let create_sequence_data = |seq_len: usize| -> (Tensor, Tensor) {
let mut x = Array3::<f32>::zeros((batch_size, seq_len, input_dim));
let mut y = Array2::<f32>::zeros((batch_size, units));
for b in 0..batch_size {
x[[b, 0, 0]] = if b % 2 == 0 { 1.0 } else { -1.0 };
x[[b, 0, 1]] = 0.5;
for t in 1..seq_len - 1 {
x[[b, t, 0]] = 0.1 * (t as f32).sin();
x[[b, t, 1]] = 0.1 * (t as f32).cos();
}
for u in 0..units {
y[[b, u]] = 0.5 + 0.3 * x[[b, 0, 0]];
}
}
(x.into_dyn(), y.into_dyn())
};
let (x_short, y_short) = create_sequence_data(short_seq_len);
let mut model_short = Sequential::new();
model_short
.add(SimpleRNN::new(input_dim, units, Tanh::new()).unwrap())
.compile(
Adam::new(0.01, 0.9, 0.999, 1e-8).unwrap(),
MeanSquaredError::new(),
);
let initial_loss_short = {
let pred = model_short.predict(&x_short).unwrap();
let diff = &pred - &y_short;
diff.mapv(|x| x.powi(2)).sum() / pred.len() as f32
};
model_short.fit(&x_short, &y_short, 25).unwrap();
let final_loss_short = {
let pred = model_short.predict(&x_short).unwrap();
let diff = &pred - &y_short;
diff.mapv(|x| x.powi(2)).sum() / pred.len() as f32
};
let (x_long, y_long) = create_sequence_data(long_seq_len);
let mut model_long = Sequential::new();
model_long
.add(SimpleRNN::new(input_dim, units, Tanh::new()).unwrap())
.compile(
Adam::new(0.01, 0.9, 0.999, 1e-8).unwrap(),
MeanSquaredError::new(),
);
let initial_loss_long = {
let pred = model_long.predict(&x_long).unwrap();
let diff = &pred - &y_long;
diff.mapv(|x| x.powi(2)).sum() / pred.len() as f32
};
model_long.fit(&x_long, &y_long, 25).unwrap();
let final_loss_long = {
let pred = model_long.predict(&x_long).unwrap();
let diff = &pred - &y_long;
diff.mapv(|x| x.powi(2)).sum() / pred.len() as f32
};
let improvement_short = (initial_loss_short - final_loss_short) / initial_loss_short;
let improvement_long = (initial_loss_long - final_loss_long) / initial_loss_long;
println!("SimpleRNN Vanishing Gradient Analysis:");
println!(
" Short sequence ({} steps): {:.6} -> {:.6} (improvement: {:.1}%)",
short_seq_len,
initial_loss_short,
final_loss_short,
improvement_short * 100.0
);
println!(
" Long sequence ({} steps): {:.6} -> {:.6} (improvement: {:.1}%)",
long_seq_len,
initial_loss_long,
final_loss_long,
improvement_long * 100.0
);
assert!(
improvement_short > 0.05,
"Short sequence should show significant improvement"
);
if improvement_long < improvement_short * 0.7 {
println!(
" Note: Long sequence shows reduced learning, indicating vanishing gradient effects"
);
}
}