1use crate::cli::logging::log;
4use crate::cli::LogLevel;
5use crate::config::{FinetuneArgs, FinetuneCommand};
6
7pub fn run_finetune(args: FinetuneArgs, level: LogLevel) -> Result<(), String> {
8 match args.command {
9 FinetuneCommand::Plan {
10 data,
11 model_path,
12 model_size,
13 num_classes,
14 output_dir,
15 strategy,
16 budget,
17 scout,
18 max_epochs,
19 lr,
20 lora_rank,
21 batch_size,
22 lora_alpha,
23 warmup,
24 gradient_clip,
25 lr_min_ratio,
26 class_weights,
27 target_modules,
28 } => run_plan(
29 data,
30 model_path,
31 model_size,
32 num_classes,
33 output_dir,
34 strategy,
35 budget,
36 scout,
37 max_epochs,
38 lr,
39 lora_rank,
40 batch_size,
41 lora_alpha,
42 warmup,
43 gradient_clip,
44 lr_min_ratio,
45 class_weights,
46 target_modules,
47 level,
48 ),
49 FinetuneCommand::Apply { plan, model_path, data, output_dir } => {
50 run_apply(plan, model_path, data, output_dir, level)
51 }
52 }
53}
54
55#[allow(clippy::too_many_arguments)]
56fn run_plan(
57 data: std::path::PathBuf,
58 model_path: Option<std::path::PathBuf>,
59 model_size: String,
60 num_classes: usize,
61 output_dir: std::path::PathBuf,
62 strategy: String,
63 budget: usize,
64 scout: bool,
65 max_epochs: usize,
66 manual_lr: Option<f32>,
67 manual_lora_rank: Option<usize>,
68 manual_batch_size: Option<usize>,
69 manual_lora_alpha: Option<f32>,
70 manual_warmup: Option<f32>,
71 manual_gradient_clip: Option<f32>,
72 manual_lr_min_ratio: Option<f32>,
73 manual_class_weights: Option<String>,
74 manual_target_modules: Option<String>,
75 level: LogLevel,
76) -> Result<(), String> {
77 use crate::finetune::training_plan::{plan, PlanConfig};
78
79 log(level, LogLevel::Normal, "Generating training plan...");
80
81 let config = PlanConfig {
82 task: "classify".to_string(),
83 data_path: data,
84 val_path: None,
85 test_path: None,
86 model_size,
87 model_path,
88 num_classes,
89 output_dir: output_dir.clone(),
90 strategy,
91 budget,
92 scout,
93 max_epochs,
94 manual_lr,
95 manual_lora_rank,
96 manual_batch_size,
97 manual_lora_alpha,
98 manual_warmup,
99 manual_gradient_clip,
100 manual_lr_min_ratio,
101 manual_class_weights,
102 manual_target_modules,
103 };
104
105 let training_plan = plan(&config).map_err(|e| format!("Plan generation failed: {e}"))?;
106
107 print_plan_summary(&training_plan, level);
109
110 std::fs::create_dir_all(&output_dir)
112 .map_err(|e| format!("Failed to create output dir: {e}"))?;
113 let plan_path = output_dir.join("plan.yaml");
114 std::fs::write(&plan_path, training_plan.to_yaml())
115 .map_err(|e| format!("Failed to write plan: {e}"))?;
116
117 log(level, LogLevel::Normal, &format!("Plan saved to: {}", plan_path.display()));
118 log(
119 level,
120 LogLevel::Normal,
121 &format!(
122 "\nTo execute: apr finetune apply --plan {} --model-path <MODEL_DIR> --data <DATA.jsonl> -o {}",
123 plan_path.display(),
124 output_dir.display()
125 ),
126 );
127
128 Ok(())
129}
130
131fn run_apply(
132 plan_path: std::path::PathBuf,
133 model_path: std::path::PathBuf,
134 data_path: std::path::PathBuf,
135 output_dir: std::path::PathBuf,
136 level: LogLevel,
137) -> Result<(), String> {
138 use crate::finetune::training_plan::{execute_plan, ApplyConfig, TrainingPlan};
139
140 log(level, LogLevel::Normal, &format!("Loading plan from: {}", plan_path.display()));
141
142 let plan_str =
143 std::fs::read_to_string(&plan_path).map_err(|e| format!("Failed to read plan: {e}"))?;
144 let plan =
145 TrainingPlan::from_str(&plan_str).map_err(|e| format!("Failed to parse plan: {e}"))?;
146
147 print_plan_summary(&plan, level);
148
149 log(level, LogLevel::Normal, &format!("Model: {}", model_path.display()));
150 log(level, LogLevel::Normal, &format!("Data: {}", data_path.display()));
151 log(level, LogLevel::Normal, &format!("Output: {}", output_dir.display()));
152 log(level, LogLevel::Normal, "");
153 log(level, LogLevel::Normal, "Starting training...");
154
155 let apply = ApplyConfig {
156 model_path,
157 data_path,
158 output_dir,
159 on_trial_complete: Some(|trial_id, total, summary| {
160 eprintln!(
161 " [{}/{}] val_loss={:.4} val_acc={:.1}% [{}]",
162 trial_id + 1,
163 total,
164 summary.val_loss,
165 summary.val_accuracy * 100.0,
166 summary.status,
167 );
168 }),
169 };
170
171 let result = execute_plan(&plan, &apply).map_err(|e| format!("Training failed: {e}"))?;
172
173 log(level, LogLevel::Normal, "");
175 log(level, LogLevel::Normal, "Training complete!");
176 log(
177 level,
178 LogLevel::Normal,
179 &format!(
180 " Strategy: {} | Trials: {} | Time: {:.1}s",
181 result.strategy,
182 result.trials.len(),
183 result.total_time_ms as f64 / 1000.0,
184 ),
185 );
186
187 if let Some(best) = result.trials.get(result.best_trial_id) {
188 log(level, LogLevel::Normal, &format!(" Best trial #{}", result.best_trial_id + 1));
189 log(
190 level,
191 LogLevel::Normal,
192 &format!(
193 " val_loss={:.4} val_acc={:.1}% epochs={}",
194 best.val_loss,
195 best.val_accuracy * 100.0,
196 best.epochs_run,
197 ),
198 );
199 }
200
201 Ok(())
202}
203
204fn print_plan_summary(plan: &crate::finetune::training_plan::TrainingPlan, level: LogLevel) {
205 let (pass, warn, fail) = plan.check_counts();
206
207 log(level, LogLevel::Normal, &format!("Plan: {} v{}", plan.task, plan.version));
208 log(
209 level,
210 LogLevel::Normal,
211 &format!(
212 " Data: {} samples, {} classes",
213 plan.data.train_samples,
214 plan.data.class_counts.len(),
215 ),
216 );
217 if plan.data.imbalance_ratio > 2.0 {
218 log(
219 level,
220 LogLevel::Normal,
221 &format!(
222 " Imbalance: {:.1}x (auto class weights: {})",
223 plan.data.imbalance_ratio, plan.data.auto_class_weights,
224 ),
225 );
226 }
227 log(
228 level,
229 LogLevel::Normal,
230 &format!(
231 " Model: {} ({}, {} layers, hidden={})",
232 plan.model.architecture, plan.model.size, plan.model.num_layers, plan.model.hidden_size,
233 ),
234 );
235 log(
236 level,
237 LogLevel::Normal,
238 &format!(
239 " HPO: {} (budget={}, scout={}, max_epochs={})",
240 plan.hyperparameters.strategy,
241 plan.hyperparameters.budget,
242 plan.hyperparameters.scout,
243 plan.hyperparameters.max_epochs,
244 ),
245 );
246 log(
247 level,
248 LogLevel::Normal,
249 &format!(
250 " Resources: {:.1} GB VRAM, {:.0} min/epoch, {:.0} min total",
251 plan.resources.estimated_vram_gb,
252 plan.resources.estimated_minutes_per_epoch,
253 plan.resources.estimated_total_minutes,
254 ),
255 );
256 log(level, LogLevel::Normal, &format!(" Pre-flight: {pass} pass, {warn} warn, {fail} fail"));
257 log(level, LogLevel::Normal, &format!(" Verdict: {:?}", plan.verdict));
258}
259
260#[cfg(test)]
261mod tests {
262 #![allow(clippy::unwrap_used)]
263 use super::*;
264 use crate::finetune::training_plan::{
265 CheckStatus, DataAudit, HyperparameterPlan, ModelInfo, PlanVerdict, PreFlightCheck,
266 ResourceEstimate, TrainingPlan,
267 };
268
269 fn make_plan() -> TrainingPlan {
270 TrainingPlan {
271 version: "1.0".to_string(),
272 task: "classify".to_string(),
273 data: DataAudit {
274 train_path: "/data/train.jsonl".to_string(),
275 train_samples: 1000,
276 avg_input_len: 50,
277 class_counts: vec![800, 200],
278 imbalance_ratio: 4.0,
279 auto_class_weights: true,
280 val_samples: Some(100),
281 test_samples: None,
282 duplicates: 0,
283 preamble_count: 0,
284 },
285 model: ModelInfo {
286 size: "0.5B".to_string(),
287 hidden_size: 896,
288 num_layers: 24,
289 architecture: "Qwen2".to_string(),
290 weights_available: true,
291 lora_trainable_params: 1_000_000,
292 classifier_params: 1792,
293 },
294 hyperparameters: HyperparameterPlan {
295 strategy: "tpe".to_string(),
296 budget: 10,
297 scout: false,
298 max_epochs: 5,
299 search_space_params: 6,
300 sample_configs: vec![],
301 manual: None,
302 recommendation: None,
303 },
304 resources: ResourceEstimate {
305 estimated_vram_gb: 6.5,
306 estimated_minutes_per_epoch: 2.0,
307 estimated_total_minutes: 100.0,
308 estimated_checkpoint_mb: 50.0,
309 steps_per_epoch: 32,
310 gpu_device: Some("RTX 4090".to_string()),
311 },
312 pre_flight: vec![
313 PreFlightCheck {
314 name: "data_exists".to_string(),
315 status: CheckStatus::Pass,
316 detail: "ok".to_string(),
317 },
318 PreFlightCheck {
319 name: "vram_check".to_string(),
320 status: CheckStatus::Warn,
321 detail: "tight".to_string(),
322 },
323 ],
324 output_dir: "/tmp/output".to_string(),
325 auto_diagnose: true,
326 verdict: PlanVerdict::WarningsPresent,
327 issues: vec![],
328 }
329 }
330
331 #[test]
332 fn test_print_plan_summary_normal() {
333 let plan = make_plan();
334 print_plan_summary(&plan, LogLevel::Normal);
336 }
337
338 #[test]
339 fn test_print_plan_summary_verbose() {
340 let plan = make_plan();
341 print_plan_summary(&plan, LogLevel::Verbose);
342 }
343
344 #[test]
345 fn test_print_plan_summary_quiet() {
346 let plan = make_plan();
347 print_plan_summary(&plan, LogLevel::Quiet);
348 }
349
350 #[test]
351 fn test_print_plan_summary_no_imbalance() {
352 let mut plan = make_plan();
353 plan.data.imbalance_ratio = 1.0;
354 print_plan_summary(&plan, LogLevel::Normal);
356 }
357
358 #[test]
359 fn test_print_plan_summary_ready() {
360 let mut plan = make_plan();
361 plan.verdict = PlanVerdict::Ready;
362 print_plan_summary(&plan, LogLevel::Normal);
363 }
364
365 #[test]
366 fn test_print_plan_summary_blocked() {
367 let mut plan = make_plan();
368 plan.verdict = PlanVerdict::Blocked;
369 print_plan_summary(&plan, LogLevel::Normal);
370 }
371
372 #[test]
373 fn test_check_counts_all_pass() {
374 let mut plan = make_plan();
375 plan.pre_flight = vec![
376 PreFlightCheck { name: "a".into(), status: CheckStatus::Pass, detail: "ok".into() },
377 PreFlightCheck { name: "b".into(), status: CheckStatus::Pass, detail: "ok".into() },
378 ];
379 let (p, w, f) = plan.check_counts();
380 assert_eq!(p, 2);
381 assert_eq!(w, 0);
382 assert_eq!(f, 0);
383 }
384
385 #[test]
386 fn test_check_counts_mixed() {
387 let plan = make_plan();
388 let (p, w, f) = plan.check_counts();
389 assert_eq!(p, 1);
390 assert_eq!(w, 1);
391 assert_eq!(f, 0);
392 }
393
394 #[test]
395 fn test_check_counts_with_fail() {
396 let mut plan = make_plan();
397 plan.pre_flight.push(PreFlightCheck {
398 name: "c".into(),
399 status: CheckStatus::Fail,
400 detail: "bad".into(),
401 });
402 let (p, w, f) = plan.check_counts();
403 assert_eq!(p, 1);
404 assert_eq!(w, 1);
405 assert_eq!(f, 1);
406 }
407
408 #[test]
409 fn test_check_counts_empty() {
410 let mut plan = make_plan();
411 plan.pre_flight = vec![];
412 let (p, w, f) = plan.check_counts();
413 assert_eq!(p, 0);
414 assert_eq!(w, 0);
415 assert_eq!(f, 0);
416 }
417
418 #[test]
419 fn test_plan_yaml_roundtrip() {
420 let plan = make_plan();
421 let yaml = plan.to_yaml();
422 assert!(!yaml.is_empty());
423 let parsed = crate::finetune::training_plan::TrainingPlan::from_str(&yaml).unwrap();
424 assert_eq!(parsed.task, "classify");
425 assert_eq!(parsed.version, "1.0");
426 assert_eq!(parsed.data.train_samples, 1000);
427 }
428
429 #[test]
430 fn test_plan_json_roundtrip() {
431 let plan = make_plan();
432 let json = plan.to_json();
433 assert!(!json.is_empty());
434 let parsed = crate::finetune::training_plan::TrainingPlan::from_str(&json).unwrap();
435 assert_eq!(parsed.task, "classify");
436 }
437
438 #[test]
439 fn test_run_finetune_plan_missing_data() {
440 let args = FinetuneArgs {
441 command: FinetuneCommand::Plan {
442 data: std::path::PathBuf::from("/nonexistent/data.jsonl"),
443 model_path: None,
444 model_size: "0.5B".to_string(),
445 num_classes: 2,
446 output_dir: std::path::PathBuf::from("/tmp/ft_test_out"),
447 strategy: "manual".to_string(),
448 budget: 1,
449 scout: false,
450 max_epochs: 1,
451 lr: Some(1e-4),
452 lora_rank: Some(8),
453 batch_size: Some(32),
454 lora_alpha: None,
455 warmup: None,
456 gradient_clip: None,
457 lr_min_ratio: None,
458 class_weights: None,
459 target_modules: None,
460 },
461 };
462 let result = run_finetune(args, LogLevel::Quiet);
464 assert!(result.is_err());
465 }
466
467 #[test]
468 fn test_run_finetune_apply_missing_plan() {
469 let args = FinetuneArgs {
470 command: FinetuneCommand::Apply {
471 plan: std::path::PathBuf::from("/nonexistent/plan.yaml"),
472 model_path: std::path::PathBuf::from("/nonexistent/model"),
473 data: std::path::PathBuf::from("/nonexistent/data.jsonl"),
474 output_dir: std::path::PathBuf::from("/tmp/ft_test_out"),
475 },
476 };
477 let result = run_finetune(args, LogLevel::Quiet);
478 assert!(result.is_err());
479 assert!(result.unwrap_err().contains("Failed to read plan"));
480 }
481
482 #[test]
483 fn test_print_plan_summary_large_data() {
484 let mut plan = make_plan();
485 plan.data.train_samples = 1_000_000;
486 plan.data.class_counts = vec![500_000, 250_000, 200_000, 50_000];
487 plan.data.imbalance_ratio = 10.0;
488 print_plan_summary(&plan, LogLevel::Normal);
489 }
490
491 #[test]
492 fn test_print_plan_summary_many_checks() {
493 let mut plan = make_plan();
494 plan.pre_flight = vec![
495 PreFlightCheck { name: "a".into(), status: CheckStatus::Pass, detail: "ok".into() },
496 PreFlightCheck { name: "b".into(), status: CheckStatus::Pass, detail: "ok".into() },
497 PreFlightCheck { name: "c".into(), status: CheckStatus::Warn, detail: "meh".into() },
498 PreFlightCheck { name: "d".into(), status: CheckStatus::Fail, detail: "bad".into() },
499 PreFlightCheck { name: "e".into(), status: CheckStatus::Fail, detail: "worse".into() },
500 ];
501 print_plan_summary(&plan, LogLevel::Normal);
502 }
503
504 #[test]
507 fn test_run_finetune_plan_with_all_manual_params() {
508 let args = FinetuneArgs {
509 command: FinetuneCommand::Plan {
510 data: std::path::PathBuf::from("/nonexistent/data.jsonl"),
511 model_path: Some(std::path::PathBuf::from("/nonexistent/model")),
512 model_size: "9B".to_string(),
513 num_classes: 5,
514 output_dir: std::path::PathBuf::from("/tmp/ft_test_full"),
515 strategy: "manual".to_string(),
516 budget: 1,
517 scout: false,
518 max_epochs: 3,
519 lr: Some(2e-5),
520 lora_rank: Some(8),
521 batch_size: Some(16),
522 lora_alpha: Some(16.0),
523 warmup: Some(0.05),
524 gradient_clip: Some(0.5),
525 lr_min_ratio: Some(0.001),
526 class_weights: Some("sqrt_inverse".to_string()),
527 target_modules: Some("qkv".to_string()),
528 },
529 };
530 let result = run_finetune(args, LogLevel::Quiet);
532 assert!(result.is_err());
533 }
534
535 #[test]
536 fn test_run_finetune_plan_tpe_strategy() {
537 let args = FinetuneArgs {
538 command: FinetuneCommand::Plan {
539 data: std::path::PathBuf::from("/nonexistent/data.jsonl"),
540 model_path: None,
541 model_size: "0.5B".to_string(),
542 num_classes: 2,
543 output_dir: std::path::PathBuf::from("/tmp/ft_test_tpe"),
544 strategy: "tpe".to_string(),
545 budget: 20,
546 scout: true,
547 max_epochs: 5,
548 lr: None,
549 lora_rank: None,
550 batch_size: None,
551 lora_alpha: None,
552 warmup: None,
553 gradient_clip: None,
554 lr_min_ratio: None,
555 class_weights: None,
556 target_modules: None,
557 },
558 };
559 let result = run_finetune(args, LogLevel::Quiet);
560 assert!(result.is_err());
561 }
562
563 #[test]
564 fn test_run_finetune_apply_invalid_plan_content() {
565 let path = std::env::temp_dir().join("ent_ft_bad_plan.yaml");
567 std::fs::write(&path, "not valid plan content {{{{").unwrap();
568 let args = FinetuneArgs {
569 command: FinetuneCommand::Apply {
570 plan: path.clone(),
571 model_path: std::path::PathBuf::from("/nonexistent/model"),
572 data: std::path::PathBuf::from("/nonexistent/data.jsonl"),
573 output_dir: std::path::PathBuf::from("/tmp/ft_test_bad"),
574 },
575 };
576 let result = run_finetune(args, LogLevel::Quiet);
577 assert!(result.is_err());
578 assert!(result.unwrap_err().contains("Failed to parse plan"));
579 let _ = std::fs::remove_file(&path);
580 }
581
582 #[test]
585 fn test_print_plan_summary_zero_resources() {
586 let mut plan = make_plan();
587 plan.resources = ResourceEstimate {
588 estimated_vram_gb: 0.0,
589 estimated_minutes_per_epoch: 0.0,
590 estimated_total_minutes: 0.0,
591 estimated_checkpoint_mb: 0.0,
592 steps_per_epoch: 0,
593 gpu_device: None,
594 };
595 print_plan_summary(&plan, LogLevel::Normal);
596 }
597
598 #[test]
599 fn test_print_plan_summary_single_class() {
600 let mut plan = make_plan();
601 plan.data.class_counts = vec![1000];
602 plan.data.imbalance_ratio = 1.0;
603 print_plan_summary(&plan, LogLevel::Verbose);
604 }
605
606 #[test]
607 fn test_print_plan_summary_many_classes() {
608 let mut plan = make_plan();
609 plan.data.class_counts = vec![100, 200, 150, 300, 50, 80, 120, 90, 60, 70];
610 plan.data.imbalance_ratio = 6.0;
611 print_plan_summary(&plan, LogLevel::Normal);
612 }
613
614 #[test]
615 fn test_print_plan_summary_high_imbalance() {
616 let mut plan = make_plan();
617 plan.data.imbalance_ratio = 100.0;
618 plan.data.auto_class_weights = true;
619 print_plan_summary(&plan, LogLevel::Normal);
620 }
621
622 #[test]
623 fn test_print_plan_summary_boundary_imbalance() {
624 let mut plan = make_plan();
625 plan.data.imbalance_ratio = 2.0;
627 plan.data.auto_class_weights = false;
628 print_plan_summary(&plan, LogLevel::Normal);
629 }
630
631 #[test]
632 fn test_print_plan_summary_just_above_imbalance_threshold() {
633 let mut plan = make_plan();
634 plan.data.imbalance_ratio = 2.01;
635 plan.data.auto_class_weights = true;
636 print_plan_summary(&plan, LogLevel::Normal);
637 }
638
639 #[test]
640 fn test_print_plan_summary_with_hpo_tpe() {
641 let mut plan = make_plan();
642 plan.hyperparameters.strategy = "tpe".to_string();
643 plan.hyperparameters.budget = 20;
644 plan.hyperparameters.scout = true;
645 plan.hyperparameters.max_epochs = 1;
646 plan.hyperparameters.search_space_params = 9;
647 print_plan_summary(&plan, LogLevel::Normal);
648 }
649
650 #[test]
653 fn test_check_counts_all_fail() {
654 let mut plan = make_plan();
655 plan.pre_flight = vec![
656 PreFlightCheck { name: "a".into(), status: CheckStatus::Fail, detail: "bad".into() },
657 PreFlightCheck { name: "b".into(), status: CheckStatus::Fail, detail: "bad".into() },
658 ];
659 let (p, w, f) = plan.check_counts();
660 assert_eq!(p, 0);
661 assert_eq!(w, 0);
662 assert_eq!(f, 2);
663 }
664
665 #[test]
666 fn test_check_counts_all_warn() {
667 let mut plan = make_plan();
668 plan.pre_flight = vec![
669 PreFlightCheck { name: "a".into(), status: CheckStatus::Warn, detail: "eh".into() },
670 PreFlightCheck { name: "b".into(), status: CheckStatus::Warn, detail: "eh".into() },
671 PreFlightCheck { name: "c".into(), status: CheckStatus::Warn, detail: "eh".into() },
672 ];
673 let (p, w, f) = plan.check_counts();
674 assert_eq!(p, 0);
675 assert_eq!(w, 3);
676 assert_eq!(f, 0);
677 }
678
679 #[test]
682 fn test_make_plan_defaults() {
683 let plan = make_plan();
684 assert_eq!(plan.version, "1.0");
685 assert_eq!(plan.task, "classify");
686 assert_eq!(plan.data.train_samples, 1000);
687 assert_eq!(plan.model.hidden_size, 896);
688 assert_eq!(plan.verdict, PlanVerdict::WarningsPresent);
689 assert!(plan.auto_diagnose);
690 }
691
692 #[test]
693 fn test_plan_to_yaml_and_back() {
694 let plan = make_plan();
695 let yaml = plan.to_yaml();
696 assert!(yaml.contains("classify"));
697 assert!(yaml.contains("1.0"));
698 let parsed = crate::finetune::training_plan::TrainingPlan::from_str(&yaml).unwrap();
699 assert_eq!(parsed.data.train_samples, 1000);
700 }
701
702 #[test]
703 fn test_plan_to_json_and_back() {
704 let plan = make_plan();
705 let json = plan.to_json();
706 assert!(json.contains("classify"));
707 let parsed = crate::finetune::training_plan::TrainingPlan::from_str(&json).unwrap();
708 assert_eq!(parsed.model.architecture, "Qwen2");
709 }
710
711 #[test]
714 fn test_run_finetune_apply_valid_plan_file_missing_model() {
715 let plan = make_plan();
716 let plan_path = std::env::temp_dir().join("ent_ft_valid_plan.yaml");
717 std::fs::write(&plan_path, plan.to_yaml()).unwrap();
718 let args = FinetuneArgs {
719 command: FinetuneCommand::Apply {
720 plan: plan_path.clone(),
721 model_path: std::path::PathBuf::from("/nonexistent/model"),
722 data: std::path::PathBuf::from("/nonexistent/data.jsonl"),
723 output_dir: std::path::PathBuf::from("/tmp/ft_test_valid"),
724 },
725 };
726 let result = run_finetune(args, LogLevel::Quiet);
728 assert!(result.is_err());
729 let _ = std::fs::remove_file(&plan_path);
730 }
731}