1use std::{convert::Infallible, env::VarError, num::ParseIntError, str::FromStr};
2
3#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
6pub enum FromEnvErr<Inner> {
7 #[error("Error reading variable {0}: {1}")]
9 EnvError(String, VarError),
10 #[error("Environment variable {0} is empty")]
12 Empty(String),
13 #[error("Failed to parse environment variable {0}")]
15 ParseError(#[from] Inner),
16}
17
18impl FromEnvErr<Infallible> {
19 pub fn infallible_into<T>(self) -> FromEnvErr<T> {
21 match self {
22 Self::EnvError(s, e) => FromEnvErr::EnvError(s, e),
23 Self::Empty(s) => FromEnvErr::Empty(s),
24 Self::ParseError(_) => unreachable!(),
25 }
26 }
27}
28
29impl<Inner> FromEnvErr<Inner> {
30 pub fn from<Other>(other: FromEnvErr<Other>) -> Self
32 where
33 Inner: From<Other>,
34 {
35 match other {
36 FromEnvErr::EnvError(s, e) => Self::EnvError(s, e),
37 FromEnvErr::Empty(s) => Self::Empty(s),
38 FromEnvErr::ParseError(e) => Self::ParseError(Inner::from(e)),
39 }
40 }
41
42 pub fn map<New>(self, f: impl FnOnce(Inner) -> New) -> FromEnvErr<New> {
46 match self {
47 Self::EnvError(s, e) => FromEnvErr::EnvError(s, e),
48 Self::Empty(s) => FromEnvErr::Empty(s),
49 Self::ParseError(e) => FromEnvErr::ParseError(f(e)),
50 }
51 }
52
53 pub fn env_err(var: &str, e: VarError) -> Self {
55 Self::EnvError(var.to_string(), e)
56 }
57
58 pub fn empty(var: &str) -> Self {
60 Self::Empty(var.to_string())
61 }
62
63 pub const fn parse_error(err: Inner) -> Self {
65 Self::ParseError(err)
66 }
67}
68
69pub fn parse_env_if_present<T: FromStr>(env_var: &str) -> Result<T, FromEnvErr<T::Err>> {
72 let s = std::env::var(env_var).map_err(|e| FromEnvErr::env_err(env_var, e))?;
73
74 if s.is_empty() {
75 Err(FromEnvErr::empty(env_var))
76 } else {
77 s.parse().map_err(Into::into)
78 }
79}
80
81pub trait FromEnv: core::fmt::Debug + Sized + 'static {
93 type Error: core::error::Error;
95
96 fn from_env() -> Result<Self, FromEnvErr<Self::Error>>;
98}
99
100pub trait FromEnvVar: core::fmt::Debug + Sized + 'static {
108 type Error: core::error::Error;
110
111 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>>;
113}
114
115impl<T> FromEnvVar for Option<T>
116where
117 T: FromEnvVar,
118{
119 type Error = T::Error;
120
121 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
122 match std::env::var(env_var) {
123 Ok(s) if s.is_empty() => Ok(None),
124 Ok(_) => T::from_env_var(env_var).map(Some),
125 Err(_) => Ok(None),
126 }
127 }
128}
129
130impl FromEnvVar for String {
131 type Error = std::convert::Infallible;
132
133 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
134 std::env::var(env_var).map_err(|_| FromEnvErr::empty(env_var))
135 }
136}
137
138impl FromEnvVar for std::time::Duration {
139 type Error = ParseIntError;
140
141 fn from_env_var(s: &str) -> Result<Self, FromEnvErr<Self::Error>> {
142 u64::from_env_var(s).map(Self::from_millis)
143 }
144}
145
146macro_rules! impl_for_parseable {
147 ($($t:ty),*) => {
148 $(
149 impl FromEnvVar for $t {
150 type Error = <$t as FromStr>::Err;
151
152 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
153 parse_env_if_present(env_var)
154 }
155 }
156 )*
157 }
158}
159
160impl_for_parseable!(
161 u8,
162 u16,
163 u32,
164 u64,
165 u128,
166 usize,
167 i8,
168 i16,
169 i32,
170 i64,
171 i128,
172 isize,
173 url::Url,
174 tracing::Level
175);
176
177#[cfg(feature = "alloy")]
178impl_for_parseable!(
179 alloy::primitives::Address,
180 alloy::primitives::Bytes,
181 alloy::primitives::U256
182);
183
184#[cfg(feature = "alloy")]
185impl<const N: usize> FromEnvVar for alloy::primitives::FixedBytes<N> {
186 type Error = <alloy::primitives::FixedBytes<N> as FromStr>::Err;
187
188 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
189 parse_env_if_present(env_var)
190 }
191}
192
193impl FromEnvVar for bool {
194 type Error = std::str::ParseBoolError;
195
196 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
197 let s: String = std::env::var(env_var).map_err(|e| FromEnvErr::env_err(env_var, e))?;
198 Ok(!s.is_empty())
199 }
200}
201
202#[cfg(test)]
203mod test {
204 use std::time::Duration;
205
206 use super::*;
207
208 fn set<T>(env: &str, val: &T)
209 where
210 T: ToString,
211 {
212 std::env::set_var(env, val.to_string());
213 }
214
215 fn load_expect_err<T>(env: &str, err: FromEnvErr<T::Error>)
216 where
217 T: FromEnvVar,
218 T::Error: PartialEq,
219 {
220 let res = T::from_env_var(env).unwrap_err();
221 assert_eq!(res, err);
222 }
223
224 fn test<T>(env: &str, val: T)
225 where
226 T: ToString + FromEnvVar + PartialEq + std::fmt::Debug,
227 {
228 set(env, &val);
229
230 let res = T::from_env_var(env).unwrap();
231 assert_eq!(res, val);
232 }
233
234 fn test_expect_err<T, U>(env: &str, value: U, err: FromEnvErr<T::Error>)
235 where
236 T: FromEnvVar,
237 U: ToString,
238 T::Error: PartialEq,
239 {
240 set(env, &value);
241 load_expect_err::<T>(env, err);
242 }
243
244 #[test]
245 fn test_primitives() {
246 test("U8", 42u8);
247 test("U16", 42u16);
248 test("U32", 42u32);
249 test("U64", 42u64);
250 test("U128", 42u128);
251 test("Usize", 42usize);
252 test("I8", 42i8);
253 test("I8-NEG", -42i16);
254 test("I16", 42i16);
255 test("I32", 42i32);
256 test("I64", 42i64);
257 test("I128", 42i128);
258 test("Isize", 42isize);
259 test("String", "hello".to_string());
260 test("Url", url::Url::parse("http://example.com").unwrap());
261 test("Level", tracing::Level::INFO);
262 }
263
264 #[test]
265 fn test_duration() {
266 let amnt = 42;
267 let val = Duration::from_millis(42);
268
269 set("Duration", &amnt);
270 let res = Duration::from_env_var("Duration").unwrap();
271
272 assert_eq!(res, val);
273 }
274
275 #[test]
276 fn test_a_few_errors() {
277 test_expect_err::<u8, _>(
278 "U8_",
279 30000u16,
280 FromEnvErr::parse_error("30000".parse::<u8>().unwrap_err()),
281 );
282
283 test_expect_err::<u8, _>("U8_", "", FromEnvErr::empty("U8_"));
284 }
285}