entrenar/config/validate/
json_schema.rs1use serde_json::{json, Value};
9
10#[allow(dead_code)]
12pub fn training_config_json_schema() -> Value {
18 json!({
19 "$schema": "http://json-schema.org/draft-07/schema#",
20 "title": "entrenar Training Configuration",
21 "description": "Schema for entrenar YAML training configuration files",
22 "type": "object",
23 "required": ["model", "data", "optimizer", "training"],
24 "properties": {
25 "model": {
26 "type": "object",
27 "required": ["path"],
28 "properties": {
29 "path": {
30 "type": "string",
31 "description": "Path to model weights or HuggingFace repo ID"
32 },
33 "hidden_size": { "type": "integer", "minimum": 1 },
34 "num_layers": { "type": "integer", "minimum": 1 },
35 "num_heads": { "type": "integer", "minimum": 1 },
36 "num_kv_heads": { "type": "integer", "minimum": 1 },
37 "intermediate_size": { "type": "integer", "minimum": 1 },
38 "vocab_size": { "type": "integer", "minimum": 1 },
39 "max_position_embeddings": { "type": "integer", "minimum": 1 }
40 }
41 },
42 "data": {
43 "type": "object",
44 "required": ["train", "batch_size"],
45 "properties": {
46 "train": { "type": "string", "description": "Training data path" },
47 "val": { "type": "string", "description": "Validation data path" },
48 "batch_size": { "type": "integer", "minimum": 1 },
49 "seq_len": { "type": "integer", "minimum": 1 }
50 }
51 },
52 "optimizer": {
53 "type": "object",
54 "required": ["name", "lr"],
55 "properties": {
56 "name": {
57 "type": "string",
58 "enum": ["adam", "adamw", "sgd", "rmsprop", "adagrad", "lamb"]
59 },
60 "lr": { "type": "number", "exclusiveMinimum": 0, "maximum": 1 },
61 "beta1": { "type": "number", "minimum": 0, "maximum": 1 },
62 "beta2": { "type": "number", "minimum": 0, "maximum": 1 },
63 "epsilon": { "type": "number", "exclusiveMinimum": 0 },
64 "weight_decay": { "type": "number", "minimum": 0 }
65 }
66 },
67 "training": {
68 "type": "object",
69 "required": ["epochs"],
70 "properties": {
71 "epochs": { "type": "integer", "minimum": 1 },
72 "max_steps": { "type": "integer", "minimum": 1 },
73 "grad_clip": { "type": "number", "exclusiveMinimum": 0 },
74 "gradient_accumulation": { "type": "integer", "minimum": 1 },
75 "save_interval": { "type": "integer", "minimum": 1 },
76 "log_interval": { "type": "integer", "minimum": 1 },
77 "lr_scheduler": {
78 "type": "string",
79 "enum": ["cosine", "linear", "constant", "step", "exponential", "one_cycle", "plateau"]
80 },
81 "warmup_steps": { "type": "integer", "minimum": 0 },
82 "mixed_precision": {
83 "type": "string",
84 "enum": ["bf16", "fp16", "fp32"]
85 },
86 "deterministic": { "type": "boolean" },
87 "seed": { "type": "integer" },
88 "eval_interval": { "type": "integer", "minimum": 0 },
89 "patience": { "type": "integer", "minimum": 0 }
90 }
91 },
92 "lora": {
93 "type": "object",
94 "properties": {
95 "rank": { "type": "integer", "minimum": 1, "maximum": 1024 },
96 "alpha": { "type": "number", "exclusiveMinimum": 0 },
97 "dropout": { "type": "number", "minimum": 0, "exclusiveMaximum": 1 },
98 "target_modules": {
99 "type": "array",
100 "items": { "type": "string" },
101 "minItems": 1
102 }
103 }
104 },
105 "distributed": {
106 "type": "object",
107 "properties": {
108 "world_size": { "type": "integer", "minimum": 1 },
109 "rank": { "type": "integer", "minimum": 0 },
110 "local_rank": { "type": "integer", "minimum": 0 },
111 "role": { "type": "string", "enum": ["coordinator", "worker"] },
112 "backend": { "type": "string", "enum": ["cuda", "wgpu", "auto"] },
113 "coordinator_addr": { "type": "string" }
114 }
115 }
116 }
117 })
118}
119
120#[allow(dead_code)]
122pub fn validate_yaml_against_schema(yaml_str: &str) -> Result<(), Vec<String>> {
123 let value: serde_json::Value = match serde_yaml::from_str(yaml_str) {
124 Ok(v) => v,
125 Err(e) => return Err(vec![format!("YAML parse error: {e}")]),
126 };
127
128 let schema = training_config_json_schema();
129 let validator = jsonschema::validator_for(&schema)
130 .map_err(|e| vec![format!("Schema compilation error: {e}")])?;
131
132 let mut errors: Vec<String> = validator
133 .iter_errors(&value)
134 .map(|error| {
135 let path = error.instance_path.to_string();
136 if path.is_empty() {
137 error.to_string()
138 } else {
139 format!("{path}: {error}")
140 }
141 })
142 .collect();
143
144 semantic_checks(&value, &mut errors);
146
147 if errors.is_empty() {
148 Ok(())
149 } else {
150 Err(errors)
151 }
152}
153
154#[allow(dead_code)]
156fn semantic_checks(value: &serde_json::Value, errors: &mut Vec<String>) {
157 let Some(obj) = value.as_object() else {
158 return;
159 };
160
161 if let Some(lr) =
162 obj.get("optimizer").and_then(|o| o.get("lr")).and_then(serde_json::Value::as_f64)
163 {
164 if lr <= 0.0 || lr > 1.0 {
165 errors.push(format!("optimizer.lr must be in (0, 1], got {lr}"));
166 }
167 }
168
169 if let Some(epochs) =
170 obj.get("training").and_then(|t| t.get("epochs")).and_then(serde_json::Value::as_u64)
171 {
172 if epochs == 0 {
173 errors.push("training.epochs must be >= 1".to_string());
174 }
175 }
176
177 if let Some(bs) =
178 obj.get("data").and_then(|d| d.get("batch_size")).and_then(serde_json::Value::as_u64)
179 {
180 if bs == 0 {
181 errors.push("data.batch_size must be >= 1".to_string());
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn test_json_schema_has_required_fields() {
192 let schema = training_config_json_schema();
193 let required = schema["required"].as_array().unwrap();
194 assert!(required.contains(&json!("model")));
195 assert!(required.contains(&json!("data")));
196 assert!(required.contains(&json!("optimizer")));
197 assert!(required.contains(&json!("training")));
198 }
199
200 #[test]
201 fn test_json_schema_optimizer_enum() {
202 let schema = training_config_json_schema();
203 let opt_enum = &schema["properties"]["optimizer"]["properties"]["name"]["enum"];
204 assert!(opt_enum.as_array().unwrap().contains(&json!("adamw")));
205 }
206
207 #[test]
208 fn test_validate_valid_yaml() {
209 let yaml = r"
210model:
211 path: /tmp/model
212data:
213 train: /tmp/train
214 batch_size: 4
215optimizer:
216 name: adamw
217 lr: 0.001
218training:
219 epochs: 10
220";
221 assert!(validate_yaml_against_schema(yaml).is_ok());
222 }
223
224 #[test]
225 fn test_validate_missing_required() {
226 let yaml = r"
227model:
228 path: /tmp/model
229";
230 let result = validate_yaml_against_schema(yaml);
231 assert!(result.is_err());
232 let errors = result.unwrap_err();
233 assert!(errors.iter().any(|e| e.contains("data")));
234 }
235
236 #[test]
237 fn test_validate_invalid_lr() {
238 let yaml = r"
239model:
240 path: /tmp/model
241data:
242 train: /tmp/train
243 batch_size: 4
244optimizer:
245 name: adamw
246 lr: -0.1
247training:
248 epochs: 10
249";
250 let result = validate_yaml_against_schema(yaml);
251 assert!(result.is_err());
252 }
253
254 #[test]
255 fn test_validate_zero_epochs() {
256 let yaml = r"
257model:
258 path: /tmp/model
259data:
260 train: /tmp/train
261 batch_size: 4
262optimizer:
263 name: adamw
264 lr: 0.001
265training:
266 epochs: 0
267";
268 let result = validate_yaml_against_schema(yaml);
269 assert!(result.is_err());
270 }
271}