#![allow(clippy::collapsible_if, clippy::collapsible_match)]
use proc_macro2::TokenStream;
use quote::{ToTokens, quote, quote_spanned};
use syn::{
Attribute, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsUnnamed,
GenericArgument, Ident, Lit, LitStr, Meta, PathArguments, Result, Type, Variant, parse2,
spanned::Spanned,
};
pub fn derive(input: TokenStream) -> Result<TokenStream> {
let input: DeriveInput = parse2(input)?;
match &input.data {
Data::Struct(data) => derive_struct(&input, data),
Data::Enum(data) => derive_enum(&input, data),
Data::Union(_) => Err(syn::Error::new_spanned(
&input,
"#[derive(Command)] is not supported on unions",
)),
}
}
fn derive_struct(input: &DeriveInput, data: &DataStruct) -> Result<TokenStream> {
let struct_ident = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let cmd_attrs = CommandAttrs::from_attrs(&input.attrs)?;
let struct_doc = collect_doc(&input.attrs);
let name = cmd_attrs.name.unwrap_or_else(|| struct_ident.to_string());
let description = cmd_attrs.description.or(struct_doc).unwrap_or_default();
let fields = match &data.fields {
Fields::Named(named) => &named.named,
Fields::Unit => {
return Ok(emit_struct_impl(
input,
struct_ident,
impl_generics,
ty_generics,
where_clause,
&name,
&description,
&[],
true,
));
}
Fields::Unnamed(_) => {
return Err(syn::Error::new_spanned(
&data.fields,
"#[derive(Command)] on structs requires named fields",
));
}
};
let mut collected: Vec<FieldSpec> = Vec::with_capacity(fields.len());
for field in fields {
collected.push(FieldSpec::from_field(field)?);
}
check_unique_option_aliases(&collected)?;
Ok(emit_struct_impl(
input,
struct_ident,
impl_generics,
ty_generics,
where_clause,
&name,
&description,
&collected,
false,
))
}
#[allow(clippy::too_many_arguments)]
fn emit_struct_impl(
input: &DeriveInput,
struct_ident: &Ident,
impl_generics: syn::ImplGenerics,
ty_generics: syn::TypeGenerics,
where_clause: Option<&syn::WhereClause>,
name: &str,
description: &str,
fields: &[FieldSpec],
is_unit: bool,
) -> TokenStream {
let _ = input; let schema_stmts = fields.iter().map(FieldSpec::schema_stmt);
let ctor_stmts = fields.iter().map(FieldSpec::ctor_stmt);
let ctor_expr = if is_unit {
quote! { ::core::result::Result::Ok(Self) }
} else {
quote! {
::core::result::Result::Ok(Self {
#( #ctor_stmts, )*
})
}
};
quote! {
impl #impl_generics ::runi_cli::Command for #struct_ident #ty_generics #where_clause {
fn schema() -> ::runi_cli::CommandSchema {
let schema = ::runi_cli::CommandSchema::new(#name, #description);
#( let schema = schema #schema_stmts; )*
schema
}
fn from_parsed(
p: &::runi_cli::ParseResult,
) -> ::runi_cli::Result<Self> {
#ctor_expr
}
}
}
}
fn derive_enum(input: &DeriveInput, data: &DataEnum) -> Result<TokenStream> {
let enum_ident = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let cmd_attrs = CommandAttrs::from_attrs(&input.attrs)?;
if cmd_attrs.description.is_some() || cmd_attrs.name.is_some() {
}
if data.variants.is_empty() {
return Err(syn::Error::new_spanned(
&input.ident,
"#[derive(Command)] on enums requires at least one variant",
));
}
let mut variants: Vec<VariantSpec> = Vec::with_capacity(data.variants.len());
for variant in &data.variants {
variants.push(VariantSpec::from_variant(variant)?);
}
check_unique_variant_names(&variants)?;
let registration_body = variants.iter().map(|v| {
let name_lit = &v.name;
let inner = &v.inner_ty;
match &v.description {
Some(desc) => {
quote! { .command_with_description::<#inner>(#name_lit, #desc) }
}
None => quote! { .command::<#inner>(#name_lit) },
}
});
let parent_bounds = variants.iter().map(|v| {
let inner = &v.inner_ty;
quote_spanned! { inner.span() =>
#inner: ::runi_cli::Command + ::runi_cli::SubCommandOf<__RegG> + 'static,
}
});
Ok(quote! {
impl #impl_generics #enum_ident #ty_generics #where_clause {
pub fn register_on<__RegG>(
launcher: ::runi_cli::Launcher<__RegG>,
) -> ::runi_cli::LauncherWithSubs<__RegG>
where
__RegG: ::runi_cli::Command + 'static,
#( #parent_bounds )*
{
launcher #( #registration_body )*
}
}
})
}
struct VariantSpec {
name: LitStr,
inner_ty: Type,
description: Option<String>,
}
impl VariantSpec {
fn from_variant(variant: &Variant) -> Result<Self> {
let inner_ty = match &variant.fields {
Fields::Unnamed(FieldsUnnamed { unnamed, .. }) if unnamed.len() == 1 => {
unnamed.first().unwrap().ty.clone()
}
_ => {
return Err(syn::Error::new_spanned(
&variant.fields,
"each variant must wrap exactly one struct — e.g. `Clone(CloneCmd)`",
));
}
};
let attrs = CommandAttrs::from_attrs(&variant.attrs)?;
let name = attrs
.name
.map(|s| LitStr::new(&s, variant.ident.span()))
.unwrap_or_else(|| {
LitStr::new(
&variant.ident.to_string().to_lowercase(),
variant.ident.span(),
)
});
let description = attrs.description.or_else(|| collect_doc(&variant.attrs));
Ok(Self {
name,
inner_ty,
description,
})
}
}
struct CommandAttrs {
name: Option<String>,
description: Option<String>,
}
impl CommandAttrs {
fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
let mut name = None;
let mut description = None;
for attr in attrs {
if !attr.path().is_ident("command") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("name") {
name = Some(lit_string(&meta.value()?.parse::<Expr>()?)?);
Ok(())
} else if meta.path.is_ident("description") {
description = Some(lit_string(&meta.value()?.parse::<Expr>()?)?);
Ok(())
} else {
Err(meta.error("unknown key; expected `name` or `description`"))
}
})?;
}
Ok(Self { name, description })
}
}
fn lit_string(expr: &Expr) -> Result<String> {
if let Expr::Lit(syn::ExprLit {
lit: Lit::Str(s), ..
}) = expr
{
Ok(s.value())
} else {
Err(syn::Error::new_spanned(expr, "expected a string literal"))
}
}
fn collect_doc(attrs: &[Attribute]) -> Option<String> {
let mut parts: Vec<String> = Vec::new();
for attr in attrs {
if !attr.path().is_ident("doc") {
continue;
}
if let Meta::NameValue(nv) = &attr.meta {
if let Expr::Lit(syn::ExprLit {
lit: Lit::Str(s), ..
}) = &nv.value
{
parts.push(s.value().trim().to_string());
}
}
}
let joined = parts.join(" ").trim().to_string();
if joined.is_empty() {
None
} else {
Some(joined)
}
}
enum FieldKind {
Option { prefix: LitStr },
Argument,
}
enum FieldShape {
Flag,
Optional,
Vec,
Required,
}
struct FieldSpec {
ident: Ident,
ty: Type,
inner_ty: Type, shape: FieldShape,
kind: FieldKind,
description: String,
}
impl FieldSpec {
fn from_field(field: &Field) -> Result<Self> {
let ident = field
.ident
.clone()
.ok_or_else(|| syn::Error::new_spanned(field, "named fields are required"))?;
let ty = field.ty.clone();
let shape = classify_shape(&ty);
let inner_ty = inner_type(&ty, &shape);
let (kind, explicit_desc) = parse_field_attrs(&field.attrs)?;
if matches!(kind, FieldKind::Argument) && matches!(shape, FieldShape::Flag) {
return Err(syn::Error::new_spanned(
&field.ty,
"#[argument] on a bool is not meaningful; use #[option(...)] for flags",
));
}
if matches!(kind, FieldKind::Argument) && matches!(shape, FieldShape::Vec) {
return Err(syn::Error::new_spanned(
&field.ty,
"repeatable positional arguments are not supported; declare a repeatable \
option with #[option(\"...\")] on a Vec<T> field instead",
));
}
let doc = collect_doc(&field.attrs);
let description = explicit_desc.or(doc).unwrap_or_default();
Ok(Self {
ident,
ty,
inner_ty,
shape,
kind,
description,
})
}
fn argument_name(&self) -> String {
self.ident.to_string()
}
fn lookup_name(&self) -> LitStr {
match &self.kind {
FieldKind::Option { prefix } => {
let (short, long) = split_prefix(&prefix.value());
let chosen = long.or(short).unwrap_or_default();
LitStr::new(&chosen, prefix.span())
}
FieldKind::Argument => LitStr::new(&self.argument_name(), self.ident.span()),
}
}
fn schema_stmt(&self) -> TokenStream {
let desc = &self.description;
match (&self.kind, &self.shape) {
(FieldKind::Option { prefix }, FieldShape::Flag) => {
quote_spanned! { prefix.span() => .flag(#prefix, #desc) }
}
(FieldKind::Option { prefix }, _) => {
quote_spanned! { prefix.span() => .option(#prefix, #desc) }
}
(FieldKind::Argument, FieldShape::Optional) => {
let name = self.argument_name();
quote_spanned! { self.ident.span() => .optional_argument(#name, #desc) }
}
(FieldKind::Argument, FieldShape::Vec) => {
let msg = "repeatable positional arguments are not supported; use #[option(...)]";
quote_spanned! { self.ty.span() => .argument(#msg, #desc) }
}
(FieldKind::Argument, _) => {
let name = self.argument_name();
quote_spanned! { self.ident.span() => .argument(#name, #desc) }
}
}
}
fn ctor_stmt(&self) -> TokenStream {
let ident = &self.ident;
let lookup = self.lookup_name();
let inner = &self.inner_ty;
match (&self.kind, &self.shape) {
(FieldKind::Option { .. }, FieldShape::Flag) => {
quote_spanned! { ident.span() => #ident: p.flag(#lookup) }
}
(FieldKind::Option { .. }, FieldShape::Optional) => {
quote_spanned! { ident.span() => #ident: p.get::<#inner>(#lookup)? }
}
(FieldKind::Option { .. }, FieldShape::Vec) => {
quote_spanned! { ident.span() => #ident: p.all::<#inner>(#lookup)? }
}
(FieldKind::Option { .. }, FieldShape::Required) => {
quote_spanned! { ident.span() => #ident: p.require::<#inner>(#lookup)? }
}
(FieldKind::Argument, FieldShape::Optional) => {
quote_spanned! { ident.span() => #ident: p.get::<#inner>(#lookup)? }
}
(FieldKind::Argument, FieldShape::Vec) => {
quote_spanned! { ident.span() => #ident: ::core::default::Default::default() }
}
(FieldKind::Argument, FieldShape::Required) => {
quote_spanned! { ident.span() => #ident: p.require::<#inner>(#lookup)? }
}
(FieldKind::Argument, FieldShape::Flag) => {
quote_spanned! { ident.span() => #ident: ::core::default::Default::default() }
}
}
}
}
fn parse_field_attrs(attrs: &[Attribute]) -> Result<(FieldKind, Option<String>)> {
let mut found: Option<(FieldKind, Option<String>)> = None;
for attr in attrs {
let is_option = attr.path().is_ident("option");
let is_argument = attr.path().is_ident("argument");
if !is_option && !is_argument {
continue;
}
if found.is_some() {
return Err(syn::Error::new_spanned(
attr,
"field has multiple #[option] / #[argument] attributes",
));
}
let parser =
|stream: syn::parse::ParseStream<'_>| -> Result<(Option<LitStr>, Option<String>)> {
let mut prefix: Option<LitStr> = None;
let mut description: Option<String> = None;
if stream.peek(LitStr) {
prefix = Some(stream.parse()?);
if stream.peek(syn::Token![,]) {
let _: syn::Token![,] = stream.parse()?;
}
}
while !stream.is_empty() {
let key: Ident = stream.parse()?;
let _: syn::Token![=] = stream.parse()?;
let val: LitStr = stream.parse()?;
if key == "description" {
description = Some(val.value());
}
if stream.peek(syn::Token![,]) {
let _: syn::Token![,] = stream.parse()?;
}
}
Ok((prefix, description))
};
let (prefix, description) = if matches!(attr.meta, Meta::List(_)) {
attr.parse_args_with(parser)?
} else {
(None, None)
};
if is_option {
let prefix = prefix.ok_or_else(|| {
syn::Error::new_spanned(attr, "#[option(...)] requires a prefix string literal")
})?;
validate_option_prefix(&prefix)?;
found = Some((FieldKind::Option { prefix }, description));
} else {
if prefix.is_some() {
return Err(syn::Error::new_spanned(
attr,
"#[argument] does not take a prefix string",
));
}
found = Some((FieldKind::Argument, description));
}
}
found.ok_or_else(|| {
syn::Error::new_spanned(
attrs
.first()
.map(|a| a.to_token_stream())
.unwrap_or_else(|| quote! {}),
"field is missing #[option(\"...\")] or #[argument] — runi-cli cannot infer intent",
)
})
}
fn classify_shape(ty: &Type) -> FieldShape {
if is_path_ident(ty, "bool") {
return FieldShape::Flag;
}
if outer_generic_ident(ty, "Option").is_some() {
return FieldShape::Optional;
}
if outer_generic_ident(ty, "Vec").is_some() {
return FieldShape::Vec;
}
FieldShape::Required
}
fn inner_type(ty: &Type, shape: &FieldShape) -> Type {
match shape {
FieldShape::Flag | FieldShape::Required => ty.clone(),
FieldShape::Optional | FieldShape::Vec => {
first_generic_arg(ty).unwrap_or_else(|| ty.clone())
}
}
}
fn is_path_ident(ty: &Type, target: &str) -> bool {
if let Type::Path(tp) = ty {
if tp.qself.is_some() {
return false;
}
if let Some(last) = tp.path.segments.last() {
return last.ident == target && matches!(last.arguments, PathArguments::None);
}
}
false
}
fn outer_generic_ident(ty: &Type, target: &str) -> Option<()> {
if let Type::Path(tp) = ty {
if tp.qself.is_some() {
return None;
}
if let Some(last) = tp.path.segments.last() {
if last.ident == target {
if let PathArguments::AngleBracketed(ang) = &last.arguments {
if ang.args.len() == 1 {
return Some(());
}
}
}
}
}
None
}
fn first_generic_arg(ty: &Type) -> Option<Type> {
if let Type::Path(tp) = ty {
if let Some(last) = tp.path.segments.last() {
if let PathArguments::AngleBracketed(ang) = &last.arguments {
if let Some(GenericArgument::Type(t)) = ang.args.first() {
return Some(t.clone());
}
}
}
}
None
}
fn validate_option_prefix(prefix: &LitStr) -> Result<()> {
let value = prefix.value();
let mut short_seen: Option<String> = None;
let mut long_seen: Option<String> = None;
for part in value.split(',').map(str::trim).filter(|s| !s.is_empty()) {
if let Some(rest) = part.strip_prefix("--") {
if rest.starts_with('-') {
return Err(syn::Error::new_spanned(
prefix,
format!("option prefix '{part}' has too many leading dashes"),
));
}
let first = rest.chars().next();
if !first.map(|c| c.is_ascii_alphabetic()).unwrap_or(false) {
return Err(syn::Error::new_spanned(
prefix,
format!("long option '{part}' must start with a letter"),
));
}
if let Some(prev) = &long_seen {
return Err(syn::Error::new_spanned(
prefix,
format!(
"option prefix has multiple long forms ('{prev}' and '{part}'); only one is supported",
),
));
}
long_seen = Some(part.to_string());
} else if let Some(rest) = part.strip_prefix('-') {
let first = rest.chars().next();
if !first.map(|c| c.is_ascii_alphabetic()).unwrap_or(false) {
return Err(syn::Error::new_spanned(
prefix,
format!("short option '{part}' must start with a letter"),
));
}
if let Some(prev) = &short_seen {
return Err(syn::Error::new_spanned(
prefix,
format!(
"option prefix has multiple short forms ('{prev}' and '{part}'); only one is supported",
),
));
}
short_seen = Some(part.to_string());
} else {
return Err(syn::Error::new_spanned(
prefix,
format!("option prefix '{part}' must start with - or --"),
));
}
}
if short_seen.is_none() && long_seen.is_none() {
return Err(syn::Error::new_spanned(
prefix,
"option prefix must contain at least one of -<short> or --<long>",
));
}
Ok(())
}
fn check_unique_option_aliases(fields: &[FieldSpec]) -> Result<()> {
use std::collections::HashMap;
let mut seen: HashMap<String, proc_macro2::Span> = HashMap::new();
for field in fields {
if let FieldKind::Option { prefix } = &field.kind {
let (short, long) = split_prefix(&prefix.value());
for alias in [short, long].into_iter().flatten() {
if let Some(prev) = seen.insert(alias.clone(), prefix.span()) {
let mut err =
syn::Error::new(prefix.span(), format!("duplicate option alias: {alias}"));
err.combine(syn::Error::new(prev, "previously declared here"));
return Err(err);
}
}
}
}
Ok(())
}
fn check_unique_variant_names(variants: &[VariantSpec]) -> Result<()> {
use std::collections::HashMap;
let mut seen: HashMap<String, proc_macro2::Span> = HashMap::new();
for v in variants {
let name = v.name.value();
if let Some(prev) = seen.insert(name.clone(), v.name.span()) {
let mut err =
syn::Error::new(v.name.span(), format!("duplicate subcommand name: {name}"));
err.combine(syn::Error::new(prev, "previously declared here"));
return Err(err);
}
}
Ok(())
}
fn split_prefix(prefix: &str) -> (Option<String>, Option<String>) {
let mut short = None;
let mut long = None;
for part in prefix.split(',').map(str::trim).filter(|s| !s.is_empty()) {
if part.starts_with("--") {
long = Some(part.to_string());
} else if part.starts_with('-') {
short = Some(part.to_string());
}
}
(short, long)
}