1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 parse::{Parse, ParseStream},
5 parse_macro_input,
6 punctuated::Punctuated,
7 AngleBracketedGenericArguments, DeriveInput, Expr, ExprLit, GenericArgument, Ident, ImplItem,
8 ItemImpl, Lit, MetaNameValue, PathArguments, Result, Token, Type, TypePath,
9};
10
11struct Attributes {
13 base_offset: usize,
14 report_id: u8,
15 cmd_len: usize,
16}
17
18impl Parse for Attributes {
19 fn parse(input: ParseStream) -> Result<Self> {
20 let args = Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)?;
22 let mut base_offset_opt = None;
23 let mut report_id_opt = None;
24 let mut cmd_len_opt = None;
25
26 for arg in args {
27 let key = arg
29 .path
30 .get_ident()
31 .ok_or_else(|| syn::Error::new_spanned(&arg.path, "Expected identifier"))?
32 .to_string();
33
34 let lit_int = if let Expr::Lit(ExprLit {
36 lit: Lit::Int(ref i),
37 ..
38 }) = arg.value
39 {
40 i
41 } else {
42 return Err(syn::Error::new_spanned(
43 &arg.value,
44 "Expected integer literal",
45 ));
46 };
47
48 match key.as_str() {
49 "base_offset" => {
50 base_offset_opt = Some(lit_int.base10_parse()?);
51 }
52 "report_id" => {
53 report_id_opt = Some(lit_int.base10_parse()?);
54 }
55 "cmd_len" => {
56 cmd_len_opt = Some(lit_int.base10_parse()?);
57 }
58 _ => return Err(syn::Error::new_spanned(arg, "Unknown attribute key")),
59 }
60 }
61
62 let base_offset = base_offset_opt
64 .ok_or_else(|| syn::Error::new(input.span(), "Missing `base_offset`"))?;
65 let report_id =
66 report_id_opt.ok_or_else(|| syn::Error::new(input.span(), "Missing `report_id`"))?;
67 let cmd_len =
68 cmd_len_opt.ok_or_else(|| syn::Error::new(input.span(), "Missing `cmd_len`"))?;
69
70 Ok(Attributes {
71 base_offset,
72 report_id,
73 cmd_len,
74 })
75 }
76}
77
78#[proc_macro_derive(CommandDescriptor, attributes(command_descriptor))]
79pub fn derive_my_trait(input: TokenStream) -> TokenStream {
80 let ast = parse_macro_input!(input as DeriveInput);
81
82 let mut args_opt = None;
83 for attr in ast.attrs.iter() {
84 if attr.path().is_ident("command_descriptor") {
85 let args: Attributes = attr
87 .parse_args()
88 .expect("Failed to parse command_descriptor arguments");
89 args_opt = Some(args);
90 break;
91 }
92 }
93
94 let args = args_opt.expect("Missing #[command_descriptor(...)] attribute");
95 let base_offset = args.base_offset;
96 let report_id = args.report_id;
97 let cmd_len = args.cmd_len;
98
99 let name = &ast.ident;
100
101 let gen = quote! {
102 impl CommandDescriptor for #name {
103 fn base_offset() -> usize {
104 #base_offset
105 }
106
107 fn report_id() -> u8 {
108 #report_id
109 }
110
111 fn cmd_len() -> usize {
112 #cmd_len
113 }
114 }
115 };
116
117 TokenStream::from(gen)
118}
119
120#[proc_macro_attribute]
121pub fn command_extension(_attr: TokenStream, item: TokenStream) -> TokenStream {
122 let input = parse_macro_input!(item as ItemImpl);
123
124 let target_type = &*input.self_ty;
125
126 let (_, generic_arg_type) = match target_type {
128 Type::Path(TypePath { path, .. }) => {
129 let first_segment = path.segments.first().expect("Expected a path segment");
130 if first_segment.ident != "Command" {
131 return syn::Error::new_spanned(
132 target_type,
133 "command_extension only works on impl blocks for Command<T>",
134 )
135 .to_compile_error()
136 .into();
137 }
138 match &first_segment.arguments {
140 PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) => {
141 if args.len() != 1 {
142 return syn::Error::new_spanned(
143 &first_segment.arguments,
144 "Expected exactly one generic argument",
145 )
146 .to_compile_error()
147 .into();
148 }
149 let generic_arg = args.first().unwrap();
150 match generic_arg {
151 GenericArgument::Type(ty) => (first_segment.ident.clone(), ty.clone()),
152 _ => {
153 return syn::Error::new_spanned(
154 generic_arg,
155 "Expected a type as the generic argument",
156 )
157 .to_compile_error()
158 .into();
159 }
160 }
161 }
162 _ => {
163 return syn::Error::new_spanned(
164 &first_segment.arguments,
165 "Expected angle bracketed generic arguments",
166 )
167 .to_compile_error()
168 .into();
169 }
170 }
171 }
172 _ => {
173 return syn::Error::new_spanned(
174 target_type,
175 "command_extension can only be applied to impl blocks for Command<T>",
176 )
177 .to_compile_error()
178 .into();
179 }
180 };
181
182 let inner_type_ident = match generic_arg_type {
183 Type::Path(TypePath { ref path, .. }) => path
184 .segments
185 .last()
186 .expect("Expected at least one segment in the generic type")
187 .ident
188 .clone(),
189 _ => {
190 return syn::Error::new_spanned(
191 generic_arg_type,
192 "Expected a simple type for the generic parameter",
193 )
194 .to_compile_error()
195 .into();
196 }
197 };
198
199 let trait_name_str = format!("{}Ext", inner_type_ident);
200 let trait_ident = Ident::new(&trait_name_str, proc_macro2::Span::call_site());
201
202 let mut trait_methods = Vec::new();
203 let mut impl_methods = Vec::new();
204
205 for item in input.items.iter() {
206 if let ImplItem::Fn(method) = item {
207 let sig = &method.sig;
208 let attrs = &method.attrs;
209 let trait_method = quote! {
210 #(#attrs)*
211 #sig;
212 };
213 trait_methods.push(trait_method);
214
215 let block = &method.block;
216 let impl_method = quote! {
217 #(#attrs)*
218 #sig #block
219 };
220 impl_methods.push(impl_method);
221 }
222 }
223
224 let trait_def = quote! {
225 pub trait #trait_ident {
226 #(#trait_methods)*
227 }
228 };
229
230 let impl_block = quote! {
231 impl #trait_ident for #target_type {
232 #(#impl_methods)*
233 }
234 };
235
236 let output = quote! {
237 #trait_def
238 #impl_block
239 };
240
241 output.into()
242}