use clap::Parser;
use crate::application::dto::OutputFormat;
use crate::i18n::Locale;
use crate::sbom_generation::domain::vulnerability::Severity;
#[derive(Parser, Debug)]
#[command(name = "uv-sbom")]
#[command(version)]
#[command(about = "Generate SBOMs for Python projects managed by uv", long_about = None)]
pub struct Args {
#[arg(short, long, default_value = "json")]
pub format: OutputFormat,
#[arg(short, long)]
pub path: Option<String>,
#[arg(short, long)]
pub output: Option<String>,
#[arg(short, long = "exclude", value_name = "PATTERN")]
pub exclude: Vec<String>,
#[arg(long)]
pub dry_run: bool,
#[arg(long, hide = false)]
pub check_cve: bool,
#[arg(long)]
pub no_check_cve: bool,
#[arg(long, value_parser = parse_severity_threshold, group = "threshold", conflicts_with = "no_check_cve")]
pub severity_threshold: Option<Severity>,
#[arg(long, value_parser = parse_cvss_threshold, group = "threshold", conflicts_with = "no_check_cve")]
pub cvss_threshold: Option<f32>,
#[arg(long, conflicts_with = "no_check_cve")]
pub suggest_fix: bool,
#[arg(long)]
pub verify_links: bool,
#[arg(short = 'c', long = "config", value_name = "PATH")]
pub config: Option<String>,
#[arg(short = 'i', long = "ignore-cve", value_name = "CVE_ID")]
pub ignore_cve: Vec<String>,
#[arg(long)]
pub check_license: bool,
#[arg(long, value_delimiter = ',', requires = "check_license")]
pub license_allow: Vec<String>,
#[arg(long, value_delimiter = ',', requires = "check_license")]
pub license_deny: Vec<String>,
#[arg(long)]
pub init: bool,
#[arg(long, default_value = "en", value_parser = parse_lang)]
pub lang: Locale,
}
fn parse_lang(s: &str) -> Result<Locale, String> {
Locale::from_str(s)
.ok_or_else(|| format!("Invalid language: '{}'. Supported languages: en, ja", s))
}
fn parse_severity_threshold(s: &str) -> Result<Severity, String> {
match s.to_lowercase().as_str() {
"low" => Ok(Severity::Low),
"medium" => Ok(Severity::Medium),
"high" => Ok(Severity::High),
"critical" => Ok(Severity::Critical),
_ => Err(format!(
"Invalid severity: {}. Valid values: low, medium, high, critical",
s
)),
}
}
fn parse_cvss_threshold(s: &str) -> Result<f32, String> {
let threshold: f32 = s
.parse()
.map_err(|_| "CVSS threshold must be a number".to_string())?;
if !(0.0..=10.0).contains(&threshold) {
return Err("CVSS threshold must be between 0.0 and 10.0".to_string());
}
Ok(threshold)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_lang_valid() {
assert_eq!(parse_lang("en").unwrap(), Locale::En);
assert_eq!(parse_lang("ja").unwrap(), Locale::Ja);
}
#[test]
fn test_parse_lang_invalid() {
let result = parse_lang("fr");
assert!(result.is_err());
assert!(result
.unwrap_err()
.contains("Invalid language: 'fr'. Supported languages: en, ja"));
let result = parse_lang("EN");
assert!(result.is_err());
let result = parse_lang("");
assert!(result.is_err());
}
#[test]
fn test_parse_severity_threshold_valid() {
assert_eq!(parse_severity_threshold("low").unwrap(), Severity::Low);
assert_eq!(
parse_severity_threshold("medium").unwrap(),
Severity::Medium
);
assert_eq!(parse_severity_threshold("high").unwrap(), Severity::High);
assert_eq!(
parse_severity_threshold("critical").unwrap(),
Severity::Critical
);
}
#[test]
fn test_parse_severity_threshold_case_insensitive() {
assert_eq!(parse_severity_threshold("LOW").unwrap(), Severity::Low);
assert_eq!(
parse_severity_threshold("Medium").unwrap(),
Severity::Medium
);
assert_eq!(parse_severity_threshold("HIGH").unwrap(), Severity::High);
assert_eq!(
parse_severity_threshold("CRITICAL").unwrap(),
Severity::Critical
);
}
#[test]
fn test_parse_severity_threshold_invalid() {
let result = parse_severity_threshold("none");
assert!(result.is_err());
assert!(result.unwrap_err().contains("Invalid severity"));
let result = parse_severity_threshold("unknown");
assert!(result.is_err());
}
#[test]
fn test_parse_cvss_threshold_valid() {
assert_eq!(parse_cvss_threshold("0.0").unwrap(), 0.0);
assert_eq!(parse_cvss_threshold("5.5").unwrap(), 5.5);
assert_eq!(parse_cvss_threshold("10.0").unwrap(), 10.0);
assert_eq!(parse_cvss_threshold("7").unwrap(), 7.0);
}
#[test]
fn test_parse_cvss_threshold_invalid_range() {
let result = parse_cvss_threshold("-1.0");
assert!(result.is_err());
assert!(result.unwrap_err().contains("between 0.0 and 10.0"));
let result = parse_cvss_threshold("11.0");
assert!(result.is_err());
assert!(result.unwrap_err().contains("between 0.0 and 10.0"));
}
#[test]
fn test_parse_cvss_threshold_invalid_format() {
let result = parse_cvss_threshold("abc");
assert!(result.is_err());
assert!(result.unwrap_err().contains("must be a number"));
}
}