libatk_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{
4    parse_macro_input, DeriveInput, FnArg, GenericArgument, ImplItem, ItemImpl, Pat, PathArguments,
5    Type, TypePath,
6};
7
8#[proc_macro_derive(Command)]
9pub fn command_trait(input: TokenStream) -> TokenStream {
10    let ast = parse_macro_input!(input as DeriveInput);
11
12    let name = &ast.ident;
13    quote! {
14        impl CommandDescriptor for #name {}
15    }
16    .into()
17}
18
19fn get_inner_type(ty: &syn::Type) -> syn::Result<syn::Type> {
20    match ty {
21        Type::Path(TypePath { path, .. }) => {
22            let segment = path.segments.last().unwrap();
23            if segment.ident != "Command" {
24                return Err(syn::Error::new_spanned(
25                    ty.to_token_stream(),
26                    "#[command] only works on impl blocks for Command<T>",
27                ));
28            }
29            match &segment.arguments {
30                PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
31                    args, ..
32                }) => {
33                    if args.len() != 1 {
34                        return Err(syn::Error::new_spanned(
35                            &segment.arguments,
36                            "Expected exactly one generic argument",
37                        ));
38                    }
39                    let arg = args.first().unwrap();
40                    return match arg {
41                        GenericArgument::Type(t) => Ok(t.clone()),
42                        _ => Err(syn::Error::new_spanned(
43                            arg,
44                            "Expected a type as the generic argument",
45                        )),
46                    };
47                }
48                _ => {
49                    return Err(syn::Error::new_spanned(
50                        &segment.arguments,
51                        "Expected angle bracketed generic arguments",
52                    ))
53                }
54            }
55        }
56        _ => Err(syn::Error::new_spanned(
57            &ty.to_token_stream(),
58            "#[command] only works on impl blocks for Command<T>",
59        )),
60    }
61}
62
63fn get_ident_from_type(ty: &syn::Type) -> syn::Result<syn::Ident> {
64    match ty {
65        Type::Path(TypePath { path, .. }) => {
66            let segment = path.segments.last();
67            match segment {
68                Some(seg) => Ok(seg.ident.clone()),
69                None => Err(syn::Error::new_spanned(
70                    ty,
71                    "Expected at least one segment in the generic type",
72                )),
73            }
74        }
75        _ => Err(syn::Error::new_spanned(
76            ty,
77            "Expected a simple type for the generic parameter",
78        )),
79    }
80}
81
82#[proc_macro_attribute]
83pub fn command_extension(_attr: TokenStream, item: TokenStream) -> TokenStream {
84    let input = parse_macro_input!(item as ItemImpl);
85    let impl_ty = input.self_ty;
86    let inner_ty = match get_inner_type(&impl_ty) {
87        Ok(ty) => ty,
88        Err(e) => return e.to_compile_error().into(),
89    };
90
91    let inner_ident = match get_ident_from_type(&inner_ty) {
92        Ok(ident) => ident,
93        Err(e) => return e.to_compile_error().into(),
94    };
95
96    let extension_trait_ident = syn::Ident::new(&format!("{}Ext", inner_ident), inner_ident.span());
97    let mut extension_trait_methods = Vec::new();
98    let mut extension_impl_methods = Vec::new();
99
100    let builder_trait_ident =
101        syn::Ident::new(&format!("{}BuilderExt", inner_ident), inner_ident.span());
102    let mut builder_trait_methods = Vec::new();
103    let mut builder_trait_impl = Vec::new();
104
105    for item in input.items.iter() {
106        if let ImplItem::Fn(method) = item {
107            let sig = &method.sig;
108            let attrs = &method.attrs;
109            let block = &method.block;
110
111            extension_trait_methods.push(quote! {
112                #(#attrs)*
113                #sig;
114            });
115            extension_impl_methods.push(quote! {
116                #(#attrs)*
117                #sig #block
118            });
119
120            let method_name = sig.ident.to_string();
121            if method_name.starts_with("set_") {
122                if sig.inputs.len() < 2 {
123                    return syn::Error::new_spanned(
124                        sig,
125                        "Expected at least one argument for setter method",
126                    )
127                    .to_compile_error()
128                    .into();
129                }
130
131                let new_method_name =
132                    syn::Ident::new(method_name.strip_prefix("set_").unwrap(), sig.ident.span());
133                let mut builder_inputs = Vec::new();
134                let mut arg_idents = Vec::new();
135                for input in sig.inputs.iter().skip(1) {
136                    builder_inputs.push(input);
137                    if let FnArg::Typed(pat_type) = input {
138                        if let Pat::Ident(pat_ident) = *pat_type.pat.clone() {
139                            arg_idents.push(pat_ident.ident);
140                        }
141                    }
142                }
143
144                let builder_sig = quote! {
145                    fn #new_method_name(self, #(#builder_inputs),* ) -> Self;
146                };
147                builder_trait_methods.push(builder_sig);
148                let setter_ident = &sig.ident;
149                builder_trait_impl.push(quote! {
150                    fn #new_method_name(mut self, #(#builder_inputs),* ) -> Self {
151                        self.command.#setter_ident( #(#arg_idents),* );
152                        self
153                    }
154                });
155            }
156        }
157    }
158
159    let mut out = quote! {
160        pub trait #extension_trait_ident {
161            #(#extension_trait_methods)*
162        }
163
164        impl #extension_trait_ident for #impl_ty {
165            #(#extension_impl_methods)*
166        }
167    };
168
169    if builder_trait_methods.len() > 0 {
170        out.extend(quote! {
171            pub trait #builder_trait_ident {
172                #(#builder_trait_methods)*
173            }
174
175            impl #builder_trait_ident for CommandBuilder<#inner_ident> {
176                #(#builder_trait_impl)*
177            }
178        });
179    }
180    out.into()
181}