init4_bin_base/utils/
from_env.rs

1use std::{convert::Infallible, env::VarError, num::ParseIntError, str::FromStr};
2
3/// Error type for loading from the environment. See the [`FromEnv`] trait for
4/// more information.
5#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
6pub enum FromEnvErr<Inner> {
7    /// The environment variable is missing.
8    #[error("Error reading variable {0}: {1}")]
9    EnvError(String, VarError),
10    /// The environment variable is empty.
11    #[error("Environment variable {0} is empty")]
12    Empty(String),
13    /// The environment variable is present, but the value could not be parsed.
14    #[error("Failed to parse environment variable {0}")]
15    ParseError(#[from] Inner),
16}
17
18impl FromEnvErr<Infallible> {
19    /// Convert the error into another error type.
20    pub fn infallible_into<T>(self) -> FromEnvErr<T> {
21        match self {
22            Self::EnvError(s, e) => FromEnvErr::EnvError(s, e),
23            Self::Empty(s) => FromEnvErr::Empty(s),
24            Self::ParseError(_) => unreachable!(),
25        }
26    }
27}
28
29impl<Inner> FromEnvErr<Inner> {
30    /// Create a new error from another error type.
31    pub fn from<Other>(other: FromEnvErr<Other>) -> Self
32    where
33        Inner: From<Other>,
34    {
35        match other {
36            FromEnvErr::EnvError(s, e) => Self::EnvError(s, e),
37            FromEnvErr::Empty(s) => Self::Empty(s),
38            FromEnvErr::ParseError(e) => Self::ParseError(Inner::from(e)),
39        }
40    }
41
42    /// Missing env var.
43    pub fn env_err(var: &str, e: VarError) -> Self {
44        Self::EnvError(var.to_string(), e)
45    }
46
47    /// Empty env var.
48    pub fn empty(var: &str) -> Self {
49        Self::Empty(var.to_string())
50    }
51
52    /// Error while parsing.
53    pub const fn parse_error(err: Inner) -> Self {
54        Self::ParseError(err)
55    }
56}
57
58/// Convenience function for parsing a value from the environment, if present
59/// and non-empty.
60pub fn parse_env_if_present<T: FromStr>(env_var: &str) -> Result<T, FromEnvErr<T::Err>> {
61    let s = std::env::var(env_var).map_err(|e| FromEnvErr::env_err(env_var, e))?;
62
63    if s.is_empty() {
64        Err(FromEnvErr::empty(env_var))
65    } else {
66        s.parse().map_err(Into::into)
67    }
68}
69
70/// Trait for loading from the environment.
71///
72/// This trait is for structs or other complex objects, that need to be loaded
73/// from the environment. It expects that
74///
75/// - The struct is [`Sized`] and `'static`.
76/// - The struct elements can be parsed from strings.
77/// - Struct elements are at fixed env vars, known by the type at compile time.
78///
79/// As such, unless the env is modified, these are essentially static runtime
80/// values.
81pub trait FromEnv: core::fmt::Debug + Sized + 'static {
82    /// Error type produced when loading from the environment.
83    type Error: core::error::Error;
84
85    /// Load from the environment.
86    fn from_env() -> Result<Self, FromEnvErr<Self::Error>>;
87}
88
89/// Trait for loading primitives from the environment. These are simple types
90/// that should correspond to a single environment variable. It has been
91/// implemented for common integer types, [`String`], [`url::Url`],
92/// [`tracing::Level`], and [`std::time::Duration`].
93///
94/// It aims to make [`FromEnv`] implementations easier to write, by providing a
95/// default implementation for common types.
96pub trait FromEnvVar: core::fmt::Debug + Sized + 'static {
97    /// Error type produced when parsing the primitive.
98    type Error: core::error::Error;
99
100    /// Load the primitive from the environment at the given variable.
101    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>>;
102}
103
104impl<T> FromEnvVar for Option<T>
105where
106    T: FromEnvVar,
107{
108    type Error = T::Error;
109
110    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
111        match std::env::var(env_var) {
112            Ok(s) if s.is_empty() => Ok(None),
113            Ok(_) => T::from_env_var(env_var).map(Some),
114            Err(_) => Ok(None),
115        }
116    }
117}
118
119impl FromEnvVar for String {
120    type Error = std::convert::Infallible;
121
122    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
123        std::env::var(env_var).map_err(|_| FromEnvErr::empty(env_var))
124    }
125}
126
127impl FromEnvVar for std::time::Duration {
128    type Error = ParseIntError;
129
130    fn from_env_var(s: &str) -> Result<Self, FromEnvErr<Self::Error>> {
131        u64::from_env_var(s).map(Self::from_millis)
132    }
133}
134
135macro_rules! impl_for_parseable {
136    ($($t:ty),*) => {
137        $(
138            impl FromEnvVar for $t {
139                type Error = <$t as FromStr>::Err;
140
141                fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
142                    parse_env_if_present(env_var)
143                }
144            }
145        )*
146    }
147}
148
149impl_for_parseable!(
150    u8,
151    u16,
152    u32,
153    u64,
154    u128,
155    usize,
156    i8,
157    i16,
158    i32,
159    i64,
160    i128,
161    isize,
162    url::Url,
163    tracing::Level
164);
165
166#[cfg(feature = "alloy")]
167impl_for_parseable!(
168    alloy::primitives::Address,
169    alloy::primitives::Bytes,
170    alloy::primitives::U256
171);
172
173#[cfg(feature = "alloy")]
174impl<const N: usize> FromEnvVar for alloy::primitives::FixedBytes<N> {
175    type Error = <alloy::primitives::FixedBytes<N> as FromStr>::Err;
176
177    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
178        parse_env_if_present(env_var)
179    }
180}
181
182#[cfg(test)]
183mod test {
184    use std::time::Duration;
185
186    use super::*;
187
188    fn set<T>(env: &str, val: &T)
189    where
190        T: ToString,
191    {
192        std::env::set_var(env, val.to_string());
193    }
194
195    fn load_expect_err<T>(env: &str, err: FromEnvErr<T::Error>)
196    where
197        T: FromEnvVar,
198        T::Error: PartialEq,
199    {
200        let res = T::from_env_var(env).unwrap_err();
201        assert_eq!(res, err);
202    }
203
204    fn test<T>(env: &str, val: T)
205    where
206        T: ToString + FromEnvVar + PartialEq + std::fmt::Debug,
207    {
208        set(env, &val);
209
210        let res = T::from_env_var(env).unwrap();
211        assert_eq!(res, val);
212    }
213
214    fn test_expect_err<T, U>(env: &str, value: U, err: FromEnvErr<T::Error>)
215    where
216        T: FromEnvVar,
217        U: ToString,
218        T::Error: PartialEq,
219    {
220        set(env, &value);
221        load_expect_err::<T>(env, err);
222    }
223
224    #[test]
225    fn test_primitives() {
226        test("U8", 42u8);
227        test("U16", 42u16);
228        test("U32", 42u32);
229        test("U64", 42u64);
230        test("U128", 42u128);
231        test("Usize", 42usize);
232        test("I8", 42i8);
233        test("I8-NEG", -42i16);
234        test("I16", 42i16);
235        test("I32", 42i32);
236        test("I64", 42i64);
237        test("I128", 42i128);
238        test("Isize", 42isize);
239        test("String", "hello".to_string());
240        test("Url", url::Url::parse("http://example.com").unwrap());
241        test("Level", tracing::Level::INFO);
242    }
243
244    #[test]
245    fn test_duration() {
246        let amnt = 42;
247        let val = Duration::from_millis(42);
248
249        set("Duration", &amnt);
250        let res = Duration::from_env_var("Duration").unwrap();
251
252        assert_eq!(res, val);
253    }
254
255    #[test]
256    fn test_a_few_errors() {
257        test_expect_err::<u8, _>(
258            "U8_",
259            30000u16,
260            FromEnvErr::parse_error("30000".parse::<u8>().unwrap_err()),
261        );
262
263        test_expect_err::<u8, _>("U8_", "", FromEnvErr::empty("U8_"));
264    }
265}