use super::high_level::*;
use super::sequential::*;
use super::sequential_basic::*;
use crate::autograd::Variable;
use crate::nn::{Linear, Module};
use crate::tensor::Tensor;
#[cfg(test)]
mod sequential_integration_tests {
use super::*;
#[test]
fn test_sequential_basic_operations() {
let mut model: Sequential<f32> = Sequential::new();
assert_eq!(model.len(), 0);
assert!(model.is_empty());
assert!(!model.is_compiled());
model.add_layer(Linear::new(10, 5));
assert_eq!(model.len(), 1);
assert!(!model.is_empty());
model.add_layer(Linear::new(5, 2));
assert_eq!(model.len(), 2);
assert!(model.remove(1).is_ok());
assert_eq!(model.len(), 1);
assert!(model.remove(5).is_err());
}
#[test]
fn test_sequential_builder_pattern() {
let model: Sequential<f32> = Sequential::new()
.add(Linear::new(784, 128))
.add(Linear::new(128, 64))
.add(Linear::new(64, 10));
assert_eq!(model.len(), 3);
assert!(!model.is_compiled());
assert!(model.total_parameters() > 0);
}
#[test]
fn test_sequential_builder_with_name() {
let model: Sequential<f32> = SequentialBuilder::with_name("test_model")
.add(Linear::new(100, 50))
.add(Linear::new(50, 25))
.build();
assert_eq!(model.len(), 2);
let summary = model.summary();
assert!(summary.contains("test_model"));
}
#[test]
fn test_sequential_forward_pass() {
let model: Sequential<f32> = Sequential::new()
.add(Linear::new(5, 3))
.add(Linear::new(3, 1));
let input_data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let input_tensor = Tensor::from_vec(input_data, vec![1, 5]);
let input_var = Variable::new(input_tensor, false);
let output = model.forward(&input_var);
assert!(output.data().read().unwrap().shape().len() > 0);
}
#[test]
fn test_sequential_training_mode() {
let mut model: Sequential<f32> = Sequential::new().add(Linear::new(10, 5));
model.train();
model.eval();
assert_eq!(model.len(), 1);
}
#[test]
fn test_sequential_summary() {
let model: Sequential<f32> = Sequential::with_name("summary_test")
.add(Linear::new(784, 256))
.add(Linear::new(256, 128))
.add(Linear::new(128, 10));
let summary = model.summary();
assert!(summary.contains("summary_test"));
assert!(summary.contains("Sequential Model") || summary.contains("summary_test"));
assert!(summary.contains("Total params"));
assert!(summary.contains("Trainable params"));
}
#[test]
fn test_sequential_parameter_counting() {
let model: Sequential<f32> = Sequential::new()
.add(Linear::new(10, 5)) .add(Linear::new(5, 2));
let total_params = model.total_parameters();
let trainable_params = model.trainable_parameters();
assert!(total_params > 0);
assert_eq!(total_params, trainable_params); }
#[test]
fn test_sequential_validation() {
let empty_model: Sequential<f32> = Sequential::new();
assert!(empty_model.validate().is_err());
let model: Sequential<f32> = Sequential::new().add(Linear::new(10, 5));
assert!(model.validate().is_err());
}
#[test]
fn test_sequential_clear() {
let mut model: Sequential<f32> = Sequential::new()
.add(Linear::new(10, 5))
.add(Linear::new(5, 2));
assert_eq!(model.len(), 2);
model.clear();
assert_eq!(model.len(), 0);
assert!(model.is_empty());
assert!(!model.is_compiled());
}
#[test]
fn test_sequential_layer_access() {
let model: Sequential<f32> = Sequential::new()
.add(Linear::new(10, 5))
.add(Linear::new(5, 2));
assert!(model.get_layer(0).is_some());
assert!(model.get_layer(1).is_some());
assert!(model.get_layer(2).is_none());
let layers = model.layers();
assert_eq!(layers.len(), 2);
}
#[test]
fn test_sequential_insert() {
let mut model: Sequential<f32> = Sequential::new()
.add(Linear::new(10, 5))
.add(Linear::new(5, 2));
assert!(model.insert(1, Linear::new(5, 3)).is_ok());
assert_eq!(model.len(), 3);
assert!(model.insert(10, Linear::new(2, 1)).is_err());
}
}
#[cfg(test)]
mod high_level_api_tests {
use super::*;
#[test]
fn test_training_history_creation() {
let history = TrainingHistory::<f32>::new();
assert!(history.train_loss.is_empty());
assert!(history.val_loss.is_empty());
assert!(history.metrics.is_empty());
assert_eq!(history.training_time, 0.0);
assert!(history.best_val_loss.is_none());
assert!(history.best_epoch.is_none());
}
#[test]
fn test_training_history_add_epoch() {
let mut history = TrainingHistory::<f32>::new();
let mut epoch_metrics = std::collections::HashMap::new();
epoch_metrics.insert("accuracy".to_string(), 0.85);
epoch_metrics.insert("precision".to_string(), 0.82);
history.add_epoch(0.6, Some(0.7), epoch_metrics.clone());
assert_eq!(history.train_loss.len(), 1);
assert_eq!(history.val_loss.len(), 1);
assert_eq!(history.best_val_loss, Some(0.7));
assert_eq!(history.best_epoch, Some(0));
history.add_epoch(0.5, Some(0.6), epoch_metrics.clone());
assert_eq!(history.train_loss.len(), 2);
assert_eq!(history.val_loss.len(), 2);
assert_eq!(history.best_val_loss, Some(0.6));
assert_eq!(history.best_epoch, Some(1));
}
#[test]
fn test_training_history_summary() {
let mut history = TrainingHistory::<f32>::new();
history.training_time = 120.5;
let mut epoch_metrics = std::collections::HashMap::new();
epoch_metrics.insert("accuracy".to_string(), 0.85);
history.add_epoch(0.5, Some(0.6), epoch_metrics);
history.add_epoch(0.4, Some(0.5), std::collections::HashMap::new());
let summary = history.summary();
assert!(summary.contains("Training History Summary"));
assert!(summary.contains("Total epochs: 2"));
assert!(summary.contains("Training time: 120.50 seconds"));
assert!(summary.contains("Final training loss"));
assert!(summary.contains("Best validation loss"));
}
#[test]
fn test_training_history_plot_data() {
let mut history = TrainingHistory::<f32>::new();
history.add_epoch(0.6, Some(0.7), std::collections::HashMap::new());
history.add_epoch(0.5, Some(0.6), std::collections::HashMap::new());
history.add_epoch(0.4, Some(0.5), std::collections::HashMap::new());
let (epochs, train_losses, val_losses) = history.plot_data();
assert_eq!(epochs.len(), 3);
assert_eq!(train_losses.len(), 3);
assert_eq!(val_losses.len(), 3);
assert_eq!(epochs, vec![1.0, 2.0, 3.0]);
assert!((train_losses[0] - 0.6).abs() < 1e-6);
assert!((train_losses[1] - 0.5).abs() < 1e-6);
assert!((train_losses[2] - 0.4).abs() < 1e-6);
assert!((val_losses[0] - 0.7).abs() < 1e-6);
assert!((val_losses[1] - 0.6).abs() < 1e-6);
assert!((val_losses[2] - 0.5).abs() < 1e-6);
}
#[test]
fn test_fit_config_builder() {
let config = FitConfig::new()
.epochs(20)
.batch_size(64)
.verbose(false)
.early_stopping(5);
assert_eq!(config.epochs, 20);
assert_eq!(config.batch_size, 64);
assert!(!config.verbose);
assert_eq!(config.patience, Some(5));
assert_eq!(config.validation_freq, 1);
}
#[test]
fn test_fit_config_default() {
let config = FitConfig::default();
assert_eq!(config.epochs, 10);
assert_eq!(config.batch_size, 32);
assert!(config.verbose);
assert_eq!(config.validation_freq, 1);
assert!(config.patience.is_none());
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
#[test]
fn test_sequential_with_training() {
let mut model: BasicSequential<f32> = BasicSequential::with_name("integration_test");
let summary = model.summary();
assert!(summary.contains("integration_test"));
Module::train(&mut model);
Module::eval(&mut model);
let input_data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let input_tensor = Tensor::from_vec(input_data, vec![1, 5]);
let input_var = Variable::new(input_tensor, false);
let output = Module::forward(&model, &input_var);
assert!(output.data().read().unwrap().shape().len() > 0);
}
#[test]
fn test_model_prediction_workflow() {
let model: BasicSequential<f32> = BasicSequential::new();
let input_data = vec![1.0, 2.0, 3.0];
let input_tensor = Tensor::from_vec(input_data, vec![1, 3]);
let input_var = Variable::new(input_tensor, false);
let prediction = Module::forward(&model, &input_var);
assert!(prediction.data().read().unwrap().shape().len() > 0);
}
#[test]
fn test_model_save_load_workflow() {
let model: BasicSequential<f32> = BasicSequential::with_name("save_load_test");
let summary = model.summary();
assert!(summary.contains("save_load_test"));
assert_eq!(model.len(), 0);
}
#[test]
fn test_full_workflow_simulation() {
let model: BasicSequential<f32> = BasicSequential::with_name("full_workflow");
let input_data = vec![1.0, 2.0, 3.0, 4.0];
let input_tensor = Tensor::from_vec(input_data, vec![1, 4]);
let input_var = Variable::new(input_tensor, false);
let output = Module::forward(&model, &input_var);
assert!(output.data().read().unwrap().shape().len() > 0);
let summary = model.summary();
assert!(summary.contains("full_workflow"));
assert!(summary.len() > 0);
assert_eq!(model.len(), 0);
assert!(model.is_empty());
}
}