Skip to main content

entrenar/cli/commands/
validate.rs

1//! Validate command implementation
2
3use crate::cli::logging::log;
4use crate::cli::LogLevel;
5use crate::config::{load_config, validate_config, TrainSpec, ValidateArgs};
6
7/// Format model information as a string
8pub fn format_model_info(spec: &TrainSpec) -> String {
9    let mode_str = format!("{:?}", spec.model.mode).to_lowercase();
10    let mut lines = vec![
11        format!("  Model path: {}", spec.model.path.display()),
12        format!("  Model mode: {mode_str}"),
13        format!("  Target layers: {:?}", spec.model.layers),
14    ];
15    if let Some(ref config) = spec.model.config {
16        lines.push(format!("  Config preset: {config}"));
17    }
18    lines.join("\n")
19}
20
21/// Format data configuration as a string
22pub fn format_data_info(spec: &TrainSpec) -> String {
23    let mut lines = vec![format!("  Training data: {}", spec.data.train.display())];
24    if let Some(val) = &spec.data.val {
25        lines.push(format!("  Validation data: {}", val.display()));
26    }
27    lines.push(format!("  Batch size: {}", spec.data.batch_size));
28    if let Some(ref tokenizer) = spec.data.tokenizer {
29        lines.push(format!("  Tokenizer: {}", tokenizer.display()));
30    }
31    if let Some(seq_len) = spec.data.seq_len {
32        lines.push(format!("  Sequence length: {seq_len}"));
33    }
34    if let Some(ref col) = spec.data.input_column {
35        lines.push(format!("  Input column: {col}"));
36    }
37    if let Some(ref col) = spec.data.output_column {
38        lines.push(format!("  Output column: {col}"));
39    }
40    if let Some(max_len) = spec.data.max_length {
41        lines.push(format!("  Max length: {max_len}"));
42    }
43    lines.join("\n")
44}
45
46/// Format optimizer configuration as a string
47pub fn format_optimizer_info(spec: &TrainSpec) -> String {
48    let mut lines = vec![
49        format!("  Optimizer: {}", spec.optimizer.name),
50        format!("  Learning rate: {}", spec.optimizer.lr),
51    ];
52    if let Some(wd) = spec.optimizer.params.get("weight_decay") {
53        lines.push(format!("  Weight decay: {wd}"));
54    }
55    lines.join("\n")
56}
57
58/// Format training configuration as a string
59pub fn format_training_info(spec: &TrainSpec) -> String {
60    let training_mode = format!("{:?}", spec.training.mode).to_lowercase();
61    let mut lines = vec![
62        format!("  Training mode: {training_mode}"),
63        format!("  Epochs: {}", spec.training.epochs),
64    ];
65    if let Some(clip) = spec.training.grad_clip {
66        lines.push(format!("  Gradient clipping: {clip}"));
67    }
68    if let Some(ref sched) = spec.training.lr_scheduler {
69        let mut sched_str = format!("  Scheduler: {sched}");
70        if spec.training.warmup_steps > 0 {
71            sched_str.push_str(&format!(" (warmup={} steps)", spec.training.warmup_steps));
72        }
73        lines.push(sched_str);
74        if let Some(ref params) = spec.training.scheduler_params {
75            for (k, v) in params {
76                lines.push(format!("    {k}: {v}"));
77            }
78        }
79    }
80    if let Some(ga) = spec.training.gradient_accumulation {
81        lines.push(format!("  Gradient accumulation: {ga}"));
82    }
83    if let Some(ref mp) = spec.training.mixed_precision {
84        lines.push(format!("  Mixed precision: {mp}"));
85    }
86    if let Some(seed) = spec.training.seed {
87        lines.push(format!("  Seed: {seed}"));
88    }
89    lines.push(format!("  Output dir: {}", spec.training.output_dir.display()));
90    lines.join("\n")
91}
92
93/// Format LoRA configuration as a string
94pub fn format_lora_info(spec: &TrainSpec) -> Option<String> {
95    spec.lora.as_ref().map(|lora| {
96        let mut lines = vec![
97            "  LoRA:".to_string(),
98            format!("    Rank: {}", lora.rank),
99            format!("    Alpha: {}", lora.alpha),
100        ];
101        if lora.dropout > 0.0 {
102            lines.push(format!("    Dropout: {}", lora.dropout));
103        }
104        lines.join("\n")
105    })
106}
107
108/// Format quantization configuration as a string
109pub fn format_quant_info(spec: &TrainSpec) -> Option<String> {
110    spec.quantize.as_ref().map(|quant| {
111        format!("  Quantization:\n    Bits: {}\n    Symmetric: {}", quant.bits, quant.symmetric)
112    })
113}
114
115/// Format merge configuration as a string
116pub fn format_merge_info(spec: &TrainSpec) -> Option<String> {
117    spec.merge.as_ref().map(|merge| {
118        let mut lines = vec!["  Merge:".to_string(), format!("    Method: {}", merge.method)];
119        if let Some(weight) = merge.params.get("weight") {
120            lines.push(format!("    Weight: {weight}"));
121        }
122        lines.join("\n")
123    })
124}
125
126/// Print detailed configuration summary
127pub fn print_detailed_summary(spec: &TrainSpec) {
128    println!();
129    println!("Configuration Summary:");
130    println!("{}", format_model_info(spec));
131    println!();
132    println!("{}", format_data_info(spec));
133    println!();
134    println!("{}", format_optimizer_info(spec));
135    println!();
136    println!("{}", format_training_info(spec));
137
138    if let Some(lora_info) = format_lora_info(spec) {
139        println!();
140        println!("{lora_info}");
141    }
142
143    if let Some(quant_info) = format_quant_info(spec) {
144        println!();
145        println!("{quant_info}");
146    }
147
148    if let Some(merge_info) = format_merge_info(spec) {
149        println!();
150        println!("{merge_info}");
151    }
152}
153
154pub fn run_validate(args: ValidateArgs, level: LogLevel) -> Result<(), String> {
155    log(level, LogLevel::Normal, &format!("Validating config: {}", args.config.display()));
156
157    let spec = load_config(&args.config).map_err(|e| format!("Config error: {e}"))?;
158
159    validate_config(&spec).map_err(|e| format!("Validation failed: {e}"))?;
160
161    log(level, LogLevel::Normal, "Configuration is valid");
162
163    if args.detailed {
164        print_detailed_summary(&spec);
165    }
166
167    Ok(())
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::config::{
174        DataConfig, LoRASpec, MergeSpec, ModelRef, OptimSpec, QuantSpec, TrainingParams,
175    };
176    use std::collections::HashMap;
177    use std::path::PathBuf;
178
179    fn make_test_spec() -> TrainSpec {
180        TrainSpec {
181            model: ModelRef {
182                path: PathBuf::from("/model/path"),
183                layers: vec!["layer1".to_string()],
184                ..Default::default()
185            },
186            data: DataConfig {
187                train: PathBuf::from("/train.parquet"),
188                val: Some(PathBuf::from("/val.parquet")),
189                batch_size: 32,
190                ..Default::default()
191            },
192            optimizer: OptimSpec {
193                name: "adam".to_string(),
194                lr: 0.001,
195                params: {
196                    let mut p = HashMap::new();
197                    p.insert("weight_decay".to_string(), serde_json::json!(0.01));
198                    p
199                },
200            },
201            training: TrainingParams {
202                epochs: 10,
203                grad_clip: Some(1.0),
204                output_dir: PathBuf::from("/output"),
205                ..Default::default()
206            },
207            lora: Some(LoRASpec {
208                rank: 16,
209                alpha: 32.0,
210                dropout: 0.1,
211                target_modules: vec!["q_proj".to_string()],
212                lora_plus_ratio: 1.0,
213                double_quantize: false,
214                quantize_base: false,
215            }),
216            quantize: Some(QuantSpec { bits: 4, symmetric: true, per_channel: true }),
217            merge: Some(MergeSpec {
218                method: "slerp".to_string(),
219                params: {
220                    let mut p = HashMap::new();
221                    p.insert("weight".to_string(), serde_json::json!(0.5));
222                    p
223                },
224            }),
225            publish: None,
226        }
227    }
228
229    #[test]
230    fn test_format_model_info() {
231        let spec = make_test_spec();
232        let info = format_model_info(&spec);
233        assert!(info.contains("/model/path"));
234        assert!(info.contains("layer1"));
235        assert!(info.contains("tabular"));
236    }
237
238    #[test]
239    fn test_format_model_info_transformer() {
240        let mut spec = make_test_spec();
241        spec.model.mode = crate::config::ModelMode::Transformer;
242        spec.model.config = Some("qwen2_1_5b".into());
243        let info = format_model_info(&spec);
244        assert!(info.contains("transformer"));
245        assert!(info.contains("qwen2_1_5b"));
246    }
247
248    #[test]
249    fn test_format_data_info() {
250        let spec = make_test_spec();
251        let info = format_data_info(&spec);
252        assert!(info.contains("/train.parquet"));
253        assert!(info.contains("/val.parquet"));
254        assert!(info.contains("32"));
255    }
256
257    #[test]
258    fn test_format_data_info_no_val() {
259        let mut spec = make_test_spec();
260        spec.data.val = None;
261        let info = format_data_info(&spec);
262        assert!(info.contains("/train.parquet"));
263        assert!(!info.contains("Validation"));
264    }
265
266    #[test]
267    fn test_format_data_info_llm_fields() {
268        let mut spec = make_test_spec();
269        spec.data.tokenizer = Some(std::path::PathBuf::from("./tokenizer.json"));
270        spec.data.seq_len = Some(2048);
271        spec.data.input_column = Some("text".into());
272        spec.data.output_column = Some("label".into());
273        spec.data.max_length = Some(512);
274        let info = format_data_info(&spec);
275        assert!(info.contains("tokenizer.json"));
276        assert!(info.contains("2048"));
277        assert!(info.contains("text"));
278        assert!(info.contains("label"));
279        assert!(info.contains("512"));
280    }
281
282    #[test]
283    fn test_format_optimizer_info() {
284        let spec = make_test_spec();
285        let info = format_optimizer_info(&spec);
286        assert!(info.contains("adam"));
287        assert!(info.contains("0.001"));
288        // weight_decay is in params, check it's present in output
289        assert!(info.contains("Weight decay"));
290    }
291
292    #[test]
293    fn test_format_training_info() {
294        let spec = make_test_spec();
295        let info = format_training_info(&spec);
296        assert!(info.contains("10"));
297        assert!(info.contains("regression"));
298        assert!(info.contains("/output"));
299    }
300
301    #[test]
302    fn test_format_training_info_full() {
303        let mut spec = make_test_spec();
304        spec.training.mode = crate::config::TrainingMode::CausalLm;
305        spec.training.lr_scheduler = Some("cosine".into());
306        spec.training.warmup_steps = 200;
307        spec.training.gradient_accumulation = Some(8);
308        spec.training.mixed_precision = Some("bf16".into());
309        spec.training.seed = Some(42);
310        let mut params = HashMap::new();
311        params.insert("t_max".into(), serde_json::json!(1000));
312        spec.training.scheduler_params = Some(params);
313        let info = format_training_info(&spec);
314        assert!(info.contains("causal"));
315        assert!(info.contains("cosine"));
316        assert!(info.contains("warmup=200"));
317        assert!(info.contains("t_max"));
318        assert!(info.contains('8'));
319        assert!(info.contains("bf16"));
320        assert!(info.contains("42"));
321    }
322
323    #[test]
324    fn test_format_lora_info() {
325        let spec = make_test_spec();
326        let info = format_lora_info(&spec).expect("operation should succeed");
327        assert!(info.contains("16"));
328        assert!(info.contains("32"));
329        assert!(info.contains("0.1"));
330    }
331
332    #[test]
333    fn test_format_lora_info_none() {
334        let mut spec = make_test_spec();
335        spec.lora = None;
336        assert!(format_lora_info(&spec).is_none());
337    }
338
339    #[test]
340    fn test_format_quant_info() {
341        let spec = make_test_spec();
342        let info = format_quant_info(&spec).expect("operation should succeed");
343        assert!(info.contains('4'));
344        assert!(info.contains("true"));
345    }
346
347    #[test]
348    fn test_format_quant_info_none() {
349        let mut spec = make_test_spec();
350        spec.quantize = None;
351        assert!(format_quant_info(&spec).is_none());
352    }
353
354    #[test]
355    fn test_format_merge_info() {
356        let spec = make_test_spec();
357        let info = format_merge_info(&spec).expect("operation should succeed");
358        assert!(info.contains("slerp"));
359        assert!(info.contains("0.5"));
360    }
361
362    #[test]
363    fn test_format_merge_info_none() {
364        let mut spec = make_test_spec();
365        spec.merge = None;
366        assert!(format_merge_info(&spec).is_none());
367    }
368}