entrenar/cli/commands/
validate.rs1use crate::cli::logging::log;
4use crate::cli::LogLevel;
5use crate::config::{load_config, validate_config, TrainSpec, ValidateArgs};
6
7pub fn format_model_info(spec: &TrainSpec) -> String {
9 let mode_str = format!("{:?}", spec.model.mode).to_lowercase();
10 let mut lines = vec![
11 format!(" Model path: {}", spec.model.path.display()),
12 format!(" Model mode: {mode_str}"),
13 format!(" Target layers: {:?}", spec.model.layers),
14 ];
15 if let Some(ref config) = spec.model.config {
16 lines.push(format!(" Config preset: {config}"));
17 }
18 lines.join("\n")
19}
20
21pub fn format_data_info(spec: &TrainSpec) -> String {
23 let mut lines = vec![format!(" Training data: {}", spec.data.train.display())];
24 if let Some(val) = &spec.data.val {
25 lines.push(format!(" Validation data: {}", val.display()));
26 }
27 lines.push(format!(" Batch size: {}", spec.data.batch_size));
28 if let Some(ref tokenizer) = spec.data.tokenizer {
29 lines.push(format!(" Tokenizer: {}", tokenizer.display()));
30 }
31 if let Some(seq_len) = spec.data.seq_len {
32 lines.push(format!(" Sequence length: {seq_len}"));
33 }
34 if let Some(ref col) = spec.data.input_column {
35 lines.push(format!(" Input column: {col}"));
36 }
37 if let Some(ref col) = spec.data.output_column {
38 lines.push(format!(" Output column: {col}"));
39 }
40 if let Some(max_len) = spec.data.max_length {
41 lines.push(format!(" Max length: {max_len}"));
42 }
43 lines.join("\n")
44}
45
46pub fn format_optimizer_info(spec: &TrainSpec) -> String {
48 let mut lines = vec![
49 format!(" Optimizer: {}", spec.optimizer.name),
50 format!(" Learning rate: {}", spec.optimizer.lr),
51 ];
52 if let Some(wd) = spec.optimizer.params.get("weight_decay") {
53 lines.push(format!(" Weight decay: {wd}"));
54 }
55 lines.join("\n")
56}
57
58pub fn format_training_info(spec: &TrainSpec) -> String {
60 let training_mode = format!("{:?}", spec.training.mode).to_lowercase();
61 let mut lines = vec![
62 format!(" Training mode: {training_mode}"),
63 format!(" Epochs: {}", spec.training.epochs),
64 ];
65 if let Some(clip) = spec.training.grad_clip {
66 lines.push(format!(" Gradient clipping: {clip}"));
67 }
68 if let Some(ref sched) = spec.training.lr_scheduler {
69 let mut sched_str = format!(" Scheduler: {sched}");
70 if spec.training.warmup_steps > 0 {
71 sched_str.push_str(&format!(" (warmup={} steps)", spec.training.warmup_steps));
72 }
73 lines.push(sched_str);
74 if let Some(ref params) = spec.training.scheduler_params {
75 for (k, v) in params {
76 lines.push(format!(" {k}: {v}"));
77 }
78 }
79 }
80 if let Some(ga) = spec.training.gradient_accumulation {
81 lines.push(format!(" Gradient accumulation: {ga}"));
82 }
83 if let Some(ref mp) = spec.training.mixed_precision {
84 lines.push(format!(" Mixed precision: {mp}"));
85 }
86 if let Some(seed) = spec.training.seed {
87 lines.push(format!(" Seed: {seed}"));
88 }
89 lines.push(format!(" Output dir: {}", spec.training.output_dir.display()));
90 lines.join("\n")
91}
92
93pub fn format_lora_info(spec: &TrainSpec) -> Option<String> {
95 spec.lora.as_ref().map(|lora| {
96 let mut lines = vec![
97 " LoRA:".to_string(),
98 format!(" Rank: {}", lora.rank),
99 format!(" Alpha: {}", lora.alpha),
100 ];
101 if lora.dropout > 0.0 {
102 lines.push(format!(" Dropout: {}", lora.dropout));
103 }
104 lines.join("\n")
105 })
106}
107
108pub fn format_quant_info(spec: &TrainSpec) -> Option<String> {
110 spec.quantize.as_ref().map(|quant| {
111 format!(" Quantization:\n Bits: {}\n Symmetric: {}", quant.bits, quant.symmetric)
112 })
113}
114
115pub fn format_merge_info(spec: &TrainSpec) -> Option<String> {
117 spec.merge.as_ref().map(|merge| {
118 let mut lines = vec![" Merge:".to_string(), format!(" Method: {}", merge.method)];
119 if let Some(weight) = merge.params.get("weight") {
120 lines.push(format!(" Weight: {weight}"));
121 }
122 lines.join("\n")
123 })
124}
125
126pub fn print_detailed_summary(spec: &TrainSpec) {
128 println!();
129 println!("Configuration Summary:");
130 println!("{}", format_model_info(spec));
131 println!();
132 println!("{}", format_data_info(spec));
133 println!();
134 println!("{}", format_optimizer_info(spec));
135 println!();
136 println!("{}", format_training_info(spec));
137
138 if let Some(lora_info) = format_lora_info(spec) {
139 println!();
140 println!("{lora_info}");
141 }
142
143 if let Some(quant_info) = format_quant_info(spec) {
144 println!();
145 println!("{quant_info}");
146 }
147
148 if let Some(merge_info) = format_merge_info(spec) {
149 println!();
150 println!("{merge_info}");
151 }
152}
153
154pub fn run_validate(args: ValidateArgs, level: LogLevel) -> Result<(), String> {
155 log(level, LogLevel::Normal, &format!("Validating config: {}", args.config.display()));
156
157 let spec = load_config(&args.config).map_err(|e| format!("Config error: {e}"))?;
158
159 validate_config(&spec).map_err(|e| format!("Validation failed: {e}"))?;
160
161 log(level, LogLevel::Normal, "Configuration is valid");
162
163 if args.detailed {
164 print_detailed_summary(&spec);
165 }
166
167 Ok(())
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::config::{
174 DataConfig, LoRASpec, MergeSpec, ModelRef, OptimSpec, QuantSpec, TrainingParams,
175 };
176 use std::collections::HashMap;
177 use std::path::PathBuf;
178
179 fn make_test_spec() -> TrainSpec {
180 TrainSpec {
181 model: ModelRef {
182 path: PathBuf::from("/model/path"),
183 layers: vec!["layer1".to_string()],
184 ..Default::default()
185 },
186 data: DataConfig {
187 train: PathBuf::from("/train.parquet"),
188 val: Some(PathBuf::from("/val.parquet")),
189 batch_size: 32,
190 ..Default::default()
191 },
192 optimizer: OptimSpec {
193 name: "adam".to_string(),
194 lr: 0.001,
195 params: {
196 let mut p = HashMap::new();
197 p.insert("weight_decay".to_string(), serde_json::json!(0.01));
198 p
199 },
200 },
201 training: TrainingParams {
202 epochs: 10,
203 grad_clip: Some(1.0),
204 output_dir: PathBuf::from("/output"),
205 ..Default::default()
206 },
207 lora: Some(LoRASpec {
208 rank: 16,
209 alpha: 32.0,
210 dropout: 0.1,
211 target_modules: vec!["q_proj".to_string()],
212 lora_plus_ratio: 1.0,
213 double_quantize: false,
214 quantize_base: false,
215 }),
216 quantize: Some(QuantSpec { bits: 4, symmetric: true, per_channel: true }),
217 merge: Some(MergeSpec {
218 method: "slerp".to_string(),
219 params: {
220 let mut p = HashMap::new();
221 p.insert("weight".to_string(), serde_json::json!(0.5));
222 p
223 },
224 }),
225 publish: None,
226 }
227 }
228
229 #[test]
230 fn test_format_model_info() {
231 let spec = make_test_spec();
232 let info = format_model_info(&spec);
233 assert!(info.contains("/model/path"));
234 assert!(info.contains("layer1"));
235 assert!(info.contains("tabular"));
236 }
237
238 #[test]
239 fn test_format_model_info_transformer() {
240 let mut spec = make_test_spec();
241 spec.model.mode = crate::config::ModelMode::Transformer;
242 spec.model.config = Some("qwen2_1_5b".into());
243 let info = format_model_info(&spec);
244 assert!(info.contains("transformer"));
245 assert!(info.contains("qwen2_1_5b"));
246 }
247
248 #[test]
249 fn test_format_data_info() {
250 let spec = make_test_spec();
251 let info = format_data_info(&spec);
252 assert!(info.contains("/train.parquet"));
253 assert!(info.contains("/val.parquet"));
254 assert!(info.contains("32"));
255 }
256
257 #[test]
258 fn test_format_data_info_no_val() {
259 let mut spec = make_test_spec();
260 spec.data.val = None;
261 let info = format_data_info(&spec);
262 assert!(info.contains("/train.parquet"));
263 assert!(!info.contains("Validation"));
264 }
265
266 #[test]
267 fn test_format_data_info_llm_fields() {
268 let mut spec = make_test_spec();
269 spec.data.tokenizer = Some(std::path::PathBuf::from("./tokenizer.json"));
270 spec.data.seq_len = Some(2048);
271 spec.data.input_column = Some("text".into());
272 spec.data.output_column = Some("label".into());
273 spec.data.max_length = Some(512);
274 let info = format_data_info(&spec);
275 assert!(info.contains("tokenizer.json"));
276 assert!(info.contains("2048"));
277 assert!(info.contains("text"));
278 assert!(info.contains("label"));
279 assert!(info.contains("512"));
280 }
281
282 #[test]
283 fn test_format_optimizer_info() {
284 let spec = make_test_spec();
285 let info = format_optimizer_info(&spec);
286 assert!(info.contains("adam"));
287 assert!(info.contains("0.001"));
288 assert!(info.contains("Weight decay"));
290 }
291
292 #[test]
293 fn test_format_training_info() {
294 let spec = make_test_spec();
295 let info = format_training_info(&spec);
296 assert!(info.contains("10"));
297 assert!(info.contains("regression"));
298 assert!(info.contains("/output"));
299 }
300
301 #[test]
302 fn test_format_training_info_full() {
303 let mut spec = make_test_spec();
304 spec.training.mode = crate::config::TrainingMode::CausalLm;
305 spec.training.lr_scheduler = Some("cosine".into());
306 spec.training.warmup_steps = 200;
307 spec.training.gradient_accumulation = Some(8);
308 spec.training.mixed_precision = Some("bf16".into());
309 spec.training.seed = Some(42);
310 let mut params = HashMap::new();
311 params.insert("t_max".into(), serde_json::json!(1000));
312 spec.training.scheduler_params = Some(params);
313 let info = format_training_info(&spec);
314 assert!(info.contains("causal"));
315 assert!(info.contains("cosine"));
316 assert!(info.contains("warmup=200"));
317 assert!(info.contains("t_max"));
318 assert!(info.contains('8'));
319 assert!(info.contains("bf16"));
320 assert!(info.contains("42"));
321 }
322
323 #[test]
324 fn test_format_lora_info() {
325 let spec = make_test_spec();
326 let info = format_lora_info(&spec).expect("operation should succeed");
327 assert!(info.contains("16"));
328 assert!(info.contains("32"));
329 assert!(info.contains("0.1"));
330 }
331
332 #[test]
333 fn test_format_lora_info_none() {
334 let mut spec = make_test_spec();
335 spec.lora = None;
336 assert!(format_lora_info(&spec).is_none());
337 }
338
339 #[test]
340 fn test_format_quant_info() {
341 let spec = make_test_spec();
342 let info = format_quant_info(&spec).expect("operation should succeed");
343 assert!(info.contains('4'));
344 assert!(info.contains("true"));
345 }
346
347 #[test]
348 fn test_format_quant_info_none() {
349 let mut spec = make_test_spec();
350 spec.quantize = None;
351 assert!(format_quant_info(&spec).is_none());
352 }
353
354 #[test]
355 fn test_format_merge_info() {
356 let spec = make_test_spec();
357 let info = format_merge_info(&spec).expect("operation should succeed");
358 assert!(info.contains("slerp"));
359 assert!(info.contains("0.5"));
360 }
361
362 #[test]
363 fn test_format_merge_info_none() {
364 let mut spec = make_test_spec();
365 spec.merge = None;
366 assert!(format_merge_info(&spec).is_none());
367 }
368}