Skip to main content

entrenar/hf_pipeline/config/
training.rs

1//! Training hyperparameters configuration
2
3use serde::{Deserialize, Serialize};
4
5/// Training hyperparameters
6#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(default)]
8pub struct TrainingConfig {
9    /// Number of epochs
10    pub epochs: usize,
11    /// Batch size
12    pub batch_size: usize,
13    /// Learning rate
14    pub learning_rate: f64,
15    /// Weight decay
16    pub weight_decay: f64,
17    /// Warmup ratio
18    pub warmup_ratio: f32,
19    /// Gradient accumulation steps
20    pub gradient_accumulation_steps: usize,
21    /// Maximum gradient norm
22    pub max_grad_norm: f32,
23    /// Enable gradient checkpointing
24    pub gradient_checkpointing: bool,
25    /// Mixed precision mode
26    pub mixed_precision: Option<String>,
27    /// Random seed
28    pub seed: u64,
29}
30
31impl Default for TrainingConfig {
32    fn default() -> Self {
33        Self {
34            epochs: 3,
35            batch_size: 16,
36            learning_rate: 2e-4,
37            weight_decay: 0.01,
38            warmup_ratio: 0.03,
39            gradient_accumulation_steps: 1,
40            max_grad_norm: 1.0,
41            gradient_checkpointing: false,
42            mixed_precision: None,
43            seed: 42,
44        }
45    }
46}