use ndarray::Array2;
use rust_lstm::{
LSTMNetwork, LayerDropoutConfig, LSTMTrainer, TrainingConfig,
layers::dropout::{Dropout, Zoneout},
layers::peephole_lstm_cell::PeepholeLSTMCell,
optimizers::{SGD, Adam, RMSprop},
loss::{MSELoss, MAELoss, CrossEntropyLoss},
training::create_basic_trainer,
};
#[test]
fn test_basic_forward_pass_example() {
let input_size = 3;
let hidden_size = 2;
let num_layers = 2;
let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers);
let input = Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap();
let hx = Array2::zeros((hidden_size, 1));
let cx = Array2::zeros((hidden_size, 1));
let (output, _) = network.forward(&input, &hx, &cx);
assert_eq!(output.shape(), &[hidden_size, 1]);
}
#[test]
fn test_dropout_regularization_example() {
let input_size = 10;
let hidden_size = 20;
let num_layers = 3;
let mut network = LSTMNetwork::new(input_size, hidden_size, num_layers)
.with_input_dropout(0.2, true) .with_recurrent_dropout(0.3, true) .with_output_dropout(0.1) .with_zoneout(0.05, 0.1);
let layer_configs = vec![
LayerDropoutConfig::new()
.with_input_dropout(0.1, false),
LayerDropoutConfig::new()
.with_recurrent_dropout(0.2, true)
.with_zoneout(0.05, 0.1),
LayerDropoutConfig::new()
.with_output_dropout(0.1),
];
let mut custom_network = LSTMNetwork::new(input_size, hidden_size, num_layers)
.with_layer_dropout(layer_configs);
network.train();
network.eval();
let input = Array2::zeros((input_size, 1));
let hx = Array2::zeros((hidden_size, 1));
let cx = Array2::zeros((hidden_size, 1));
let (output1, _) = network.forward(&input, &hx, &cx);
let (output2, _) = custom_network.forward(&input, &hx, &cx);
assert_eq!(output1.shape(), &[hidden_size, 1]);
assert_eq!(output2.shape(), &[hidden_size, 1]);
}
#[test]
fn test_training_example() {
let network = LSTMNetwork::new(1, 4, 1) .with_input_dropout(0.2, true)
.with_recurrent_dropout(0.3, true)
.with_output_dropout(0.1);
let loss_function = MSELoss;
let optimizer = Adam::new(0.001);
let mut trainer = LSTMTrainer::new(network, loss_function, optimizer);
let config = TrainingConfig {
epochs: 2, print_every: 1,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: None,
};
trainer = trainer.with_config(config);
let train_data = generate_test_data();
trainer.train(&train_data, None);
let input_sequence = vec![Array2::zeros((1, 1)), Array2::ones((1, 1))];
let predictions = trainer.predict(&input_sequence);
assert_eq!(predictions.len(), 2);
assert_eq!(predictions[0].shape(), &[4, 1]);
}
#[test]
fn test_dropout_types_example() {
let mut dropout = Dropout::new(0.3);
let mut variational_dropout = Dropout::variational(0.3);
let zoneout = Zoneout::new(0.1, 0.15);
let input = Array2::ones((3, 1));
dropout.train();
let _output1 = dropout.forward(&input);
variational_dropout.train();
let _output2 = variational_dropout.forward(&input);
let prev_state = Array2::zeros((3, 1));
let _output3 = zoneout.apply_cell_zoneout(&input, &prev_state);
}
#[test]
fn test_optimizers_example() {
let _sgd = SGD::new(0.01);
let _adam = Adam::with_params(0.001, 0.9, 0.999, 1e-8);
let _rmsprop = RMSprop::new(0.01);
assert!(true);
}
#[test]
fn test_loss_functions_example() {
let _mse_loss = MSELoss;
let _mae_loss = MAELoss;
let _ce_loss = CrossEntropyLoss;
assert!(true);
}
#[test]
fn test_peephole_lstm_example() {
let input_size = 3;
let hidden_size = 4;
let cell = PeepholeLSTMCell::new(input_size, hidden_size);
let input = Array2::ones((input_size, 1));
let h_prev = Array2::zeros((hidden_size, 1));
let c_prev = Array2::zeros((hidden_size, 1));
let (h_t, c_t) = cell.forward(&input, &h_prev, &c_prev);
assert_eq!(h_t.shape(), &[hidden_size, 1]);
assert_eq!(c_t.shape(), &[hidden_size, 1]);
}
#[test]
fn test_create_basic_trainer() {
let network = LSTMNetwork::new(2, 3, 1);
let _trainer = create_basic_trainer(network, 0.01);
assert!(true);
}
fn generate_test_data() -> Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)> {
let mut data = Vec::new();
for _seq_idx in 0..3 { let mut inputs = Vec::new();
let mut targets = Vec::new();
for _t in 0..2 { let input = Array2::ones((1, 1));
let target = Array2::ones((4, 1)) * 0.5;
inputs.push(input);
targets.push(target);
}
data.push((inputs, targets));
}
data
}