quantrs2_device/ml_optimization/
training.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct TrainingConfig {
8 pub max_iterations: usize,
10 pub learning_rate: f64,
12 pub batch_size: usize,
14 pub early_stopping: EarlyStoppingConfig,
16 pub cv_folds: usize,
18 pub train_test_split: f64,
20 pub optimizer: TrainingOptimizer,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct EarlyStoppingConfig {
27 pub enable_early_stopping: bool,
29 pub patience: usize,
31 pub min_improvement: f64,
33 pub restore_best_weights: bool,
35}
36
37#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
39pub enum TrainingOptimizer {
40 SGD,
41 Adam,
42 AdamW,
43 RMSprop,
44 Adagrad,
45 LBFGS,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct RegularizationConfig {
51 pub l1_lambda: f64,
53 pub l2_lambda: f64,
55 pub dropout_rate: f64,
57 pub batch_normalization: bool,
59 pub weight_decay: f64,
61}