1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#[macro_use]
extern crate quote;

use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use proc_macro_error::{proc_macro_error, abort};
use syn::spanned::Spanned;
use venial::{parse_item, Fields, Item};

#[proc_macro_error]
#[proc_macro_derive(UserData)]
pub fn derive_user_data(input: TokenStream) -> TokenStream {
    let input = TokenStream2::from(input);
    let name = match parse_item(input.clone()) {
        Ok(Item::Struct(struct_type)) => {
            struct_type.name.clone()
        },
        Ok(Item::Enum(enum_type)) => {
            enum_type.name.clone()
        },
        Err(err) => abort!(err.span(), "{}", err),
        _ => abort!(input.span(), "only `struct` and `enum` types are supported for TypedUserData")
    };

    quote!(
        impl mlua::UserData for #name {
            fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) {
                let mut wrapper = mlua_extras::typed::WrappedGenerator::new(fields);
                <#name as mlua_extras::typed::TypedUserData>::add_fields(&mut wrapper);
            }

            fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
                let mut wrapper = mlua_extras::typed::WrappedGenerator::new(methods);
                <#name as mlua_extras::typed::TypedUserData>::add_methods(&mut wrapper);
            }
        }
    ).into()
}

#[proc_macro_error]
#[proc_macro_derive(Typed, attributes(typed))]
pub fn derive_typed(input: TokenStream) -> TokenStream {
    let input = TokenStream2::from(input);
    match parse_item(input.clone()) {
        Ok(Item::Struct(struct_type)) => {
            let name = struct_type.name.clone();
            let value = syn::LitStr::new(name.to_string().as_str(), Span::call_site());
            quote!(
                impl mlua_extras::typed::Typed for #name {
                    fn ty() -> mlua_extras::typed::Type {
                        mlua_extras::typed::Type::single(#value)
                    }
                }
            )
        },
        Ok(Item::Enum(enum_type)) => {
            let variants = enum_type.variants
                .iter()
                .map(|(variant, _punc)| {
                    let name = format!("\"{}\"", variant.name);
                    match &variant.fields {
                        Fields::Unit => quote!{ mlua_extras::typed::Type::single(#name) },
                        Fields::Tuple(tf) => {
                            let tuple_values = tf.fields.iter().map(|(field, _)| {
                                let ty = field.ty.clone();
                                quote!{ <#ty as mlua_extras::typed::Typed>::ty() }
                            }).collect::<Vec<_>>();

                            if tuple_values.len() == 1 {
                                let first = tuple_values.first().unwrap();
                                quote!{ #first }
                            } else {
                                quote!{ mlua_extras::typed::Type::Tuple(Vec::from([
                                        #(#tuple_values,)*
                                ])) }
                            }
                        },
                        Fields::Named(named) => {
                            let tuple_values = named.fields.iter().map(|(field, _)| {
                                let name = field.name.to_string();
                                let ty = field.ty.clone();
                                quote!{ (#name, <#ty as mlua_extras::typed::Typed>::ty()) }
                            }).collect::<Vec<_>>();
                            quote!{ mlua_extras::typed::Type::Struct(std::collections::HashMap::from([
                                    #(#tuple_values,)*
                            ])) }
                        }
                    }
                    
                })
                .collect::<Vec<_>>();

            // TODO: This should be a union alias
            let name = enum_type.name.clone();
            let value = name.to_string();
            quote!(
                impl mlua_extras::typed::Typed for #name {
                    fn ty() -> mlua_extras::typed::Type {
                        mlua_extras::typed::Type::r#enum(
                            #value,
                            [ #(#variants,)* ]
                        )
                    }
                }
            )
        },
        Err(err) => abort!(err.span(), "{}", err),
        _ => abort!(input.span(), "only `struct` and `enum` types are supported for Typed")
    }.into()
}