use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct TrainingConfig {
pub epochs: usize,
pub batch_size: usize,
pub learning_rate: f64,
pub weight_decay: f64,
pub warmup_ratio: f32,
pub gradient_accumulation_steps: usize,
pub max_grad_norm: f32,
pub gradient_checkpointing: bool,
pub mixed_precision: Option<String>,
pub seed: u64,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
epochs: 3,
batch_size: 16,
learning_rate: 2e-4,
weight_decay: 0.01,
warmup_ratio: 0.03,
gradient_accumulation_steps: 1,
max_grad_norm: 1.0,
gradient_checkpointing: false,
mixed_precision: None,
seed: 42,
}
}
}