1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{
4 parse_macro_input, DeriveInput, FnArg, GenericArgument, ImplItem, ItemImpl, Pat, PathArguments,
5 Type, TypePath,
6};
7
8#[proc_macro_derive(Command)]
9pub fn command_trait(input: TokenStream) -> TokenStream {
10 let ast = parse_macro_input!(input as DeriveInput);
11
12 let name = &ast.ident;
13 quote! {
14 impl CommandDescriptor for #name {}
15 }
16 .into()
17}
18
19fn get_inner_type(ty: &syn::Type) -> syn::Result<syn::Type> {
20 match ty {
21 Type::Path(TypePath { path, .. }) => {
22 let segment = path.segments.last().unwrap();
23 if segment.ident != "Command" {
24 return Err(syn::Error::new_spanned(
25 ty.to_token_stream(),
26 "#[command] only works on impl blocks for Command<T>",
27 ));
28 }
29 match &segment.arguments {
30 PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
31 args, ..
32 }) => {
33 if args.len() != 1 {
34 return Err(syn::Error::new_spanned(
35 &segment.arguments,
36 "Expected exactly one generic argument",
37 ));
38 }
39 let arg = args.first().unwrap();
40 return match arg {
41 GenericArgument::Type(t) => Ok(t.clone()),
42 _ => Err(syn::Error::new_spanned(
43 arg,
44 "Expected a type as the generic argument",
45 )),
46 };
47 }
48 _ => {
49 return Err(syn::Error::new_spanned(
50 &segment.arguments,
51 "Expected angle bracketed generic arguments",
52 ))
53 }
54 }
55 }
56 _ => Err(syn::Error::new_spanned(
57 &ty.to_token_stream(),
58 "#[command] only works on impl blocks for Command<T>",
59 )),
60 }
61}
62
63fn get_ident_from_type(ty: &syn::Type) -> syn::Result<syn::Ident> {
64 match ty {
65 Type::Path(TypePath { path, .. }) => {
66 let segment = path.segments.last();
67 match segment {
68 Some(seg) => Ok(seg.ident.clone()),
69 None => Err(syn::Error::new_spanned(
70 ty,
71 "Expected at least one segment in the generic type",
72 )),
73 }
74 }
75 _ => Err(syn::Error::new_spanned(
76 ty,
77 "Expected a simple type for the generic parameter",
78 )),
79 }
80}
81
82#[proc_macro_attribute]
83pub fn command_extension(_attr: TokenStream, item: TokenStream) -> TokenStream {
84 let input = parse_macro_input!(item as ItemImpl);
85 let impl_ty = input.self_ty;
86 let inner_ty = match get_inner_type(&impl_ty) {
87 Ok(ty) => ty,
88 Err(e) => return e.to_compile_error().into(),
89 };
90
91 let inner_ident = match get_ident_from_type(&inner_ty) {
92 Ok(ident) => ident,
93 Err(e) => return e.to_compile_error().into(),
94 };
95
96 let extension_trait_ident = syn::Ident::new(&format!("{}Ext", inner_ident), inner_ident.span());
97 let mut extension_trait_methods = Vec::new();
98 let mut extension_impl_methods = Vec::new();
99
100 let builder_trait_ident =
101 syn::Ident::new(&format!("{}BuilderExt", inner_ident), inner_ident.span());
102 let mut builder_trait_methods = Vec::new();
103 let mut builder_trait_impl = Vec::new();
104
105 for item in input.items.iter() {
106 if let ImplItem::Fn(method) = item {
107 let sig = &method.sig;
108 let attrs = &method.attrs;
109 let block = &method.block;
110
111 extension_trait_methods.push(quote! {
112 #(#attrs)*
113 #sig;
114 });
115 extension_impl_methods.push(quote! {
116 #(#attrs)*
117 #sig #block
118 });
119
120 let method_name = sig.ident.to_string();
121 if method_name.starts_with("set_") {
122 if sig.inputs.len() < 2 {
123 return syn::Error::new_spanned(
124 sig,
125 "Expected at least one argument for setter method",
126 )
127 .to_compile_error()
128 .into();
129 }
130
131 let new_method_name =
132 syn::Ident::new(method_name.strip_prefix("set_").unwrap(), sig.ident.span());
133 let mut builder_inputs = Vec::new();
134 let mut arg_idents = Vec::new();
135 for input in sig.inputs.iter().skip(1) {
136 builder_inputs.push(input);
137 if let FnArg::Typed(pat_type) = input {
138 if let Pat::Ident(pat_ident) = *pat_type.pat.clone() {
139 arg_idents.push(pat_ident.ident);
140 }
141 }
142 }
143
144 let builder_sig = quote! {
145 fn #new_method_name(self, #(#builder_inputs),* ) -> Self;
146 };
147 builder_trait_methods.push(builder_sig);
148 let setter_ident = &sig.ident;
149 builder_trait_impl.push(quote! {
150 fn #new_method_name(mut self, #(#builder_inputs),* ) -> Self {
151 self.command.#setter_ident( #(#arg_idents),* );
152 self
153 }
154 });
155 }
156 }
157 }
158
159 let mut out = quote! {
160 pub trait #extension_trait_ident {
161 #(#extension_trait_methods)*
162 }
163
164 impl #extension_trait_ident for #impl_ty {
165 #(#extension_impl_methods)*
166 }
167 };
168
169 if builder_trait_methods.len() > 0 {
170 out.extend(quote! {
171 pub trait #builder_trait_ident {
172 #(#builder_trait_methods)*
173 }
174
175 impl #builder_trait_ident for CommandBuilder<#inner_ident> {
176 #(#builder_trait_impl)*
177 }
178 });
179 }
180 out.into()
181}