1use std::{convert::Infallible, env::VarError, num::ParseIntError, str::FromStr};
2
3#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
6pub enum FromEnvErr<Inner> {
7 #[error("Error reading variable {0}: {1}")]
9 EnvError(String, VarError),
10 #[error("Environment variable {0} is empty")]
12 Empty(String),
13 #[error("Failed to parse environment variable {0}")]
15 ParseError(#[from] Inner),
16}
17
18impl FromEnvErr<Infallible> {
19 pub fn infallible_into<T>(self) -> FromEnvErr<T> {
21 match self {
22 Self::EnvError(s, e) => FromEnvErr::EnvError(s, e),
23 Self::Empty(s) => FromEnvErr::Empty(s),
24 Self::ParseError(_) => unreachable!(),
25 }
26 }
27}
28
29impl<Inner> FromEnvErr<Inner> {
30 pub fn from<Other>(other: FromEnvErr<Other>) -> Self
32 where
33 Inner: From<Other>,
34 {
35 match other {
36 FromEnvErr::EnvError(s, e) => Self::EnvError(s, e),
37 FromEnvErr::Empty(s) => Self::Empty(s),
38 FromEnvErr::ParseError(e) => Self::ParseError(Inner::from(e)),
39 }
40 }
41
42 pub fn env_err(var: &str, e: VarError) -> Self {
44 Self::EnvError(var.to_string(), e)
45 }
46
47 pub fn empty(var: &str) -> Self {
49 Self::Empty(var.to_string())
50 }
51
52 pub const fn parse_error(err: Inner) -> Self {
54 Self::ParseError(err)
55 }
56}
57
58pub fn parse_env_if_present<T: FromStr>(env_var: &str) -> Result<T, FromEnvErr<T::Err>> {
61 let s = std::env::var(env_var).map_err(|e| FromEnvErr::env_err(env_var, e))?;
62
63 if s.is_empty() {
64 Err(FromEnvErr::empty(env_var))
65 } else {
66 s.parse().map_err(Into::into)
67 }
68}
69
70pub trait FromEnv: core::fmt::Debug + Sized + 'static {
82 type Error: core::error::Error;
84
85 fn from_env() -> Result<Self, FromEnvErr<Self::Error>>;
87}
88
89pub trait FromEnvVar: core::fmt::Debug + Sized + 'static {
97 type Error: core::error::Error;
99
100 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>>;
102}
103
104impl<T> FromEnvVar for Option<T>
105where
106 T: FromEnvVar,
107{
108 type Error = T::Error;
109
110 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
111 match std::env::var(env_var) {
112 Ok(s) if s.is_empty() => Ok(None),
113 Ok(_) => T::from_env_var(env_var).map(Some),
114 Err(_) => Ok(None),
115 }
116 }
117}
118
119impl FromEnvVar for String {
120 type Error = std::convert::Infallible;
121
122 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
123 std::env::var(env_var).map_err(|_| FromEnvErr::empty(env_var))
124 }
125}
126
127impl FromEnvVar for std::time::Duration {
128 type Error = ParseIntError;
129
130 fn from_env_var(s: &str) -> Result<Self, FromEnvErr<Self::Error>> {
131 u64::from_env_var(s).map(Self::from_millis)
132 }
133}
134
135macro_rules! impl_for_parseable {
136 ($($t:ty),*) => {
137 $(
138 impl FromEnvVar for $t {
139 type Error = <$t as FromStr>::Err;
140
141 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
142 parse_env_if_present(env_var)
143 }
144 }
145 )*
146 }
147}
148
149impl_for_parseable!(
150 u8,
151 u16,
152 u32,
153 u64,
154 u128,
155 usize,
156 i8,
157 i16,
158 i32,
159 i64,
160 i128,
161 isize,
162 url::Url,
163 tracing::Level
164);
165
166#[cfg(feature = "alloy")]
167impl_for_parseable!(
168 alloy::primitives::Address,
169 alloy::primitives::Bytes,
170 alloy::primitives::U256
171);
172
173#[cfg(feature = "alloy")]
174impl<const N: usize> FromEnvVar for alloy::primitives::FixedBytes<N> {
175 type Error = <alloy::primitives::FixedBytes<N> as FromStr>::Err;
176
177 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
178 parse_env_if_present(env_var)
179 }
180}
181
182#[cfg(test)]
183mod test {
184 use std::time::Duration;
185
186 use super::*;
187
188 fn set<T>(env: &str, val: &T)
189 where
190 T: ToString,
191 {
192 std::env::set_var(env, val.to_string());
193 }
194
195 fn load_expect_err<T>(env: &str, err: FromEnvErr<T::Error>)
196 where
197 T: FromEnvVar,
198 T::Error: PartialEq,
199 {
200 let res = T::from_env_var(env).unwrap_err();
201 assert_eq!(res, err);
202 }
203
204 fn test<T>(env: &str, val: T)
205 where
206 T: ToString + FromEnvVar + PartialEq + std::fmt::Debug,
207 {
208 set(env, &val);
209
210 let res = T::from_env_var(env).unwrap();
211 assert_eq!(res, val);
212 }
213
214 fn test_expect_err<T, U>(env: &str, value: U, err: FromEnvErr<T::Error>)
215 where
216 T: FromEnvVar,
217 U: ToString,
218 T::Error: PartialEq,
219 {
220 set(env, &value);
221 load_expect_err::<T>(env, err);
222 }
223
224 #[test]
225 fn test_primitives() {
226 test("U8", 42u8);
227 test("U16", 42u16);
228 test("U32", 42u32);
229 test("U64", 42u64);
230 test("U128", 42u128);
231 test("Usize", 42usize);
232 test("I8", 42i8);
233 test("I8-NEG", -42i16);
234 test("I16", 42i16);
235 test("I32", 42i32);
236 test("I64", 42i64);
237 test("I128", 42i128);
238 test("Isize", 42isize);
239 test("String", "hello".to_string());
240 test("Url", url::Url::parse("http://example.com").unwrap());
241 test("Level", tracing::Level::INFO);
242 }
243
244 #[test]
245 fn test_duration() {
246 let amnt = 42;
247 let val = Duration::from_millis(42);
248
249 set("Duration", &amnt);
250 let res = Duration::from_env_var("Duration").unwrap();
251
252 assert_eq!(res, val);
253 }
254
255 #[test]
256 fn test_a_few_errors() {
257 test_expect_err::<u8, _>(
258 "U8_",
259 30000u16,
260 FromEnvErr::parse_error("30000".parse::<u8>().unwrap_err()),
261 );
262
263 test_expect_err::<u8, _>("U8_", "", FromEnvErr::empty("U8_"));
264 }
265}