Skip to main content

entrenar/yaml_mode/manifest/
training.rs

1//! Training Configuration
2//!
3//! Contains training loop configuration types for training manifests.
4
5use serde::{Deserialize, Serialize};
6
7/// Training loop configuration
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct TrainingConfig {
10    /// Number of epochs
11    #[serde(default, skip_serializing_if = "Option::is_none")]
12    pub epochs: Option<usize>,
13
14    /// Maximum training steps (mutually exclusive with epochs)
15    #[serde(default, skip_serializing_if = "Option::is_none")]
16    pub max_steps: Option<usize>,
17
18    /// Maximum wall-clock duration (mutually exclusive with epochs/max_steps)
19    #[serde(default, skip_serializing_if = "Option::is_none")]
20    pub duration: Option<String>,
21
22    /// Gradient settings
23    #[serde(default, skip_serializing_if = "Option::is_none")]
24    pub gradient: Option<GradientConfig>,
25
26    /// Mixed precision training
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    pub mixed_precision: Option<MixedPrecisionConfig>,
29
30    /// Distributed training
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub distributed: Option<DistributedConfig>,
33
34    /// Checkpointing
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub checkpoint: Option<CheckpointConfig>,
37
38    /// Early stopping (Jidoka)
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    pub early_stopping: Option<EarlyStoppingConfig>,
41
42    /// Validation configuration
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub validation: Option<ValidationConfig>,
45
46    /// Deterministic mode
47    #[serde(default, skip_serializing_if = "Option::is_none")]
48    pub deterministic: Option<bool>,
49
50    /// Benchmark mode (cuDNN autotuner)
51    #[serde(default, skip_serializing_if = "Option::is_none")]
52    pub benchmark: Option<bool>,
53
54    /// R-023: Curriculum learning — multi-stage data mixing
55    #[serde(default, skip_serializing_if = "Option::is_none")]
56    pub curriculum: Option<Vec<crate::config::CurriculumStage>>,
57}
58
59/// Gradient settings
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct GradientConfig {
62    /// Gradient accumulation steps
63    #[serde(default, skip_serializing_if = "Option::is_none")]
64    pub accumulation_steps: Option<usize>,
65
66    /// Gradient clipping (L2 norm)
67    #[serde(default, skip_serializing_if = "Option::is_none")]
68    pub clip_norm: Option<f64>,
69
70    /// Gradient clipping (absolute value)
71    #[serde(default, skip_serializing_if = "Option::is_none")]
72    pub clip_value: Option<f64>,
73}
74
75/// Mixed precision training configuration
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct MixedPrecisionConfig {
78    /// Enable mixed precision
79    pub enabled: bool,
80
81    /// Data type (float16, bfloat16)
82    #[serde(default, skip_serializing_if = "Option::is_none")]
83    pub dtype: Option<String>,
84
85    /// Loss scale (dynamic, static, or float)
86    #[serde(default, skip_serializing_if = "Option::is_none")]
87    pub loss_scale: Option<String>,
88}
89
90/// Distributed training configuration
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct DistributedConfig {
93    /// Strategy (ddp, fsdp, deepspeed)
94    pub strategy: String,
95
96    /// World size
97    #[serde(default, skip_serializing_if = "Option::is_none")]
98    pub world_size: Option<usize>,
99
100    /// Gradient as bucket view
101    #[serde(default, skip_serializing_if = "Option::is_none")]
102    pub gradient_as_bucket_view: Option<bool>,
103
104    /// Find unused parameters
105    #[serde(default, skip_serializing_if = "Option::is_none")]
106    pub find_unused_parameters: Option<bool>,
107}
108
109/// Checkpoint configuration
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct CheckpointConfig {
112    /// Save every N steps
113    #[serde(default, skip_serializing_if = "Option::is_none")]
114    pub save_every: Option<usize>,
115
116    /// Keep last N checkpoints
117    #[serde(default, skip_serializing_if = "Option::is_none")]
118    pub keep_last: Option<usize>,
119
120    /// Save best model by metric
121    #[serde(default, skip_serializing_if = "Option::is_none")]
122    pub save_best: Option<bool>,
123
124    /// Metric for best model selection
125    #[serde(default, skip_serializing_if = "Option::is_none")]
126    pub metric: Option<String>,
127
128    /// Metric mode (min, max)
129    #[serde(default, skip_serializing_if = "Option::is_none")]
130    pub mode: Option<String>,
131}
132
133/// Early stopping configuration (Jidoka - automatic halt on quality degradation)
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct EarlyStoppingConfig {
136    /// Enable early stopping
137    pub enabled: bool,
138
139    /// Metric to monitor
140    #[serde(default, skip_serializing_if = "Option::is_none")]
141    pub metric: Option<String>,
142
143    /// Patience (epochs without improvement)
144    #[serde(default, skip_serializing_if = "Option::is_none")]
145    pub patience: Option<usize>,
146
147    /// Minimum delta for improvement
148    #[serde(default, skip_serializing_if = "Option::is_none")]
149    pub min_delta: Option<f64>,
150
151    /// Metric mode (min, max)
152    #[serde(default, skip_serializing_if = "Option::is_none")]
153    pub mode: Option<String>,
154}
155
156/// Validation configuration
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct ValidationConfig {
159    /// Validate every N steps
160    #[serde(default, skip_serializing_if = "Option::is_none")]
161    pub every: Option<usize>,
162
163    /// Validate each epoch
164    #[serde(default, skip_serializing_if = "Option::is_none")]
165    pub every_epoch: Option<bool>,
166
167    /// Metrics to compute
168    #[serde(default, skip_serializing_if = "Option::is_none")]
169    pub metrics: Option<Vec<String>>,
170
171    /// Cross-validation configuration
172    #[serde(default, skip_serializing_if = "Option::is_none")]
173    pub cross_validation: Option<CrossValidationConfig>,
174}
175
176/// Cross-validation configuration
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct CrossValidationConfig {
179    /// Number of folds
180    pub folds: usize,
181
182    /// Stratified sampling
183    #[serde(default, skip_serializing_if = "Option::is_none")]
184    pub stratified: Option<bool>,
185
186    /// Shuffle data
187    #[serde(default, skip_serializing_if = "Option::is_none")]
188    pub shuffle: Option<bool>,
189}