Skip to main content

apr_cli/
train_commands.rs

1
2/// Training pipeline subcommands (forjar-style plan/apply).
3///
4/// Thin CLI wrappers around entrenar's training plan/apply infrastructure.
5#[derive(Subcommand, Debug)]
6pub enum TrainCommands {
7    /// Generate a training plan without touching the GPU.
8    ///
9    /// Validates data quality, checks model compatibility, builds HPO search space,
10    /// estimates resource usage, and runs pre-flight checks. Outputs a serializable
11    /// plan manifest (text, JSON, or YAML).
12    ///
13    /// Analogous to `forjar plan` — shows what will happen before committing GPU time.
14    Plan {
15        /// Path to training data (JSONL) — required for --task classify
16        #[arg(long, value_name = "FILE")]
17        data: Option<PathBuf>,
18        /// Model size: "0.5B", "9B", "7B", "13B"
19        #[arg(long, default_value = "0.5B")]
20        model_size: String,
21        /// Path to model weights directory
22        #[arg(long, value_name = "DIR")]
23        model_path: Option<PathBuf>,
24        /// Number of output classes
25        #[arg(long, default_value = "5")]
26        num_classes: usize,
27        /// Task type: classify, pretrain
28        #[arg(long, default_value = "classify")]
29        task: String,
30        /// YAML training config (for --task pretrain)
31        #[arg(long, value_name = "FILE")]
32        config: Option<PathBuf>,
33        /// Output directory for checkpoints
34        #[arg(short, long, default_value = "/tmp/training-output")]
35        output: PathBuf,
36        /// HPO strategy: tpe, grid, random, manual
37        #[arg(long, default_value = "tpe")]
38        strategy: String,
39        /// HPO budget (number of trials)
40        #[arg(long, default_value = "20")]
41        budget: usize,
42        /// Scout mode: 1 epoch per trial for fast exploration
43        #[arg(long)]
44        scout: bool,
45        /// Maximum epochs per trial
46        #[arg(long, default_value = "3")]
47        max_epochs: usize,
48        /// Manual learning rate (only used with --strategy manual)
49        #[arg(long)]
50        learning_rate: Option<f32>,
51        /// Manual LoRA rank (only used with --strategy manual)
52        #[arg(long)]
53        lora_rank: Option<usize>,
54        /// Manual batch size (only used with --strategy manual)
55        #[arg(long)]
56        batch_size: Option<usize>,
57        /// Validation data file (JSONL)
58        #[arg(long, value_name = "FILE")]
59        val_data: Option<PathBuf>,
60        /// Test data file (JSONL)
61        #[arg(long, value_name = "FILE")]
62        test_data: Option<PathBuf>,
63        /// Output format: text, json, yaml
64        #[arg(long, default_value = "text")]
65        format: String,
66    },
67
68    /// Execute a training plan (allocate GPU, run trials).
69    ///
70    /// Reads a previously generated plan (YAML/JSON) and executes it:
71    /// - Manual strategy: single training run with specified hyperparameters
72    /// - HPO strategy: multiple trials with automatic hyperparameter tuning
73    ///
74    /// Analogous to `forjar apply` — commits resources and executes the plan.
75    Apply {
76        /// Path to a saved plan file (YAML or JSON from `apr train plan`)
77        #[arg(long, value_name = "FILE")]
78        plan: Option<PathBuf>,
79
80        /// YAML training config (for --task pretrain)
81        #[arg(long, value_name = "FILE")]
82        config: Option<PathBuf>,
83
84        /// Task type: classify, pretrain
85        #[arg(long, default_value = "classify")]
86        task: String,
87
88        // ── Inline plan params (used when no --plan file is given) ─────
89        /// Path to training data (JSONL)
90        #[arg(long, value_name = "FILE")]
91        data: Option<PathBuf>,
92        /// Model size: "0.5B", "9B", "7B", "13B"
93        #[arg(long, default_value = "0.5B")]
94        model_size: String,
95        /// Path to model weights directory
96        #[arg(long, value_name = "DIR")]
97        model_path: Option<PathBuf>,
98        /// Number of output classes
99        #[arg(long, default_value = "5")]
100        num_classes: usize,
101        /// Output directory for checkpoints and leaderboard
102        #[arg(short, long, default_value = "/tmp/training-output")]
103        output: PathBuf,
104        /// HPO strategy: tpe, grid, random, manual
105        #[arg(long, default_value = "tpe")]
106        strategy: String,
107        /// HPO budget (number of trials)
108        #[arg(long, default_value = "20")]
109        budget: usize,
110        /// Scout mode: 1 epoch per trial
111        #[arg(long)]
112        scout: bool,
113        /// Maximum epochs per trial
114        #[arg(long, default_value = "3")]
115        max_epochs: usize,
116        /// Manual learning rate (only used with --strategy manual)
117        #[arg(long)]
118        learning_rate: Option<f32>,
119        /// Manual LoRA rank (only used with --strategy manual)
120        #[arg(long)]
121        lora_rank: Option<usize>,
122        /// Manual batch size (only used with --strategy manual)
123        #[arg(long)]
124        batch_size: Option<usize>,
125
126        // ── Distributed training params (tickets #131-#140, aprender #393) ──
127        /// Enable distributed data-parallel training
128        #[arg(long)]
129        distributed: bool,
130        /// Total number of workers (default: auto-detect GPUs)
131        #[arg(long, value_name = "N")]
132        world_size: Option<usize>,
133        /// This worker's global rank (default: 0 = coordinator)
134        #[arg(long, value_name = "N")]
135        rank: Option<usize>,
136        /// Coordinator address for distributed training (default: 0.0.0.0:9000)
137        #[arg(long, value_name = "HOST:PORT")]
138        coordinator_addr: Option<String>,
139
140        // ── Reproducibility params (R-084 C-DETERM-001) ──
141        /// Enable bitwise deterministic training (CUBLAS_WORKSPACE_CONFIG, cuDNN deterministic)
142        #[arg(long)]
143        deterministic: bool,
144        /// Random seed for reproducibility (default: from YAML or 42)
145        #[arg(long, value_name = "N")]
146        seed: Option<u64>,
147
148        // ── Profiling params (PMAT-486) ──
149        /// Enable StepProfiler for per-phase wall-clock timing (KAIZEN-047)
150        #[arg(long)]
151        profile: bool,
152        /// StepProfiler report interval (every N steps, default: 50)
153        #[arg(long, value_name = "N", default_value = "50")]
154        profile_interval: usize,
155    },
156
157    /// Watch a training run with automatic restart on crash and hang detection.
158    ///
159    /// Monitors a running or to-be-started training process:
160    /// - Detects crashes (SIGABRT, SIGSEGV, OOM) and restarts with backoff
161    /// - Detects hangs via heartbeat/training_state.json staleness
162    /// - Captures GPU state and crash diagnostics
163    /// - Auto-enables CUDA_LAUNCH_BLOCKING on async crash pattern
164    ///
165    /// Sovereign Rust replacement for train-guard.sh.
166    Watch {
167        /// YAML training config to run and watch
168        #[arg(long, value_name = "FILE")]
169        config: PathBuf,
170
171        /// Maximum number of restart attempts
172        #[arg(long, default_value = "5")]
173        max_restarts: usize,
174
175        /// Heartbeat staleness threshold in seconds
176        #[arg(long, default_value = "300")]
177        heartbeat_timeout: u64,
178
179        /// Initial backoff delay in seconds
180        #[arg(long, default_value = "30")]
181        backoff_initial: u64,
182
183        /// Maximum backoff delay in seconds
184        #[arg(long, default_value = "600")]
185        backoff_max: u64,
186    },
187
188    /// Generate hyperparameter sweep configs from a base YAML.
189    ///
190    /// Creates N training configs with varied hyperparameters using grid
191    /// or random search. Each config is a complete YAML that can be
192    /// passed to `apr train apply --task pretrain --config <file>`.
193    ///
194    /// Sovereign Rust replacement for hyperparam-sweep.py.
195    Sweep {
196        /// Base YAML training config to sweep from
197        #[arg(long, value_name = "FILE")]
198        config: PathBuf,
199
200        /// Search strategy: grid or random
201        #[arg(long, default_value = "random")]
202        strategy: String,
203
204        /// Number of configs to generate (random) or max combinations (grid)
205        #[arg(long, default_value = "10")]
206        num_configs: usize,
207
208        /// Output directory for generated configs
209        #[arg(long, default_value = "sweeps/")]
210        output_dir: PathBuf,
211
212        /// Seed for random search reproducibility
213        #[arg(long, default_value = "42")]
214        seed: u64,
215    },
216
217    /// Run successive halving HPO on sweep configs (C-HPO-001).
218    ///
219    /// Takes a directory of sweep configs (from `apr train sweep`), runs each
220    /// for `--steps-per-round` steps, kills the worst half by val_ppl, doubles
221    /// steps, and repeats for `--rounds` rounds. Reports the winner with
222    /// μTransfer-scaled LR for the target model width.
223    ///
224    /// References: Hyperband (Li et al. 2018, arXiv:1603.06560),
225    /// μTransfer (Yang et al. 2022, arXiv:2203.03466).
226    Halving {
227        /// Directory containing sweep-*.yaml configs (from `apr train sweep`)
228        #[arg(long, value_name = "DIR")]
229        sweep_dir: PathBuf,
230
231        /// Number of halving rounds (default: 3)
232        #[arg(long, default_value = "3")]
233        rounds: usize,
234
235        /// Training steps in first round (doubles each round)
236        #[arg(long, default_value = "500")]
237        steps_per_round: usize,
238
239        /// Proxy model hidden_size (for μTransfer scaling)
240        #[arg(long, default_value = "512")]
241        source_width: usize,
242
243        /// Target model hidden_size (for μTransfer scaling)
244        #[arg(long, default_value = "1024")]
245        target_width: usize,
246
247        /// Output JSON file for results
248        #[arg(long, default_value = "sweeps/hpo-results.json")]
249        output: PathBuf,
250    },
251
252    /// Archive a checkpoint into a release bundle.
253    ///
254    /// Packages model weights, config, training state, and metadata
255    /// into a self-contained directory with integrity manifest.
256    Archive {
257        /// Path to checkpoint directory
258        #[arg(value_name = "CHECKPOINT_DIR")]
259        checkpoint_dir: PathBuf,
260
261        /// Output archive directory
262        #[arg(short, long, value_name = "DIR")]
263        output: PathBuf,
264
265        /// Release version tag (e.g., "v1.0")
266        #[arg(long = "release-version")]
267        release_version: Option<String>,
268
269        /// Release notes
270        #[arg(long)]
271        notes: Option<String>,
272    },
273
274    /// Submit multi-adapter training jobs to a cluster (GPU-SHARE Phase 3).
275    ///
276    /// Reads a cluster.yaml config, places adapter jobs across nodes using
277    /// the greedy placement algorithm, and generates launch commands.
278    Submit {
279        /// Path to cluster config YAML
280        #[arg(long, value_name = "FILE")]
281        cluster: PathBuf,
282
283        /// Model checkpoint path (.apr)
284        #[arg(long, value_name = "FILE")]
285        model: PathBuf,
286
287        /// Adapter specs: DATA:CHECKPOINT pairs (one per adapter)
288        #[arg(long = "adapter", value_name = "DATA:CHECKPOINT")]
289        adapters: Vec<String>,
290
291        /// LoRA rank
292        #[arg(long, default_value = "16")]
293        rank: u32,
294
295        /// Number of training epochs
296        #[arg(long, default_value = "3")]
297        epochs: u32,
298
299        /// Estimated VRAM budget per adapter (MB)
300        #[arg(long, default_value = "6000")]
301        budget_mb: u64,
302
303        /// Dry run: show placement and commands without executing
304        #[arg(long)]
305        dry_run: bool,
306    },
307
308    /// Show cluster status: nodes, GPUs, adapter capacity (GPU-SHARE Phase 3).
309    ///
310    /// Reads a cluster.yaml config and displays node health, VRAM availability,
311    /// and adapter placement capacity.
312    ClusterStatus {
313        /// Path to cluster config YAML
314        #[arg(long, value_name = "FILE")]
315        cluster: PathBuf,
316    },
317}