1use std::{convert::Infallible, env::VarError, num::ParseIntError, str::FromStr};
2
3#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
6pub enum FromEnvErr<Inner> {
7 #[error("Error reading variable {0}: {1}")]
9 EnvError(String, VarError),
10 #[error("Environment variable {0} is empty")]
12 Empty(String),
13 #[error("Failed to parse environment variable {0}")]
15 ParseError(#[from] Inner),
16}
17
18impl FromEnvErr<Infallible> {
19 pub fn infallible_into<T>(self) -> FromEnvErr<T> {
21 match self {
22 Self::EnvError(s, e) => FromEnvErr::EnvError(s, e),
23 Self::Empty(s) => FromEnvErr::Empty(s),
24 Self::ParseError(_) => unreachable!(),
25 }
26 }
27}
28
29impl<Inner> FromEnvErr<Inner> {
30 pub fn from<Other>(other: FromEnvErr<Other>) -> Self
32 where
33 Inner: From<Other>,
34 {
35 match other {
36 FromEnvErr::EnvError(s, e) => Self::EnvError(s, e),
37 FromEnvErr::Empty(s) => Self::Empty(s),
38 FromEnvErr::ParseError(e) => Self::ParseError(Inner::from(e)),
39 }
40 }
41
42 pub fn env_err(var: &str, e: VarError) -> Self {
44 Self::EnvError(var.to_string(), e)
45 }
46
47 pub fn empty(var: &str) -> Self {
49 Self::Empty(var.to_string())
50 }
51
52 pub const fn parse_error(err: Inner) -> Self {
54 Self::ParseError(err)
55 }
56}
57
58pub fn parse_env_if_present<T: FromStr>(env_var: &str) -> Result<T, FromEnvErr<T::Err>> {
61 let s = std::env::var(env_var).map_err(|e| FromEnvErr::env_err(env_var, e))?;
62
63 if s.is_empty() {
64 Err(FromEnvErr::empty(env_var))
65 } else {
66 s.parse().map_err(Into::into)
67 }
68}
69
70pub trait FromEnv: core::fmt::Debug + Sized + 'static {
82 type Error: core::error::Error;
84
85 fn from_env() -> Result<Self, FromEnvErr<Self::Error>>;
87}
88
89pub trait FromEnvVar: core::fmt::Debug + Sized + 'static {
97 type Error: core::error::Error;
99
100 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>>;
102}
103
104impl<T> FromEnvVar for Option<T>
105where
106 T: FromEnvVar,
107{
108 type Error = T::Error;
109
110 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
111 match std::env::var(env_var) {
112 Ok(s) if s.is_empty() => Ok(None),
113 Ok(_) => T::from_env_var(env_var).map(Some),
114 Err(_) => Ok(None),
115 }
116 }
117}
118
119impl FromEnvVar for String {
120 type Error = std::convert::Infallible;
121
122 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
123 std::env::var(env_var).map_err(|_| FromEnvErr::empty(env_var))
124 }
125}
126
127impl FromEnvVar for std::time::Duration {
128 type Error = ParseIntError;
129
130 fn from_env_var(s: &str) -> Result<Self, FromEnvErr<Self::Error>> {
131 u64::from_env_var(s).map(Self::from_millis)
132 }
133}
134
135macro_rules! impl_for_parseable {
136 ($($t:ty),*) => {
137 $(
138 impl FromEnvVar for $t {
139 type Error = <$t as FromStr>::Err;
140
141 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
142 parse_env_if_present(env_var)
143 }
144 }
145 )*
146 }
147}
148
149impl_for_parseable!(
150 u8,
151 u16,
152 u32,
153 u64,
154 u128,
155 usize,
156 i8,
157 i16,
158 i32,
159 i64,
160 i128,
161 isize,
162 url::Url,
163 tracing::Level
164);
165
166#[cfg(feature = "alloy")]
167impl_for_parseable!(
168 alloy::primitives::Address,
169 alloy::primitives::Bytes,
170 alloy::primitives::U256
171);
172
173#[cfg(feature = "alloy")]
174impl<const N: usize> FromEnvVar for alloy::primitives::FixedBytes<N> {
175 type Error = <alloy::primitives::FixedBytes<N> as FromStr>::Err;
176
177 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
178 parse_env_if_present(env_var)
179 }
180}
181
182impl FromEnvVar for bool {
183 type Error = std::str::ParseBoolError;
184
185 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
186 let s: String = std::env::var(env_var).map_err(|e| FromEnvErr::env_err(env_var, e))?;
187 Ok(!s.is_empty())
188 }
189}
190
191#[cfg(test)]
192mod test {
193 use std::time::Duration;
194
195 use super::*;
196
197 fn set<T>(env: &str, val: &T)
198 where
199 T: ToString,
200 {
201 std::env::set_var(env, val.to_string());
202 }
203
204 fn load_expect_err<T>(env: &str, err: FromEnvErr<T::Error>)
205 where
206 T: FromEnvVar,
207 T::Error: PartialEq,
208 {
209 let res = T::from_env_var(env).unwrap_err();
210 assert_eq!(res, err);
211 }
212
213 fn test<T>(env: &str, val: T)
214 where
215 T: ToString + FromEnvVar + PartialEq + std::fmt::Debug,
216 {
217 set(env, &val);
218
219 let res = T::from_env_var(env).unwrap();
220 assert_eq!(res, val);
221 }
222
223 fn test_expect_err<T, U>(env: &str, value: U, err: FromEnvErr<T::Error>)
224 where
225 T: FromEnvVar,
226 U: ToString,
227 T::Error: PartialEq,
228 {
229 set(env, &value);
230 load_expect_err::<T>(env, err);
231 }
232
233 #[test]
234 fn test_primitives() {
235 test("U8", 42u8);
236 test("U16", 42u16);
237 test("U32", 42u32);
238 test("U64", 42u64);
239 test("U128", 42u128);
240 test("Usize", 42usize);
241 test("I8", 42i8);
242 test("I8-NEG", -42i16);
243 test("I16", 42i16);
244 test("I32", 42i32);
245 test("I64", 42i64);
246 test("I128", 42i128);
247 test("Isize", 42isize);
248 test("String", "hello".to_string());
249 test("Url", url::Url::parse("http://example.com").unwrap());
250 test("Level", tracing::Level::INFO);
251 }
252
253 #[test]
254 fn test_duration() {
255 let amnt = 42;
256 let val = Duration::from_millis(42);
257
258 set("Duration", &amnt);
259 let res = Duration::from_env_var("Duration").unwrap();
260
261 assert_eq!(res, val);
262 }
263
264 #[test]
265 fn test_a_few_errors() {
266 test_expect_err::<u8, _>(
267 "U8_",
268 30000u16,
269 FromEnvErr::parse_error("30000".parse::<u8>().unwrap_err()),
270 );
271
272 test_expect_err::<u8, _>("U8_", "", FromEnvErr::empty("U8_"));
273 }
274}