Skip to main content

cmdkit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    FnArg, GenericArgument, ItemFn, PatType, PathArguments, ReturnType, Type, parse_macro_input,
5};
6
7#[proc_macro_attribute]
8pub fn strategy(attr: TokenStream, item: TokenStream) -> TokenStream {
9    let attr_tokens: proc_macro2::TokenStream = attr.into();
10
11    if !attr_tokens.is_empty() {
12        return syn::Error::new_spanned(
13            attr_tokens,
14            "strategy attribute does not take any arguments",
15        )
16        .into_compile_error()
17        .into();
18    }
19
20    let input_fn = parse_macro_input!(item as ItemFn);
21
22    if input_fn.sig.asyncness.is_some() {
23        return syn::Error::new_spanned(&input_fn.sig, "async functions are not supported")
24            .into_compile_error()
25            .into();
26    }
27
28    let mut inputs = input_fn.sig.inputs.iter();
29    if let Some(FnArg::Receiver(_)) = inputs.next() {
30        return syn::Error::new_spanned(
31            &input_fn.sig,
32            "cli strategy functions must be plain free functions; remove the &self receiver and keep options, arguments, and subcommands arguments",
33        )
34        .into_compile_error()
35        .into();
36    }
37
38    let mut inputs = input_fn.sig.inputs.iter();
39
40    let ctx_pat = match inputs.next() {
41        Some(FnArg::Typed(PatType { pat, ty, .. })) => {
42            if !matches_execution_context(ty.as_ref()) {
43                return syn::Error::new_spanned(
44                    ty,
45                    "strategy annotated functions must accept an cmdkit::ExecutionContext",
46                )
47                .into_compile_error()
48                .into();
49            }
50
51            pat
52        }
53        _ => {
54            return syn::Error::new_spanned(
55                &input_fn.sig,
56                "strategy functions must accept an cmdkit::ExecutionContext argument",
57            )
58            .into_compile_error()
59            .into();
60        }
61    };
62
63    let arguments_pat = match inputs.next() {
64        Some(FnArg::Typed(PatType { pat, ty, .. })) => {
65            if !matches_path_segments(ty.as_ref(), &[stringify!(InvocationArgs)])
66                && !matches_path_segments(ty.as_ref(), &["cmdkit", stringify!(InvocationArgs)])
67            {
68                return syn::Error::new_spanned(
69                    ty,
70                    "strategy annotated functions must accept an cmdkit::InvocationArgs arguments argument",
71                )
72                .into_compile_error()
73                .into();
74            }
75
76            pat
77        }
78        _ => {
79            return syn::Error::new_spanned(
80                &input_fn.sig,
81                "strategy annotated functions must accept an cmdkit::InvocationArgs arguments argument",
82            )
83            .into_compile_error()
84            .into();
85        }
86    };
87
88    match &input_fn.sig.output {
89        ReturnType::Type(_, ty) => match ty.as_ref() {
90            Type::Path(path)
91                if path.path.segments.len() == 1
92                    && path.path.segments[0].ident == "Result"
93                    && matches_result_type(&path.path) => {}
94            _ => {
95                return syn::Error::new_spanned(
96                    ty,
97                    "strategy annotated functions must return Result<(), cmdkit::StrategyError>",
98                )
99                .into_compile_error()
100                .into();
101            }
102        },
103        ReturnType::Default => {
104            return syn::Error::new_spanned(
105                &input_fn.sig,
106                "strategy annotated functions must return Result<(), cmdkit::StrategyError>",
107            )
108            .into_compile_error()
109            .into();
110        }
111    }
112
113    let fn_ident = &input_fn.sig.ident;
114    let vis = &input_fn.vis;
115    let strategy_ident = format_ident!("{}", to_pascal(&fn_ident.to_string()));
116    let factory_ident = format_ident!("{}_strategy", fn_ident);
117    let attrs = &input_fn.attrs;
118    let body = &input_fn.block;
119
120    let expanded = quote! {
121        #(#attrs)*
122        #vis struct #strategy_ident;
123
124        impl #strategy_ident {
125            #vis fn new() -> Self {
126                Self
127            }
128        }
129
130        impl ::cmdkit::CommandStrategy for #strategy_ident {
131            fn execute(
132                &self,
133                #ctx_pat: &::cmdkit::ExecutionContext,
134                #arguments_pat: ::cmdkit::InvocationArgs,
135            ) -> Result<(), ::cmdkit::StrategyError> {
136                #body
137            }
138        }
139
140        #vis fn #factory_ident() -> #strategy_ident {
141            #strategy_ident::new()
142        }
143    };
144
145    expanded.into()
146}
147
148fn to_pascal(s: &str) -> String {
149    let mut out = String::new();
150    for part in s.split('_') {
151        if part.is_empty() {
152            continue;
153        }
154        let mut chars = part.chars();
155        if let Some(first) = chars.next() {
156            out.extend(first.to_uppercase());
157            out.push_str(chars.as_str());
158        }
159    }
160    out
161}
162
163fn matches_execution_context(ty: &Type) -> bool {
164    matches_path_segments(ty, &[stringify!(ExecutionContext)])
165        || matches_path_segments(ty, &["cmdkit", stringify!(ExecutionContext)])
166        || matches!(
167            ty,
168            Type::Reference(reference)
169                if matches_path_segments(reference.elem.as_ref(), &[stringify!(ExecutionContext)])
170                    || matches_path_segments(
171                        reference.elem.as_ref(),
172                        &["cmdkit", stringify!(ExecutionContext)]
173                    )
174        )
175}
176
177fn matches_result_type(path: &syn::Path) -> bool {
178    let Some(last_segment) = path.segments.last() else {
179        return false;
180    };
181
182    let PathArguments::AngleBracketed(arguments) = &last_segment.arguments else {
183        return false;
184    };
185
186    let mut args = arguments.args.iter();
187
188    matches!(args.next(), Some(GenericArgument::Type(Type::Tuple(tuple))) if tuple.elems.is_empty())
189        && matches!(
190            args.next(),
191            Some(GenericArgument::Type(inner_type)) if matches_path_segments(inner_type, &["StrategyError"])
192                || matches_path_segments(inner_type, &["cmdkit", "StrategyError"])
193        )
194        && args.next().is_none()
195}
196
197fn matches_path_segments(ty: &Type, expected_segments: &[&str]) -> bool {
198    let Type::Path(path) = ty else {
199        return false;
200    };
201
202    let actual_segments: Vec<_> = path
203        .path
204        .segments
205        .iter()
206        .map(|segment| segment.ident.to_string())
207        .collect();
208
209    actual_segments == expected_segments
210}