1use clap::{Parser, Subcommand};
4use std::path::PathBuf;
5
6use super::types::{AuditType, InspectMode, OutputFormat, ShellType};
7
8#[derive(Parser, Debug, Clone, PartialEq)]
10pub struct CompletionArgs {
11 #[arg(value_name = "SHELL")]
13 pub shell: ShellType,
14}
15
16#[derive(Parser, Debug, Clone, PartialEq)]
18pub struct BenchArgs {
19 #[arg(value_name = "INPUT")]
21 pub input: PathBuf,
22
23 #[arg(long, default_value = "10")]
25 pub warmup: usize,
26
27 #[arg(long, default_value = "100")]
29 pub iterations: usize,
30
31 #[arg(long, default_value = "1,8,32")]
33 pub batch_sizes: String,
34
35 #[arg(short, long, default_value = "text")]
37 pub format: OutputFormat,
38}
39
40#[derive(Parser, Debug, Clone, PartialEq)]
42pub struct InspectArgs {
43 #[arg(value_name = "INPUT")]
45 pub input: PathBuf,
46
47 #[arg(short, long, default_value = "summary")]
49 pub mode: InspectMode,
50
51 #[arg(long)]
53 pub columns: Option<String>,
54
55 #[arg(long, default_value = "3.0")]
57 pub z_threshold: f32,
58}
59
60#[derive(Parser, Debug, Clone, PartialEq)]
62pub struct AuditArgs {
63 #[arg(value_name = "INPUT")]
65 pub input: PathBuf,
66
67 #[arg(short, long, default_value = "bias")]
69 pub audit_type: AuditType,
70
71 #[arg(long)]
73 pub protected_attr: Option<String>,
74
75 #[arg(long, default_value = "0.8")]
77 pub threshold: f32,
78
79 #[arg(short, long, default_value = "text")]
81 pub format: OutputFormat,
82}
83
84#[derive(Parser, Debug, Clone, PartialEq)]
86pub struct MonitorArgs {
87 #[arg(value_name = "INPUT")]
89 pub input: PathBuf,
90
91 #[arg(long)]
93 pub baseline: Option<PathBuf>,
94
95 #[arg(long, default_value = "0.2")]
97 pub threshold: f32,
98
99 #[arg(long, default_value = "60")]
101 pub interval: u64,
102
103 #[arg(short, long, default_value = "text")]
105 pub format: OutputFormat,
106}
107
108#[allow(clippy::struct_excessive_bools)]
110#[derive(Parser, Debug, Clone, PartialEq)]
111pub struct PublishArgs {
112 #[arg(value_name = "MODEL_DIR", default_value = "./output")]
114 pub model_dir: PathBuf,
115
116 #[arg(long)]
118 pub repo: String,
119
120 #[arg(long)]
122 pub private: bool,
123
124 #[arg(long, default_value_t = true)]
126 pub model_card: bool,
127
128 #[arg(long)]
130 pub merge_adapters: bool,
131
132 #[arg(long)]
134 pub base_model: Option<String>,
135
136 #[arg(long, default_value = "safetensors")]
138 pub format: String,
139
140 #[arg(long)]
142 pub dry_run: bool,
143}
144
145#[derive(Parser, Debug, Clone, PartialEq)]
147pub struct FinetuneArgs {
148 #[command(subcommand)]
150 pub command: FinetuneCommand,
151}
152
153#[derive(Subcommand, Debug, Clone, PartialEq)]
155pub enum FinetuneCommand {
156 Plan {
158 #[arg(long)]
160 data: PathBuf,
161
162 #[arg(long)]
164 model_path: Option<PathBuf>,
165
166 #[arg(long, default_value = "0.5B")]
168 model_size: String,
169
170 #[arg(long, default_value = "5")]
172 num_classes: usize,
173
174 #[arg(short, long, default_value = "./output")]
176 output_dir: PathBuf,
177
178 #[arg(long, default_value = "tpe")]
180 strategy: String,
181
182 #[arg(long, default_value = "20")]
184 budget: usize,
185
186 #[arg(long)]
188 scout: bool,
189
190 #[arg(long, default_value = "10")]
192 max_epochs: usize,
193
194 #[arg(long)]
196 lr: Option<f32>,
197
198 #[arg(long)]
200 lora_rank: Option<usize>,
201
202 #[arg(long)]
204 batch_size: Option<usize>,
205
206 #[arg(long)]
208 lora_alpha: Option<f32>,
209
210 #[arg(long)]
212 warmup: Option<f32>,
213
214 #[arg(long)]
216 gradient_clip: Option<f32>,
217
218 #[arg(long)]
220 lr_min_ratio: Option<f32>,
221
222 #[arg(long)]
224 class_weights: Option<String>,
225
226 #[arg(long)]
228 target_modules: Option<String>,
229 },
230
231 Apply {
233 #[arg(long)]
235 plan: PathBuf,
236
237 #[arg(long)]
239 model_path: PathBuf,
240
241 #[arg(long)]
243 data: PathBuf,
244
245 #[arg(short, long, default_value = "./output")]
247 output_dir: PathBuf,
248 },
249}
250
251#[derive(Parser, Debug, Clone, PartialEq)]
253pub struct ExperimentsArgs {
254 #[command(subcommand)]
256 pub command: ExperimentsCommand,
257
258 #[arg(short, long, global = true, default_value = ".")]
260 pub project: PathBuf,
261
262 #[arg(short, long, global = true, default_value = "text")]
264 pub format: OutputFormat,
265}
266
267#[derive(Subcommand, Debug, Clone, PartialEq)]
269pub enum ExperimentsCommand {
270 List,
272
273 Show {
275 #[arg(value_name = "ID")]
277 id: String,
278 },
279
280 Runs {
282 #[arg(value_name = "EXPERIMENT_ID")]
284 experiment_id: String,
285 },
286
287 Metrics {
289 #[arg(value_name = "RUN_ID")]
291 run_id: String,
292
293 #[arg(value_name = "KEY")]
295 key: String,
296 },
297
298 Delete {
300 #[arg(value_name = "ID")]
302 id: String,
303 },
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::config::cli::parse_args;
310
311 #[test]
312 fn test_parse_completion_command() {
313 let cli = parse_args(["entrenar", "completion", "bash"]).expect("parsing should succeed");
314 match cli.command {
315 crate::config::cli::Command::Completion(args) => {
316 assert_eq!(args.shell, ShellType::Bash);
317 }
318 _ => panic!("Expected Completion command"),
319 }
320 }
321
322 #[test]
323 fn test_parse_bench_command() {
324 let cli = parse_args(["entrenar", "bench", "model.gguf"]).expect("parsing should succeed");
325 match cli.command {
326 crate::config::cli::Command::Bench(args) => {
327 assert_eq!(args.input, PathBuf::from("model.gguf"));
328 assert_eq!(args.warmup, 10);
329 assert_eq!(args.iterations, 100);
330 assert_eq!(args.batch_sizes, "1,8,32");
331 }
332 _ => panic!("Expected Bench command"),
333 }
334 }
335
336 #[test]
337 fn test_parse_bench_with_options() {
338 let cli = parse_args([
339 "entrenar",
340 "bench",
341 "model.gguf",
342 "--warmup",
343 "5",
344 "--iterations",
345 "50",
346 "--batch-sizes",
347 "1,2,4,8",
348 "--format",
349 "json",
350 ])
351 .expect("operation should succeed");
352 match cli.command {
353 crate::config::cli::Command::Bench(args) => {
354 assert_eq!(args.warmup, 5);
355 assert_eq!(args.iterations, 50);
356 assert_eq!(args.batch_sizes, "1,2,4,8");
357 assert_eq!(args.format, OutputFormat::Json);
358 }
359 _ => panic!("Expected Bench command"),
360 }
361 }
362
363 #[test]
364 fn test_parse_inspect_command() {
365 let cli =
366 parse_args(["entrenar", "inspect", "data.parquet"]).expect("parsing should succeed");
367 match cli.command {
368 crate::config::cli::Command::Inspect(args) => {
369 assert_eq!(args.input, PathBuf::from("data.parquet"));
370 assert_eq!(args.mode, InspectMode::Summary);
371 assert!((args.z_threshold - 3.0).abs() < 1e-6);
372 }
373 _ => panic!("Expected Inspect command"),
374 }
375 }
376
377 #[test]
378 fn test_parse_inspect_with_options() {
379 let cli = parse_args([
380 "entrenar",
381 "inspect",
382 "data.parquet",
383 "--mode",
384 "outliers",
385 "--columns",
386 "col1,col2",
387 "--z-threshold",
388 "2.5",
389 ])
390 .expect("operation should succeed");
391 match cli.command {
392 crate::config::cli::Command::Inspect(args) => {
393 assert_eq!(args.mode, InspectMode::Outliers);
394 assert_eq!(args.columns, Some("col1,col2".to_string()));
395 assert!((args.z_threshold - 2.5).abs() < 1e-6);
396 }
397 _ => panic!("Expected Inspect command"),
398 }
399 }
400
401 #[test]
402 fn test_parse_audit_command() {
403 let cli = parse_args(["entrenar", "audit", "model.gguf"]).expect("parsing should succeed");
404 match cli.command {
405 crate::config::cli::Command::Audit(args) => {
406 assert_eq!(args.input, PathBuf::from("model.gguf"));
407 assert_eq!(args.audit_type, AuditType::Bias);
408 assert!((args.threshold - 0.8).abs() < 1e-6);
409 }
410 _ => panic!("Expected Audit command"),
411 }
412 }
413
414 #[test]
415 fn test_parse_audit_with_options() {
416 let cli = parse_args([
417 "entrenar",
418 "audit",
419 "model.gguf",
420 "--audit-type",
421 "fairness",
422 "--protected-attr",
423 "gender",
424 "--threshold",
425 "0.9",
426 "--format",
427 "json",
428 ])
429 .expect("operation should succeed");
430 match cli.command {
431 crate::config::cli::Command::Audit(args) => {
432 assert_eq!(args.audit_type, AuditType::Fairness);
433 assert_eq!(args.protected_attr, Some("gender".to_string()));
434 assert!((args.threshold - 0.9).abs() < 1e-6);
435 assert_eq!(args.format, OutputFormat::Json);
436 }
437 _ => panic!("Expected Audit command"),
438 }
439 }
440
441 #[test]
442 fn test_parse_monitor_command() {
443 let cli =
444 parse_args(["entrenar", "monitor", "model.gguf"]).expect("parsing should succeed");
445 match cli.command {
446 crate::config::cli::Command::Monitor(args) => {
447 assert_eq!(args.input, PathBuf::from("model.gguf"));
448 assert!((args.threshold - 0.2).abs() < 1e-6);
449 assert_eq!(args.interval, 60);
450 }
451 _ => panic!("Expected Monitor command"),
452 }
453 }
454
455 #[test]
456 fn test_parse_monitor_with_options() {
457 let cli = parse_args([
458 "entrenar",
459 "monitor",
460 "model.gguf",
461 "--baseline",
462 "baseline.json",
463 "--threshold",
464 "0.3",
465 "--interval",
466 "120",
467 "--format",
468 "json",
469 ])
470 .expect("operation should succeed");
471 match cli.command {
472 crate::config::cli::Command::Monitor(args) => {
473 assert_eq!(args.baseline, Some(PathBuf::from("baseline.json")));
474 assert!((args.threshold - 0.3).abs() < 1e-6);
475 assert_eq!(args.interval, 120);
476 assert_eq!(args.format, OutputFormat::Json);
477 }
478 _ => panic!("Expected Monitor command"),
479 }
480 }
481
482 #[test]
485 fn test_completion_args_debug_clone() {
486 let args = CompletionArgs { shell: ShellType::Bash };
487 let debug = format!("{args:?}");
488 assert!(debug.contains("CompletionArgs"));
489
490 let cloned = args.clone();
491 assert_eq!(args, cloned);
492 }
493
494 #[test]
495 fn test_bench_args_debug_clone() {
496 let args = BenchArgs {
497 input: PathBuf::from("model.bin"),
498 warmup: 5,
499 iterations: 50,
500 batch_sizes: "1,2,4".to_string(),
501 format: OutputFormat::Text,
502 };
503 let debug = format!("{args:?}");
504 assert!(debug.contains("BenchArgs"));
505
506 let cloned = args.clone();
507 assert_eq!(args, cloned);
508 }
509
510 #[test]
511 fn test_inspect_args_debug_clone() {
512 let args = InspectArgs {
513 input: PathBuf::from("data.csv"),
514 mode: InspectMode::Outliers,
515 columns: Some("col1".to_string()),
516 z_threshold: 2.5,
517 };
518 let debug = format!("{args:?}");
519 assert!(debug.contains("InspectArgs"));
520
521 let cloned = args.clone();
522 assert_eq!(args, cloned);
523 }
524
525 #[test]
526 fn test_audit_args_debug_clone() {
527 let args = AuditArgs {
528 input: PathBuf::from("model.bin"),
529 audit_type: AuditType::Bias,
530 protected_attr: Some("age".to_string()),
531 threshold: 0.75,
532 format: OutputFormat::Json,
533 };
534 let debug = format!("{args:?}");
535 assert!(debug.contains("AuditArgs"));
536
537 let cloned = args.clone();
538 assert_eq!(args, cloned);
539 }
540
541 #[test]
542 fn test_monitor_args_debug_clone() {
543 let args = MonitorArgs {
544 input: PathBuf::from("model.bin"),
545 baseline: Some(PathBuf::from("base.json")),
546 threshold: 0.25,
547 interval: 30,
548 format: OutputFormat::Text,
549 };
550 let debug = format!("{args:?}");
551 assert!(debug.contains("MonitorArgs"));
552
553 let cloned = args.clone();
554 assert_eq!(args, cloned);
555 }
556
557 #[test]
558 fn test_completion_other_shells() {
559 let cli = parse_args(["entrenar", "completion", "zsh"]).expect("parsing should succeed");
561 match cli.command {
562 crate::config::cli::Command::Completion(args) => {
563 assert_eq!(args.shell, ShellType::Zsh);
564 }
565 _ => panic!("Expected Completion command"),
566 }
567
568 let cli = parse_args(["entrenar", "completion", "fish"]).expect("parsing should succeed");
569 match cli.command {
570 crate::config::cli::Command::Completion(args) => {
571 assert_eq!(args.shell, ShellType::Fish);
572 }
573 _ => panic!("Expected Completion command"),
574 }
575 }
576
577 #[test]
578 fn test_inspect_distribution_mode() {
579 let cli = parse_args(["entrenar", "inspect", "data.csv", "--mode", "distribution"])
580 .expect("parsing should succeed");
581 match cli.command {
582 crate::config::cli::Command::Inspect(args) => {
583 assert_eq!(args.mode, InspectMode::Distribution);
584 }
585 _ => panic!("Expected Inspect command"),
586 }
587 }
588
589 #[test]
590 fn test_audit_privacy_security_types() {
591 let cli = parse_args(["entrenar", "audit", "model.bin", "--audit-type", "privacy"])
592 .expect("parsing should succeed");
593 match cli.command {
594 crate::config::cli::Command::Audit(args) => {
595 assert_eq!(args.audit_type, AuditType::Privacy);
596 }
597 _ => panic!("Expected Audit command"),
598 }
599
600 let cli = parse_args(["entrenar", "audit", "model.bin", "--audit-type", "security"])
601 .expect("parsing should succeed");
602 match cli.command {
603 crate::config::cli::Command::Audit(args) => {
604 assert_eq!(args.audit_type, AuditType::Security);
605 }
606 _ => panic!("Expected Audit command"),
607 }
608 }
609}