tracel_xtask/
environment.rs

1use std::{
2    collections::HashMap,
3    fmt::{self, Display, Write as _},
4    marker::PhantomData,
5    path::PathBuf,
6};
7
8use strum::{EnumIter, EnumString};
9
10use crate::{group_error, group_info, utils::git};
11
12/// Implicit index which means that index '1' is omitted in display.
13#[derive(Clone, Debug, PartialEq, Default)]
14pub struct ImplicitIndex;
15
16/// Explicit index which means that index is always in display.
17#[derive(Clone, Debug, PartialEq, Default)]
18pub struct ExplicitIndex;
19
20/// Style for how to format `{base}{index}`.
21pub trait IndexStyle {
22    fn format(base: &str, index: u8) -> String;
23}
24
25impl IndexStyle for ImplicitIndex {
26    fn format(base: &str, index: u8) -> String {
27        if index == 1 {
28            base.to_string()
29        } else {
30            format!("{base}{index}")
31        }
32    }
33}
34
35impl IndexStyle for ExplicitIndex {
36    fn format(base: &str, index: u8) -> String {
37        format!("{base}{index}")
38    }
39}
40
41#[derive(Clone, Debug, Default, PartialEq)]
42pub struct Environment<M = ImplicitIndex> {
43    pub name: EnvironmentName,
44    pub index: EnvironmentIndex,
45    _marker: PhantomData<M>,
46}
47
48impl<M> Environment<M> {
49    pub fn new(name: EnvironmentName, index: u8) -> Self {
50        Self {
51            name,
52            index: index.into(),
53            _marker: PhantomData,
54        }
55    }
56
57    pub fn index(&self) -> u8 {
58        self.index.index
59    }
60}
61
62impl Environment<ImplicitIndex> {
63    /// Turn an non explicit environment into an explicit one.
64    /// An explicit environment will always append the index number to its display names.
65    /// Whereas a non-explicit one (default) only append the index if it is different than 1.
66    pub fn into_explicit(self) -> Environment<ExplicitIndex> {
67        Environment {
68            name: self.name.clone(),
69            index: self.index().into(),
70            _marker: PhantomData,
71        }
72    }
73}
74
75impl Environment<ExplicitIndex> {
76    pub fn into_implicit(self) -> Environment<ImplicitIndex> {
77        Environment {
78            name: self.name.clone(),
79            index: self.index().into(),
80            _marker: PhantomData,
81        }
82    }
83}
84
85impl<M: IndexStyle> Display for Environment<M> {
86    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87        write!(f, "{}", self.medium())
88    }
89}
90
91impl<M: IndexStyle> Environment<M> {
92    pub fn long(&self) -> String {
93        M::format(self.name.long(), self.index())
94    }
95
96    pub fn medium(&self) -> String {
97        M::format(self.name.medium(), self.index())
98    }
99
100    pub fn short(&self) -> String {
101        M::format(&self.name.short().to_string(), self.index())
102    }
103
104    pub fn get_dotenv_filename(&self) -> String {
105        format!(".env.{self}")
106    }
107
108    pub fn get_dotenv_secrets_filename(&self) -> String {
109        format!("{}.secrets", self.get_dotenv_filename())
110    }
111
112    pub fn get_env_files(&self) -> [String; 3] {
113        let filename = self.get_dotenv_filename();
114        let secrets_filename = self.get_dotenv_secrets_filename();
115        [
116            ".env".to_owned(),
117            filename.to_owned(),
118            secrets_filename.to_owned(),
119        ]
120    }
121
122    /// Load the .env environment files family.
123    pub fn load(&self, prefix: Option<&str>) -> anyhow::Result<()> {
124        let files = self.get_env_files();
125        files.iter().for_each(|f| {
126            let path = if let Some(p) = prefix {
127                std::path::PathBuf::from(p).join(f)
128            } else {
129                std::path::PathBuf::from(f)
130            };
131            if path.exists() {
132                match dotenvy::from_filename(f) {
133                    Ok(_) => {
134                        group_info!("loading '{}' file...", f);
135                    }
136                    Err(e) => {
137                        group_error!("error while loading '{}' file ({})", f, e);
138                    }
139                }
140            }
141        });
142        Ok(())
143    }
144
145    /// Merge all the .env files of the environment with all variable expanded
146    pub fn merge_env_files(&self) -> anyhow::Result<PathBuf> {
147        let repo_root = git::git_repo_root_or_cwd()?;
148        let files = self.get_env_files();
149        // merged set of env vars, the later files override earlier ones
150        // we sort keys to have a more deterministic merged file result
151        let mut merged: HashMap<String, String> = HashMap::new();
152        for filename in files {
153            let path = repo_root.join(&filename);
154            if !path.exists() {
155                eprintln!(
156                    "⚠️ Warning: environment file '{}' ({}) not found, skipping...",
157                    filename,
158                    path.display()
159                );
160                continue;
161            }
162            for item in dotenvy::from_path_iter(&path)? {
163                let (key, value) = item?;
164                std::env::set_var(&key, &value);
165                merged.insert(key, value);
166            }
167        }
168        let mut keys: Vec<_> = merged.keys().cloned().collect();
169        keys.sort();
170        // write merged file
171        let mut out = String::new();
172        for key in keys {
173            let val = &merged[&key];
174            writeln!(&mut out, "{key}={val}")?;
175        }
176        let tmp_path = std::env::temp_dir().join(format!("merged-env-{}.tmp", std::process::id()));
177        std::fs::write(&tmp_path, out)?;
178        Ok(tmp_path)
179    }
180}
181
182#[derive(EnumString, EnumIter, Default, Clone, Debug, PartialEq, clap::ValueEnum)]
183#[strum(serialize_all = "lowercase")]
184pub enum EnvironmentName {
185    /// Development environment (alias: dev).
186    #[default]
187    #[clap(alias = "dev")]
188    Development,
189    /// Staging environment (alias: stag).
190    #[clap(alias = "stag")]
191    Staging,
192    /// Testing environment (alias: test).
193    #[clap(alias = "test")]
194    Test,
195    /// Production environment (alias: prod).
196    #[clap(alias = "prod")]
197    Production,
198}
199
200impl Display for EnvironmentName {
201    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
202        write!(f, "{}", self.medium())
203    }
204}
205
206impl EnvironmentName {
207    pub fn long(&self) -> &'static str {
208        match self {
209            EnvironmentName::Development => "development",
210            EnvironmentName::Staging => "staging",
211            EnvironmentName::Test => "test",
212            EnvironmentName::Production => "production",
213        }
214    }
215
216    pub fn medium(&self) -> &'static str {
217        match self {
218            EnvironmentName::Development => "dev",
219            EnvironmentName::Staging => "stag",
220            EnvironmentName::Test => "test",
221            EnvironmentName::Production => "prod",
222        }
223    }
224
225    pub fn short(&self) -> char {
226        match self {
227            EnvironmentName::Development => 'd',
228            EnvironmentName::Staging => 's',
229            EnvironmentName::Test => 't',
230            EnvironmentName::Production => 'p',
231        }
232    }
233}
234
235#[derive(Clone, Debug, PartialEq)]
236pub struct EnvironmentIndex {
237    pub index: u8,
238}
239
240impl Default for EnvironmentIndex {
241    fn default() -> Self {
242        Self { index: 1 }
243    }
244}
245
246impl Display for EnvironmentIndex {
247    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
248        write!(f, "{}", self.index)
249    }
250}
251
252impl From<u8> for EnvironmentIndex {
253    fn from(index: u8) -> Self {
254        Self { index }
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use rstest::rstest;
262    use serial_test::serial;
263    use std::env;
264
265    // For tests we always use the implicit style
266    type TestEnv = Environment<ImplicitIndex>;
267
268    fn expected_vars(env: &TestEnv) -> Vec<(String, String)> {
269        let suffix = match env.name {
270            EnvironmentName::Development => "DEV",
271            EnvironmentName::Staging => "STAG",
272            EnvironmentName::Test => "TEST",
273            EnvironmentName::Production => "PROD",
274        };
275
276        vec![
277            ("FROM_DOTENV".to_string(), ".env".to_string()),
278            (
279                format!("FROM_DOTENV_{suffix}").to_string(),
280                env.get_dotenv_filename(),
281            ),
282            (
283                format!("FROM_DOTENV_{suffix}_SECRETS").to_string(),
284                env.get_dotenv_secrets_filename(),
285            ),
286        ]
287    }
288
289    #[rstest]
290    #[case::dev(TestEnv::new(EnvironmentName::Development, 1))]
291    #[case::stag(TestEnv::new(EnvironmentName::Staging, 1))]
292    #[case::test(TestEnv::new(EnvironmentName::Test, 1))]
293    #[case::prod(TestEnv::new(EnvironmentName::Production, 1))]
294    #[serial]
295    fn test_environment_load(#[case] env: TestEnv) {
296        // Remove possible prior values
297        for (key, _) in expected_vars(&env) {
298            env::remove_var(key);
299        }
300
301        // Run the actual function under test
302        env.load(Some("../.."))
303            .expect("Environment load should succeed");
304
305        // Assert each expected env var is present and has the correct value
306        for (key, expected_value) in expected_vars(&env) {
307            let actual_value =
308                env::var(&key).unwrap_or_else(|_| panic!("Missing expected env var: {key}"));
309            assert_eq!(
310                actual_value, expected_value,
311                "Environment variable {key} should be set to {expected_value} but was {actual_value}"
312            );
313        }
314    }
315
316    #[rstest]
317    #[case::dev(TestEnv::new(EnvironmentName::Development, 1))]
318    #[case::stag(TestEnv::new(EnvironmentName::Staging, 1))]
319    #[case::test(TestEnv::new(EnvironmentName::Test, 1))]
320    #[case::prod(TestEnv::new(EnvironmentName::Production, 1))]
321    #[serial]
322    fn test_environment_merge_env_files(#[case] env: TestEnv) {
323        // Make sure we start from a clean state
324        for (key, _) in expected_vars(&env) {
325            env::remove_var(key);
326        }
327        // Generate the merged env file
328        let merged_path = env
329            .merge_env_files()
330            .expect("merge_env_files should succeed");
331        assert!(
332            merged_path.exists(),
333            "Merged env file should exist at {}",
334            merged_path.display()
335        );
336        // Parse the merged file as a .env file again
337        let mut merged_map: std::collections::HashMap<String, String> =
338            std::collections::HashMap::new();
339        for item in
340            dotenvy::from_path_iter(&merged_path).expect("Reading merged env file should succeed")
341        {
342            let (key, value) = item.expect("Parsing key/value from merged env file should succeed");
343            merged_map.insert(key, value);
344        }
345        // All the vars we expect from the individual files must be present
346        for (key, expected_value) in expected_vars(&env) {
347            let actual_value = merged_map
348                .get(&key)
349                .unwrap_or_else(|| panic!("Missing expected merged env var: {key}"));
350            assert_eq!(
351                actual_value, &expected_value,
352                "Merged env var {key} should be {expected_value} but was {actual_value}"
353            );
354        }
355    }
356
357    #[test]
358    #[serial]
359    fn test_environment_merge_env_files_expansion() {
360        let env = Environment::<ImplicitIndex>::new(EnvironmentName::Staging, 1);
361        // Clean any prior values that could interfere
362        env::remove_var("LOG_LEVEL_TEST");
363        env::remove_var("RUST_LOG_TEST");
364        env::remove_var("RUST_LOG_STAG_TEST");
365
366        let merged_path = env
367            .merge_env_files()
368            .expect("merge_env_files should succeed");
369        let mut merged_map: std::collections::HashMap<String, String> =
370            std::collections::HashMap::new();
371        for item in
372            dotenvy::from_path_iter(&merged_path).expect("Reading merged env file should succeed")
373        {
374            let (key, value) = item.expect("Parsing key/value from merged env file should succeed");
375            merged_map.insert(key, value);
376        }
377
378        let log_level = merged_map
379            .get("LOG_LEVEL_TEST")
380            .expect("LOG_LEVEL_TEST should be present in merged env file");
381        let rust_log = merged_map
382            .get("RUST_LOG_TEST")
383            .expect("RUST_LOG_TEST should be present in merged env file");
384
385        // 1) We should not see the raw placeholder anymore
386        assert!(
387            !rust_log.contains("${LOG_LEVEL_TEST}"),
388            "RUST_LOG_TEST should not contain the raw placeholder '${{LOG_LEVEL}}', got: {rust_log}"
389        );
390        // 2) The expanded LOG_LEVEL_TEST value should appear in RUST_LOG_TEST
391        assert!(
392            rust_log.contains(log_level),
393            "RUST_LOG_TEST should contain the expanded LOG_LEVEL_TEST value; LOG_LEVEL_TEST={log_level}, RUST_LOG_TEST={rust_log}"
394        );
395        // Cross-file expansion with RUST_LOG_STAG_TEST that references LOG_LEVEL_TEST from base .env
396        let rust_log_stag = merged_map
397            .get("RUST_LOG_STAG_TEST")
398            .expect("RUST_LOG_STAG_TEST should be present in merged env file");
399        // 3) No raw placeholder in the cross-file value either
400        assert!(
401            !rust_log_stag.contains("${LOG_LEVEL_TEST}"),
402            "RUST_LOG_STAG_TEST should not contain the raw placeholder '${{LOG_LEVEL_TEST}}', got: {rust_log_stag}"
403        );
404        // 4) The expanded LOG_LEVEL_TEST value should appear in RUST_LOG_STAG_TEST
405        assert!(
406            rust_log_stag.contains(log_level),
407            "RUST_LOG_STAG_TEST should contain the expanded LOG_LEVEL_TEST value; LOG_LEVEL_TEST={log_level}, RUST_LOG_STAG_TEST={rust_log_stag}"
408        );
409    }
410}