1use serde::{Deserialize, Serialize};
2use shuttle_runtime::async_trait;
3use shuttle_service::{error::CustomError, Factory, ResourceBuilder, Type};
4use shuttle_static_folder::{Paths, StaticFolder};
5use std::path::PathBuf;
6
7const DEFAULT_FOLDER: &str = ".env";
8const DEFAULT_ENV_PROD: &str = ".env";
9
10#[derive(Serialize)]
11pub struct EnvVars<'a> {
12 folder: &'a str,
14 env_prod: &'a str,
16 env_local: Option<&'a str>,
18 static_provider: Option<shuttle_static_folder::StaticFolder<'a>>,
20}
21
22#[derive(Debug)]
23pub struct EnvError(dotenvy::Error);
24
25impl<'a> EnvVars<'a> {
26 #[must_use]
27 pub fn folder(mut self, folder: &'a str) -> Self {
28 self.folder = folder;
29 self.static_provider = self.static_provider.map(|p| p.folder(folder));
30 self
31 }
32
33 #[must_use]
34 pub const fn env_prod(mut self, env_prod: &'a str) -> Self {
35 self.env_prod = env_prod;
36 self
37 }
38
39 #[must_use]
40 pub const fn env_local(mut self, env_local: &'a str) -> Self {
41 self.env_local = Some(env_local);
42 self
43 }
44
45 pub fn env_file_path(&self, output_dir: Option<&PathBuf>) -> PathBuf {
46 output_dir.map_or_else(
47 || self.env_local.unwrap_or("").into(),
48 |dir| dir.join(self.env_prod),
49 )
50 }
51
52 pub fn load_env_vars(env_file_path: &PathBuf) -> Result<PathBuf, EnvError> {
53 if env_file_path.as_os_str().is_empty() {
54 tracing::info!(?env_file_path, "Is empty!");
55 return Ok("".into());
56 }
57
58 tracing::info!(?env_file_path, "Loading env vars from file");
59
60 dotenvy::from_filename(env_file_path).map_err(|e| {
61 tracing::error!(?e, "Failed to load env vars");
62 EnvError(e)
63 })
64 }
65}
66
67#[derive(Serialize, Deserialize)]
68pub struct ResourceOutput {
69 env_prod: String,
70 env_local: String,
71 paths: Option<Paths>,
72}
73
74impl ResourceOutput {
75 pub fn new(paths: Option<Paths>, env_local: Option<&str>, env_prod: &str) -> Self {
76 Self {
77 paths,
78 env_local: env_local.unwrap_or("").to_string(),
79 env_prod: env_prod.to_string(),
80 }
81 }
82
83 pub fn env_file_path(&self, output_dir: Option<&PathBuf>) -> PathBuf {
84 output_dir.map_or_else(
85 || self.env_local.clone().into(),
86 |dir| dir.join(self.env_prod.clone()),
87 )
88 }
89}
90
91#[async_trait]
92impl<'a> ResourceBuilder<PathBuf> for EnvVars<'a> {
93 const TYPE: Type = Type::StaticFolder;
94 type Config = &'a str;
95 type Output = ResourceOutput;
96
97 fn new() -> Self {
98 let static_provider = shuttle_static_folder::StaticFolder::new().folder(DEFAULT_FOLDER);
99 Self {
100 folder: DEFAULT_FOLDER,
101 env_prod: DEFAULT_ENV_PROD,
102 env_local: None,
103 static_provider: Some(static_provider),
104 }
105 }
106
107 fn config(&self) -> &&'a str {
108 &self.folder
109 }
110
111 async fn output(
112 mut self,
113 factory: &mut dyn Factory,
114 ) -> Result<Self::Output, shuttle_service::Error> {
115 tracing::info!("Calling output function");
116
117 let env = factory.get_environment();
119 let is_production = match env {
120 shuttle_service::Environment::Production => true,
121 shuttle_service::Environment::Local => false,
122 };
123
124 tracing::debug!(?is_production, "Is production?");
125
126 if !is_production {
127 tracing::info!("Not in production, loading env vars from file");
128 let resource = ResourceOutput::new(None, self.env_local, self.env_prod);
129 return Ok(resource);
130 }
131
132 tracing::trace!("Calling Static provider");
133 let static_provider = self
134 .static_provider
135 .take()
136 .expect("Static Provider is missing");
137
138 tracing::trace!("Getting paths");
139 let paths = static_provider.output(factory).await?;
140 tracing::info!("Static provider returned");
141
142 let resource = ResourceOutput::new(Some(paths), self.env_local, self.env_prod);
143 Ok(resource)
144 }
145
146 async fn build(build_data: &Self::Output) -> Result<PathBuf, shuttle_service::Error> {
147 if let Some(paths) = build_data.paths.as_ref() {
148 tracing::info!("build method called for production");
150 let output_dir = StaticFolder::build(paths).await?;
151 tracing::info!("Got output_dir from StaticFolder::build {:?}", output_dir);
152 let env_file_path = build_data.env_file_path(Some(&output_dir));
153 Self::load_env_vars(&env_file_path)?;
154 Ok(output_dir)
155 } else {
156 tracing::info!("build method called for development");
158 let env_file_path = build_data.env_file_path(None);
159 Self::load_env_vars(&env_file_path)?;
160 Ok(env_file_path)
161 }
162 }
163}
164
165impl From<EnvError> for shuttle_service::Error {
166 fn from(error: EnvError) -> Self {
167 let msg = format!("Cannot load env vars: {error:?}");
168 Self::Custom(CustomError::msg(msg))
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use std::fs;
175 use std::path::PathBuf;
176
177 use shuttle_runtime::async_trait;
178 use shuttle_service::{DatabaseReadyInfo, Factory, ResourceBuilder};
179 use tempfile::{Builder, TempDir};
180
181 use super::*;
182
183 struct MockFactory {
184 temp_dir: TempDir,
185 is_production: bool,
186 }
187
188 impl MockFactory {
199 fn new(is_production: bool) -> Self {
200 Self {
201 temp_dir: Builder::new().prefix("env_folder").tempdir().unwrap(),
202 is_production,
203 }
204 }
205
206 fn build_path(&self) -> PathBuf {
207 self.get_path("build")
208 }
209
210 fn storage_path(&self) -> PathBuf {
211 self.get_path("storage")
212 }
213
214 fn escape_path(&self) -> PathBuf {
215 self.get_path("escape")
216 }
217
218 fn get_path(&self, folder: &str) -> PathBuf {
219 let path = self.temp_dir.path().join(folder);
220
221 if !path.exists() {
222 fs::create_dir(&path).unwrap();
223 }
224
225 path
226 }
227 }
228
229 #[async_trait]
230 impl Factory for MockFactory {
231 async fn get_db_connection(
232 &mut self,
233 _db_type: shuttle_service::database::Type,
234 ) -> Result<DatabaseReadyInfo, shuttle_service::Error> {
235 panic!("no env folder test should try to get a db connection string")
236 }
237
238 async fn get_secrets(
239 &mut self,
240 ) -> Result<std::collections::BTreeMap<String, String>, shuttle_service::Error> {
241 panic!("no env folder test should try to get secrets")
242 }
243
244 fn get_service_name(&self) -> shuttle_service::ServiceName {
245 panic!("no env folder test should try to get the service name")
246 }
247
248 fn get_environment(&self) -> shuttle_service::Environment {
249 if self.is_production {
250 shuttle_service::Environment::Production
251 } else {
252 shuttle_service::Environment::Local
253 }
254 }
255
256 fn get_build_path(&self) -> Result<std::path::PathBuf, shuttle_service::Error> {
257 Ok(self.build_path())
258 }
259
260 fn get_storage_path(&self) -> Result<std::path::PathBuf, shuttle_service::Error> {
261 Ok(self.storage_path())
262 }
263 }
264
265 #[tokio::test]
266 async fn copies_folder_if_production() {
267 let mut factory = MockFactory::new(true);
268
269 const CONTENT: &str = "MY_VAR0=1";
270
271 let input_file_path = factory
272 .build_path()
273 .join(DEFAULT_FOLDER)
274 .join(DEFAULT_ENV_PROD);
275 fs::create_dir_all(input_file_path.parent().unwrap()).unwrap();
276 fs::write(input_file_path, CONTENT).unwrap();
277
278 let expected_file = factory
279 .storage_path()
280 .join(DEFAULT_FOLDER)
281 .join(DEFAULT_ENV_PROD);
282
283 assert!(!expected_file.exists(), "input file should not exist yet");
284
285 let env_folder = EnvVars::new();
287 let resource_output = env_folder.output(&mut factory).await.unwrap();
288 let output_folder = EnvVars::build(&resource_output).await.unwrap();
289
290 assert_eq!(
291 output_folder,
292 factory.storage_path().join(DEFAULT_FOLDER),
293 "expect path to the env folder to be in the storage folder"
294 );
295 assert!(
296 expected_file.exists(),
297 "expected input file to be created in storage folder"
298 );
299 assert_eq!(
300 fs::read_to_string(expected_file).unwrap(),
301 CONTENT,
302 "expected file content to match"
303 );
304 }
305
306 #[tokio::test]
307 async fn copies_folder_if_production_with_custom_folder_and_prod_file() {
308 let mut factory = MockFactory::new(true);
309
310 const CONTENT: &str = "MY_VAR1=1";
311 const ENV_FOLDER: &str = "custom_env_folder";
312 const ENV_PROD_FILE: &str = ".env-prod";
313
314 let input_file_path = factory.build_path().join(ENV_FOLDER).join(ENV_PROD_FILE);
315 fs::create_dir_all(input_file_path.parent().unwrap()).unwrap();
316 fs::write(input_file_path, CONTENT).unwrap();
317
318 let expected_file = factory.storage_path().join(ENV_FOLDER).join(ENV_PROD_FILE);
319
320 assert!(!expected_file.exists(), "input file should not exist yet");
321
322 let env_folder = EnvVars::new().folder(ENV_FOLDER).env_prod(ENV_PROD_FILE);
324 let resource_output = env_folder.output(&mut factory).await.unwrap();
325 let output_folder = EnvVars::build(&resource_output).await.unwrap();
326
327 assert_eq!(
328 output_folder,
329 factory.storage_path().join(ENV_FOLDER),
330 "expect path to the env folder to be in the storage folder"
331 );
332 assert!(
333 expected_file.exists(),
334 "expected input file to be created in storage folder"
335 );
336 assert_eq!(
337 fs::read_to_string(expected_file).unwrap(),
338 CONTENT,
339 "expected file content to match"
340 );
341 }
342
343 #[tokio::test]
344 #[should_panic(expected = "Cannot use an absolute path for a static folder")]
345 async fn cannot_use_absolute_path() {
346 let mut factory = MockFactory::new(true);
347 let env_folder = EnvVars::new();
348
349 let _ = env_folder
350 .folder("/etc")
351 .output(&mut factory)
352 .await
353 .unwrap();
354 }
355
356 #[tokio::test]
357 async fn can_use_absolute_path_if_local() {
358 let mut factory = MockFactory::new(false);
359 let env_folder = EnvVars::new();
360
361 let resource_output = env_folder
362 .folder("/etc")
363 .output(&mut factory)
364 .await
365 .unwrap();
366 let output_folder = EnvVars::build(&resource_output).await.unwrap();
367
368 assert!(
369 output_folder.as_os_str().is_empty(),
370 "should return empty path"
371 );
372 }
373
374 #[tokio::test]
375 async fn folder_is_ignored_if_local_and_local_file_absolute() {
376 let mut factory = MockFactory::new(false);
377
378 const CONTENT: &str = "MY_VAR2=1";
379 const ENV_FOLDER: &str = "../other";
380 const ENV_LOCAL_FILE: &str = ".env-dev";
381
382 let local_env_path = factory.build_path().join(ENV_FOLDER).join(ENV_LOCAL_FILE);
383 fs::create_dir_all(&local_env_path.parent().unwrap()).unwrap();
384 fs::write(&local_env_path, CONTENT).unwrap();
385
386 let env_folder = EnvVars::new()
388 .folder("/etc")
389 .env_local(local_env_path.to_str().unwrap());
390
391 let resource_output = env_folder.output(&mut factory).await.unwrap();
392 let output_folder = EnvVars::build(&resource_output).await.unwrap();
393
394 assert_eq!(
395 output_folder, local_env_path,
396 "should return local env path"
397 );
398 assert_eq!(
399 std::env::var("MY_VAR2").unwrap(),
400 "1",
401 "should load env var"
402 );
403 }
404
405 #[tokio::test]
406 #[should_panic(expected = "Cannot traverse out of crate for a static folder")]
407 async fn cannot_traverse_up() {
408 let mut factory = MockFactory::new(true);
409
410 let password_file_path = factory.escape_path().join("passwd");
411 fs::create_dir_all(password_file_path.parent().unwrap()).unwrap();
412 fs::write(password_file_path, "qwerty").unwrap();
413
414 let env_folder = EnvVars::new();
416
417 let _ = env_folder
418 .folder("../escape")
419 .output(&mut factory)
420 .await
421 .unwrap();
422 }
423
424 #[tokio::test]
425 async fn can_traverse_up_if_local_and_no_local_file() {
426 let mut factory = MockFactory::new(false);
427
428 let password_file_path = factory.escape_path().join("passwd");
429 fs::create_dir_all(password_file_path.parent().unwrap()).unwrap();
430 fs::write(password_file_path, "qwerty").unwrap();
431
432 let env_folder = EnvVars::new();
434
435 let resource_output = env_folder
436 .folder("../escape")
437 .output(&mut factory)
438 .await
439 .unwrap();
440
441 let output_folder = EnvVars::build(&resource_output).await.unwrap();
442
443 assert!(
444 output_folder.as_os_str().is_empty(),
445 "should return empty path"
446 );
447 }
448
449 #[tokio::test]
450 async fn folder_is_ignored_if_local_and_local_file() {
451 let mut factory = MockFactory::new(false);
452
453 const CONTENT: &str = "MY_VAR3=1";
454 const ENV_FOLDER: &str = "../other";
455 const ENV_LOCAL_FILE: &str = ".env-dev";
456
457 let password_file_path = factory.escape_path().join("passwd");
458 fs::create_dir_all(password_file_path.parent().unwrap()).unwrap();
459 fs::write(password_file_path, "qwerty").unwrap();
460
461 let local_env_path = factory.build_path().join(ENV_FOLDER).join(ENV_LOCAL_FILE);
462 fs::create_dir_all(&local_env_path.parent().unwrap()).unwrap();
463 fs::write(&local_env_path, CONTENT).unwrap();
464
465 let env_folder = EnvVars::new()
467 .folder("../escape")
468 .env_local(local_env_path.to_str().unwrap());
469
470 let resource_output = env_folder.output(&mut factory).await.unwrap();
471 let output_folder = EnvVars::build(&resource_output).await.unwrap();
472
473 assert_eq!(
474 output_folder, local_env_path,
475 "should return local env path"
476 );
477 assert_eq!(
478 std::env::var("MY_VAR3").unwrap(),
479 "1",
480 "should load env var"
481 );
482 }
483
484 #[tokio::test]
485 #[should_panic(expected = "Cannot load env vars")]
486 async fn panics_if_local_and_local_file_is_not_correct() {
487 let mut factory = MockFactory::new(false);
488
489 const CONTENT: &str = "MY_VAR4=1";
490 const ENV_FOLDER: &str = "../other";
491 const ENV_LOCAL_FILE: &str = ".env-dev";
492
493 let local_env_path = factory.build_path().join(ENV_FOLDER).join(ENV_LOCAL_FILE);
494 fs::create_dir_all(&local_env_path.parent().unwrap()).unwrap();
495 fs::write(&local_env_path, CONTENT).unwrap();
496
497 let env_folder = EnvVars::new().folder("random").env_local("random/.env-dev");
499
500 let output = env_folder.output(&mut factory).await.unwrap();
501 let _ = EnvVars::build(&output).await.unwrap();
502 }
503
504 #[tokio::test]
505 async fn works_if_folder_and_prod_file_custom() {
506 let mut factory = MockFactory::new(true);
507
508 const CONTENT: &str = "MY_VAR5=1";
509 const ENV_FOLDER: &str = "other";
510 const ENV_PROD_FILE: &str = ".env-prod";
511
512 let env_path = factory.build_path().join(ENV_FOLDER).join(ENV_PROD_FILE);
513 fs::create_dir_all(&env_path.parent().unwrap()).unwrap();
514 fs::write(&env_path, CONTENT).unwrap();
515
516 let env_folder = EnvVars::new().folder(ENV_FOLDER).env_prod(ENV_PROD_FILE);
518
519 let resource_output = env_folder.output(&mut factory).await.unwrap();
520 let _ = EnvVars::build(&resource_output).await;
521
522 let expected_output_folder = factory.storage_path().join(ENV_FOLDER);
523 let output_folder = EnvVars::build(&resource_output).await.unwrap();
524
525 assert_eq!(
526 output_folder, expected_output_folder,
527 "should return storage folder"
528 );
529 assert_eq!(
530 std::env::var("MY_VAR5").unwrap(),
531 "1",
532 "should load env var"
533 );
534 }
535
536 #[tokio::test]
537 async fn works_if_folder_and_prod_file_default() {
538 let mut factory = MockFactory::new(true);
539
540 const CONTENT: &str = "MY_VAR6=1";
541
542 let env_path = factory
543 .build_path()
544 .join(DEFAULT_FOLDER)
545 .join(DEFAULT_ENV_PROD);
546 fs::create_dir_all(&env_path.parent().unwrap()).unwrap();
547 fs::write(&env_path, CONTENT).unwrap();
548
549 let env_folder = EnvVars::new()
551 .folder(DEFAULT_FOLDER)
552 .env_prod(DEFAULT_ENV_PROD);
553
554 let resource_output = env_folder.output(&mut factory).await.unwrap();
555
556 let _ = EnvVars::build(&resource_output).await;
557
558 let expected_output_folder = factory.storage_path().join(DEFAULT_FOLDER);
559 let output_folder = EnvVars::build(&resource_output).await.unwrap();
560
561 assert_eq!(
562 output_folder, expected_output_folder,
563 "should return storage folder"
564 );
565 assert_eq!(
566 std::env::var("MY_VAR6").unwrap(),
567 "1",
568 "should load env var"
569 );
570 }
571
572 #[tokio::test]
573 #[should_panic(expected = "Cannot load env vars")]
574 async fn panics_if_folder_and_prod_file_default_not_present() {
575 let mut factory = MockFactory::new(true);
576
577 let env_path = factory
578 .build_path()
579 .join(DEFAULT_FOLDER)
580 .join(DEFAULT_ENV_PROD);
581 fs::create_dir_all(&env_path.parent().unwrap()).unwrap();
582
583 let env_folder = EnvVars::new()
585 .folder(DEFAULT_FOLDER)
586 .env_prod(DEFAULT_ENV_PROD);
587
588 let output = env_folder.output(&mut factory).await.unwrap();
589 let _ = EnvVars::build(&output).await.unwrap();
590 }
591}