1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 parse::{Parse, ParseStream},
5 parse_macro_input,
6 punctuated::Punctuated,
7 AngleBracketedGenericArguments, DeriveInput, Expr, ExprLit, FnArg, GenericArgument, Ident,
8 ImplItem, ItemImpl, Lit, MetaNameValue, Pat, PathArguments, Result, Token, Type, TypePath,
9};
10
11struct Attributes {
13 report_id: u8,
14 cmd_len: usize,
15}
16
17impl Parse for Attributes {
18 fn parse(input: ParseStream) -> Result<Self> {
19 let args = Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)?;
21 let mut report_id_opt = None;
22 let mut cmd_len_opt = None;
23
24 for arg in args {
25 let key = arg
27 .path
28 .get_ident()
29 .ok_or_else(|| syn::Error::new_spanned(&arg.path, "Expected identifier"))?
30 .to_string();
31
32 let lit_int = if let Expr::Lit(ExprLit {
34 lit: Lit::Int(ref i),
35 ..
36 }) = arg.value
37 {
38 i
39 } else {
40 return Err(syn::Error::new_spanned(
41 &arg.value,
42 "Expected integer literal",
43 ));
44 };
45
46 match key.as_str() {
47 "report_id" => {
48 report_id_opt = Some(lit_int.base10_parse()?);
49 }
50 "cmd_len" => {
51 cmd_len_opt = Some(lit_int.base10_parse()?);
52 }
53 _ => return Err(syn::Error::new_spanned(arg, "Unknown attribute key")),
54 }
55 }
56
57 let report_id =
59 report_id_opt.ok_or_else(|| syn::Error::new(input.span(), "Missing `report_id`"))?;
60 let cmd_len =
61 cmd_len_opt.ok_or_else(|| syn::Error::new(input.span(), "Missing `cmd_len`"))?;
62
63 Ok(Attributes { report_id, cmd_len })
64 }
65}
66
67#[proc_macro_derive(CommandDescriptor, attributes(command_descriptor))]
68pub fn derive_my_trait(input: TokenStream) -> TokenStream {
69 let ast = parse_macro_input!(input as DeriveInput);
70
71 let mut args_opt = None;
72 for attr in ast.attrs.iter() {
73 if attr.path().is_ident("command_descriptor") {
74 let args: Attributes = attr
76 .parse_args()
77 .expect("Failed to parse command_descriptor arguments");
78 args_opt = Some(args);
79 break;
80 }
81 }
82
83 let args = args_opt.expect("Missing #[command_descriptor(...)] attribute");
84 let report_id = args.report_id;
85 let cmd_len = args.cmd_len;
86
87 let name = &ast.ident;
88
89 let gen = quote! {
90 impl CommandDescriptor for #name {
91 fn report_id() -> u8 {
92 #report_id
93 }
94
95 fn cmd_len() -> usize {
96 #cmd_len
97 }
98 }
99 };
100
101 TokenStream::from(gen)
102}
103
104#[proc_macro_attribute]
105pub fn command_extension(_attr: TokenStream, item: TokenStream) -> TokenStream {
106 let input = parse_macro_input!(item as ItemImpl);
107
108 let target_type = &*input.self_ty;
109
110 let (_, generic_arg_type) = match target_type {
112 Type::Path(TypePath { path, .. }) => {
113 let first_segment = path.segments.first().expect("Expected a path segment");
114 if first_segment.ident != "Command" {
115 return syn::Error::new_spanned(
116 target_type,
117 "command_extension only works on impl blocks for Command<T>",
118 )
119 .to_compile_error()
120 .into();
121 }
122 match &first_segment.arguments {
124 PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) => {
125 if args.len() != 1 {
126 return syn::Error::new_spanned(
127 &first_segment.arguments,
128 "Expected exactly one generic argument",
129 )
130 .to_compile_error()
131 .into();
132 }
133 let generic_arg = args.first().unwrap();
134 match generic_arg {
135 GenericArgument::Type(ty) => (first_segment.ident.clone(), ty.clone()),
136 _ => {
137 return syn::Error::new_spanned(
138 generic_arg,
139 "Expected a type as the generic argument",
140 )
141 .to_compile_error()
142 .into();
143 }
144 }
145 }
146 _ => {
147 return syn::Error::new_spanned(
148 &first_segment.arguments,
149 "Expected angle bracketed generic arguments",
150 )
151 .to_compile_error()
152 .into();
153 }
154 }
155 }
156 _ => {
157 return syn::Error::new_spanned(
158 target_type,
159 "command_extension can only be applied to impl blocks for Command<T>",
160 )
161 .to_compile_error()
162 .into();
163 }
164 };
165
166 let inner_type_ident = match generic_arg_type {
167 Type::Path(TypePath { ref path, .. }) => path
168 .segments
169 .last()
170 .expect("Expected at least one segment in the generic type")
171 .ident
172 .clone(),
173 _ => {
174 return syn::Error::new_spanned(
175 generic_arg_type,
176 "Expected a simple type for the generic parameter",
177 )
178 .to_compile_error()
179 .into();
180 }
181 };
182
183 let trait_name_str = format!("{}Ext", inner_type_ident);
186 let trait_ident = Ident::new(&trait_name_str, proc_macro2::Span::call_site());
187 let builder_trait_name_str = format!("{}BuilderExt", inner_type_ident);
189 let builder_trait_ident = Ident::new(&builder_trait_name_str, proc_macro2::Span::call_site());
190
191 let mut cmd_trait_methods = Vec::new();
193 let mut cmd_impl_methods = Vec::new();
194
195 let mut builder_trait_methods = Vec::new();
196 let mut builder_impl_methods = Vec::new();
197
198 for item in input.items.iter() {
200 if let ImplItem::Fn(method) = item {
201 let sig = &method.sig;
202 let attrs = &method.attrs;
203
204 let trait_method = quote! {
206 #(#attrs)*
207 #sig;
208 };
209 cmd_trait_methods.push(trait_method);
210
211 let block = &method.block;
212 let impl_method = quote! {
213 #(#attrs)*
214 #sig #block
215 };
216 cmd_impl_methods.push(impl_method);
217
218 let method_name = sig.ident.to_string();
220 if let Some(stripped) = method_name.strip_prefix("set_") {
221 let builder_method_ident = Ident::new(stripped, sig.ident.span());
223
224 let mut builder_inputs = Vec::new();
229 let mut arg_idents = Vec::new();
230 for input in sig.inputs.iter().skip(1) {
232 builder_inputs.push(input);
233 if let FnArg::Typed(pat_type) = input {
235 if let Pat::Ident(pat_ident) = *pat_type.pat.clone() {
237 arg_idents.push(pat_ident.ident);
238 }
239 }
240 }
241
242 let builder_sig = quote! {
245 fn #builder_method_ident(self, #(#builder_inputs),* ) -> Self;
246 };
247
248 builder_trait_methods.push(builder_sig);
249
250 let setter_ident = sig.ident.clone();
252
253 let builder_impl = quote! {
255 fn #builder_method_ident(mut self, #(#builder_inputs),* ) -> Self {
256 self.command.#setter_ident( #(#arg_idents),* );
257 self
258 }
259 };
260
261 builder_impl_methods.push(builder_impl);
262 }
263 }
264 }
265
266 let cmd_trait_def = quote! {
268 pub trait #trait_ident {
269 #(#cmd_trait_methods)*
270 }
271 };
272
273 let cmd_impl_block = quote! {
274 impl #trait_ident for #target_type {
275 #(#cmd_impl_methods)*
276 }
277 };
278
279 let builder_trait_def = quote! {
281 pub trait #builder_trait_ident {
282 #(#builder_trait_methods)*
283 }
284 };
285
286 let builder_target = quote! { CommandBuilder<#generic_arg_type> };
288
289 let builder_impl_block = quote! {
290 impl #builder_trait_ident for #builder_target {
291 #(#builder_impl_methods)*
292 }
293 };
294
295 let output = quote! {
299 #cmd_trait_def
300 #cmd_impl_block
301
302 #builder_trait_def
303 #builder_impl_block
304 };
305
306 output.into()
307}