Skip to main content

argot_cmd_derive/
lib.rs

1#![forbid(unsafe_code)]
2use proc_macro::TokenStream;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Fields, LitChar, LitStr};
6
7/// Derive macro that implements [`argot_cmd::ArgotCommand`] for a struct.
8///
9/// Annotate a struct with `#[derive(ArgotCommand)]` to automatically implement
10/// `ArgotCommand::command()` using `#[argot(...)]` attributes on the struct
11/// and its fields.
12///
13/// The generated `command()` implementation calls [`argot_cmd::Command::builder`]
14/// and chains builder methods derived from the attributes, then calls
15/// `.build().unwrap()`. The `unwrap` will **panic** at *call time* (not at
16/// compile time) if the generated canonical name is somehow empty — this
17/// should not occur in practice because the name defaults to the kebab-case
18/// struct name.
19///
20/// ## Struct-level attributes (`#[argot(...)]`)
21///
22/// | Key | Type | Description |
23/// |-----|------|-------------|
24/// | `canonical = "name"` | string | Override the canonical command name. Default: struct name converted to kebab-case (e.g. `DeployApp` → `deploy-app`). |
25/// | `summary = "text"` | string | One-line summary. |
26/// | `description = "text"` | string | Long prose description. |
27/// | `alias = "a"` | string | Add an alias (repeat the attribute to add more). |
28/// | `best_practice = "text"` | string | Add a best-practice tip (repeatable). |
29/// | `anti_pattern = "text"` | string | Add an anti-pattern warning (repeatable). |
30///
31/// ## Field-level attributes (`#[argot(...)]`)
32///
33/// Fields **without** an `#[argot(...)]` attribute are skipped entirely.
34/// Every annotated field must include either `positional` or `flag`.
35///
36/// | Key | Description |
37/// |-----|-------------|
38/// | `positional` | Treat as a positional [`argot_cmd::Argument`]. |
39/// | `flag` | Treat as a named [`argot_cmd::Flag`]. |
40/// | `required` | Mark the argument or flag as required. |
41/// | `short = 'c'` | Short character for a flag (e.g. `short = 'v'`). |
42/// | `takes_value` | Flag consumes the next token as its value. |
43/// | `description = "text"` | Human-readable description. |
44/// | `default = "value"` | Default value string. |
45///
46/// ## Example
47///
48/// ```rust,ignore
49/// use argot_cmd::ArgotCommand;
50///
51/// #[derive(ArgotCommand)]
52/// #[argot(
53///     summary = "Deploy the application",
54///     alias = "d",
55///     best_practice = "always dry-run first"
56/// )]
57/// struct Deploy {
58///     #[argot(positional, required, description = "Target environment")]
59///     env: String,
60///
61///     #[argot(flag, short = 'n', description = "Simulate without changes")]
62///     dry_run: bool,
63/// }
64///
65/// let cmd = Deploy::command();
66/// assert_eq!(cmd.canonical, "deploy");
67/// assert_eq!(cmd.aliases, vec!["d"]);
68/// ```
69#[proc_macro_derive(ArgotCommand, attributes(argot))]
70pub fn derive_argot_command(input: TokenStream) -> TokenStream {
71    let input = parse_macro_input!(input as DeriveInput);
72    derive_impl(input)
73        .unwrap_or_else(|e| e.into_compile_error())
74        .into()
75}
76
77// ---------------------------------------------------------------------------
78// Attribute data structures
79// ---------------------------------------------------------------------------
80
81#[derive(Default)]
82struct StructAttrs {
83    canonical: Option<String>,
84    summary: Option<String>,
85    description: Option<String>,
86    aliases: Vec<String>,
87    best_practices: Vec<String>,
88    anti_patterns: Vec<String>,
89}
90
91#[derive(Default)]
92struct FieldAttrs {
93    positional: bool,
94    flag: bool,
95    required: bool,
96    short: Option<char>,
97    takes_value: bool,
98    description: Option<String>,
99    default: Option<String>,
100}
101
102// ---------------------------------------------------------------------------
103// Attribute parsers
104// ---------------------------------------------------------------------------
105
106fn parse_struct_attrs(attrs: &[syn::Attribute]) -> syn::Result<StructAttrs> {
107    let mut out = StructAttrs::default();
108    for attr in attrs {
109        if !attr.path().is_ident("argot") {
110            continue;
111        }
112        attr.parse_nested_meta(|meta| {
113            if meta.path.is_ident("canonical") {
114                let val: LitStr = meta.value()?.parse()?;
115                out.canonical = Some(val.value());
116            } else if meta.path.is_ident("summary") {
117                let val: LitStr = meta.value()?.parse()?;
118                out.summary = Some(val.value());
119            } else if meta.path.is_ident("description") {
120                let val: LitStr = meta.value()?.parse()?;
121                out.description = Some(val.value());
122            } else if meta.path.is_ident("alias") {
123                let val: LitStr = meta.value()?.parse()?;
124                out.aliases.push(val.value());
125            } else if meta.path.is_ident("best_practice") {
126                let val: LitStr = meta.value()?.parse()?;
127                out.best_practices.push(val.value());
128            } else if meta.path.is_ident("anti_pattern") {
129                let val: LitStr = meta.value()?.parse()?;
130                out.anti_patterns.push(val.value());
131            } else {
132                return Err(meta.error(format!(
133                    "unknown struct-level argot attribute `{}` — valid keys are: canonical, summary, description, alias, best_practice, anti_pattern",
134                    meta.path
135                        .get_ident()
136                        .map(|i| i.to_string())
137                        .unwrap_or_default()
138                )));
139            }
140            Ok(())
141        })?;
142    }
143    Ok(out)
144}
145
146fn parse_field_attrs(attrs: &[syn::Attribute]) -> syn::Result<Option<FieldAttrs>> {
147    let mut found = false;
148    let mut out = FieldAttrs::default();
149    for attr in attrs {
150        if !attr.path().is_ident("argot") {
151            continue;
152        }
153        found = true;
154        attr.parse_nested_meta(|meta| {
155            if meta.path.is_ident("positional") {
156                out.positional = true;
157            } else if meta.path.is_ident("flag") {
158                out.flag = true;
159            } else if meta.path.is_ident("required") {
160                out.required = true;
161            } else if meta.path.is_ident("takes_value") {
162                out.takes_value = true;
163            } else if meta.path.is_ident("short") {
164                let val: LitChar = meta.value()?.parse()?;
165                out.short = Some(val.value());
166            } else if meta.path.is_ident("description") {
167                let val: LitStr = meta.value()?.parse()?;
168                out.description = Some(val.value());
169            } else if meta.path.is_ident("default") {
170                let val: LitStr = meta.value()?.parse()?;
171                out.default = Some(val.value());
172            } else {
173                return Err(meta.error(format!(
174                    "unknown field-level argot attribute `{}` — valid keys are: positional, flag, required, short, takes_value, description, default",
175                    meta.path
176                        .get_ident()
177                        .map(|i| i.to_string())
178                        .unwrap_or_default()
179                )));
180            }
181            Ok(())
182        })?;
183    }
184    if found {
185        Ok(Some(out))
186    } else {
187        Ok(None)
188    }
189}
190
191// ---------------------------------------------------------------------------
192// Name conversion helpers
193// ---------------------------------------------------------------------------
194
195/// Convert `CamelCase` → `kebab-case`.
196///
197/// Inserts `-` before each uppercase letter that follows a lowercase letter,
198/// then lowercases everything.
199fn camel_to_kebab(name: &str) -> String {
200    let mut out = String::with_capacity(name.len() + 4);
201    let chars: Vec<char> = name.chars().collect();
202    for (i, &c) in chars.iter().enumerate() {
203        if c.is_uppercase() && i > 0 && chars[i - 1].is_lowercase() {
204            out.push('-');
205        }
206        out.push(c.to_ascii_lowercase());
207    }
208    out
209}
210
211/// Convert a Rust field name (`snake_case`) to a CLI name (`kebab-case`).
212fn snake_to_kebab(name: &str) -> String {
213    name.replace('_', "-")
214}
215
216// ---------------------------------------------------------------------------
217// Core derive implementation
218// ---------------------------------------------------------------------------
219
220fn derive_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
221    let fields = match &input.data {
222        Data::Struct(s) => &s.fields,
223        Data::Enum(_) => {
224            return Err(syn::Error::new_spanned(
225                &input.ident,
226                format!(
227                    "`#[derive(ArgotCommand)]` cannot be used on enum `{}` — only structs are supported",
228                    input.ident
229                ),
230            ));
231        }
232        Data::Union(_) => {
233            return Err(syn::Error::new_spanned(
234                &input.ident,
235                format!(
236                    "`#[derive(ArgotCommand)]` cannot be used on union `{}` — only structs are supported",
237                    input.ident
238                ),
239            ));
240        }
241    };
242
243    let named = match fields {
244        Fields::Named(n) => &n.named,
245        Fields::Unit => &syn::punctuated::Punctuated::new(),
246        Fields::Unnamed(_) => {
247            return Err(syn::Error::new_spanned(
248                &input.ident,
249                format!(
250                    "`{}` uses tuple fields — `#[derive(ArgotCommand)]` requires named fields (e.g., `struct Foo {{ name: String }}`)",
251                    input.ident
252                ),
253            ));
254        }
255    };
256
257    let struct_attrs = parse_struct_attrs(&input.attrs)?;
258
259    let canonical = struct_attrs
260        .canonical
261        .clone()
262        .unwrap_or_else(|| camel_to_kebab(&input.ident.to_string()));
263
264    let mut builder_tokens = quote! {
265        ::argot_cmd::Command::builder(#canonical)
266    };
267
268    if let Some(ref s) = struct_attrs.summary {
269        builder_tokens = quote! { #builder_tokens .summary(#s) };
270    }
271    if let Some(ref d) = struct_attrs.description {
272        builder_tokens = quote! { #builder_tokens .description(#d) };
273    }
274    for alias in &struct_attrs.aliases {
275        builder_tokens = quote! { #builder_tokens .alias(#alias) };
276    }
277    for bp in &struct_attrs.best_practices {
278        builder_tokens = quote! { #builder_tokens .best_practice(#bp) };
279    }
280    for ap in &struct_attrs.anti_patterns {
281        builder_tokens = quote! { #builder_tokens .anti_pattern(#ap) };
282    }
283
284    for field in named.iter() {
285        let field_ident = field.ident.as_ref().expect("named field has ident");
286        let fa = match parse_field_attrs(&field.attrs)? {
287            None => continue,
288            Some(fa) => fa,
289        };
290
291        if fa.positional && fa.flag {
292            return Err(syn::Error::new_spanned(
293                field_ident,
294                "a field cannot be both `positional` and `flag` — choose one",
295            ));
296        }
297
298        if fa.positional {
299            let arg_name = snake_to_kebab(&field_ident.to_string());
300            let mut arg_builder = quote! { ::argot_cmd::Argument::builder(#arg_name) };
301            if fa.required {
302                arg_builder = quote! { #arg_builder .required() };
303            }
304            if let Some(ref desc) = fa.description {
305                arg_builder = quote! { #arg_builder .description(#desc) };
306            }
307            if let Some(ref def) = fa.default {
308                arg_builder = quote! { #arg_builder .default_value(#def) };
309            }
310            builder_tokens = quote! { #builder_tokens .argument(#arg_builder .build().unwrap()) };
311        } else if fa.flag {
312            let flag_name = snake_to_kebab(&field_ident.to_string());
313            let mut flag_builder = quote! { ::argot_cmd::Flag::builder(#flag_name) };
314            if let Some(c) = fa.short {
315                flag_builder = quote! { #flag_builder .short(#c) };
316            }
317            if fa.required {
318                flag_builder = quote! { #flag_builder .required() };
319            }
320            if fa.takes_value {
321                flag_builder = quote! { #flag_builder .takes_value() };
322            }
323            if let Some(ref desc) = fa.description {
324                flag_builder = quote! { #flag_builder .description(#desc) };
325            }
326            if let Some(ref def) = fa.default {
327                flag_builder = quote! { #flag_builder .default_value(#def) };
328            }
329            builder_tokens = quote! { #builder_tokens .flag(#flag_builder .build().unwrap()) };
330        } else {
331            return Err(syn::Error::new_spanned(
332                field_ident,
333                format!(
334                    "field `{}` has `#[argot(...)]` but is missing a kind — add `positional` or `flag`",
335                    field_ident
336                ),
337            ));
338        }
339    }
340
341    builder_tokens = quote! { #builder_tokens .build().unwrap() };
342
343    let ident = &input.ident;
344    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
345
346    Ok(quote! {
347        impl #impl_generics ::argot_cmd::ArgotCommand for #ident #ty_generics #where_clause {
348            fn command() -> ::argot_cmd::Command {
349                #builder_tokens
350            }
351        }
352    })
353}
354
355// ---------------------------------------------------------------------------
356// Unit tests for name conversion helpers
357// ---------------------------------------------------------------------------
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_camel_to_kebab() {
365        assert_eq!(camel_to_kebab("Deploy"), "deploy");
366        assert_eq!(camel_to_kebab("DeployCommand"), "deploy-command");
367        assert_eq!(camel_to_kebab("RemoteAdd"), "remote-add");
368        assert_eq!(camel_to_kebab("SomeOtherCommand"), "some-other-command");
369    }
370
371    #[test]
372    fn test_snake_to_kebab() {
373        assert_eq!(snake_to_kebab("dry_run"), "dry-run");
374        assert_eq!(snake_to_kebab("output"), "output");
375        assert_eq!(snake_to_kebab("env"), "env");
376    }
377
378    #[test]
379    fn test_camel_to_kebab_single_word() {
380        assert_eq!(camel_to_kebab("Deploy"), "deploy");
381    }
382
383    #[test]
384    fn test_snake_to_kebab_multi_word() {
385        assert_eq!(snake_to_kebab("dry_run_mode"), "dry-run-mode");
386    }
387}