env_extract/
lib.rs

1//! # env-extract
2//!
3//! The `env-extract` crate provides convenient methods of extracting environment variables into
4//! different data types.
5//!
6//! The crate includes two procedural macros: `ConfigStruct` and `EnvVar`, which can be used to derive
7//! traits and implement automatic extraction of values from environment variables.
8//!
9//! ## Usage
10//!
11//! To use the `EnvVar` and `ConfigStruct` macros, add `env-extract` as a dependency in your `Cargo.toml` file:
12//!
13//! ```toml
14//! [dependencies]
15//! env-extract = "0.1.22"
16//! ```
17//!
18//! Then, in your Rust code, import the procedural macros by adding the following line:
19//!
20//! ```rust
21//! use env_extract::{EnvVar, ConfigStruct};
22//! ```
23//!
24//! ## ConfigStruct Macro
25//!
26//! The `ConfigStruct` macro is applied to structs and derives the `ConfigStruct` trait. This trait
27//! allows for the easy retrieval of a struct from environment variables by pulling each field from
28//! the environment variables by name. The following types are valid for fields of a struct:
29//!
30//! - `String`
31//! - `bool`
32//! - `u8`, `u16`, `u32`, `u64`, `u128`
33//! - `i8`, `i16`, `i32`, `i64`, `i128`
34//! - `f32`, `f64`
35//! - An enum that derives `EnvVar`
36//!
37//! The `ConfigStruct` macro supports the following attributes on the fields in the struct:
38//!
39//! - `default`: Sets a default value for the field. If this is not provided, the macro will panic
40//!   if the environment variable is not set.
41//! - `env_var`: Sets the name of the environment variable to use for this field. If this is not
42//!   provided, the macro will use the name of the field in uppercase as the environment variable
43//!   name.
44//! - `enumerated`: Identifies an enum that implements the `EnvVar` trait. The macro will parse the
45//!   environment variable to the enum type.
46//!
47//! ## EnvVar Macro
48//!
49//! The `EnvVar` macro is applied to enums and implements the `EnvVar` trait, which provides a
50//! `.get()` method to retrieve a value of type `T` from an environment variable. The macro parses
51//! the environment variable to the enum type.
52//!
53//! The `EnvVar` macro requires one of the following conditions to be met for the enum:
54//!
55//! - A variant called "Invalid", which will be returned if the environment variable does not match
56//!   any of the variants.
57//! - A variant marked with `#[default]`, which will be returned if the environment variable does
58//!   not match any of the variants.
59//! - The enum to be marked with `#[panic_on_invalid]`, which will panic if the environment variable
60//!   does not match any of the variants.
61//!
62//! The `EnvVar` macro allows for the following attributes on the enum itself:
63//!
64//! - `#[env_var = "FOO"]`: Set a custom environment variable name to search for. Defaults to the
65//!   name of the enum in uppercase.
66//! - `#[panic_on_invalid]`: Panics if a valid variant is not found.
67//! - `#[case(convert = "[uppercase|lowercase|exact|any]")]`: Converts all environment variable
68//!   values to a specific case before comparing them to map the valid variant. This attribute is
69//!   overwritten if the variant also contains this attribute.
70//!
71//! The `EnvVar` macro also supports the following attributes on the enum variants:
72//!
73//! - `#[case = "[uppercase|lowercase|exact|any]"]`: Specifies case conversion for the annotated
74//!   enum variant. The `uppercase` and `lowercase` options convert the environment variable value
75//!   to uppercase or lowercase before comparing it to the variant name. The `exact` option compares
76//!   the environment variable value to the variant name without any case conversion. The `any`
77//!   option converts both the environment variable value and the variant name to lowercase before
78//!   comparing them.
79//! - `#[default]`: Specifies the default enum variant.
80//! - `#[ignore_variant]`: Ignores the annotated enum variant when checking for a match.
81//!
82//! ## Example Usage
83//!
84//! ```rust
85//! use env_extract::{ConfigStruct, EnvVar};
86//!
87//! #[derive(Debug, EnvVar)]
88//! #[var_name = "DATABASE_TYPE"]
89//! #[panic_on_invalid]
90//! #[case(convert = "lowercase")]
91//! enum DatabaseType {
92//!     Postgres,
93//!     Mysql,
94//!     Sqlite,
95//! }
96//!
97//! #[derive(ConfigStruct, Debug)]
98//! struct Config {
99//!     db_host: String,
100//!     db_port: u16,
101//!     use_tls: bool,
102//!
103//!     #[enumerated]
104//!     db_type: DatabaseType,
105//! }
106//!
107//! fn main() {
108//!     std::env::set_var("DB_HOST", "localhost");
109//!     std::env::set_var("DB_PORT", "5432");
110//!     std::env::set_var("USE_TLS", "true");
111//!     std::env::set_var("DATABASE_TYPE", "postgres");
112//!
113//!     let config = Config::get();
114//!
115//!     assert_eq!(config.db_host, "localhost");
116//!     assert_eq!(config.db_port, 5432);
117//!     assert_eq!(config.use_tls, true);
118//!     assert!(matches!(config.db_type, DatabaseType::Postgres));
119//! }
120//! ```
121//!
122//! In the example above, the `ConfigStruct` macro is used to derive the `ConfigStruct` trait for
123//! the `Config` struct, enabling easy retrieval of values from environment variables. The `EnvVar`
124//! trait is derived for the `DatabaseType` enum using the `EnvVar` macro, allowing the extraction
125//! of the enum variant from the "DATABASE_TYPE" environment variable. The environment variable
126//! values are parsed and converted according to the specified case conversions. Finally, the `Config`
127//! struct is populated with values retrieved from environment variables, and assertions are used to
128//! validate the extracted values.
129
130use proc_macro::TokenStream;
131use quote::{quote, ToTokens};
132use syn::{parse_macro_input, Attribute, DeriveInput, Lit, Meta, MetaNameValue};
133
134/// This proc macro is applied to enums and implements the `EnvVar` trait, which provides a `.get()`
135/// method to retrieve a value of type `T` from an environment variable.
136///
137/// The macro parses the environment variable to the enum type and requires one of the following:
138/// - A variant called "Invalid", which will be returned if the environment variable does not match
139///   any of the variants.
140/// - A variant marked with `#[default]`, which will be returned if the environment variable does
141///   not match any of the variants.
142/// - The enum to be marked with `#[panic_on_invalid]`, which will panic if the environment variable
143///   does not match any of the variants.
144///
145/// The macro supports the following attributes on the enum itself:
146/// - `#[env_var = "FOO"]`: Set a custom environment variable name to search for. Defaults to the
147///   name of the enum in uppercase.
148/// - `#[panic_on_invalid]`: Panics if a valid variant is not found.
149/// - `#[case(convert = "[uppercase|lowercase|exact|any]")]`: Converts all environment variable
150///   values to a specific case before comparing them to map the valid variant. This attribute is
151///   overwritten if the variant also contains this attribute.
152///
153/// The macro also supports the following attributes on the enum variants:
154/// - `#[case = "[uppercase|lowercase|exact|any]"]`: Specifies case conversion for the annotated
155///   enum variant. The `uppercase` and `lowercase` options convert the environment variable value
156///   to uppercase or lowercase before comparing it to the variant name. The `exact` option compares
157///   the environment variable value to the variant name without any case conversion. The `any`
158///   option converts both the environment variable value and the variant name to lowercase before
159///   comparing them.
160/// - `#[default]`: Specifies the default enum variant.
161/// - `#[ignore_variant]`: Ignores the annotated enum variant when checking for a match.
162///
163/// Example usage:
164///
165/// ```rust
166/// use env_extract::EnvVar;
167///
168/// #[derive(EnvVar)]
169/// #[var_name = "DATABASE_TYPE"]
170/// #[case(convert = "uppercase")]
171/// enum DatabaseType {
172///     #[case(convert = "lowercase")]
173///     Postgres,
174///     Mysql,
175///
176///     #[default]
177///     Sqlite,
178/// }
179///
180/// fn main() {
181///     std::env::set_var("DATABASE_TYPE", "MYSQL");
182///
183///     let database_type = DatabaseType::get();
184///     assert!(matches!(database_type, DatabaseType::Mysql));
185/// }
186/// ```
187///
188/// In the example above, the `EnvVar` trait is implemented for the `DatabaseType` enum, allowing
189/// the retrieval of a value from the "DATABASE_TYPE" environment variable. The enum variants are
190/// compared to the environment variable value after applying case conversions specified by the
191/// `#[case]` attributes. The `Mysql` variant is matched since the environment variable value is
192/// converted to uppercase and the variant name to lowercase, resulting in a match.
193#[proc_macro_derive(
194    EnvVar,
195    attributes(case, var_name, default, panic_on_invalid, ignore_variant)
196)]
197pub fn enum_from_env(input: TokenStream) -> TokenStream {
198    let input = parse_macro_input!(input as DeriveInput);
199
200    let enum_name = &input.ident;
201
202    let var_name_to_check_for = match get_var_name(&input.attrs) {
203        Some(v) => v,
204        None => enum_name.to_string().to_uppercase(),
205    };
206
207    let variants = match input.data {
208        syn::Data::Enum(ref variants) => &variants.variants,
209        _ => panic!("EnvVar can only be derived for enums"),
210    };
211
212    let mut invalid_type: Option<&syn::Ident> = None;
213
214    for variant in variants {
215        if &variant.ident.to_token_stream().to_string() == "Invalid" {
216            invalid_type = Some(&variant.ident);
217        };
218    }
219
220    let mut default_value: Option<&syn::Ident> = None;
221
222    let panic_on_invalid = input.attrs.iter().any(|attr| {
223        if let Ok(Meta::Path(path)) = attr.parse_meta() {
224            path.is_ident("panic_on_invalid")
225        } else {
226            false
227        }
228    });
229
230    let default_case = get_case_conversion(&input.attrs);
231    let default_case_conversion = match default_case.0 {
232        CaseConversion::Uppercase => quote! { .to_uppercase() },
233        CaseConversion::Lowercase => quote! { .to_lowercase() },
234        CaseConversion::Exact => quote! {},
235        CaseConversion::Any => quote! { .to_lowercase() },
236    };
237
238    let mut check_variants = Vec::new();
239    let mut check_variants_result = Vec::new();
240    for variant in variants {
241        if let syn::Fields::Unit = variant.fields {
242            let ignore_variant = get_empty_path_attribute(&variant.attrs, "ignore_variant");
243
244            if ignore_variant {
245                continue;
246            }
247
248            let variant_name = &variant.ident;
249
250            let case = get_case_conversion(&variant.attrs);
251            if default_value.is_none() {
252                if get_empty_path_attribute(&variant.attrs, "default") {
253                    default_value = Some(variant_name);
254                }
255            }
256
257            let variant_case_conversion = if case.1 {
258                match case.0 {
259                    CaseConversion::Uppercase => quote! { .to_uppercase() },
260                    CaseConversion::Lowercase => quote! { .to_lowercase() },
261                    CaseConversion::Exact => quote! {},
262                    CaseConversion::Any => quote! { .to_lowercase() },
263                }
264            } else {
265                default_case_conversion.clone()
266            };
267
268            let var_case_conversion = if let CaseConversion::Any = case.0 {
269                quote! { .to_lowercase() }
270            } else {
271                quote! {}
272            };
273
274            check_variants.push(quote! {
275                if match std::env::var(#var_name_to_check_for) { Ok(v) => { Some((v)#var_case_conversion) }, Err(..) => None}.as_deref() == Some(&(stringify!(#variant_name)#variant_case_conversion)[..]) {
276                    return #enum_name::#variant_name;
277                }
278            });
279
280            check_variants_result.push(quote! {
281                if match std::env::var(#var_name_to_check_for) { Ok(v) => { Some((v)#var_case_conversion) }, Err(..) => None}.as_deref() == Some(&(stringify!(#variant_name)#variant_case_conversion)[..]) {
282                    return Ok(#enum_name::#variant_name);
283                }
284            });
285        }
286    }
287
288    if invalid_type.is_none() && default_value.is_none() && !panic_on_invalid {
289        panic!("EnvVar Enum must have either an Invalid variant or specify a variant with the #[default] attribute");
290    }
291
292    let invalid_value = if let Some(v) = default_value {
293        if panic_on_invalid {
294            quote! { panic!("Invalid environment variable value") }
295        } else {
296            quote! { #enum_name::#v }
297        }
298    } else {
299        if panic_on_invalid {
300            quote! { panic!("Invalid environment variable value") }
301        } else {
302            quote! { #enum_name::Invalid }
303        }
304    };
305
306    let expanded = quote! {
307        impl #enum_name {
308            fn get() -> Self {
309                #(#check_variants)*
310
311                #invalid_value
312            }
313
314            fn get_result() -> Result<Self, String> {
315                #(#check_variants_result)*
316
317                Err("Invalid environment variable value".to_string())
318            }
319
320            fn default() -> Self {
321                #invalid_value
322            }
323        }
324    };
325
326    TokenStream::from(expanded)
327}
328
329enum CaseConversion {
330    Uppercase,
331    Lowercase,
332    Exact,
333    Any,
334}
335
336fn get_var_name(attr: &[Attribute]) -> Option<String> {
337    for attr in attr {
338        if let Ok(Meta::NameValue(meta_value)) = attr.parse_meta() {
339            if meta_value.path.is_ident("var_name") {
340                match meta_value.lit {
341                    syn::Lit::Str(ref s) => return Some(s.value()),
342                    _ => panic!("Invalid var_name specified"),
343                }
344            }
345        }
346    }
347    None
348}
349
350fn get_case_conversion(attrs: &[Attribute]) -> (CaseConversion, bool) {
351    for attr in attrs {
352        if let Ok(Meta::List(meta_list)) = attr.parse_meta() {
353            if meta_list.path.is_ident("case") {
354                for nested_meta in meta_list.nested {
355                    if let syn::NestedMeta::Meta(Meta::NameValue(MetaNameValue {
356                        path,
357                        lit: Lit::Str(value),
358                        ..
359                    })) = nested_meta
360                    {
361                        if path.is_ident("convert") {
362                            match value.value().as_str() {
363                                "uppercase" => return (CaseConversion::Uppercase, true),
364                                "lowercase" => return (CaseConversion::Lowercase, true),
365                                "exact" => return (CaseConversion::Exact, true),
366                                "any" => return (CaseConversion::Any, true),
367                                _ => panic!("Invalid case conversion specified"),
368                            }
369                        }
370                    }
371                }
372            }
373        }
374    }
375
376    (CaseConversion::Exact, false)
377}
378
379fn get_empty_path_attribute(attrs: &[Attribute], path: &str) -> bool {
380    for attr in attrs {
381        if let Ok(Meta::Path(meta_path)) = attr.parse_meta() {
382            if meta_path.is_ident(path) {
383                return true;
384            }
385        }
386    }
387    false
388}
389
390fn get_default_value(attrs: &[Attribute]) -> Option<String> {
391    for attr in attrs {
392        if let Ok(Meta::List(meta_list)) = attr.parse_meta() {
393            if meta_list.path.is_ident("default") {
394                for nested_meta in meta_list.nested {
395                    if let syn::NestedMeta::Lit(Lit::Str(value)) = nested_meta {
396                        return Some(value.value());
397                    }
398                }
399            }
400        }
401    }
402    None
403}
404
405#[derive(Debug)]
406
407enum PrimitiveType {
408    String,
409    Number,
410    Bool,
411    ImplementedEnum,
412}
413
414fn get_implemented_enum_ident(ty: &syn::Type) -> String {
415    match ty {
416        syn::Type::Path(type_path) => type_path.clone().into_token_stream().to_string(),
417        _ => panic!("Invalid type"),
418    }
419}
420
421fn get_function_primitive_type(ty: &syn::Type, attributes: &[Attribute]) -> PrimitiveType {
422    match ty {
423        syn::Type::Path(type_path) => {
424            let type_name = match type_path.clone().into_token_stream().to_string() {
425                s if s == "String" => Some(PrimitiveType::String),
426                s if s == "i32"
427                    || s == "u8"
428                    || s == "u16"
429                    || s == "u32"
430                    || s == "u64"
431                    || s == "u128"
432                    || s == "usize"
433                    || s == "i8"
434                    || s == "i16"
435                    || s == "i32"
436                    || s == "i64"
437                    || s == "i128"
438                    || s == "isize"
439                    || s == "f32"
440                    || s == "f64" =>
441                {
442                    Some(PrimitiveType::Number)
443                }
444                s if s == "bool" => Some(PrimitiveType::Bool),
445                _ => None,
446            };
447
448            if let Some(t) = type_name {
449                return t;
450            } else {
451                if let Some(segment) = type_path.clone().path.segments.last() {
452                    if segment.arguments.is_empty() {
453                        if let Some(_attr) = attributes.clone().iter().find(|attr| {
454                            if let Ok(meta) = attr.parse_meta() {
455                                if let syn::Meta::Path(path) = meta {
456                                    path.is_ident("enumerated")
457                                } else {
458                                    false
459                                }
460                            } else {
461                                false
462                            }
463                        }) {
464                            return PrimitiveType::ImplementedEnum;
465                        } else {
466                            panic!("Invalid type")
467                        }
468                    }
469                }
470                panic!("Invalid type")
471            }
472        }
473        _ => panic!("Invalid type"),
474    }
475}
476
477/// This proc macro derives the `ConfigStruct` trait for a struct, enabling easy extraction of
478/// fields from environment variables and parsing them to the correct type.
479///
480/// The macro supports the following attributes for struct fields:
481///
482/// - `default`: Sets a default value for the field. If not provided, the macro will panic if the
483///              environment variable is not set.
484/// - `env_var`: Sets the name of the environment variable to use for the field. If not provided,
485///              the macro will use the name of the field in uppercase as the environment variable
486///              name.
487/// - `enumerated`: Identifies an enum that implements the `EnvVar` trait. The macro will parse the
488///                 environment variable to the enum type.
489///
490/// Example usage:
491///
492/// ```rust
493/// use env_extract::ConfigStruct;
494/// use env_extract::EnvVar;
495/// #[derive(Debug, EnvVar)]
496/// #[var_name = "DATABASE_TYPE"]
497/// #[panic_on_invalid]
498/// #[case(convert = "lowercase")]
499/// enum DatabaseType {
500///     Postgres,
501///     Mysql,
502///     Sqlite,
503/// }
504///
505/// #[derive(ConfigStruct, Debug)]
506/// struct Config {
507///     db_host: String,
508///     db_port: u16,
509///     use_tls: bool,
510///
511///     #[enumerated]
512///     db_type: DatabaseType,
513/// }
514///
515/// fn main() {
516///     std::env::set_var("DB_HOST", "localhost");
517///     std::env::set_var("DB_PORT", "5432");
518///     std::env::set_var("USE_TLS", "true");
519///     std::env::set_var("DATABASE_TYPE", "postgres");
520///
521///     let config = Config::get();
522///
523///     assert_eq!(config.db_host, "localhost");
524///     assert_eq!(config.db_port, 5432);
525///     assert_eq!(config.use_tls, true);
526///     assert!(matches!(config.db_type, DatabaseType::Postgres));
527/// }
528/// ```
529///
530/// In the example above, the `ConfigStruct` trait is derived for the `Config` struct, allowing
531/// easy extraction of fields from environment variables. The `db_host`, `db_port`, and `use_tls`
532/// fields are extracted as `String`, `u16`, and `bool` types, respectively. The `db_type` field is
533/// extracted as an enum type `DatabaseType`, which is parsed from the environment variable named
534/// `DATABASE_TYPE` and converted to lowercase.
535#[proc_macro_derive(ConfigStruct, attributes(default, enumerated, var_name))]
536pub fn env_for_struct(input: TokenStream) -> TokenStream {
537    let input = parse_macro_input!(input as DeriveInput);
538
539    let struct_name = &input.ident;
540    let fields = match input.data {
541        syn::Data::Struct(s) => s.fields,
542        _ => panic!("StructVar only supports structs."),
543    };
544
545    let mut check_fields = Vec::new();
546    for field in fields {
547        let field_type = get_function_primitive_type(&field.ty, &field.attrs);
548        let field_ident = field.ident.unwrap();
549
550        let default_value_or_panic = match get_default_value(&field.attrs) {
551            Some(v) => match field_type {
552                PrimitiveType::String => quote! { #v.to_string() },
553                PrimitiveType::Number => quote! { #v.to_string().parse().unwrap() },
554                PrimitiveType::Bool => quote! { #v.to_string().parse().unwrap() },
555                PrimitiveType::ImplementedEnum => quote! {},
556            },
557            None => {
558                quote! { panic!("No environment variable or default value found for '{}'", stringify!(#field_ident)) }
559            }
560        };
561
562        let var_name_to_check_for = match get_var_name(&field.attrs) {
563            Some(v) => v,
564            None => field_ident.to_token_stream().to_string().to_uppercase(),
565        };
566
567        let enum_ident: syn::Ident;
568        match field_type {
569            PrimitiveType::ImplementedEnum => {
570                enum_ident =
571                    syn::parse_str(&get_implemented_enum_ident(&field.ty).as_str()).unwrap()
572            }
573            _ => enum_ident = field_ident.clone(),
574        };
575
576        check_fields.push(match field_type {
577            PrimitiveType::Bool => quote! {
578                 #field_ident: match std::env::var(#var_name_to_check_for) {
579                    Ok(v) => match v.to_string().parse() {
580                        Ok(v) => v,
581                        Err(..) => false
582                    },
583                    Err(..) => false
584                 },
585            },
586            PrimitiveType::String => quote! {
587                 #field_ident: match std::env::var(#var_name_to_check_for) {
588                    Ok(v) => v.to_string(),
589                    Err(..) => #default_value_or_panic
590                 },
591            },
592            PrimitiveType::ImplementedEnum => quote! {
593                #field_ident: match #enum_ident::get_result() {
594                    Ok(v) => v,
595                    Err(e) =>  #enum_ident::default()
596                },
597            },
598            PrimitiveType::Number => quote! {
599                 #field_ident: match std::env::var(#var_name_to_check_for) {
600                    Ok(v) => match v.to_string().trim().parse() {
601                        Ok(v) => v,
602                        Err(..) => #default_value_or_panic
603                    },
604                    Err(..) => #default_value_or_panic
605                 },
606            },
607        });
608    }
609
610    let expanded = quote! {
611        impl #struct_name {
612            pub fn get() -> Self {
613                Self {
614                    #(#check_fields)*
615                }
616            }
617        }
618    };
619
620    expanded.into()
621}