1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#![allow(incomplete_features)]
///# Masala
///
/// An autocurrying macro for Rust.
///
/// ## Usage
///
/// This crate requires nightly:
///
///```
/// #![feature(type_alias_impl_trait, min_type_alias_impl_trait)]
/// use masala::curry;
///
/// #[curry]
/// fn add(a: isize, b: isize) -> isize {
///    a + b
/// }
///
/// fn main() {
///    println!("{}", add(33)(42));
/// }
///```
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::punctuated::Punctuated;
use syn::{parse_macro_input, Block, FnArg, ItemFn, Pat, ReturnType, Type, TypeGenerics};

/// Curry is a proc_macro. It takes a Rust function written like:
/// ```
/// #![feature(type_alias_impl_trait, min_type_alias_impl_trait)]
/// use masala::curry;
///
/// #[curry]
/// fn add(a: isize, b:isize) -> isize {
///     a + b
/// }
///
/// fn main() {
///     let add_10 = add(10);
///     let x = 23;
///     println!("Ten plus {} is {}", x, add_10(x));
/// }
/// ```
/// This macro supports generics, to an extent. For example this is valid:
/// ```
/// #![feature(type_alias_impl_trait, min_type_alias_impl_trait)]
/// use std::ops::Add;
/// use masala::curry;
///
/// #[curry]
/// fn add<T: Add + Add<Output = T> + Clone>(a: T, b:T) -> T {
///     a.clone() + b
/// }
/// ```
/// The macro does expect that a function returns _something_. A function with
/// no return will not compile. To take another function as a parameter follow
/// this pattern:
/// ```
/// #![feature(type_alias_impl_trait, min_type_alias_impl_trait)]
/// use masala::curry;
///
/// #[curry]
/// fn psi<T: Clone>(a: fn(T, T) -> T, b: fn(T) -> T, c: T, d: T) -> T {
///     a(b(c.clone()), b(d))
/// }
/// ```
#[proc_macro_attribute]
pub fn curry(_attr: TokenStream, item: TokenStream) -> TokenStream {
    let parsed = parse_macro_input!(item as ItemFn);
    cook_curry(parsed).into()
}

fn cook_curry(parsed: ItemFn) -> proc_macro2::TokenStream {
    let fn_body = parsed.block;
    let sig = parsed.sig;
    let (impl_generics, ty_generics, _) = sig.generics.split_for_impl();
    let vis = parsed.vis;
    let fn_name = sig.ident;
    let fn_args = sig.inputs;
    let fn_return_type = sig.output;

    let arg_idents = extract_arg_idents(fn_args.clone());
    let first_ident = &arg_idents.first().unwrap();

    let curried_body = generate_body(&arg_idents[1..], fn_body.clone());

    let arg_types = extract_arg_types(fn_args.clone());
    let first_type = &arg_types.first().unwrap();
    let type_aliases = generate_type_aliases(
        &arg_types[1..],
        extract_return_type(fn_return_type),
        &fn_name,
        ty_generics.clone(),
    );
    let return_type = format_ident!(
        "{}{}",
        title_case(&fn_name.to_string()),
        format!("T{}", type_aliases.len() - 1)
    );

    quote! {
        #(#type_aliases);* ;
        #vis fn #fn_name #impl_generics (#first_ident: #first_type) -> #return_type #ty_generics {
            #curried_body ;
        }
    }
}

fn extract_arg_idents(fn_args: Punctuated<FnArg, syn::token::Comma>) -> Vec<Box<Pat>> {
    fn_args.into_iter().map(extract_arg_pat).collect::<Vec<_>>()
}

fn extract_arg_pat(a: FnArg) -> Box<Pat> {
    match a {
        FnArg::Typed(p) => p.pat,
        _ => panic!("Not supported on types with `self!`"),
    }
}

fn extract_arg_types(fn_args: Punctuated<FnArg, syn::token::Comma>) -> Vec<Box<Type>> {
    fn_args.into_iter().map(extract_type).collect::<Vec<_>>()
}

fn extract_return_type(a: ReturnType) -> Box<Type> {
    match a {
        ReturnType::Type(_, p) => p,
        _ => panic!("Not supported on functions without return types!"),
    }
}

fn extract_type(a: FnArg) -> Box<Type> {
    match a {
        FnArg::Typed(p) => p.ty,
        _ => panic!("Not supported on types with `self!`"),
    }
}

fn generate_body(fn_args: &[Box<Pat>], body: Box<Block>) -> proc_macro2::TokenStream {
    quote! {
        return #( move |#fn_args| )* #body
    }
}

fn generate_type_aliases(
    fn_arg_types: &[Box<Type>],
    fn_return_type: Box<Type>,
    fn_name: &syn::Ident,
    type_generics: TypeGenerics,
) -> Vec<proc_macro2::TokenStream> {
    let type_t0 = format_ident!("{}T0", title_case(&fn_name.to_string()));
    let mut type_aliases = vec![quote! { type #type_t0 #type_generics = #fn_return_type}];
    for (i, t) in (1..).zip(fn_arg_types.into_iter().rev()) {
        let func = title_case(&fn_name.to_string());
        let p = format_ident!("{}{}", func, format!("T{}", i - 1));
        let n = format_ident!("{}{}", func, format!("T{}", i));
        type_aliases.push(quote! {
            type #n #type_generics = impl Fn(#t) -> #p #type_generics
        });
    }
    type_aliases
}

fn title_case(s: &str) -> String {
    let mut c = s.chars();
    match c.next() {
        None => String::new(),
        Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
    }
}