1#![warn(missing_docs)]
3#![allow(clippy::style)]
4
5use proc_macro::TokenStream;
6
7use quote::quote;
8
9struct TypeInfo {
10 ident: syn::Ident,
11 generics: Option<syn::AngleBracketedGenericArguments>,
12 reference: Option<syn::Lifetime>,
13 mutability: Option<syn::token::Mut>
14}
15
16fn generate_self_trait_bound(generic_name: syn::Ident, trait_name: &syn::Ident) -> syn::GenericArgument {
17 let mut segments = syn::punctuated::Punctuated::new();
18 segments.push(syn::PathSegment {
19 ident: trait_name.clone(),
20 arguments: syn::PathArguments::None,
21 });
22
23 let mut bounds = syn::punctuated::Punctuated::new();
24 bounds.push(syn::TypeParamBound::Trait(syn::TraitBound {
25 paren_token: None,
26 modifier: syn::TraitBoundModifier::None,
27 lifetimes: None,
28 path: syn::Path {
29 leading_colon: None,
30 segments
31 }
32 }));
33 syn::GenericArgument::Constraint(syn::Constraint {
34 ident: generic_name,
35 generics: None,
36 colon_token: syn::Token),
37 bounds
38 })
39}
40
41fn extract_type(typ: &mut syn::Type, trait_name: &syn::Ident, deref_type: &mut Option<syn::Ident>) -> Result<TypeInfo, TokenStream> {
42 match typ {
43 syn::Type::Path(ref mut typ) => {
44 let ident = match typ.path.segments.first() {
45 Some(path) => path.ident.clone(),
46 None => return Err(syn::Error::new_spanned(typ, "Type has no path segments").to_compile_error().into()),
47 };
48
49 match typ.path.segments.last_mut().expect("To have at least on type path segment").arguments {
50 syn::PathArguments::AngleBracketed(ref mut args) => {
51 let result = args.clone();
52
53 for arg in args.args.iter_mut() {
54 if let syn::GenericArgument::Constraint(constraint) = arg {
55
56 for param in constraint.bounds.iter() {
57 if let syn::TypeParamBound::Trait(bound) = param {
58 if bound.path.is_ident(trait_name) {
59 if let Some(ident) = deref_type.replace(constraint.ident.clone()) {
60 return Err(syn::Error::new_spanned(ident, "Multiple bounds to trait, can be problematic so how about no?").to_compile_error().into());
61 }
62 }
63 }
64 }
65
66 let mut segments = syn::punctuated::Punctuated::new();
67 segments.push(syn::PathSegment {
68 ident: constraint.ident.clone(),
69 arguments: syn::PathArguments::None
70 });
71
72 *arg = syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
73 qself: None,
74 path: syn::Path {
75 leading_colon: None,
76 segments
77 },
78 }));
79 }
80 }
81
82 Ok(TypeInfo {
87 ident,
88 generics: Some(result),
89 reference: None,
90 mutability: None,
91 })
92 },
93 syn::PathArguments::None => Ok(TypeInfo {
94 ident,
95 generics: None,
96 reference: None,
97 mutability: None,
98 }),
99 syn::PathArguments::Parenthesized(ref args) => Err(syn::Error::new_spanned(args, "Unsupported type arguments").to_compile_error().into()),
100 }
101 },
102 syn::Type::Reference(reference) => match extract_type(&mut reference.elem, trait_name, deref_type) {
103 Ok(mut result) => {
104 result.mutability = reference.mutability;
105 result.reference = reference.lifetime.clone();
106 Ok(result)
107 },
108 Err(error) => Err(error),
109 }
110 other => Err(syn::Error::new_spanned(other, "Unsupported type").to_compile_error().into()),
111 }
112}
113
114#[proc_macro_attribute]
212pub fn auto_trait(args: TokenStream, input: TokenStream) -> TokenStream {
213 let mut input = syn::parse_macro_input!(input as syn::ItemTrait);
214 let args: syn::Type = match syn::parse(args) {
215 Ok(args) => args,
216 Err(error) => {
217 return syn::Error::new(error.span(), "Argument is required and must be a type").to_compile_error().into()
218 }
219 };
220
221 let mut args = vec![args];
222
223 let mut remaining_attrs = Vec::new();
225 for attr in input.attrs.drain(..) {
226 if attr.path().is_ident("auto_trait") {
227 match attr.parse_args() {
228 Ok(arg) => match arg {
229 syn::Type::Paren(arg) => args.push(*arg.elem),
230 arg => args.push(arg),
231 },
232 Err(error) => {
233 return syn::Error::new(error.span(), "Argument is required and must be a type").to_compile_error().into()
234 }
235 }
236 } else {
237 remaining_attrs.push(attr)
238 }
239 }
240 input.attrs = remaining_attrs;
241
242 let mut impls = Vec::new();
243
244 for mut args in args.drain(..) {
245 let trait_name = input.ident.clone();
246 let mut deref_type = None;
247 let type_info = match extract_type(&mut args, &trait_name, &mut deref_type) {
248 Ok(type_info) => type_info,
249 Err(error) => return error,
250 };
251
252 let deref_name = deref_type.unwrap_or_else(|| trait_name.clone());
253
254 let mut methods = Vec::new();
255
256 for item in input.items.iter() {
257 match item {
258 syn::TraitItem::Fn(ref method) => {
259 let method_name = method.sig.ident.clone();
260 let mut method_args = Vec::new();
261 for arg in method.sig.inputs.iter() {
262 match arg {
263 syn::FnArg::Receiver(arg) => {
264 if arg.reference.is_some() {
265 if arg.mutability.is_some() {
266 if type_info.reference.is_some() {
267 method_args.push(quote! {
268 &mut **self
269 })
270 } else {
271 method_args.push(quote! {
272 core::ops::DerefMut::deref_mut(self)
273 })
274 }
275 } else {
276 if type_info.reference.is_some() {
277 method_args.push(quote! {
278 &**self
279 })
280 } else {
281 method_args.push(quote! {
282 core::ops::Deref::deref(self)
283 })
284 }
285 }
286 } else {
287 method_args.push(quote! {
288 self.into()
289 })
290 }
291 },
292 syn::FnArg::Typed(arg) => {
293 let name = &arg.pat;
294 method_args.push(quote! {
295 #name
296 })
297 },
298 }
299 }
300
301 let deref_block: syn::Block = syn::parse2(quote! {
302 {
303 #deref_name::#method_name(#(#method_args,)*)
304 }
305 }).unwrap();
306
307 let mut method = method.clone();
308 method.default = Some(deref_block);
309 method.semi_token = None;
310
311 methods.push(method);
312 },
313 unsupported => return syn::Error::new_spanned(unsupported, "Trait contains non-method definitions which is unsupported").to_compile_error().into(),
314
315 }
316 }
317
318 let type_generics = if let Some(lifetime) = type_info.reference {
319 match type_info.generics {
320 Some(mut generics) => {
321 let mut new_args = syn::punctuated::Punctuated::new();
322 new_args.insert(0, generate_self_trait_bound(type_info.ident, &trait_name));
323 new_args.insert(0, syn::GenericArgument::Lifetime(lifetime));
324 while let Some(arg) = generics.args.pop() {
325 new_args.push(arg.into_tuple().0);
326 }
327 generics.args = new_args;
328 Some(generics)
329 },
330 None => {
331 let mut args = syn::punctuated::Punctuated::new();
332 args.push(syn::GenericArgument::Lifetime(lifetime));
333 args.push(generate_self_trait_bound(type_info.ident, &trait_name));
334
335 Some(syn::AngleBracketedGenericArguments {
336 colon2_token: None,
337 lt_token: syn::Token),
338 args,
339 gt_token: syn::Token),
340 })
341 }
342 }
343 } else {
344 type_info.generics
345 };
346
347 impls.push(quote! {
348 impl#type_generics #trait_name for #args {
349 #(
350 #methods
351 )*
352 }
353 });
354 }
355
356 let mut result = quote! {
357 #input
358 };
359 result.extend(impls.drain(..));
360
361 result.into()
362}