1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
//! ML Training Configuration Types
use serde::{Deserialize, Serialize};
/// Training configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
/// Maximum training iterations
pub max_iterations: usize,
/// Learning rate
pub learning_rate: f64,
/// Batch size
pub batch_size: usize,
/// Early stopping criteria
pub early_stopping: EarlyStoppingConfig,
/// Cross-validation folds
pub cv_folds: usize,
/// Training data split
pub train_test_split: f64,
/// Optimization algorithm for training
pub optimizer: TrainingOptimizer,
}
/// Early stopping configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EarlyStoppingConfig {
/// Enable early stopping
pub enable_early_stopping: bool,
/// Patience (iterations without improvement)
pub patience: usize,
/// Minimum improvement threshold
pub min_improvement: f64,
/// Restoration of best weights
pub restore_best_weights: bool,
}
/// Training optimizers
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum TrainingOptimizer {
SGD,
Adam,
AdamW,
RMSprop,
Adagrad,
LBFGS,
}
/// Regularization configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegularizationConfig {
/// L1 regularization strength
pub l1_lambda: f64,
/// L2 regularization strength
pub l2_lambda: f64,
/// Dropout rate
pub dropout_rate: f64,
/// Batch normalization
pub batch_normalization: bool,
/// Weight decay
pub weight_decay: f64,
}