1use std::path::PathBuf;
17
18use serde::{Deserialize, Serialize};
19
20use super::classification::{corpus_stats, load_safety_corpus};
21use super::classify_tuner::{default_classify_search_space, extract_trial_params, TuneStrategy};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct PlanConfig {
33 pub task: String,
35 pub data_path: PathBuf,
37 pub val_path: Option<PathBuf>,
39 pub test_path: Option<PathBuf>,
41 pub model_size: String,
43 pub model_path: Option<PathBuf>,
45 pub num_classes: usize,
47 pub output_dir: PathBuf,
49 pub strategy: String,
51 pub budget: usize,
53 pub scout: bool,
55 pub max_epochs: usize,
57 pub manual_lr: Option<f32>,
59 pub manual_lora_rank: Option<usize>,
61 pub manual_batch_size: Option<usize>,
63 pub manual_lora_alpha: Option<f32>,
65 pub manual_warmup: Option<f32>,
67 pub manual_gradient_clip: Option<f32>,
69 pub manual_lr_min_ratio: Option<f32>,
71 pub manual_class_weights: Option<String>,
73 pub manual_target_modules: Option<String>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct TrainingPlan {
87 pub version: String,
89 pub task: String,
91
92 pub data: DataAudit,
95
96 pub model: ModelInfo,
99
100 pub hyperparameters: HyperparameterPlan,
103
104 pub resources: ResourceEstimate,
107
108 pub pre_flight: Vec<PreFlightCheck>,
111
112 pub output_dir: String,
115 pub auto_diagnose: bool,
117
118 pub verdict: PlanVerdict,
121 pub issues: Vec<PlanIssue>,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct DataAudit {
128 pub train_path: String,
130 pub train_samples: usize,
132 pub avg_input_len: usize,
134 pub class_counts: Vec<usize>,
136 pub imbalance_ratio: f64,
138 pub auto_class_weights: bool,
140 pub val_samples: Option<usize>,
142 pub test_samples: Option<usize>,
144 pub duplicates: usize,
146 pub preamble_count: usize,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ModelInfo {
153 pub size: String,
155 pub hidden_size: usize,
157 pub num_layers: usize,
159 pub architecture: String,
161 pub weights_available: bool,
163 pub lora_trainable_params: usize,
165 pub classifier_params: usize,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct HyperparameterPlan {
172 pub strategy: String,
174 pub budget: usize,
176 pub scout: bool,
178 pub max_epochs: usize,
180 pub search_space_params: usize,
182 pub sample_configs: Vec<TrialPreview>,
184 pub manual: Option<ManualConfig>,
186 pub recommendation: Option<String>,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct TrialPreview {
193 pub trial: usize,
195 pub learning_rate: f32,
197 pub lora_rank: usize,
199 pub lora_alpha: f32,
201 pub batch_size: usize,
203 pub warmup: f32,
205 pub gradient_clip: f32,
207 pub class_weights: String,
209 pub target_modules: String,
211 pub lr_min_ratio: f32,
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct ManualConfig {
218 pub learning_rate: f32,
220 pub lora_rank: usize,
222 pub batch_size: usize,
224 #[serde(default)]
226 pub lora_alpha: Option<f32>,
227 #[serde(default)]
229 pub warmup_fraction: Option<f32>,
230 #[serde(default)]
232 pub gradient_clip_norm: Option<f32>,
233 #[serde(default)]
235 pub lr_min_ratio: Option<f32>,
236 #[serde(default)]
238 pub class_weights: Option<String>,
239 #[serde(default)]
241 pub target_modules: Option<String>,
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct ResourceEstimate {
247 pub estimated_vram_gb: f64,
249 pub estimated_minutes_per_epoch: f64,
251 pub estimated_total_minutes: f64,
253 pub estimated_checkpoint_mb: f64,
255 pub steps_per_epoch: usize,
257 pub gpu_device: Option<String>,
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct PreFlightCheck {
264 pub name: String,
266 pub status: CheckStatus,
268 pub detail: String,
270}
271
272#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
274pub enum CheckStatus {
275 Pass,
277 Warn,
279 Fail,
281}
282
283#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
285pub enum PlanVerdict {
286 Ready,
288 WarningsPresent,
290 Blocked,
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct PlanIssue {
297 pub severity: CheckStatus,
299 pub category: String,
301 pub message: String,
303 pub fix: Option<String>,
305}
306
307pub fn plan(config: &PlanConfig) -> crate::Result<TrainingPlan> {
316 let mut issues: Vec<PlanIssue> = Vec::new();
317 let mut pre_flight: Vec<PreFlightCheck> = Vec::new();
318
319 let data = audit_data(config, &mut issues, &mut pre_flight)?;
322
323 let model = resolve_model(config, &mut pre_flight);
326
327 let hyperparameters = build_hpo_plan(config, data.train_samples, &mut issues);
330
331 let batch_size = hyperparameters.manual.as_ref().map_or(64, |m| m.batch_size);
336 let resources = estimate_resources(config, &model, &data, batch_size);
337
338 if config.output_dir.exists() {
342 let has_checkpoints = config.output_dir.join("metadata.json").exists()
343 || config.output_dir.join("epoch_001").exists();
344 if has_checkpoints {
345 pre_flight.push(PreFlightCheck {
346 name: "output_dir".to_string(),
347 status: CheckStatus::Warn,
348 detail: format!(
349 "Output directory {} already contains checkpoints — may overwrite",
350 config.output_dir.display()
351 ),
352 });
353 issues.push(PlanIssue {
354 severity: CheckStatus::Warn,
355 category: "Output".to_string(),
356 message: "Checkpoint directory already contains previous run".to_string(),
357 fix: Some("Use a fresh output directory or rename existing one".to_string()),
358 });
359 } else {
360 pre_flight.push(PreFlightCheck {
361 name: "output_dir".to_string(),
362 status: CheckStatus::Pass,
363 detail: format!("Output directory {} exists", config.output_dir.display()),
364 });
365 }
366 } else {
367 pre_flight.push(PreFlightCheck {
368 name: "output_dir".to_string(),
369 status: CheckStatus::Pass,
370 detail: format!("Output directory {} will be created", config.output_dir.display()),
371 });
372 }
373
374 pre_flight.push(PreFlightCheck {
376 name: "class_weights_persist".to_string(),
377 status: CheckStatus::Pass,
378 detail: "class_weights saved in checkpoint metadata (entrenar ≥0.7.5)".to_string(),
379 });
380
381 let has_fail = pre_flight.iter().any(|c| c.status == CheckStatus::Fail)
384 || issues.iter().any(|i| i.severity == CheckStatus::Fail);
385 let has_warn = pre_flight.iter().any(|c| c.status == CheckStatus::Warn)
386 || issues.iter().any(|i| i.severity == CheckStatus::Warn);
387
388 let verdict = if has_fail {
389 PlanVerdict::Blocked
390 } else if has_warn {
391 PlanVerdict::WarningsPresent
392 } else {
393 PlanVerdict::Ready
394 };
395
396 Ok(TrainingPlan {
397 version: "1.0".to_string(),
398 task: config.task.clone(),
399 data,
400 model,
401 hyperparameters,
402 resources,
403 pre_flight,
404 output_dir: config.output_dir.display().to_string(),
405 auto_diagnose: true,
406 verdict,
407 issues,
408 })
409}
410
411fn audit_data(
417 config: &PlanConfig,
418 issues: &mut Vec<PlanIssue>,
419 pre_flight: &mut Vec<PreFlightCheck>,
420) -> crate::Result<DataAudit> {
421 if !config.data_path.exists() {
423 pre_flight.push(PreFlightCheck {
424 name: "data_file".to_string(),
425 status: CheckStatus::Fail,
426 detail: format!("Training data not found: {}", config.data_path.display()),
427 });
428 return Err(crate::Error::Io(format!(
429 "Training data not found: {}",
430 config.data_path.display()
431 )));
432 }
433
434 let corpus = load_safety_corpus(&config.data_path, config.num_classes)?;
436 let stats = corpus_stats(&corpus, config.num_classes);
437
438 pre_flight.push(PreFlightCheck {
439 name: "data_file".to_string(),
440 status: CheckStatus::Pass,
441 detail: format!("{} samples loaded from {}", stats.total, config.data_path.display()),
442 });
443
444 let empty_classes: Vec<usize> =
446 stats.class_counts.iter().enumerate().filter(|(_, &c)| c == 0).map(|(i, _)| i).collect();
447 if empty_classes.is_empty() {
448 pre_flight.push(PreFlightCheck {
449 name: "class_coverage".to_string(),
450 status: CheckStatus::Pass,
451 detail: format!("All {} classes have samples", config.num_classes),
452 });
453 } else {
454 pre_flight.push(PreFlightCheck {
455 name: "class_coverage".to_string(),
456 status: CheckStatus::Fail,
457 detail: format!("Classes with zero samples: {empty_classes:?}"),
458 });
459 issues.push(PlanIssue {
460 severity: CheckStatus::Fail,
461 category: "Data".to_string(),
462 message: format!("Classes {empty_classes:?} have zero training samples"),
463 fix: Some("Add samples for missing classes or reduce num_classes".to_string()),
464 });
465 }
466
467 let min_count = stats.class_counts.iter().copied().min().unwrap_or(1).max(1);
469 let max_count = stats.class_counts.iter().copied().max().unwrap_or(1);
470 let imbalance_ratio = max_count as f64 / min_count as f64;
471 let auto_class_weights = imbalance_ratio > 2.0;
472
473 if imbalance_ratio > 5.0 {
474 issues.push(PlanIssue {
475 severity: CheckStatus::Warn,
476 category: "Data".to_string(),
477 message: format!(
478 "Severe class imbalance ({imbalance_ratio:.1}:1) — sqrt-inverse weights will be auto-applied"
479 ),
480 fix: Some("Consider oversampling minority classes: apr data balance --strategy oversample".to_string()),
481 });
482 }
483
484 let mut seen = std::collections::HashSet::new();
486 let mut duplicates = 0usize;
487 for s in &corpus {
488 if !seen.insert(&s.input) {
489 duplicates += 1;
490 }
491 }
492 if duplicates > 0 {
493 issues.push(PlanIssue {
494 severity: CheckStatus::Warn,
495 category: "Data".to_string(),
496 message: format!(
497 "{duplicates} duplicate inputs detected ({:.1}%)",
498 duplicates as f64 / stats.total as f64 * 100.0
499 ),
500 fix: Some("Remove duplicates: apr data dedup".to_string()),
501 });
502 }
503
504 let preamble_count = corpus
506 .iter()
507 .filter(|s| {
508 s.input.starts_with("#!/")
509 || s.input.starts_with("#! /")
510 || s.input.starts_with("set -")
511 })
512 .count();
513 if preamble_count > stats.total / 10 {
514 issues.push(PlanIssue {
515 severity: CheckStatus::Warn,
516 category: "Data".to_string(),
517 message: format!(
518 "{preamble_count} samples ({:.0}%) have shell preamble",
519 preamble_count as f64 / stats.total as f64 * 100.0
520 ),
521 fix: Some("Strip preambles: use --strip-preamble during data export".to_string()),
522 });
523 }
524
525 if stats.total < 100 {
527 issues.push(PlanIssue {
528 severity: CheckStatus::Warn,
529 category: "Data".to_string(),
530 message: format!("Only {} samples — may be insufficient for fine-tuning", stats.total),
531 fix: None,
532 });
533 }
534
535 let val_samples = count_file_samples(config.val_path.as_ref(), config.num_classes);
537 let test_samples = count_file_samples(config.test_path.as_ref(), config.num_classes);
538
539 Ok(DataAudit {
540 train_path: config.data_path.display().to_string(),
541 train_samples: stats.total,
542 avg_input_len: stats.avg_input_len,
543 class_counts: stats.class_counts,
544 imbalance_ratio,
545 auto_class_weights,
546 val_samples,
547 test_samples,
548 duplicates,
549 preamble_count,
550 })
551}
552
553pub(crate) fn count_file_samples(path: Option<&PathBuf>, num_classes: usize) -> Option<usize> {
555 path.and_then(|p| {
556 if p.exists() {
557 load_safety_corpus(p, num_classes).ok().map(|c| c.len())
558 } else {
559 None
560 }
561 })
562}
563
564pub(crate) fn resolve_model(
566 config: &PlanConfig,
567 pre_flight: &mut Vec<PreFlightCheck>,
568) -> ModelInfo {
569 let (hidden_size, num_layers, architecture) = match config.model_size.as_str() {
570 "0.5B" | "500M" | "qwen2-0.5b" => (896, 24, "qwen2"),
571 "9B" | "qwen3.5-9b" => (4096, 48, "qwen3.5"),
572 "7B" | "llama2-7b" => (4096, 32, "llama2"),
573 "13B" | "llama2-13b" => (5120, 40, "llama2"),
574 _ => (896, 24, "unknown"),
575 };
576
577 let weights_available = config.model_path.as_ref().is_some_and(|p| p.is_dir());
579 if let Some(ref path) = config.model_path {
580 if weights_available {
581 let has_safetensors = path.join("model.safetensors").exists()
583 || path.join("model-00001-of-00002.safetensors").exists();
584 let has_tokenizer = path.join("tokenizer.json").exists();
585
586 if has_safetensors && has_tokenizer {
587 pre_flight.push(PreFlightCheck {
588 name: "model_weights".to_string(),
589 status: CheckStatus::Pass,
590 detail: format!("Model weights found at {}", path.display()),
591 });
592 } else {
593 let mut missing = Vec::new();
594 if !has_safetensors {
595 missing.push("model.safetensors");
596 }
597 if !has_tokenizer {
598 missing.push("tokenizer.json");
599 }
600 pre_flight.push(PreFlightCheck {
601 name: "model_weights".to_string(),
602 status: CheckStatus::Warn,
603 detail: format!("Model directory exists but missing: {}", missing.join(", ")),
604 });
605 }
606 } else {
607 pre_flight.push(PreFlightCheck {
608 name: "model_weights".to_string(),
609 status: CheckStatus::Fail,
610 detail: format!("Model path not found: {}", path.display()),
611 });
612 }
613 } else {
614 pre_flight.push(PreFlightCheck {
615 name: "model_weights".to_string(),
616 status: CheckStatus::Warn,
617 detail: "No model path specified — will use default model resolution".to_string(),
618 });
619 }
620
621 let default_rank = config.manual_lora_rank.unwrap_or(16);
624 let lora_trainable_params = 2 * default_rank * hidden_size * 2 * num_layers;
625 let classifier_params = hidden_size * config.num_classes + config.num_classes;
626
627 ModelInfo {
628 size: config.model_size.clone(),
629 hidden_size,
630 num_layers,
631 architecture: architecture.to_string(),
632 weights_available,
633 lora_trainable_params,
634 classifier_params,
635 }
636}
637
638pub(crate) fn build_hpo_plan(
640 config: &PlanConfig,
641 train_samples: usize,
642 issues: &mut Vec<PlanIssue>,
643) -> HyperparameterPlan {
644 let strategy = config.strategy.as_str();
645
646 if strategy == "manual" {
647 let lr = config.manual_lr.unwrap_or(1e-4);
648 let rank = config.manual_lora_rank.unwrap_or(16);
649 let batch = config.manual_batch_size.unwrap_or(32);
650
651 issues.push(PlanIssue {
653 severity: CheckStatus::Warn,
654 category: "Hyperparameters".to_string(),
655 message: "Using manual hyperparameters — HPO (--strategy tpe) searches 9 parameters automatically".to_string(),
656 fix: Some(format!(
657 "apr train plan --strategy tpe --budget 20 --scout --data {}",
658 config.data_path.display()
659 )),
660 });
661
662 return HyperparameterPlan {
663 strategy: "manual".to_string(),
664 budget: 0,
665 scout: false,
666 max_epochs: config.max_epochs,
667 search_space_params: 0,
668 sample_configs: Vec::new(),
669 manual: Some(ManualConfig {
670 learning_rate: lr,
671 lora_rank: rank,
672 batch_size: batch,
673 lora_alpha: config.manual_lora_alpha,
674 warmup_fraction: config.manual_warmup,
675 gradient_clip_norm: config.manual_gradient_clip,
676 lr_min_ratio: config.manual_lr_min_ratio,
677 class_weights: config.manual_class_weights.clone(),
678 target_modules: config.manual_target_modules.clone(),
679 }),
680 recommendation: Some(
681 "Consider using --strategy tpe for automated hyperparameter search".to_string(),
682 ),
683 };
684 }
685
686 let tune_strategy: TuneStrategy = strategy.parse().unwrap_or(TuneStrategy::Tpe);
688
689 let space = default_classify_search_space();
691 let mut searcher: Box<dyn super::classify_tuner::TuneSearcher> = match tune_strategy {
692 TuneStrategy::Tpe => {
693 let n_startup = (config.budget / 3).max(3);
694 Box::new(super::tune_searchers::TpeSearcher::new(space.clone(), n_startup))
695 }
696 TuneStrategy::Grid => Box::new(super::tune_searchers::GridSearcher::new(space.clone(), 3)),
697 TuneStrategy::Random => Box::new(super::tune_searchers::RandomSearcher::new(space.clone())),
698 };
699
700 let num_previews = config.budget.min(3);
701 let mut sample_configs = Vec::new();
702 for i in 0..num_previews {
703 if let Ok(trial) = searcher.suggest() {
704 let (lr, rank, alpha, batch, warmup, clip, weights, targets, lr_min) =
705 extract_trial_params(&trial.config);
706 sample_configs.push(TrialPreview {
707 trial: i + 1,
708 learning_rate: lr,
709 lora_rank: rank,
710 lora_alpha: alpha,
711 batch_size: batch,
712 warmup,
713 gradient_clip: clip,
714 class_weights: weights,
715 target_modules: targets,
716 lr_min_ratio: lr_min,
717 });
718 }
719 }
720
721 if config.budget < 5 && tune_strategy == TuneStrategy::Tpe {
723 issues.push(PlanIssue {
724 severity: CheckStatus::Warn,
725 category: "Hyperparameters".to_string(),
726 message: format!(
727 "TPE budget {} is low — needs ≥5 trials for Bayesian optimization to converge",
728 config.budget
729 ),
730 fix: Some("Use --budget 20 for better results".to_string()),
731 });
732 }
733
734 if !config.scout && train_samples > 10_000 && config.max_epochs > 1 {
736 issues.push(PlanIssue {
737 severity: CheckStatus::Warn,
738 category: "Hyperparameters".to_string(),
739 message: format!(
740 "Full HPO with {} samples × {} epochs × {} trials = ~{:.0} GPU hours",
741 train_samples,
742 config.max_epochs,
743 config.budget,
744 estimate_gpu_hours(train_samples, config.max_epochs, config.budget)
745 ),
746 fix: Some(
747 "Use --scout for 1-epoch trials first, then --from-scout for full run".to_string(),
748 ),
749 });
750 }
751
752 HyperparameterPlan {
753 strategy: strategy.to_string(),
754 budget: config.budget,
755 scout: config.scout,
756 max_epochs: if config.scout { 1 } else { config.max_epochs },
757 search_space_params: 9,
758 sample_configs,
759 manual: None,
760 recommendation: None,
761 }
762}
763
764pub(crate) fn estimate_gpu_hours(train_samples: usize, max_epochs: usize, budget: usize) -> f64 {
766 let batch_size = 64;
768 let steps_per_epoch = train_samples.div_ceil(batch_size);
769 let seconds_per_epoch = steps_per_epoch as f64 * 58.0;
770 let total_seconds = seconds_per_epoch * max_epochs as f64 * budget as f64;
771 total_seconds / 3600.0
772}
773
774pub(crate) fn estimate_resources(
776 config: &PlanConfig,
777 model: &ModelInfo,
778 data: &DataAudit,
779 batch_size: usize,
780) -> ResourceEstimate {
781 let base_vram = match model.hidden_size {
784 896 => 2.5, 4096 => 18.0, 5120 => 26.0, _ => 3.0,
788 };
789
790 let steps_per_epoch = data.train_samples.div_ceil(batch_size);
791
792 let seconds_per_step = match model.hidden_size {
797 896 => 58.0, 4096 => 270.0, 5120 => 450.0, _ => 90.0,
801 };
802 let minutes_per_epoch = (steps_per_epoch as f64 * seconds_per_step) / 60.0;
803
804 let total_epochs = if config.scout { 1 } else { config.max_epochs };
805 let total_trials = if config.strategy == "manual" { 1 } else { config.budget };
806 let total_minutes = minutes_per_epoch * total_epochs as f64 * total_trials as f64;
807
808 let checkpoint_mb =
810 (model.lora_trainable_params + model.classifier_params) as f64 * 4.0 / 1_048_576.0;
811
812 let gpu_device = detect_gpu_device();
814
815 ResourceEstimate {
816 estimated_vram_gb: base_vram,
817 estimated_minutes_per_epoch: minutes_per_epoch,
818 estimated_total_minutes: total_minutes,
819 estimated_checkpoint_mb: checkpoint_mb,
820 steps_per_epoch,
821 gpu_device,
822 }
823}
824
825pub(crate) fn detect_gpu_device() -> Option<String> {
827 if let Ok(entries) = std::fs::read_dir("/proc/driver/nvidia/gpus") {
829 for entry in entries.flatten() {
830 let info_path = entry.path().join("information");
831 if let Ok(info) = std::fs::read_to_string(&info_path) {
832 for line in info.lines() {
833 if let Some(name) = line.strip_prefix("Model:") {
834 return Some(name.trim().to_string());
835 }
836 }
837 }
838 }
839 }
840 if std::env::var("CUDA_VISIBLE_DEVICES").is_ok() {
842 return Some("CUDA device (unknown model)".to_string());
843 }
844 None
845}
846
847#[derive(Debug, Clone)]
856pub struct ApplyConfig {
857 pub model_path: PathBuf,
859 pub data_path: PathBuf,
861 pub output_dir: PathBuf,
863 pub on_trial_complete: Option<fn(usize, usize, &super::classify_tuner::TrialSummary)>,
866}
867
868pub fn execute_plan(
884 plan: &TrainingPlan,
885 apply: &ApplyConfig,
886) -> crate::Result<super::classify_tuner::TuneResult> {
887 use super::classify_pipeline::ClassifyConfig;
888 use super::classify_tuner::{
889 ClassifyTuner, SchedulerKind, TrialSummary, TuneConfig, TuneStrategy,
890 };
891 use crate::optim::ParameterValue;
892 use crate::transformer::TransformerConfig;
893 use std::collections::HashMap;
894
895 if plan.verdict == PlanVerdict::Blocked {
897 return Err(crate::Error::ConfigError(
898 "Cannot apply a blocked plan — resolve all failures first".to_string(),
899 ));
900 }
901
902 if !apply.model_path.is_dir() {
903 return Err(crate::Error::ConfigError(format!(
904 "Model path does not exist: {}",
905 apply.model_path.display()
906 )));
907 }
908
909 if !apply.data_path.exists() {
910 return Err(crate::Error::Io(format!(
911 "Training data not found: {}",
912 apply.data_path.display()
913 )));
914 }
915
916 std::fs::create_dir_all(&apply.output_dir).map_err(|e| {
918 crate::Error::Io(format!(
919 "Failed to create output directory {}: {e}",
920 apply.output_dir.display()
921 ))
922 })?;
923
924 let mut tracker = ExperimentTracker::open(&apply.output_dir, plan);
926
927 let model_config =
929 TransformerConfig::from_size_str(&plan.model.size).map_err(crate::Error::ConfigError)?;
930
931 let total_start = std::time::Instant::now();
932
933 let auto_nf4 = model_config.hidden_size >= 2048;
937 if auto_nf4 {
938 eprintln!(
939 "[plan] Auto-enabling NF4 quantization (hidden_size={} >= 2048)",
940 model_config.hidden_size
941 );
942 }
943
944 if plan.hyperparameters.strategy == "manual" {
946 let manual = plan.hyperparameters.manual.as_ref().ok_or_else(|| {
947 crate::Error::ConfigError(
948 "Manual strategy requires manual hyperparameters in plan".to_string(),
949 )
950 })?;
951
952 let num_classes = plan.data.class_counts.len();
953 let lora_alpha = manual.lora_alpha.unwrap_or(manual.lora_rank as f32);
954 let gradient_clip = manual.gradient_clip_norm.unwrap_or(1.0);
955 let warmup = manual.warmup_fraction.unwrap_or(0.1);
956 let lr_min_ratio = manual.lr_min_ratio.unwrap_or(0.01);
957
958 let class_weights = manual
959 .class_weights
960 .as_deref()
961 .and_then(|s| resolve_class_weights(s, &plan.data.class_counts, num_classes));
962
963 let classify_config = ClassifyConfig {
964 num_classes,
965 lora_rank: manual.lora_rank,
966 lora_alpha,
967 learning_rate: manual.learning_rate,
968 epochs: plan.hyperparameters.max_epochs,
969 batch_size: manual.batch_size,
970 gradient_clip_norm: Some(gradient_clip),
971 class_weights,
972 quantize_nf4: auto_nf4,
973 ..ClassifyConfig::default()
974 };
975
976 let trial_start = std::time::Instant::now();
977 let result = run_single_trial_with_warmup(
978 &apply.model_path,
979 &apply.data_path,
980 &apply.output_dir.join("trial_001"),
981 &model_config,
982 classify_config,
983 plan.hyperparameters.max_epochs,
984 warmup,
985 lr_min_ratio,
986 &plan.model.size,
987 )?;
988
989 let mut config_map = HashMap::new();
990 config_map.insert(
991 "learning_rate".to_string(),
992 ParameterValue::Float(f64::from(manual.learning_rate)),
993 );
994 config_map.insert("lora_rank".to_string(), ParameterValue::Int(manual.lora_rank as i64));
995 config_map.insert(
996 "batch_size".to_string(),
997 ParameterValue::Categorical(manual.batch_size.to_string()),
998 );
999
1000 let summary = TrialSummary {
1001 id: 0,
1002 val_loss: f64::from(result.best_val_loss),
1003 val_accuracy: result
1004 .epoch_metrics
1005 .get(result.best_epoch)
1006 .map_or(0.0, |m| f64::from(m.val_accuracy)),
1007 train_loss: result.epoch_metrics.last().map_or(0.0, |m| f64::from(m.train_loss)),
1008 train_accuracy: result
1009 .epoch_metrics
1010 .last()
1011 .map_or(0.0, |m| f64::from(m.train_accuracy)),
1012 epochs_run: result.epoch_metrics.len(),
1013 time_ms: trial_start.elapsed().as_millis() as u64,
1014 config: config_map,
1015 status: if result.stopped_early {
1016 "stopped_early".to_string()
1017 } else {
1018 "completed".to_string()
1019 },
1020 };
1021
1022 tracker.log_manual_trial(manual, &result);
1023
1024 if let Some(cb) = apply.on_trial_complete {
1025 cb(0, 1, &summary);
1026 }
1027
1028 return Ok(super::classify_tuner::TuneResult {
1029 strategy: "manual".to_string(),
1030 mode: "manual".to_string(),
1031 budget: 1,
1032 trials: vec![summary],
1033 best_trial_id: 0,
1034 total_time_ms: total_start.elapsed().as_millis() as u64,
1035 });
1036 }
1037
1038 let strategy: TuneStrategy = plan.hyperparameters.strategy.parse().unwrap_or(TuneStrategy::Tpe);
1040
1041 let num_classes = plan.data.class_counts.len();
1042
1043 let tune_config = TuneConfig {
1044 budget: plan.hyperparameters.budget,
1045 strategy,
1046 scheduler: SchedulerKind::Asha,
1047 scout: plan.hyperparameters.scout,
1048 max_epochs: plan.hyperparameters.max_epochs,
1049 num_classes,
1050 seed: 42,
1051 time_limit_secs: None,
1052 };
1053
1054 let mut tuner = ClassifyTuner::new(tune_config)?;
1055 let mut searcher = tuner.build_searcher();
1056 let scheduler = tuner.build_scheduler();
1057
1058 let budget = plan.hyperparameters.budget;
1059
1060 let plan_path = apply.output_dir.join("plan.yaml");
1062 let _ = std::fs::write(&plan_path, plan.to_yaml());
1063
1064 for trial_idx in 0..budget {
1065 let suggestion = match searcher.suggest() {
1067 Ok(s) => s,
1068 Err(e) => {
1069 eprintln!(" Trial {}: searcher exhausted ({e}), stopping", trial_idx + 1);
1070 break;
1071 }
1072 };
1073
1074 let (lr, rank, alpha, batch_size, warmup, clip, weights_strategy, _targets, lr_min_ratio) =
1075 super::classify_tuner::extract_trial_params(&suggestion.config);
1076
1077 let class_weights =
1079 resolve_class_weights(&weights_strategy, &plan.data.class_counts, num_classes);
1080
1081 let epochs = if plan.hyperparameters.scout { 1 } else { plan.hyperparameters.max_epochs };
1082
1083 let classify_config = ClassifyConfig {
1084 num_classes,
1085 lora_rank: rank,
1086 lora_alpha: alpha,
1087 learning_rate: lr,
1088 epochs,
1089 batch_size,
1090 gradient_clip_norm: Some(clip),
1091 class_weights,
1092 quantize_nf4: auto_nf4,
1093 ..ClassifyConfig::default()
1094 };
1095
1096 let trial_dir = apply.output_dir.join(format!("trial_{:03}", trial_idx + 1));
1097 let trial_start = std::time::Instant::now();
1098
1099 eprintln!(
1100 " Trial {}/{}: lr={:.2e} rank={} alpha={:.1} batch={} warmup={:.2} clip={:.1} weights={}",
1101 trial_idx + 1, budget, lr, rank, alpha, batch_size, warmup, clip, weights_strategy
1102 );
1103
1104 let trial_result = run_single_trial_with_warmup(
1106 &apply.model_path,
1107 &apply.data_path,
1108 &trial_dir,
1109 &model_config,
1110 classify_config,
1111 epochs,
1112 warmup,
1113 lr_min_ratio,
1114 &plan.model.size,
1115 );
1116
1117 let trial_time_ms = trial_start.elapsed().as_millis() as u64;
1118
1119 match trial_result {
1120 Ok(result) => {
1121 let val_loss = f64::from(result.best_val_loss);
1122 let val_accuracy = result
1123 .epoch_metrics
1124 .get(result.best_epoch)
1125 .map_or(0.0, |m| f64::from(m.val_accuracy));
1126
1127 let was_pruned = scheduler.should_stop(trial_idx, result.best_epoch, val_loss);
1129
1130 let status = resolve_trial_status(was_pruned, result.stopped_early);
1131
1132 let summary = TrialSummary {
1133 id: trial_idx,
1134 val_loss,
1135 val_accuracy,
1136 train_loss: result
1137 .epoch_metrics
1138 .last()
1139 .map_or(0.0, |m| f64::from(m.train_loss)),
1140 train_accuracy: result
1141 .epoch_metrics
1142 .last()
1143 .map_or(0.0, |m| f64::from(m.train_accuracy)),
1144 epochs_run: result.epoch_metrics.len(),
1145 time_ms: trial_time_ms,
1146 config: suggestion.config.clone(),
1147 status: status.to_string(),
1148 };
1149
1150 eprintln!(
1151 " => val_loss={:.4} val_acc={:.1}% epochs={} [{}]",
1152 val_loss,
1153 val_accuracy * 100.0,
1154 result.epoch_metrics.len(),
1155 status,
1156 );
1157
1158 tracker.log_hpo_trial(&suggestion.config, &result, was_pruned);
1159
1160 searcher.record(suggestion.clone(), val_loss, result.epoch_metrics.len());
1162 tuner.record_trial(summary.clone());
1163
1164 if let Some(cb) = apply.on_trial_complete {
1165 cb(trial_idx, budget, &summary);
1166 }
1167 }
1168 Err(e) => {
1169 eprintln!(" => FAILED: {e}");
1170 tracker.log_failed_trial();
1171
1172 let summary = TrialSummary {
1173 id: trial_idx,
1174 val_loss: f64::INFINITY,
1175 val_accuracy: 0.0,
1176 train_loss: f64::INFINITY,
1177 train_accuracy: 0.0,
1178 epochs_run: 0,
1179 time_ms: trial_time_ms,
1180 config: suggestion.config.clone(),
1181 status: "failed".to_string(),
1182 };
1183 searcher.record(suggestion, f64::INFINITY, 0);
1184 tuner.record_trial(summary);
1185 }
1186 }
1187 }
1188
1189 let total_time_ms = total_start.elapsed().as_millis() as u64;
1190
1191 let result = tuner.into_result(total_time_ms);
1193 let leaderboard_path = apply.output_dir.join("leaderboard.json");
1194 let _ = std::fs::write(
1195 &leaderboard_path,
1196 serde_json::to_string_pretty(&result).unwrap_or_default(),
1197 );
1198
1199 Ok(result)
1200}
1201
1202fn run_single_trial_with_warmup(
1204 model_path: &std::path::Path,
1205 data_path: &std::path::Path,
1206 checkpoint_dir: &std::path::Path,
1207 model_config: &crate::transformer::TransformerConfig,
1208 classify_config: super::classify_pipeline::ClassifyConfig,
1209 epochs: usize,
1210 warmup_fraction: f32,
1211 lr_min_ratio: f32,
1212 model_name: &str,
1213) -> crate::Result<super::classify_trainer::TrainResult> {
1214 use super::classify_pipeline::ClassifyPipeline;
1215 use super::classify_trainer::{ClassifyTrainer, TrainingConfig};
1216
1217 std::fs::create_dir_all(checkpoint_dir).map_err(|e| {
1219 crate::Error::Io(format!(
1220 "Failed to create checkpoint dir {}: {e}",
1221 checkpoint_dir.display()
1222 ))
1223 })?;
1224
1225 let pipeline = ClassifyPipeline::from_pretrained(model_path, model_config, classify_config)?;
1227
1228 let samples = pipeline.load_corpus(data_path)?;
1230
1231 let lr_min = pipeline.config.learning_rate * lr_min_ratio;
1232
1233 let training_config = TrainingConfig {
1235 epochs,
1236 val_split: 0.2,
1237 save_every: 1,
1238 early_stopping_patience: 5,
1239 checkpoint_dir: checkpoint_dir.to_path_buf(),
1240 seed: 42,
1241 log_interval: 1,
1242 warmup_fraction,
1243 lr_min,
1244 ..TrainingConfig::default()
1245 };
1246
1247 let mut trainer = ClassifyTrainer::new(pipeline, samples, training_config)?;
1249
1250 let experiment_id = format!(
1252 "trial-{}",
1253 std::time::SystemTime::now()
1254 .duration_since(std::time::UNIX_EPOCH)
1255 .map(|d| d.as_secs())
1256 .unwrap_or(0)
1257 );
1258 let writer =
1259 crate::monitor::tui::TrainingStateWriter::new(checkpoint_dir, &experiment_id, model_name);
1260 trainer.set_monitor_writer(writer);
1261
1262 Ok(trainer.train())
1264}
1265
1266pub(crate) fn resolve_class_weights(
1268 strategy: &str,
1269 class_counts: &[usize],
1270 num_classes: usize,
1271) -> Option<Vec<f32>> {
1272 use super::classification::{compute_class_weights, ClassWeightStrategy, SafetyCorpusStats};
1273
1274 match strategy {
1275 "uniform" => None,
1276 "inverse_freq" => {
1277 let stats = SafetyCorpusStats {
1278 total: class_counts.iter().sum(),
1279 class_counts: class_counts.to_vec(),
1280 avg_input_len: 0,
1281 };
1282 Some(compute_class_weights(&stats, ClassWeightStrategy::InverseFreq, num_classes))
1283 }
1284 "sqrt_inverse" => {
1285 let stats = SafetyCorpusStats {
1286 total: class_counts.iter().sum(),
1287 class_counts: class_counts.to_vec(),
1288 avg_input_len: 0,
1289 };
1290 Some(compute_class_weights(&stats, ClassWeightStrategy::SqrtInverse, num_classes))
1291 }
1292 _ => None,
1293 }
1294}
1295
1296impl TrainingPlan {
1301 pub fn to_json(&self) -> String {
1303 serde_json::to_string_pretty(self).unwrap_or_default()
1304 }
1305
1306 pub fn to_yaml(&self) -> String {
1308 serde_yaml::to_string(self).unwrap_or_default()
1309 }
1310
1311 #[allow(clippy::should_implement_trait)]
1313 pub fn from_str(s: &str) -> crate::Result<Self> {
1314 if let Ok(plan) = serde_json::from_str::<TrainingPlan>(s) {
1316 return Ok(plan);
1317 }
1318 serde_yaml::from_str::<TrainingPlan>(s).map_err(|e| {
1320 crate::Error::ConfigError(format!("Failed to parse plan as JSON or YAML: {e}"))
1321 })
1322 }
1323
1324 pub fn check_counts(&self) -> (usize, usize, usize) {
1326 let pass = self.pre_flight.iter().filter(|c| c.status == CheckStatus::Pass).count();
1327 let warn = self.pre_flight.iter().filter(|c| c.status == CheckStatus::Warn).count();
1328 let fail = self.pre_flight.iter().filter(|c| c.status == CheckStatus::Fail).count();
1329 (pass, warn, fail)
1330 }
1331}
1332
1333pub(crate) fn resolve_trial_status(was_pruned: bool, stopped_early: bool) -> &'static str {
1335 if was_pruned {
1336 "pruned"
1337 } else if stopped_early {
1338 "stopped_early"
1339 } else {
1340 "completed"
1341 }
1342}
1343
1344pub(crate) struct ExperimentTracker {
1350 pub(crate) store: Option<crate::storage::SqliteBackend>,
1351 pub(crate) exp_id: Option<String>,
1352}
1353
1354impl ExperimentTracker {
1355 pub(crate) fn open(output_dir: &std::path::Path, plan: &TrainingPlan) -> Self {
1356 use crate::storage::{ExperimentStorage, SqliteBackend};
1357
1358 let mut store = SqliteBackend::open_project(output_dir).ok();
1359 let exp_id = store.as_mut().and_then(|s| {
1360 let config_json = serde_json::json!({
1361 "model": &plan.model.architecture,
1362 "size": &plan.model.size,
1363 "strategy": &plan.hyperparameters.strategy,
1364 "budget": plan.hyperparameters.budget,
1365 "num_classes": plan.data.class_counts.len(),
1366 });
1367 s.create_experiment(&plan.model.architecture, Some(config_json)).ok()
1368 });
1369 Self { store, exp_id }
1370 }
1371
1372 fn log_manual_trial(
1373 &mut self,
1374 manual: &ManualConfig,
1375 result: &super::classify_trainer::TrainResult,
1376 ) {
1377 use crate::storage::{ExperimentStorage, ParameterValue as SPV};
1378 let (store, eid) = match (self.store.as_mut(), self.exp_id.as_ref()) {
1379 (Some(s), Some(e)) => (s, e),
1380 _ => return,
1381 };
1382 let run_id = match store.create_run(eid) {
1383 Ok(id) => id,
1384 Err(_) => return,
1385 };
1386 let _ = store.start_run(&run_id);
1387 let _ =
1388 store.log_param(&run_id, "learning_rate", SPV::Float(f64::from(manual.learning_rate)));
1389 let _ = store.log_param(&run_id, "lora_rank", SPV::Int(manual.lora_rank as i64));
1390 let _ = store.log_param(&run_id, "batch_size", SPV::Int(manual.batch_size as i64));
1391 Self::log_epoch_metrics(store, &run_id, &result.epoch_metrics);
1392 let _ = store.complete_run(&run_id, crate::storage::RunStatus::Success);
1393 }
1394
1395 fn log_hpo_trial(
1396 &mut self,
1397 config: &std::collections::HashMap<String, crate::optim::ParameterValue>,
1398 result: &super::classify_trainer::TrainResult,
1399 was_pruned: bool,
1400 ) {
1401 use crate::optim::ParameterValue as OPV;
1402 use crate::storage::{ExperimentStorage, ParameterValue as SPV};
1403 let (store, eid) = match (self.store.as_mut(), self.exp_id.as_ref()) {
1404 (Some(s), Some(e)) => (s, e),
1405 _ => return,
1406 };
1407 let run_id = match store.create_run(eid) {
1408 Ok(id) => id,
1409 Err(_) => return,
1410 };
1411 let _ = store.start_run(&run_id);
1412 for (k, v) in config {
1413 let sv = match v {
1414 OPV::Float(f) => SPV::Float(*f),
1415 OPV::Int(i) => SPV::Int(*i),
1416 OPV::Categorical(s) => SPV::String(s.clone()),
1417 };
1418 let _ = store.log_param(&run_id, k, sv);
1419 }
1420 Self::log_epoch_metrics(store, &run_id, &result.epoch_metrics);
1421 let status = if was_pruned {
1422 crate::storage::RunStatus::Cancelled
1423 } else {
1424 crate::storage::RunStatus::Success
1425 };
1426 let _ = store.complete_run(&run_id, status);
1427 }
1428
1429 pub(crate) fn log_failed_trial(&mut self) {
1430 use crate::storage::ExperimentStorage;
1431 let (store, eid) = match (self.store.as_mut(), self.exp_id.as_ref()) {
1432 (Some(s), Some(e)) => (s, e),
1433 _ => return,
1434 };
1435 if let Ok(run_id) = store.create_run(eid) {
1436 let _ = store.start_run(&run_id);
1437 let _ = store.complete_run(&run_id, crate::storage::RunStatus::Failed);
1438 }
1439 }
1440
1441 fn log_epoch_metrics(
1442 store: &mut crate::storage::SqliteBackend,
1443 run_id: &str,
1444 epochs: &[super::classify_trainer::EpochMetrics],
1445 ) {
1446 use crate::storage::ExperimentStorage;
1447 for (i, epoch) in epochs.iter().enumerate() {
1448 let _ = store.log_metric(run_id, "train_loss", i as u64, f64::from(epoch.train_loss));
1449 let _ = store.log_metric(run_id, "val_loss", i as u64, f64::from(epoch.val_loss));
1450 let _ =
1451 store.log_metric(run_id, "val_accuracy", i as u64, f64::from(epoch.val_accuracy));
1452 }
1453 }
1454}
1455
1456#[cfg(test)]
1457#[allow(clippy::unwrap_used)]
1458#[path = "training_plan_tests.rs"]
1459mod tests;