ephemeral_env/
lib.rs

1//! A testing utility for creating ephemeral environments which are reset once they go out of
2//! scope.
3//!
4//! ## Examples:
5//!
6//! Once an [EphemeralEnv] drops out of scope, new env vars which were created whilst it was
7//! in scope will be dropped.
8//!
9//! ```
10//! use ephemeral_env::EphemeralEnv;
11//!
12//! assert!(std::env::var("MY_ENVIRONMENT_VARIABLE").is_err());
13//! #[cfg(feature = "sync")]
14//! {
15//!     let mut ephemeral = EphemeralEnv::from_env_sync().unwrap();
16//!     ephemeral.set_var("MY_ENVIRONMENT_VARIABLE", "test");
17//!     assert_eq!(std::env::var("MY_ENVIRONMENT_VARIABLE").unwrap(), "test");
18//! }
19//! assert!(std::env::var("MY_ENVIRONMENT_VARIABLE").is_err());
20//! ```
21//!
22//! Similarly, once an EphemeralEnv drops out of scope, modified env vars will revert to
23//! the value they held when the EphemeralEnv was created.
24//!
25//! ```
26//! use ephemeral_env::EphemeralEnv;
27//!
28//! unsafe {
29//!     std::env::set_var("MY_ENVIRONMENT_VARIABLE", "Spanners! Shh!");
30//! }
31//! {
32//!     let mut ephemeral = EphemeralEnv::from_env_sync().unwrap();
33//!     ephemeral.set_var("MY_ENVIRONMENT_VARIABLE", "test");
34//!     assert_eq!(std::env::var("MY_ENVIRONMENT_VARIABLE").unwrap(), "test");
35//! }
36//! assert_eq!(std::env::var("MY_ENVIRONMENT_VARIABLE").unwrap(), "Spanners! Shh!");
37//!```
38//!
39//! ## Async interface
40//!
41//! If you're testing across await points then it's inadvisable to use
42//! [EphemeralEnv::from_env_sync()], as this will use a [std::sync::Mutex], which isn't [Send].
43//!
44//! Instead, `ephemeral_env` provides [from_env_async()], which will use Tokio's async-compatible
45//! [Mutex](tokio::sync::Mutex) instead:
46//!
47//! ```
48//! use ephemeral_env::EphemeralEnv;
49//!
50//! # #[cfg(feature = "async")]
51//! # #[tokio::test]
52//! # async fn test_async_ephemeral_env() {
53//! assert!(std::env::var("MY_ENVIRONMENT_VARIABLE").is_err());
54//! {
55//!     let mut ephemeral = EphemeralEnv::from_env_async().unwrap();
56//!     ephemeral.set_var("MY_ENVIRONMENT_VARIABLE", "test");
57//!     // Without the async interface, this call would cause a warning to be shown, and there
58//!     // would be a potential for the mutex held by the ephemeral environment to block the
59//!     // event loop.
60//!     tokio::time::sleep(Duration::from_millis(500)).await;
61//!     assert_eq!(std::env::var("MY_ENVIRONMENT_VARIABLE").unwrap(), "test");
62//! }
63//! assert!(std::env::var("MY_ENVIRONMENT_VARIABLE").is_err());
64//! # }
65//! ```
66#[cfg(feature = "async")]
67use once_cell::sync::Lazy;
68#[cfg(feature = "sync")]
69use std::sync::MutexGuard;
70use std::sync::PoisonError;
71use std::{collections::HashMap, env};
72#[cfg(feature = "async")]
73use tokio::sync::MutexGuard as AsyncMutexGuard;
74
75#[cfg(feature = "async")]
76static ASYNC_ENV_LOCK: Lazy<tokio::sync::Mutex<()>> = Lazy::new(|| tokio::sync::Mutex::new(()));
77#[cfg(feature = "sync")]
78static SYNC_ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
79
80pub struct EphemeralEnv {
81    initial_vars: HashMap<String, String>,
82    #[cfg(feature = "sync")]
83    _sync_guard: Option<MutexGuard<'static, ()>>,
84    #[cfg(feature = "async")]
85    _async_guard: Option<AsyncMutexGuard<'static, ()>>,
86}
87
88impl EphemeralEnv {
89    /// Create a new synchronous EphemeralEnv.
90    ///
91    /// Takes a copy of the current environment variables, and prevents other EphemeralEnvs from
92    /// being started by holding a [MutexGuard](std::sync::MutexGuard).
93    ///
94    /// **Do not use this in tests of async code**. [std::sync::MutexGuard] is not [Send], and so
95    /// holding onto one across an await point can result in deadlocks.
96    #[cfg(feature = "sync")]
97    pub fn from_env_sync() -> Result<Self, PoisonError<MutexGuard<'static, ()>>> {
98        let guard = SYNC_ENV_LOCK.lock()?;
99        Ok(Self {
100            initial_vars: HashMap::from_iter(env::vars()),
101            _sync_guard: Some(guard),
102            #[cfg(feature = "async")]
103            _async_guard: None,
104        })
105    }
106
107    /// Create a new asynchronous EphemeralEnv.
108    ///
109    /// Takes a copy of the current environment variables, and prevents other EphemeralEnvs from
110    /// being started by holding a [MutexGuard](tokio::sync::MutexGuard).
111    #[cfg(feature = "async")]
112    pub async fn from_env_async() -> Result<Self, PoisonError<AsyncMutexGuard<'static, ()>>> {
113        let guard = ASYNC_ENV_LOCK.lock().await;
114        Ok(Self {
115            initial_vars: HashMap::from_iter(env::vars()),
116            #[cfg(feature = "async")]
117            _async_guard: Some(guard),
118            #[cfg(feature = "sync")]
119            _sync_guard: None,
120        })
121    }
122
123    /// Set one or more environment variables from a Vec of string tuples.
124    pub fn with_vars(&mut self, vars: Vec<(String, String)>) {
125        for (key, value) in vars.iter() {
126            unsafe {
127                std::env::set_var(key, value);
128            }
129        }
130    }
131
132    /// Set a single env var.
133    ///
134    /// This is a convenience function to spare the caller from adding unsafe {}
135    /// blocks hither and yon.
136    pub fn set_var(&mut self, key: impl Into<String>, value: impl Into<String>) {
137        unsafe {
138            std::env::set_var(key.into(), value.into());
139        }
140    }
141}
142
143impl Drop for EphemeralEnv {
144    fn drop(&mut self) {
145        for (key, value) in std::env::vars() {
146            unsafe {
147                match self.initial_vars.get(&key) {
148                    Some(initial) => {
149                        if initial != &value {
150                            std::env::set_var(key, initial);
151                        }
152                    }
153                    None => std::env::remove_var(key),
154                }
155            }
156        }
157    }
158}
159
160#[cfg(test)]
161mod test {
162    use super::*;
163    use once_cell::sync::Lazy;
164    use tokio::sync::Mutex;
165
166    static TEST_LOCK: Lazy<tokio::sync::Mutex<()>> = Lazy::new(|| Mutex::new(()));
167
168    #[cfg(feature = "async")]
169    #[tokio::test]
170    async fn test_async_env_locking_does_not_block_the_event_loop() {
171        let _guard = TEST_LOCK.lock().await;
172        let key = "ASYNC_TEST_KEY";
173        unsafe {
174            env::set_var(key, "original");
175        }
176
177        {
178            let mut env = EphemeralEnv::from_env_async().await.unwrap();
179            env.set_var(key, "temporary");
180
181            // We can safely await other things while holding the lock
182            // without breaking the executor or the compiler.
183            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
184
185            assert_eq!(env::var(key).unwrap(), "temporary");
186        } // Drops here, releasing lock
187
188        assert_eq!(env::var(key).unwrap(), "original");
189        unsafe {
190            env::remove_var(key);
191        }
192    }
193
194    #[cfg(feature = "sync")]
195    #[tokio::test]
196    async fn test_restores_modified_variable_on_drop() {
197        let _guard = TEST_LOCK.lock().await;
198        let key = "TEST_ENV_AUTO_LOCK";
199        unsafe {
200            std::env::set_var(key, "original");
201        }
202
203        {
204            let _ephemeral = EphemeralEnv::from_env_sync().unwrap();
205            unsafe {
206                std::env::set_var(key, "temp");
207            }
208            assert_eq!(std::env::var(key).unwrap(), "temp");
209        }
210
211        assert_eq!(std::env::var(key).unwrap(), "original");
212        unsafe {
213            std::env::remove_var(key);
214        }
215    }
216
217    #[cfg(feature = "sync")]
218    #[tokio::test]
219    async fn test_sync_ephemeral_env_is_threadsafe() {
220        let _guard = TEST_LOCK.lock().await;
221        // Spawn some threads that would normally race. If the lock works, they will run
222        // sequentially without error.
223        let handles: Vec<_> = (0..10)
224            .map(|i| {
225                std::thread::spawn(move || {
226                    let key = "TEST_RACE_CONDITION";
227                    let _ephemeral = EphemeralEnv::from_env_sync().unwrap();
228
229                    // If race conditions were present, another thread might
230                    // change this value while we are sleeping.
231                    let val = format!("val_{}", i);
232                    unsafe {
233                        std::env::set_var(key, &val);
234                    }
235
236                    // Simulate some work
237                    std::thread::sleep(std::time::Duration::from_millis(10));
238                    assert_eq!(std::env::var(key).unwrap(), val);
239                })
240            })
241            .collect();
242
243        for handle in handles {
244            handle.join().unwrap();
245        }
246    }
247
248    #[cfg(feature = "async")]
249    #[tokio::test]
250    async fn test_async_ephemeral_env_is_threadsafe() {
251        let _guard = TEST_LOCK.lock().await;
252        // Spawn some threads that would normally race. If the lock works, they will run
253        // sequentially without error.
254        let handles: Vec<_> = (0..10)
255            .map(|i| {
256                tokio::spawn(async move {
257                    let key = "TEST_RACE_CONDITION";
258                    let _ephemeral = EphemeralEnv::from_env_async().await.unwrap();
259
260                    // If race conditions were present, another thread might
261                    // change this value while we are sleeping.
262                    let val = format!("val_{}", i);
263                    unsafe {
264                        std::env::set_var(key, &val);
265                    }
266
267                    // Simulate some work
268                    std::thread::sleep(std::time::Duration::from_millis(10));
269                    assert_eq!(std::env::var(key).unwrap(), val);
270                })
271            })
272            .collect();
273
274        for handle in handles {
275            handle.await.unwrap()
276        }
277    }
278}