1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
use heck::ToSnakeCase;
use proc_macro::{self, TokenStream};
use proc_macro2::Span;
use quote::{quote, quote_spanned};
use syn::{
    parse_macro_input, Attribute, DataEnum, DataStruct, DeriveInput, Fields, Ident, Lit, Meta,
    Variant,
};

fn get_iden_attr(attrs: &[Attribute]) -> Option<syn::Lit> {
    for attr in attrs {
        let name_value = match attr.parse_meta() {
            Ok(Meta::NameValue(nv)) => nv,
            _ => continue,
        };
        if name_value.path.is_ident("iden") || // interoperate with sea_query_derive Iden
            name_value.path.is_ident("name")
        {
            return Some(name_value.lit);
        }
    }
    None
}

fn get_catch_attr(attrs: &[Attribute]) -> Option<syn::Lit> {
    for attr in attrs {
        let name_value = match attr.parse_meta() {
            Ok(Meta::NameValue(nv)) => nv,
            _ => continue,
        };
        if name_value.path.is_ident("catch") {
            return Some(name_value.lit);
        }
    }
    None
}

#[proc_macro_derive(Name, attributes(iden, name, catch))]
pub fn derive_iden(input: TokenStream) -> TokenStream {
    let DeriveInput {
        ident, data, attrs, ..
    } = parse_macro_input!(input);

    let table_name = match get_iden_attr(&attrs) {
        Some(lit) => quote! { #lit },
        None => {
            let normalized = ident.to_string().to_snake_case();
            quote! { #normalized }
        }
    };

    let catch = match get_catch_attr(&attrs) {
        Some(lit) => {
            let name: String = match lit {
                Lit::Str(name) => name.value(),
                _ => panic!("expected string for `catch`"),
            };
            let method = Ident::new(name.as_str(), Span::call_site());

            quote! { #ident::#method(string) }
        }
        None => {
            quote! { None }
        }
    };

    // Currently we only support enums and unit structs
    let variants =
        match data {
            syn::Data::Enum(DataEnum { variants, .. }) => variants,
            syn::Data::Struct(DataStruct {
                fields: Fields::Unit,
                ..
            }) => {
                return quote! {
                    impl sea_schema::Name for #ident {
                        fn from_str(string: &str) -> Option<Self> {
                            if string == #table_name {
                                Some(Self)
                            } else {
                                None
                            }
                        }
                    }
                }
                .into()
            }
            _ => return quote_spanned! {
                ident.span() => compile_error!("you can only derive Name on enums or unit structs");
            }
            .into(),
        };

    if variants.is_empty() {
        return TokenStream::new();
    }

    let variant = variants
        .iter()
        .filter(|v| get_catch_attr(&v.attrs).is_none() && matches!(v.fields, Fields::Unit))
        .map(|Variant { ident, fields, .. }| match fields {
            Fields::Unit => quote! { #ident },
            _ => panic!(),
        });

    let name = variants.iter().map(|v| {
        if let Some(lit) = get_iden_attr(&v.attrs) {
            // If the user supplied a name, just use it
            quote! { #lit }
        } else if v.ident == "Table" {
            table_name.clone()
        } else {
            let ident = v.ident.to_string().to_snake_case();
            quote! { #ident }
        }
    });

    let output = quote! {
        impl sea_schema::Name for #ident {
            fn from_str(string: &str) -> Option<Self> {
                let result = match string {
                    #(#name => Some(Self::#variant),)*
                    _ => None,
                };
                if result.is_some() {
                    result
                } else {
                    #catch
                }
            }
        }
    };

    output.into()
}