partial_application_rs/
lib.rs

1#![feature(proc_macro_diagnostic)]
2extern crate proc_macro;
3use proc_macro::TokenStream;
4use proc_macro2;
5use quote::quote;
6use syn;
7use syn::spanned::Spanned;
8
9/// Turns function into partially applicable functions.
10#[proc_macro_attribute]
11pub fn part_app(attr: TokenStream, item: TokenStream) -> TokenStream {
12    let func_item: syn::Item = syn::parse(item).expect("failed to parse input");
13    let attr_options = MacroOptions::from(attr);
14    attr_options.check(&func_item);
15
16    match func_item {
17        syn::Item::Fn(ref func) => {
18            let fn_info = FunctionInformation::from(func);
19            fn_info.check();
20
21            // disallow where clauses
22            if let Some(w) = &func.sig.generics.where_clause {
23                w.span()
24                    .unstable()
25                    .error("part_app does not allow where clauses")
26                    .emit();
27            }
28
29            let func_struct = main_struct(&fn_info, &attr_options);
30            let generator_func = generator_func(&fn_info, &attr_options);
31            let final_call = final_call(&fn_info, &attr_options);
32            let argument_calls = argument_calls(&fn_info, &attr_options);
33
34            let unit_structs = {
35                let added_unit = fn_info.unit.added;
36                let empty_unit = fn_info.unit.empty;
37                let vis = fn_info.public;
38                quote! {
39                    #[allow(non_camel_case_types,non_snake_case)]
40                    #vis struct #added_unit;
41                    #[allow(non_camel_case_types,non_snake_case)]
42                    #vis struct #empty_unit;
43                }
44            };
45
46            // assemble output
47            let mut out = proc_macro2::TokenStream::new();
48            out.extend(unit_structs);
49            out.extend(func_struct);
50            out.extend(generator_func);
51            out.extend(argument_calls);
52            out.extend(final_call);
53            // println!("{}", out);
54            TokenStream::from(out)
55        }
56        _ => {
57            func_item
58                .span()
59                .unstable()
60                .error(
61                    "Only functions can be partially applied, for structs use the builder pattern",
62                )
63                .emit();
64            proc_macro::TokenStream::from(quote! { #func_item })
65        }
66    }
67}
68
69/// The portion of the signature necessary for each impl block
70fn impl_signature<'a>(
71    args: &Vec<&syn::PatType>,
72    ret_type: &'a syn::ReturnType,
73    generics: &Vec<&syn::GenericParam>,
74    opts: &MacroOptions,
75) -> proc_macro2::TokenStream {
76    let arg_names = arg_names(&args);
77    let arg_types = arg_types(&args);
78    let augmented_names = if !(opts.impl_poly || opts.by_value) {
79        augmented_argument_names(&arg_names)
80    } else {
81        Vec::new()
82    };
83
84    quote! {
85        #(#generics,)*
86        #(#augmented_names: Fn() -> #arg_types,)*
87        BODYFN: Fn(#(#arg_types,)*) #ret_type
88    }
89}
90
91/// Generates the methods used to add argument values to a partially applied function. One method is generate for each
92/// argument and each method is contained in it's own impl block.
93fn argument_calls<'a>(
94    fn_info: &FunctionInformation,
95    opts: &MacroOptions,
96) -> proc_macro2::TokenStream {
97    let impl_sig = impl_signature(
98        &fn_info.argument_vec,
99        fn_info.ret_type,
100        &fn_info.generics,
101        opts,
102    );
103    let arg_name_vec = arg_names(&fn_info.argument_vec);
104    let aug_arg_names = augmented_argument_names(&arg_name_vec);
105    let arg_types = arg_types(&fn_info.argument_vec);
106    arg_names(&fn_info.argument_vec)
107        .into_iter()
108        .zip(&aug_arg_names)
109        .zip(arg_types)
110        .map(|((n, n_fn), n_type)| {
111            // All variable names except the name of this function
112            let free_vars: Vec<_> = arg_name_vec.iter().filter(|&x| x != &n).collect();
113            let associated_vals_out: Vec<_> = arg_name_vec
114                .iter()
115                .map(|x| {
116                    if &n == x {
117                        fn_info.unit.added.clone()
118                    } else {
119                        x.clone()
120                    }
121                })
122                .collect();
123            let val_list_out = if opts.impl_poly || opts.by_value {
124                quote! {#(#associated_vals_out,)*}
125            } else {
126                quote! {#(#associated_vals_out, #aug_arg_names,)*}
127            };
128            let associated_vals_in: Vec<_> = associated_vals_out
129                .iter()
130                .map(|x| {
131                    if x == &fn_info.unit.added {
132                        &fn_info.unit.empty
133                    } else {
134                        x
135                    }
136                })
137                .collect();
138            let val_list_in = if opts.impl_poly || opts.by_value {
139                quote! {#(#associated_vals_in,)*}
140            } else {
141                quote! {#(#associated_vals_in, #aug_arg_names,)*}
142            };
143            let (transmute, self_type) = if opts.impl_poly || opts.impl_clone {
144                (quote!(transmute), quote!(self))
145            } else {
146                // if by_value
147                (quote!(transmute_copy), quote!(&self))
148            };
149            let some = if opts.impl_poly {
150                quote! {Some(::std::sync::Arc::from(#n))}
151            } else {
152                // || by_value
153                quote! {Some(#n)}
154            };
155            let in_type = if opts.impl_poly {
156                quote! { Box<dyn Fn() -> #n_type> }
157            } else if opts.by_value {
158                quote! {#n_type}
159            } else {
160                quote! { #n_fn }
161            };
162            let struct_name = &fn_info.struct_name;
163            let generics = &fn_info.generics;
164            let vis = fn_info.public;
165            quote! {
166                #[allow(non_camel_case_types,non_snake_case)]
167                impl< #impl_sig, #(#free_vars,)* > // The impl signature
168                    #struct_name<#(#generics,)* #val_list_in BODYFN> // The struct signature
169                {
170                    #vis fn #n (mut self, #n: #in_type) ->
171                        #struct_name<#(#generics,)* #val_list_out BODYFN>{
172                        self.#n = #some;
173                        unsafe {
174                            ::std::mem::#transmute::<
175                                #struct_name<#(#generics,)* #val_list_in BODYFN>,
176                            #struct_name<#(#generics,)* #val_list_out BODYFN>,
177                            >(#self_type)
178                        }
179                    }
180                }
181            }
182        })
183        .collect()
184}
185
186/// Generates the call function, which executes a fully filled out partially applicable struct.
187fn final_call<'a>(fn_info: &FunctionInformation, opts: &MacroOptions) -> proc_macro2::TokenStream {
188    let ret_type = fn_info.ret_type;
189    let generics = &fn_info.generics;
190    let unit_added = &fn_info.unit.added;
191    let struct_name = &fn_info.struct_name;
192    let impl_sig = impl_signature(&fn_info.argument_vec, ret_type, generics, opts);
193    let arg_names = arg_names(&fn_info.argument_vec);
194    let aug_args = augmented_argument_names(&arg_names);
195    let vis = fn_info.public;
196    let arg_list: proc_macro2::TokenStream = if opts.impl_poly || opts.by_value {
197        aug_args.iter().map(|_| quote! {#unit_added,}).collect()
198    } else {
199        aug_args.iter().map(|a| quote! {#unit_added, #a,}).collect()
200    };
201    let call = if !opts.by_value {
202        quote! {()}
203    } else {
204        quote! {}
205    };
206    quote! {
207        #[allow(non_camel_case_types,non_snake_case)]
208        impl <#impl_sig> // impl signature
209            // struct signature
210            #struct_name<#(#generics,)* #arg_list BODYFN>
211        {
212            #vis fn call(self) #ret_type { // final call
213                (self.body)(#(self.#arg_names.unwrap()#call,)*)
214            }
215        }
216    }
217}
218
219/// The function called by the user to create an instance of a partially applicable function. This function always has
220/// the name of the original function the macro is called on.
221fn generator_func<'a>(
222    fn_info: &FunctionInformation,
223    opts: &MacroOptions,
224) -> proc_macro2::TokenStream {
225    // because the quote! macro cannot expand fields
226    let arg_names = arg_names(&fn_info.argument_vec);
227    let arg_types = arg_types(&fn_info.argument_vec);
228    let marker_names = marker_names(&arg_names);
229    let generics = &fn_info.generics;
230    let empty_unit = &fn_info.unit.empty;
231    let body = fn_info.block;
232    let name = fn_info.fn_name;
233    let struct_name = &fn_info.struct_name;
234    let ret_type = fn_info.ret_type;
235    let vis = fn_info.public;
236
237    let gen_types = if opts.impl_poly || opts.by_value {
238        quote! {#(#generics,)*}
239    } else {
240        quote! {#(#generics,)* #(#arg_names,)*}
241    };
242    let struct_types = if opts.impl_poly || opts.by_value {
243        arg_names.iter().map(|_| quote! {#empty_unit,}).collect()
244    } else {
245        quote! {#(#empty_unit,#arg_names,)*}
246    };
247    let body_fn = if opts.impl_poly || opts.impl_clone {
248        quote! {::std::sync::Arc::new(|#(#arg_names,)*| #body),}
249    } else {
250        quote! {|#(#arg_names,)*| #body,}
251    };
252    let where_clause = if opts.impl_poly || opts.by_value {
253        quote!()
254    } else {
255        quote! {
256            where
257                #(#arg_names: Fn() -> #arg_types,)*
258        }
259    };
260    quote! {
261        #[allow(non_camel_case_types,non_snake_case)]
262        #vis fn #name<#gen_types>() -> #struct_name<#(#generics,)* #struct_types
263        impl Fn(#(#arg_types,)*) #ret_type>
264            #where_clause
265        {
266            #struct_name {
267                #(#arg_names: None,)*
268                #(#marker_names: ::std::marker::PhantomData,)*
269                body: #body_fn
270            }
271        }
272
273    }
274}
275
276/// A vector of all argument names. Simple parsing.
277fn arg_names<'a>(args: &Vec<&syn::PatType>) -> Vec<syn::Ident> {
278    args.iter()
279        .map(|f| {
280            let f_pat = &f.pat;
281            syn::Ident::new(&format!("{}", quote!(#f_pat)), f.span())
282        })
283        .collect()
284}
285
286/// The vector of names used to hold PhantomData.
287fn marker_names(names: &Vec<syn::Ident>) -> Vec<syn::Ident> {
288    names.iter().map(|f| concat_ident(f, "m")).collect()
289}
290
291/// Concatenates a identity with a string, returning a new identity with the same span.
292fn concat_ident<'a>(ident: &'a syn::Ident, end: &str) -> syn::Ident {
293    let name = format!("{}___{}", quote! {#ident}, end);
294    syn::Ident::new(&name, ident.span())
295}
296
297/// Filter an argument list to a pattern type
298fn argument_vector<'a>(
299    args: &'a syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>,
300) -> Vec<&syn::PatType> {
301    args.iter()
302        .map(|fn_arg| match fn_arg {
303            syn::FnArg::Receiver(_) => panic!("should filter out reciever arguments"),
304            syn::FnArg::Typed(t) => {
305                if let syn::Type::Reference(r) = t.ty.as_ref() {
306                    if r.lifetime.is_none() {
307                        t.span()
308                            .unstable()
309                            .error("part_app does not support lifetime elision")
310                            .emit();
311                    }
312                }
313
314                t
315            }
316        })
317        .collect()
318}
319
320/// Retrieves the identities of an the argument list
321fn arg_types<'a>(args: &Vec<&'a syn::PatType>) -> Vec<&'a syn::Type> {
322    args.iter().map(|f| f.ty.as_ref()).collect()
323}
324
325/// Names to hold function types
326fn augmented_argument_names<'a>(arg_names: &Vec<syn::Ident>) -> Vec<syn::Ident> {
327    arg_names.iter().map(|f| concat_ident(f, "FN")).collect()
328}
329
330/// Generates the main struct for the partially applicable function.
331/// All other functions are methods on this struct.
332fn main_struct<'a>(fn_info: &FunctionInformation, opts: &MacroOptions) -> proc_macro2::TokenStream {
333    let arg_types = arg_types(&fn_info.argument_vec);
334
335    let arg_names = arg_names(&fn_info.argument_vec);
336    let arg_augmented = augmented_argument_names(&arg_names);
337    let ret_type = fn_info.ret_type;
338
339    let arg_list: Vec<_> = if !(opts.impl_poly || opts.by_value) {
340        arg_names
341            .iter()
342            .zip(arg_augmented.iter())
343            .flat_map(|(n, a)| vec![n, a])
344            .collect()
345    } else {
346        arg_names.iter().collect()
347    };
348    let bodyfn = if opts.impl_poly || opts.impl_clone {
349        quote! {::std::sync::Arc<BODYFN>}
350    } else {
351        quote! { BODYFN }
352    };
353    let where_clause = if opts.impl_poly || opts.by_value {
354        quote!(BODYFN: Fn(#(#arg_types,)*) #ret_type,)
355    } else {
356        quote! {
357            #(#arg_augmented: Fn() -> #arg_types,)*
358            BODYFN: Fn(#(#arg_types,)*) #ret_type,
359        }
360    };
361    let names_with_m = marker_names(&arg_names);
362    let option_list = if opts.impl_poly {
363        quote! {#(#arg_names: Option<::std::sync::Arc<dyn Fn() -> #arg_types>>,)*}
364    } else if opts.by_value {
365        quote! {#(#arg_names: Option<#arg_types>,)*}
366    } else {
367        quote! {#(#arg_names: Option<#arg_augmented>,)*}
368    };
369    let name = &fn_info.struct_name;
370
371    let clone = if opts.impl_clone {
372        let sig = impl_signature(
373            &fn_info.argument_vec,
374            fn_info.ret_type,
375            &fn_info.generics,
376            opts,
377        );
378        quote! {
379            #[allow(non_camel_case_types,non_snake_case)]
380            impl<#sig, #(#arg_list,)*> ::std::clone::Clone for #name <#(#arg_list,)* BODYFN>
381            where #where_clause
382            {
383                fn clone(&self) -> Self {
384                    Self {
385                        #(#names_with_m: ::std::marker::PhantomData,)*
386                        #(#arg_names: self.#arg_names.clone(),)*
387                        body: self.body.clone(),
388                    }
389                }
390            }
391        }
392    } else {
393        quote! {}
394    };
395    let generics = &fn_info.generics;
396    let vis = fn_info.public;
397    quote! {
398        #[allow(non_camel_case_types,non_snake_case)]
399        #vis struct #name <#(#generics,)* #(#arg_list,)*BODYFN>
400        where #where_clause
401        {
402            // These hold the (phantom) types which represent if a field has
403            // been filled
404            #(#names_with_m: ::std::marker::PhantomData<#arg_names>,)*
405            // These hold the closures representing each argument
406            #option_list
407            // This holds the executable function
408            body: #bodyfn,
409        }
410
411        #clone
412    }
413}
414
415/// Contains options used to customize output
416struct MacroOptions {
417    attr: proc_macro::TokenStream,
418    by_value: bool,
419    impl_clone: bool,
420    impl_poly: bool,
421}
422
423impl MacroOptions {
424    fn new(attr: proc_macro::TokenStream) -> Self {
425        Self {
426            attr,
427            by_value: false,
428            impl_clone: false,
429            impl_poly: false,
430        }
431    }
432    fn from(attr: proc_macro::TokenStream) -> Self {
433        let attributes: Vec<String> = attr
434            .to_string()
435            .split(",")
436            .map(|s| s.trim().to_string())
437            .collect();
438        let mut attr_options = MacroOptions::new(attr);
439        attr_options.impl_poly = attributes.contains(&"poly".to_string());
440        attr_options.by_value = attributes.contains(&"value".to_string());
441        attr_options.impl_clone = attributes.contains(&"Clone".to_string());
442        attr_options
443    }
444    fn check(&self, span: &syn::Item) {
445        if self.impl_poly && self.by_value {
446            span.span()
447                .unstable()
448                .error(r#"Cannot implement "poly" and "value" at the same time"#)
449                .emit()
450        }
451
452        if self.impl_clone && !(self.impl_poly || self.by_value) {
453            span.span()
454                .unstable()
455                .error(r#"Cannot implement "Clone" without "poly" or "value""#)
456                .emit()
457        }
458        if !self.attr.is_empty() && !self.impl_poly && !self.by_value && !self.impl_clone {
459            span.span()
460                .unstable()
461                .error(
462                    r#"Unknown attribute. Acceptable attributes are "poly", "Clone" and "value""#,
463                )
464                .emit()
465        }
466    }
467}
468
469/// Contains prepossesses information about
470struct FunctionInformation<'a> {
471    argument_vec: Vec<&'a syn::PatType>,
472    ret_type: &'a syn::ReturnType,
473    generics: Vec<&'a syn::GenericParam>,
474    fn_name: &'a syn::Ident,
475    struct_name: syn::Ident,
476    unit: Units,
477    block: &'a syn::Block,
478    public: &'a syn::Visibility,
479    orignal_fn: &'a syn::ItemFn,
480}
481
482/// Contains Idents for the unit structs
483struct Units {
484    added: syn::Ident,
485    empty: syn::Ident,
486}
487
488impl<'a> FunctionInformation<'a> {
489    fn from(func: &'a syn::ItemFn) -> Self {
490        let func_name = &func.sig.ident;
491        Self {
492            argument_vec: argument_vector(&func.sig.inputs),
493            ret_type: &func.sig.output,
494            generics: func.sig.generics.params.iter().map(|f| f).collect(),
495            fn_name: func_name,
496            struct_name: syn::Ident::new(
497                &format!("__PartialApplication__{}_", func_name),
498                func_name.span(),
499            ),
500            unit: Units {
501                added: concat_ident(func_name, "Added"),
502                empty: concat_ident(func_name, "Empty"),
503            },
504            block: &func.block,
505            public: &func.vis,
506            orignal_fn: func,
507        }
508    }
509    fn check(&self) {
510        if let Some(r) = self.orignal_fn.sig.receiver() {
511            r.span()
512                .unstable()
513                .error("Cannot make methods partially applicable yet")
514                .emit();
515        }
516    }
517}