1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{FnArg, ItemFn, PatType, ReturnType, Type, parse_macro_input};
4
5#[proc_macro_attribute]
6pub fn strategy(attr: TokenStream, item: TokenStream) -> TokenStream {
7 let attr_tokens: proc_macro2::TokenStream = attr.into();
8
9 if !attr_tokens.is_empty() {
10 return syn::Error::new_spanned(
11 attr_tokens,
12 "strategy attribute does not take any arguments",
13 )
14 .into_compile_error()
15 .into();
16 }
17
18 let input_fn = parse_macro_input!(item as ItemFn);
19
20 if input_fn.sig.asyncness.is_some() {
21 return syn::Error::new_spanned(&input_fn.sig, "async functions are not supported")
22 .into_compile_error()
23 .into();
24 }
25
26 let mut inputs = input_fn.sig.inputs.iter();
27 if let Some(FnArg::Receiver(_)) = inputs.next() {
28 return syn::Error::new_spanned(
29 &input_fn.sig,
30 "cli strategy functions must be plain free functions; remove the &self receiver and keep options, arguments, and subcommands arguments",
31 )
32 .into_compile_error()
33 .into();
34 }
35
36 let mut inputs = input_fn.sig.inputs.iter();
37
38 let options_pat = match inputs.next() {
39 Some(FnArg::Typed(PatType { pat, ty, .. })) => {
40 match ty.as_ref() {
41 Type::Path(path)
42 if path
43 .path
44 .segments
45 .last()
46 .is_some_and(|segment| segment.ident == "Vec") => {}
47 _ => {
48 return syn::Error::new_spanned(
49 ty,
50 "cli strategy functions must accept a Vec<Switch> options argument",
51 )
52 .into_compile_error()
53 .into();
54 }
55 }
56
57 pat
58 }
59 _ => {
60 return syn::Error::new_spanned(
61 &input_fn.sig,
62 "cli strategy functions must accept an options Vec<Switch> argument",
63 )
64 .into_compile_error()
65 .into();
66 }
67 };
68
69 let arguments_pat = match inputs.next() {
70 Some(FnArg::Typed(PatType { pat, ty, .. })) => {
71 match ty.as_ref() {
72 Type::Path(path)
73 if path
74 .path
75 .segments
76 .last()
77 .is_some_and(|segment| segment.ident == "Vec") => {}
78 _ => {
79 return syn::Error::new_spanned(
80 ty,
81 "cli strategy functions must accept a Vec<Argument> arguments argument",
82 )
83 .into_compile_error()
84 .into();
85 }
86 }
87
88 pat
89 }
90 _ => {
91 return syn::Error::new_spanned(
92 &input_fn.sig,
93 "cli strategy functions must accept an arguments Vec<Argument> argument",
94 )
95 .into_compile_error()
96 .into();
97 }
98 };
99
100 let subcommands_pat = match inputs.next() {
101 Some(FnArg::Typed(PatType { pat, ty, .. })) => {
102 if inputs.next().is_some() {
103 return syn::Error::new_spanned(
104 &input_fn.sig,
105 "cli strategy functions must accept exactly three parsed invocation arguments",
106 )
107 .into_compile_error()
108 .into();
109 }
110
111 match ty.as_ref() {
112 Type::Path(path)
113 if path
114 .path
115 .segments
116 .last()
117 .is_some_and(|segment| segment.ident == "Vec") => {}
118 _ => {
119 return syn::Error::new_spanned(
120 ty,
121 "cli strategy functions must accept a Vec<String> subcommands argument",
122 )
123 .into_compile_error()
124 .into();
125 }
126 }
127
128 pat
129 }
130 _ => {
131 return syn::Error::new_spanned(
132 &input_fn.sig,
133 "cli strategy functions must accept a subcommands Vec<String> argument",
134 )
135 .into_compile_error()
136 .into();
137 }
138 };
139
140 match &input_fn.sig.output {
141 ReturnType::Type(_, ty) => match ty.as_ref() {
142 Type::Path(path)
143 if path.path.segments.len() == 1 && path.path.segments[0].ident == "Result" => {}
144 _ => {
145 return syn::Error::new_spanned(
146 ty,
147 "cli strategy functions must return Result<(), cmdkit::StrategyError>",
148 )
149 .into_compile_error()
150 .into();
151 }
152 },
153 ReturnType::Default => {
154 return syn::Error::new_spanned(
155 &input_fn.sig,
156 "cli strategy functions must return Result<(), cmdkit::StrategyError>",
157 )
158 .into_compile_error()
159 .into();
160 }
161 }
162
163 let fn_ident = &input_fn.sig.ident;
164 let vis = &input_fn.vis;
165 let strategy_ident = format_ident!("{}", to_pascal(&fn_ident.to_string()));
166 let factory_ident = format_ident!("{}_strategy", fn_ident);
167 let attrs = &input_fn.attrs;
168 let body = &input_fn.block;
169
170 let expanded = quote! {
171 #(#attrs)*
172 #vis struct #strategy_ident;
173
174 impl #strategy_ident {
175 #vis fn new() -> Self {
176 Self
177 }
178 }
179
180 impl ::cmdkit::CommandStrategy for #strategy_ident {
181 fn execute(
182 &self,
183 #options_pat: Vec<::cmdkit::Switch>,
184 #arguments_pat: Vec<::cmdkit::Argument>,
185 #subcommands_pat: Vec<String>,
186 ) -> Result<(), ::cmdkit::StrategyError> {
187 #body
188 }
189 }
190
191 #vis fn #factory_ident() -> #strategy_ident {
192 #strategy_ident::new()
193 }
194 };
195
196 expanded.into()
197}
198
199fn to_pascal(s: &str) -> String {
200 let mut out = String::new();
201 for part in s.split('_') {
202 if part.is_empty() {
203 continue;
204 }
205 let mut chars = part.chars();
206 if let Some(first) = chars.next() {
207 out.extend(first.to_uppercase());
208 out.push_str(chars.as_str());
209 }
210 }
211 out
212}