as_method/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use syn::visit_mut::VisitMut;
4
5/// Call function with the method syntax!
6#[proc_macro_attribute]
7pub fn as_method(
8    attr: proc_macro::TokenStream,
9    item: proc_macro::TokenStream,
10) -> proc_macro::TokenStream {
11    if !attr.is_empty() {
12        return syn::Error::new(proc_macro2::Span::call_site(), "unexpected attr(s)")
13            .into_compile_error()
14            .into();
15    }
16
17    let mut func: syn::ItemFn = match syn::parse(item) {
18        Ok(func) => func,
19        Err(err) => return err.into_compile_error().into(),
20    };
21
22    let mut visitor = ImplTraitReplace(Vec::new());
23    let Some(syn::FnArg::Typed(self_ty)) = func.sig.inputs.first_mut() else {
24        return syn::Error::new(
25            func.sig.paren_token.span.open(),
26            "expected at least one parameter",
27        )
28        .into_compile_error()
29        .into();
30    };
31    visitor.visit_type_mut(&mut *self_ty.ty);
32
33    for type_param in visitor.0 {
34        func.sig
35            .generics
36            .params
37            .push(syn::GenericParam::Type(type_param));
38    }
39
40    let self_ty = self_ty.ty.clone();
41
42    let vis = &func.vis;
43    let name = &func.sig.ident;
44    let ret_ty = &func.sig.output;
45
46    let mut arg_tys = Vec::new();
47    for input in func.sig.inputs.iter().skip(1) {
48        match input {
49            syn::FnArg::Typed(pat_type) => arg_tys.push(&*pat_type.ty),
50            syn::FnArg::Receiver(receiver) => {
51                return syn::Error::new(receiver.self_token.span, "unexpected self receiver")
52                    .into_compile_error()
53                    .into()
54            }
55        }
56    }
57
58    let args = (1..=arg_tys.len())
59        .map(|i| quote::format_ident!("x{}", i))
60        .collect::<Vec<_>>();
61
62    let (impl_generics, ty_generics, where_clause) = func.sig.generics.split_for_impl();
63
64    quote::quote! {
65        #func
66
67        #[allow(non_camel_case_types)]
68        #vis trait #name #ty_generics #where_clause {
69            fn #name(self, #(#args: #arg_tys),*) #ret_ty;
70        }
71
72        impl #impl_generics #name #ty_generics for #self_ty #where_clause {
73            fn #name(self, #(#args: #arg_tys),*) #ret_ty {
74                #name(self, #(#args),*)
75            }
76        }
77    }
78    .into()
79}
80
81struct ImplTraitReplace(Vec<syn::TypeParam>);
82
83impl VisitMut for ImplTraitReplace {
84    fn visit_type_mut(&mut self, node: &mut syn::Type) {
85        if let syn::Type::ImplTrait(type_impl_trait) = node {
86            let ident = quote::format_ident!("AS_METHOD_SELF_T{}", self.0.len());
87            self.0.push(syn::TypeParam {
88                attrs: Vec::new(),
89                ident: ident.clone(),
90                colon_token: None,
91                bounds: type_impl_trait.bounds.clone(),
92                eq_token: None,
93                default: None,
94            });
95
96            let mut segments = syn::punctuated::Punctuated::new();
97            segments.push(syn::PathSegment {
98                ident,
99                arguments: syn::PathArguments::None,
100            });
101
102            *node = syn::Type::Path(syn::TypePath {
103                qself: None,
104                path: syn::Path {
105                    leading_colon: None,
106                    segments,
107                },
108            });
109        } else {
110            syn::visit_mut::visit_type_mut(self, node);
111        }
112    }
113}