1use crate::cli::logging::log;
4use crate::cli::LogLevel;
5use crate::config::{apply_overrides, load_config, train_from_yaml, TrainArgs, TrainSpec};
6
7pub fn run_train(args: TrainArgs, level: LogLevel) -> Result<(), String> {
8 log(level, LogLevel::Normal, &format!("Entrenar: Training from {}", args.config.display()));
9
10 let mut spec = load_config(&args.config).map_err(|e| format!("Config error: {e}"))?;
12
13 apply_overrides(&mut spec, &args);
15
16 if args.dry_run {
17 log_dry_run_summary(&spec, level);
18 return Ok(());
19 }
20
21 train_from_yaml(&args.config).map_err(|e| format!("Training error: {e}"))?;
23
24 log(level, LogLevel::Normal, "Training complete!");
25 Ok(())
26}
27
28fn log_dry_run_summary(spec: &TrainSpec, level: LogLevel) {
30 log(level, LogLevel::Normal, "Dry run - config validated successfully");
31
32 let mode_str = format!("{:?}", spec.model.mode).to_lowercase();
33 log(level, LogLevel::Normal, &format!(" Model: {} ({})", spec.model.path.display(), mode_str));
34
35 let training_mode = format!("{:?}", spec.training.mode).to_lowercase();
36 log(level, LogLevel::Normal, &format!(" Training mode: {training_mode}"));
37
38 log(
39 level,
40 LogLevel::Normal,
41 &format!(" Optimizer: {} (lr={})", spec.optimizer.name, spec.optimizer.lr),
42 );
43
44 log_scheduler_info(spec, level);
45
46 log(level, LogLevel::Normal, &format!(" Epochs: {}", spec.training.epochs));
47 log(level, LogLevel::Normal, &format!(" Batch size: {}", spec.data.batch_size));
48
49 log_optional_features(spec, level);
50
51 log(level, LogLevel::Normal, &format!(" Output: {}", spec.training.output_dir.display()));
52}
53
54fn log_scheduler_info(spec: &TrainSpec, level: LogLevel) {
56 if let Some(ref sched) = spec.training.lr_scheduler {
57 let warmup = if spec.training.warmup_steps > 0 {
58 format!(" (warmup={} steps)", spec.training.warmup_steps)
59 } else {
60 String::new()
61 };
62 log(level, LogLevel::Normal, &format!(" Scheduler: {sched}{warmup}"));
63 }
64}
65
66fn log_optional_features(spec: &TrainSpec, level: LogLevel) {
68 if let Some(ga) = spec.training.gradient_accumulation {
69 let effective = spec.data.batch_size * ga;
70 log(
71 level,
72 LogLevel::Normal,
73 &format!(" Gradient accumulation: {ga} (effective batch={effective})"),
74 );
75 }
76
77 if let Some(ref mp) = spec.training.mixed_precision {
78 log(level, LogLevel::Normal, &format!(" Mixed precision: {mp}"));
79 }
80
81 if let Some(ref lora) = spec.lora {
82 log(
83 level,
84 LogLevel::Normal,
85 &format!(
86 " LoRA: rank={}, alpha={}, modules={:?}",
87 lora.rank, lora.alpha, lora.target_modules
88 ),
89 );
90 }
91
92 if let Some(ref quant) = spec.quantize {
93 let scheme = if quant.symmetric { "symmetric" } else { "asymmetric" };
94 let gran = if quant.per_channel { "per-channel" } else { "per-tensor" };
95 log(
96 level,
97 LogLevel::Normal,
98 &format!(" Quantization: {}-bit {} {}", quant.bits, scheme, gran),
99 );
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 #![allow(clippy::unwrap_used)]
106 use super::*;
107 use std::path::PathBuf;
108
109 fn make_args(config_path: &str, dry_run: bool) -> TrainArgs {
110 TrainArgs {
111 config: PathBuf::from(config_path),
112 output_dir: None,
113 resume: None,
114 epochs: None,
115 batch_size: None,
116 lr: None,
117 dry_run,
118 save_every: None,
119 log_every: None,
120 seed: None,
121 }
122 }
123
124 #[test]
125 fn test_train_dry_run_valid_config() {
126 let config_content = r"
128model:
129 path: /tmp/test_model.gguf
130data:
131 train_path: /tmp/train.json
132 batch_size: 8
133optimizer:
134 name: adam
135 lr: 0.001
136training:
137 epochs: 1
138";
139 let config_path = "/tmp/test_train_config.yaml";
140 std::fs::write(config_path, config_content).expect("file write should succeed");
141
142 let args = make_args(config_path, true);
143 let result = run_train(args, LogLevel::Quiet);
144 assert!(result.is_ok(), "Dry run should succeed: {result:?}");
145
146 std::fs::remove_file(config_path).ok();
147 }
148
149 #[test]
150 fn test_train_invalid_config_path() {
151 let args = make_args("/nonexistent/config.yaml", false);
152 let result = run_train(args, LogLevel::Quiet);
153 assert!(result.is_err(), "Should fail with invalid config path");
154 }
155
156 #[test]
157 fn test_train_dry_run_logs_correctly() {
158 let config_content = r"
159model:
160 path: /tmp/test_model.gguf
161data:
162 train_path: /tmp/train.json
163 batch_size: 16
164optimizer:
165 name: sgd
166 lr: 0.01
167training:
168 epochs: 5
169";
170 let config_path = "/tmp/test_train_config_logs.yaml";
171 std::fs::write(config_path, config_content).expect("file write should succeed");
172
173 let args = make_args(config_path, true);
174 let result = run_train(args, LogLevel::Verbose);
176 assert!(result.is_ok());
177
178 std::fs::remove_file(config_path).ok();
179 }
180
181 #[test]
184 fn test_train_dry_run_with_lr_scheduler() {
185 let config_content = r"
186model:
187 path: /tmp/test_model.gguf
188data:
189 train_path: /tmp/train.json
190 batch_size: 8
191optimizer:
192 name: adam
193 lr: 0.001
194training:
195 epochs: 10
196 lr_scheduler: cosine
197 warmup_steps: 100
198";
199 let config_path = "/tmp/test_train_config_sched.yaml";
200 std::fs::write(config_path, config_content).expect("file write should succeed");
201
202 let args = make_args(config_path, true);
203 let result = run_train(args, LogLevel::Normal);
204 assert!(result.is_ok());
205
206 std::fs::remove_file(config_path).ok();
207 }
208
209 #[test]
210 fn test_train_dry_run_with_scheduler_no_warmup() {
211 let config_content = r"
212model:
213 path: /tmp/test_model.gguf
214data:
215 train_path: /tmp/train.json
216 batch_size: 8
217optimizer:
218 name: adam
219 lr: 0.001
220training:
221 epochs: 10
222 lr_scheduler: step
223 warmup_steps: 0
224";
225 let config_path = "/tmp/test_train_config_sched_nowarmup.yaml";
226 std::fs::write(config_path, config_content).expect("file write should succeed");
227
228 let args = make_args(config_path, true);
229 let result = run_train(args, LogLevel::Normal);
230 assert!(result.is_ok());
231
232 std::fs::remove_file(config_path).ok();
233 }
234
235 #[test]
238 fn test_train_dry_run_with_gradient_accumulation() {
239 let config_content = r"
240model:
241 path: /tmp/test_model.gguf
242data:
243 train_path: /tmp/train.json
244 batch_size: 8
245optimizer:
246 name: adam
247 lr: 0.001
248training:
249 epochs: 3
250 gradient_accumulation: 4
251";
252 let config_path = "/tmp/test_train_config_grad_acc.yaml";
253 std::fs::write(config_path, config_content).expect("file write should succeed");
254
255 let args = make_args(config_path, true);
256 let result = run_train(args, LogLevel::Normal);
257 assert!(result.is_ok());
258
259 std::fs::remove_file(config_path).ok();
260 }
261
262 #[test]
265 fn test_train_dry_run_with_mixed_precision() {
266 let config_content = r"
267model:
268 path: /tmp/test_model.gguf
269data:
270 train_path: /tmp/train.json
271 batch_size: 8
272optimizer:
273 name: adam
274 lr: 0.001
275training:
276 epochs: 1
277 mixed_precision: bf16
278";
279 let config_path = "/tmp/test_train_config_mp.yaml";
280 std::fs::write(config_path, config_content).expect("file write should succeed");
281
282 let args = make_args(config_path, true);
283 let result = run_train(args, LogLevel::Normal);
284 assert!(result.is_ok());
285
286 std::fs::remove_file(config_path).ok();
287 }
288
289 #[test]
292 fn test_train_dry_run_with_lora() {
293 let config_content = r"
294model:
295 path: /tmp/test_model.gguf
296data:
297 train_path: /tmp/train.json
298 batch_size: 8
299optimizer:
300 name: adam
301 lr: 0.001
302training:
303 epochs: 1
304lora:
305 rank: 16
306 alpha: 32.0
307 target_modules:
308 - q_proj
309 - v_proj
310";
311 let config_path = "/tmp/test_train_config_lora.yaml";
312 std::fs::write(config_path, config_content).expect("file write should succeed");
313
314 let args = make_args(config_path, true);
315 let result = run_train(args, LogLevel::Normal);
316 assert!(result.is_ok());
317
318 std::fs::remove_file(config_path).ok();
319 }
320
321 #[test]
324 fn test_train_dry_run_with_quantization() {
325 let config_content = r"
326model:
327 path: /tmp/test_model.gguf
328data:
329 train_path: /tmp/train.json
330 batch_size: 8
331optimizer:
332 name: adam
333 lr: 0.001
334training:
335 epochs: 1
336quantize:
337 bits: 4
338 symmetric: true
339 per_channel: true
340";
341 let config_path = "/tmp/test_train_config_quant.yaml";
342 std::fs::write(config_path, config_content).expect("file write should succeed");
343
344 let args = make_args(config_path, true);
345 let result = run_train(args, LogLevel::Normal);
346 assert!(result.is_ok());
347
348 std::fs::remove_file(config_path).ok();
349 }
350
351 #[test]
354 fn test_train_dry_run_with_asymmetric_quantization() {
355 let config_content = r"
356model:
357 path: /tmp/test_model.gguf
358data:
359 train_path: /tmp/train.json
360 batch_size: 8
361optimizer:
362 name: adam
363 lr: 0.001
364training:
365 epochs: 1
366quantize:
367 bits: 8
368 symmetric: false
369 per_channel: false
370";
371 let config_path = "/tmp/test_train_config_quant_asym.yaml";
372 std::fs::write(config_path, config_content).expect("file write should succeed");
373
374 let args = make_args(config_path, true);
375 let result = run_train(args, LogLevel::Normal);
376 assert!(result.is_ok());
377
378 std::fs::remove_file(config_path).ok();
379 }
380
381 #[test]
384 fn test_train_dry_run_all_features() {
385 let config_content = r"
386model:
387 path: /tmp/test_model.gguf
388data:
389 train_path: /tmp/train.json
390 batch_size: 32
391optimizer:
392 name: adam
393 lr: 0.0001
394training:
395 epochs: 20
396 lr_scheduler: cosine
397 warmup_steps: 500
398 gradient_accumulation: 8
399 mixed_precision: fp16
400lora:
401 rank: 8
402 alpha: 16.0
403 target_modules:
404 - q_proj
405 - k_proj
406 - v_proj
407quantize:
408 bits: 4
409 symmetric: true
410 per_channel: true
411";
412 let config_path = "/tmp/test_train_config_all.yaml";
413 std::fs::write(config_path, config_content).expect("file write should succeed");
414
415 let args = make_args(config_path, true);
416 let result = run_train(args, LogLevel::Verbose);
417 assert!(result.is_ok());
418
419 std::fs::remove_file(config_path).ok();
420 }
421
422 #[test]
425 fn test_apply_overrides_output_dir() {
426 let config_content = r"
427model:
428 path: /tmp/test_model.gguf
429data:
430 train_path: /tmp/train.json
431 batch_size: 8
432optimizer:
433 name: adam
434 lr: 0.001
435training:
436 epochs: 1
437";
438 let config_path = "/tmp/test_train_override_out.yaml";
439 std::fs::write(config_path, config_content).expect("file write should succeed");
440
441 let mut spec = load_config(PathBuf::from(config_path)).unwrap();
442 let args = TrainArgs {
443 config: PathBuf::from(config_path),
444 output_dir: Some(PathBuf::from("/tmp/override_output")),
445 resume: None,
446 epochs: None,
447 batch_size: None,
448 lr: None,
449 dry_run: true,
450 save_every: None,
451 log_every: None,
452 seed: None,
453 };
454 apply_overrides(&mut spec, &args);
455 assert_eq!(spec.training.output_dir, PathBuf::from("/tmp/override_output"));
456
457 std::fs::remove_file(config_path).ok();
458 }
459
460 #[test]
461 fn test_apply_overrides_epochs() {
462 let config_content = r"
463model:
464 path: /tmp/test_model.gguf
465data:
466 train_path: /tmp/train.json
467 batch_size: 8
468optimizer:
469 name: adam
470 lr: 0.001
471training:
472 epochs: 1
473";
474 let config_path = "/tmp/test_train_override_epochs.yaml";
475 std::fs::write(config_path, config_content).expect("file write should succeed");
476
477 let mut spec = load_config(PathBuf::from(config_path)).unwrap();
478 let args = TrainArgs {
479 config: PathBuf::from(config_path),
480 output_dir: None,
481 resume: None,
482 epochs: Some(99),
483 batch_size: None,
484 lr: None,
485 dry_run: true,
486 save_every: None,
487 log_every: None,
488 seed: None,
489 };
490 apply_overrides(&mut spec, &args);
491 assert_eq!(spec.training.epochs, 99);
492
493 std::fs::remove_file(config_path).ok();
494 }
495
496 #[test]
497 fn test_apply_overrides_batch_size() {
498 let config_content = r"
499model:
500 path: /tmp/test_model.gguf
501data:
502 train_path: /tmp/train.json
503 batch_size: 8
504optimizer:
505 name: adam
506 lr: 0.001
507training:
508 epochs: 1
509";
510 let config_path = "/tmp/test_train_override_batch.yaml";
511 std::fs::write(config_path, config_content).expect("file write should succeed");
512
513 let mut spec = load_config(PathBuf::from(config_path)).unwrap();
514 let args = TrainArgs {
515 config: PathBuf::from(config_path),
516 output_dir: None,
517 resume: None,
518 epochs: None,
519 batch_size: Some(128),
520 lr: None,
521 dry_run: true,
522 save_every: None,
523 log_every: None,
524 seed: None,
525 };
526 apply_overrides(&mut spec, &args);
527 assert_eq!(spec.data.batch_size, 128);
528
529 std::fs::remove_file(config_path).ok();
530 }
531
532 #[test]
533 fn test_apply_overrides_learning_rate() {
534 let config_content = r"
535model:
536 path: /tmp/test_model.gguf
537data:
538 train_path: /tmp/train.json
539 batch_size: 8
540optimizer:
541 name: adam
542 lr: 0.001
543training:
544 epochs: 1
545";
546 let config_path = "/tmp/test_train_override_lr.yaml";
547 std::fs::write(config_path, config_content).expect("file write should succeed");
548
549 let mut spec = load_config(PathBuf::from(config_path)).unwrap();
550 let args = TrainArgs {
551 config: PathBuf::from(config_path),
552 output_dir: None,
553 resume: None,
554 epochs: None,
555 batch_size: None,
556 lr: Some(0.042),
557 dry_run: true,
558 save_every: None,
559 log_every: None,
560 seed: None,
561 };
562 apply_overrides(&mut spec, &args);
563 assert!((spec.optimizer.lr - 0.042).abs() < 1e-6);
564
565 std::fs::remove_file(config_path).ok();
566 }
567
568 #[test]
569 fn test_apply_overrides_save_every() {
570 let config_content = r"
571model:
572 path: /tmp/test_model.gguf
573data:
574 train_path: /tmp/train.json
575 batch_size: 8
576optimizer:
577 name: adam
578 lr: 0.001
579training:
580 epochs: 1
581";
582 let config_path = "/tmp/test_train_override_save.yaml";
583 std::fs::write(config_path, config_content).expect("file write should succeed");
584
585 let mut spec = load_config(PathBuf::from(config_path)).unwrap();
586 let args = TrainArgs {
587 config: PathBuf::from(config_path),
588 output_dir: None,
589 resume: None,
590 epochs: None,
591 batch_size: None,
592 lr: None,
593 dry_run: true,
594 save_every: Some(5),
595 log_every: None,
596 seed: None,
597 };
598 apply_overrides(&mut spec, &args);
599 assert_eq!(spec.training.save_interval, 5);
600
601 std::fs::remove_file(config_path).ok();
602 }
603
604 #[test]
605 fn test_apply_overrides_all_at_once() {
606 let config_content = r"
607model:
608 path: /tmp/test_model.gguf
609data:
610 train_path: /tmp/train.json
611 batch_size: 8
612optimizer:
613 name: adam
614 lr: 0.001
615training:
616 epochs: 1
617";
618 let config_path = "/tmp/test_train_override_all.yaml";
619 std::fs::write(config_path, config_content).expect("file write should succeed");
620
621 let mut spec = load_config(PathBuf::from(config_path)).unwrap();
622 let args = TrainArgs {
623 config: PathBuf::from(config_path),
624 output_dir: Some(PathBuf::from("/tmp/all_override")),
625 resume: None,
626 epochs: Some(50),
627 batch_size: Some(64),
628 lr: Some(0.01),
629 dry_run: true,
630 save_every: Some(10),
631 log_every: None,
632 seed: None,
633 };
634 apply_overrides(&mut spec, &args);
635 assert_eq!(spec.training.output_dir, PathBuf::from("/tmp/all_override"));
636 assert_eq!(spec.training.epochs, 50);
637 assert_eq!(spec.data.batch_size, 64);
638 assert!((spec.optimizer.lr - 0.01).abs() < 1e-6);
639 assert_eq!(spec.training.save_interval, 10);
640
641 std::fs::remove_file(config_path).ok();
642 }
643
644 #[test]
645 fn test_apply_overrides_none_leaves_original() {
646 let config_content = r"
647model:
648 path: /tmp/test_model.gguf
649data:
650 train_path: /tmp/train.json
651 batch_size: 8
652optimizer:
653 name: adam
654 lr: 0.001
655training:
656 epochs: 3
657";
658 let config_path = "/tmp/test_train_override_none.yaml";
659 std::fs::write(config_path, config_content).expect("file write should succeed");
660
661 let mut spec = load_config(PathBuf::from(config_path)).unwrap();
662 let original_epochs = spec.training.epochs;
663 let original_batch = spec.data.batch_size;
664 let original_lr = spec.optimizer.lr;
665 let args = TrainArgs {
666 config: PathBuf::from(config_path),
667 output_dir: None,
668 resume: None,
669 epochs: None,
670 batch_size: None,
671 lr: None,
672 dry_run: true,
673 save_every: None,
674 log_every: None,
675 seed: None,
676 };
677 apply_overrides(&mut spec, &args);
678 assert_eq!(spec.training.epochs, original_epochs);
679 assert_eq!(spec.data.batch_size, original_batch);
680 assert!((spec.optimizer.lr - original_lr).abs() < 1e-6);
681
682 std::fs::remove_file(config_path).ok();
683 }
684
685 #[test]
688 fn test_train_invalid_yaml() {
689 let config_content = "{{invalid yaml content}}";
690 let config_path = "/tmp/test_train_config_invalid.yaml";
691 std::fs::write(config_path, config_content).expect("file write should succeed");
692
693 let args = make_args(config_path, true);
694 let result = run_train(args, LogLevel::Quiet);
695 assert!(result.is_err());
696 assert!(result.unwrap_err().contains("Config error"));
697
698 std::fs::remove_file(config_path).ok();
699 }
700
701 #[test]
704 fn test_log_dry_run_summary_quiet() {
705 let config_content = r"
706model:
707 path: /tmp/test_model.gguf
708data:
709 train_path: /tmp/train.json
710 batch_size: 8
711optimizer:
712 name: adam
713 lr: 0.001
714training:
715 epochs: 1
716";
717 let config_path = "/tmp/test_train_log_quiet.yaml";
718 std::fs::write(config_path, config_content).expect("file write should succeed");
719
720 let spec = load_config(PathBuf::from(config_path)).unwrap();
721 log_dry_run_summary(&spec, LogLevel::Quiet);
723
724 std::fs::remove_file(config_path).ok();
725 }
726
727 #[test]
728 fn test_log_scheduler_info_none() {
729 let config_content = r"
730model:
731 path: /tmp/test_model.gguf
732data:
733 train_path: /tmp/train.json
734 batch_size: 8
735optimizer:
736 name: adam
737 lr: 0.001
738training:
739 epochs: 1
740";
741 let config_path = "/tmp/test_train_log_sched_none.yaml";
742 std::fs::write(config_path, config_content).expect("file write should succeed");
743
744 let spec = load_config(PathBuf::from(config_path)).unwrap();
745 log_scheduler_info(&spec, LogLevel::Normal);
747
748 std::fs::remove_file(config_path).ok();
749 }
750
751 #[test]
752 fn test_log_optional_features_none() {
753 let config_content = r"
754model:
755 path: /tmp/test_model.gguf
756data:
757 train_path: /tmp/train.json
758 batch_size: 8
759optimizer:
760 name: adam
761 lr: 0.001
762training:
763 epochs: 1
764";
765 let config_path = "/tmp/test_train_log_opt_none.yaml";
766 std::fs::write(config_path, config_content).expect("file write should succeed");
767
768 let spec = load_config(PathBuf::from(config_path)).unwrap();
769 log_optional_features(&spec, LogLevel::Normal);
771
772 std::fs::remove_file(config_path).ok();
773 }
774
775 #[test]
778 fn test_make_args_dry_run_true() {
779 let args = make_args("/tmp/cfg.yaml", true);
780 assert!(args.dry_run);
781 assert_eq!(args.config, PathBuf::from("/tmp/cfg.yaml"));
782 assert!(args.output_dir.is_none());
783 assert!(args.resume.is_none());
784 assert!(args.epochs.is_none());
785 assert!(args.batch_size.is_none());
786 assert!(args.lr.is_none());
787 assert!(args.save_every.is_none());
788 assert!(args.log_every.is_none());
789 assert!(args.seed.is_none());
790 }
791
792 #[test]
793 fn test_make_args_dry_run_false() {
794 let args = make_args("/path/to/config.yaml", false);
795 assert!(!args.dry_run);
796 assert_eq!(args.config, PathBuf::from("/path/to/config.yaml"));
797 }
798
799 #[test]
802 fn test_train_dry_run_causal_lm_mode() {
803 let config_content = r"
804model:
805 path: /tmp/test_model.gguf
806 mode: causal_lm
807data:
808 train_path: /tmp/train.json
809 batch_size: 4
810optimizer:
811 name: adam
812 lr: 0.0001
813training:
814 mode: causal_lm
815 epochs: 2
816";
817 let config_path = "/tmp/test_train_config_causal.yaml";
818 std::fs::write(config_path, config_content).expect("file write should succeed");
819
820 let args = make_args(config_path, true);
821 let _result = run_train(args, LogLevel::Normal);
824
825 std::fs::remove_file(config_path).ok();
826 }
827}