Skip to main content

env_required/
lib.rs

1#![forbid(unsafe_code)]
2#![warn(rust_2018_idioms)]
3
4#![doc = include_str!("../README.md")]
5
6/// Read and validate required environment variables.
7///
8/// ## Single env var (String)
9///
10/// ```rust,no_run
11/// use env_required::required;
12///
13/// let database_url = required!("DATABASE_URL");
14/// ```
15///
16/// ## Parse via `FromStr`
17///
18/// ```rust,no_run
19/// use env_required::required;
20///
21/// let port: u16 = required!("PORT" => u16);
22/// let workers: usize = required!("WORKERS" => _); // `_` lets the compiler infer the type.
23/// ```
24///
25/// ## Validate multiple variables (no boilerplate)
26///
27/// ```rust,no_run
28/// use env_required::required;
29///
30/// required!(["DATABASE_URL", "PORT", "RUST_LOG"]);
31/// ```
32///
33/// ## Custom message
34///
35/// ```rust,no_run
36/// use env_required::required;
37///
38/// let token = required!("API_TOKEN", "API_TOKEN is required to call Example API");
39/// required!(["DATABASE_URL", "PORT"], "missing configuration for my-service");
40/// ```
41#[macro_export]
42macro_rules! required {
43    // --- Single variable (String)
44    ($key:literal $(,)?) => {
45        $crate::__private::required_string($key, None)
46    };
47    ($key:literal, $msg:expr $(,)?) => {
48        $crate::__private::required_string($key, Some(($msg).to_string()))
49    };
50
51    // --- Single variable (parse via FromStr)
52    ($key:literal => $t:ty $(,)?) => {
53        $crate::__private::required_parse::<$t>($key, None)
54    };
55    ($key:literal => $t:ty, $msg:expr $(,)?) => {
56        $crate::__private::required_parse::<$t>($key, Some(($msg).to_string()))
57    };
58
59    // --- Validate many (no values returned)
60    ([$($key:literal),+ $(,)?] $(,)?) => {
61        $crate::__private::validate_required(&[$($key),+], None)
62    };
63    ([$($key:literal),+ $(,)?], $msg:expr $(,)?) => {
64        $crate::__private::validate_required(&[$($key),+], Some(($msg).to_string()))
65    };
66
67    // --- Misuse help (compile-time errors)
68    () => {
69        ::core::compile_error!(
70            "env-required: expected input. Example: required!(\"PORT\") or required!([\"A\", \"B\"])"
71        )
72    };
73    ($($anything:tt)+) => {
74        ::core::compile_error!(
75            "env-required: invalid syntax. Use string literals, e.g. required!(\"PORT\"), required!(\"PORT\" => u16), required!([\"A\", \"B\"])."
76        )
77    };
78}
79
80#[doc(hidden)]
81pub mod __private {
82    #![allow(missing_docs)]
83
84    use core::fmt;
85    use core::str::FromStr;
86
87    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
88    pub enum VarIssueKind {
89        Missing,
90        Empty,
91        NotUnicode,
92    }
93
94    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
95    pub struct VarIssue {
96        pub key: &'static str,
97        pub kind: VarIssueKind,
98    }
99
100    pub fn required_string(key: &'static str, msg: Option<String>) -> String {
101        match raw_required(key) {
102            Ok(v) => v,
103            Err(issue) => panic_issues(&[issue], msg.as_deref()),
104        }
105    }
106
107    pub fn required_parse<T>(key: &'static str, msg: Option<String>) -> T
108    where
109        T: FromStr,
110        T::Err: fmt::Display,
111    {
112        let raw = match raw_required(key) {
113            Ok(v) => v,
114            Err(issue) => panic_issues(&[issue], msg.as_deref()),
115        };
116
117        match raw.parse::<T>() {
118            Ok(v) => v,
119            Err(err) => panic_parse::<T>(key, &err, msg.as_deref()),
120        }
121    }
122
123    pub fn validate_required(keys: &[&'static str], msg: Option<String>) {
124        let mut issues = Vec::new();
125        for &key in keys {
126            if let Err(issue) = raw_required(key) {
127                issues.push(issue);
128            }
129        }
130        if !issues.is_empty() {
131            panic_issues(&issues, msg.as_deref());
132        }
133    }
134
135    fn raw_required(key: &'static str) -> Result<String, VarIssue> {
136        let os = match std::env::var_os(key) {
137            Some(v) => v,
138            None => {
139                return Err(VarIssue {
140                    key,
141                    kind: VarIssueKind::Missing,
142                })
143            }
144        };
145
146        let s = match os.into_string() {
147            Ok(v) => v,
148            Err(_) => {
149                return Err(VarIssue {
150                    key,
151                    kind: VarIssueKind::NotUnicode,
152                })
153            }
154        };
155
156        #[cfg(not(feature = "allow-empty"))]
157        if s.is_empty() {
158            return Err(VarIssue {
159                key,
160                kind: VarIssueKind::Empty,
161            });
162        }
163
164        Ok(s)
165    }
166
167    fn panic_parse<T>(key: &'static str, err: &T::Err, msg: Option<&str>) -> !
168    where
169        T: FromStr,
170        T::Err: fmt::Display,
171    {
172        let mut out = String::new();
173        if let Some(msg) = msg {
174            out.push_str("env-required: ");
175            out.push_str(msg);
176            out.push_str("\n\n");
177        }
178
179        out.push_str("env-required: failed to parse required environment variable\n");
180        out.push_str("\n");
181        out.push_str("Key: ");
182        out.push_str(key);
183        out.push_str("\n");
184        out.push_str("Expected type: ");
185        out.push_str(core::any::type_name::<T>());
186        out.push_str("\n");
187        out.push_str("Parse error: ");
188        out.push_str(&err.to_string());
189        out.push_str("\n\n");
190
191        out.push_str("How to fix:\n");
192        out.push_str("- Ensure the value matches the expected type (see the message above).\n");
193        out.push_str("- Tip: print the env var before parsing to inspect its contents.\n");
194
195        panic!("{}", out);
196    }
197
198    fn panic_issues(issues: &[VarIssue], msg: Option<&str>) -> ! {
199        let mut out = String::new();
200        if let Some(msg) = msg {
201            out.push_str("env-required: ");
202            out.push_str(msg);
203            out.push_str("\n\n");
204        }
205
206        if issues.len() == 1 {
207            let i = issues[0];
208            out.push_str("env-required: missing required environment variable\n\n");
209            out.push_str("Key: ");
210            out.push_str(i.key);
211            out.push_str("\n");
212            out.push_str("Problem: ");
213            out.push_str(kind_human(i.kind));
214            out.push_str("\n\n");
215        } else {
216            out.push_str("env-required: missing required environment variables\n\n");
217            out.push_str("Missing count: ");
218            out.push_str(&issues.len().to_string());
219            out.push_str("\n\n");
220            for i in issues {
221                out.push_str("- ");
222                out.push_str(i.key);
223                out.push_str(": ");
224                out.push_str(kind_human(i.kind));
225                out.push_str("\n");
226            }
227            out.push_str("\n");
228        }
229
230        out.push_str("How to fix:\n");
231        out.push_str("- Set the env var(s) before running this program.\n");
232        out.push_str("- Example (bash/zsh): export KEY=\"value\"\n");
233        out.push_str("- Example (PowerShell): $Env:KEY = \"value\"\n");
234
235        #[cfg(not(feature = "allow-empty"))]
236        {
237            out.push_str("- Note: empty strings (KEY=\"\") are treated as missing by default.\n");
238            out.push_str("  Enable feature `allow-empty` if you want to accept empty values.\n");
239        }
240
241        panic!("{}", out);
242    }
243
244    fn kind_human(kind: VarIssueKind) -> &'static str {
245        match kind {
246            VarIssueKind::Missing => "not set",
247            VarIssueKind::Empty => "set but empty",
248            VarIssueKind::NotUnicode => "set but not valid UTF-8",
249        }
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use std::sync::Mutex;
256
257    static ENV_LOCK: Mutex<()> = Mutex::new(());
258
259    fn lock_env() -> std::sync::MutexGuard<'static, ()> {
260        ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner())
261    }
262
263    #[test]
264    fn reads_string_env_var() {
265        let _guard = lock_env();
266        std::env::set_var("ENV_REQUIRED_TEST_URL", "postgres://localhost/db");
267
268        let v = required!("ENV_REQUIRED_TEST_URL");
269        assert_eq!(v, "postgres://localhost/db");
270
271        std::env::remove_var("ENV_REQUIRED_TEST_URL");
272    }
273
274    #[test]
275    fn parses_fromstr_type() {
276        let _guard = lock_env();
277        std::env::set_var("ENV_REQUIRED_TEST_PORT", "5432");
278
279        let port: u16 = required!("ENV_REQUIRED_TEST_PORT" => u16);
280        assert_eq!(port, 5432);
281
282        std::env::remove_var("ENV_REQUIRED_TEST_PORT");
283    }
284
285    #[test]
286    fn validates_many() {
287        let _guard = lock_env();
288        std::env::set_var("ENV_REQUIRED_TEST_A", "a");
289        std::env::set_var("ENV_REQUIRED_TEST_B", "b");
290
291        required!(["ENV_REQUIRED_TEST_A", "ENV_REQUIRED_TEST_B"]);
292
293        std::env::remove_var("ENV_REQUIRED_TEST_A");
294        std::env::remove_var("ENV_REQUIRED_TEST_B");
295    }
296
297    #[test]
298    fn missing_env_panics_with_key_name() {
299        let _guard = lock_env();
300        std::env::remove_var("ENV_REQUIRED_TEST_MISSING");
301
302        let panic_msg = std::panic::catch_unwind(|| {
303            let _ = required!("ENV_REQUIRED_TEST_MISSING");
304        })
305        .expect_err("expected panic");
306
307        let msg = panic_to_string(panic_msg);
308        assert!(msg.contains("ENV_REQUIRED_TEST_MISSING"));
309        assert!(msg.contains("missing required environment variable"));
310    }
311
312    #[test]
313    fn parse_error_includes_type_name() {
314        let _guard = lock_env();
315        std::env::set_var("ENV_REQUIRED_TEST_BAD_U16", "not-a-number");
316
317        let panic_msg = std::panic::catch_unwind(|| {
318            let _: u16 = required!("ENV_REQUIRED_TEST_BAD_U16" => u16);
319        })
320        .expect_err("expected panic");
321
322        let msg = panic_to_string(panic_msg);
323        assert!(msg.contains("ENV_REQUIRED_TEST_BAD_U16"));
324        assert!(msg.contains("Expected type"));
325        assert!(msg.contains("u16"));
326
327        std::env::remove_var("ENV_REQUIRED_TEST_BAD_U16");
328    }
329
330    #[test]
331    fn validate_many_reports_all_missing_keys() {
332        let _guard = lock_env();
333        std::env::set_var("ENV_REQUIRED_TEST_PRESENT", "x");
334        std::env::remove_var("ENV_REQUIRED_TEST_MISSING_1");
335        std::env::remove_var("ENV_REQUIRED_TEST_MISSING_2");
336
337        let panic_msg = std::panic::catch_unwind(|| {
338            required!([
339                "ENV_REQUIRED_TEST_PRESENT",
340                "ENV_REQUIRED_TEST_MISSING_1",
341                "ENV_REQUIRED_TEST_MISSING_2",
342            ]);
343        })
344        .expect_err("expected panic");
345
346        let msg = panic_to_string(panic_msg);
347        assert!(msg.contains("ENV_REQUIRED_TEST_MISSING_1"));
348        assert!(msg.contains("ENV_REQUIRED_TEST_MISSING_2"));
349
350        std::env::remove_var("ENV_REQUIRED_TEST_PRESENT");
351    }
352
353    fn panic_to_string(p: Box<dyn std::any::Any + Send>) -> String {
354        if let Some(s) = p.downcast_ref::<&'static str>() {
355            s.to_string()
356        } else if let Some(s) = p.downcast_ref::<String>() {
357            s.clone()
358        } else {
359            "<non-string panic payload>".to_string()
360        }
361    }
362}