Skip to main content

apr_cli/
train_commands.rs

1
2/// Training pipeline subcommands (forjar-style plan/apply).
3///
4/// Thin CLI wrappers around entrenar's training plan/apply infrastructure.
5#[derive(Subcommand, Debug)]
6pub enum TrainCommands {
7    /// Generate a training plan without touching the GPU.
8    ///
9    /// Validates data quality, checks model compatibility, builds HPO search space,
10    /// estimates resource usage, and runs pre-flight checks. Outputs a serializable
11    /// plan manifest (text, JSON, or YAML).
12    ///
13    /// Analogous to `forjar plan` — shows what will happen before committing GPU time.
14    Plan {
15        /// Path to training data (JSONL) — required for --task classify
16        #[arg(long, value_name = "FILE")]
17        data: Option<PathBuf>,
18        /// Model size: "0.5B", "9B", "7B", "13B"
19        #[arg(long, default_value = "0.5B")]
20        model_size: String,
21        /// Path to model weights directory
22        #[arg(long, value_name = "DIR")]
23        model_path: Option<PathBuf>,
24        /// Number of output classes
25        #[arg(long, default_value = "5")]
26        num_classes: usize,
27        /// Task type: classify, pretrain
28        #[arg(long, default_value = "classify")]
29        task: String,
30        /// YAML training config (for --task pretrain)
31        #[arg(long, value_name = "FILE")]
32        config: Option<PathBuf>,
33        /// Output directory for checkpoints
34        #[arg(short, long, default_value = "/tmp/training-output")]
35        output: PathBuf,
36        /// HPO strategy: tpe, grid, random, manual
37        #[arg(long, default_value = "tpe")]
38        strategy: String,
39        /// HPO budget (number of trials)
40        #[arg(long, default_value = "20")]
41        budget: usize,
42        /// Scout mode: 1 epoch per trial for fast exploration
43        #[arg(long)]
44        scout: bool,
45        /// Maximum epochs per trial
46        #[arg(long, default_value = "3")]
47        max_epochs: usize,
48        /// Manual learning rate (only used with --strategy manual)
49        #[arg(long)]
50        learning_rate: Option<f32>,
51        /// Manual LoRA rank (only used with --strategy manual)
52        #[arg(long)]
53        lora_rank: Option<usize>,
54        /// Manual batch size (only used with --strategy manual)
55        #[arg(long)]
56        batch_size: Option<usize>,
57        /// Validation data file (JSONL)
58        #[arg(long, value_name = "FILE")]
59        val_data: Option<PathBuf>,
60        /// Test data file (JSONL)
61        #[arg(long, value_name = "FILE")]
62        test_data: Option<PathBuf>,
63        /// Output format: text, json, yaml
64        #[arg(long, default_value = "text")]
65        format: String,
66    },
67
68    /// Execute a training plan (allocate GPU, run trials).
69    ///
70    /// Reads a previously generated plan (YAML/JSON) and executes it:
71    /// - Manual strategy: single training run with specified hyperparameters
72    /// - HPO strategy: multiple trials with automatic hyperparameter tuning
73    ///
74    /// Analogous to `forjar apply` — commits resources and executes the plan.
75    Apply {
76        /// Path to a saved plan file (YAML or JSON from `apr train plan`)
77        #[arg(long, value_name = "FILE")]
78        plan: Option<PathBuf>,
79
80        /// YAML training config (for --task pretrain)
81        #[arg(long, value_name = "FILE")]
82        config: Option<PathBuf>,
83
84        /// Task type: classify, pretrain
85        #[arg(long, default_value = "classify")]
86        task: String,
87
88        // ── Inline plan params (used when no --plan file is given) ─────
89        /// Path to training data (JSONL)
90        #[arg(long, value_name = "FILE")]
91        data: Option<PathBuf>,
92        /// Model size: "0.5B", "9B", "7B", "13B"
93        #[arg(long, default_value = "0.5B")]
94        model_size: String,
95        /// Path to model weights directory
96        #[arg(long, value_name = "DIR")]
97        model_path: Option<PathBuf>,
98        /// Number of output classes
99        #[arg(long, default_value = "5")]
100        num_classes: usize,
101        /// Output directory for checkpoints and leaderboard
102        #[arg(short, long, default_value = "/tmp/training-output")]
103        output: PathBuf,
104        /// HPO strategy: tpe, grid, random, manual
105        #[arg(long, default_value = "tpe")]
106        strategy: String,
107        /// HPO budget (number of trials)
108        #[arg(long, default_value = "20")]
109        budget: usize,
110        /// Scout mode: 1 epoch per trial
111        #[arg(long)]
112        scout: bool,
113        /// Maximum epochs per trial
114        #[arg(long, default_value = "3")]
115        max_epochs: usize,
116        /// Manual learning rate (only used with --strategy manual)
117        #[arg(long)]
118        learning_rate: Option<f32>,
119        /// Manual LoRA rank (only used with --strategy manual)
120        #[arg(long)]
121        lora_rank: Option<usize>,
122        /// Manual batch size (only used with --strategy manual)
123        #[arg(long)]
124        batch_size: Option<usize>,
125
126        // ── Distributed training params (tickets #131-#140, aprender #393) ──
127        /// Enable distributed data-parallel training
128        #[arg(long)]
129        distributed: bool,
130        /// Total number of workers (default: auto-detect GPUs)
131        #[arg(long, value_name = "N")]
132        world_size: Option<usize>,
133        /// This worker's global rank (default: 0 = coordinator)
134        #[arg(long, value_name = "N")]
135        rank: Option<usize>,
136        /// Coordinator address for distributed training (default: 0.0.0.0:9000)
137        #[arg(long, value_name = "HOST:PORT")]
138        coordinator_addr: Option<String>,
139
140        // ── Reproducibility params (R-084 C-DETERM-001) ──
141        /// Enable bitwise deterministic training (CUBLAS_WORKSPACE_CONFIG, cuDNN deterministic)
142        #[arg(long)]
143        deterministic: bool,
144        /// Random seed for reproducibility (default: from YAML or 42)
145        #[arg(long, value_name = "N")]
146        seed: Option<u64>,
147    },
148
149    /// Watch a training run with automatic restart on crash and hang detection.
150    ///
151    /// Monitors a running or to-be-started training process:
152    /// - Detects crashes (SIGABRT, SIGSEGV, OOM) and restarts with backoff
153    /// - Detects hangs via heartbeat/training_state.json staleness
154    /// - Captures GPU state and crash diagnostics
155    /// - Auto-enables CUDA_LAUNCH_BLOCKING on async crash pattern
156    ///
157    /// Sovereign Rust replacement for train-guard.sh.
158    Watch {
159        /// YAML training config to run and watch
160        #[arg(long, value_name = "FILE")]
161        config: PathBuf,
162
163        /// Maximum number of restart attempts
164        #[arg(long, default_value = "5")]
165        max_restarts: usize,
166
167        /// Heartbeat staleness threshold in seconds
168        #[arg(long, default_value = "300")]
169        heartbeat_timeout: u64,
170
171        /// Initial backoff delay in seconds
172        #[arg(long, default_value = "30")]
173        backoff_initial: u64,
174
175        /// Maximum backoff delay in seconds
176        #[arg(long, default_value = "600")]
177        backoff_max: u64,
178    },
179
180    /// Generate hyperparameter sweep configs from a base YAML.
181    ///
182    /// Creates N training configs with varied hyperparameters using grid
183    /// or random search. Each config is a complete YAML that can be
184    /// passed to `apr train apply --task pretrain --config <file>`.
185    ///
186    /// Sovereign Rust replacement for hyperparam-sweep.py.
187    Sweep {
188        /// Base YAML training config to sweep from
189        #[arg(long, value_name = "FILE")]
190        config: PathBuf,
191
192        /// Search strategy: grid or random
193        #[arg(long, default_value = "random")]
194        strategy: String,
195
196        /// Number of configs to generate (random) or max combinations (grid)
197        #[arg(long, default_value = "10")]
198        num_configs: usize,
199
200        /// Output directory for generated configs
201        #[arg(long, default_value = "sweeps/")]
202        output_dir: PathBuf,
203
204        /// Seed for random search reproducibility
205        #[arg(long, default_value = "42")]
206        seed: u64,
207    },
208
209    /// Archive a checkpoint into a release bundle.
210    ///
211    /// Packages model weights, config, training state, and metadata
212    /// into a self-contained directory with integrity manifest.
213    Archive {
214        /// Path to checkpoint directory
215        #[arg(value_name = "CHECKPOINT_DIR")]
216        checkpoint_dir: PathBuf,
217
218        /// Output archive directory
219        #[arg(short, long, value_name = "DIR")]
220        output: PathBuf,
221
222        /// Release version tag (e.g., "v1.0")
223        #[arg(long)]
224        version: Option<String>,
225
226        /// Release notes
227        #[arg(long)]
228        notes: Option<String>,
229    },
230
231    /// Submit multi-adapter training jobs to a cluster (GPU-SHARE Phase 3).
232    ///
233    /// Reads a cluster.yaml config, places adapter jobs across nodes using
234    /// the greedy placement algorithm, and generates launch commands.
235    Submit {
236        /// Path to cluster config YAML
237        #[arg(long, value_name = "FILE")]
238        cluster: PathBuf,
239
240        /// Model checkpoint path (.apr)
241        #[arg(long, value_name = "FILE")]
242        model: PathBuf,
243
244        /// Adapter specs: DATA:CHECKPOINT pairs (one per adapter)
245        #[arg(long = "adapter", value_name = "DATA:CHECKPOINT")]
246        adapters: Vec<String>,
247
248        /// LoRA rank
249        #[arg(long, default_value = "16")]
250        rank: u32,
251
252        /// Number of training epochs
253        #[arg(long, default_value = "3")]
254        epochs: u32,
255
256        /// Estimated VRAM budget per adapter (MB)
257        #[arg(long, default_value = "6000")]
258        budget_mb: u64,
259
260        /// Dry run: show placement and commands without executing
261        #[arg(long)]
262        dry_run: bool,
263    },
264
265    /// Show cluster status: nodes, GPUs, adapter capacity (GPU-SHARE Phase 3).
266    ///
267    /// Reads a cluster.yaml config and displays node health, VRAM availability,
268    /// and adapter placement capacity.
269    ClusterStatus {
270        /// Path to cluster config YAML
271        #[arg(long, value_name = "FILE")]
272        cluster: PathBuf,
273    },
274}