use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::{parse_macro_input, Attribute, DeriveInput, Lit, Meta, MetaNameValue};
#[proc_macro_derive(
EnvVar,
attributes(case, var_name, default, panic_on_invalid, ignore_variant)
)]
pub fn enum_from_env(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let enum_name = &input.ident;
let var_name_to_check_for = match get_var_name(&input.attrs) {
Some(v) => v,
None => enum_name.to_string().to_uppercase(),
};
let variants = match input.data {
syn::Data::Enum(ref variants) => &variants.variants,
_ => panic!("EnvVar can only be derived for enums"),
};
let mut invalid_type: Option<&syn::Ident> = None;
for variant in variants {
if &variant.ident.to_token_stream().to_string() == "Invalid" {
invalid_type = Some(&variant.ident);
};
}
let mut default_value: Option<&syn::Ident> = None;
let panic_on_invalid = input.attrs.iter().any(|attr| {
if let Ok(Meta::Path(path)) = attr.parse_meta() {
path.is_ident("panic_on_invalid")
} else {
false
}
});
let default_case = get_case_conversion(&input.attrs);
let default_case_conversion = match default_case.0 {
CaseConversion::Uppercase => quote! { .to_uppercase() },
CaseConversion::Lowercase => quote! { .to_lowercase() },
CaseConversion::Exact => quote! {},
CaseConversion::Any => quote! { .to_lowercase() },
};
let mut check_variants = Vec::new();
let mut check_variants_result = Vec::new();
for variant in variants {
if let syn::Fields::Unit = variant.fields {
let ignore_variant = get_empty_path_attribute(&variant.attrs, "ignore_variant");
if ignore_variant {
continue;
}
let variant_name = &variant.ident;
let case = get_case_conversion(&variant.attrs);
if default_value.is_none() {
if get_empty_path_attribute(&variant.attrs, "default") {
default_value = Some(variant_name);
}
}
let variant_case_conversion = if case.1 {
match case.0 {
CaseConversion::Uppercase => quote! { .to_uppercase() },
CaseConversion::Lowercase => quote! { .to_lowercase() },
CaseConversion::Exact => quote! {},
CaseConversion::Any => quote! { .to_lowercase() },
}
} else {
default_case_conversion.clone()
};
let var_case_conversion = if let CaseConversion::Any = case.0 {
quote! { .to_lowercase() }
} else {
quote! {}
};
check_variants.push(quote! {
if match std::env::var(#var_name_to_check_for) { Ok(v) => { Some((v)#var_case_conversion) }, Err(..) => None}.as_deref() == Some(&(stringify!(#variant_name)#variant_case_conversion)[..]) {
return #enum_name::#variant_name;
}
});
check_variants_result.push(quote! {
if match std::env::var(#var_name_to_check_for) { Ok(v) => { Some((v)#var_case_conversion) }, Err(..) => None}.as_deref() == Some(&(stringify!(#variant_name)#variant_case_conversion)[..]) {
return Ok(#enum_name::#variant_name);
}
});
}
}
if invalid_type.is_none() && default_value.is_none() && !panic_on_invalid {
panic!("EnvVar Enum must have either an Invalid variant or specify a variant with the #[default] attribute");
}
let invalid_value = if let Some(v) = default_value {
if panic_on_invalid {
quote! { panic!("Invalid environment variable value") }
} else {
quote! { #enum_name::#v }
}
} else {
if panic_on_invalid {
quote! { panic!("Invalid environment variable value") }
} else {
quote! { #enum_name::Invalid }
}
};
let expanded = quote! {
impl #enum_name {
fn get() -> Self {
#(#check_variants)*
#invalid_value
}
fn get_result() -> Result<Self, String> {
#(#check_variants_result)*
Err("Invalid environment variable value".to_string())
}
fn default() -> Self {
#invalid_value
}
}
};
TokenStream::from(expanded)
}
enum CaseConversion {
Uppercase,
Lowercase,
Exact,
Any,
}
fn get_var_name(attr: &[Attribute]) -> Option<String> {
for attr in attr {
if let Ok(Meta::NameValue(meta_value)) = attr.parse_meta() {
if meta_value.path.is_ident("var_name") {
match meta_value.lit {
syn::Lit::Str(ref s) => return Some(s.value()),
_ => panic!("Invalid var_name specified"),
}
}
}
}
None
}
fn get_case_conversion(attrs: &[Attribute]) -> (CaseConversion, bool) {
for attr in attrs {
if let Ok(Meta::List(meta_list)) = attr.parse_meta() {
if meta_list.path.is_ident("case") {
for nested_meta in meta_list.nested {
if let syn::NestedMeta::Meta(Meta::NameValue(MetaNameValue {
path,
lit: Lit::Str(value),
..
})) = nested_meta
{
if path.is_ident("convert") {
match value.value().as_str() {
"uppercase" => return (CaseConversion::Uppercase, true),
"lowercase" => return (CaseConversion::Lowercase, true),
"exact" => return (CaseConversion::Exact, true),
"any" => return (CaseConversion::Any, true),
_ => panic!("Invalid case conversion specified"),
}
}
}
}
}
}
}
(CaseConversion::Exact, false)
}
fn get_empty_path_attribute(attrs: &[Attribute], path: &str) -> bool {
for attr in attrs {
if let Ok(Meta::Path(meta_path)) = attr.parse_meta() {
if meta_path.is_ident(path) {
return true;
}
}
}
false
}
fn get_default_value(attrs: &[Attribute]) -> Option<String> {
for attr in attrs {
if let Ok(Meta::List(meta_list)) = attr.parse_meta() {
if meta_list.path.is_ident("default") {
for nested_meta in meta_list.nested {
if let syn::NestedMeta::Lit(Lit::Str(value)) = nested_meta {
return Some(value.value());
}
}
}
}
}
None
}
#[derive(Debug)]
enum PrimitiveType {
String,
Number,
Bool,
ImplementedEnum,
}
fn get_implemented_enum_ident(ty: &syn::Type) -> String {
match ty {
syn::Type::Path(type_path) => type_path.clone().into_token_stream().to_string(),
_ => panic!("Invalid type"),
}
}
fn get_function_primitive_type(ty: &syn::Type, attributes: &[Attribute]) -> PrimitiveType {
match ty {
syn::Type::Path(type_path) => {
let type_name = match type_path.clone().into_token_stream().to_string() {
s if s == "String" => Some(PrimitiveType::String),
s if s == "i32"
|| s == "u8"
|| s == "u16"
|| s == "u32"
|| s == "u64"
|| s == "u128"
|| s == "usize"
|| s == "i8"
|| s == "i16"
|| s == "i32"
|| s == "i64"
|| s == "i128"
|| s == "isize"
|| s == "f32"
|| s == "f64" =>
{
Some(PrimitiveType::Number)
}
s if s == "bool" => Some(PrimitiveType::Bool),
_ => None,
};
if let Some(t) = type_name {
return t;
} else {
if let Some(segment) = type_path.clone().path.segments.last() {
if segment.arguments.is_empty() {
if let Some(_attr) = attributes.clone().iter().find(|attr| {
if let Ok(meta) = attr.parse_meta() {
if let syn::Meta::Path(path) = meta {
path.is_ident("enumerated")
} else {
false
}
} else {
false
}
}) {
return PrimitiveType::ImplementedEnum;
} else {
panic!("Invalid type")
}
}
}
panic!("Invalid type")
}
}
_ => panic!("Invalid type"),
}
}
#[proc_macro_derive(ConfigStruct, attributes(default, enumerated, var_name))]
pub fn env_for_struct(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let struct_name = &input.ident;
let fields = match input.data {
syn::Data::Struct(s) => s.fields,
_ => panic!("StructVar only supports structs."),
};
let mut check_fields = Vec::new();
for field in fields {
let field_type = get_function_primitive_type(&field.ty, &field.attrs);
let field_ident = field.ident.unwrap();
let default_value_or_panic = match get_default_value(&field.attrs) {
Some(v) => match field_type {
PrimitiveType::String => quote! { #v.to_string() },
PrimitiveType::Number => quote! { #v.to_string().parse().unwrap() },
PrimitiveType::Bool => quote! { #v.to_string().parse().unwrap() },
PrimitiveType::ImplementedEnum => quote! {},
},
None => {
quote! { panic!("No environment variable or default value found for '{}'", stringify!(#field_ident)) }
}
};
let var_name_to_check_for = match get_var_name(&field.attrs) {
Some(v) => v,
None => field_ident.to_token_stream().to_string().to_uppercase(),
};
let enum_ident: syn::Ident;
match field_type {
PrimitiveType::ImplementedEnum => {
enum_ident =
syn::parse_str(&get_implemented_enum_ident(&field.ty).as_str()).unwrap()
}
_ => enum_ident = field_ident.clone(),
};
check_fields.push(match field_type {
PrimitiveType::Bool => quote! {
#field_ident: match std::env::var(#var_name_to_check_for) {
Ok(v) => match v.to_string().parse() {
Ok(v) => v,
Err(..) => false
},
Err(..) => false
},
},
PrimitiveType::String => quote! {
#field_ident: match std::env::var(#var_name_to_check_for) {
Ok(v) => v.to_string(),
Err(..) => #default_value_or_panic
},
},
PrimitiveType::ImplementedEnum => quote! {
#field_ident: match #enum_ident::get_result() {
Ok(v) => v,
Err(e) => #enum_ident::default()
},
},
PrimitiveType::Number => quote! {
#field_ident: match std::env::var(#var_name_to_check_for) {
Ok(v) => match v.to_string().trim().parse() {
Ok(v) => v,
Err(..) => #default_value_or_panic
},
Err(..) => #default_value_or_panic
},
},
});
}
let expanded = quote! {
impl #struct_name {
pub fn get() -> Self {
Self {
#(#check_fields)*
}
}
}
};
expanded.into()
}