Skip to main content

cmdkit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    FnArg, GenericArgument, ItemFn, PatType, PathArguments, ReturnType, Type, parse_macro_input,
5};
6
7#[proc_macro_attribute]
8pub fn strategy(attr: TokenStream, item: TokenStream) -> TokenStream {
9    let attr_tokens: proc_macro2::TokenStream = attr.into();
10
11    if !attr_tokens.is_empty() {
12        return syn::Error::new_spanned(
13            attr_tokens,
14            "strategy attribute does not take any arguments",
15        )
16        .into_compile_error()
17        .into();
18    }
19
20    let input_fn = parse_macro_input!(item as ItemFn);
21
22    if input_fn.sig.asyncness.is_some() {
23        return syn::Error::new_spanned(&input_fn.sig, "async functions are not supported")
24            .into_compile_error()
25            .into();
26    }
27
28    let mut inputs = input_fn.sig.inputs.iter();
29    if let Some(FnArg::Receiver(_)) = inputs.next() {
30        return syn::Error::new_spanned(
31            &input_fn.sig,
32            "cli strategy functions must be plain free functions; remove the &self receiver and keep options, arguments, and subcommands arguments",
33        )
34        .into_compile_error()
35        .into();
36    }
37
38    let mut inputs = input_fn.sig.inputs.iter();
39
40    let options_pat = match inputs.next() {
41        Some(FnArg::Typed(PatType { pat, ty, .. })) => {
42            if !matches_vec_of_path(ty.as_ref(), &["Switch"])
43                && !matches_vec_of_path(ty.as_ref(), &["cmdkit", "Switch"])
44            {
45                return syn::Error::new_spanned(
46                    ty,
47                    "cli strategy functions must accept a Vec<Switch> options argument",
48                )
49                .into_compile_error()
50                .into();
51            }
52
53            pat
54        }
55        _ => {
56            return syn::Error::new_spanned(
57                &input_fn.sig,
58                "cli strategy functions must accept an options Vec<Switch> argument",
59            )
60            .into_compile_error()
61            .into();
62        }
63    };
64
65    let arguments_pat = match inputs.next() {
66        Some(FnArg::Typed(PatType { pat, ty, .. })) => {
67            if !matches_vec_of_path(ty.as_ref(), &["Argument"])
68                && !matches_vec_of_path(ty.as_ref(), &["cmdkit", "Argument"])
69            {
70                return syn::Error::new_spanned(
71                    ty,
72                    "cli strategy functions must accept a Vec<Argument> arguments argument",
73                )
74                .into_compile_error()
75                .into();
76            }
77
78            pat
79        }
80        _ => {
81            return syn::Error::new_spanned(
82                &input_fn.sig,
83                "cli strategy functions must accept an arguments Vec<Argument> argument",
84            )
85            .into_compile_error()
86            .into();
87        }
88    };
89
90    let subcommands_pat = match inputs.next() {
91        Some(FnArg::Typed(PatType { pat, ty, .. })) => {
92            if inputs.next().is_some() {
93                return syn::Error::new_spanned(
94                    &input_fn.sig,
95                    "cli strategy functions must accept exactly three parsed invocation arguments",
96                )
97                .into_compile_error()
98                .into();
99            }
100
101            if !matches_vec_of_path(ty.as_ref(), &["String"])
102                && !matches_vec_of_path(ty.as_ref(), &["std", "string", "String"])
103                && !matches_vec_of_path(ty.as_ref(), &["alloc", "string", "String"])
104            {
105                return syn::Error::new_spanned(
106                    ty,
107                    "cli strategy functions must accept a Vec<String> subcommands argument",
108                )
109                .into_compile_error()
110                .into();
111            }
112
113            pat
114        }
115        _ => {
116            return syn::Error::new_spanned(
117                &input_fn.sig,
118                "cli strategy functions must accept a subcommands Vec<String> argument",
119            )
120            .into_compile_error()
121            .into();
122        }
123    };
124
125    match &input_fn.sig.output {
126        ReturnType::Type(_, ty) => match ty.as_ref() {
127            Type::Path(path)
128                if path.path.segments.len() == 1
129                    && path.path.segments[0].ident == "Result"
130                    && matches_result_type(&path.path) => {}
131            _ => {
132                return syn::Error::new_spanned(
133                    ty,
134                    "cli strategy functions must return Result<(), cmdkit::StrategyError>",
135                )
136                .into_compile_error()
137                .into();
138            }
139        },
140        ReturnType::Default => {
141            return syn::Error::new_spanned(
142                &input_fn.sig,
143                "cli strategy functions must return Result<(), cmdkit::StrategyError>",
144            )
145            .into_compile_error()
146            .into();
147        }
148    }
149
150    let fn_ident = &input_fn.sig.ident;
151    let vis = &input_fn.vis;
152    let strategy_ident = format_ident!("{}", to_pascal(&fn_ident.to_string()));
153    let factory_ident = format_ident!("{}_strategy", fn_ident);
154    let attrs = &input_fn.attrs;
155    let body = &input_fn.block;
156
157    let expanded = quote! {
158        #(#attrs)*
159        #vis struct #strategy_ident;
160
161        impl #strategy_ident {
162            #vis fn new() -> Self {
163                Self
164            }
165        }
166
167        impl ::cmdkit::CommandStrategy for #strategy_ident {
168            fn execute(
169                &self,
170                #options_pat: Vec<::cmdkit::Switch>,
171                #arguments_pat: Vec<::cmdkit::Argument>,
172                #subcommands_pat: Vec<String>,
173            ) -> Result<(), ::cmdkit::StrategyError> {
174                #body
175            }
176        }
177
178        #vis fn #factory_ident() -> #strategy_ident {
179            #strategy_ident::new()
180        }
181    };
182
183    expanded.into()
184}
185
186fn to_pascal(s: &str) -> String {
187    let mut out = String::new();
188    for part in s.split('_') {
189        if part.is_empty() {
190            continue;
191        }
192        let mut chars = part.chars();
193        if let Some(first) = chars.next() {
194            out.extend(first.to_uppercase());
195            out.push_str(chars.as_str());
196        }
197    }
198    out
199}
200
201fn matches_vec_of_path(ty: &Type, expected_segments: &[&str]) -> bool {
202    let Type::Path(path) = ty else {
203        return false;
204    };
205
206    let Some(last_segment) = path.path.segments.last() else {
207        return false;
208    };
209
210    if last_segment.ident != "Vec" {
211        return false;
212    }
213
214    let PathArguments::AngleBracketed(arguments) = &last_segment.arguments else {
215        return false;
216    };
217
218    let Some(GenericArgument::Type(inner_type)) = arguments.args.first() else {
219        return false;
220    };
221
222    matches_path_segments(inner_type, expected_segments)
223}
224
225fn matches_result_type(path: &syn::Path) -> bool {
226    let Some(last_segment) = path.segments.last() else {
227        return false;
228    };
229
230    let PathArguments::AngleBracketed(arguments) = &last_segment.arguments else {
231        return false;
232    };
233
234    let mut args = arguments.args.iter();
235
236    matches!(args.next(), Some(GenericArgument::Type(Type::Tuple(tuple))) if tuple.elems.is_empty())
237        && matches!(
238            args.next(),
239            Some(GenericArgument::Type(inner_type)) if matches_path_segments(inner_type, &["StrategyError"])
240                || matches_path_segments(inner_type, &["cmdkit", "StrategyError"])
241        )
242        && args.next().is_none()
243}
244
245fn matches_path_segments(ty: &Type, expected_segments: &[&str]) -> bool {
246    let Type::Path(path) = ty else {
247        return false;
248    };
249
250    let actual_segments: Vec<_> = path
251        .path
252        .segments
253        .iter()
254        .map(|segment| segment.ident.to_string())
255        .collect();
256
257    actual_segments == expected_segments
258}