use ndarray::Array2;
use rust_lstm::{
LSTMNetwork,
persistence::{ModelPersistence, PersistentModel, ModelMetadata},
training::create_basic_trainer,
};
use tempfile::tempdir;
#[test]
fn test_model_metadata_creation() {
let metadata = ModelMetadata {
model_name: "test_model".to_string(),
version: "0.2.0".to_string(),
created_at: "2024-01-01T00:00:00Z".to_string(),
input_size: 10,
hidden_size: 20,
num_layers: 2,
total_epochs: 100,
final_loss: Some(0.01),
description: Some("Test model for validation".to_string()),
};
assert_eq!(metadata.model_name, "test_model");
assert_eq!(metadata.input_size, 10);
assert_eq!(metadata.hidden_size, 20);
assert_eq!(metadata.num_layers, 2);
assert_eq!(metadata.total_epochs, 100);
assert_eq!(metadata.final_loss, Some(0.01));
}
#[test]
fn test_network_save_load_json() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test_model.json");
let mut network = LSTMNetwork::new(3, 4, 2);
let input = Array2::ones((3, 1));
let hx = Array2::zeros((4, 1));
let cx = Array2::zeros((4, 1));
let (output_before, _) = network.forward(&input, &hx, &cx);
let metadata = ModelMetadata {
model_name: "test_json_model".to_string(),
version: "0.2.0".to_string(),
created_at: chrono::Utc::now().to_rfc3339(),
input_size: 3,
hidden_size: 4,
num_layers: 2,
total_epochs: 0,
final_loss: None,
description: Some("Test JSON persistence".to_string()),
};
let result = network.save(&file_path, metadata.clone());
assert!(result.is_ok());
assert!(file_path.exists());
let (mut loaded_network, loaded_metadata) = LSTMNetwork::load(&file_path).unwrap();
assert_eq!(loaded_metadata.model_name, metadata.model_name);
assert_eq!(loaded_metadata.input_size, metadata.input_size);
assert_eq!(loaded_metadata.hidden_size, metadata.hidden_size);
assert_eq!(loaded_metadata.num_layers, metadata.num_layers);
assert_eq!(loaded_network.input_size, 3);
assert_eq!(loaded_network.hidden_size, 4);
assert_eq!(loaded_network.num_layers, 2);
let (output_after, _) = loaded_network.forward(&input, &hx, &cx);
assert_eq!(output_before.shape(), output_after.shape());
let diff = (&output_before - &output_after).mapv(|x| x.abs()).sum();
assert!(diff < 1e-10, "Loaded network output differs significantly from original");
}
#[test]
fn test_network_save_load_binary() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test_model.bin");
let mut network = LSTMNetwork::new(2, 3, 1)
.with_input_dropout(0.1, false)
.with_output_dropout(0.1);
let input = Array2::ones((2, 1));
let hx = Array2::zeros((3, 1));
let cx = Array2::zeros((3, 1));
network.eval(); let (output_before, _) = network.forward(&input, &hx, &cx);
let metadata = ModelMetadata {
model_name: "test_binary_model".to_string(),
version: "0.2.0".to_string(),
created_at: chrono::Utc::now().to_rfc3339(),
input_size: 2,
hidden_size: 3,
num_layers: 1,
total_epochs: 50,
final_loss: Some(0.05),
description: Some("Test binary persistence".to_string()),
};
let result = network.save(&file_path, metadata.clone());
assert!(result.is_ok());
assert!(file_path.exists());
let (mut loaded_network, loaded_metadata) = LSTMNetwork::load(&file_path).unwrap();
assert_eq!(loaded_metadata.model_name, metadata.model_name);
assert_eq!(loaded_metadata.total_epochs, metadata.total_epochs);
assert_eq!(loaded_metadata.final_loss, metadata.final_loss);
assert_eq!(loaded_network.input_size, 2);
assert_eq!(loaded_network.hidden_size, 3);
assert_eq!(loaded_network.num_layers, 1);
loaded_network.eval();
let (output_after, _) = loaded_network.forward(&input, &hx, &cx);
assert_eq!(output_before.shape(), output_after.shape());
let diff = (&output_before - &output_after).mapv(|x| x.abs()).sum();
assert!(diff < 1e-10, "Loaded network output differs significantly from original");
}
#[test]
fn test_model_persistence_create_saved_model() {
let network = LSTMNetwork::new(5, 10, 3);
let saved_model = ModelPersistence::create_saved_model(
&network,
"test_create_model".to_string(),
200,
Some(0.001),
Some("Created via ModelPersistence".to_string()),
);
assert_eq!(saved_model.metadata.model_name, "test_create_model");
assert_eq!(saved_model.metadata.total_epochs, 200);
assert_eq!(saved_model.metadata.final_loss, Some(0.001));
assert_eq!(saved_model.metadata.input_size, 5);
assert_eq!(saved_model.metadata.hidden_size, 10);
assert_eq!(saved_model.metadata.num_layers, 3);
}
#[test]
fn test_persistence_with_trained_model() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("trained_model.json");
let network = LSTMNetwork::new(1, 2, 1);
let mut trainer = create_basic_trainer(network, 0.01);
let train_data = vec![
(vec![Array2::ones((1, 1))], vec![Array2::zeros((2, 1))]),
(vec![Array2::zeros((1, 1))], vec![Array2::ones((2, 1))]),
];
let config = rust_lstm::training::TrainingConfig {
epochs: 2,
print_every: 1,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: None,
};
trainer = trainer.with_config(config);
trainer.train(&train_data, None);
let final_loss = trainer.get_latest_metrics().map(|m| m.train_loss);
let metadata = ModelMetadata {
model_name: "trained_test_model".to_string(),
version: "0.2.0".to_string(),
created_at: chrono::Utc::now().to_rfc3339(),
input_size: 1,
hidden_size: 2,
num_layers: 1,
total_epochs: 2,
final_loss,
description: Some("Trained model test".to_string()),
};
let save_result = trainer.network.save(&file_path, metadata);
assert!(save_result.is_ok());
let (mut loaded_network, loaded_metadata) = LSTMNetwork::load(&file_path).unwrap();
assert_eq!(loaded_metadata.model_name, "trained_test_model");
assert_eq!(loaded_metadata.total_epochs, 2);
let test_input = vec![Array2::ones((1, 1))];
loaded_network.eval();
let hx = Array2::zeros((2, 1));
let cx = Array2::zeros((2, 1));
let (output, _) = loaded_network.forward(&test_input[0], &hx, &cx);
assert_eq!(output.shape(), &[2, 1]);
}
#[test]
fn test_file_extension_detection() {
let dir = tempdir().unwrap();
let network = LSTMNetwork::new(2, 3, 1);
let metadata = ModelMetadata {
model_name: "extension_test".to_string(),
version: "0.2.0".to_string(),
created_at: chrono::Utc::now().to_rfc3339(),
input_size: 2,
hidden_size: 3,
num_layers: 1,
total_epochs: 0,
final_loss: None,
description: None,
};
let json_path = dir.path().join("model.json");
let result = network.save(&json_path, metadata.clone());
assert!(result.is_ok());
assert!(json_path.exists());
let bin_path = dir.path().join("model.bin");
let result = network.save(&bin_path, metadata.clone());
assert!(result.is_ok());
assert!(bin_path.exists());
let model_path = dir.path().join("model.model");
let result = network.save(&model_path, metadata.clone());
assert!(result.is_ok());
assert!(model_path.exists());
let unknown_path = dir.path().join("model.xyz");
let result = network.save(&unknown_path, metadata);
assert!(result.is_ok());
assert!(unknown_path.exists());
}
#[test]
fn test_error_handling() {
let result = LSTMNetwork::load("/non/existent/path.json");
assert!(result.is_err());
let network = LSTMNetwork::new(1, 1, 1);
let metadata = ModelMetadata {
model_name: "error_test".to_string(),
version: "0.2.0".to_string(),
created_at: chrono::Utc::now().to_rfc3339(),
input_size: 1,
hidden_size: 1,
num_layers: 1,
total_epochs: 0,
final_loss: None,
description: None,
};
let result = network.save("/invalid/path/that/does/not/exist.json", metadata);
assert!(result.is_err());
}