libatk_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse::{Parse, ParseStream},
5    parse_macro_input,
6    punctuated::Punctuated,
7    AngleBracketedGenericArguments, DeriveInput, Expr, ExprLit, GenericArgument, Ident, ImplItem,
8    ItemImpl, Lit, MetaNameValue, PathArguments, Result, Token, Type, TypePath,
9};
10
11/// Struct for holding the parsed attribute arguments.
12struct Attributes {
13    base_offset: usize,
14    report_id: u8,
15    cmd_len: usize,
16}
17
18impl Parse for Attributes {
19    fn parse(input: ParseStream) -> Result<Self> {
20        // Parse the comma-separated list of key = value pairs.
21        let args = Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)?;
22        let mut base_offset_opt = None;
23        let mut report_id_opt = None;
24        let mut cmd_len_opt = None;
25
26        for arg in args {
27            // Get the key as a string.
28            let key = arg
29                .path
30                .get_ident()
31                .ok_or_else(|| syn::Error::new_spanned(&arg.path, "Expected identifier"))?
32                .to_string();
33
34            // For each arg, we now extract the literal from the expression.
35            let lit_int = if let Expr::Lit(ExprLit {
36                lit: Lit::Int(ref i),
37                ..
38            }) = arg.value
39            {
40                i
41            } else {
42                return Err(syn::Error::new_spanned(
43                    &arg.value,
44                    "Expected integer literal",
45                ));
46            };
47
48            match key.as_str() {
49                "base_offset" => {
50                    base_offset_opt = Some(lit_int.base10_parse()?);
51                }
52                "report_id" => {
53                    report_id_opt = Some(lit_int.base10_parse()?);
54                }
55                "cmd_len" => {
56                    cmd_len_opt = Some(lit_int.base10_parse()?);
57                }
58                _ => return Err(syn::Error::new_spanned(arg, "Unknown attribute key")),
59            }
60        }
61
62        // Ensure all required fields were provided.
63        let base_offset = base_offset_opt
64            .ok_or_else(|| syn::Error::new(input.span(), "Missing `base_offset`"))?;
65        let report_id =
66            report_id_opt.ok_or_else(|| syn::Error::new(input.span(), "Missing `report_id`"))?;
67        let cmd_len =
68            cmd_len_opt.ok_or_else(|| syn::Error::new(input.span(), "Missing `cmd_len`"))?;
69
70        Ok(Attributes {
71            base_offset,
72            report_id,
73            cmd_len,
74        })
75    }
76}
77
78#[proc_macro_derive(CommandDescriptor, attributes(command_descriptor))]
79pub fn derive_my_trait(input: TokenStream) -> TokenStream {
80    let ast = parse_macro_input!(input as DeriveInput);
81
82    let mut args_opt = None;
83    for attr in ast.attrs.iter() {
84        if attr.path().is_ident("command_descriptor") {
85            // Instead of parse_meta, we use parse_args to parse the tokens within parentheses.
86            let args: Attributes = attr
87                .parse_args()
88                .expect("Failed to parse command_descriptor arguments");
89            args_opt = Some(args);
90            break;
91        }
92    }
93
94    let args = args_opt.expect("Missing #[command_descriptor(...)] attribute");
95    let base_offset = args.base_offset;
96    let report_id = args.report_id;
97    let cmd_len = args.cmd_len;
98
99    let name = &ast.ident;
100
101    let gen = quote! {
102        impl CommandDescriptor for #name {
103            fn base_offset() -> usize {
104                #base_offset
105            }
106
107            fn report_id() -> u8 {
108                #report_id
109            }
110
111            fn cmd_len() -> usize {
112                #cmd_len
113            }
114        }
115    };
116
117    TokenStream::from(gen)
118}
119
120#[proc_macro_attribute]
121pub fn command_extension(_attr: TokenStream, item: TokenStream) -> TokenStream {
122    let input = parse_macro_input!(item as ItemImpl);
123
124    let target_type = &*input.self_ty;
125
126    // Expect target_type looks like Command<Something>
127    let (_, generic_arg_type) = match target_type {
128        Type::Path(TypePath { path, .. }) => {
129            let first_segment = path.segments.first().expect("Expected a path segment");
130            if first_segment.ident != "Command" {
131                return syn::Error::new_spanned(
132                    target_type,
133                    "command_extension only works on impl blocks for Command<T>",
134                )
135                .to_compile_error()
136                .into();
137            }
138            // Extract the single generic argument T from Command<T>.
139            match &first_segment.arguments {
140                PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) => {
141                    if args.len() != 1 {
142                        return syn::Error::new_spanned(
143                            &first_segment.arguments,
144                            "Expected exactly one generic argument",
145                        )
146                        .to_compile_error()
147                        .into();
148                    }
149                    let generic_arg = args.first().unwrap();
150                    match generic_arg {
151                        GenericArgument::Type(ty) => (first_segment.ident.clone(), ty.clone()),
152                        _ => {
153                            return syn::Error::new_spanned(
154                                generic_arg,
155                                "Expected a type as the generic argument",
156                            )
157                            .to_compile_error()
158                            .into();
159                        }
160                    }
161                }
162                _ => {
163                    return syn::Error::new_spanned(
164                        &first_segment.arguments,
165                        "Expected angle bracketed generic arguments",
166                    )
167                    .to_compile_error()
168                    .into();
169                }
170            }
171        }
172        _ => {
173            return syn::Error::new_spanned(
174                target_type,
175                "command_extension can only be applied to impl blocks for Command<T>",
176            )
177            .to_compile_error()
178            .into();
179        }
180    };
181
182    let inner_type_ident = match generic_arg_type {
183        Type::Path(TypePath { ref path, .. }) => path
184            .segments
185            .last()
186            .expect("Expected at least one segment in the generic type")
187            .ident
188            .clone(),
189        _ => {
190            return syn::Error::new_spanned(
191                generic_arg_type,
192                "Expected a simple type for the generic parameter",
193            )
194            .to_compile_error()
195            .into();
196        }
197    };
198
199    let trait_name_str = format!("{}Ext", inner_type_ident);
200    let trait_ident = Ident::new(&trait_name_str, proc_macro2::Span::call_site());
201
202    let mut trait_methods = Vec::new();
203    let mut impl_methods = Vec::new();
204
205    for item in input.items.iter() {
206        if let ImplItem::Fn(method) = item {
207            let sig = &method.sig;
208            let attrs = &method.attrs;
209            let trait_method = quote! {
210                #(#attrs)*
211                #sig;
212            };
213            trait_methods.push(trait_method);
214
215            let block = &method.block;
216            let impl_method = quote! {
217                #(#attrs)*
218                #sig #block
219            };
220            impl_methods.push(impl_method);
221        }
222    }
223
224    let trait_def = quote! {
225        pub trait #trait_ident {
226            #(#trait_methods)*
227        }
228    };
229
230    let impl_block = quote! {
231        impl #trait_ident for #target_type {
232            #(#impl_methods)*
233        }
234    };
235
236    let output = quote! {
237        #trait_def
238        #impl_block
239    };
240
241    output.into()
242}