quantrs2_device/ml_optimization/
training.rs

1//! ML Training Configuration Types
2
3use serde::{Deserialize, Serialize};
4
5/// Training configuration
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct TrainingConfig {
8    /// Maximum training iterations
9    pub max_iterations: usize,
10    /// Learning rate
11    pub learning_rate: f64,
12    /// Batch size
13    pub batch_size: usize,
14    /// Early stopping criteria
15    pub early_stopping: EarlyStoppingConfig,
16    /// Cross-validation folds
17    pub cv_folds: usize,
18    /// Training data split
19    pub train_test_split: f64,
20    /// Optimization algorithm for training
21    pub optimizer: TrainingOptimizer,
22}
23
24/// Early stopping configuration
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct EarlyStoppingConfig {
27    /// Enable early stopping
28    pub enable_early_stopping: bool,
29    /// Patience (iterations without improvement)
30    pub patience: usize,
31    /// Minimum improvement threshold
32    pub min_improvement: f64,
33    /// Restoration of best weights
34    pub restore_best_weights: bool,
35}
36
37/// Training optimizers
38#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
39pub enum TrainingOptimizer {
40    SGD,
41    Adam,
42    AdamW,
43    RMSprop,
44    Adagrad,
45    LBFGS,
46}
47
48/// Regularization configuration
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct RegularizationConfig {
51    /// L1 regularization strength
52    pub l1_lambda: f64,
53    /// L2 regularization strength
54    pub l2_lambda: f64,
55    /// Dropout rate
56    pub dropout_rate: f64,
57    /// Batch normalization
58    pub batch_normalization: bool,
59    /// Weight decay
60    pub weight_decay: f64,
61}