use clap::{Parser, Subcommand};
use std::path::PathBuf;
use super::types::{AuditType, InspectMode, OutputFormat, ShellType};
#[derive(Parser, Debug, Clone, PartialEq)]
pub struct CompletionArgs {
#[arg(value_name = "SHELL")]
pub shell: ShellType,
}
#[derive(Parser, Debug, Clone, PartialEq)]
pub struct BenchArgs {
#[arg(value_name = "INPUT")]
pub input: PathBuf,
#[arg(long, default_value = "10")]
pub warmup: usize,
#[arg(long, default_value = "100")]
pub iterations: usize,
#[arg(long, default_value = "1,8,32")]
pub batch_sizes: String,
#[arg(short, long, default_value = "text")]
pub format: OutputFormat,
}
#[derive(Parser, Debug, Clone, PartialEq)]
pub struct InspectArgs {
#[arg(value_name = "INPUT")]
pub input: PathBuf,
#[arg(short, long, default_value = "summary")]
pub mode: InspectMode,
#[arg(long)]
pub columns: Option<String>,
#[arg(long, default_value = "3.0")]
pub z_threshold: f32,
}
#[derive(Parser, Debug, Clone, PartialEq)]
pub struct AuditArgs {
#[arg(value_name = "INPUT")]
pub input: PathBuf,
#[arg(short, long, default_value = "bias")]
pub audit_type: AuditType,
#[arg(long)]
pub protected_attr: Option<String>,
#[arg(long, default_value = "0.8")]
pub threshold: f32,
#[arg(short, long, default_value = "text")]
pub format: OutputFormat,
}
#[derive(Parser, Debug, Clone, PartialEq)]
pub struct MonitorArgs {
#[arg(value_name = "INPUT")]
pub input: PathBuf,
#[arg(long)]
pub baseline: Option<PathBuf>,
#[arg(long, default_value = "0.2")]
pub threshold: f32,
#[arg(long, default_value = "60")]
pub interval: u64,
#[arg(short, long, default_value = "text")]
pub format: OutputFormat,
}
#[derive(Parser, Debug, Clone, PartialEq)]
pub struct PublishArgs {
#[arg(value_name = "MODEL_DIR", default_value = "./output")]
pub model_dir: PathBuf,
#[arg(long)]
pub repo: String,
#[arg(long)]
pub private: bool,
#[arg(long, default_value_t = true)]
pub model_card: bool,
#[arg(long)]
pub merge_adapters: bool,
#[arg(long)]
pub base_model: Option<String>,
#[arg(long, default_value = "safetensors")]
pub format: String,
#[arg(long)]
pub dry_run: bool,
}
#[derive(Parser, Debug, Clone, PartialEq)]
pub struct FinetuneArgs {
#[command(subcommand)]
pub command: FinetuneCommand,
}
#[derive(Subcommand, Debug, Clone, PartialEq)]
pub enum FinetuneCommand {
Plan {
#[arg(long)]
data: PathBuf,
#[arg(long)]
model_path: Option<PathBuf>,
#[arg(long, default_value = "0.5B")]
model_size: String,
#[arg(long, default_value = "5")]
num_classes: usize,
#[arg(short, long, default_value = "./output")]
output_dir: PathBuf,
#[arg(long, default_value = "tpe")]
strategy: String,
#[arg(long, default_value = "20")]
budget: usize,
#[arg(long)]
scout: bool,
#[arg(long, default_value = "10")]
max_epochs: usize,
#[arg(long)]
lr: Option<f32>,
#[arg(long)]
lora_rank: Option<usize>,
#[arg(long)]
batch_size: Option<usize>,
#[arg(long)]
lora_alpha: Option<f32>,
#[arg(long)]
warmup: Option<f32>,
#[arg(long)]
gradient_clip: Option<f32>,
#[arg(long)]
lr_min_ratio: Option<f32>,
#[arg(long)]
class_weights: Option<String>,
#[arg(long)]
target_modules: Option<String>,
},
Apply {
#[arg(long)]
plan: PathBuf,
#[arg(long)]
model_path: PathBuf,
#[arg(long)]
data: PathBuf,
#[arg(short, long, default_value = "./output")]
output_dir: PathBuf,
},
}
#[derive(Parser, Debug, Clone, PartialEq)]
pub struct ExperimentsArgs {
#[command(subcommand)]
pub command: ExperimentsCommand,
#[arg(short, long, global = true, default_value = ".")]
pub project: PathBuf,
#[arg(short, long, global = true, default_value = "text")]
pub format: OutputFormat,
}
#[derive(Subcommand, Debug, Clone, PartialEq)]
pub enum ExperimentsCommand {
List,
Show {
#[arg(value_name = "ID")]
id: String,
},
Runs {
#[arg(value_name = "EXPERIMENT_ID")]
experiment_id: String,
},
Metrics {
#[arg(value_name = "RUN_ID")]
run_id: String,
#[arg(value_name = "KEY")]
key: String,
},
Delete {
#[arg(value_name = "ID")]
id: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::cli::parse_args;
#[test]
fn test_parse_completion_command() {
let cli = parse_args(["entrenar", "completion", "bash"]).expect("parsing should succeed");
match cli.command {
crate::config::cli::Command::Completion(args) => {
assert_eq!(args.shell, ShellType::Bash);
}
_ => panic!("Expected Completion command"),
}
}
#[test]
fn test_parse_bench_command() {
let cli = parse_args(["entrenar", "bench", "model.gguf"]).expect("parsing should succeed");
match cli.command {
crate::config::cli::Command::Bench(args) => {
assert_eq!(args.input, PathBuf::from("model.gguf"));
assert_eq!(args.warmup, 10);
assert_eq!(args.iterations, 100);
assert_eq!(args.batch_sizes, "1,8,32");
}
_ => panic!("Expected Bench command"),
}
}
#[test]
fn test_parse_bench_with_options() {
let cli = parse_args([
"entrenar",
"bench",
"model.gguf",
"--warmup",
"5",
"--iterations",
"50",
"--batch-sizes",
"1,2,4,8",
"--format",
"json",
])
.expect("operation should succeed");
match cli.command {
crate::config::cli::Command::Bench(args) => {
assert_eq!(args.warmup, 5);
assert_eq!(args.iterations, 50);
assert_eq!(args.batch_sizes, "1,2,4,8");
assert_eq!(args.format, OutputFormat::Json);
}
_ => panic!("Expected Bench command"),
}
}
#[test]
fn test_parse_inspect_command() {
let cli =
parse_args(["entrenar", "inspect", "data.parquet"]).expect("parsing should succeed");
match cli.command {
crate::config::cli::Command::Inspect(args) => {
assert_eq!(args.input, PathBuf::from("data.parquet"));
assert_eq!(args.mode, InspectMode::Summary);
assert!((args.z_threshold - 3.0).abs() < 1e-6);
}
_ => panic!("Expected Inspect command"),
}
}
#[test]
fn test_parse_inspect_with_options() {
let cli = parse_args([
"entrenar",
"inspect",
"data.parquet",
"--mode",
"outliers",
"--columns",
"col1,col2",
"--z-threshold",
"2.5",
])
.expect("operation should succeed");
match cli.command {
crate::config::cli::Command::Inspect(args) => {
assert_eq!(args.mode, InspectMode::Outliers);
assert_eq!(args.columns, Some("col1,col2".to_string()));
assert!((args.z_threshold - 2.5).abs() < 1e-6);
}
_ => panic!("Expected Inspect command"),
}
}
#[test]
fn test_parse_audit_command() {
let cli = parse_args(["entrenar", "audit", "model.gguf"]).expect("parsing should succeed");
match cli.command {
crate::config::cli::Command::Audit(args) => {
assert_eq!(args.input, PathBuf::from("model.gguf"));
assert_eq!(args.audit_type, AuditType::Bias);
assert!((args.threshold - 0.8).abs() < 1e-6);
}
_ => panic!("Expected Audit command"),
}
}
#[test]
fn test_parse_audit_with_options() {
let cli = parse_args([
"entrenar",
"audit",
"model.gguf",
"--audit-type",
"fairness",
"--protected-attr",
"gender",
"--threshold",
"0.9",
"--format",
"json",
])
.expect("operation should succeed");
match cli.command {
crate::config::cli::Command::Audit(args) => {
assert_eq!(args.audit_type, AuditType::Fairness);
assert_eq!(args.protected_attr, Some("gender".to_string()));
assert!((args.threshold - 0.9).abs() < 1e-6);
assert_eq!(args.format, OutputFormat::Json);
}
_ => panic!("Expected Audit command"),
}
}
#[test]
fn test_parse_monitor_command() {
let cli =
parse_args(["entrenar", "monitor", "model.gguf"]).expect("parsing should succeed");
match cli.command {
crate::config::cli::Command::Monitor(args) => {
assert_eq!(args.input, PathBuf::from("model.gguf"));
assert!((args.threshold - 0.2).abs() < 1e-6);
assert_eq!(args.interval, 60);
}
_ => panic!("Expected Monitor command"),
}
}
#[test]
fn test_parse_monitor_with_options() {
let cli = parse_args([
"entrenar",
"monitor",
"model.gguf",
"--baseline",
"baseline.json",
"--threshold",
"0.3",
"--interval",
"120",
"--format",
"json",
])
.expect("operation should succeed");
match cli.command {
crate::config::cli::Command::Monitor(args) => {
assert_eq!(args.baseline, Some(PathBuf::from("baseline.json")));
assert!((args.threshold - 0.3).abs() < 1e-6);
assert_eq!(args.interval, 120);
assert_eq!(args.format, OutputFormat::Json);
}
_ => panic!("Expected Monitor command"),
}
}
#[test]
fn test_completion_args_debug_clone() {
let args = CompletionArgs { shell: ShellType::Bash };
let debug = format!("{args:?}");
assert!(debug.contains("CompletionArgs"));
let cloned = args.clone();
assert_eq!(args, cloned);
}
#[test]
fn test_bench_args_debug_clone() {
let args = BenchArgs {
input: PathBuf::from("model.bin"),
warmup: 5,
iterations: 50,
batch_sizes: "1,2,4".to_string(),
format: OutputFormat::Text,
};
let debug = format!("{args:?}");
assert!(debug.contains("BenchArgs"));
let cloned = args.clone();
assert_eq!(args, cloned);
}
#[test]
fn test_inspect_args_debug_clone() {
let args = InspectArgs {
input: PathBuf::from("data.csv"),
mode: InspectMode::Outliers,
columns: Some("col1".to_string()),
z_threshold: 2.5,
};
let debug = format!("{args:?}");
assert!(debug.contains("InspectArgs"));
let cloned = args.clone();
assert_eq!(args, cloned);
}
#[test]
fn test_audit_args_debug_clone() {
let args = AuditArgs {
input: PathBuf::from("model.bin"),
audit_type: AuditType::Bias,
protected_attr: Some("age".to_string()),
threshold: 0.75,
format: OutputFormat::Json,
};
let debug = format!("{args:?}");
assert!(debug.contains("AuditArgs"));
let cloned = args.clone();
assert_eq!(args, cloned);
}
#[test]
fn test_monitor_args_debug_clone() {
let args = MonitorArgs {
input: PathBuf::from("model.bin"),
baseline: Some(PathBuf::from("base.json")),
threshold: 0.25,
interval: 30,
format: OutputFormat::Text,
};
let debug = format!("{args:?}");
assert!(debug.contains("MonitorArgs"));
let cloned = args.clone();
assert_eq!(args, cloned);
}
#[test]
fn test_completion_other_shells() {
let cli = parse_args(["entrenar", "completion", "zsh"]).expect("parsing should succeed");
match cli.command {
crate::config::cli::Command::Completion(args) => {
assert_eq!(args.shell, ShellType::Zsh);
}
_ => panic!("Expected Completion command"),
}
let cli = parse_args(["entrenar", "completion", "fish"]).expect("parsing should succeed");
match cli.command {
crate::config::cli::Command::Completion(args) => {
assert_eq!(args.shell, ShellType::Fish);
}
_ => panic!("Expected Completion command"),
}
}
#[test]
fn test_inspect_distribution_mode() {
let cli = parse_args(["entrenar", "inspect", "data.csv", "--mode", "distribution"])
.expect("parsing should succeed");
match cli.command {
crate::config::cli::Command::Inspect(args) => {
assert_eq!(args.mode, InspectMode::Distribution);
}
_ => panic!("Expected Inspect command"),
}
}
#[test]
fn test_audit_privacy_security_types() {
let cli = parse_args(["entrenar", "audit", "model.bin", "--audit-type", "privacy"])
.expect("parsing should succeed");
match cli.command {
crate::config::cli::Command::Audit(args) => {
assert_eq!(args.audit_type, AuditType::Privacy);
}
_ => panic!("Expected Audit command"),
}
let cli = parse_args(["entrenar", "audit", "model.bin", "--audit-type", "security"])
.expect("parsing should succeed");
match cli.command {
crate::config::cli::Command::Audit(args) => {
assert_eq!(args.audit_type, AuditType::Security);
}
_ => panic!("Expected Audit command"),
}
}
}