entrenar/yaml_mode/manifest/
training.rs1use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct TrainingConfig {
10 #[serde(default, skip_serializing_if = "Option::is_none")]
12 pub epochs: Option<usize>,
13
14 #[serde(default, skip_serializing_if = "Option::is_none")]
16 pub max_steps: Option<usize>,
17
18 #[serde(default, skip_serializing_if = "Option::is_none")]
20 pub duration: Option<String>,
21
22 #[serde(default, skip_serializing_if = "Option::is_none")]
24 pub gradient: Option<GradientConfig>,
25
26 #[serde(default, skip_serializing_if = "Option::is_none")]
28 pub mixed_precision: Option<MixedPrecisionConfig>,
29
30 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub distributed: Option<DistributedConfig>,
33
34 #[serde(default, skip_serializing_if = "Option::is_none")]
36 pub checkpoint: Option<CheckpointConfig>,
37
38 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub early_stopping: Option<EarlyStoppingConfig>,
41
42 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub validation: Option<ValidationConfig>,
45
46 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub deterministic: Option<bool>,
49
50 #[serde(default, skip_serializing_if = "Option::is_none")]
52 pub benchmark: Option<bool>,
53
54 #[serde(default, skip_serializing_if = "Option::is_none")]
56 pub curriculum: Option<Vec<crate::config::CurriculumStage>>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct GradientConfig {
62 #[serde(default, skip_serializing_if = "Option::is_none")]
64 pub accumulation_steps: Option<usize>,
65
66 #[serde(default, skip_serializing_if = "Option::is_none")]
68 pub clip_norm: Option<f64>,
69
70 #[serde(default, skip_serializing_if = "Option::is_none")]
72 pub clip_value: Option<f64>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct MixedPrecisionConfig {
78 pub enabled: bool,
80
81 #[serde(default, skip_serializing_if = "Option::is_none")]
83 pub dtype: Option<String>,
84
85 #[serde(default, skip_serializing_if = "Option::is_none")]
87 pub loss_scale: Option<String>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct DistributedConfig {
93 pub strategy: String,
95
96 #[serde(default, skip_serializing_if = "Option::is_none")]
98 pub world_size: Option<usize>,
99
100 #[serde(default, skip_serializing_if = "Option::is_none")]
102 pub gradient_as_bucket_view: Option<bool>,
103
104 #[serde(default, skip_serializing_if = "Option::is_none")]
106 pub find_unused_parameters: Option<bool>,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct CheckpointConfig {
112 #[serde(default, skip_serializing_if = "Option::is_none")]
114 pub save_every: Option<usize>,
115
116 #[serde(default, skip_serializing_if = "Option::is_none")]
118 pub keep_last: Option<usize>,
119
120 #[serde(default, skip_serializing_if = "Option::is_none")]
122 pub save_best: Option<bool>,
123
124 #[serde(default, skip_serializing_if = "Option::is_none")]
126 pub metric: Option<String>,
127
128 #[serde(default, skip_serializing_if = "Option::is_none")]
130 pub mode: Option<String>,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct EarlyStoppingConfig {
136 pub enabled: bool,
138
139 #[serde(default, skip_serializing_if = "Option::is_none")]
141 pub metric: Option<String>,
142
143 #[serde(default, skip_serializing_if = "Option::is_none")]
145 pub patience: Option<usize>,
146
147 #[serde(default, skip_serializing_if = "Option::is_none")]
149 pub min_delta: Option<f64>,
150
151 #[serde(default, skip_serializing_if = "Option::is_none")]
153 pub mode: Option<String>,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct ValidationConfig {
159 #[serde(default, skip_serializing_if = "Option::is_none")]
161 pub every: Option<usize>,
162
163 #[serde(default, skip_serializing_if = "Option::is_none")]
165 pub every_epoch: Option<bool>,
166
167 #[serde(default, skip_serializing_if = "Option::is_none")]
169 pub metrics: Option<Vec<String>>,
170
171 #[serde(default, skip_serializing_if = "Option::is_none")]
173 pub cross_validation: Option<CrossValidationConfig>,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct CrossValidationConfig {
179 pub folds: usize,
181
182 #[serde(default, skip_serializing_if = "Option::is_none")]
184 pub stratified: Option<bool>,
185
186 #[serde(default, skip_serializing_if = "Option::is_none")]
188 pub shuffle: Option<bool>,
189}