Skip to main content

ohos_enum_derive/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::{
5    parse::Parse, parse::ParseStream, parse_macro_input, Data, DeriveInput, Ident, LitStr, Token,
6    Type,
7};
8
9struct ConfigArgs {
10    target_type: Ident,
11    prefix: LitStr,
12    extra_types: Vec<Type>,
13}
14
15impl Parse for ConfigArgs {
16    fn parse(input: ParseStream) -> syn::Result<Self> {
17        let target_type: Ident = input.parse()?;
18        input.parse::<Token![,]>()?;
19        let prefix: LitStr = input.parse()?;
20        let mut extra_types = Vec::new();
21        while !input.is_empty() {
22            if input.peek(Token![,]) {
23                input.parse::<Token![,]>()?;
24            }
25            if input.is_empty() {
26                break;
27            }
28            extra_types.push(input.parse()?);
29        }
30        Ok(ConfigArgs {
31            target_type,
32            prefix,
33            extra_types,
34        })
35    }
36}
37
38fn to_upper_snake_case(s: &str) -> String {
39    s.to_case(Case::UpperSnake)
40}
41
42fn get_variant_prefix(variant: &syn::Variant, default_prefix: &str) -> String {
43    for attr in &variant.attrs {
44        if attr.path().is_ident("prefix") {
45            if let Ok(lit_str) = attr.parse_args::<LitStr>() {
46                return lit_str.value();
47            }
48        }
49    }
50    default_prefix.to_string()
51}
52
53fn get_target_variant(variant: &syn::Variant, default_prefix: &str) -> Ident {
54    for attr in &variant.attrs {
55        if attr.path().is_ident("alias") {
56            if let Ok(lit_str) = attr.parse_args::<LitStr>() {
57                return format_ident!("{}", lit_str.value());
58            }
59        }
60    }
61
62    let variant_prefix = get_variant_prefix(variant, default_prefix);
63    for attr in &variant.attrs {
64        if attr.path().is_ident("suffix") {
65            if let Ok(lit_str) = attr.parse_args::<LitStr>() {
66                return format_ident!("{}{}", variant_prefix, lit_str.value());
67            }
68        }
69    }
70
71    let upper_snake_variant = to_upper_snake_case(&variant.ident.to_string());
72    format_ident!("{}{}", variant_prefix, upper_snake_variant)
73}
74
75#[proc_macro_derive(EnumFrom, attributes(config, prefix, suffix, alias))]
76pub fn enum_from(input: TokenStream) -> TokenStream {
77    let input = parse_macro_input!(input as DeriveInput);
78    let name = &input.ident;
79
80    let args = input
81        .attrs
82        .iter()
83        .find(|attr| attr.path().is_ident("config"))
84        .map(|attr| attr.parse_args::<ConfigArgs>())
85        .expect("config attribute is required")
86        .expect("Failed to parse config attribute");
87
88    let target_type = args.target_type;
89    let default_prefix = args.prefix.value();
90    let extra_types = args.extra_types;
91
92    let variants = match &input.data {
93        Data::Enum(data_enum) => &data_enum.variants,
94        _ => panic!("EnumFrom can only be derived for enums"),
95    };
96
97    let from_attribute_type_arms: Vec<_> = variants
98        .iter()
99        .map(|v| {
100            let variant = &v.ident;
101            let target_variant = get_target_variant(v, &default_prefix);
102            quote! {
103                #name::#variant => #target_variant,
104            }
105        })
106        .collect();
107
108    let from_target_type_arms: Vec<_> = variants
109        .iter()
110        .map(|v| {
111            let variant = &v.ident;
112            let target_variant = get_target_variant(v, &default_prefix);
113            quote! {
114                #target_variant => #name::#variant,
115            }
116        })
117        .collect();
118
119    let try_from_target_type_arms: Vec<_> = variants
120        .iter()
121        .map(|v| {
122            let variant = &v.ident;
123            let target_variant = get_target_variant(v, &default_prefix);
124            quote! {
125                #target_variant => ::std::option::Option::Some(#name::#variant),
126            }
127        })
128        .collect();
129
130    let extra_from_attribute_type_impls: Vec<_> = extra_types
131        .iter()
132        .map(|extra_type| {
133            quote! {
134                impl From<#name> for #extra_type {
135                    fn from(attr: #name) -> Self {
136                        let raw: #target_type = attr.into();
137                        raw as #extra_type
138                    }
139                }
140            }
141        })
142        .collect();
143
144    let extra_from_target_type_impls: Vec<_> = extra_types
145        .iter()
146        .map(|extra_type| {
147            quote! {
148                impl From<#extra_type> for #name {
149                    fn from(attr: #extra_type) -> Self {
150                        let raw = attr as #target_type;
151                        raw.into()
152                    }
153                }
154            }
155        })
156        .collect();
157
158    let expanded = quote! {
159        impl From<#name> for #target_type {
160            fn from(attr: #name) -> Self {
161                match attr {
162                    #(#from_attribute_type_arms)*
163                    _ => unreachable!("Invalid attribute value"),
164                }
165            }
166        }
167
168        impl From<#target_type> for #name {
169            fn from(attr: #target_type) -> Self {
170                match attr {
171                    #(#from_target_type_arms)*
172                    _ => unreachable!("Invalid attribute value"),
173                }
174            }
175        }
176
177        impl #name {
178            pub fn try_from_raw(attr: #target_type) -> ::std::option::Option<Self> {
179                match attr {
180                    #(#try_from_target_type_arms)*
181                    _ => ::std::option::Option::None,
182                }
183            }
184        }
185
186        #(#extra_from_attribute_type_impls)*
187
188        #(#extra_from_target_type_impls)*
189    };
190
191    TokenStream::from(expanded)
192}