mvutils_proc_macro/
lib.rs

1extern crate proc_macro;
2
3use crate::savable::{enumerator, named, unit, unnamed};
4use proc_macro::{TokenStream};
5use std::str::FromStr;
6use proc_macro2::{Ident, Span};
7use quote::quote;
8use syn::{parse_macro_input, Data, DeriveInput, Expr, ExprClosure, Fields, LitStr, Meta, Path, Token};
9use syn::parse::{ParseBuffer, Parser};
10use syn::punctuated::Punctuated;
11
12mod savable;
13
14#[proc_macro_derive(Savable, attributes(unsaved, custom, varint))]
15pub fn derive_savable(input: TokenStream) -> TokenStream {
16    let input = parse_macro_input!(input as DeriveInput);
17
18    let name = input.ident;
19    let generics = input.generics;
20
21    let varint = input.attrs.iter().any(|attr| {
22        if let Meta::Path(ref p) = attr.meta {
23            p.segments.iter().any(|s| s.ident == "varint")
24        } else {
25            false
26        }
27    });
28
29    match &input.data {
30        Data::Struct(s) => match &s.fields {
31            Fields::Named(fields) => named(fields, name, generics),
32            Fields::Unnamed(fields) => unnamed(fields, name, generics),
33            Fields::Unit => unit(name, generics),
34        },
35        Data::Enum(e) => enumerator(e, name, generics, varint),
36        Data::Union(_) => panic!("Deriving Savable for unions is not supported!"),
37    }
38}
39
40#[proc_macro_derive(TryFromString, attributes(exclude, casing, pattern, custom, inner))]
41pub fn try_from_string(input: TokenStream) -> TokenStream {
42    let input = parse_macro_input!(input as DeriveInput);
43    let name = input.ident.clone();
44
45    #[derive(Clone, Copy)]
46    enum Casing {
47        Lower,
48        Upper,
49        Both,
50    }
51
52    // helper: check #[exclude]
53    fn is_excluded(v: &syn::Variant) -> bool {
54        v.attrs.iter().any(|attr| attr.path().is_ident("exclude"))
55    }
56
57    // helper: check #[casing(...)]
58    fn get_casing(v: &syn::Variant) -> Casing {
59        for attr in &v.attrs {
60            if attr.path().is_ident("casing") {
61                if let Ok(list) = attr.meta.require_list() {
62                    if let Ok(path) = list.parse_args::<Path>() {
63                        let ident = path.get_ident().unwrap().to_string();
64                        return match ident.as_str() {
65                            "Lower" => Casing::Lower,
66                            "Upper" => Casing::Upper,
67                            "Both" => Casing::Both,
68                            other => panic!("Invalid casing: {}", other),
69                        };
70                    }
71                }
72            }
73        }
74        Casing::Both
75    }
76
77    fn get_pattern(v: &syn::Variant) -> Option<String> {
78        for attr in &v.attrs {
79            if attr.path().is_ident("pattern") {
80                if let Ok(list) = attr.meta.require_list() {
81                    let l = list.parse_args::<LitStr>().ok()?;
82                    return Some(l.value());
83                }
84            }
85        }
86        None
87    }
88
89    fn get_custom(v: &syn::Variant) -> Option<Vec<LitStr>> {
90        for attr in &v.attrs {
91            if attr.path().is_ident("custom") {
92                if let Ok(list) = attr.meta.require_list() {
93                    let parser = Punctuated::<LitStr, Token![,]>::parse_terminated;
94                    if let Ok(punctuated) = parser.parse2(list.tokens.clone()) {
95                        return Some(
96                            punctuated
97                                .into_iter()
98                                .collect()
99                        );
100                    }
101                }
102            }
103        }
104        None
105    }
106
107    fn get_inner(v: &syn::Variant) -> Option<Expr> {
108        for attr in &v.attrs {
109            if attr.path().is_ident("inner") {
110                if let Ok(list) = attr.meta.require_list() {
111                    return list.parse_args::<Expr>().ok();
112                }
113            }
114        }
115        None
116    }
117
118    match &input.data {
119        Data::Enum(e) => {
120            let mut statics = quote! {};
121
122            let values: Vec<proc_macro2::TokenStream> = e.variants.iter().filter(|v| !is_excluded(v)).flat_map(|v| {
123                let ident = &v.ident;
124                let name_str = ident.to_string();
125                let casing = get_casing(v);
126                let pattern = get_pattern(v);
127                let custom = get_custom(v);
128                let inner = get_inner(v);
129
130                let constructor = if let Some(inner) = inner {
131                    quote! {{
132                        let e = #inner;
133                        Ok(Self::#ident(e(value).ok_or(())?))
134                    }}
135                } else {
136                    if !v.fields.is_empty() {
137                        panic!("Attention! Inner fields must be provided a valid parse closure using the #[inner()] attribute! The closure takes an &String and returns a Option<T>")
138                    }
139                    quote! {
140                        Ok(Self::#ident)
141                    }
142                };
143
144                if let Some(custom) = custom {
145                    vec![quote! {
146                        s if [#(#custom),*].contains(s) => #constructor
147                    }]
148                } else if let Some(pattern) = pattern {
149                    let regex_name_s = format!("{name}_{name_str}_regex");
150                    let regex_name = Ident::new(&regex_name_s, Span::call_site());
151
152                    statics.extend(quote! {
153                        static #regex_name: Lazy<Regex> = Lazy::new(|| Regex::new(#pattern).unwrap());
154                    });
155
156                    vec![quote! {
157                        _ if #regex_name.is_match() => #constructor
158                    }]
159                } else {
160                    let mut arms = Vec::new();
161                    match casing {
162                        Casing::Lower => {
163                            let lower = name_str.to_lowercase();
164                            arms.push(quote! { #lower => #constructor });
165                        }
166                        Casing::Upper => {
167                            let upper = name_str.to_uppercase();
168                            arms.push(quote! { #upper => #constructor });
169                        }
170                        Casing::Both => {
171                            let lower = name_str.to_lowercase();
172                            let upper = name_str.to_uppercase();
173                            arms.push(quote! { #lower => #constructor });
174                            arms.push(quote! { #upper => #constructor });
175                        }
176                    }
177                    arms
178                }
179            }).collect();
180
181            let expanded = quote! {
182                #statics
183
184                impl core::str::FromStr for #name {
185                    type Err = ();
186
187                    fn from_str(value: &str) -> Result<Self, Self::Err> {
188                        match value {
189                            #(#values,)*
190                            _ => Err(()),
191                        }
192                    }
193                }
194            };
195
196            expanded.into()
197        }
198        _ => panic!("`TryFromString` can only be derived for enums"),
199    }
200}
201
202enum Casing {
203    Lower,
204    Upper,
205    Both,
206}