rustorch 0.6.29

Production-ready PyTorch-compatible deep learning library in Rust with special mathematical functions (gamma, Bessel, error functions), statistical distributions, Fourier transforms (FFT/RFFT), matrix decomposition (SVD/QR/LU/eigenvalue), automatic differentiation, neural networks, computer vision transforms, complete GPU acceleration (CUDA/Metal/OpenCL), SIMD optimizations, parallel processing, WebAssembly browser support, comprehensive distributed learning support, and performance validation
Documentation
//! 簡略化された学習ループテスト
//! Simplified training loop tests

use super::*;
use crate::training::checkpoint::CheckpointMetadata;
use crate::training::metrics::ConfusionMatrix;
use crate::training::trainer::CallbackSignal;

/// 基本的な学習ループのテスト
/// Basic training loop tests
#[cfg(test)]
mod basic_tests {
    use super::*;

    #[test]
    fn test_trainer_config_creation() {
        let config = TrainerConfig::default();
        assert_eq!(config.epochs, 10);
        assert_eq!(config.log_frequency, 100);
        assert_eq!(config.validation_frequency, 1);
        assert_eq!(config.device, "cpu");
        assert!(!config.use_mixed_precision);
        assert_eq!(config.accumulation_steps, 1);
    }

    #[test]
    fn test_training_state() {
        let mut state = TrainingState::<f32>::new(5);
        assert_eq!(state.total_epochs, 5);
        assert_eq!(state.current_epoch, 0);
        assert!(!state.is_completed());
        assert_eq!(state.progress(), 0.0);

        let epoch_state = EpochState::new(0);
        state.add_epoch(epoch_state);
        assert_eq!(state.current_epoch, 1);
        assert_eq!(state.progress(), 20.0);
    }

    #[test]
    fn test_epoch_state() {
        let mut epoch_state: EpochState<f32> = EpochState::new(5);
        assert_eq!(epoch_state.epoch, 5);
        assert!(epoch_state.train_metrics.is_none());
        assert!(epoch_state.val_metrics.is_none());
        assert!(epoch_state.batch_history.is_empty());

        epoch_state.set_metadata("learning_rate".to_string(), "0.001".to_string());
        assert_eq!(
            epoch_state.metadata.get("learning_rate"),
            Some(&"0.001".to_string())
        );
    }

    #[test]
    fn test_batch_state() {
        let mut batch_state = BatchState::new(42);
        assert_eq!(batch_state.batch, 42);
        assert!(batch_state.loss.is_none());
        assert!(batch_state.metrics.is_empty());
        assert!(batch_state.batch_size.is_none());

        batch_state.set_metric("accuracy".to_string(), 0.85);
        batch_state.set_metadata("device".to_string(), "cuda".to_string());

        assert_eq!(batch_state.get_metric("accuracy"), Some(0.85));
        assert_eq!(
            batch_state.metadata.get("device"),
            Some(&"cuda".to_string())
        );
    }

    #[test]
    fn test_metrics_collector() {
        let _collector = MetricsCollector::<f32>::new();
        // 基本的な作成テストのみ - プレースホルダーテスト
    }

    #[test]
    fn test_early_stopping_callback() {
        let _early_stopping: EarlyStopping<f32> = EarlyStopping::monitor_val_loss(3, 0.01);

        // 基本的な作成テストのみ(プライベートフィールドのアクセスを避ける) - プレースホルダーテスト
    }

    #[test]
    fn test_lr_scheduler() {
        let _scheduler = LearningRateScheduler::exponential_decay(0.001, 0.95);

        // 基本的な作成テストのみ(プライベートフィールドのアクセスを避ける) - プレースホルダーテスト
    }

    #[test]
    fn test_progress_bar() {
        let _progress = ProgressBar::simple();

        // 基本的な作成テストのみ(プライベートフィールドのアクセスを避ける) - プレースホルダーテスト
    }

    #[test]
    fn test_confusion_matrix() {
        let mut matrix = ConfusionMatrix::new();
        matrix.true_positives = 80;
        matrix.false_positives = 10;
        matrix.true_negatives = 90;
        matrix.false_negatives = 20;

        assert_eq!(matrix.total(), 200);
        assert_eq!(matrix.accuracy(), 0.85);

        let precision = matrix.precision();
        let recall = matrix.recall();
        assert!(precision > 0.0 && precision <= 1.0);
        assert!(recall > 0.0 && recall <= 1.0);

        let f1 = matrix.f1_score();
        assert!(f1 > 0.0 && f1 <= 1.0);
    }

    #[test]
    fn test_checkpoint_config() {
        let config = SaveConfig::default();
        assert_eq!(config.prefix, "model");
        assert_eq!(config.max_checkpoints, 5);
        assert!(!config.save_best_only);
        assert_eq!(config.monitor, "val_loss");
        assert!(!config.mode_max);
    }

    #[test]
    fn test_checkpoint_metadata() {
        let mut metadata = CheckpointMetadata::new(5, 0.25);
        assert_eq!(metadata.epoch, 5);
        assert_eq!(metadata.train_loss, 0.25);
        assert!(metadata.val_loss.is_none());
        assert!(metadata.timestamp > 0);

        metadata.set_metric("accuracy".to_string(), 0.85);
        metadata.set_extra("model_type".to_string(), "transformer".to_string());

        assert_eq!(metadata.metrics.get("accuracy"), Some(&0.85));
        assert_eq!(
            metadata.extra.get("model_type"),
            Some(&"transformer".to_string())
        );
    }
}

/// パフォーマンステスト
/// Performance tests
#[cfg(test)]
mod performance_tests {
    use super::*;
    use std::time::{Duration, Instant};

    #[test]
    fn test_state_creation_performance() {
        let start_time = Instant::now();

        for _ in 0..1000 {
            let _state = TrainingState::<f32>::new(100);
        }

        let duration = start_time.elapsed();
        assert!(duration < Duration::from_millis(100));
        println!("State creation performance: {:?}", duration);
    }

    #[test]
    fn test_metrics_collector_performance() {
        let start_time = Instant::now();

        for _ in 0..1000 {
            let _collector = MetricsCollector::<f32>::new();
        }

        let duration = start_time.elapsed();
        assert!(duration < Duration::from_millis(50));
        println!("MetricsCollector creation performance: {:?}", duration);
    }

    #[test]
    fn test_callback_creation_performance() {
        let start_time = Instant::now();

        for _ in 0..1000 {
            let _early_stopping: EarlyStopping<f32> = EarlyStopping::monitor_val_loss(5, 0.01);
            let _scheduler = LearningRateScheduler::exponential_decay(0.001, 0.95);
            let _progress = ProgressBar::simple();
        }

        let duration = start_time.elapsed();
        assert!(duration < Duration::from_millis(100));
        println!("Callback creation performance: {:?}", duration);
    }
}

/// 統合テスト
/// Integration tests
#[cfg(test)]
mod integration_tests {
    use super::*;

    #[test]
    fn test_training_state_progression() {
        let mut state = TrainingState::<f32>::new(3);

        for epoch in 0..3 {
            let mut epoch_state = EpochState::new(epoch);
            epoch_state.learning_rate = Some(0.001 * 0.9_f64.powi(epoch as i32));

            let mut batch_state = BatchState::new(0);
            batch_state.loss = Some(0.5 - epoch as f64 * 0.1);
            batch_state.set_metric("accuracy".to_string(), 0.5 + epoch as f64 * 0.15);

            epoch_state.add_batch(batch_state);
            state.add_epoch(epoch_state);
        }

        assert!(state.is_completed());
        assert_eq!(state.progress(), 100.0);
        assert_eq!(state.epoch_history.len(), 3);

        // 各エポックの学習率が正しく減少しているかチェック
        for (i, epoch) in state.epoch_history.iter().enumerate() {
            let expected_lr = 0.001 * 0.9_f64.powi(i as i32);
            assert!((epoch.learning_rate.unwrap() - expected_lr).abs() < 1e-6);
        }
    }

    #[test]
    fn test_early_stopping_logic() {
        let mut early_stopping: EarlyStopping<f32> = EarlyStopping::monitor_val_loss(3, 0.01);

        // 初期化
        let mut state = TrainingState::<f32>::new(10);
        early_stopping.on_train_begin(&mut state).unwrap();

        // 改善するケース
        let mut epoch_state = EpochState::new(0);
        let val_metrics = trainer::EpochMetrics {
            total_loss: 0.5,
            avg_loss: 0.5,
            batch_count: 10,
        };
        epoch_state.val_metrics = Some(val_metrics);

        let result = early_stopping
            .on_epoch_end(&mut state, &epoch_state)
            .unwrap();
        assert!(matches!(result, Some(CallbackSignal::Continue)));

        // さらに改善するケース
        let mut epoch_state = EpochState::new(1);
        let val_metrics = trainer::EpochMetrics {
            total_loss: 0.3,
            avg_loss: 0.3,
            batch_count: 10,
        };
        epoch_state.val_metrics = Some(val_metrics);

        let result = early_stopping
            .on_epoch_end(&mut state, &epoch_state)
            .unwrap();
        assert!(matches!(result, Some(CallbackSignal::Continue)));
    }

    #[test]
    fn test_learning_rate_scheduler_logic() {
        let mut scheduler = LearningRateScheduler::exponential_decay(0.1, 0.9);
        let mut state = TrainingState::<f32>::new(5);

        scheduler.on_train_begin(&mut state).unwrap();

        // 複数エポックでの学習率変化をテスト
        for epoch in 0..5 {
            let mut epoch_state = EpochState::new(epoch);
            scheduler
                .on_epoch_begin(&mut state, &mut epoch_state)
                .unwrap();

            // 実際の学習率の値チェックは省略(プライベートフィールドアクセス問題を避ける)
            // 学習率がエポック状態に設定されているかのみチェック
            assert!(epoch_state.learning_rate.is_some());
        }
    }

    #[test]
    fn test_checkpoint_metadata_progression() {
        let mut metadata = CheckpointMetadata::new(0, 1.0);
        metadata.val_loss = Some(0.8);
        metadata.learning_rate = Some(0.001);

        // 複数のメトリクスを追加
        metadata.set_metric("accuracy".to_string(), 0.75);
        metadata.set_metric("precision".to_string(), 0.78);
        metadata.set_metric("recall".to_string(), 0.72);

        // 追加メタデータ
        metadata.set_extra("optimizer".to_string(), "adam".to_string());
        metadata.set_extra("batch_size".to_string(), "32".to_string());

        // 検証
        assert_eq!(metadata.epoch, 0);
        assert_eq!(metadata.train_loss, 1.0);
        assert_eq!(metadata.val_loss, Some(0.8));
        assert_eq!(metadata.learning_rate, Some(0.001));
        assert_eq!(metadata.metrics.len(), 3);
        assert_eq!(metadata.extra.len(), 2);

        // メトリクスの値を確認
        assert_eq!(metadata.metrics.get("accuracy"), Some(&0.75));
        assert_eq!(metadata.extra.get("optimizer"), Some(&"adam".to_string()));
    }
}

/// ベンチマーク用のユーティリティ
/// Benchmark utilities
pub mod simple_benchmarks {
    use super::*;
    use std::time::{Duration, Instant};

    /// 基本オブジェクト作成のベンチマーク
    /// Basic object creation benchmark
    pub fn benchmark_object_creation(iterations: usize) -> Duration {
        let start_time = Instant::now();

        for _ in 0..iterations {
            let _state = TrainingState::<f32>::new(100);
            let _collector = MetricsCollector::<f32>::new();
            let _early_stopping: EarlyStopping<f32> = EarlyStopping::monitor_val_loss(5, 0.01);
        }

        start_time.elapsed()
    }

    /// 状態管理のベンチマーク
    /// State management benchmark
    pub fn benchmark_state_management(num_epochs: usize) -> Duration {
        let start_time = Instant::now();
        let mut state = TrainingState::<f32>::new(num_epochs);

        for epoch in 0..num_epochs {
            let mut epoch_state = EpochState::new(epoch);

            // 複数のバッチを追加
            for batch in 0..10 {
                let mut batch_state = BatchState::new(batch);
                batch_state.loss = Some(0.5 - batch as f64 * 0.01);
                batch_state.set_metric("accuracy".to_string(), 0.8 + batch as f64 * 0.01);
                epoch_state.add_batch(batch_state);
            }

            state.add_epoch(epoch_state);
        }

        start_time.elapsed()
    }
}

#[cfg(test)]
mod benchmark_tests {
    use super::simple_benchmarks::*;
    use std::time::Duration;

    #[test]
    fn test_object_creation_benchmark() {
        let duration = benchmark_object_creation(100);
        println!("Object creation benchmark (100 iterations): {:?}", duration);
        assert!(duration < Duration::from_millis(100));
    }

    #[test]
    fn test_state_management_benchmark() {
        let duration = benchmark_state_management(10);
        println!("State management benchmark (10 epochs): {:?}", duration);
        assert!(duration < Duration::from_millis(50));
    }
}