Skip to main content

entrenar/config/cli/
extended.rs

1//! Extended command types - Completion, Bench, Inspect, Audit, Monitor, Publish
2
3use clap::{Parser, Subcommand};
4use std::path::PathBuf;
5
6use super::types::{AuditType, InspectMode, OutputFormat, ShellType};
7
8/// Arguments for completion command
9#[derive(Parser, Debug, Clone, PartialEq)]
10pub struct CompletionArgs {
11    /// Shell to generate completions for
12    #[arg(value_name = "SHELL")]
13    pub shell: ShellType,
14}
15
16/// Arguments for bench command
17#[derive(Parser, Debug, Clone, PartialEq)]
18pub struct BenchArgs {
19    /// Path to model or config file
20    #[arg(value_name = "INPUT")]
21    pub input: PathBuf,
22
23    /// Number of warmup iterations
24    #[arg(long, default_value = "10")]
25    pub warmup: usize,
26
27    /// Number of benchmark iterations
28    #[arg(long, default_value = "100")]
29    pub iterations: usize,
30
31    /// Batch sizes to test (comma-separated)
32    #[arg(long, default_value = "1,8,32")]
33    pub batch_sizes: String,
34
35    /// Output format (text, json)
36    #[arg(short, long, default_value = "text")]
37    pub format: OutputFormat,
38}
39
40/// Arguments for inspect command
41#[derive(Parser, Debug, Clone, PartialEq)]
42pub struct InspectArgs {
43    /// Path to data file or model
44    #[arg(value_name = "INPUT")]
45    pub input: PathBuf,
46
47    /// Inspection mode
48    #[arg(short, long, default_value = "summary")]
49    pub mode: InspectMode,
50
51    /// Columns to inspect (comma-separated)
52    #[arg(long)]
53    pub columns: Option<String>,
54
55    /// Z-score threshold for outlier detection
56    #[arg(long, default_value = "3.0")]
57    pub z_threshold: f32,
58}
59
60/// Arguments for audit command
61#[derive(Parser, Debug, Clone, PartialEq)]
62pub struct AuditArgs {
63    /// Path to model or config file
64    #[arg(value_name = "INPUT")]
65    pub input: PathBuf,
66
67    /// Audit type
68    #[arg(short, long, default_value = "bias")]
69    pub audit_type: AuditType,
70
71    /// Protected attribute column
72    #[arg(long)]
73    pub protected_attr: Option<String>,
74
75    /// Fairness threshold (0.0-1.0)
76    #[arg(long, default_value = "0.8")]
77    pub threshold: f32,
78
79    /// Output format (text, json, html)
80    #[arg(short, long, default_value = "text")]
81    pub format: OutputFormat,
82}
83
84/// Arguments for monitor command
85#[derive(Parser, Debug, Clone, PartialEq)]
86pub struct MonitorArgs {
87    /// Path to model or config file
88    #[arg(value_name = "INPUT")]
89    pub input: PathBuf,
90
91    /// Baseline statistics file
92    #[arg(long)]
93    pub baseline: Option<PathBuf>,
94
95    /// Drift detection threshold (PSI)
96    #[arg(long, default_value = "0.2")]
97    pub threshold: f32,
98
99    /// Monitoring interval in seconds
100    #[arg(long, default_value = "60")]
101    pub interval: u64,
102
103    /// Output format (text, json)
104    #[arg(short, long, default_value = "text")]
105    pub format: OutputFormat,
106}
107
108/// Arguments for the publish command
109#[allow(clippy::struct_excessive_bools)]
110#[derive(Parser, Debug, Clone, PartialEq)]
111pub struct PublishArgs {
112    /// Path to trained model output directory
113    #[arg(value_name = "MODEL_DIR", default_value = "./output")]
114    pub model_dir: PathBuf,
115
116    /// HuggingFace repo ID (e.g., myuser/my-model)
117    #[arg(long)]
118    pub repo: String,
119
120    /// Make the repository private
121    #[arg(long)]
122    pub private: bool,
123
124    /// Generate and upload a model card
125    #[arg(long, default_value_t = true)]
126    pub model_card: bool,
127
128    /// Merge LoRA adapters into base weights before publishing
129    #[arg(long)]
130    pub merge_adapters: bool,
131
132    /// Base model HF repo ID (for model card metadata)
133    #[arg(long)]
134    pub base_model: Option<String>,
135
136    /// Export format (safetensors or gguf)
137    #[arg(long, default_value = "safetensors")]
138    pub format: String,
139
140    /// Dry run (validate but don't upload)
141    #[arg(long)]
142    pub dry_run: bool,
143}
144
145/// Arguments for the finetune command (plan/apply classification training)
146#[derive(Parser, Debug, Clone, PartialEq)]
147pub struct FinetuneArgs {
148    /// Subcommand to execute
149    #[command(subcommand)]
150    pub command: FinetuneCommand,
151}
152
153/// Finetune subcommands (forjar-style plan/apply)
154#[derive(Subcommand, Debug, Clone, PartialEq)]
155pub enum FinetuneCommand {
156    /// Generate a training plan (validate data, estimate resources, build HPO config)
157    Plan {
158        /// Path to training data (JSONL with {"input": ..., "label": N})
159        #[arg(long)]
160        data: PathBuf,
161
162        /// Path to model weights directory (e.g., Qwen2.5-Coder-0.5B)
163        #[arg(long)]
164        model_path: Option<PathBuf>,
165
166        /// Model size hint (e.g., "0.5B", "9B")
167        #[arg(long, default_value = "0.5B")]
168        model_size: String,
169
170        /// Number of output classes
171        #[arg(long, default_value = "5")]
172        num_classes: usize,
173
174        /// Output directory for plan and checkpoints
175        #[arg(short, long, default_value = "./output")]
176        output_dir: PathBuf,
177
178        /// HPO strategy: tpe, grid, random, or manual
179        #[arg(long, default_value = "tpe")]
180        strategy: String,
181
182        /// HPO budget (number of trials)
183        #[arg(long, default_value = "20")]
184        budget: usize,
185
186        /// Scout mode (1 epoch per trial for fast HPO)
187        #[arg(long)]
188        scout: bool,
189
190        /// Maximum epochs per trial
191        #[arg(long, default_value = "10")]
192        max_epochs: usize,
193
194        /// Manual learning rate (for strategy=manual)
195        #[arg(long)]
196        lr: Option<f32>,
197
198        /// Manual LoRA rank (for strategy=manual)
199        #[arg(long)]
200        lora_rank: Option<usize>,
201
202        /// Manual batch size (for strategy=manual)
203        #[arg(long)]
204        batch_size: Option<usize>,
205
206        /// LoRA alpha (for strategy=manual; defaults to lora_rank)
207        #[arg(long)]
208        lora_alpha: Option<f32>,
209
210        /// Warmup fraction (for strategy=manual; default 0.1)
211        #[arg(long)]
212        warmup: Option<f32>,
213
214        /// Gradient clip norm (for strategy=manual; default 1.0)
215        #[arg(long)]
216        gradient_clip: Option<f32>,
217
218        /// LR min ratio for cosine decay (for strategy=manual; default 0.01)
219        #[arg(long)]
220        lr_min_ratio: Option<f32>,
221
222        /// Class weight strategy: uniform, inverse_freq, sqrt_inverse
223        #[arg(long)]
224        class_weights: Option<String>,
225
226        /// Target modules: qv, qkv, all_linear
227        #[arg(long)]
228        target_modules: Option<String>,
229    },
230
231    /// Execute a training plan (load model, run trials, save checkpoints)
232    Apply {
233        /// Path to plan YAML/JSON (generated by `finetune plan`)
234        #[arg(long)]
235        plan: PathBuf,
236
237        /// Path to model weights directory
238        #[arg(long)]
239        model_path: PathBuf,
240
241        /// Path to training data (JSONL)
242        #[arg(long)]
243        data: PathBuf,
244
245        /// Output directory for checkpoints and leaderboard
246        #[arg(short, long, default_value = "./output")]
247        output_dir: PathBuf,
248    },
249}
250
251/// Arguments for the experiments command
252#[derive(Parser, Debug, Clone, PartialEq)]
253pub struct ExperimentsArgs {
254    /// Subcommand to execute
255    #[command(subcommand)]
256    pub command: ExperimentsCommand,
257
258    /// Project directory (defaults to current directory)
259    #[arg(short, long, global = true, default_value = ".")]
260    pub project: PathBuf,
261
262    /// Output format (text, json)
263    #[arg(short, long, global = true, default_value = "text")]
264    pub format: OutputFormat,
265}
266
267/// Experiment store subcommands
268#[derive(Subcommand, Debug, Clone, PartialEq)]
269pub enum ExperimentsCommand {
270    /// List all experiments
271    List,
272
273    /// Show details of a specific experiment
274    Show {
275        /// Experiment ID
276        #[arg(value_name = "ID")]
277        id: String,
278    },
279
280    /// List runs for an experiment
281    Runs {
282        /// Experiment ID
283        #[arg(value_name = "EXPERIMENT_ID")]
284        experiment_id: String,
285    },
286
287    /// Show metrics for a run
288    Metrics {
289        /// Run ID
290        #[arg(value_name = "RUN_ID")]
291        run_id: String,
292
293        /// Metric key (e.g., "loss", "accuracy")
294        #[arg(value_name = "KEY")]
295        key: String,
296    },
297
298    /// Delete an experiment and all its runs
299    Delete {
300        /// Experiment ID
301        #[arg(value_name = "ID")]
302        id: String,
303    },
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use crate::config::cli::parse_args;
310
311    #[test]
312    fn test_parse_completion_command() {
313        let cli = parse_args(["entrenar", "completion", "bash"]).expect("parsing should succeed");
314        match cli.command {
315            crate::config::cli::Command::Completion(args) => {
316                assert_eq!(args.shell, ShellType::Bash);
317            }
318            _ => panic!("Expected Completion command"),
319        }
320    }
321
322    #[test]
323    fn test_parse_bench_command() {
324        let cli = parse_args(["entrenar", "bench", "model.gguf"]).expect("parsing should succeed");
325        match cli.command {
326            crate::config::cli::Command::Bench(args) => {
327                assert_eq!(args.input, PathBuf::from("model.gguf"));
328                assert_eq!(args.warmup, 10);
329                assert_eq!(args.iterations, 100);
330                assert_eq!(args.batch_sizes, "1,8,32");
331            }
332            _ => panic!("Expected Bench command"),
333        }
334    }
335
336    #[test]
337    fn test_parse_bench_with_options() {
338        let cli = parse_args([
339            "entrenar",
340            "bench",
341            "model.gguf",
342            "--warmup",
343            "5",
344            "--iterations",
345            "50",
346            "--batch-sizes",
347            "1,2,4,8",
348            "--format",
349            "json",
350        ])
351        .expect("operation should succeed");
352        match cli.command {
353            crate::config::cli::Command::Bench(args) => {
354                assert_eq!(args.warmup, 5);
355                assert_eq!(args.iterations, 50);
356                assert_eq!(args.batch_sizes, "1,2,4,8");
357                assert_eq!(args.format, OutputFormat::Json);
358            }
359            _ => panic!("Expected Bench command"),
360        }
361    }
362
363    #[test]
364    fn test_parse_inspect_command() {
365        let cli =
366            parse_args(["entrenar", "inspect", "data.parquet"]).expect("parsing should succeed");
367        match cli.command {
368            crate::config::cli::Command::Inspect(args) => {
369                assert_eq!(args.input, PathBuf::from("data.parquet"));
370                assert_eq!(args.mode, InspectMode::Summary);
371                assert!((args.z_threshold - 3.0).abs() < 1e-6);
372            }
373            _ => panic!("Expected Inspect command"),
374        }
375    }
376
377    #[test]
378    fn test_parse_inspect_with_options() {
379        let cli = parse_args([
380            "entrenar",
381            "inspect",
382            "data.parquet",
383            "--mode",
384            "outliers",
385            "--columns",
386            "col1,col2",
387            "--z-threshold",
388            "2.5",
389        ])
390        .expect("operation should succeed");
391        match cli.command {
392            crate::config::cli::Command::Inspect(args) => {
393                assert_eq!(args.mode, InspectMode::Outliers);
394                assert_eq!(args.columns, Some("col1,col2".to_string()));
395                assert!((args.z_threshold - 2.5).abs() < 1e-6);
396            }
397            _ => panic!("Expected Inspect command"),
398        }
399    }
400
401    #[test]
402    fn test_parse_audit_command() {
403        let cli = parse_args(["entrenar", "audit", "model.gguf"]).expect("parsing should succeed");
404        match cli.command {
405            crate::config::cli::Command::Audit(args) => {
406                assert_eq!(args.input, PathBuf::from("model.gguf"));
407                assert_eq!(args.audit_type, AuditType::Bias);
408                assert!((args.threshold - 0.8).abs() < 1e-6);
409            }
410            _ => panic!("Expected Audit command"),
411        }
412    }
413
414    #[test]
415    fn test_parse_audit_with_options() {
416        let cli = parse_args([
417            "entrenar",
418            "audit",
419            "model.gguf",
420            "--audit-type",
421            "fairness",
422            "--protected-attr",
423            "gender",
424            "--threshold",
425            "0.9",
426            "--format",
427            "json",
428        ])
429        .expect("operation should succeed");
430        match cli.command {
431            crate::config::cli::Command::Audit(args) => {
432                assert_eq!(args.audit_type, AuditType::Fairness);
433                assert_eq!(args.protected_attr, Some("gender".to_string()));
434                assert!((args.threshold - 0.9).abs() < 1e-6);
435                assert_eq!(args.format, OutputFormat::Json);
436            }
437            _ => panic!("Expected Audit command"),
438        }
439    }
440
441    #[test]
442    fn test_parse_monitor_command() {
443        let cli =
444            parse_args(["entrenar", "monitor", "model.gguf"]).expect("parsing should succeed");
445        match cli.command {
446            crate::config::cli::Command::Monitor(args) => {
447                assert_eq!(args.input, PathBuf::from("model.gguf"));
448                assert!((args.threshold - 0.2).abs() < 1e-6);
449                assert_eq!(args.interval, 60);
450            }
451            _ => panic!("Expected Monitor command"),
452        }
453    }
454
455    #[test]
456    fn test_parse_monitor_with_options() {
457        let cli = parse_args([
458            "entrenar",
459            "monitor",
460            "model.gguf",
461            "--baseline",
462            "baseline.json",
463            "--threshold",
464            "0.3",
465            "--interval",
466            "120",
467            "--format",
468            "json",
469        ])
470        .expect("operation should succeed");
471        match cli.command {
472            crate::config::cli::Command::Monitor(args) => {
473                assert_eq!(args.baseline, Some(PathBuf::from("baseline.json")));
474                assert!((args.threshold - 0.3).abs() < 1e-6);
475                assert_eq!(args.interval, 120);
476                assert_eq!(args.format, OutputFormat::Json);
477            }
478            _ => panic!("Expected Monitor command"),
479        }
480    }
481
482    // Additional coverage tests for derive traits
483
484    #[test]
485    fn test_completion_args_debug_clone() {
486        let args = CompletionArgs { shell: ShellType::Bash };
487        let debug = format!("{args:?}");
488        assert!(debug.contains("CompletionArgs"));
489
490        let cloned = args.clone();
491        assert_eq!(args, cloned);
492    }
493
494    #[test]
495    fn test_bench_args_debug_clone() {
496        let args = BenchArgs {
497            input: PathBuf::from("model.bin"),
498            warmup: 5,
499            iterations: 50,
500            batch_sizes: "1,2,4".to_string(),
501            format: OutputFormat::Text,
502        };
503        let debug = format!("{args:?}");
504        assert!(debug.contains("BenchArgs"));
505
506        let cloned = args.clone();
507        assert_eq!(args, cloned);
508    }
509
510    #[test]
511    fn test_inspect_args_debug_clone() {
512        let args = InspectArgs {
513            input: PathBuf::from("data.csv"),
514            mode: InspectMode::Outliers,
515            columns: Some("col1".to_string()),
516            z_threshold: 2.5,
517        };
518        let debug = format!("{args:?}");
519        assert!(debug.contains("InspectArgs"));
520
521        let cloned = args.clone();
522        assert_eq!(args, cloned);
523    }
524
525    #[test]
526    fn test_audit_args_debug_clone() {
527        let args = AuditArgs {
528            input: PathBuf::from("model.bin"),
529            audit_type: AuditType::Bias,
530            protected_attr: Some("age".to_string()),
531            threshold: 0.75,
532            format: OutputFormat::Json,
533        };
534        let debug = format!("{args:?}");
535        assert!(debug.contains("AuditArgs"));
536
537        let cloned = args.clone();
538        assert_eq!(args, cloned);
539    }
540
541    #[test]
542    fn test_monitor_args_debug_clone() {
543        let args = MonitorArgs {
544            input: PathBuf::from("model.bin"),
545            baseline: Some(PathBuf::from("base.json")),
546            threshold: 0.25,
547            interval: 30,
548            format: OutputFormat::Text,
549        };
550        let debug = format!("{args:?}");
551        assert!(debug.contains("MonitorArgs"));
552
553        let cloned = args.clone();
554        assert_eq!(args, cloned);
555    }
556
557    #[test]
558    fn test_completion_other_shells() {
559        // Test other shell types for coverage
560        let cli = parse_args(["entrenar", "completion", "zsh"]).expect("parsing should succeed");
561        match cli.command {
562            crate::config::cli::Command::Completion(args) => {
563                assert_eq!(args.shell, ShellType::Zsh);
564            }
565            _ => panic!("Expected Completion command"),
566        }
567
568        let cli = parse_args(["entrenar", "completion", "fish"]).expect("parsing should succeed");
569        match cli.command {
570            crate::config::cli::Command::Completion(args) => {
571                assert_eq!(args.shell, ShellType::Fish);
572            }
573            _ => panic!("Expected Completion command"),
574        }
575    }
576
577    #[test]
578    fn test_inspect_distribution_mode() {
579        let cli = parse_args(["entrenar", "inspect", "data.csv", "--mode", "distribution"])
580            .expect("parsing should succeed");
581        match cli.command {
582            crate::config::cli::Command::Inspect(args) => {
583                assert_eq!(args.mode, InspectMode::Distribution);
584            }
585            _ => panic!("Expected Inspect command"),
586        }
587    }
588
589    #[test]
590    fn test_audit_privacy_security_types() {
591        let cli = parse_args(["entrenar", "audit", "model.bin", "--audit-type", "privacy"])
592            .expect("parsing should succeed");
593        match cli.command {
594            crate::config::cli::Command::Audit(args) => {
595                assert_eq!(args.audit_type, AuditType::Privacy);
596            }
597            _ => panic!("Expected Audit command"),
598        }
599
600        let cli = parse_args(["entrenar", "audit", "model.bin", "--audit-type", "security"])
601            .expect("parsing should succeed");
602        match cli.command {
603            crate::config::cli::Command::Audit(args) => {
604                assert_eq!(args.audit_type, AuditType::Security);
605            }
606            _ => panic!("Expected Audit command"),
607        }
608    }
609}