use crate::completion::{CompletionFunc, CompletionResult};
use crate::error::{Error, Result};
use std::collections::HashSet;
#[derive(Clone, Debug, PartialEq)]
pub enum FlagValue {
String(String),
Bool(bool),
Int(i64),
Float(f64),
StringSlice(Vec<String>),
}
impl FlagValue {
pub fn as_string(&self) -> Result<&String> {
match self {
Self::String(s) => Ok(s),
_ => Err(Error::flag_parsing("Flag value is not a string")),
}
}
pub fn as_bool(&self) -> Result<bool> {
match self {
Self::Bool(b) => Ok(*b),
_ => Err(Error::flag_parsing("Flag value is not a boolean")),
}
}
pub fn as_int(&self) -> Result<i64> {
match self {
Self::Int(i) => Ok(*i),
_ => Err(Error::flag_parsing("Flag value is not an integer")),
}
}
pub fn as_float(&self) -> Result<f64> {
match self {
Self::Float(f) => Ok(*f),
_ => Err(Error::flag_parsing("Flag value is not a float")),
}
}
pub fn as_string_slice(&self) -> Result<&Vec<String>> {
match self {
Self::StringSlice(v) => Ok(v),
_ => Err(Error::flag_parsing("Flag value is not a string slice")),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum FlagConstraint {
RequiredIf(String),
ConflictsWith(Vec<String>),
Requires(Vec<String>),
}
pub struct Flag {
pub name: String,
pub short: Option<char>,
pub usage: String,
pub default: Option<FlagValue>,
pub required: bool,
pub value_type: FlagType,
pub constraints: Vec<FlagConstraint>,
pub completion: Option<CompletionFunc>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum FlagType {
String,
Bool,
Int,
Float,
StringSlice,
StringArray,
Choice(Vec<String>),
Range(i64, i64),
File,
Directory,
}
impl Flag {
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
short: None,
usage: String::new(),
default: None,
required: false,
value_type: FlagType::String,
constraints: Vec::new(),
completion: None,
}
}
#[must_use]
pub fn bool(name: impl Into<String>) -> Self {
Self::new(name).value_type(FlagType::Bool)
}
#[must_use]
pub fn int(name: impl Into<String>) -> Self {
Self::new(name).value_type(FlagType::Int)
}
#[must_use]
pub fn float(name: impl Into<String>) -> Self {
Self::new(name).value_type(FlagType::Float)
}
#[must_use]
pub fn string(name: impl Into<String>) -> Self {
Self::new(name) }
#[must_use]
pub fn string_slice(name: impl Into<String>) -> Self {
Self::new(name).value_type(FlagType::StringSlice)
}
#[must_use]
pub fn choice(name: impl Into<String>, choices: &[&str]) -> Self {
let choices: Vec<String> = choices.iter().map(|&s| s.to_string()).collect();
Self::new(name).value_type(FlagType::Choice(choices))
}
#[must_use]
pub fn range(name: impl Into<String>, min: i64, max: i64) -> Self {
Self::new(name).value_type(FlagType::Range(min, max))
}
#[must_use]
pub fn file(name: impl Into<String>) -> Self {
Self::new(name).value_type(FlagType::File)
}
#[must_use]
pub fn directory(name: impl Into<String>) -> Self {
Self::new(name).value_type(FlagType::Directory)
}
#[must_use]
pub const fn short(mut self, short: char) -> Self {
self.short = Some(short);
self
}
#[must_use]
pub fn usage(mut self, usage: impl Into<String>) -> Self {
self.usage = usage.into();
self
}
#[must_use]
pub fn default(mut self, value: FlagValue) -> Self {
self.default = Some(value);
self
}
#[must_use]
pub fn default_bool(self, value: bool) -> Self {
self.default(FlagValue::Bool(value))
}
#[must_use]
pub fn default_str(self, value: &str) -> Self {
self.default(FlagValue::String(value.to_string()))
}
#[must_use]
pub fn default_int(self, value: i64) -> Self {
self.default(FlagValue::Int(value))
}
#[must_use]
pub fn default_float(self, value: f64) -> Self {
self.default(FlagValue::Float(value))
}
#[must_use]
pub const fn required(mut self) -> Self {
self.required = true;
self
}
#[must_use]
pub fn value_type(mut self, value_type: FlagType) -> Self {
self.value_type = value_type;
self
}
#[must_use]
pub fn constraint(mut self, constraint: FlagConstraint) -> Self {
self.constraints.push(constraint);
self
}
#[must_use]
pub fn completion<F>(mut self, completion: F) -> Self
where
F: Fn(&crate::Context, &str) -> Result<CompletionResult> + Send + Sync + 'static,
{
self.completion = Some(Box::new(completion));
self
}
pub fn parse_value(&self, input: &str) -> Result<FlagValue> {
match &self.value_type {
FlagType::String => Ok(FlagValue::String(input.to_string())),
FlagType::Bool => match input.to_lowercase().as_str() {
"true" | "t" | "1" | "yes" | "y" => Ok(FlagValue::Bool(true)),
"false" | "f" | "0" | "no" | "n" => Ok(FlagValue::Bool(false)),
_ => Err(Error::flag_parsing_with_suggestions(
format!("Invalid boolean value: '{input}'"),
self.name.clone(),
vec![
"true, false".to_string(),
"yes, no".to_string(),
"1, 0".to_string(),
],
)),
},
FlagType::Int => input.parse::<i64>().map(FlagValue::Int).map_err(|_| {
Error::flag_parsing_with_suggestions(
format!("Invalid integer value: '{input}'"),
self.name.clone(),
vec!["a whole number (e.g., 42, -10, 0)".to_string()],
)
}),
FlagType::Float => input.parse::<f64>().map(FlagValue::Float).map_err(|_| {
Error::flag_parsing_with_suggestions(
format!("Invalid float value: '{input}'"),
self.name.clone(),
vec!["a decimal number (e.g., 3.14, -0.5, 1e10)".to_string()],
)
}),
FlagType::StringSlice | FlagType::StringArray => {
Ok(FlagValue::StringSlice(vec![input.to_string()]))
}
FlagType::Choice(choices) => {
if choices.contains(&input.to_string()) {
Ok(FlagValue::String(input.to_string()))
} else {
Err(Error::flag_parsing_with_suggestions(
format!("Invalid choice: '{input}'"),
self.name.clone(),
choices.clone(),
))
}
}
FlagType::Range(min, max) => {
let value = input.parse::<i64>().map_err(|_| {
Error::flag_parsing_with_suggestions(
format!("Invalid integer value: '{input}'"),
self.name.clone(),
vec![format!("a number between {min} and {max}")],
)
})?;
if value >= *min && value <= *max {
Ok(FlagValue::Int(value))
} else {
Err(Error::flag_parsing_with_suggestions(
format!("Value {value} is out of range"),
self.name.clone(),
vec![format!("a number between {min} and {max} (inclusive)")],
))
}
}
FlagType::File => Self::parse_path(input, &self.name, PathKind::File),
FlagType::Directory => Self::parse_path(input, &self.name, PathKind::Directory),
}
}
fn parse_path(input: &str, name: &str, kind: PathKind) -> Result<FlagValue> {
use std::path::Path;
let path = Path::new(input);
if path.exists() && kind.matches(path) {
Ok(FlagValue::String(input.to_string()))
} else if !path.exists() {
Err(Error::flag_parsing_with_suggestions(
format!("{} not found: '{input}'", kind.capitalized()),
name.to_string(),
vec![kind.not_found_suggestion().to_string()],
))
} else {
Err(Error::flag_parsing_with_suggestions(
format!("Path exists but is not a {}: '{input}'", kind.lowercase()),
name.to_string(),
vec![kind.wrong_kind_suggestion().to_string()],
))
}
}
pub fn validate_constraints(
&self,
flag_name: &str,
provided_flags: &HashSet<String>,
) -> Result<()> {
for constraint in &self.constraints {
match constraint {
FlagConstraint::RequiredIf(other_flag) => {
if provided_flags.contains(other_flag) && !provided_flags.contains(flag_name) {
return Err(Error::flag_parsing_with_suggestions(
format!(
"Flag '--{flag_name}' is required when '--{other_flag}' is set"
),
flag_name.to_string(),
vec![format!("add --{flag_name} <value>")],
));
}
}
FlagConstraint::ConflictsWith(conflicting_flags) => {
if provided_flags.contains(flag_name) {
for conflict in conflicting_flags {
if provided_flags.contains(conflict) {
return Err(Error::flag_parsing_with_suggestions(
format!("Flag '--{flag_name}' conflicts with '--{conflict}'"),
flag_name.to_string(),
vec![format!(
"use either --{flag_name} or --{conflict}, not both"
)],
));
}
}
}
}
FlagConstraint::Requires(required_flags) => {
if provided_flags.contains(flag_name) {
for required in required_flags {
if !provided_flags.contains(required) {
return Err(Error::flag_parsing_with_suggestions(
format!(
"Flag '--{flag_name}' requires '--{required}' to be set"
),
flag_name.to_string(),
vec![format!("add --{required} <value>")],
));
}
}
}
}
}
}
Ok(())
}
}
impl Clone for Flag {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
short: self.short,
usage: self.usage.clone(),
default: self.default.clone(),
required: self.required,
value_type: self.value_type.clone(),
constraints: self.constraints.clone(),
completion: None, }
}
}
#[derive(Clone, Copy)]
enum PathKind {
File,
Directory,
}
impl PathKind {
fn matches(self, path: &std::path::Path) -> bool {
match self {
Self::File => path.is_file(),
Self::Directory => path.is_dir(),
}
}
const fn capitalized(self) -> &'static str {
match self {
Self::File => "File",
Self::Directory => "Directory",
}
}
const fn lowercase(self) -> &'static str {
match self {
Self::File => "file",
Self::Directory => "directory",
}
}
const fn not_found_suggestion(self) -> &'static str {
match self {
Self::File => "path to an existing file",
Self::Directory => "path to an existing directory",
}
}
const fn wrong_kind_suggestion(self) -> &'static str {
match self {
Self::File => "path to a regular file (not a directory)",
Self::Directory => "path to a directory (not a file)",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(clippy::approx_constant)]
const PI: f64 = 3.14;
#[test]
fn test_flag_value_conversions() {
let string_val = FlagValue::String("hello".to_string());
assert_eq!(string_val.as_string().unwrap(), "hello");
assert!(string_val.as_bool().is_err());
let bool_val = FlagValue::Bool(true);
assert!(bool_val.as_bool().unwrap());
assert!(bool_val.as_string().is_err());
let int_val = FlagValue::Int(42);
assert_eq!(int_val.as_int().unwrap(), 42);
assert!(int_val.as_float().is_err());
let float_val = FlagValue::Float(PI);
assert!((float_val.as_float().unwrap() - PI).abs() < f64::EPSILON);
assert!(float_val.as_int().is_err());
let slice_val = FlagValue::StringSlice(vec!["a".to_string(), "b".to_string()]);
assert_eq!(
slice_val.as_string_slice().unwrap(),
&vec!["a".to_string(), "b".to_string()]
);
assert!(slice_val.as_string().is_err());
}
#[test]
fn test_flag_parsing() {
let string_flag = Flag::new("name").value_type(FlagType::String);
assert_eq!(
string_flag.parse_value("test").unwrap(),
FlagValue::String("test".to_string())
);
let bool_flag = Flag::new("verbose").value_type(FlagType::Bool);
assert_eq!(
bool_flag.parse_value("true").unwrap(),
FlagValue::Bool(true)
);
assert_eq!(
bool_flag.parse_value("false").unwrap(),
FlagValue::Bool(false)
);
assert_eq!(bool_flag.parse_value("1").unwrap(), FlagValue::Bool(true));
assert_eq!(bool_flag.parse_value("0").unwrap(), FlagValue::Bool(false));
assert_eq!(bool_flag.parse_value("yes").unwrap(), FlagValue::Bool(true));
assert_eq!(bool_flag.parse_value("no").unwrap(), FlagValue::Bool(false));
assert!(bool_flag.parse_value("invalid").is_err());
let int_flag = Flag::new("count").value_type(FlagType::Int);
assert_eq!(int_flag.parse_value("42").unwrap(), FlagValue::Int(42));
assert_eq!(int_flag.parse_value("-10").unwrap(), FlagValue::Int(-10));
assert!(int_flag.parse_value("not_a_number").is_err());
let float_flag = Flag::new("ratio").value_type(FlagType::Float);
assert_eq!(
float_flag.parse_value("3.14").unwrap(),
FlagValue::Float(PI)
);
assert_eq!(
float_flag.parse_value("-2.5").unwrap(),
FlagValue::Float(-2.5)
);
assert!(float_flag.parse_value("not_a_float").is_err());
}
#[test]
fn test_flag_builder() {
let flag = Flag::new("verbose")
.short('v')
.usage("Enable verbose output")
.default(FlagValue::Bool(false))
.value_type(FlagType::Bool);
assert_eq!(flag.name, "verbose");
assert_eq!(flag.short, Some('v'));
assert_eq!(flag.usage, "Enable verbose output");
assert_eq!(flag.default, Some(FlagValue::Bool(false)));
assert!(!flag.required);
}
#[test]
fn test_choice_flag() {
let choice_flag = Flag::new("environment").value_type(FlagType::Choice(vec![
"dev".to_string(),
"staging".to_string(),
"prod".to_string(),
]));
assert_eq!(
choice_flag.parse_value("dev").unwrap(),
FlagValue::String("dev".to_string())
);
assert_eq!(
choice_flag.parse_value("staging").unwrap(),
FlagValue::String("staging".to_string())
);
assert!(choice_flag.parse_value("test").is_err());
}
#[test]
fn test_range_flag() {
let range_flag = Flag::new("port").value_type(FlagType::Range(1024, 65535));
assert_eq!(
range_flag.parse_value("8080").unwrap(),
FlagValue::Int(8080)
);
assert_eq!(
range_flag.parse_value("1024").unwrap(),
FlagValue::Int(1024)
);
assert_eq!(
range_flag.parse_value("65535").unwrap(),
FlagValue::Int(65535)
);
assert!(range_flag.parse_value("80").is_err());
assert!(range_flag.parse_value("70000").is_err());
assert!(range_flag.parse_value("not_a_number").is_err());
}
#[test]
fn test_file_flag() {
use std::fs::File;
use std::io::Write;
let temp_file = "test_file_flag.tmp";
let mut file = File::create(temp_file).unwrap();
writeln!(file, "test").unwrap();
let file_flag = Flag::new("config").value_type(FlagType::File);
assert_eq!(
file_flag.parse_value(temp_file).unwrap(),
FlagValue::String(temp_file.to_string())
);
assert!(file_flag.parse_value("nonexistent.file").is_err());
std::fs::remove_file(temp_file).unwrap();
}
#[test]
fn test_directory_flag() {
let dir_flag = Flag::new("output").value_type(FlagType::Directory);
assert_eq!(
dir_flag.parse_value(".").unwrap(),
FlagValue::String(".".to_string())
);
assert_eq!(
dir_flag.parse_value("src").unwrap(),
FlagValue::String("src".to_string())
);
assert!(dir_flag.parse_value("nonexistent_directory").is_err());
}
#[test]
fn test_string_array_flag() {
let array_flag = Flag::new("tags").value_type(FlagType::StringArray);
assert_eq!(
array_flag.parse_value("tag1").unwrap(),
FlagValue::StringSlice(vec!["tag1".to_string()])
);
}
#[test]
fn test_flag_constraints() {
let mut provided_flags = HashSet::new();
let ssl_flag = Flag::new("ssl").constraint(FlagConstraint::RequiredIf("port".to_string()));
assert!(
ssl_flag
.validate_constraints("ssl", &provided_flags)
.is_ok()
);
provided_flags.insert("port".to_string());
assert!(
ssl_flag
.validate_constraints("ssl", &provided_flags)
.is_err()
);
provided_flags.insert("ssl".to_string());
assert!(
ssl_flag
.validate_constraints("ssl", &provided_flags)
.is_ok()
);
let encrypt_flag = Flag::new("encrypt").constraint(FlagConstraint::ConflictsWith(vec![
"no-encrypt".to_string(),
]));
provided_flags.clear();
provided_flags.insert("encrypt".to_string());
assert!(
encrypt_flag
.validate_constraints("encrypt", &provided_flags)
.is_ok()
);
provided_flags.insert("no-encrypt".to_string());
assert!(
encrypt_flag
.validate_constraints("encrypt", &provided_flags)
.is_err()
);
let output_flag =
Flag::new("output").constraint(FlagConstraint::Requires(vec!["format".to_string()]));
provided_flags.clear();
provided_flags.insert("output".to_string());
assert!(
output_flag
.validate_constraints("output", &provided_flags)
.is_err()
);
provided_flags.insert("format".to_string());
assert!(
output_flag
.validate_constraints("output", &provided_flags)
.is_ok()
);
}
}