Skip to main content

entrenar/finetune/
training_plan.rs

1//! Training plan — forjar-style plan/apply for ML training.
2//!
3//! A `TrainingPlan` captures everything needed to execute a training run:
4//! data audit results, model configuration, hyperparameter strategy, resource
5//! estimates, and pre-flight check results. The plan is generated without
6//! touching the GPU, so validation is fast and cheap.
7//!
8//! # Architecture (mirrors forjar plan/apply)
9//!
10//! ```text
11//! PlanConfig → validate data → check model → build HPO → estimate cost → pre-flight → TrainingPlan
12//!                                                                                          │
13//!                                   TrainingPlan → apply (future) → checkpoint + lock
14//! ```
15
16use std::path::PathBuf;
17
18use serde::{Deserialize, Serialize};
19
20use super::classification::{corpus_stats, load_safety_corpus};
21use super::classify_tuner::{default_classify_search_space, extract_trial_params, TuneStrategy};
22
23// ═══════════════════════════════════════════════════════════════════════
24// Plan input configuration
25// ═══════════════════════════════════════════════════════════════════════
26
27/// Input configuration for plan generation.
28///
29/// This is the user's intent — what they want to train. The plan builder
30/// validates this against reality and produces a `TrainingPlan`.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct PlanConfig {
33    /// Task type (currently only "classify").
34    pub task: String,
35    /// Path to training data (JSONL).
36    pub data_path: PathBuf,
37    /// Optional validation data (JSONL). If absent, split from training data.
38    pub val_path: Option<PathBuf>,
39    /// Optional test data (JSONL). Used for post-train eval.
40    pub test_path: Option<PathBuf>,
41    /// Model size hint (e.g. "0.5B", "9B").
42    pub model_size: String,
43    /// Path to model weights directory.
44    pub model_path: Option<PathBuf>,
45    /// Number of output classes.
46    pub num_classes: usize,
47    /// Output directory for checkpoints.
48    pub output_dir: PathBuf,
49    /// HPO strategy: "tpe", "grid", "random", or "manual".
50    pub strategy: String,
51    /// HPO budget (number of trials). Ignored if strategy is "manual".
52    pub budget: usize,
53    /// Scout mode (1 epoch per trial).
54    pub scout: bool,
55    /// Maximum epochs per trial.
56    pub max_epochs: usize,
57    /// Manual hyperparameters (used when strategy is "manual").
58    pub manual_lr: Option<f32>,
59    /// Manual LoRA rank.
60    pub manual_lora_rank: Option<usize>,
61    /// Manual batch size.
62    pub manual_batch_size: Option<usize>,
63    /// Manual LoRA alpha.
64    pub manual_lora_alpha: Option<f32>,
65    /// Manual warmup fraction.
66    pub manual_warmup: Option<f32>,
67    /// Manual gradient clip norm.
68    pub manual_gradient_clip: Option<f32>,
69    /// Manual LR min ratio.
70    pub manual_lr_min_ratio: Option<f32>,
71    /// Manual class weight strategy.
72    pub manual_class_weights: Option<String>,
73    /// Manual target modules.
74    pub manual_target_modules: Option<String>,
75}
76
77// ═══════════════════════════════════════════════════════════════════════
78// Training plan output
79// ═══════════════════════════════════════════════════════════════════════
80
81/// Complete training plan — the serializable artifact that describes
82/// exactly what a training run will do.
83///
84/// Analogous to forjar's `ExecutionPlan`.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct TrainingPlan {
87    /// Plan format version.
88    pub version: String,
89    /// Task type.
90    pub task: String,
91
92    // ── Data audit ─────────────────────────────────────────────────────
93    /// Data audit summary.
94    pub data: DataAudit,
95
96    // ── Model ──────────────────────────────────────────────────────────
97    /// Model configuration summary.
98    pub model: ModelInfo,
99
100    // ── Hyperparameters ────────────────────────────────────────────────
101    /// Hyperparameter configuration.
102    pub hyperparameters: HyperparameterPlan,
103
104    // ── Resource estimates ─────────────────────────────────────────────
105    /// Estimated resource usage.
106    pub resources: ResourceEstimate,
107
108    // ── Pre-flight checks ──────────────────────────────────────────────
109    /// Pre-flight check results.
110    pub pre_flight: Vec<PreFlightCheck>,
111
112    // ── Output config ──────────────────────────────────────────────────
113    /// Output directory.
114    pub output_dir: String,
115    /// Whether to auto-diagnose after training.
116    pub auto_diagnose: bool,
117
118    // ── Plan-level verdict ─────────────────────────────────────────────
119    /// Overall plan status.
120    pub verdict: PlanVerdict,
121    /// Issues found during planning.
122    pub issues: Vec<PlanIssue>,
123}
124
125/// Data audit results.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct DataAudit {
128    /// Path to training data.
129    pub train_path: String,
130    /// Total training samples.
131    pub train_samples: usize,
132    /// Average input length in characters.
133    pub avg_input_len: usize,
134    /// Per-class sample counts.
135    pub class_counts: Vec<usize>,
136    /// Imbalance ratio (max/min class count).
137    pub imbalance_ratio: f64,
138    /// Whether class weighting will be auto-applied.
139    pub auto_class_weights: bool,
140    /// Validation samples (if separate file provided).
141    pub val_samples: Option<usize>,
142    /// Test samples (if separate file provided).
143    pub test_samples: Option<usize>,
144    /// Number of duplicate inputs detected.
145    pub duplicates: usize,
146    /// Number of samples with shell preamble.
147    pub preamble_count: usize,
148}
149
150/// Model configuration summary.
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ModelInfo {
153    /// Model size label.
154    pub size: String,
155    /// Hidden dimension.
156    pub hidden_size: usize,
157    /// Number of transformer layers.
158    pub num_layers: usize,
159    /// Architecture name.
160    pub architecture: String,
161    /// Whether model weights are loadable.
162    pub weights_available: bool,
163    /// LoRA trainable parameters (estimated).
164    pub lora_trainable_params: usize,
165    /// Classifier head parameters.
166    pub classifier_params: usize,
167}
168
169/// Hyperparameter plan.
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct HyperparameterPlan {
172    /// Strategy: "tpe", "grid", "random", or "manual".
173    pub strategy: String,
174    /// Number of HPO trials (0 if manual).
175    pub budget: usize,
176    /// Scout mode.
177    pub scout: bool,
178    /// Maximum epochs per trial.
179    pub max_epochs: usize,
180    /// Search space parameter count (0 if manual).
181    pub search_space_params: usize,
182    /// Sample trial configurations (first 3 from searcher).
183    pub sample_configs: Vec<TrialPreview>,
184    /// Manual config (if strategy is "manual").
185    pub manual: Option<ManualConfig>,
186    /// Recommendation: should user switch strategy?
187    pub recommendation: Option<String>,
188}
189
190/// Preview of a single HPO trial configuration.
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct TrialPreview {
193    /// Trial index.
194    pub trial: usize,
195    /// Learning rate.
196    pub learning_rate: f32,
197    /// LoRA rank.
198    pub lora_rank: usize,
199    /// LoRA alpha.
200    pub lora_alpha: f32,
201    /// Batch size.
202    pub batch_size: usize,
203    /// Warmup fraction.
204    pub warmup: f32,
205    /// Gradient clip norm.
206    pub gradient_clip: f32,
207    /// Class weight strategy name.
208    pub class_weights: String,
209    /// Target modules.
210    pub target_modules: String,
211    /// LR min ratio.
212    pub lr_min_ratio: f32,
213}
214
215/// Manual hyperparameter configuration.
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct ManualConfig {
218    /// Learning rate.
219    pub learning_rate: f32,
220    /// LoRA rank.
221    pub lora_rank: usize,
222    /// Batch size.
223    pub batch_size: usize,
224    /// LoRA alpha (defaults to rank if absent).
225    #[serde(default)]
226    pub lora_alpha: Option<f32>,
227    /// Warmup fraction (defaults to 0.1 if absent).
228    #[serde(default)]
229    pub warmup_fraction: Option<f32>,
230    /// Gradient clip norm (defaults to 1.0 if absent).
231    #[serde(default)]
232    pub gradient_clip_norm: Option<f32>,
233    /// LR min ratio for cosine decay (defaults to 0.01 if absent).
234    #[serde(default)]
235    pub lr_min_ratio: Option<f32>,
236    /// Class weight strategy: "uniform", "inverse_freq", "sqrt_inverse".
237    #[serde(default)]
238    pub class_weights: Option<String>,
239    /// Target modules: "qv", "qkv", "all_linear".
240    #[serde(default)]
241    pub target_modules: Option<String>,
242}
243
244/// Resource usage estimate.
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct ResourceEstimate {
247    /// Estimated VRAM usage in GB.
248    pub estimated_vram_gb: f64,
249    /// Estimated time per epoch in minutes.
250    pub estimated_minutes_per_epoch: f64,
251    /// Estimated total training time in minutes.
252    pub estimated_total_minutes: f64,
253    /// Estimated checkpoint storage in MB.
254    pub estimated_checkpoint_mb: f64,
255    /// Steps per epoch (train_samples / batch_size).
256    pub steps_per_epoch: usize,
257    /// Detected GPU device name (if available).
258    pub gpu_device: Option<String>,
259}
260
261/// A single pre-flight check result.
262#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct PreFlightCheck {
264    /// Check name.
265    pub name: String,
266    /// Check status.
267    pub status: CheckStatus,
268    /// Detail message.
269    pub detail: String,
270}
271
272/// Status of a pre-flight check.
273#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
274pub enum CheckStatus {
275    /// Check passed.
276    Pass,
277    /// Warning (non-blocking).
278    Warn,
279    /// Failed (blocks apply).
280    Fail,
281}
282
283/// Overall plan verdict.
284#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
285pub enum PlanVerdict {
286    /// All checks pass, ready to apply.
287    Ready,
288    /// Warnings present but can proceed.
289    WarningsPresent,
290    /// Failures detected, cannot apply.
291    Blocked,
292}
293
294/// An issue found during planning.
295#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct PlanIssue {
297    /// Issue severity.
298    pub severity: CheckStatus,
299    /// Issue category.
300    pub category: String,
301    /// Issue description.
302    pub message: String,
303    /// Suggested fix.
304    pub fix: Option<String>,
305}
306
307// ═══════════════════════════════════════════════════════════════════════
308// Plan generation
309// ═══════════════════════════════════════════════════════════════════════
310
311/// Generate a training plan from a `PlanConfig`.
312///
313/// This is the pure validation phase — no GPU allocation, no weight loading,
314/// no training. Reads data files, validates schemas, estimates costs.
315pub fn plan(config: &PlanConfig) -> crate::Result<TrainingPlan> {
316    let mut issues: Vec<PlanIssue> = Vec::new();
317    let mut pre_flight: Vec<PreFlightCheck> = Vec::new();
318
319    // ── 1. Data audit ──────────────────────────────────────────────────
320
321    let data = audit_data(config, &mut issues, &mut pre_flight)?;
322
323    // ── 2. Model info ──────────────────────────────────────────────────
324
325    let model = resolve_model(config, &mut pre_flight);
326
327    // ── 3. Hyperparameter plan ─────────────────────────────────────────
328
329    let hyperparameters = build_hpo_plan(config, data.train_samples, &mut issues);
330
331    // ── 4. Resource estimation ─────────────────────────────────────────
332    //
333    // For HPO, use the median batch size from search space (64) to get a
334    // representative estimate. For manual, use the configured batch size.
335    let batch_size = hyperparameters.manual.as_ref().map_or(64, |m| m.batch_size);
336    let resources = estimate_resources(config, &model, &data, batch_size);
337
338    // ── 5. Additional pre-flight checks ────────────────────────────────
339
340    // Output directory
341    if config.output_dir.exists() {
342        let has_checkpoints = config.output_dir.join("metadata.json").exists()
343            || config.output_dir.join("epoch_001").exists();
344        if has_checkpoints {
345            pre_flight.push(PreFlightCheck {
346                name: "output_dir".to_string(),
347                status: CheckStatus::Warn,
348                detail: format!(
349                    "Output directory {} already contains checkpoints — may overwrite",
350                    config.output_dir.display()
351                ),
352            });
353            issues.push(PlanIssue {
354                severity: CheckStatus::Warn,
355                category: "Output".to_string(),
356                message: "Checkpoint directory already contains previous run".to_string(),
357                fix: Some("Use a fresh output directory or rename existing one".to_string()),
358            });
359        } else {
360            pre_flight.push(PreFlightCheck {
361                name: "output_dir".to_string(),
362                status: CheckStatus::Pass,
363                detail: format!("Output directory {} exists", config.output_dir.display()),
364            });
365        }
366    } else {
367        pre_flight.push(PreFlightCheck {
368            name: "output_dir".to_string(),
369            status: CheckStatus::Pass,
370            detail: format!("Output directory {} will be created", config.output_dir.display()),
371        });
372    }
373
374    // Class weights persistence check
375    pre_flight.push(PreFlightCheck {
376        name: "class_weights_persist".to_string(),
377        status: CheckStatus::Pass,
378        detail: "class_weights saved in checkpoint metadata (entrenar ≥0.7.5)".to_string(),
379    });
380
381    // ── 6. Verdict ─────────────────────────────────────────────────────
382
383    let has_fail = pre_flight.iter().any(|c| c.status == CheckStatus::Fail)
384        || issues.iter().any(|i| i.severity == CheckStatus::Fail);
385    let has_warn = pre_flight.iter().any(|c| c.status == CheckStatus::Warn)
386        || issues.iter().any(|i| i.severity == CheckStatus::Warn);
387
388    let verdict = if has_fail {
389        PlanVerdict::Blocked
390    } else if has_warn {
391        PlanVerdict::WarningsPresent
392    } else {
393        PlanVerdict::Ready
394    };
395
396    Ok(TrainingPlan {
397        version: "1.0".to_string(),
398        task: config.task.clone(),
399        data,
400        model,
401        hyperparameters,
402        resources,
403        pre_flight,
404        output_dir: config.output_dir.display().to_string(),
405        auto_diagnose: true,
406        verdict,
407        issues,
408    })
409}
410
411// ═══════════════════════════════════════════════════════════════════════
412// Internal plan builders
413// ═══════════════════════════════════════════════════════════════════════
414
415/// Audit training data without loading into GPU.
416fn audit_data(
417    config: &PlanConfig,
418    issues: &mut Vec<PlanIssue>,
419    pre_flight: &mut Vec<PreFlightCheck>,
420) -> crate::Result<DataAudit> {
421    // Validate data file exists
422    if !config.data_path.exists() {
423        pre_flight.push(PreFlightCheck {
424            name: "data_file".to_string(),
425            status: CheckStatus::Fail,
426            detail: format!("Training data not found: {}", config.data_path.display()),
427        });
428        return Err(crate::Error::Io(format!(
429            "Training data not found: {}",
430            config.data_path.display()
431        )));
432    }
433
434    // Load and validate corpus
435    let corpus = load_safety_corpus(&config.data_path, config.num_classes)?;
436    let stats = corpus_stats(&corpus, config.num_classes);
437
438    pre_flight.push(PreFlightCheck {
439        name: "data_file".to_string(),
440        status: CheckStatus::Pass,
441        detail: format!("{} samples loaded from {}", stats.total, config.data_path.display()),
442    });
443
444    // Check for empty classes
445    let empty_classes: Vec<usize> =
446        stats.class_counts.iter().enumerate().filter(|(_, &c)| c == 0).map(|(i, _)| i).collect();
447    if empty_classes.is_empty() {
448        pre_flight.push(PreFlightCheck {
449            name: "class_coverage".to_string(),
450            status: CheckStatus::Pass,
451            detail: format!("All {} classes have samples", config.num_classes),
452        });
453    } else {
454        pre_flight.push(PreFlightCheck {
455            name: "class_coverage".to_string(),
456            status: CheckStatus::Fail,
457            detail: format!("Classes with zero samples: {empty_classes:?}"),
458        });
459        issues.push(PlanIssue {
460            severity: CheckStatus::Fail,
461            category: "Data".to_string(),
462            message: format!("Classes {empty_classes:?} have zero training samples"),
463            fix: Some("Add samples for missing classes or reduce num_classes".to_string()),
464        });
465    }
466
467    // Imbalance analysis
468    let min_count = stats.class_counts.iter().copied().min().unwrap_or(1).max(1);
469    let max_count = stats.class_counts.iter().copied().max().unwrap_or(1);
470    let imbalance_ratio = max_count as f64 / min_count as f64;
471    let auto_class_weights = imbalance_ratio > 2.0;
472
473    if imbalance_ratio > 5.0 {
474        issues.push(PlanIssue {
475            severity: CheckStatus::Warn,
476            category: "Data".to_string(),
477            message: format!(
478                "Severe class imbalance ({imbalance_ratio:.1}:1) — sqrt-inverse weights will be auto-applied"
479            ),
480            fix: Some("Consider oversampling minority classes: apr data balance --strategy oversample".to_string()),
481        });
482    }
483
484    // Duplicate detection (fast: hash inputs)
485    let mut seen = std::collections::HashSet::new();
486    let mut duplicates = 0usize;
487    for s in &corpus {
488        if !seen.insert(&s.input) {
489            duplicates += 1;
490        }
491    }
492    if duplicates > 0 {
493        issues.push(PlanIssue {
494            severity: CheckStatus::Warn,
495            category: "Data".to_string(),
496            message: format!(
497                "{duplicates} duplicate inputs detected ({:.1}%)",
498                duplicates as f64 / stats.total as f64 * 100.0
499            ),
500            fix: Some("Remove duplicates: apr data dedup".to_string()),
501        });
502    }
503
504    // Preamble detection
505    let preamble_count = corpus
506        .iter()
507        .filter(|s| {
508            s.input.starts_with("#!/")
509                || s.input.starts_with("#! /")
510                || s.input.starts_with("set -")
511        })
512        .count();
513    if preamble_count > stats.total / 10 {
514        issues.push(PlanIssue {
515            severity: CheckStatus::Warn,
516            category: "Data".to_string(),
517            message: format!(
518                "{preamble_count} samples ({:.0}%) have shell preamble",
519                preamble_count as f64 / stats.total as f64 * 100.0
520            ),
521            fix: Some("Strip preambles: use --strip-preamble during data export".to_string()),
522        });
523    }
524
525    // Minimum sample count
526    if stats.total < 100 {
527        issues.push(PlanIssue {
528            severity: CheckStatus::Warn,
529            category: "Data".to_string(),
530            message: format!("Only {} samples — may be insufficient for fine-tuning", stats.total),
531            fix: None,
532        });
533    }
534
535    // Validate val/test if provided
536    let val_samples = count_file_samples(config.val_path.as_ref(), config.num_classes);
537    let test_samples = count_file_samples(config.test_path.as_ref(), config.num_classes);
538
539    Ok(DataAudit {
540        train_path: config.data_path.display().to_string(),
541        train_samples: stats.total,
542        avg_input_len: stats.avg_input_len,
543        class_counts: stats.class_counts,
544        imbalance_ratio,
545        auto_class_weights,
546        val_samples,
547        test_samples,
548        duplicates,
549        preamble_count,
550    })
551}
552
553/// Count samples in an optional JSONL file.
554pub(crate) fn count_file_samples(path: Option<&PathBuf>, num_classes: usize) -> Option<usize> {
555    path.and_then(|p| {
556        if p.exists() {
557            load_safety_corpus(p, num_classes).ok().map(|c| c.len())
558        } else {
559            None
560        }
561    })
562}
563
564/// Resolve model architecture from size hint.
565pub(crate) fn resolve_model(
566    config: &PlanConfig,
567    pre_flight: &mut Vec<PreFlightCheck>,
568) -> ModelInfo {
569    let (hidden_size, num_layers, architecture) = match config.model_size.as_str() {
570        "0.5B" | "500M" | "qwen2-0.5b" => (896, 24, "qwen2"),
571        "9B" | "qwen3.5-9b" => (4096, 48, "qwen3.5"),
572        "7B" | "llama2-7b" => (4096, 32, "llama2"),
573        "13B" | "llama2-13b" => (5120, 40, "llama2"),
574        _ => (896, 24, "unknown"),
575    };
576
577    // Check if model weights are available
578    let weights_available = config.model_path.as_ref().is_some_and(|p| p.is_dir());
579    if let Some(ref path) = config.model_path {
580        if weights_available {
581            // Check for key files
582            let has_safetensors = path.join("model.safetensors").exists()
583                || path.join("model-00001-of-00002.safetensors").exists();
584            let has_tokenizer = path.join("tokenizer.json").exists();
585
586            if has_safetensors && has_tokenizer {
587                pre_flight.push(PreFlightCheck {
588                    name: "model_weights".to_string(),
589                    status: CheckStatus::Pass,
590                    detail: format!("Model weights found at {}", path.display()),
591                });
592            } else {
593                let mut missing = Vec::new();
594                if !has_safetensors {
595                    missing.push("model.safetensors");
596                }
597                if !has_tokenizer {
598                    missing.push("tokenizer.json");
599                }
600                pre_flight.push(PreFlightCheck {
601                    name: "model_weights".to_string(),
602                    status: CheckStatus::Warn,
603                    detail: format!("Model directory exists but missing: {}", missing.join(", ")),
604                });
605            }
606        } else {
607            pre_flight.push(PreFlightCheck {
608                name: "model_weights".to_string(),
609                status: CheckStatus::Fail,
610                detail: format!("Model path not found: {}", path.display()),
611            });
612        }
613    } else {
614        pre_flight.push(PreFlightCheck {
615            name: "model_weights".to_string(),
616            status: CheckStatus::Warn,
617            detail: "No model path specified — will use default model resolution".to_string(),
618        });
619    }
620
621    // Estimate trainable parameters
622    // LoRA: 2 * rank * hidden_size * num_adapters (Q,V per layer = 2 * num_layers)
623    let default_rank = config.manual_lora_rank.unwrap_or(16);
624    let lora_trainable_params = 2 * default_rank * hidden_size * 2 * num_layers;
625    let classifier_params = hidden_size * config.num_classes + config.num_classes;
626
627    ModelInfo {
628        size: config.model_size.clone(),
629        hidden_size,
630        num_layers,
631        architecture: architecture.to_string(),
632        weights_available,
633        lora_trainable_params,
634        classifier_params,
635    }
636}
637
638/// Build HPO plan with search space and sample configs.
639pub(crate) fn build_hpo_plan(
640    config: &PlanConfig,
641    train_samples: usize,
642    issues: &mut Vec<PlanIssue>,
643) -> HyperparameterPlan {
644    let strategy = config.strategy.as_str();
645
646    if strategy == "manual" {
647        let lr = config.manual_lr.unwrap_or(1e-4);
648        let rank = config.manual_lora_rank.unwrap_or(16);
649        let batch = config.manual_batch_size.unwrap_or(32);
650
651        // Warn about manual mode when HPO is available
652        issues.push(PlanIssue {
653            severity: CheckStatus::Warn,
654            category: "Hyperparameters".to_string(),
655            message: "Using manual hyperparameters — HPO (--strategy tpe) searches 9 parameters automatically".to_string(),
656            fix: Some(format!(
657                "apr train plan --strategy tpe --budget 20 --scout --data {}",
658                config.data_path.display()
659            )),
660        });
661
662        return HyperparameterPlan {
663            strategy: "manual".to_string(),
664            budget: 0,
665            scout: false,
666            max_epochs: config.max_epochs,
667            search_space_params: 0,
668            sample_configs: Vec::new(),
669            manual: Some(ManualConfig {
670                learning_rate: lr,
671                lora_rank: rank,
672                batch_size: batch,
673                lora_alpha: config.manual_lora_alpha,
674                warmup_fraction: config.manual_warmup,
675                gradient_clip_norm: config.manual_gradient_clip,
676                lr_min_ratio: config.manual_lr_min_ratio,
677                class_weights: config.manual_class_weights.clone(),
678                target_modules: config.manual_target_modules.clone(),
679            }),
680            recommendation: Some(
681                "Consider using --strategy tpe for automated hyperparameter search".to_string(),
682            ),
683        };
684    }
685
686    // Parse strategy for searcher
687    let tune_strategy: TuneStrategy = strategy.parse().unwrap_or(TuneStrategy::Tpe);
688
689    // Build searcher and sample configs
690    let space = default_classify_search_space();
691    let mut searcher: Box<dyn super::classify_tuner::TuneSearcher> = match tune_strategy {
692        TuneStrategy::Tpe => {
693            let n_startup = (config.budget / 3).max(3);
694            Box::new(super::tune_searchers::TpeSearcher::new(space.clone(), n_startup))
695        }
696        TuneStrategy::Grid => Box::new(super::tune_searchers::GridSearcher::new(space.clone(), 3)),
697        TuneStrategy::Random => Box::new(super::tune_searchers::RandomSearcher::new(space.clone())),
698    };
699
700    let num_previews = config.budget.min(3);
701    let mut sample_configs = Vec::new();
702    for i in 0..num_previews {
703        if let Ok(trial) = searcher.suggest() {
704            let (lr, rank, alpha, batch, warmup, clip, weights, targets, lr_min) =
705                extract_trial_params(&trial.config);
706            sample_configs.push(TrialPreview {
707                trial: i + 1,
708                learning_rate: lr,
709                lora_rank: rank,
710                lora_alpha: alpha,
711                batch_size: batch,
712                warmup,
713                gradient_clip: clip,
714                class_weights: weights,
715                target_modules: targets,
716                lr_min_ratio: lr_min,
717            });
718        }
719    }
720
721    // Budget sanity check
722    if config.budget < 5 && tune_strategy == TuneStrategy::Tpe {
723        issues.push(PlanIssue {
724            severity: CheckStatus::Warn,
725            category: "Hyperparameters".to_string(),
726            message: format!(
727                "TPE budget {} is low — needs ≥5 trials for Bayesian optimization to converge",
728                config.budget
729            ),
730            fix: Some("Use --budget 20 for better results".to_string()),
731        });
732    }
733
734    // Scout recommendation for large datasets
735    if !config.scout && train_samples > 10_000 && config.max_epochs > 1 {
736        issues.push(PlanIssue {
737            severity: CheckStatus::Warn,
738            category: "Hyperparameters".to_string(),
739            message: format!(
740                "Full HPO with {} samples × {} epochs × {} trials = ~{:.0} GPU hours",
741                train_samples,
742                config.max_epochs,
743                config.budget,
744                estimate_gpu_hours(train_samples, config.max_epochs, config.budget)
745            ),
746            fix: Some(
747                "Use --scout for 1-epoch trials first, then --from-scout for full run".to_string(),
748            ),
749        });
750    }
751
752    HyperparameterPlan {
753        strategy: strategy.to_string(),
754        budget: config.budget,
755        scout: config.scout,
756        max_epochs: if config.scout { 1 } else { config.max_epochs },
757        search_space_params: 9,
758        sample_configs,
759        manual: None,
760        recommendation: None,
761    }
762}
763
764/// Estimate GPU hours for a full HPO run.
765pub(crate) fn estimate_gpu_hours(train_samples: usize, max_epochs: usize, budget: usize) -> f64 {
766    // Based on observed RTX 4090 throughput: ~58 sec/step, median batch_size=64
767    let batch_size = 64;
768    let steps_per_epoch = train_samples.div_ceil(batch_size);
769    let seconds_per_epoch = steps_per_epoch as f64 * 58.0;
770    let total_seconds = seconds_per_epoch * max_epochs as f64 * budget as f64;
771    total_seconds / 3600.0
772}
773
774/// Estimate resource usage.
775pub(crate) fn estimate_resources(
776    config: &PlanConfig,
777    model: &ModelInfo,
778    data: &DataAudit,
779    batch_size: usize,
780) -> ResourceEstimate {
781    // VRAM estimate: model weights + optimizer state + activations
782    // Qwen2 0.5B: ~1.0 GB weights, ~0.8 GB optimizer, ~0.5 GB activations
783    let base_vram = match model.hidden_size {
784        896 => 2.5,   // 0.5B
785        4096 => 18.0, // 7B/9B
786        5120 => 26.0, // 13B
787        _ => 3.0,
788    };
789
790    let steps_per_epoch = data.train_samples.div_ceil(batch_size);
791
792    // Time estimate based on observed GPU training throughput
793    // RTX 4090 measured: ~58 seconds/step for 0.5B at batch_size=40
794    // (measured over 245 steps in v3 training run)
795    // For larger models, scale by layer count ratio
796    let seconds_per_step = match model.hidden_size {
797        896 => 58.0,   // 0.5B: observed on RTX 4090
798        4096 => 270.0, // 7B/9B: estimated ~4.7x slower
799        5120 => 450.0, // 13B: estimated ~7.8x slower
800        _ => 90.0,
801    };
802    let minutes_per_epoch = (steps_per_epoch as f64 * seconds_per_step) / 60.0;
803
804    let total_epochs = if config.scout { 1 } else { config.max_epochs };
805    let total_trials = if config.strategy == "manual" { 1 } else { config.budget };
806    let total_minutes = minutes_per_epoch * total_epochs as f64 * total_trials as f64;
807
808    // Checkpoint size: LoRA adapters + classifier head
809    let checkpoint_mb =
810        (model.lora_trainable_params + model.classifier_params) as f64 * 4.0 / 1_048_576.0;
811
812    // Try to detect GPU
813    let gpu_device = detect_gpu_device();
814
815    ResourceEstimate {
816        estimated_vram_gb: base_vram,
817        estimated_minutes_per_epoch: minutes_per_epoch,
818        estimated_total_minutes: total_minutes,
819        estimated_checkpoint_mb: checkpoint_mb,
820        steps_per_epoch,
821        gpu_device,
822    }
823}
824
825/// Detect GPU device name (best-effort, no NVML required).
826pub(crate) fn detect_gpu_device() -> Option<String> {
827    // Try reading from sysfs (Linux)
828    if let Ok(entries) = std::fs::read_dir("/proc/driver/nvidia/gpus") {
829        for entry in entries.flatten() {
830            let info_path = entry.path().join("information");
831            if let Ok(info) = std::fs::read_to_string(&info_path) {
832                for line in info.lines() {
833                    if let Some(name) = line.strip_prefix("Model:") {
834                        return Some(name.trim().to_string());
835                    }
836                }
837            }
838        }
839    }
840    // Fallback: check CUDA_VISIBLE_DEVICES
841    if std::env::var("CUDA_VISIBLE_DEVICES").is_ok() {
842        return Some("CUDA device (unknown model)".to_string());
843    }
844    None
845}
846
847// ═══════════════════════════════════════════════════════════════════════
848// Plan execution (apply phase)
849// ═══════════════════════════════════════════════════════════════════════
850
851/// Runtime configuration for plan execution.
852///
853/// Supplements the `TrainingPlan` with execution-time parameters that
854/// cannot be determined at plan time (e.g. actual model path resolution).
855#[derive(Debug, Clone)]
856pub struct ApplyConfig {
857    /// Path to model weights directory.
858    pub model_path: PathBuf,
859    /// Path to training data (JSONL).
860    pub data_path: PathBuf,
861    /// Output directory for checkpoints and leaderboard.
862    pub output_dir: PathBuf,
863    /// Callback invoked after each trial completes.
864    /// Arguments: (trial_id, total_budget, summary).
865    pub on_trial_complete: Option<fn(usize, usize, &super::classify_tuner::TrialSummary)>,
866}
867
868/// Execute a training plan (the "apply" phase).
869///
870/// For HPO plans, iterates over trials:
871/// 1. `ClassifyTuner.suggest()` → hyperparameters
872/// 2. Build `ClassifyConfig` → `ClassifyPipeline::from_pretrained()`
873/// 3. Load corpus → `ClassifyTrainer::new()` → `trainer.train()`
874/// 4. Record trial results → check scheduler
875/// 5. Return `TuneResult` with leaderboard
876///
877/// For manual plans, executes a single trial with the specified params.
878///
879/// # Errors
880///
881/// Returns error if model weights cannot be loaded, data is invalid,
882/// or all trials fail.
883pub fn execute_plan(
884    plan: &TrainingPlan,
885    apply: &ApplyConfig,
886) -> crate::Result<super::classify_tuner::TuneResult> {
887    use super::classify_pipeline::ClassifyConfig;
888    use super::classify_tuner::{
889        ClassifyTuner, SchedulerKind, TrialSummary, TuneConfig, TuneStrategy,
890    };
891    use crate::optim::ParameterValue;
892    use crate::transformer::TransformerConfig;
893    use std::collections::HashMap;
894
895    // ── Verify pre-conditions ──────────────────────────────────────────
896    if plan.verdict == PlanVerdict::Blocked {
897        return Err(crate::Error::ConfigError(
898            "Cannot apply a blocked plan — resolve all failures first".to_string(),
899        ));
900    }
901
902    if !apply.model_path.is_dir() {
903        return Err(crate::Error::ConfigError(format!(
904            "Model path does not exist: {}",
905            apply.model_path.display()
906        )));
907    }
908
909    if !apply.data_path.exists() {
910        return Err(crate::Error::Io(format!(
911            "Training data not found: {}",
912            apply.data_path.display()
913        )));
914    }
915
916    // Create output directory
917    std::fs::create_dir_all(&apply.output_dir).map_err(|e| {
918        crate::Error::Io(format!(
919            "Failed to create output directory {}: {e}",
920            apply.output_dir.display()
921        ))
922    })?;
923
924    // ── Open project-local experiment store ────────────────────────────
925    let mut tracker = ExperimentTracker::open(&apply.output_dir, plan);
926
927    // GH-377: Resolve model config — error on unknown instead of silent tiny()
928    let model_config =
929        TransformerConfig::from_size_str(&plan.model.size).map_err(crate::Error::ConfigError)?;
930
931    let total_start = std::time::Instant::now();
932
933    // Auto-enable NF4 quantization for large models.
934    // Full fp32 weights for hidden_size >= 2048 (roughly >= 1B params)
935    // exceed RTX 4090 VRAM (24 GB) after scratch + kernel cache overhead.
936    let auto_nf4 = model_config.hidden_size >= 2048;
937    if auto_nf4 {
938        eprintln!(
939            "[plan] Auto-enabling NF4 quantization (hidden_size={} >= 2048)",
940            model_config.hidden_size
941        );
942    }
943
944    // ── Manual strategy: single trial ──────────────────────────────────
945    if plan.hyperparameters.strategy == "manual" {
946        let manual = plan.hyperparameters.manual.as_ref().ok_or_else(|| {
947            crate::Error::ConfigError(
948                "Manual strategy requires manual hyperparameters in plan".to_string(),
949            )
950        })?;
951
952        let num_classes = plan.data.class_counts.len();
953        let lora_alpha = manual.lora_alpha.unwrap_or(manual.lora_rank as f32);
954        let gradient_clip = manual.gradient_clip_norm.unwrap_or(1.0);
955        let warmup = manual.warmup_fraction.unwrap_or(0.1);
956        let lr_min_ratio = manual.lr_min_ratio.unwrap_or(0.01);
957
958        let class_weights = manual
959            .class_weights
960            .as_deref()
961            .and_then(|s| resolve_class_weights(s, &plan.data.class_counts, num_classes));
962
963        let classify_config = ClassifyConfig {
964            num_classes,
965            lora_rank: manual.lora_rank,
966            lora_alpha,
967            learning_rate: manual.learning_rate,
968            epochs: plan.hyperparameters.max_epochs,
969            batch_size: manual.batch_size,
970            gradient_clip_norm: Some(gradient_clip),
971            class_weights,
972            quantize_nf4: auto_nf4,
973            ..ClassifyConfig::default()
974        };
975
976        let trial_start = std::time::Instant::now();
977        let result = run_single_trial_with_warmup(
978            &apply.model_path,
979            &apply.data_path,
980            &apply.output_dir.join("trial_001"),
981            &model_config,
982            classify_config,
983            plan.hyperparameters.max_epochs,
984            warmup,
985            lr_min_ratio,
986            &plan.model.size,
987        )?;
988
989        let mut config_map = HashMap::new();
990        config_map.insert(
991            "learning_rate".to_string(),
992            ParameterValue::Float(f64::from(manual.learning_rate)),
993        );
994        config_map.insert("lora_rank".to_string(), ParameterValue::Int(manual.lora_rank as i64));
995        config_map.insert(
996            "batch_size".to_string(),
997            ParameterValue::Categorical(manual.batch_size.to_string()),
998        );
999
1000        let summary = TrialSummary {
1001            id: 0,
1002            val_loss: f64::from(result.best_val_loss),
1003            val_accuracy: result
1004                .epoch_metrics
1005                .get(result.best_epoch)
1006                .map_or(0.0, |m| f64::from(m.val_accuracy)),
1007            train_loss: result.epoch_metrics.last().map_or(0.0, |m| f64::from(m.train_loss)),
1008            train_accuracy: result
1009                .epoch_metrics
1010                .last()
1011                .map_or(0.0, |m| f64::from(m.train_accuracy)),
1012            epochs_run: result.epoch_metrics.len(),
1013            time_ms: trial_start.elapsed().as_millis() as u64,
1014            config: config_map,
1015            status: if result.stopped_early {
1016                "stopped_early".to_string()
1017            } else {
1018                "completed".to_string()
1019            },
1020        };
1021
1022        tracker.log_manual_trial(manual, &result);
1023
1024        if let Some(cb) = apply.on_trial_complete {
1025            cb(0, 1, &summary);
1026        }
1027
1028        return Ok(super::classify_tuner::TuneResult {
1029            strategy: "manual".to_string(),
1030            mode: "manual".to_string(),
1031            budget: 1,
1032            trials: vec![summary],
1033            best_trial_id: 0,
1034            total_time_ms: total_start.elapsed().as_millis() as u64,
1035        });
1036    }
1037
1038    // ── HPO strategy: multiple trials ──────────────────────────────────
1039    let strategy: TuneStrategy = plan.hyperparameters.strategy.parse().unwrap_or(TuneStrategy::Tpe);
1040
1041    let num_classes = plan.data.class_counts.len();
1042
1043    let tune_config = TuneConfig {
1044        budget: plan.hyperparameters.budget,
1045        strategy,
1046        scheduler: SchedulerKind::Asha,
1047        scout: plan.hyperparameters.scout,
1048        max_epochs: plan.hyperparameters.max_epochs,
1049        num_classes,
1050        seed: 42,
1051        time_limit_secs: None,
1052    };
1053
1054    let mut tuner = ClassifyTuner::new(tune_config)?;
1055    let mut searcher = tuner.build_searcher();
1056    let scheduler = tuner.build_scheduler();
1057
1058    let budget = plan.hyperparameters.budget;
1059
1060    // Save plan as YAML in output dir for reproducibility
1061    let plan_path = apply.output_dir.join("plan.yaml");
1062    let _ = std::fs::write(&plan_path, plan.to_yaml());
1063
1064    for trial_idx in 0..budget {
1065        // ── Suggest hyperparameters ────────────────────────────────────
1066        let suggestion = match searcher.suggest() {
1067            Ok(s) => s,
1068            Err(e) => {
1069                eprintln!("  Trial {}: searcher exhausted ({e}), stopping", trial_idx + 1);
1070                break;
1071            }
1072        };
1073
1074        let (lr, rank, alpha, batch_size, warmup, clip, weights_strategy, _targets, lr_min_ratio) =
1075            super::classify_tuner::extract_trial_params(&suggestion.config);
1076
1077        // ── Build ClassifyConfig from trial params ─────────────────────
1078        let class_weights =
1079            resolve_class_weights(&weights_strategy, &plan.data.class_counts, num_classes);
1080
1081        let epochs = if plan.hyperparameters.scout { 1 } else { plan.hyperparameters.max_epochs };
1082
1083        let classify_config = ClassifyConfig {
1084            num_classes,
1085            lora_rank: rank,
1086            lora_alpha: alpha,
1087            learning_rate: lr,
1088            epochs,
1089            batch_size,
1090            gradient_clip_norm: Some(clip),
1091            class_weights,
1092            quantize_nf4: auto_nf4,
1093            ..ClassifyConfig::default()
1094        };
1095
1096        let trial_dir = apply.output_dir.join(format!("trial_{:03}", trial_idx + 1));
1097        let trial_start = std::time::Instant::now();
1098
1099        eprintln!(
1100            "  Trial {}/{}: lr={:.2e} rank={} alpha={:.1} batch={} warmup={:.2} clip={:.1} weights={}",
1101            trial_idx + 1, budget, lr, rank, alpha, batch_size, warmup, clip, weights_strategy
1102        );
1103
1104        // ── Execute trial ──────────────────────────────────────────────
1105        let trial_result = run_single_trial_with_warmup(
1106            &apply.model_path,
1107            &apply.data_path,
1108            &trial_dir,
1109            &model_config,
1110            classify_config,
1111            epochs,
1112            warmup,
1113            lr_min_ratio,
1114            &plan.model.size,
1115        );
1116
1117        let trial_time_ms = trial_start.elapsed().as_millis() as u64;
1118
1119        match trial_result {
1120            Ok(result) => {
1121                let val_loss = f64::from(result.best_val_loss);
1122                let val_accuracy = result
1123                    .epoch_metrics
1124                    .get(result.best_epoch)
1125                    .map_or(0.0, |m| f64::from(m.val_accuracy));
1126
1127                // ── Check scheduler for early stopping ─────────────────
1128                let was_pruned = scheduler.should_stop(trial_idx, result.best_epoch, val_loss);
1129
1130                let status = resolve_trial_status(was_pruned, result.stopped_early);
1131
1132                let summary = TrialSummary {
1133                    id: trial_idx,
1134                    val_loss,
1135                    val_accuracy,
1136                    train_loss: result
1137                        .epoch_metrics
1138                        .last()
1139                        .map_or(0.0, |m| f64::from(m.train_loss)),
1140                    train_accuracy: result
1141                        .epoch_metrics
1142                        .last()
1143                        .map_or(0.0, |m| f64::from(m.train_accuracy)),
1144                    epochs_run: result.epoch_metrics.len(),
1145                    time_ms: trial_time_ms,
1146                    config: suggestion.config.clone(),
1147                    status: status.to_string(),
1148                };
1149
1150                eprintln!(
1151                    "    => val_loss={:.4} val_acc={:.1}% epochs={} [{}]",
1152                    val_loss,
1153                    val_accuracy * 100.0,
1154                    result.epoch_metrics.len(),
1155                    status,
1156                );
1157
1158                tracker.log_hpo_trial(&suggestion.config, &result, was_pruned);
1159
1160                // Record for Bayesian learner
1161                searcher.record(suggestion.clone(), val_loss, result.epoch_metrics.len());
1162                tuner.record_trial(summary.clone());
1163
1164                if let Some(cb) = apply.on_trial_complete {
1165                    cb(trial_idx, budget, &summary);
1166                }
1167            }
1168            Err(e) => {
1169                eprintln!("    => FAILED: {e}");
1170                tracker.log_failed_trial();
1171
1172                let summary = TrialSummary {
1173                    id: trial_idx,
1174                    val_loss: f64::INFINITY,
1175                    val_accuracy: 0.0,
1176                    train_loss: f64::INFINITY,
1177                    train_accuracy: 0.0,
1178                    epochs_run: 0,
1179                    time_ms: trial_time_ms,
1180                    config: suggestion.config.clone(),
1181                    status: "failed".to_string(),
1182                };
1183                searcher.record(suggestion, f64::INFINITY, 0);
1184                tuner.record_trial(summary);
1185            }
1186        }
1187    }
1188
1189    let total_time_ms = total_start.elapsed().as_millis() as u64;
1190
1191    // Save leaderboard
1192    let result = tuner.into_result(total_time_ms);
1193    let leaderboard_path = apply.output_dir.join("leaderboard.json");
1194    let _ = std::fs::write(
1195        &leaderboard_path,
1196        serde_json::to_string_pretty(&result).unwrap_or_default(),
1197    );
1198
1199    Ok(result)
1200}
1201
1202/// Execute a single training trial with explicit warmup/LR min parameters.
1203fn run_single_trial_with_warmup(
1204    model_path: &std::path::Path,
1205    data_path: &std::path::Path,
1206    checkpoint_dir: &std::path::Path,
1207    model_config: &crate::transformer::TransformerConfig,
1208    classify_config: super::classify_pipeline::ClassifyConfig,
1209    epochs: usize,
1210    warmup_fraction: f32,
1211    lr_min_ratio: f32,
1212    model_name: &str,
1213) -> crate::Result<super::classify_trainer::TrainResult> {
1214    use super::classify_pipeline::ClassifyPipeline;
1215    use super::classify_trainer::{ClassifyTrainer, TrainingConfig};
1216
1217    // Create checkpoint directory
1218    std::fs::create_dir_all(checkpoint_dir).map_err(|e| {
1219        crate::Error::Io(format!(
1220            "Failed to create checkpoint dir {}: {e}",
1221            checkpoint_dir.display()
1222        ))
1223    })?;
1224
1225    // Load pipeline with pretrained weights
1226    let pipeline = ClassifyPipeline::from_pretrained(model_path, model_config, classify_config)?;
1227
1228    // Load corpus
1229    let samples = pipeline.load_corpus(data_path)?;
1230
1231    let lr_min = pipeline.config.learning_rate * lr_min_ratio;
1232
1233    // Build training config
1234    let training_config = TrainingConfig {
1235        epochs,
1236        val_split: 0.2,
1237        save_every: 1,
1238        early_stopping_patience: 5,
1239        checkpoint_dir: checkpoint_dir.to_path_buf(),
1240        seed: 42,
1241        log_interval: 1,
1242        warmup_fraction,
1243        lr_min,
1244        ..TrainingConfig::default()
1245    };
1246
1247    // Create trainer
1248    let mut trainer = ClassifyTrainer::new(pipeline, samples, training_config)?;
1249
1250    // Attach monitor writer
1251    let experiment_id = format!(
1252        "trial-{}",
1253        std::time::SystemTime::now()
1254            .duration_since(std::time::UNIX_EPOCH)
1255            .map(|d| d.as_secs())
1256            .unwrap_or(0)
1257    );
1258    let writer =
1259        crate::monitor::tui::TrainingStateWriter::new(checkpoint_dir, &experiment_id, model_name);
1260    trainer.set_monitor_writer(writer);
1261
1262    // Run training
1263    Ok(trainer.train())
1264}
1265
1266/// Resolve class weights from strategy name and class counts.
1267pub(crate) fn resolve_class_weights(
1268    strategy: &str,
1269    class_counts: &[usize],
1270    num_classes: usize,
1271) -> Option<Vec<f32>> {
1272    use super::classification::{compute_class_weights, ClassWeightStrategy, SafetyCorpusStats};
1273
1274    match strategy {
1275        "uniform" => None,
1276        "inverse_freq" => {
1277            let stats = SafetyCorpusStats {
1278                total: class_counts.iter().sum(),
1279                class_counts: class_counts.to_vec(),
1280                avg_input_len: 0,
1281            };
1282            Some(compute_class_weights(&stats, ClassWeightStrategy::InverseFreq, num_classes))
1283        }
1284        "sqrt_inverse" => {
1285            let stats = SafetyCorpusStats {
1286                total: class_counts.iter().sum(),
1287                class_counts: class_counts.to_vec(),
1288                avg_input_len: 0,
1289            };
1290            Some(compute_class_weights(&stats, ClassWeightStrategy::SqrtInverse, num_classes))
1291        }
1292        _ => None,
1293    }
1294}
1295
1296// ═══════════════════════════════════════════════════════════════════════
1297// Display helpers (for CLI consumption)
1298// ═══════════════════════════════════════════════════════════════════════
1299
1300impl TrainingPlan {
1301    /// Serialize to pretty JSON.
1302    pub fn to_json(&self) -> String {
1303        serde_json::to_string_pretty(self).unwrap_or_default()
1304    }
1305
1306    /// Serialize to YAML.
1307    pub fn to_yaml(&self) -> String {
1308        serde_yaml::to_string(self).unwrap_or_default()
1309    }
1310
1311    /// Deserialize from a string (auto-detects JSON or YAML).
1312    #[allow(clippy::should_implement_trait)]
1313    pub fn from_str(s: &str) -> crate::Result<Self> {
1314        // Try JSON first (faster, more common for programmatic use)
1315        if let Ok(plan) = serde_json::from_str::<TrainingPlan>(s) {
1316            return Ok(plan);
1317        }
1318        // Fall back to YAML
1319        serde_yaml::from_str::<TrainingPlan>(s).map_err(|e| {
1320            crate::Error::ConfigError(format!("Failed to parse plan as JSON or YAML: {e}"))
1321        })
1322    }
1323
1324    /// Count pre-flight checks by status.
1325    pub fn check_counts(&self) -> (usize, usize, usize) {
1326        let pass = self.pre_flight.iter().filter(|c| c.status == CheckStatus::Pass).count();
1327        let warn = self.pre_flight.iter().filter(|c| c.status == CheckStatus::Warn).count();
1328        let fail = self.pre_flight.iter().filter(|c| c.status == CheckStatus::Fail).count();
1329        (pass, warn, fail)
1330    }
1331}
1332
1333/// Map pruned/stopped_early flags to a status string.
1334pub(crate) fn resolve_trial_status(was_pruned: bool, stopped_early: bool) -> &'static str {
1335    if was_pruned {
1336        "pruned"
1337    } else if stopped_early {
1338        "stopped_early"
1339    } else {
1340        "completed"
1341    }
1342}
1343
1344// ── Experiment store integration ──────────────────────────────────────────
1345
1346/// Thin wrapper around SqliteBackend for logging training experiments.
1347/// All methods are best-effort (errors silently ignored) so training is
1348/// never blocked by storage failures.
1349pub(crate) struct ExperimentTracker {
1350    pub(crate) store: Option<crate::storage::SqliteBackend>,
1351    pub(crate) exp_id: Option<String>,
1352}
1353
1354impl ExperimentTracker {
1355    pub(crate) fn open(output_dir: &std::path::Path, plan: &TrainingPlan) -> Self {
1356        use crate::storage::{ExperimentStorage, SqliteBackend};
1357
1358        let mut store = SqliteBackend::open_project(output_dir).ok();
1359        let exp_id = store.as_mut().and_then(|s| {
1360            let config_json = serde_json::json!({
1361                "model": &plan.model.architecture,
1362                "size": &plan.model.size,
1363                "strategy": &plan.hyperparameters.strategy,
1364                "budget": plan.hyperparameters.budget,
1365                "num_classes": plan.data.class_counts.len(),
1366            });
1367            s.create_experiment(&plan.model.architecture, Some(config_json)).ok()
1368        });
1369        Self { store, exp_id }
1370    }
1371
1372    fn log_manual_trial(
1373        &mut self,
1374        manual: &ManualConfig,
1375        result: &super::classify_trainer::TrainResult,
1376    ) {
1377        use crate::storage::{ExperimentStorage, ParameterValue as SPV};
1378        let (store, eid) = match (self.store.as_mut(), self.exp_id.as_ref()) {
1379            (Some(s), Some(e)) => (s, e),
1380            _ => return,
1381        };
1382        let run_id = match store.create_run(eid) {
1383            Ok(id) => id,
1384            Err(_) => return,
1385        };
1386        let _ = store.start_run(&run_id);
1387        let _ =
1388            store.log_param(&run_id, "learning_rate", SPV::Float(f64::from(manual.learning_rate)));
1389        let _ = store.log_param(&run_id, "lora_rank", SPV::Int(manual.lora_rank as i64));
1390        let _ = store.log_param(&run_id, "batch_size", SPV::Int(manual.batch_size as i64));
1391        Self::log_epoch_metrics(store, &run_id, &result.epoch_metrics);
1392        let _ = store.complete_run(&run_id, crate::storage::RunStatus::Success);
1393    }
1394
1395    fn log_hpo_trial(
1396        &mut self,
1397        config: &std::collections::HashMap<String, crate::optim::ParameterValue>,
1398        result: &super::classify_trainer::TrainResult,
1399        was_pruned: bool,
1400    ) {
1401        use crate::optim::ParameterValue as OPV;
1402        use crate::storage::{ExperimentStorage, ParameterValue as SPV};
1403        let (store, eid) = match (self.store.as_mut(), self.exp_id.as_ref()) {
1404            (Some(s), Some(e)) => (s, e),
1405            _ => return,
1406        };
1407        let run_id = match store.create_run(eid) {
1408            Ok(id) => id,
1409            Err(_) => return,
1410        };
1411        let _ = store.start_run(&run_id);
1412        for (k, v) in config {
1413            let sv = match v {
1414                OPV::Float(f) => SPV::Float(*f),
1415                OPV::Int(i) => SPV::Int(*i),
1416                OPV::Categorical(s) => SPV::String(s.clone()),
1417            };
1418            let _ = store.log_param(&run_id, k, sv);
1419        }
1420        Self::log_epoch_metrics(store, &run_id, &result.epoch_metrics);
1421        let status = if was_pruned {
1422            crate::storage::RunStatus::Cancelled
1423        } else {
1424            crate::storage::RunStatus::Success
1425        };
1426        let _ = store.complete_run(&run_id, status);
1427    }
1428
1429    pub(crate) fn log_failed_trial(&mut self) {
1430        use crate::storage::ExperimentStorage;
1431        let (store, eid) = match (self.store.as_mut(), self.exp_id.as_ref()) {
1432            (Some(s), Some(e)) => (s, e),
1433            _ => return,
1434        };
1435        if let Ok(run_id) = store.create_run(eid) {
1436            let _ = store.start_run(&run_id);
1437            let _ = store.complete_run(&run_id, crate::storage::RunStatus::Failed);
1438        }
1439    }
1440
1441    fn log_epoch_metrics(
1442        store: &mut crate::storage::SqliteBackend,
1443        run_id: &str,
1444        epochs: &[super::classify_trainer::EpochMetrics],
1445    ) {
1446        use crate::storage::ExperimentStorage;
1447        for (i, epoch) in epochs.iter().enumerate() {
1448            let _ = store.log_metric(run_id, "train_loss", i as u64, f64::from(epoch.train_loss));
1449            let _ = store.log_metric(run_id, "val_loss", i as u64, f64::from(epoch.val_loss));
1450            let _ =
1451                store.log_metric(run_id, "val_accuracy", i as u64, f64::from(epoch.val_accuracy));
1452        }
1453    }
1454}
1455
1456#[cfg(test)]
1457#[allow(clippy::unwrap_used)]
1458#[path = "training_plan_tests.rs"]
1459mod tests;