1use std::{convert::Infallible, env::VarError, num::ParseIntError, str::FromStr};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct EnvItemInfo {
8 pub var: &'static str,
10 pub description: &'static str,
12 pub optional: bool,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
19pub enum FromEnvErr<Inner> {
20 #[error("Error reading variable {0}: {1}")]
22 EnvError(String, VarError),
23 #[error("Environment variable {0} is empty")]
25 Empty(String),
26 #[error("Failed to parse environment variable {0}")]
28 ParseError(#[from] Inner),
29}
30
31impl FromEnvErr<Infallible> {
32 pub fn infallible_into<T>(self) -> FromEnvErr<T> {
34 match self {
35 Self::EnvError(s, e) => FromEnvErr::EnvError(s, e),
36 Self::Empty(s) => FromEnvErr::Empty(s),
37 Self::ParseError(_) => unreachable!(),
38 }
39 }
40}
41
42impl<Inner> FromEnvErr<Inner> {
43 pub fn from<Other>(other: FromEnvErr<Other>) -> Self
45 where
46 Inner: From<Other>,
47 {
48 match other {
49 FromEnvErr::EnvError(s, e) => Self::EnvError(s, e),
50 FromEnvErr::Empty(s) => Self::Empty(s),
51 FromEnvErr::ParseError(e) => Self::ParseError(Inner::from(e)),
52 }
53 }
54
55 pub fn map<New>(self, f: impl FnOnce(Inner) -> New) -> FromEnvErr<New> {
59 match self {
60 Self::EnvError(s, e) => FromEnvErr::EnvError(s, e),
61 Self::Empty(s) => FromEnvErr::Empty(s),
62 Self::ParseError(e) => FromEnvErr::ParseError(f(e)),
63 }
64 }
65
66 pub fn env_err(var: &str, e: VarError) -> Self {
68 Self::EnvError(var.to_string(), e)
69 }
70
71 pub fn empty(var: &str) -> Self {
73 Self::Empty(var.to_string())
74 }
75
76 pub const fn parse_error(err: Inner) -> Self {
78 Self::ParseError(err)
79 }
80}
81
82pub fn parse_env_if_present<T: FromStr>(env_var: &str) -> Result<T, FromEnvErr<T::Err>> {
85 let s = std::env::var(env_var).map_err(|e| FromEnvErr::env_err(env_var, e))?;
86
87 if s.is_empty() {
88 Err(FromEnvErr::empty(env_var))
89 } else {
90 s.parse().map_err(Into::into)
91 }
92}
93
94pub trait FromEnv: core::fmt::Debug + Sized + 'static {
106 type Error: core::error::Error;
108
109 fn inventory() -> Vec<&'static EnvItemInfo>;
116
117 fn check_inventory() -> Result<(), Vec<&'static EnvItemInfo>> {
123 let mut missing = Vec::new();
124 for var in Self::inventory() {
125 if std::env::var(var.var).is_err() && !var.optional {
126 missing.push(var);
127 }
128 }
129 if missing.is_empty() {
130 Ok(())
131 } else {
132 Err(missing)
133 }
134 }
135
136 fn from_env() -> Result<Self, FromEnvErr<Self::Error>>;
138}
139
140pub trait FromEnvVar: core::fmt::Debug + Sized + 'static {
148 type Error: core::error::Error;
150
151 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>>;
153}
154
155impl<T> FromEnvVar for Option<T>
156where
157 T: FromEnvVar,
158{
159 type Error = T::Error;
160
161 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
162 match std::env::var(env_var) {
163 Ok(s) if s.is_empty() => Ok(None),
164 Ok(_) => T::from_env_var(env_var).map(Some),
165 Err(_) => Ok(None),
166 }
167 }
168}
169
170impl FromEnvVar for String {
171 type Error = std::convert::Infallible;
172
173 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
174 std::env::var(env_var).map_err(|_| FromEnvErr::empty(env_var))
175 }
176}
177
178impl FromEnvVar for std::time::Duration {
179 type Error = ParseIntError;
180
181 fn from_env_var(s: &str) -> Result<Self, FromEnvErr<Self::Error>> {
182 u64::from_env_var(s).map(Self::from_millis)
183 }
184}
185
186macro_rules! impl_for_parseable {
187 ($($t:ty),*) => {
188 $(
189 impl FromEnvVar for $t {
190 type Error = <$t as FromStr>::Err;
191
192 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
193 parse_env_if_present(env_var)
194 }
195 }
196 )*
197 }
198}
199
200impl_for_parseable!(
201 u8,
202 u16,
203 u32,
204 u64,
205 u128,
206 usize,
207 i8,
208 i16,
209 i32,
210 i64,
211 i128,
212 isize,
213 url::Url,
214 tracing::Level
215);
216
217#[cfg(feature = "alloy")]
218impl_for_parseable!(
219 alloy::primitives::Address,
220 alloy::primitives::Bytes,
221 alloy::primitives::U256
222);
223
224#[cfg(feature = "alloy")]
225impl<const N: usize> FromEnvVar for alloy::primitives::FixedBytes<N> {
226 type Error = <alloy::primitives::FixedBytes<N> as FromStr>::Err;
227
228 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
229 parse_env_if_present(env_var)
230 }
231}
232
233impl FromEnvVar for bool {
234 type Error = std::str::ParseBoolError;
235
236 fn from_env_var(env_var: &str) -> Result<Self, FromEnvErr<Self::Error>> {
237 let s: String = std::env::var(env_var).map_err(|e| FromEnvErr::env_err(env_var, e))?;
238 Ok(!s.is_empty())
239 }
240}
241
242#[cfg(test)]
243mod test {
244 use std::time::Duration;
245
246 use super::*;
247
248 fn set<T>(env: &str, val: &T)
249 where
250 T: ToString,
251 {
252 std::env::set_var(env, val.to_string());
253 }
254
255 fn load_expect_err<T>(env: &str, err: FromEnvErr<T::Error>)
256 where
257 T: FromEnvVar,
258 T::Error: PartialEq,
259 {
260 let res = T::from_env_var(env).unwrap_err();
261 assert_eq!(res, err);
262 }
263
264 fn test<T>(env: &str, val: T)
265 where
266 T: ToString + FromEnvVar + PartialEq + std::fmt::Debug,
267 {
268 set(env, &val);
269
270 let res = T::from_env_var(env).unwrap();
271 assert_eq!(res, val);
272 }
273
274 fn test_expect_err<T, U>(env: &str, value: U, err: FromEnvErr<T::Error>)
275 where
276 T: FromEnvVar,
277 U: ToString,
278 T::Error: PartialEq,
279 {
280 set(env, &value);
281 load_expect_err::<T>(env, err);
282 }
283
284 #[test]
285 fn test_primitives() {
286 test("U8", 42u8);
287 test("U16", 42u16);
288 test("U32", 42u32);
289 test("U64", 42u64);
290 test("U128", 42u128);
291 test("Usize", 42usize);
292 test("I8", 42i8);
293 test("I8-NEG", -42i16);
294 test("I16", 42i16);
295 test("I32", 42i32);
296 test("I64", 42i64);
297 test("I128", 42i128);
298 test("Isize", 42isize);
299 test("String", "hello".to_string());
300 test("Url", url::Url::parse("http://example.com").unwrap());
301 test("Level", tracing::Level::INFO);
302 }
303
304 #[test]
305 fn test_duration() {
306 let amnt = 42;
307 let val = Duration::from_millis(42);
308
309 set("Duration", &amnt);
310 let res = Duration::from_env_var("Duration").unwrap();
311
312 assert_eq!(res, val);
313 }
314
315 #[test]
316 fn test_a_few_errors() {
317 test_expect_err::<u8, _>(
318 "U8_",
319 30000u16,
320 FromEnvErr::parse_error("30000".parse::<u8>().unwrap_err()),
321 );
322
323 test_expect_err::<u8, _>("U8_", "", FromEnvErr::empty("U8_"));
324 }
325}