Skip to main content

entrenar/config/validate/
json_schema.rs

1//! JSON Schema generation for YAML configuration validation
2//!
3//! Generates a JSON schema from the TrainSpec struct for external validation
4//! tools and IDE autocompletion.
5//!
6//! Batuta: AI-05 (Declarative Schema Validation)
7
8use serde_json::{json, Value};
9
10/// Generate a JSON schema for the training configuration
11#[allow(dead_code)]
12///
13/// This schema can be used by:
14/// - IDE YAML plugins for autocompletion
15/// - CI validation of config files
16/// - Documentation generation
17pub fn training_config_json_schema() -> Value {
18    json!({
19        "$schema": "http://json-schema.org/draft-07/schema#",
20        "title": "entrenar Training Configuration",
21        "description": "Schema for entrenar YAML training configuration files",
22        "type": "object",
23        "required": ["model", "data", "optimizer", "training"],
24        "properties": {
25            "model": {
26                "type": "object",
27                "required": ["path"],
28                "properties": {
29                    "path": {
30                        "type": "string",
31                        "description": "Path to model weights or HuggingFace repo ID"
32                    },
33                    "hidden_size": { "type": "integer", "minimum": 1 },
34                    "num_layers": { "type": "integer", "minimum": 1 },
35                    "num_heads": { "type": "integer", "minimum": 1 },
36                    "num_kv_heads": { "type": "integer", "minimum": 1 },
37                    "intermediate_size": { "type": "integer", "minimum": 1 },
38                    "vocab_size": { "type": "integer", "minimum": 1 },
39                    "max_position_embeddings": { "type": "integer", "minimum": 1 }
40                }
41            },
42            "data": {
43                "type": "object",
44                "required": ["train", "batch_size"],
45                "properties": {
46                    "train": { "type": "string", "description": "Training data path" },
47                    "val": { "type": "string", "description": "Validation data path" },
48                    "batch_size": { "type": "integer", "minimum": 1 },
49                    "seq_len": { "type": "integer", "minimum": 1 }
50                }
51            },
52            "optimizer": {
53                "type": "object",
54                "required": ["name", "lr"],
55                "properties": {
56                    "name": {
57                        "type": "string",
58                        "enum": ["adam", "adamw", "sgd", "rmsprop", "adagrad", "lamb"]
59                    },
60                    "lr": { "type": "number", "exclusiveMinimum": 0, "maximum": 1 },
61                    "beta1": { "type": "number", "minimum": 0, "maximum": 1 },
62                    "beta2": { "type": "number", "minimum": 0, "maximum": 1 },
63                    "epsilon": { "type": "number", "exclusiveMinimum": 0 },
64                    "weight_decay": { "type": "number", "minimum": 0 }
65                }
66            },
67            "training": {
68                "type": "object",
69                "required": ["epochs"],
70                "properties": {
71                    "epochs": { "type": "integer", "minimum": 1 },
72                    "max_steps": { "type": "integer", "minimum": 1 },
73                    "grad_clip": { "type": "number", "exclusiveMinimum": 0 },
74                    "gradient_accumulation": { "type": "integer", "minimum": 1 },
75                    "save_interval": { "type": "integer", "minimum": 1 },
76                    "log_interval": { "type": "integer", "minimum": 1 },
77                    "lr_scheduler": {
78                        "type": "string",
79                        "enum": ["cosine", "linear", "constant", "step", "exponential", "one_cycle", "plateau"]
80                    },
81                    "warmup_steps": { "type": "integer", "minimum": 0 },
82                    "mixed_precision": {
83                        "type": "string",
84                        "enum": ["bf16", "fp16", "fp32"]
85                    },
86                    "deterministic": { "type": "boolean" },
87                    "seed": { "type": "integer" },
88                    "eval_interval": { "type": "integer", "minimum": 0 },
89                    "patience": { "type": "integer", "minimum": 0 }
90                }
91            },
92            "lora": {
93                "type": "object",
94                "properties": {
95                    "rank": { "type": "integer", "minimum": 1, "maximum": 1024 },
96                    "alpha": { "type": "number", "exclusiveMinimum": 0 },
97                    "dropout": { "type": "number", "minimum": 0, "exclusiveMaximum": 1 },
98                    "target_modules": {
99                        "type": "array",
100                        "items": { "type": "string" },
101                        "minItems": 1
102                    }
103                }
104            },
105            "distributed": {
106                "type": "object",
107                "properties": {
108                    "world_size": { "type": "integer", "minimum": 1 },
109                    "rank": { "type": "integer", "minimum": 0 },
110                    "local_rank": { "type": "integer", "minimum": 0 },
111                    "role": { "type": "string", "enum": ["coordinator", "worker"] },
112                    "backend": { "type": "string", "enum": ["cuda", "wgpu", "auto"] },
113                    "coordinator_addr": { "type": "string" }
114                }
115            }
116        }
117    })
118}
119
120/// Validate a YAML config string against the JSON schema using jsonschema crate
121#[allow(dead_code)]
122pub fn validate_yaml_against_schema(yaml_str: &str) -> Result<(), Vec<String>> {
123    let value: serde_json::Value = match serde_yaml::from_str(yaml_str) {
124        Ok(v) => v,
125        Err(e) => return Err(vec![format!("YAML parse error: {e}")]),
126    };
127
128    let schema = training_config_json_schema();
129    let validator = jsonschema::validator_for(&schema)
130        .map_err(|e| vec![format!("Schema compilation error: {e}")])?;
131
132    let mut errors: Vec<String> = validator
133        .iter_errors(&value)
134        .map(|error| {
135            let path = error.instance_path.to_string();
136            if path.is_empty() {
137                error.to_string()
138            } else {
139                format!("{path}: {error}")
140            }
141        })
142        .collect();
143
144    // Additional semantic checks beyond JSON Schema
145    semantic_checks(&value, &mut errors);
146
147    if errors.is_empty() {
148        Ok(())
149    } else {
150        Err(errors)
151    }
152}
153
154/// Semantic validation checks that go beyond what JSON Schema can express.
155#[allow(dead_code)]
156fn semantic_checks(value: &serde_json::Value, errors: &mut Vec<String>) {
157    let Some(obj) = value.as_object() else {
158        return;
159    };
160
161    if let Some(lr) =
162        obj.get("optimizer").and_then(|o| o.get("lr")).and_then(serde_json::Value::as_f64)
163    {
164        if lr <= 0.0 || lr > 1.0 {
165            errors.push(format!("optimizer.lr must be in (0, 1], got {lr}"));
166        }
167    }
168
169    if let Some(epochs) =
170        obj.get("training").and_then(|t| t.get("epochs")).and_then(serde_json::Value::as_u64)
171    {
172        if epochs == 0 {
173            errors.push("training.epochs must be >= 1".to_string());
174        }
175    }
176
177    if let Some(bs) =
178        obj.get("data").and_then(|d| d.get("batch_size")).and_then(serde_json::Value::as_u64)
179    {
180        if bs == 0 {
181            errors.push("data.batch_size must be >= 1".to_string());
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_json_schema_has_required_fields() {
192        let schema = training_config_json_schema();
193        let required = schema["required"].as_array().unwrap();
194        assert!(required.contains(&json!("model")));
195        assert!(required.contains(&json!("data")));
196        assert!(required.contains(&json!("optimizer")));
197        assert!(required.contains(&json!("training")));
198    }
199
200    #[test]
201    fn test_json_schema_optimizer_enum() {
202        let schema = training_config_json_schema();
203        let opt_enum = &schema["properties"]["optimizer"]["properties"]["name"]["enum"];
204        assert!(opt_enum.as_array().unwrap().contains(&json!("adamw")));
205    }
206
207    #[test]
208    fn test_validate_valid_yaml() {
209        let yaml = r"
210model:
211  path: /tmp/model
212data:
213  train: /tmp/train
214  batch_size: 4
215optimizer:
216  name: adamw
217  lr: 0.001
218training:
219  epochs: 10
220";
221        assert!(validate_yaml_against_schema(yaml).is_ok());
222    }
223
224    #[test]
225    fn test_validate_missing_required() {
226        let yaml = r"
227model:
228  path: /tmp/model
229";
230        let result = validate_yaml_against_schema(yaml);
231        assert!(result.is_err());
232        let errors = result.unwrap_err();
233        assert!(errors.iter().any(|e| e.contains("data")));
234    }
235
236    #[test]
237    fn test_validate_invalid_lr() {
238        let yaml = r"
239model:
240  path: /tmp/model
241data:
242  train: /tmp/train
243  batch_size: 4
244optimizer:
245  name: adamw
246  lr: -0.1
247training:
248  epochs: 10
249";
250        let result = validate_yaml_against_schema(yaml);
251        assert!(result.is_err());
252    }
253
254    #[test]
255    fn test_validate_zero_epochs() {
256        let yaml = r"
257model:
258  path: /tmp/model
259data:
260  train: /tmp/train
261  batch_size: 4
262optimizer:
263  name: adamw
264  lr: 0.001
265training:
266  epochs: 0
267";
268        let result = validate_yaml_against_schema(yaml);
269        assert!(result.is_err());
270    }
271}