Skip to main content

entrenar/cli/commands/
train.rs

1//! Train command implementation
2
3use crate::cli::logging::log;
4use crate::cli::LogLevel;
5use crate::config::{apply_overrides, load_config, train_from_yaml, TrainArgs, TrainSpec};
6
7pub fn run_train(args: TrainArgs, level: LogLevel) -> Result<(), String> {
8    log(level, LogLevel::Normal, &format!("Entrenar: Training from {}", args.config.display()));
9
10    // Load and validate config
11    let mut spec = load_config(&args.config).map_err(|e| format!("Config error: {e}"))?;
12
13    // Apply command-line overrides
14    apply_overrides(&mut spec, &args);
15
16    if args.dry_run {
17        log_dry_run_summary(&spec, level);
18        return Ok(());
19    }
20
21    // Run training
22    train_from_yaml(&args.config).map_err(|e| format!("Training error: {e}"))?;
23
24    log(level, LogLevel::Normal, "Training complete!");
25    Ok(())
26}
27
28/// Log a summary of the training configuration for dry-run mode
29fn log_dry_run_summary(spec: &TrainSpec, level: LogLevel) {
30    log(level, LogLevel::Normal, "Dry run - config validated successfully");
31
32    let mode_str = format!("{:?}", spec.model.mode).to_lowercase();
33    log(level, LogLevel::Normal, &format!("  Model: {} ({})", spec.model.path.display(), mode_str));
34
35    let training_mode = format!("{:?}", spec.training.mode).to_lowercase();
36    log(level, LogLevel::Normal, &format!("  Training mode: {training_mode}"));
37
38    log(
39        level,
40        LogLevel::Normal,
41        &format!("  Optimizer: {} (lr={})", spec.optimizer.name, spec.optimizer.lr),
42    );
43
44    log_scheduler_info(spec, level);
45
46    log(level, LogLevel::Normal, &format!("  Epochs: {}", spec.training.epochs));
47    log(level, LogLevel::Normal, &format!("  Batch size: {}", spec.data.batch_size));
48
49    log_optional_features(spec, level);
50
51    log(level, LogLevel::Normal, &format!("  Output: {}", spec.training.output_dir.display()));
52}
53
54/// Log scheduler information if present
55fn log_scheduler_info(spec: &TrainSpec, level: LogLevel) {
56    if let Some(ref sched) = spec.training.lr_scheduler {
57        let warmup = if spec.training.warmup_steps > 0 {
58            format!(" (warmup={} steps)", spec.training.warmup_steps)
59        } else {
60            String::new()
61        };
62        log(level, LogLevel::Normal, &format!("  Scheduler: {sched}{warmup}"));
63    }
64}
65
66/// Log optional training features (gradient accumulation, mixed precision, LoRA, quantization)
67fn log_optional_features(spec: &TrainSpec, level: LogLevel) {
68    if let Some(ga) = spec.training.gradient_accumulation {
69        let effective = spec.data.batch_size * ga;
70        log(
71            level,
72            LogLevel::Normal,
73            &format!("  Gradient accumulation: {ga} (effective batch={effective})"),
74        );
75    }
76
77    if let Some(ref mp) = spec.training.mixed_precision {
78        log(level, LogLevel::Normal, &format!("  Mixed precision: {mp}"));
79    }
80
81    if let Some(ref lora) = spec.lora {
82        log(
83            level,
84            LogLevel::Normal,
85            &format!(
86                "  LoRA: rank={}, alpha={}, modules={:?}",
87                lora.rank, lora.alpha, lora.target_modules
88            ),
89        );
90    }
91
92    if let Some(ref quant) = spec.quantize {
93        let scheme = if quant.symmetric { "symmetric" } else { "asymmetric" };
94        let gran = if quant.per_channel { "per-channel" } else { "per-tensor" };
95        log(
96            level,
97            LogLevel::Normal,
98            &format!("  Quantization: {}-bit {} {}", quant.bits, scheme, gran),
99        );
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    #![allow(clippy::unwrap_used)]
106    use super::*;
107    use std::path::PathBuf;
108
109    fn make_args(config_path: &str, dry_run: bool) -> TrainArgs {
110        TrainArgs {
111            config: PathBuf::from(config_path),
112            output_dir: None,
113            resume: None,
114            epochs: None,
115            batch_size: None,
116            lr: None,
117            dry_run,
118            save_every: None,
119            log_every: None,
120            seed: None,
121        }
122    }
123
124    #[test]
125    fn test_train_dry_run_valid_config() {
126        // Create a minimal valid config file
127        let config_content = r"
128model:
129  path: /tmp/test_model.gguf
130data:
131  train_path: /tmp/train.json
132  batch_size: 8
133optimizer:
134  name: adam
135  lr: 0.001
136training:
137  epochs: 1
138";
139        let config_path = "/tmp/test_train_config.yaml";
140        std::fs::write(config_path, config_content).expect("file write should succeed");
141
142        let args = make_args(config_path, true);
143        let result = run_train(args, LogLevel::Quiet);
144        assert!(result.is_ok(), "Dry run should succeed: {result:?}");
145
146        std::fs::remove_file(config_path).ok();
147    }
148
149    #[test]
150    fn test_train_invalid_config_path() {
151        let args = make_args("/nonexistent/config.yaml", false);
152        let result = run_train(args, LogLevel::Quiet);
153        assert!(result.is_err(), "Should fail with invalid config path");
154    }
155
156    #[test]
157    fn test_train_dry_run_logs_correctly() {
158        let config_content = r"
159model:
160  path: /tmp/test_model.gguf
161data:
162  train_path: /tmp/train.json
163  batch_size: 16
164optimizer:
165  name: sgd
166  lr: 0.01
167training:
168  epochs: 5
169";
170        let config_path = "/tmp/test_train_config_logs.yaml";
171        std::fs::write(config_path, config_content).expect("file write should succeed");
172
173        let args = make_args(config_path, true);
174        // Test with verbose logging to cover log branches
175        let result = run_train(args, LogLevel::Verbose);
176        assert!(result.is_ok());
177
178        std::fs::remove_file(config_path).ok();
179    }
180
181    // ── dry run with lr scheduler ───────────────────────────────────────
182
183    #[test]
184    fn test_train_dry_run_with_lr_scheduler() {
185        let config_content = r"
186model:
187  path: /tmp/test_model.gguf
188data:
189  train_path: /tmp/train.json
190  batch_size: 8
191optimizer:
192  name: adam
193  lr: 0.001
194training:
195  epochs: 10
196  lr_scheduler: cosine
197  warmup_steps: 100
198";
199        let config_path = "/tmp/test_train_config_sched.yaml";
200        std::fs::write(config_path, config_content).expect("file write should succeed");
201
202        let args = make_args(config_path, true);
203        let result = run_train(args, LogLevel::Normal);
204        assert!(result.is_ok());
205
206        std::fs::remove_file(config_path).ok();
207    }
208
209    #[test]
210    fn test_train_dry_run_with_scheduler_no_warmup() {
211        let config_content = r"
212model:
213  path: /tmp/test_model.gguf
214data:
215  train_path: /tmp/train.json
216  batch_size: 8
217optimizer:
218  name: adam
219  lr: 0.001
220training:
221  epochs: 10
222  lr_scheduler: step
223  warmup_steps: 0
224";
225        let config_path = "/tmp/test_train_config_sched_nowarmup.yaml";
226        std::fs::write(config_path, config_content).expect("file write should succeed");
227
228        let args = make_args(config_path, true);
229        let result = run_train(args, LogLevel::Normal);
230        assert!(result.is_ok());
231
232        std::fs::remove_file(config_path).ok();
233    }
234
235    // ── dry run with gradient accumulation ──────────────────────────────
236
237    #[test]
238    fn test_train_dry_run_with_gradient_accumulation() {
239        let config_content = r"
240model:
241  path: /tmp/test_model.gguf
242data:
243  train_path: /tmp/train.json
244  batch_size: 8
245optimizer:
246  name: adam
247  lr: 0.001
248training:
249  epochs: 3
250  gradient_accumulation: 4
251";
252        let config_path = "/tmp/test_train_config_grad_acc.yaml";
253        std::fs::write(config_path, config_content).expect("file write should succeed");
254
255        let args = make_args(config_path, true);
256        let result = run_train(args, LogLevel::Normal);
257        assert!(result.is_ok());
258
259        std::fs::remove_file(config_path).ok();
260    }
261
262    // ── dry run with mixed precision ────────────────────────────────────
263
264    #[test]
265    fn test_train_dry_run_with_mixed_precision() {
266        let config_content = r"
267model:
268  path: /tmp/test_model.gguf
269data:
270  train_path: /tmp/train.json
271  batch_size: 8
272optimizer:
273  name: adam
274  lr: 0.001
275training:
276  epochs: 1
277  mixed_precision: bf16
278";
279        let config_path = "/tmp/test_train_config_mp.yaml";
280        std::fs::write(config_path, config_content).expect("file write should succeed");
281
282        let args = make_args(config_path, true);
283        let result = run_train(args, LogLevel::Normal);
284        assert!(result.is_ok());
285
286        std::fs::remove_file(config_path).ok();
287    }
288
289    // ── dry run with LoRA ───────────────────────────────────────────────
290
291    #[test]
292    fn test_train_dry_run_with_lora() {
293        let config_content = r"
294model:
295  path: /tmp/test_model.gguf
296data:
297  train_path: /tmp/train.json
298  batch_size: 8
299optimizer:
300  name: adam
301  lr: 0.001
302training:
303  epochs: 1
304lora:
305  rank: 16
306  alpha: 32.0
307  target_modules:
308    - q_proj
309    - v_proj
310";
311        let config_path = "/tmp/test_train_config_lora.yaml";
312        std::fs::write(config_path, config_content).expect("file write should succeed");
313
314        let args = make_args(config_path, true);
315        let result = run_train(args, LogLevel::Normal);
316        assert!(result.is_ok());
317
318        std::fs::remove_file(config_path).ok();
319    }
320
321    // ── dry run with quantization ───────────────────────────────────────
322
323    #[test]
324    fn test_train_dry_run_with_quantization() {
325        let config_content = r"
326model:
327  path: /tmp/test_model.gguf
328data:
329  train_path: /tmp/train.json
330  batch_size: 8
331optimizer:
332  name: adam
333  lr: 0.001
334training:
335  epochs: 1
336quantize:
337  bits: 4
338  symmetric: true
339  per_channel: true
340";
341        let config_path = "/tmp/test_train_config_quant.yaml";
342        std::fs::write(config_path, config_content).expect("file write should succeed");
343
344        let args = make_args(config_path, true);
345        let result = run_train(args, LogLevel::Normal);
346        assert!(result.is_ok());
347
348        std::fs::remove_file(config_path).ok();
349    }
350
351    // ── dry run with asymmetric quantization ────────────────────────────
352
353    #[test]
354    fn test_train_dry_run_with_asymmetric_quantization() {
355        let config_content = r"
356model:
357  path: /tmp/test_model.gguf
358data:
359  train_path: /tmp/train.json
360  batch_size: 8
361optimizer:
362  name: adam
363  lr: 0.001
364training:
365  epochs: 1
366quantize:
367  bits: 8
368  symmetric: false
369  per_channel: false
370";
371        let config_path = "/tmp/test_train_config_quant_asym.yaml";
372        std::fs::write(config_path, config_content).expect("file write should succeed");
373
374        let args = make_args(config_path, true);
375        let result = run_train(args, LogLevel::Normal);
376        assert!(result.is_ok());
377
378        std::fs::remove_file(config_path).ok();
379    }
380
381    // ── dry run with all optional features combined ─────────────────────
382
383    #[test]
384    fn test_train_dry_run_all_features() {
385        let config_content = r"
386model:
387  path: /tmp/test_model.gguf
388data:
389  train_path: /tmp/train.json
390  batch_size: 32
391optimizer:
392  name: adam
393  lr: 0.0001
394training:
395  epochs: 20
396  lr_scheduler: cosine
397  warmup_steps: 500
398  gradient_accumulation: 8
399  mixed_precision: fp16
400lora:
401  rank: 8
402  alpha: 16.0
403  target_modules:
404    - q_proj
405    - k_proj
406    - v_proj
407quantize:
408  bits: 4
409  symmetric: true
410  per_channel: true
411";
412        let config_path = "/tmp/test_train_config_all.yaml";
413        std::fs::write(config_path, config_content).expect("file write should succeed");
414
415        let args = make_args(config_path, true);
416        let result = run_train(args, LogLevel::Verbose);
417        assert!(result.is_ok());
418
419        std::fs::remove_file(config_path).ok();
420    }
421
422    // ── apply_overrides tests ───────────────────────────────────────────
423
424    #[test]
425    fn test_apply_overrides_output_dir() {
426        let config_content = r"
427model:
428  path: /tmp/test_model.gguf
429data:
430  train_path: /tmp/train.json
431  batch_size: 8
432optimizer:
433  name: adam
434  lr: 0.001
435training:
436  epochs: 1
437";
438        let config_path = "/tmp/test_train_override_out.yaml";
439        std::fs::write(config_path, config_content).expect("file write should succeed");
440
441        let mut spec = load_config(PathBuf::from(config_path)).unwrap();
442        let args = TrainArgs {
443            config: PathBuf::from(config_path),
444            output_dir: Some(PathBuf::from("/tmp/override_output")),
445            resume: None,
446            epochs: None,
447            batch_size: None,
448            lr: None,
449            dry_run: true,
450            save_every: None,
451            log_every: None,
452            seed: None,
453        };
454        apply_overrides(&mut spec, &args);
455        assert_eq!(spec.training.output_dir, PathBuf::from("/tmp/override_output"));
456
457        std::fs::remove_file(config_path).ok();
458    }
459
460    #[test]
461    fn test_apply_overrides_epochs() {
462        let config_content = r"
463model:
464  path: /tmp/test_model.gguf
465data:
466  train_path: /tmp/train.json
467  batch_size: 8
468optimizer:
469  name: adam
470  lr: 0.001
471training:
472  epochs: 1
473";
474        let config_path = "/tmp/test_train_override_epochs.yaml";
475        std::fs::write(config_path, config_content).expect("file write should succeed");
476
477        let mut spec = load_config(PathBuf::from(config_path)).unwrap();
478        let args = TrainArgs {
479            config: PathBuf::from(config_path),
480            output_dir: None,
481            resume: None,
482            epochs: Some(99),
483            batch_size: None,
484            lr: None,
485            dry_run: true,
486            save_every: None,
487            log_every: None,
488            seed: None,
489        };
490        apply_overrides(&mut spec, &args);
491        assert_eq!(spec.training.epochs, 99);
492
493        std::fs::remove_file(config_path).ok();
494    }
495
496    #[test]
497    fn test_apply_overrides_batch_size() {
498        let config_content = r"
499model:
500  path: /tmp/test_model.gguf
501data:
502  train_path: /tmp/train.json
503  batch_size: 8
504optimizer:
505  name: adam
506  lr: 0.001
507training:
508  epochs: 1
509";
510        let config_path = "/tmp/test_train_override_batch.yaml";
511        std::fs::write(config_path, config_content).expect("file write should succeed");
512
513        let mut spec = load_config(PathBuf::from(config_path)).unwrap();
514        let args = TrainArgs {
515            config: PathBuf::from(config_path),
516            output_dir: None,
517            resume: None,
518            epochs: None,
519            batch_size: Some(128),
520            lr: None,
521            dry_run: true,
522            save_every: None,
523            log_every: None,
524            seed: None,
525        };
526        apply_overrides(&mut spec, &args);
527        assert_eq!(spec.data.batch_size, 128);
528
529        std::fs::remove_file(config_path).ok();
530    }
531
532    #[test]
533    fn test_apply_overrides_learning_rate() {
534        let config_content = r"
535model:
536  path: /tmp/test_model.gguf
537data:
538  train_path: /tmp/train.json
539  batch_size: 8
540optimizer:
541  name: adam
542  lr: 0.001
543training:
544  epochs: 1
545";
546        let config_path = "/tmp/test_train_override_lr.yaml";
547        std::fs::write(config_path, config_content).expect("file write should succeed");
548
549        let mut spec = load_config(PathBuf::from(config_path)).unwrap();
550        let args = TrainArgs {
551            config: PathBuf::from(config_path),
552            output_dir: None,
553            resume: None,
554            epochs: None,
555            batch_size: None,
556            lr: Some(0.042),
557            dry_run: true,
558            save_every: None,
559            log_every: None,
560            seed: None,
561        };
562        apply_overrides(&mut spec, &args);
563        assert!((spec.optimizer.lr - 0.042).abs() < 1e-6);
564
565        std::fs::remove_file(config_path).ok();
566    }
567
568    #[test]
569    fn test_apply_overrides_save_every() {
570        let config_content = r"
571model:
572  path: /tmp/test_model.gguf
573data:
574  train_path: /tmp/train.json
575  batch_size: 8
576optimizer:
577  name: adam
578  lr: 0.001
579training:
580  epochs: 1
581";
582        let config_path = "/tmp/test_train_override_save.yaml";
583        std::fs::write(config_path, config_content).expect("file write should succeed");
584
585        let mut spec = load_config(PathBuf::from(config_path)).unwrap();
586        let args = TrainArgs {
587            config: PathBuf::from(config_path),
588            output_dir: None,
589            resume: None,
590            epochs: None,
591            batch_size: None,
592            lr: None,
593            dry_run: true,
594            save_every: Some(5),
595            log_every: None,
596            seed: None,
597        };
598        apply_overrides(&mut spec, &args);
599        assert_eq!(spec.training.save_interval, 5);
600
601        std::fs::remove_file(config_path).ok();
602    }
603
604    #[test]
605    fn test_apply_overrides_all_at_once() {
606        let config_content = r"
607model:
608  path: /tmp/test_model.gguf
609data:
610  train_path: /tmp/train.json
611  batch_size: 8
612optimizer:
613  name: adam
614  lr: 0.001
615training:
616  epochs: 1
617";
618        let config_path = "/tmp/test_train_override_all.yaml";
619        std::fs::write(config_path, config_content).expect("file write should succeed");
620
621        let mut spec = load_config(PathBuf::from(config_path)).unwrap();
622        let args = TrainArgs {
623            config: PathBuf::from(config_path),
624            output_dir: Some(PathBuf::from("/tmp/all_override")),
625            resume: None,
626            epochs: Some(50),
627            batch_size: Some(64),
628            lr: Some(0.01),
629            dry_run: true,
630            save_every: Some(10),
631            log_every: None,
632            seed: None,
633        };
634        apply_overrides(&mut spec, &args);
635        assert_eq!(spec.training.output_dir, PathBuf::from("/tmp/all_override"));
636        assert_eq!(spec.training.epochs, 50);
637        assert_eq!(spec.data.batch_size, 64);
638        assert!((spec.optimizer.lr - 0.01).abs() < 1e-6);
639        assert_eq!(spec.training.save_interval, 10);
640
641        std::fs::remove_file(config_path).ok();
642    }
643
644    #[test]
645    fn test_apply_overrides_none_leaves_original() {
646        let config_content = r"
647model:
648  path: /tmp/test_model.gguf
649data:
650  train_path: /tmp/train.json
651  batch_size: 8
652optimizer:
653  name: adam
654  lr: 0.001
655training:
656  epochs: 3
657";
658        let config_path = "/tmp/test_train_override_none.yaml";
659        std::fs::write(config_path, config_content).expect("file write should succeed");
660
661        let mut spec = load_config(PathBuf::from(config_path)).unwrap();
662        let original_epochs = spec.training.epochs;
663        let original_batch = spec.data.batch_size;
664        let original_lr = spec.optimizer.lr;
665        let args = TrainArgs {
666            config: PathBuf::from(config_path),
667            output_dir: None,
668            resume: None,
669            epochs: None,
670            batch_size: None,
671            lr: None,
672            dry_run: true,
673            save_every: None,
674            log_every: None,
675            seed: None,
676        };
677        apply_overrides(&mut spec, &args);
678        assert_eq!(spec.training.epochs, original_epochs);
679        assert_eq!(spec.data.batch_size, original_batch);
680        assert!((spec.optimizer.lr - original_lr).abs() < 1e-6);
681
682        std::fs::remove_file(config_path).ok();
683    }
684
685    // ── invalid YAML content ────────────────────────────────────────────
686
687    #[test]
688    fn test_train_invalid_yaml() {
689        let config_content = "{{invalid yaml content}}";
690        let config_path = "/tmp/test_train_config_invalid.yaml";
691        std::fs::write(config_path, config_content).expect("file write should succeed");
692
693        let args = make_args(config_path, true);
694        let result = run_train(args, LogLevel::Quiet);
695        assert!(result.is_err());
696        assert!(result.unwrap_err().contains("Config error"));
697
698        std::fs::remove_file(config_path).ok();
699    }
700
701    // ── log_dry_run_summary and helpers direct tests ────────────────────
702
703    #[test]
704    fn test_log_dry_run_summary_quiet() {
705        let config_content = r"
706model:
707  path: /tmp/test_model.gguf
708data:
709  train_path: /tmp/train.json
710  batch_size: 8
711optimizer:
712  name: adam
713  lr: 0.001
714training:
715  epochs: 1
716";
717        let config_path = "/tmp/test_train_log_quiet.yaml";
718        std::fs::write(config_path, config_content).expect("file write should succeed");
719
720        let spec = load_config(PathBuf::from(config_path)).unwrap();
721        // Should not panic even in quiet mode
722        log_dry_run_summary(&spec, LogLevel::Quiet);
723
724        std::fs::remove_file(config_path).ok();
725    }
726
727    #[test]
728    fn test_log_scheduler_info_none() {
729        let config_content = r"
730model:
731  path: /tmp/test_model.gguf
732data:
733  train_path: /tmp/train.json
734  batch_size: 8
735optimizer:
736  name: adam
737  lr: 0.001
738training:
739  epochs: 1
740";
741        let config_path = "/tmp/test_train_log_sched_none.yaml";
742        std::fs::write(config_path, config_content).expect("file write should succeed");
743
744        let spec = load_config(PathBuf::from(config_path)).unwrap();
745        // lr_scheduler should be None — no-op branch
746        log_scheduler_info(&spec, LogLevel::Normal);
747
748        std::fs::remove_file(config_path).ok();
749    }
750
751    #[test]
752    fn test_log_optional_features_none() {
753        let config_content = r"
754model:
755  path: /tmp/test_model.gguf
756data:
757  train_path: /tmp/train.json
758  batch_size: 8
759optimizer:
760  name: adam
761  lr: 0.001
762training:
763  epochs: 1
764";
765        let config_path = "/tmp/test_train_log_opt_none.yaml";
766        std::fs::write(config_path, config_content).expect("file write should succeed");
767
768        let spec = load_config(PathBuf::from(config_path)).unwrap();
769        // No optional features — all branches should be no-ops
770        log_optional_features(&spec, LogLevel::Normal);
771
772        std::fs::remove_file(config_path).ok();
773    }
774
775    // ── make_args helper verify ─────────────────────────────────────────
776
777    #[test]
778    fn test_make_args_dry_run_true() {
779        let args = make_args("/tmp/cfg.yaml", true);
780        assert!(args.dry_run);
781        assert_eq!(args.config, PathBuf::from("/tmp/cfg.yaml"));
782        assert!(args.output_dir.is_none());
783        assert!(args.resume.is_none());
784        assert!(args.epochs.is_none());
785        assert!(args.batch_size.is_none());
786        assert!(args.lr.is_none());
787        assert!(args.save_every.is_none());
788        assert!(args.log_every.is_none());
789        assert!(args.seed.is_none());
790    }
791
792    #[test]
793    fn test_make_args_dry_run_false() {
794        let args = make_args("/path/to/config.yaml", false);
795        assert!(!args.dry_run);
796        assert_eq!(args.config, PathBuf::from("/path/to/config.yaml"));
797    }
798
799    // ── causal_lm mode dry run ──────────────────────────────────────────
800
801    #[test]
802    fn test_train_dry_run_causal_lm_mode() {
803        let config_content = r"
804model:
805  path: /tmp/test_model.gguf
806  mode: causal_lm
807data:
808  train_path: /tmp/train.json
809  batch_size: 4
810optimizer:
811  name: adam
812  lr: 0.0001
813training:
814  mode: causal_lm
815  epochs: 2
816";
817        let config_path = "/tmp/test_train_config_causal.yaml";
818        std::fs::write(config_path, config_content).expect("file write should succeed");
819
820        let args = make_args(config_path, true);
821        // dry_run may succeed or fail depending on config parsing;
822        // the important thing is it doesn't panic
823        let _result = run_train(args, LogLevel::Normal);
824
825        std::fs::remove_file(config_path).ok();
826    }
827}