init4_bin_base/utils/
from_env.rs

1use std::{convert::Infallible, env::VarError, num::ParseIntError, str::FromStr};
2
3/// Error type for loading from the environment. See the [`FromEnv`] trait for
4/// more information.
5#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
6pub enum FromEnvErr<Inner> {
7    /// The environment variable is missing.
8    #[error("Error reading variable {0}: {1}")]
9    EnvError(String, VarError),
10    /// The environment variable is empty.
11    #[error("Environment variable {0} is empty")]
12    Empty(String),
13    /// The environment variable is present, but the value could not be parsed.
14    #[error("Failed to parse environment variable {0}")]
15    ParseError(#[from] Inner),
16}
17
18impl FromEnvErr<Infallible> {
19    /// Convert the error into another error type.
20    pub fn infallible_into<T>(self) -> FromEnvErr<T> {
21        match self {
22            Self::EnvError(s, e) => FromEnvErr::EnvError(s, e),
23            Self::Empty(s) => FromEnvErr::Empty(s),
24            Self::ParseError(_) => unreachable!(),
25        }
26    }
27}
28
29impl<Inner> FromEnvErr<Inner> {
30    /// Create a new error from another error type.
31    pub fn from<Other>(other: FromEnvErr<Other>) -> Self
32    where
33        Inner: From<Other>,
34    {
35        match other {
36            FromEnvErr::EnvError(s, e) => Self::EnvError(s, e),
37            FromEnvErr::Empty(s) => Self::Empty(s),
38            FromEnvErr::ParseError(e) => Self::ParseError(Inner::from(e)),
39        }
40    }
41
42    /// Map the error to another type. This is useful for converting the error
43    /// type to a different type, while keeping the other error information
44    /// intact.
45    pub fn map<New>(self, f: impl FnOnce(Inner) -> New) -> FromEnvErr<New> {
46        match self {
47            Self::EnvError(s, e) => FromEnvErr::EnvError(s, e),
48            Self::Empty(s) => FromEnvErr::Empty(s),
49            Self::ParseError(e) => FromEnvErr::ParseError(f(e)),
50        }
51    }
52
53    /// Missing env var.
54    pub fn env_err(var: &str, e: VarError) -> Self {
55        Self::EnvError(var.to_string(), e)
56    }
57
58    /// Empty env var.
59    pub fn empty(var: &str) -> Self {
60        Self::Empty(var.to_string())
61    }
62
63    /// Error while parsing.
64    pub const fn parse_error(err: Inner) -> Self {
65        Self::ParseError(err)
66    }
67}
68
69/// Convenience function for parsing a value from the environment, if present
70/// and non-empty.
71pub fn parse_env_if_present<T: FromStr>(env_var: &str) -> Result<T, FromEnvErr<T::Err>> {
72    let s = std::env::var(env_var).map_err(|e| FromEnvErr::env_err(env_var, e))?;
73
74    if s.is_empty() {
75        Err(FromEnvErr::empty(env_var))
76    } else {
77        s.parse().map_err(Into::into)
78    }
79}
80
81/// Trait for loading from the environment.
82///
83/// This trait is for structs or other complex objects, that need to be loaded
84/// from the environment. It expects that
85///
86/// - The struct is [`Sized`] and `'static`.
87/// - The struct elements can be parsed from strings.
88/// - Struct elements are at fixed env vars, known by the type at compile time.
89///
90/// As such, unless the env is modified, these are essentially static runtime
91/// values.
92pub trait FromEnv: core::fmt::Debug + Sized + 'static {
93    /// Error type produced when loading from the environment.
94    type Error: core::error::Error;
95
96    /// Load from the environment.
97    fn from_env() -> Result<Self, FromEnvErr<Self::Error>>;
98}
99
100/// Trait for loading primitives from the environment. These are simple types
101/// that should correspond to a single environment variable. It has been
102/// implemented for common integer types, [`String`], [`url::Url`],
103/// [`tracing::Level`], and [`std::time::Duration`].
104///
105/// It aims to make [`FromEnv`] implementations easier to write, by providing a
106/// default implementation for common types.
107pub trait FromEnvVar: core::fmt::Debug + Sized + 'static {
108    /// Error type produced when parsing the primitive.
109    type Error: core::error::Error;
110
111    /// Load the primitive from the environment at the given variable.
112    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>>;
113}
114
115impl<T> FromEnvVar for Option<T>
116where
117    T: FromEnvVar,
118{
119    type Error = T::Error;
120
121    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
122        match std::env::var(env_var) {
123            Ok(s) if s.is_empty() => Ok(None),
124            Ok(_) => T::from_env_var(env_var).map(Some),
125            Err(_) => Ok(None),
126        }
127    }
128}
129
130impl FromEnvVar for String {
131    type Error = std::convert::Infallible;
132
133    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
134        std::env::var(env_var).map_err(|_| FromEnvErr::empty(env_var))
135    }
136}
137
138impl FromEnvVar for std::time::Duration {
139    type Error = ParseIntError;
140
141    fn from_env_var(s: &str) -> Result<Self, FromEnvErr<Self::Error>> {
142        u64::from_env_var(s).map(Self::from_millis)
143    }
144}
145
146macro_rules! impl_for_parseable {
147    ($($t:ty),*) => {
148        $(
149            impl FromEnvVar for $t {
150                type Error = <$t as FromStr>::Err;
151
152                fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
153                    parse_env_if_present(env_var)
154                }
155            }
156        )*
157    }
158}
159
160impl_for_parseable!(
161    u8,
162    u16,
163    u32,
164    u64,
165    u128,
166    usize,
167    i8,
168    i16,
169    i32,
170    i64,
171    i128,
172    isize,
173    url::Url,
174    tracing::Level
175);
176
177#[cfg(feature = "alloy")]
178impl_for_parseable!(
179    alloy::primitives::Address,
180    alloy::primitives::Bytes,
181    alloy::primitives::U256
182);
183
184#[cfg(feature = "alloy")]
185impl<const N: usize> FromEnvVar for alloy::primitives::FixedBytes<N> {
186    type Error = <alloy::primitives::FixedBytes<N> as FromStr>::Err;
187
188    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
189        parse_env_if_present(env_var)
190    }
191}
192
193impl FromEnvVar for bool {
194    type Error = std::str::ParseBoolError;
195
196    fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
197        let s: String = std::env::var(env_var).map_err(|e| FromEnvErr::env_err(env_var, e))?;
198        Ok(!s.is_empty())
199    }
200}
201
202#[cfg(test)]
203mod test {
204    use std::time::Duration;
205
206    use super::*;
207
208    fn set<T>(env: &str, val: &T)
209    where
210        T: ToString,
211    {
212        std::env::set_var(env, val.to_string());
213    }
214
215    fn load_expect_err<T>(env: &str, err: FromEnvErr<T::Error>)
216    where
217        T: FromEnvVar,
218        T::Error: PartialEq,
219    {
220        let res = T::from_env_var(env).unwrap_err();
221        assert_eq!(res, err);
222    }
223
224    fn test<T>(env: &str, val: T)
225    where
226        T: ToString + FromEnvVar + PartialEq + std::fmt::Debug,
227    {
228        set(env, &val);
229
230        let res = T::from_env_var(env).unwrap();
231        assert_eq!(res, val);
232    }
233
234    fn test_expect_err<T, U>(env: &str, value: U, err: FromEnvErr<T::Error>)
235    where
236        T: FromEnvVar,
237        U: ToString,
238        T::Error: PartialEq,
239    {
240        set(env, &value);
241        load_expect_err::<T>(env, err);
242    }
243
244    #[test]
245    fn test_primitives() {
246        test("U8", 42u8);
247        test("U16", 42u16);
248        test("U32", 42u32);
249        test("U64", 42u64);
250        test("U128", 42u128);
251        test("Usize", 42usize);
252        test("I8", 42i8);
253        test("I8-NEG", -42i16);
254        test("I16", 42i16);
255        test("I32", 42i32);
256        test("I64", 42i64);
257        test("I128", 42i128);
258        test("Isize", 42isize);
259        test("String", "hello".to_string());
260        test("Url", url::Url::parse("http://example.com").unwrap());
261        test("Level", tracing::Level::INFO);
262    }
263
264    #[test]
265    fn test_duration() {
266        let amnt = 42;
267        let val = Duration::from_millis(42);
268
269        set("Duration", &amnt);
270        let res = Duration::from_env_var("Duration").unwrap();
271
272        assert_eq!(res, val);
273    }
274
275    #[test]
276    fn test_a_few_errors() {
277        test_expect_err::<u8, _>(
278            "U8_",
279            30000u16,
280            FromEnvErr::parse_error("30000".parse::<u8>().unwrap_err()),
281        );
282
283        test_expect_err::<u8, _>("U8_", "", FromEnvErr::empty("U8_"));
284    }
285}