init4_bin_base/utils/
from_env.rs

1use std::{convert::Infallible, env::VarError, num::ParseIntError, str::FromStr};
2
3/// Details about an environment variable. This is used to generate
4/// documentation for the environment variables and by the [`FromEnv`] trait to
5/// check if necessary environment variables are present.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct EnvItemInfo {
8    /// The environment variable name.
9    pub var: &'static str,
10    /// A description of the environment variable function in the CFG.
11    pub description: &'static str,
12    /// Whether the environment variable is optional or not.
13    pub optional: bool,
14}
15
16/// Error type for loading from the environment. See the [`FromEnv`] trait for
17/// more information.
18#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
19pub enum FromEnvErr<Inner> {
20    /// The environment variable is missing.
21    #[error("Error reading variable {0}: {1}")]
22    EnvError(String, VarError),
23    /// The environment variable is empty.
24    #[error("Environment variable {0} is empty")]
25    Empty(String),
26    /// The environment variable is present, but the value could not be parsed.
27    #[error("Failed to parse environment variable {0}")]
28    ParseError(#[from] Inner),
29}
30
31impl FromEnvErr<Infallible> {
32    /// Convert the error into another error type.
33    pub fn infallible_into<T>(self) -> FromEnvErr<T> {
34        match self {
35            Self::EnvError(s, e) => FromEnvErr::EnvError(s, e),
36            Self::Empty(s) => FromEnvErr::Empty(s),
37            Self::ParseError(_) => unreachable!(),
38        }
39    }
40}
41
42impl<Inner> FromEnvErr<Inner> {
43    /// Create a new error from another error type.
44    pub fn from<Other>(other: FromEnvErr<Other>) -> Self
45    where
46        Inner: From<Other>,
47    {
48        match other {
49            FromEnvErr::EnvError(s, e) => Self::EnvError(s, e),
50            FromEnvErr::Empty(s) => Self::Empty(s),
51            FromEnvErr::ParseError(e) => Self::ParseError(Inner::from(e)),
52        }
53    }
54
55    /// Map the error to another type. This is useful for converting the error
56    /// type to a different type, while keeping the other error information
57    /// intact.
58    pub fn map<New>(self, f: impl FnOnce(Inner) -> New) -> FromEnvErr<New> {
59        match self {
60            Self::EnvError(s, e) => FromEnvErr::EnvError(s, e),
61            Self::Empty(s) => FromEnvErr::Empty(s),
62            Self::ParseError(e) => FromEnvErr::ParseError(f(e)),
63        }
64    }
65
66    /// Missing env var.
67    pub fn env_err(var: &str, e: VarError) -> Self {
68        Self::EnvError(var.to_string(), e)
69    }
70
71    /// Empty env var.
72    pub fn empty(var: &str) -> Self {
73        Self::Empty(var.to_string())
74    }
75
76    /// Error while parsing.
77    pub const fn parse_error(err: Inner) -> Self {
78        Self::ParseError(err)
79    }
80}
81
82/// Convenience function for parsing a value from the environment, if present
83/// and non-empty.
84pub fn parse_env_if_present<T: FromStr>(env_var: &str) -> Result<T, FromEnvErr<T::Err>> {
85    let s = std::env::var(env_var).map_err(|e| FromEnvErr::env_err(env_var, e))?;
86
87    if s.is_empty() {
88        Err(FromEnvErr::empty(env_var))
89    } else {
90        s.parse().map_err(Into::into)
91    }
92}
93
94/// Trait for loading from the environment.
95///
96/// This trait is for structs or other complex objects, that need to be loaded
97/// from the environment. It expects that
98///
99/// - The struct is [`Sized`] and `'static`.
100/// - The struct elements can be parsed from strings.
101/// - Struct elements are at fixed env vars, known by the type at compile time.
102///
103/// As such, unless the env is modified, these are essentially static runtime
104/// values.
105pub trait FromEnv: core::fmt::Debug + Sized + 'static {
106    /// Error type produced when loading from the environment.
107    type Error: core::error::Error;
108
109    /// Get the required environment variable names for this type.
110    ///
111    /// ## Note
112    ///
113    /// This MUST include the environment variable names for all fields in the
114    /// struct, including optional vars.
115    fn inventory() -> Vec<&'static EnvItemInfo>;
116
117    /// Get a list of missing environment variables.
118    ///
119    /// This will check all environment variables in the inventory, and return
120    /// a list of those that are non-optional and missing. This is useful for
121    /// reporting missing environment variables.
122    fn check_inventory() -> Result<(), Vec<&'static EnvItemInfo>> {
123        let mut missing = Vec::new();
124        for var in Self::inventory() {
125            if std::env::var(var.var).is_err() && !var.optional {
126                missing.push(var);
127            }
128        }
129        if missing.is_empty() {
130            Ok(())
131        } else {
132            Err(missing)
133        }
134    }
135
136    /// Load from the environment.
137    fn from_env() -> Result<Self, FromEnvErr<Self::Error>>;
138}
139
140/// Trait for loading primitives from the environment. These are simple types
141/// that should correspond to a single environment variable. It has been
142/// implemented for common integer types, [`String`], [`url::Url`],
143/// [`tracing::Level`], and [`std::time::Duration`].
144///
145/// It aims to make [`FromEnv`] implementations easier to write, by providing a
146/// default implementation for common types.
147pub trait FromEnvVar: core::fmt::Debug + Sized + 'static {
148    /// Error type produced when parsing the primitive.
149    type Error: core::error::Error;
150
151    /// Load the primitive from the environment at the given variable.
152    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>>;
153}
154
155impl<T> FromEnvVar for Option<T>
156where
157    T: FromEnvVar,
158{
159    type Error = T::Error;
160
161    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
162        match std::env::var(env_var) {
163            Ok(s) if s.is_empty() => Ok(None),
164            Ok(_) => T::from_env_var(env_var).map(Some),
165            Err(_) => Ok(None),
166        }
167    }
168}
169
170impl FromEnvVar for String {
171    type Error = std::convert::Infallible;
172
173    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
174        std::env::var(env_var).map_err(|_| FromEnvErr::empty(env_var))
175    }
176}
177
178impl FromEnvVar for std::time::Duration {
179    type Error = ParseIntError;
180
181    fn from_env_var(s: &str) -> Result<Self, FromEnvErr<Self::Error>> {
182        u64::from_env_var(s).map(Self::from_millis)
183    }
184}
185
186macro_rules! impl_for_parseable {
187    ($($t:ty),*) => {
188        $(
189            impl FromEnvVar for $t {
190                type Error = <$t as FromStr>::Err;
191
192                fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
193                    parse_env_if_present(env_var)
194                }
195            }
196        )*
197    }
198}
199
200impl_for_parseable!(
201    u8,
202    u16,
203    u32,
204    u64,
205    u128,
206    usize,
207    i8,
208    i16,
209    i32,
210    i64,
211    i128,
212    isize,
213    url::Url,
214    tracing::Level
215);
216
217#[cfg(feature = "alloy")]
218impl_for_parseable!(
219    alloy::primitives::Address,
220    alloy::primitives::Bytes,
221    alloy::primitives::U256
222);
223
224#[cfg(feature = "alloy")]
225impl<const N: usize> FromEnvVar for alloy::primitives::FixedBytes<N> {
226    type Error = <alloy::primitives::FixedBytes<N> as FromStr>::Err;
227
228    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
229        parse_env_if_present(env_var)
230    }
231}
232
233impl FromEnvVar for bool {
234    type Error = std::str::ParseBoolError;
235
236    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
237        let s: String = std::env::var(env_var).map_err(|e| FromEnvErr::env_err(env_var, e))?;
238        Ok(!s.is_empty())
239    }
240}
241
242#[cfg(test)]
243mod test {
244    use std::time::Duration;
245
246    use super::*;
247
248    fn set<T>(env: &str, val: &T)
249    where
250        T: ToString,
251    {
252        std::env::set_var(env, val.to_string());
253    }
254
255    fn load_expect_err<T>(env: &str, err: FromEnvErr<T::Error>)
256    where
257        T: FromEnvVar,
258        T::Error: PartialEq,
259    {
260        let res = T::from_env_var(env).unwrap_err();
261        assert_eq!(res, err);
262    }
263
264    fn test<T>(env: &str, val: T)
265    where
266        T: ToString + FromEnvVar + PartialEq + std::fmt::Debug,
267    {
268        set(env, &val);
269
270        let res = T::from_env_var(env).unwrap();
271        assert_eq!(res, val);
272    }
273
274    fn test_expect_err<T, U>(env: &str, value: U, err: FromEnvErr<T::Error>)
275    where
276        T: FromEnvVar,
277        U: ToString,
278        T::Error: PartialEq,
279    {
280        set(env, &value);
281        load_expect_err::<T>(env, err);
282    }
283
284    #[test]
285    fn test_primitives() {
286        test("U8", 42u8);
287        test("U16", 42u16);
288        test("U32", 42u32);
289        test("U64", 42u64);
290        test("U128", 42u128);
291        test("Usize", 42usize);
292        test("I8", 42i8);
293        test("I8-NEG", -42i16);
294        test("I16", 42i16);
295        test("I32", 42i32);
296        test("I64", 42i64);
297        test("I128", 42i128);
298        test("Isize", 42isize);
299        test("String", "hello".to_string());
300        test("Url", url::Url::parse("http://example.com").unwrap());
301        test("Level", tracing::Level::INFO);
302    }
303
304    #[test]
305    fn test_duration() {
306        let amnt = 42;
307        let val = Duration::from_millis(42);
308
309        set("Duration", &amnt);
310        let res = Duration::from_env_var("Duration").unwrap();
311
312        assert_eq!(res, val);
313    }
314
315    #[test]
316    fn test_a_few_errors() {
317        test_expect_err::<u8, _>(
318            "U8_",
319            30000u16,
320            FromEnvErr::parse_error("30000".parse::<u8>().unwrap_err()),
321        );
322
323        test_expect_err::<u8, _>("U8_", "", FromEnvErr::empty("U8_"));
324    }
325}