1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 FnArg, GenericArgument, ItemFn, PatType, PathArguments, ReturnType, Type, parse_macro_input,
5};
6
7#[proc_macro_attribute]
8pub fn strategy(attr: TokenStream, item: TokenStream) -> TokenStream {
9 let attr_tokens: proc_macro2::TokenStream = attr.into();
10
11 if !attr_tokens.is_empty() {
12 return syn::Error::new_spanned(
13 attr_tokens,
14 "strategy attribute does not take any arguments",
15 )
16 .into_compile_error()
17 .into();
18 }
19
20 let input_fn = parse_macro_input!(item as ItemFn);
21
22 if input_fn.sig.asyncness.is_some() {
23 return syn::Error::new_spanned(&input_fn.sig, "async functions are not supported")
24 .into_compile_error()
25 .into();
26 }
27
28 let mut inputs = input_fn.sig.inputs.iter();
29 if let Some(FnArg::Receiver(_)) = inputs.next() {
30 return syn::Error::new_spanned(
31 &input_fn.sig,
32 "cli strategy functions must be plain free functions; remove the &self receiver and keep options, arguments, and subcommands arguments",
33 )
34 .into_compile_error()
35 .into();
36 }
37
38 let mut inputs = input_fn.sig.inputs.iter();
39
40 let ctx_pat = match inputs.next() {
41 Some(FnArg::Typed(PatType { pat, ty, .. })) => {
42 if !matches_execution_context(ty.as_ref()) {
43 return syn::Error::new_spanned(
44 ty,
45 "strategy annotated functions must accept an cmdkit::ExecutionContext",
46 )
47 .into_compile_error()
48 .into();
49 }
50
51 pat
52 }
53 _ => {
54 return syn::Error::new_spanned(
55 &input_fn.sig,
56 "strategy functions must accept an cmdkit::ExecutionContext argument",
57 )
58 .into_compile_error()
59 .into();
60 }
61 };
62
63 let arguments_pat = match inputs.next() {
64 Some(FnArg::Typed(PatType { pat, ty, .. })) => {
65 if !matches_path_segments(ty.as_ref(), &[stringify!(InvocationArgs)])
66 && !matches_path_segments(ty.as_ref(), &["cmdkit", stringify!(InvocationArgs)])
67 {
68 return syn::Error::new_spanned(
69 ty,
70 "strategy annotated functions must accept an cmdkit::InvocationArgs arguments argument",
71 )
72 .into_compile_error()
73 .into();
74 }
75
76 pat
77 }
78 _ => {
79 return syn::Error::new_spanned(
80 &input_fn.sig,
81 "strategy annotated functions must accept an cmdkit::InvocationArgs arguments argument",
82 )
83 .into_compile_error()
84 .into();
85 }
86 };
87
88 match &input_fn.sig.output {
89 ReturnType::Type(_, ty) => match ty.as_ref() {
90 Type::Path(path)
91 if path.path.segments.len() == 1
92 && path.path.segments[0].ident == "Result"
93 && matches_result_type(&path.path) => {}
94 _ => {
95 return syn::Error::new_spanned(
96 ty,
97 "strategy annotated functions must return Result<(), cmdkit::StrategyError>",
98 )
99 .into_compile_error()
100 .into();
101 }
102 },
103 ReturnType::Default => {
104 return syn::Error::new_spanned(
105 &input_fn.sig,
106 "strategy annotated functions must return Result<(), cmdkit::StrategyError>",
107 )
108 .into_compile_error()
109 .into();
110 }
111 }
112
113 let fn_ident = &input_fn.sig.ident;
114 let vis = &input_fn.vis;
115 let strategy_ident = format_ident!("{}", to_pascal(&fn_ident.to_string()));
116 let factory_ident = format_ident!("{}_strategy", fn_ident);
117 let attrs = &input_fn.attrs;
118 let body = &input_fn.block;
119
120 let expanded = quote! {
121 #(#attrs)*
122 #vis struct #strategy_ident;
123
124 impl #strategy_ident {
125 #vis fn new() -> Self {
126 Self
127 }
128 }
129
130 impl ::cmdkit::CommandStrategy for #strategy_ident {
131 fn execute(
132 &self,
133 #ctx_pat: &::cmdkit::ExecutionContext,
134 #arguments_pat: ::cmdkit::InvocationArgs,
135 ) -> Result<(), ::cmdkit::StrategyError> {
136 #body
137 }
138 }
139
140 #vis fn #factory_ident() -> #strategy_ident {
141 #strategy_ident::new()
142 }
143 };
144
145 expanded.into()
146}
147
148fn to_pascal(s: &str) -> String {
149 let mut out = String::new();
150 for part in s.split('_') {
151 if part.is_empty() {
152 continue;
153 }
154 let mut chars = part.chars();
155 if let Some(first) = chars.next() {
156 out.extend(first.to_uppercase());
157 out.push_str(chars.as_str());
158 }
159 }
160 out
161}
162
163fn matches_execution_context(ty: &Type) -> bool {
164 matches_path_segments(ty, &[stringify!(ExecutionContext)])
165 || matches_path_segments(ty, &["cmdkit", stringify!(ExecutionContext)])
166 || matches!(
167 ty,
168 Type::Reference(reference)
169 if matches_path_segments(reference.elem.as_ref(), &[stringify!(ExecutionContext)])
170 || matches_path_segments(
171 reference.elem.as_ref(),
172 &["cmdkit", stringify!(ExecutionContext)]
173 )
174 )
175}
176
177fn matches_result_type(path: &syn::Path) -> bool {
178 let Some(last_segment) = path.segments.last() else {
179 return false;
180 };
181
182 let PathArguments::AngleBracketed(arguments) = &last_segment.arguments else {
183 return false;
184 };
185
186 let mut args = arguments.args.iter();
187
188 matches!(args.next(), Some(GenericArgument::Type(Type::Tuple(tuple))) if tuple.elems.is_empty())
189 && matches!(
190 args.next(),
191 Some(GenericArgument::Type(inner_type)) if matches_path_segments(inner_type, &["StrategyError"])
192 || matches_path_segments(inner_type, &["cmdkit", "StrategyError"])
193 )
194 && args.next().is_none()
195}
196
197fn matches_path_segments(ty: &Type, expected_segments: &[&str]) -> bool {
198 let Type::Path(path) = ty else {
199 return false;
200 };
201
202 let actual_segments: Vec<_> = path
203 .path
204 .segments
205 .iter()
206 .map(|segment| segment.ident.to_string())
207 .collect();
208
209 actual_segments == expected_segments
210}