apr_cli/model_ops_commands.rs
1
2#[derive(Subcommand, Debug)]
3pub enum ModelOpsCommands {
4 /// Fine-tune model with LoRA/QLoRA (GH-244)
5 #[cfg(feature = "training")]
6 Finetune {
7 /// Input model file
8 #[arg(value_name = "FILE")]
9 file: Option<PathBuf>,
10 /// Fine-tuning method: auto, full, lora, qlora
11 #[arg(long, short = 'm', default_value = "auto")]
12 method: String,
13 /// LoRA rank (default: auto-selected)
14 #[arg(long, short = 'r')]
15 rank: Option<u32>,
16 /// Available VRAM in GB
17 #[arg(long, default_value = "16.0")]
18 vram: f64,
19 /// Plan mode (estimate only)
20 #[arg(long)]
21 plan: bool,
22 /// Training data file (JSONL format)
23 #[arg(long, short = 'd', value_name = "FILE")]
24 data: Option<PathBuf>,
25 /// Output path (adapter dir or merged model)
26 #[arg(short, long)]
27 output: Option<PathBuf>,
28 /// Adapter path for merge mode
29 #[arg(long)]
30 adapter: Option<PathBuf>,
31 /// Merge adapter into base model
32 #[arg(long)]
33 merge: bool,
34 /// Training epochs
35 #[arg(long, default_value = "3")]
36 epochs: u32,
37 /// Learning rate
38 #[arg(long, default_value = "0.0002")]
39 learning_rate: f64,
40 /// Model size for planning (e.g., "7B", "1.5B")
41 #[arg(long, value_name = "SIZE")]
42 model_size: Option<String>,
43 /// Fine-tuning task: classify (sequence classification)
44 #[arg(long)]
45 task: Option<String>,
46 /// Number of classes for classification task
47 #[arg(long, default_value = "5")]
48 num_classes: usize,
49 /// Output format for checkpoints: apr, safetensors, or both (comma-separated)
50 #[arg(long, value_name = "FORMAT", default_value = "apr,safetensors")]
51 checkpoint_format: String,
52 /// Oversample minority classes to match majority (for imbalanced datasets)
53 #[arg(long)]
54 oversample: bool,
55 /// Maximum sequence length for GPU buffer allocation (lower = less VRAM)
56 #[arg(long, value_name = "LEN")]
57 max_seq_len: Option<usize>,
58 /// Quantize frozen weights to NF4 (4-bit) for QLoRA training (~8x VRAM savings)
59 #[arg(long)]
60 quantize_nf4: bool,
61 /// GPU indices for data-parallel training (e.g., "0,1" for dual GPU)
62 #[arg(long, value_name = "INDICES")]
63 gpus: Option<String>,
64 /// GPU backend selection: auto, cuda, wgpu
65 #[arg(long, default_value = "auto")]
66 gpu_backend: String,
67 /// Distributed training role: coordinator or worker
68 #[arg(long, value_name = "ROLE")]
69 role: Option<String>,
70 /// Address to bind (coordinator) or connect to (worker)
71 #[arg(long, value_name = "ADDR")]
72 bind: Option<String>,
73 /// Coordinator address for worker nodes (e.g., "intel:9000")
74 #[arg(long, value_name = "ADDR")]
75 coordinator: Option<String>,
76 /// Expected number of workers (coordinator only)
77 #[arg(long, value_name = "N")]
78 expect_workers: Option<usize>,
79 /// Wait for VRAM availability before training (timeout in seconds, 0 = no wait)
80 #[arg(long, value_name = "SECS", default_value = "0")]
81 wait_gpu: u64,
82 /// Multi-adapter training: data:checkpoint pairs (GPU-SHARE Phase 2)
83 /// Format: --adapters data/corpus-a.jsonl:checkpoints/adapter-a
84 /// Can be specified multiple times for concurrent adapter training.
85 #[arg(long, value_name = "DATA:CHECKPOINT")]
86 adapters: Vec<String>,
87
88 /// Multi-adapter config file: TOML with [[adapter]] entries (GPU-SHARE §2.4)
89 #[arg(long, value_name = "FILE")]
90 adapters_config: Option<PathBuf>,
91
92 /// Enable experimental CUDA MPS for concurrent GPU sharing (GPU-SHARE §1.5).
93 /// WARNING: A GPU fault in any MPS client will crash ALL clients on that GPU.
94 #[arg(long)]
95 experimental_mps: bool,
96
97 /// MPS thread percentage (1-100). Controls SM allocation per process.
98 /// Only effective with --experimental-mps. Default: 50.
99 #[arg(long, value_name = "PCT", default_value = "50")]
100 gpu_share: u32,
101
102 /// PMAT-486: Enable StepProfiler for per-phase wall-clock timing
103 #[arg(long)]
104 profile: bool,
105 },
106 /// Prune model (structured/unstructured pruning) (GH-247)
107 Prune {
108 /// Input model file
109 #[arg(value_name = "FILE")]
110 file: PathBuf,
111 /// Pruning method: magnitude, structured, depth, width, wanda, sparsegpt
112 #[arg(long, short = 'm', default_value = "magnitude")]
113 method: String,
114 /// Target pruning ratio (0-1)
115 #[arg(long, default_value = "0.5")]
116 target_ratio: f32,
117 /// Sparsity level (0-1)
118 #[arg(long, default_value = "0.0")]
119 sparsity: f32,
120 /// Output file path
121 #[arg(short, long)]
122 output: Option<PathBuf>,
123 /// Layers to remove for depth pruning (e.g., "20-24")
124 #[arg(long)]
125 remove_layers: Option<String>,
126 /// Analyze mode (identify pruning opportunities)
127 #[arg(long)]
128 analyze: bool,
129 /// Plan mode (estimate only)
130 #[arg(long)]
131 plan: bool,
132 /// Calibration data file
133 #[arg(long, value_name = "FILE")]
134 calibration: Option<PathBuf>,
135 },
136 /// Knowledge distillation (teacher -> student) (GH-247, ALB-011)
137 Distill {
138 /// Teacher model file (positional, for file-based mode)
139 #[arg(value_name = "TEACHER")]
140 teacher: Option<PathBuf>,
141 /// Student model file
142 #[arg(long, value_name = "FILE")]
143 student: Option<PathBuf>,
144 /// Training data file
145 #[arg(long, short = 'd', value_name = "FILE")]
146 data: Option<PathBuf>,
147 /// Output file path
148 #[arg(short, long)]
149 output: Option<PathBuf>,
150 /// Distillation strategy: standard, progressive, ensemble
151 #[arg(long, default_value = "standard")]
152 strategy: String,
153 /// Temperature for softmax scaling
154 #[arg(long, default_value = "3.0")]
155 temperature: f64,
156 /// Alpha weight for KL vs task loss
157 #[arg(long, default_value = "0.7")]
158 alpha: f64,
159 /// Training epochs
160 #[arg(long, default_value = "3")]
161 epochs: u32,
162 /// Plan mode (estimate only)
163 #[arg(long)]
164 plan: bool,
165 /// YAML config file for two-stage distillation (ALB-011)
166 #[arg(long, value_name = "FILE")]
167 config: Option<PathBuf>,
168 /// Distillation stage: precompute, train (logit KD), or generate (text-based, GH-455)
169 #[arg(long, value_name = "STAGE")]
170 stage: Option<String>,
171 },
172}