1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 FnArg, GenericArgument, ItemFn, PatType, PathArguments, ReturnType, Type, parse_macro_input,
5};
6
7#[proc_macro_attribute]
8pub fn strategy(attr: TokenStream, item: TokenStream) -> TokenStream {
9 let attr_tokens: proc_macro2::TokenStream = attr.into();
10
11 if !attr_tokens.is_empty() {
12 return syn::Error::new_spanned(
13 attr_tokens,
14 "strategy attribute does not take any arguments",
15 )
16 .into_compile_error()
17 .into();
18 }
19
20 let input_fn = parse_macro_input!(item as ItemFn);
21
22 if input_fn.sig.asyncness.is_some() {
23 return syn::Error::new_spanned(&input_fn.sig, "async functions are not supported")
24 .into_compile_error()
25 .into();
26 }
27
28 let mut inputs = input_fn.sig.inputs.iter();
29 if let Some(FnArg::Receiver(_)) = inputs.next() {
30 return syn::Error::new_spanned(
31 &input_fn.sig,
32 "cli strategy functions must be plain free functions; remove the &self receiver and keep options, arguments, and subcommands arguments",
33 )
34 .into_compile_error()
35 .into();
36 }
37
38 let mut inputs = input_fn.sig.inputs.iter();
39
40 let options_pat = match inputs.next() {
41 Some(FnArg::Typed(PatType { pat, ty, .. })) => {
42 if !matches_vec_of_path(ty.as_ref(), &["Switch"])
43 && !matches_vec_of_path(ty.as_ref(), &["cmdkit", "Switch"])
44 {
45 return syn::Error::new_spanned(
46 ty,
47 "cli strategy functions must accept a Vec<Switch> options argument",
48 )
49 .into_compile_error()
50 .into();
51 }
52
53 pat
54 }
55 _ => {
56 return syn::Error::new_spanned(
57 &input_fn.sig,
58 "cli strategy functions must accept an options Vec<Switch> argument",
59 )
60 .into_compile_error()
61 .into();
62 }
63 };
64
65 let arguments_pat = match inputs.next() {
66 Some(FnArg::Typed(PatType { pat, ty, .. })) => {
67 if !matches_vec_of_path(ty.as_ref(), &["Argument"])
68 && !matches_vec_of_path(ty.as_ref(), &["cmdkit", "Argument"])
69 {
70 return syn::Error::new_spanned(
71 ty,
72 "cli strategy functions must accept a Vec<Argument> arguments argument",
73 )
74 .into_compile_error()
75 .into();
76 }
77
78 pat
79 }
80 _ => {
81 return syn::Error::new_spanned(
82 &input_fn.sig,
83 "cli strategy functions must accept an arguments Vec<Argument> argument",
84 )
85 .into_compile_error()
86 .into();
87 }
88 };
89
90 let subcommands_pat = match inputs.next() {
91 Some(FnArg::Typed(PatType { pat, ty, .. })) => {
92 if inputs.next().is_some() {
93 return syn::Error::new_spanned(
94 &input_fn.sig,
95 "cli strategy functions must accept exactly three parsed invocation arguments",
96 )
97 .into_compile_error()
98 .into();
99 }
100
101 if !matches_vec_of_path(ty.as_ref(), &["String"])
102 && !matches_vec_of_path(ty.as_ref(), &["std", "string", "String"])
103 && !matches_vec_of_path(ty.as_ref(), &["alloc", "string", "String"])
104 {
105 return syn::Error::new_spanned(
106 ty,
107 "cli strategy functions must accept a Vec<String> subcommands argument",
108 )
109 .into_compile_error()
110 .into();
111 }
112
113 pat
114 }
115 _ => {
116 return syn::Error::new_spanned(
117 &input_fn.sig,
118 "cli strategy functions must accept a subcommands Vec<String> argument",
119 )
120 .into_compile_error()
121 .into();
122 }
123 };
124
125 match &input_fn.sig.output {
126 ReturnType::Type(_, ty) => match ty.as_ref() {
127 Type::Path(path)
128 if path.path.segments.len() == 1
129 && path.path.segments[0].ident == "Result"
130 && matches_result_type(&path.path) => {}
131 _ => {
132 return syn::Error::new_spanned(
133 ty,
134 "cli strategy functions must return Result<(), cmdkit::StrategyError>",
135 )
136 .into_compile_error()
137 .into();
138 }
139 },
140 ReturnType::Default => {
141 return syn::Error::new_spanned(
142 &input_fn.sig,
143 "cli strategy functions must return Result<(), cmdkit::StrategyError>",
144 )
145 .into_compile_error()
146 .into();
147 }
148 }
149
150 let fn_ident = &input_fn.sig.ident;
151 let vis = &input_fn.vis;
152 let strategy_ident = format_ident!("{}", to_pascal(&fn_ident.to_string()));
153 let factory_ident = format_ident!("{}_strategy", fn_ident);
154 let attrs = &input_fn.attrs;
155 let body = &input_fn.block;
156
157 let expanded = quote! {
158 #(#attrs)*
159 #vis struct #strategy_ident;
160
161 impl #strategy_ident {
162 #vis fn new() -> Self {
163 Self
164 }
165 }
166
167 impl ::cmdkit::CommandStrategy for #strategy_ident {
168 fn execute(
169 &self,
170 #options_pat: Vec<::cmdkit::Switch>,
171 #arguments_pat: Vec<::cmdkit::Argument>,
172 #subcommands_pat: Vec<String>,
173 ) -> Result<(), ::cmdkit::StrategyError> {
174 #body
175 }
176 }
177
178 #vis fn #factory_ident() -> #strategy_ident {
179 #strategy_ident::new()
180 }
181 };
182
183 expanded.into()
184}
185
186fn to_pascal(s: &str) -> String {
187 let mut out = String::new();
188 for part in s.split('_') {
189 if part.is_empty() {
190 continue;
191 }
192 let mut chars = part.chars();
193 if let Some(first) = chars.next() {
194 out.extend(first.to_uppercase());
195 out.push_str(chars.as_str());
196 }
197 }
198 out
199}
200
201fn matches_vec_of_path(ty: &Type, expected_segments: &[&str]) -> bool {
202 let Type::Path(path) = ty else {
203 return false;
204 };
205
206 let Some(last_segment) = path.path.segments.last() else {
207 return false;
208 };
209
210 if last_segment.ident != "Vec" {
211 return false;
212 }
213
214 let PathArguments::AngleBracketed(arguments) = &last_segment.arguments else {
215 return false;
216 };
217
218 let Some(GenericArgument::Type(inner_type)) = arguments.args.first() else {
219 return false;
220 };
221
222 matches_path_segments(inner_type, expected_segments)
223}
224
225fn matches_result_type(path: &syn::Path) -> bool {
226 let Some(last_segment) = path.segments.last() else {
227 return false;
228 };
229
230 let PathArguments::AngleBracketed(arguments) = &last_segment.arguments else {
231 return false;
232 };
233
234 let mut args = arguments.args.iter();
235
236 matches!(args.next(), Some(GenericArgument::Type(Type::Tuple(tuple))) if tuple.elems.is_empty())
237 && matches!(
238 args.next(),
239 Some(GenericArgument::Type(inner_type)) if matches_path_segments(inner_type, &["StrategyError"])
240 || matches_path_segments(inner_type, &["cmdkit", "StrategyError"])
241 )
242 && args.next().is_none()
243}
244
245fn matches_path_segments(ty: &Type, expected_segments: &[&str]) -> bool {
246 let Type::Path(path) = ty else {
247 return false;
248 };
249
250 let actual_segments: Vec<_> = path
251 .path
252 .segments
253 .iter()
254 .map(|segment| segment.ident.to_string())
255 .collect();
256
257 actual_segments == expected_segments
258}