#[derive(Debug, Clone)]
pub struct TrainingConfiguration {
pub learning_rate: LearningRateSchedule,
pub optimizer: Optimizer,
pub loss_function: LossFunction,
pub regularization: RegularizationConfig,
pub batch_size: usize,
pub epochs: usize,
pub validation_split: f64,
pub early_stopping: EarlyStoppingConfig,
}
#[derive(Debug, Clone)]
pub enum LearningRateSchedule {
Constant(f64),
ExponentialDecay {
initial_lr: f64,
decay_rate: f64,
decay_steps: usize,
},
CosineAnnealing {
initial_lr: f64,
min_lr: f64,
cycle_length: usize,
},
StepDecay {
initial_lr: f64,
drop_rate: f64,
epochs_drop: usize,
},
Adaptive {
initial_lr: f64,
patience: usize,
factor: f64,
},
}
#[derive(Debug, Clone)]
pub enum Optimizer {
SGD { momentum: f64, nesterov: bool },
Adam {
beta1: f64,
beta2: f64,
epsilon: f64,
},
AdamW {
beta1: f64,
beta2: f64,
epsilon: f64,
weight_decay: f64,
},
RMSprop { alpha: f64, epsilon: f64 },
AdaGrad { epsilon: f64 },
}
#[derive(Debug, Clone, Copy)]
pub enum LossFunction {
MSE,
CrossEntropy,
FocalLoss(f64, f64), HuberLoss(f64), WeightedMSE,
}
#[derive(Debug, Clone)]
pub struct RegularizationConfig {
pub l1_lambda: f64,
pub l2_lambda: f64,
pub dropout_prob: f64,
pub data_augmentation: Vec<DataAugmentation>,
pub label_smoothing: f64,
}
#[derive(Debug, Clone)]
pub enum DataAugmentation {
GaussianNoise(f64), TimeShift(f64), Scaling(f64, f64), FeaturePermutation,
Mixup(f64), }
#[derive(Debug, Clone)]
pub struct EarlyStoppingConfig {
pub enabled: bool,
pub monitor: String,
pub min_delta: f64,
pub patience: usize,
pub maximize: bool,
}
impl Default for TrainingConfiguration {
fn default() -> Self {
Self {
learning_rate: LearningRateSchedule::Constant(0.001),
optimizer: Optimizer::Adam {
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
},
loss_function: LossFunction::CrossEntropy,
regularization: RegularizationConfig::default(),
batch_size: 32,
epochs: 100,
validation_split: 0.2,
early_stopping: EarlyStoppingConfig::default(),
}
}
}
impl Default for RegularizationConfig {
fn default() -> Self {
Self {
l1_lambda: 0.0,
l2_lambda: 0.001,
dropout_prob: 0.1,
data_augmentation: Vec::new(),
label_smoothing: 0.0,
}
}
}
impl Default for EarlyStoppingConfig {
fn default() -> Self {
Self {
enabled: true,
monitor: "val_loss".to_string(),
min_delta: 1e-4,
patience: 10,
maximize: false,
}
}
}