libatk_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse::{Parse, ParseStream},
5    parse_macro_input,
6    punctuated::Punctuated,
7    AngleBracketedGenericArguments, DeriveInput, Expr, ExprLit, FnArg, GenericArgument, Ident,
8    ImplItem, ItemImpl, Lit, MetaNameValue, Pat, PathArguments, Result, Token, Type, TypePath,
9};
10
11/// Struct for holding the parsed attribute arguments.
12struct Attributes {
13    report_id: u8,
14    cmd_len: usize,
15}
16
17impl Parse for Attributes {
18    fn parse(input: ParseStream) -> Result<Self> {
19        // Parse the comma-separated list of key = value pairs.
20        let args = Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)?;
21        let mut report_id_opt = None;
22        let mut cmd_len_opt = None;
23
24        for arg in args {
25            // Get the key as a string.
26            let key = arg
27                .path
28                .get_ident()
29                .ok_or_else(|| syn::Error::new_spanned(&arg.path, "Expected identifier"))?
30                .to_string();
31
32            // For each arg, we now extract the literal from the expression.
33            let lit_int = if let Expr::Lit(ExprLit {
34                lit: Lit::Int(ref i),
35                ..
36            }) = arg.value
37            {
38                i
39            } else {
40                return Err(syn::Error::new_spanned(
41                    &arg.value,
42                    "Expected integer literal",
43                ));
44            };
45
46            match key.as_str() {
47                "report_id" => {
48                    report_id_opt = Some(lit_int.base10_parse()?);
49                }
50                "cmd_len" => {
51                    cmd_len_opt = Some(lit_int.base10_parse()?);
52                }
53                _ => return Err(syn::Error::new_spanned(arg, "Unknown attribute key")),
54            }
55        }
56
57        // Ensure all required fields were provided.
58        let report_id =
59            report_id_opt.ok_or_else(|| syn::Error::new(input.span(), "Missing `report_id`"))?;
60        let cmd_len =
61            cmd_len_opt.ok_or_else(|| syn::Error::new(input.span(), "Missing `cmd_len`"))?;
62
63        Ok(Attributes { report_id, cmd_len })
64    }
65}
66
67#[proc_macro_derive(CommandDescriptor, attributes(command_descriptor))]
68pub fn derive_my_trait(input: TokenStream) -> TokenStream {
69    let ast = parse_macro_input!(input as DeriveInput);
70
71    let mut args_opt = None;
72    for attr in ast.attrs.iter() {
73        if attr.path().is_ident("command_descriptor") {
74            // Instead of parse_meta, we use parse_args to parse the tokens within parentheses.
75            let args: Attributes = attr
76                .parse_args()
77                .expect("Failed to parse command_descriptor arguments");
78            args_opt = Some(args);
79            break;
80        }
81    }
82
83    let args = args_opt.expect("Missing #[command_descriptor(...)] attribute");
84    let report_id = args.report_id;
85    let cmd_len = args.cmd_len;
86
87    let name = &ast.ident;
88
89    let gen = quote! {
90        impl CommandDescriptor for #name {
91            fn report_id() -> u8 {
92                #report_id
93            }
94
95            fn cmd_len() -> usize {
96                #cmd_len
97            }
98        }
99    };
100
101    TokenStream::from(gen)
102}
103
104#[proc_macro_attribute]
105pub fn command_extension(_attr: TokenStream, item: TokenStream) -> TokenStream {
106    let input = parse_macro_input!(item as ItemImpl);
107
108    let target_type = &*input.self_ty;
109
110    // Expect target_type looks like Command<Something>
111    let (_, generic_arg_type) = match target_type {
112        Type::Path(TypePath { path, .. }) => {
113            let first_segment = path.segments.first().expect("Expected a path segment");
114            if first_segment.ident != "Command" {
115                return syn::Error::new_spanned(
116                    target_type,
117                    "command_extension only works on impl blocks for Command<T>",
118                )
119                .to_compile_error()
120                .into();
121            }
122            // Extract the single generic argument T from Command<T>.
123            match &first_segment.arguments {
124                PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) => {
125                    if args.len() != 1 {
126                        return syn::Error::new_spanned(
127                            &first_segment.arguments,
128                            "Expected exactly one generic argument",
129                        )
130                        .to_compile_error()
131                        .into();
132                    }
133                    let generic_arg = args.first().unwrap();
134                    match generic_arg {
135                        GenericArgument::Type(ty) => (first_segment.ident.clone(), ty.clone()),
136                        _ => {
137                            return syn::Error::new_spanned(
138                                generic_arg,
139                                "Expected a type as the generic argument",
140                            )
141                            .to_compile_error()
142                            .into();
143                        }
144                    }
145                }
146                _ => {
147                    return syn::Error::new_spanned(
148                        &first_segment.arguments,
149                        "Expected angle bracketed generic arguments",
150                    )
151                    .to_compile_error()
152                    .into();
153                }
154            }
155        }
156        _ => {
157            return syn::Error::new_spanned(
158                target_type,
159                "command_extension can only be applied to impl blocks for Command<T>",
160            )
161            .to_compile_error()
162            .into();
163        }
164    };
165
166    let inner_type_ident = match generic_arg_type {
167        Type::Path(TypePath { ref path, .. }) => path
168            .segments
169            .last()
170            .expect("Expected at least one segment in the generic type")
171            .ident
172            .clone(),
173        _ => {
174            return syn::Error::new_spanned(
175                generic_arg_type,
176                "Expected a simple type for the generic parameter",
177            )
178            .to_compile_error()
179            .into();
180        }
181    };
182
183    // Build names for the two traits:
184    // For Command<T> extension trait (the original full set, including both getters and setters).
185    let trait_name_str = format!("{}Ext", inner_type_ident);
186    let trait_ident = Ident::new(&trait_name_str, proc_macro2::Span::call_site());
187    // For the builder extension trait (only for setters). We’ll call it CommandNameBuilderExt.
188    let builder_trait_name_str = format!("{}BuilderExt", inner_type_ident);
189    let builder_trait_ident = Ident::new(&builder_trait_name_str, proc_macro2::Span::call_site());
190
191    // Prepare to collect the methods for the two generated impls.
192    let mut cmd_trait_methods = Vec::new();
193    let mut cmd_impl_methods = Vec::new();
194
195    let mut builder_trait_methods = Vec::new();
196    let mut builder_impl_methods = Vec::new();
197
198    // For each function in the input impl block…
199    for item in input.items.iter() {
200        if let ImplItem::Fn(method) = item {
201            let sig = &method.sig;
202            let attrs = &method.attrs;
203
204            // Add the method to the command trait (the “full” trait) as-is.
205            let trait_method = quote! {
206                #(#attrs)*
207                #sig;
208            };
209            cmd_trait_methods.push(trait_method);
210
211            let block = &method.block;
212            let impl_method = quote! {
213                #(#attrs)*
214                #sig #block
215            };
216            cmd_impl_methods.push(impl_method);
217
218            // Now if the method name begins with "set_", generate the corresponding builder method.
219            let method_name = sig.ident.to_string();
220            if let Some(stripped) = method_name.strip_prefix("set_") {
221                // New (builder) method name: drop the set_ prefix.
222                let builder_method_ident = Ident::new(stripped, sig.ident.span());
223
224                // The builder method’s signature becomes:
225                //     fn <builder_method_ident>(mut self, <args from the original> ) -> Self;
226                // The original setter is assumed to have a receiver (e.g. &mut self) plus one or more parameters.
227                // We remove the original receiver and use "mut self" instead.
228                let mut builder_inputs = Vec::new();
229                let mut arg_idents = Vec::new();
230                // iterate over inputs skipping the receiver.
231                for input in sig.inputs.iter().skip(1) {
232                    builder_inputs.push(input);
233                    // Also extract the identifier from each argument so we can pass it on.
234                    if let FnArg::Typed(pat_type) = input {
235                        // We expect the pat to be a simple identifier.
236                        if let Pat::Ident(pat_ident) = *pat_type.pat.clone() {
237                            arg_idents.push(pat_ident.ident);
238                        }
239                    }
240                }
241
242                // Build the builder method signature. We want something like:
243                //    fn rgb_lighting_effects(mut self, <params>) -> Self;
244                let builder_sig = quote! {
245                    fn #builder_method_ident(self, #(#builder_inputs),* ) -> Self;
246                };
247
248                builder_trait_methods.push(builder_sig);
249
250                // Inside the loop over methods, before generating the builder impl:
251                let setter_ident = sig.ident.clone();
252
253                // Then generate the builder impl using the captured setter_ident:
254                let builder_impl = quote! {
255                    fn #builder_method_ident(mut self, #(#builder_inputs),* ) -> Self {
256                        self.command.#setter_ident( #(#arg_idents),* );
257                        self
258                    }
259                };
260
261                builder_impl_methods.push(builder_impl);
262            }
263        }
264    }
265
266    // Build the command extension trait definition and its impl block.
267    let cmd_trait_def = quote! {
268        pub trait #trait_ident {
269            #(#cmd_trait_methods)*
270        }
271    };
272
273    let cmd_impl_block = quote! {
274        impl #trait_ident for #target_type {
275            #(#cmd_impl_methods)*
276        }
277    };
278
279    // Build the builder extension trait definition.
280    let builder_trait_def = quote! {
281        pub trait #builder_trait_ident {
282            #(#builder_trait_methods)*
283        }
284    };
285
286    // Build the builder target type: CommandBuilder<T>
287    let builder_target = quote! { CommandBuilder<#generic_arg_type> };
288
289    let builder_impl_block = quote! {
290        impl #builder_trait_ident for #builder_target {
291            #(#builder_impl_methods)*
292        }
293    };
294
295    // Finally, put everything together. The output contains:
296    //  - the original command extension trait and impl for Command<T>
297    //  - the builder extension trait and impl for CommandBuilder<T>
298    let output = quote! {
299        #cmd_trait_def
300        #cmd_impl_block
301
302        #builder_trait_def
303        #builder_impl_block
304    };
305
306    output.into()
307}