#![allow(dead_code)]
use super::*;
use crate::layers::*;
use crate::models::sequential::Sequential;
use scirs2_core::ndarray::Array2;
use scirs2_core::numeric::Float;
use scirs2_core::random::SeedableRng;
use std::fs;
use tempfile::tempdir;
#[test]
#[allow(dead_code)]
fn test_dense_layer_serialization_roundtrip() -> Result<()> {
let temp_dir = tempdir().expect("Failed to create temp directory");
let model_path = temp_dir.path().join("test_dense_model.json");
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
let mut model = Sequential::<f32>::new();
model.add(Box::new(Dense::new(4, 8, Some("relu"), &mut rng)?));
model.add(Box::new(Dense::new(8, 3, Some("softmax"), &mut rng)?));
save_model(&model, &model_path, SerializationFormat::JSON)?;
assert!(model_path.exists(), "Model file should exist after saving");
let loaded_model: Sequential<f32> = load_model(&model_path, SerializationFormat::JSON)?;
assert_eq!(model.layers().len(), loaded_model.layers().len());
let input = Array2::<f32>::ones((2, 4)).into_dyn();
let original_output = model.forward(&input)?;
let loaded_output = loaded_model.forward(&input)?;
assert_eq!(original_output.shape(), loaded_output.shape());
Ok(())
}
#[allow(dead_code)]
fn test_cnn_model_serialization() -> Result<()> {
let model_path = temp_dir.path().join("test_cnn_model.json");
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed(123);
model.add(Box::new(Conv2D::new(
3, 16, (3, 3), (1, 1), PaddingMode::Same, &mut rng
)?));
model.add(Box::new(BatchNorm::new(16, 0.1, 1e-5, &mut rng)?));
model.add(Box::new(MaxPool2D::new((2, 2), (2, 2), None)?));
model.add(Box::new(Dropout::new(0.25, &mut rng)?));
#[allow(dead_code)]
fn test_multiple_serialization_formats() -> Result<()> {
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed(456);
model.add(Box::new(Dense::new(3, 5, Some("tanh"), &mut rng)?));
model.add(Box::new(LayerNorm::new(5, 1e-5, &mut rng)?));
let json_path = temp_dir.path().join("model.json");
save_model(&model, &json_path, SerializationFormat::JSON)?;
let _loaded_json: Sequential<f32> = load_model(&json_path, SerializationFormat::JSON)?;
let cbor_path = temp_dir.path().join("model.cbor");
save_model(&model, &cbor_path, SerializationFormat::CBOR)?;
let _loaded_cbor: Sequential<f32> = load_model(&cbor_path, SerializationFormat::CBOR)?;
let msgpack_path = temp_dir.path().join("model.msgpack");
save_model(&model, &msgpack_path, SerializationFormat::MessagePack)?;
let _loaded_msgpack: Sequential<f32> = load_model(&msgpack_path, SerializationFormat::MessagePack)?;
assert!(json_path.exists());
assert!(cbor_path.exists());
assert!(msgpack_path.exists());
#[allow(dead_code)]
fn test_parameter_preservation() -> Result<()> {
let model_path = temp_dir.path().join("param_test_model.json");
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed(789);
let dense_layer = Dense::new(2, 3, None, &mut rng)?;
let original_params = dense_layer.get_parameters();
model.add(Box::new(dense_layer));
let loaded_dense = loaded_model.layers()[0]
.as_any()
.downcast_ref::<Dense<f32>>()
.expect("First layer should be Dense");
let loaded_params = loaded_dense.get_parameters();
assert_eq!(original_params.len(), loaded_params.len());
for (orig, loaded) in original_params.iter().zip(loaded_params.iter()) {
assert_eq!(orig.shape(), loaded.shape());
}
#[allow(dead_code)]
fn test_error_handling() -> Result<()> {
let invalid_path = temp_dir.path().join("invalid.json");
let result = load_model::<f32>(&invalid_path, SerializationFormat::JSON);
assert!(result.is_err());
fs::write(&invalid_path, "invalid json content")?;
#[allow(dead_code)]
fn test_activation_function_serialization() -> Result<()> {
let activations = vec![
"relu", "sigmoid", "tanh", "softmax", "gelu", "swish", "mish"
];
for activation_name in activations {
let activation_fn = ActivationFunction::from_name(activation_name);
assert!(activation_fn.is_some(), "Should recognize activation: {}", activation_name);
if let Some(af) = activation_fn {
let created = af.create::<f32>();
assert!(!created.as_any().type_id().is_zero(), "Should create valid activation");
}
let leaky_relu = ActivationFunction::from_name("leaky_relu(0.2)");
assert!(leaky_relu.is_some());
let elu = ActivationFunction::from_name("elu(1.5)");
assert!(elu.is_some());
#[allow(dead_code)]
fn test_mixed_layer_model() -> Result<()> {
let model_path = temp_dir.path().join("mixed_model.json");
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed(999);
model.add(Box::new(Dense::new(10, 20, Some("relu"), &mut rng)?));
model.add(Box::new(BatchNorm::new(20, 0.1, 1e-5, &mut rng)?));
model.add(Box::new(Dropout::new(0.3, &mut rng)?));
model.add(Box::new(LayerNorm::new(20, 1e-6, &mut rng)?));
model.add(Box::new(Dense::new(20, 5, Some("softmax"), &mut rng)?));
let input = Array2::<f32>::ones((3, 10)).into_dyn();
#[allow(dead_code)]
fn test_f64_model_serialization() -> Result<()> {
let model_path = temp_dir.path().join("f64_model.json");
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed(111);
let mut model = Sequential::<f64>::new();
model.add(Box::new(Dense::new(4, 6, Some("tanh"), &mut rng)?));
model.add(Box::new(Dense::new(6, 2, None, &mut rng)?));
let loaded_model: Sequential<f64> = load_model(&model_path, SerializationFormat::JSON)?;
let input = Array2::<f64>::ones((2, 4)).into_dyn();
#[allow(dead_code)]
fn test_empty_model_serialization() -> Result<()> {
let model_path = temp_dir.path().join("empty_model.json");
let model = Sequential::<f32>::new();
assert_eq!(0, loaded_model.layers().len());
#[allow(dead_code)]
fn test_large_model_serialization() -> Result<()> {
let model_path = temp_dir.path().join("large_model.json");
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed(222);
model.add(Box::new(Dense::new(100, 200, Some("relu"), &mut rng)?));
model.add(Box::new(Dense::new(200, 400, Some("relu"), &mut rng)?));
model.add(Box::new(Dropout::new(0.5, &mut rng)?));
model.add(Box::new(Dense::new(400, 200, Some("relu"), &mut rng)?));
model.add(Box::new(Dense::new(200, 50, Some("relu"), &mut rng)?));
model.add(Box::new(Dense::new(50, 10, Some("softmax"), &mut rng)?));
let start_time = std::time::Instant::now();
let save_duration = start_time.elapsed();
let _loaded_model: Sequential<f32> = load_model(&model_path, SerializationFormat::JSON)?;
let load_duration = start_time.elapsed();
assert!(save_duration.as_secs() < 5, "Save should complete within 5 seconds");
assert!(load_duration.as_secs() < 5, "Load should complete within 5 seconds");
#[allow(dead_code)]
fn test_format_comparison() -> Result<()> {
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed(333);
model.add(Box::new(Dense::new(50, 100, Some("relu"), &mut rng)?));
model.add(Box::new(Dense::new(100, 25, Some("softmax"), &mut rng)?));
let json_size = fs::metadata(&json_path)?.len();
let cbor_size = fs::metadata(&cbor_path)?.len();
let msgpack_size = fs::metadata(&msgpack_path)?.len();
assert!(json_size > 0);
assert!(cbor_size > 0);
assert!(msgpack_size > 0);
println!("File sizes - JSON: {}, CBOR: {}, MessagePack: {}", json_size, cbor_size, msgpack_size);
#[allow(dead_code)]
fn create_test_model<F: Float + Debug + ScalarOperand + Send + Sync + 'static>() -> Result<Sequential<F>> {
let mut model = Sequential::<F>::new();
model.add(Box::new(Dropout::new(0.2, &mut rng)?));
model.add(Box::new(Dense::new(20, 10, Some("tanh"), &mut rng)?));
model.add(Box::new(Dense::new(10, 5, Some("softmax"), &mut rng)?));
Ok(model)
#[allow(dead_code)]
fn test_complete_workflow() -> Result<()> {
let model_path = temp_dir.path().join("workflow_model.json");
let model = create_test_model::<f32>()?;
let input = Array2::<f32>::ones((5, 10)).into_dyn();
let output = loaded_model.forward(&input)?;
assert_eq!(output.shape(), &[5, 5]);
fs::remove_file(&model_path)?;