1#![feature(proc_macro_diagnostic)]
2extern crate proc_macro;
3use proc_macro::TokenStream;
4use proc_macro2;
5use quote::quote;
6use syn;
7use syn::spanned::Spanned;
8
9#[proc_macro_attribute]
11pub fn part_app(attr: TokenStream, item: TokenStream) -> TokenStream {
12 let func_item: syn::Item = syn::parse(item).expect("failed to parse input");
13 let attr_options = MacroOptions::from(attr);
14 attr_options.check(&func_item);
15
16 match func_item {
17 syn::Item::Fn(ref func) => {
18 let fn_info = FunctionInformation::from(func);
19 fn_info.check();
20
21 if let Some(w) = &func.sig.generics.where_clause {
23 w.span()
24 .unstable()
25 .error("part_app does not allow where clauses")
26 .emit();
27 }
28
29 let func_struct = main_struct(&fn_info, &attr_options);
30 let generator_func = generator_func(&fn_info, &attr_options);
31 let final_call = final_call(&fn_info, &attr_options);
32 let argument_calls = argument_calls(&fn_info, &attr_options);
33
34 let unit_structs = {
35 let added_unit = fn_info.unit.added;
36 let empty_unit = fn_info.unit.empty;
37 let vis = fn_info.public;
38 quote! {
39 #[allow(non_camel_case_types,non_snake_case)]
40 #vis struct #added_unit;
41 #[allow(non_camel_case_types,non_snake_case)]
42 #vis struct #empty_unit;
43 }
44 };
45
46 let mut out = proc_macro2::TokenStream::new();
48 out.extend(unit_structs);
49 out.extend(func_struct);
50 out.extend(generator_func);
51 out.extend(argument_calls);
52 out.extend(final_call);
53 TokenStream::from(out)
55 }
56 _ => {
57 func_item
58 .span()
59 .unstable()
60 .error(
61 "Only functions can be partially applied, for structs use the builder pattern",
62 )
63 .emit();
64 proc_macro::TokenStream::from(quote! { #func_item })
65 }
66 }
67}
68
69fn impl_signature<'a>(
71 args: &Vec<&syn::PatType>,
72 ret_type: &'a syn::ReturnType,
73 generics: &Vec<&syn::GenericParam>,
74 opts: &MacroOptions,
75) -> proc_macro2::TokenStream {
76 let arg_names = arg_names(&args);
77 let arg_types = arg_types(&args);
78 let augmented_names = if !(opts.impl_poly || opts.by_value) {
79 augmented_argument_names(&arg_names)
80 } else {
81 Vec::new()
82 };
83
84 quote! {
85 #(#generics,)*
86 #(#augmented_names: Fn() -> #arg_types,)*
87 BODYFN: Fn(#(#arg_types,)*) #ret_type
88 }
89}
90
91fn argument_calls<'a>(
94 fn_info: &FunctionInformation,
95 opts: &MacroOptions,
96) -> proc_macro2::TokenStream {
97 let impl_sig = impl_signature(
98 &fn_info.argument_vec,
99 fn_info.ret_type,
100 &fn_info.generics,
101 opts,
102 );
103 let arg_name_vec = arg_names(&fn_info.argument_vec);
104 let aug_arg_names = augmented_argument_names(&arg_name_vec);
105 let arg_types = arg_types(&fn_info.argument_vec);
106 arg_names(&fn_info.argument_vec)
107 .into_iter()
108 .zip(&aug_arg_names)
109 .zip(arg_types)
110 .map(|((n, n_fn), n_type)| {
111 let free_vars: Vec<_> = arg_name_vec.iter().filter(|&x| x != &n).collect();
113 let associated_vals_out: Vec<_> = arg_name_vec
114 .iter()
115 .map(|x| {
116 if &n == x {
117 fn_info.unit.added.clone()
118 } else {
119 x.clone()
120 }
121 })
122 .collect();
123 let val_list_out = if opts.impl_poly || opts.by_value {
124 quote! {#(#associated_vals_out,)*}
125 } else {
126 quote! {#(#associated_vals_out, #aug_arg_names,)*}
127 };
128 let associated_vals_in: Vec<_> = associated_vals_out
129 .iter()
130 .map(|x| {
131 if x == &fn_info.unit.added {
132 &fn_info.unit.empty
133 } else {
134 x
135 }
136 })
137 .collect();
138 let val_list_in = if opts.impl_poly || opts.by_value {
139 quote! {#(#associated_vals_in,)*}
140 } else {
141 quote! {#(#associated_vals_in, #aug_arg_names,)*}
142 };
143 let (transmute, self_type) = if opts.impl_poly || opts.impl_clone {
144 (quote!(transmute), quote!(self))
145 } else {
146 (quote!(transmute_copy), quote!(&self))
148 };
149 let some = if opts.impl_poly {
150 quote! {Some(::std::sync::Arc::from(#n))}
151 } else {
152 quote! {Some(#n)}
154 };
155 let in_type = if opts.impl_poly {
156 quote! { Box<dyn Fn() -> #n_type> }
157 } else if opts.by_value {
158 quote! {#n_type}
159 } else {
160 quote! { #n_fn }
161 };
162 let struct_name = &fn_info.struct_name;
163 let generics = &fn_info.generics;
164 let vis = fn_info.public;
165 quote! {
166 #[allow(non_camel_case_types,non_snake_case)]
167 impl< #impl_sig, #(#free_vars,)* > #struct_name<#(#generics,)* #val_list_in BODYFN> {
170 #vis fn #n (mut self, #n: #in_type) ->
171 #struct_name<#(#generics,)* #val_list_out BODYFN>{
172 self.#n = #some;
173 unsafe {
174 ::std::mem::#transmute::<
175 #struct_name<#(#generics,)* #val_list_in BODYFN>,
176 #struct_name<#(#generics,)* #val_list_out BODYFN>,
177 >(#self_type)
178 }
179 }
180 }
181 }
182 })
183 .collect()
184}
185
186fn final_call<'a>(fn_info: &FunctionInformation, opts: &MacroOptions) -> proc_macro2::TokenStream {
188 let ret_type = fn_info.ret_type;
189 let generics = &fn_info.generics;
190 let unit_added = &fn_info.unit.added;
191 let struct_name = &fn_info.struct_name;
192 let impl_sig = impl_signature(&fn_info.argument_vec, ret_type, generics, opts);
193 let arg_names = arg_names(&fn_info.argument_vec);
194 let aug_args = augmented_argument_names(&arg_names);
195 let vis = fn_info.public;
196 let arg_list: proc_macro2::TokenStream = if opts.impl_poly || opts.by_value {
197 aug_args.iter().map(|_| quote! {#unit_added,}).collect()
198 } else {
199 aug_args.iter().map(|a| quote! {#unit_added, #a,}).collect()
200 };
201 let call = if !opts.by_value {
202 quote! {()}
203 } else {
204 quote! {}
205 };
206 quote! {
207 #[allow(non_camel_case_types,non_snake_case)]
208 impl <#impl_sig> #struct_name<#(#generics,)* #arg_list BODYFN>
211 {
212 #vis fn call(self) #ret_type { (self.body)(#(self.#arg_names.unwrap()#call,)*)
214 }
215 }
216 }
217}
218
219fn generator_func<'a>(
222 fn_info: &FunctionInformation,
223 opts: &MacroOptions,
224) -> proc_macro2::TokenStream {
225 let arg_names = arg_names(&fn_info.argument_vec);
227 let arg_types = arg_types(&fn_info.argument_vec);
228 let marker_names = marker_names(&arg_names);
229 let generics = &fn_info.generics;
230 let empty_unit = &fn_info.unit.empty;
231 let body = fn_info.block;
232 let name = fn_info.fn_name;
233 let struct_name = &fn_info.struct_name;
234 let ret_type = fn_info.ret_type;
235 let vis = fn_info.public;
236
237 let gen_types = if opts.impl_poly || opts.by_value {
238 quote! {#(#generics,)*}
239 } else {
240 quote! {#(#generics,)* #(#arg_names,)*}
241 };
242 let struct_types = if opts.impl_poly || opts.by_value {
243 arg_names.iter().map(|_| quote! {#empty_unit,}).collect()
244 } else {
245 quote! {#(#empty_unit,#arg_names,)*}
246 };
247 let body_fn = if opts.impl_poly || opts.impl_clone {
248 quote! {::std::sync::Arc::new(|#(#arg_names,)*| #body),}
249 } else {
250 quote! {|#(#arg_names,)*| #body,}
251 };
252 let where_clause = if opts.impl_poly || opts.by_value {
253 quote!()
254 } else {
255 quote! {
256 where
257 #(#arg_names: Fn() -> #arg_types,)*
258 }
259 };
260 quote! {
261 #[allow(non_camel_case_types,non_snake_case)]
262 #vis fn #name<#gen_types>() -> #struct_name<#(#generics,)* #struct_types
263 impl Fn(#(#arg_types,)*) #ret_type>
264 #where_clause
265 {
266 #struct_name {
267 #(#arg_names: None,)*
268 #(#marker_names: ::std::marker::PhantomData,)*
269 body: #body_fn
270 }
271 }
272
273 }
274}
275
276fn arg_names<'a>(args: &Vec<&syn::PatType>) -> Vec<syn::Ident> {
278 args.iter()
279 .map(|f| {
280 let f_pat = &f.pat;
281 syn::Ident::new(&format!("{}", quote!(#f_pat)), f.span())
282 })
283 .collect()
284}
285
286fn marker_names(names: &Vec<syn::Ident>) -> Vec<syn::Ident> {
288 names.iter().map(|f| concat_ident(f, "m")).collect()
289}
290
291fn concat_ident<'a>(ident: &'a syn::Ident, end: &str) -> syn::Ident {
293 let name = format!("{}___{}", quote! {#ident}, end);
294 syn::Ident::new(&name, ident.span())
295}
296
297fn argument_vector<'a>(
299 args: &'a syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>,
300) -> Vec<&syn::PatType> {
301 args.iter()
302 .map(|fn_arg| match fn_arg {
303 syn::FnArg::Receiver(_) => panic!("should filter out reciever arguments"),
304 syn::FnArg::Typed(t) => {
305 if let syn::Type::Reference(r) = t.ty.as_ref() {
306 if r.lifetime.is_none() {
307 t.span()
308 .unstable()
309 .error("part_app does not support lifetime elision")
310 .emit();
311 }
312 }
313
314 t
315 }
316 })
317 .collect()
318}
319
320fn arg_types<'a>(args: &Vec<&'a syn::PatType>) -> Vec<&'a syn::Type> {
322 args.iter().map(|f| f.ty.as_ref()).collect()
323}
324
325fn augmented_argument_names<'a>(arg_names: &Vec<syn::Ident>) -> Vec<syn::Ident> {
327 arg_names.iter().map(|f| concat_ident(f, "FN")).collect()
328}
329
330fn main_struct<'a>(fn_info: &FunctionInformation, opts: &MacroOptions) -> proc_macro2::TokenStream {
333 let arg_types = arg_types(&fn_info.argument_vec);
334
335 let arg_names = arg_names(&fn_info.argument_vec);
336 let arg_augmented = augmented_argument_names(&arg_names);
337 let ret_type = fn_info.ret_type;
338
339 let arg_list: Vec<_> = if !(opts.impl_poly || opts.by_value) {
340 arg_names
341 .iter()
342 .zip(arg_augmented.iter())
343 .flat_map(|(n, a)| vec![n, a])
344 .collect()
345 } else {
346 arg_names.iter().collect()
347 };
348 let bodyfn = if opts.impl_poly || opts.impl_clone {
349 quote! {::std::sync::Arc<BODYFN>}
350 } else {
351 quote! { BODYFN }
352 };
353 let where_clause = if opts.impl_poly || opts.by_value {
354 quote!(BODYFN: Fn(#(#arg_types,)*) #ret_type,)
355 } else {
356 quote! {
357 #(#arg_augmented: Fn() -> #arg_types,)*
358 BODYFN: Fn(#(#arg_types,)*) #ret_type,
359 }
360 };
361 let names_with_m = marker_names(&arg_names);
362 let option_list = if opts.impl_poly {
363 quote! {#(#arg_names: Option<::std::sync::Arc<dyn Fn() -> #arg_types>>,)*}
364 } else if opts.by_value {
365 quote! {#(#arg_names: Option<#arg_types>,)*}
366 } else {
367 quote! {#(#arg_names: Option<#arg_augmented>,)*}
368 };
369 let name = &fn_info.struct_name;
370
371 let clone = if opts.impl_clone {
372 let sig = impl_signature(
373 &fn_info.argument_vec,
374 fn_info.ret_type,
375 &fn_info.generics,
376 opts,
377 );
378 quote! {
379 #[allow(non_camel_case_types,non_snake_case)]
380 impl<#sig, #(#arg_list,)*> ::std::clone::Clone for #name <#(#arg_list,)* BODYFN>
381 where #where_clause
382 {
383 fn clone(&self) -> Self {
384 Self {
385 #(#names_with_m: ::std::marker::PhantomData,)*
386 #(#arg_names: self.#arg_names.clone(),)*
387 body: self.body.clone(),
388 }
389 }
390 }
391 }
392 } else {
393 quote! {}
394 };
395 let generics = &fn_info.generics;
396 let vis = fn_info.public;
397 quote! {
398 #[allow(non_camel_case_types,non_snake_case)]
399 #vis struct #name <#(#generics,)* #(#arg_list,)*BODYFN>
400 where #where_clause
401 {
402 #(#names_with_m: ::std::marker::PhantomData<#arg_names>,)*
405 #option_list
407 body: #bodyfn,
409 }
410
411 #clone
412 }
413}
414
415struct MacroOptions {
417 attr: proc_macro::TokenStream,
418 by_value: bool,
419 impl_clone: bool,
420 impl_poly: bool,
421}
422
423impl MacroOptions {
424 fn new(attr: proc_macro::TokenStream) -> Self {
425 Self {
426 attr,
427 by_value: false,
428 impl_clone: false,
429 impl_poly: false,
430 }
431 }
432 fn from(attr: proc_macro::TokenStream) -> Self {
433 let attributes: Vec<String> = attr
434 .to_string()
435 .split(",")
436 .map(|s| s.trim().to_string())
437 .collect();
438 let mut attr_options = MacroOptions::new(attr);
439 attr_options.impl_poly = attributes.contains(&"poly".to_string());
440 attr_options.by_value = attributes.contains(&"value".to_string());
441 attr_options.impl_clone = attributes.contains(&"Clone".to_string());
442 attr_options
443 }
444 fn check(&self, span: &syn::Item) {
445 if self.impl_poly && self.by_value {
446 span.span()
447 .unstable()
448 .error(r#"Cannot implement "poly" and "value" at the same time"#)
449 .emit()
450 }
451
452 if self.impl_clone && !(self.impl_poly || self.by_value) {
453 span.span()
454 .unstable()
455 .error(r#"Cannot implement "Clone" without "poly" or "value""#)
456 .emit()
457 }
458 if !self.attr.is_empty() && !self.impl_poly && !self.by_value && !self.impl_clone {
459 span.span()
460 .unstable()
461 .error(
462 r#"Unknown attribute. Acceptable attributes are "poly", "Clone" and "value""#,
463 )
464 .emit()
465 }
466 }
467}
468
469struct FunctionInformation<'a> {
471 argument_vec: Vec<&'a syn::PatType>,
472 ret_type: &'a syn::ReturnType,
473 generics: Vec<&'a syn::GenericParam>,
474 fn_name: &'a syn::Ident,
475 struct_name: syn::Ident,
476 unit: Units,
477 block: &'a syn::Block,
478 public: &'a syn::Visibility,
479 orignal_fn: &'a syn::ItemFn,
480}
481
482struct Units {
484 added: syn::Ident,
485 empty: syn::Ident,
486}
487
488impl<'a> FunctionInformation<'a> {
489 fn from(func: &'a syn::ItemFn) -> Self {
490 let func_name = &func.sig.ident;
491 Self {
492 argument_vec: argument_vector(&func.sig.inputs),
493 ret_type: &func.sig.output,
494 generics: func.sig.generics.params.iter().map(|f| f).collect(),
495 fn_name: func_name,
496 struct_name: syn::Ident::new(
497 &format!("__PartialApplication__{}_", func_name),
498 func_name.span(),
499 ),
500 unit: Units {
501 added: concat_ident(func_name, "Added"),
502 empty: concat_ident(func_name, "Empty"),
503 },
504 block: &func.block,
505 public: &func.vis,
506 orignal_fn: func,
507 }
508 }
509 fn check(&self) {
510 if let Some(r) = self.orignal_fn.sig.receiver() {
511 r.span()
512 .unstable()
513 .error("Cannot make methods partially applicable yet")
514 .emit();
515 }
516 }
517}