shuttle_env_vars/
lib.rs

1use serde::{Deserialize, Serialize};
2use shuttle_runtime::async_trait;
3use shuttle_service::{error::CustomError, Factory, ResourceBuilder, Type};
4use shuttle_static_folder::{Paths, StaticFolder};
5use std::path::PathBuf;
6
7const DEFAULT_FOLDER: &str = ".env";
8const DEFAULT_ENV_PROD: &str = ".env";
9
10#[derive(Serialize)]
11pub struct EnvVars<'a> {
12    /// The folder to reach at runtime. Defaults to `.env`.
13    folder: &'a str,
14    /// The name of the file to use in production. Defaults to `.env`.
15    env_prod: &'a str,
16    /// The name of the file to use in local.
17    env_local: Option<&'a str>,
18    /// The static provider to use.
19    static_provider: Option<shuttle_static_folder::StaticFolder<'a>>,
20}
21
22#[derive(Debug)]
23pub struct EnvError(dotenvy::Error);
24
25impl<'a> EnvVars<'a> {
26    #[must_use]
27    pub fn folder(mut self, folder: &'a str) -> Self {
28        self.folder = folder;
29        self.static_provider = self.static_provider.map(|p| p.folder(folder));
30        self
31    }
32
33    #[must_use]
34    pub const fn env_prod(mut self, env_prod: &'a str) -> Self {
35        self.env_prod = env_prod;
36        self
37    }
38
39    #[must_use]
40    pub const fn env_local(mut self, env_local: &'a str) -> Self {
41        self.env_local = Some(env_local);
42        self
43    }
44
45    pub fn env_file_path(&self, output_dir: Option<&PathBuf>) -> PathBuf {
46        output_dir.map_or_else(
47            || self.env_local.unwrap_or("").into(),
48            |dir| dir.join(self.env_prod),
49        )
50    }
51
52    pub fn load_env_vars(env_file_path: &PathBuf) -> Result<PathBuf, EnvError> {
53        if env_file_path.as_os_str().is_empty() {
54            tracing::info!(?env_file_path, "Is empty!");
55            return Ok("".into());
56        }
57
58        tracing::info!(?env_file_path, "Loading env vars from file");
59
60        dotenvy::from_filename(env_file_path).map_err(|e| {
61            tracing::error!(?e, "Failed to load env vars");
62            EnvError(e)
63        })
64    }
65}
66
67#[derive(Serialize, Deserialize)]
68pub struct ResourceOutput {
69    env_prod: String,
70    env_local: String,
71    paths: Option<Paths>,
72}
73
74impl ResourceOutput {
75    pub fn new(paths: Option<Paths>, env_local: Option<&str>, env_prod: &str) -> Self {
76        Self {
77            paths,
78            env_local: env_local.unwrap_or("").to_string(),
79            env_prod: env_prod.to_string(),
80        }
81    }
82
83    pub fn env_file_path(&self, output_dir: Option<&PathBuf>) -> PathBuf {
84        output_dir.map_or_else(
85            || self.env_local.clone().into(),
86            |dir| dir.join(self.env_prod.clone()),
87        )
88    }
89}
90
91#[async_trait]
92impl<'a> ResourceBuilder<PathBuf> for EnvVars<'a> {
93    const TYPE: Type = Type::StaticFolder;
94    type Config = &'a str;
95    type Output = ResourceOutput;
96
97    fn new() -> Self {
98        let static_provider = shuttle_static_folder::StaticFolder::new().folder(DEFAULT_FOLDER);
99        Self {
100            folder: DEFAULT_FOLDER,
101            env_prod: DEFAULT_ENV_PROD,
102            env_local: None,
103            static_provider: Some(static_provider),
104        }
105    }
106
107    fn config(&self) -> &&'a str {
108        &self.folder
109    }
110
111    async fn output(
112        mut self,
113        factory: &mut dyn Factory,
114    ) -> Result<Self::Output, shuttle_service::Error> {
115        tracing::info!("Calling output function");
116
117        // is production?
118        let env = factory.get_environment();
119        let is_production = match env {
120            shuttle_service::Environment::Production => true,
121            shuttle_service::Environment::Local => false,
122        };
123
124        tracing::debug!(?is_production, "Is production?");
125
126        if !is_production {
127            tracing::info!("Not in production, loading env vars from file");
128            let resource = ResourceOutput::new(None, self.env_local, self.env_prod);
129            return Ok(resource);
130        }
131
132        tracing::trace!("Calling Static provider");
133        let static_provider = self
134            .static_provider
135            .take()
136            .expect("Static Provider is missing");
137
138        tracing::trace!("Getting paths");
139        let paths = static_provider.output(factory).await?;
140        tracing::info!("Static provider returned");
141
142        let resource = ResourceOutput::new(Some(paths), self.env_local, self.env_prod);
143        Ok(resource)
144    }
145
146    async fn build(build_data: &Self::Output) -> Result<PathBuf, shuttle_service::Error> {
147        if let Some(paths) = build_data.paths.as_ref() {
148            // production environment
149            tracing::info!("build method called for production");
150            let output_dir = StaticFolder::build(paths).await?;
151            tracing::info!("Got output_dir from StaticFolder::build {:?}", output_dir);
152            let env_file_path = build_data.env_file_path(Some(&output_dir));
153            Self::load_env_vars(&env_file_path)?;
154            Ok(output_dir)
155        } else {
156            // development environment
157            tracing::info!("build method called for development");
158            let env_file_path = build_data.env_file_path(None);
159            Self::load_env_vars(&env_file_path)?;
160            Ok(env_file_path)
161        }
162    }
163}
164
165impl From<EnvError> for shuttle_service::Error {
166    fn from(error: EnvError) -> Self {
167        let msg = format!("Cannot load env vars: {error:?}");
168        Self::Custom(CustomError::msg(msg))
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use std::fs;
175    use std::path::PathBuf;
176
177    use shuttle_runtime::async_trait;
178    use shuttle_service::{DatabaseReadyInfo, Factory, ResourceBuilder};
179    use tempfile::{Builder, TempDir};
180
181    use super::*;
182
183    struct MockFactory {
184        temp_dir: TempDir,
185        is_production: bool,
186    }
187
188    // Will have this tree across all the production tests
189    // .
190    // ├── build
191    // │   └── .env
192    // │       └── .env
193    // ├── storage
194    // │   └── .env
195    // │       └── .env
196    // └── escape
197    //     └── passwd
198    impl MockFactory {
199        fn new(is_production: bool) -> Self {
200            Self {
201                temp_dir: Builder::new().prefix("env_folder").tempdir().unwrap(),
202                is_production,
203            }
204        }
205
206        fn build_path(&self) -> PathBuf {
207            self.get_path("build")
208        }
209
210        fn storage_path(&self) -> PathBuf {
211            self.get_path("storage")
212        }
213
214        fn escape_path(&self) -> PathBuf {
215            self.get_path("escape")
216        }
217
218        fn get_path(&self, folder: &str) -> PathBuf {
219            let path = self.temp_dir.path().join(folder);
220
221            if !path.exists() {
222                fs::create_dir(&path).unwrap();
223            }
224
225            path
226        }
227    }
228
229    #[async_trait]
230    impl Factory for MockFactory {
231        async fn get_db_connection(
232            &mut self,
233            _db_type: shuttle_service::database::Type,
234        ) -> Result<DatabaseReadyInfo, shuttle_service::Error> {
235            panic!("no env folder test should try to get a db connection string")
236        }
237
238        async fn get_secrets(
239            &mut self,
240        ) -> Result<std::collections::BTreeMap<String, String>, shuttle_service::Error> {
241            panic!("no env folder test should try to get secrets")
242        }
243
244        fn get_service_name(&self) -> shuttle_service::ServiceName {
245            panic!("no env folder test should try to get the service name")
246        }
247
248        fn get_environment(&self) -> shuttle_service::Environment {
249            if self.is_production {
250                shuttle_service::Environment::Production
251            } else {
252                shuttle_service::Environment::Local
253            }
254        }
255
256        fn get_build_path(&self) -> Result<std::path::PathBuf, shuttle_service::Error> {
257            Ok(self.build_path())
258        }
259
260        fn get_storage_path(&self) -> Result<std::path::PathBuf, shuttle_service::Error> {
261            Ok(self.storage_path())
262        }
263    }
264
265    #[tokio::test]
266    async fn copies_folder_if_production() {
267        let mut factory = MockFactory::new(true);
268
269        const CONTENT: &str = "MY_VAR0=1";
270
271        let input_file_path = factory
272            .build_path()
273            .join(DEFAULT_FOLDER)
274            .join(DEFAULT_ENV_PROD);
275        fs::create_dir_all(input_file_path.parent().unwrap()).unwrap();
276        fs::write(input_file_path, CONTENT).unwrap();
277
278        let expected_file = factory
279            .storage_path()
280            .join(DEFAULT_FOLDER)
281            .join(DEFAULT_ENV_PROD);
282
283        assert!(!expected_file.exists(), "input file should not exist yet");
284
285        // Call plugin
286        let env_folder = EnvVars::new();
287        let resource_output = env_folder.output(&mut factory).await.unwrap();
288        let output_folder = EnvVars::build(&resource_output).await.unwrap();
289
290        assert_eq!(
291            output_folder,
292            factory.storage_path().join(DEFAULT_FOLDER),
293            "expect path to the env folder to be in the storage folder"
294        );
295        assert!(
296            expected_file.exists(),
297            "expected input file to be created in storage folder"
298        );
299        assert_eq!(
300            fs::read_to_string(expected_file).unwrap(),
301            CONTENT,
302            "expected file content to match"
303        );
304    }
305
306    #[tokio::test]
307    async fn copies_folder_if_production_with_custom_folder_and_prod_file() {
308        let mut factory = MockFactory::new(true);
309
310        const CONTENT: &str = "MY_VAR1=1";
311        const ENV_FOLDER: &str = "custom_env_folder";
312        const ENV_PROD_FILE: &str = ".env-prod";
313
314        let input_file_path = factory.build_path().join(ENV_FOLDER).join(ENV_PROD_FILE);
315        fs::create_dir_all(input_file_path.parent().unwrap()).unwrap();
316        fs::write(input_file_path, CONTENT).unwrap();
317
318        let expected_file = factory.storage_path().join(ENV_FOLDER).join(ENV_PROD_FILE);
319
320        assert!(!expected_file.exists(), "input file should not exist yet");
321
322        // Call plugin
323        let env_folder = EnvVars::new().folder(ENV_FOLDER).env_prod(ENV_PROD_FILE);
324        let resource_output = env_folder.output(&mut factory).await.unwrap();
325        let output_folder = EnvVars::build(&resource_output).await.unwrap();
326
327        assert_eq!(
328            output_folder,
329            factory.storage_path().join(ENV_FOLDER),
330            "expect path to the env folder to be in the storage folder"
331        );
332        assert!(
333            expected_file.exists(),
334            "expected input file to be created in storage folder"
335        );
336        assert_eq!(
337            fs::read_to_string(expected_file).unwrap(),
338            CONTENT,
339            "expected file content to match"
340        );
341    }
342
343    #[tokio::test]
344    #[should_panic(expected = "Cannot use an absolute path for a static folder")]
345    async fn cannot_use_absolute_path() {
346        let mut factory = MockFactory::new(true);
347        let env_folder = EnvVars::new();
348
349        let _ = env_folder
350            .folder("/etc")
351            .output(&mut factory)
352            .await
353            .unwrap();
354    }
355
356    #[tokio::test]
357    async fn can_use_absolute_path_if_local() {
358        let mut factory = MockFactory::new(false);
359        let env_folder = EnvVars::new();
360
361        let resource_output = env_folder
362            .folder("/etc")
363            .output(&mut factory)
364            .await
365            .unwrap();
366        let output_folder = EnvVars::build(&resource_output).await.unwrap();
367
368        assert!(
369            output_folder.as_os_str().is_empty(),
370            "should return empty path"
371        );
372    }
373
374    #[tokio::test]
375    async fn folder_is_ignored_if_local_and_local_file_absolute() {
376        let mut factory = MockFactory::new(false);
377
378        const CONTENT: &str = "MY_VAR2=1";
379        const ENV_FOLDER: &str = "../other";
380        const ENV_LOCAL_FILE: &str = ".env-dev";
381
382        let local_env_path = factory.build_path().join(ENV_FOLDER).join(ENV_LOCAL_FILE);
383        fs::create_dir_all(&local_env_path.parent().unwrap()).unwrap();
384        fs::write(&local_env_path, CONTENT).unwrap();
385
386        // Call plugin
387        let env_folder = EnvVars::new()
388            .folder("/etc")
389            .env_local(local_env_path.to_str().unwrap());
390
391        let resource_output = env_folder.output(&mut factory).await.unwrap();
392        let output_folder = EnvVars::build(&resource_output).await.unwrap();
393
394        assert_eq!(
395            output_folder, local_env_path,
396            "should return local env path"
397        );
398        assert_eq!(
399            std::env::var("MY_VAR2").unwrap(),
400            "1",
401            "should load env var"
402        );
403    }
404
405    #[tokio::test]
406    #[should_panic(expected = "Cannot traverse out of crate for a static folder")]
407    async fn cannot_traverse_up() {
408        let mut factory = MockFactory::new(true);
409
410        let password_file_path = factory.escape_path().join("passwd");
411        fs::create_dir_all(password_file_path.parent().unwrap()).unwrap();
412        fs::write(password_file_path, "qwerty").unwrap();
413
414        // Call plugin
415        let env_folder = EnvVars::new();
416
417        let _ = env_folder
418            .folder("../escape")
419            .output(&mut factory)
420            .await
421            .unwrap();
422    }
423
424    #[tokio::test]
425    async fn can_traverse_up_if_local_and_no_local_file() {
426        let mut factory = MockFactory::new(false);
427
428        let password_file_path = factory.escape_path().join("passwd");
429        fs::create_dir_all(password_file_path.parent().unwrap()).unwrap();
430        fs::write(password_file_path, "qwerty").unwrap();
431
432        // Call plugin
433        let env_folder = EnvVars::new();
434
435        let resource_output = env_folder
436            .folder("../escape")
437            .output(&mut factory)
438            .await
439            .unwrap();
440
441        let output_folder = EnvVars::build(&resource_output).await.unwrap();
442
443        assert!(
444            output_folder.as_os_str().is_empty(),
445            "should return empty path"
446        );
447    }
448
449    #[tokio::test]
450    async fn folder_is_ignored_if_local_and_local_file() {
451        let mut factory = MockFactory::new(false);
452
453        const CONTENT: &str = "MY_VAR3=1";
454        const ENV_FOLDER: &str = "../other";
455        const ENV_LOCAL_FILE: &str = ".env-dev";
456
457        let password_file_path = factory.escape_path().join("passwd");
458        fs::create_dir_all(password_file_path.parent().unwrap()).unwrap();
459        fs::write(password_file_path, "qwerty").unwrap();
460
461        let local_env_path = factory.build_path().join(ENV_FOLDER).join(ENV_LOCAL_FILE);
462        fs::create_dir_all(&local_env_path.parent().unwrap()).unwrap();
463        fs::write(&local_env_path, CONTENT).unwrap();
464
465        // Call plugin
466        let env_folder = EnvVars::new()
467            .folder("../escape")
468            .env_local(local_env_path.to_str().unwrap());
469
470        let resource_output = env_folder.output(&mut factory).await.unwrap();
471        let output_folder = EnvVars::build(&resource_output).await.unwrap();
472
473        assert_eq!(
474            output_folder, local_env_path,
475            "should return local env path"
476        );
477        assert_eq!(
478            std::env::var("MY_VAR3").unwrap(),
479            "1",
480            "should load env var"
481        );
482    }
483
484    #[tokio::test]
485    #[should_panic(expected = "Cannot load env vars")]
486    async fn panics_if_local_and_local_file_is_not_correct() {
487        let mut factory = MockFactory::new(false);
488
489        const CONTENT: &str = "MY_VAR4=1";
490        const ENV_FOLDER: &str = "../other";
491        const ENV_LOCAL_FILE: &str = ".env-dev";
492
493        let local_env_path = factory.build_path().join(ENV_FOLDER).join(ENV_LOCAL_FILE);
494        fs::create_dir_all(&local_env_path.parent().unwrap()).unwrap();
495        fs::write(&local_env_path, CONTENT).unwrap();
496
497        // Call plugin
498        let env_folder = EnvVars::new().folder("random").env_local("random/.env-dev");
499
500        let output = env_folder.output(&mut factory).await.unwrap();
501        let _ = EnvVars::build(&output).await.unwrap();
502    }
503
504    #[tokio::test]
505    async fn works_if_folder_and_prod_file_custom() {
506        let mut factory = MockFactory::new(true);
507
508        const CONTENT: &str = "MY_VAR5=1";
509        const ENV_FOLDER: &str = "other";
510        const ENV_PROD_FILE: &str = ".env-prod";
511
512        let env_path = factory.build_path().join(ENV_FOLDER).join(ENV_PROD_FILE);
513        fs::create_dir_all(&env_path.parent().unwrap()).unwrap();
514        fs::write(&env_path, CONTENT).unwrap();
515
516        // Call plugin
517        let env_folder = EnvVars::new().folder(ENV_FOLDER).env_prod(ENV_PROD_FILE);
518
519        let resource_output = env_folder.output(&mut factory).await.unwrap();
520        let _ = EnvVars::build(&resource_output).await;
521
522        let expected_output_folder = factory.storage_path().join(ENV_FOLDER);
523        let output_folder = EnvVars::build(&resource_output).await.unwrap();
524
525        assert_eq!(
526            output_folder, expected_output_folder,
527            "should return storage folder"
528        );
529        assert_eq!(
530            std::env::var("MY_VAR5").unwrap(),
531            "1",
532            "should load env var"
533        );
534    }
535
536    #[tokio::test]
537    async fn works_if_folder_and_prod_file_default() {
538        let mut factory = MockFactory::new(true);
539
540        const CONTENT: &str = "MY_VAR6=1";
541
542        let env_path = factory
543            .build_path()
544            .join(DEFAULT_FOLDER)
545            .join(DEFAULT_ENV_PROD);
546        fs::create_dir_all(&env_path.parent().unwrap()).unwrap();
547        fs::write(&env_path, CONTENT).unwrap();
548
549        // Call plugin
550        let env_folder = EnvVars::new()
551            .folder(DEFAULT_FOLDER)
552            .env_prod(DEFAULT_ENV_PROD);
553
554        let resource_output = env_folder.output(&mut factory).await.unwrap();
555
556        let _ = EnvVars::build(&resource_output).await;
557
558        let expected_output_folder = factory.storage_path().join(DEFAULT_FOLDER);
559        let output_folder = EnvVars::build(&resource_output).await.unwrap();
560
561        assert_eq!(
562            output_folder, expected_output_folder,
563            "should return storage folder"
564        );
565        assert_eq!(
566            std::env::var("MY_VAR6").unwrap(),
567            "1",
568            "should load env var"
569        );
570    }
571
572    #[tokio::test]
573    #[should_panic(expected = "Cannot load env vars")]
574    async fn panics_if_folder_and_prod_file_default_not_present() {
575        let mut factory = MockFactory::new(true);
576
577        let env_path = factory
578            .build_path()
579            .join(DEFAULT_FOLDER)
580            .join(DEFAULT_ENV_PROD);
581        fs::create_dir_all(&env_path.parent().unwrap()).unwrap();
582
583        // Call plugin
584        let env_folder = EnvVars::new()
585            .folder(DEFAULT_FOLDER)
586            .env_prod(DEFAULT_ENV_PROD);
587
588        let output = env_folder.output(&mut factory).await.unwrap();
589        let _ = EnvVars::build(&output).await.unwrap();
590    }
591}