1use std::borrow::Cow;
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::sync::Arc;
7
8use anyhow::Context;
9use anyhow::Result;
10use anyhow::anyhow;
11use anyhow::bail;
12use serde::Deserialize;
13use serde::Serialize;
14use tracing::warn;
15use url::Url;
16
17use crate::DockerBackend;
18use crate::LocalBackend;
19use crate::SYSTEM;
20use crate::TaskExecutionBackend;
21use crate::TesBackend;
22use crate::convert_unit_string;
23use crate::path::is_url;
24
25pub const MAX_RETRIES: u64 = 100;
27
28pub const DEFAULT_TASK_SHELL: &str = "bash";
30
31pub const DEFAULT_MAX_CONCURRENT_DOWNLOADS: u64 = 10;
33
34pub const DEFAULT_BACKEND_NAME: &str = "default";
36
37#[derive(Debug, Default, Clone, Serialize, Deserialize)]
39#[serde(rename_all = "snake_case", deny_unknown_fields)]
40pub struct Config {
41 #[serde(default)]
43 pub http: HttpConfig,
44 #[serde(default)]
46 pub workflow: WorkflowConfig,
47 #[serde(default)]
49 pub task: TaskConfig,
50 pub backend: Option<String>,
55 #[serde(default)]
63 pub backends: HashMap<String, BackendConfig>,
64 #[serde(default)]
66 pub storage: StorageConfig,
67}
68
69impl Config {
70 pub fn validate(&self) -> Result<()> {
72 self.http.validate()?;
73 self.workflow.validate()?;
74 self.task.validate()?;
75
76 if self.backend.is_none() && self.backends.len() < 2 {
77 } else {
80 let backend = self.backend.as_deref().unwrap_or(DEFAULT_BACKEND_NAME);
82 if !self.backends.contains_key(backend) {
83 bail!("a backend named `{backend}` is not present in the configuration");
84 }
85 }
86
87 for backend in self.backends.values() {
88 backend.validate()?;
89 }
90
91 self.storage.validate()?;
92 Ok(())
93 }
94
95 pub async fn create_backend(self: &Arc<Self>) -> Result<Arc<dyn TaskExecutionBackend>> {
97 let config = if self.backend.is_none() && self.backends.len() < 2 {
98 if self.backends.len() == 1 {
99 Cow::Borrowed(self.backends.values().next().unwrap())
101 } else {
102 Cow::Owned(BackendConfig::default())
104 }
105 } else {
106 let backend = self.backend.as_deref().unwrap_or(DEFAULT_BACKEND_NAME);
108 Cow::Borrowed(self.backends.get(backend).ok_or_else(|| {
109 anyhow!("a backend named `{backend}` is not present in the configuration")
110 })?)
111 };
112
113 match config.as_ref() {
114 BackendConfig::Local(config) => {
115 warn!(
116 "the engine is configured to use the local backend: tasks will not be run \
117 inside of a container"
118 );
119 Ok(Arc::new(LocalBackend::new(self.clone(), config)?))
120 }
121 BackendConfig::Docker(config) => {
122 Ok(Arc::new(DockerBackend::new(self.clone(), config).await?))
123 }
124 BackendConfig::Tes(config) => {
125 Ok(Arc::new(TesBackend::new(self.clone(), config).await?))
126 }
127 }
128 }
129}
130
131#[derive(Debug, Default, Clone, Serialize, Deserialize)]
133#[serde(rename_all = "snake_case", deny_unknown_fields)]
134pub struct HttpConfig {
135 #[serde(default)]
139 pub cache: Option<PathBuf>,
140 #[serde(default, skip_serializing_if = "Option::is_none")]
144 pub max_concurrent_downloads: Option<u64>,
145}
146
147impl HttpConfig {
148 pub fn validate(&self) -> Result<()> {
150 if let Some(limit) = self.max_concurrent_downloads {
151 if limit == 0 {
152 bail!("configuration value `http.max_concurrent_downloads` cannot be zero");
153 }
154 }
155 Ok(())
156 }
157}
158
159#[derive(Debug, Default, Clone, Serialize, Deserialize)]
161#[serde(rename_all = "snake_case", deny_unknown_fields)]
162pub struct StorageConfig {
163 #[serde(default)]
165 pub azure: AzureStorageConfig,
166 #[serde(default)]
168 pub s3: S3StorageConfig,
169 #[serde(default)]
171 pub google: GoogleStorageConfig,
172}
173
174impl StorageConfig {
175 pub fn validate(&self) -> Result<()> {
177 self.azure.validate()?;
178 self.s3.validate()?;
179 self.google.validate()?;
180 Ok(())
181 }
182}
183
184#[derive(Debug, Default, Clone, Serialize, Deserialize)]
186#[serde(rename_all = "snake_case", deny_unknown_fields)]
187pub struct AzureStorageConfig {
188 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
197 pub auth: HashMap<String, HashMap<String, String>>,
198}
199
200impl AzureStorageConfig {
201 pub fn validate(&self) -> Result<()> {
203 Ok(())
204 }
205}
206
207#[derive(Debug, Default, Clone, Serialize, Deserialize)]
209#[serde(rename_all = "snake_case", deny_unknown_fields)]
210pub struct S3StorageConfig {
211 #[serde(default, skip_serializing_if = "Option::is_none")]
216 pub region: Option<String>,
217
218 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
224 pub auth: HashMap<String, String>,
225}
226
227impl S3StorageConfig {
228 pub fn validate(&self) -> Result<()> {
230 Ok(())
231 }
232}
233
234#[derive(Debug, Default, Clone, Serialize, Deserialize)]
236#[serde(rename_all = "snake_case", deny_unknown_fields)]
237pub struct GoogleStorageConfig {
238 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
244 pub auth: HashMap<String, String>,
245}
246
247impl GoogleStorageConfig {
248 pub fn validate(&self) -> Result<()> {
250 Ok(())
251 }
252}
253
254#[derive(Debug, Default, Clone, Serialize, Deserialize)]
256#[serde(rename_all = "snake_case", deny_unknown_fields)]
257pub struct WorkflowConfig {
258 #[serde(default)]
260 pub scatter: ScatterConfig,
261}
262
263impl WorkflowConfig {
264 pub fn validate(&self) -> Result<()> {
266 self.scatter.validate()?;
267 Ok(())
268 }
269}
270
271#[derive(Debug, Default, Clone, Serialize, Deserialize)]
273#[serde(rename_all = "snake_case", deny_unknown_fields)]
274pub struct ScatterConfig {
275 #[serde(default, skip_serializing_if = "Option::is_none")]
328 pub concurrency: Option<u64>,
329}
330
331impl ScatterConfig {
332 pub fn validate(&self) -> Result<()> {
334 if let Some(concurrency) = self.concurrency {
335 if concurrency == 0 {
336 bail!("configuration value `workflow.scatter.concurrency` cannot be zero");
337 }
338 }
339
340 Ok(())
341 }
342}
343
344#[derive(Debug, Default, Clone, Serialize, Deserialize)]
346#[serde(rename_all = "snake_case", deny_unknown_fields)]
347pub struct TaskConfig {
348 #[serde(default, skip_serializing_if = "Option::is_none")]
354 pub retries: Option<u64>,
355 #[serde(default, skip_serializing_if = "Option::is_none")]
360 pub container: Option<String>,
361 #[serde(default, skip_serializing_if = "Option::is_none")]
369 pub shell: Option<String>,
370}
371
372impl TaskConfig {
373 pub fn validate(&self) -> Result<()> {
375 if self.retries.unwrap_or(0) > MAX_RETRIES {
376 bail!("configuration value `task.retries` cannot exceed {MAX_RETRIES}");
377 }
378
379 Ok(())
380 }
381}
382
383#[derive(Debug, Clone, Serialize, Deserialize)]
385#[serde(rename_all = "snake_case", tag = "type")]
386pub enum BackendConfig {
387 Local(LocalBackendConfig),
389 Docker(DockerBackendConfig),
391 Tes(Box<TesBackendConfig>),
393}
394
395impl Default for BackendConfig {
396 fn default() -> Self {
397 Self::Docker(Default::default())
398 }
399}
400
401impl BackendConfig {
402 pub fn validate(&self) -> Result<()> {
404 match self {
405 Self::Local(config) => config.validate(),
406 Self::Docker(config) => config.validate(),
407 Self::Tes(config) => config.validate(),
408 }
409 }
410
411 pub fn as_local(&self) -> Option<&LocalBackendConfig> {
415 match self {
416 Self::Local(config) => Some(config),
417 _ => None,
418 }
419 }
420
421 pub fn as_docker(&self) -> Option<&DockerBackendConfig> {
425 match self {
426 Self::Docker(config) => Some(config),
427 _ => None,
428 }
429 }
430
431 pub fn as_tes(&self) -> Option<&TesBackendConfig> {
435 match self {
436 Self::Tes(config) => Some(config),
437 _ => None,
438 }
439 }
440}
441
442#[derive(Debug, Default, Clone, Serialize, Deserialize)]
449#[serde(rename_all = "snake_case", deny_unknown_fields)]
450pub struct LocalBackendConfig {
451 #[serde(default, skip_serializing_if = "Option::is_none")]
457 pub cpu: Option<u64>,
458
459 #[serde(default, skip_serializing_if = "Option::is_none")]
466 pub memory: Option<String>,
467}
468
469impl LocalBackendConfig {
470 pub fn validate(&self) -> Result<()> {
472 if let Some(cpu) = self.cpu {
473 if cpu == 0 {
474 bail!("local backend configuration value `cpu` cannot be zero");
475 }
476
477 let total = SYSTEM.cpus().len() as u64;
478 if cpu > total {
479 bail!(
480 "local backend configuration value `cpu` cannot exceed the virtual CPUs \
481 available to the host ({total})"
482 );
483 }
484 }
485
486 if let Some(memory) = &self.memory {
487 let memory = convert_unit_string(memory).with_context(|| {
488 format!("local backend configuration value `memory` has invalid value `{memory}`")
489 })?;
490
491 if memory == 0 {
492 bail!("local backend configuration value `memory` cannot be zero");
493 }
494
495 let total = SYSTEM.total_memory();
496 if memory > total {
497 bail!(
498 "local backend configuration value `memory` cannot exceed the total memory of \
499 the host ({total} bytes)"
500 );
501 }
502 }
503
504 Ok(())
505 }
506}
507
508const fn cleanup_default() -> bool {
510 true
511}
512
513#[derive(Debug, Clone, Serialize, Deserialize)]
515#[serde(rename_all = "snake_case", deny_unknown_fields)]
516pub struct DockerBackendConfig {
517 #[serde(default = "cleanup_default")]
521 pub cleanup: bool,
522}
523
524impl DockerBackendConfig {
525 pub fn validate(&self) -> Result<()> {
527 Ok(())
528 }
529}
530
531impl Default for DockerBackendConfig {
532 fn default() -> Self {
533 Self { cleanup: true }
534 }
535}
536
537#[derive(Debug, Default, Clone, Serialize, Deserialize)]
539#[serde(rename_all = "snake_case", deny_unknown_fields)]
540pub struct BasicAuthConfig {
541 pub username: Option<String>,
543 pub password: Option<String>,
545}
546
547impl BasicAuthConfig {
548 pub fn validate(&self) -> Result<()> {
550 if self.username.is_none() {
551 bail!("HTTP basic auth configuration value `username` is required");
552 }
553
554 if self.password.is_none() {
555 bail!("HTTP basic auth configuration value `password` is required");
556 }
557
558 Ok(())
559 }
560}
561
562#[derive(Debug, Default, Clone, Serialize, Deserialize)]
564#[serde(rename_all = "snake_case", deny_unknown_fields)]
565pub struct BearerAuthConfig {
566 pub token: Option<String>,
568}
569
570impl BearerAuthConfig {
571 pub fn validate(&self) -> Result<()> {
573 if self.token.is_none() {
574 bail!("HTTP bearer auth configuration value `token` is required");
575 }
576
577 Ok(())
578 }
579}
580
581#[derive(Debug, Clone, Serialize, Deserialize)]
583#[serde(rename_all = "snake_case", tag = "type")]
584pub enum TesBackendAuthConfig {
585 Basic(BasicAuthConfig),
587 Bearer(BearerAuthConfig),
589}
590
591impl TesBackendAuthConfig {
592 pub fn validate(&self) -> Result<()> {
594 match self {
595 Self::Basic(config) => config.validate(),
596 Self::Bearer(config) => config.validate(),
597 }
598 }
599}
600
601#[derive(Debug, Default, Clone, Serialize, Deserialize)]
603#[serde(rename_all = "snake_case", deny_unknown_fields)]
604pub struct TesBackendConfig {
605 #[serde(default)]
607 pub url: Option<Url>,
608
609 #[serde(default, skip_serializing_if = "Option::is_none")]
611 pub auth: Option<TesBackendAuthConfig>,
612
613 #[serde(default, skip_serializing_if = "Option::is_none")]
615 pub inputs: Option<Url>,
616
617 #[serde(default, skip_serializing_if = "Option::is_none")]
619 pub outputs: Option<Url>,
620
621 #[serde(default)]
625 pub interval: Option<u64>,
626
627 #[serde(default)]
631 pub max_concurrency: Option<u64>,
632
633 #[serde(default)]
636 pub insecure: bool,
637}
638
639impl TesBackendConfig {
640 pub fn validate(&self) -> Result<()> {
642 match &self.url {
643 Some(url) => {
644 if !self.insecure && url.scheme() != "https" {
645 bail!(
646 "TES backend configuration value `url` has invalid value `{url}`: URL \
647 must use a HTTPS scheme"
648 );
649 }
650 }
651 None => bail!("TES backend configuration value `url` is required"),
652 }
653
654 if let Some(auth) = &self.auth {
655 auth.validate()?;
656 }
657
658 match &self.inputs {
659 Some(url) => {
660 if !is_url(url.as_str()) {
661 bail!(
662 "TES backend storage configuration value `inputs` has invalid value \
663 `{url}`: URL scheme is not supported"
664 );
665 }
666
667 if !url.path().ends_with('/') {
668 bail!(
669 "TES backend storage configuration value `inputs` has invalid value \
670 `{url}`: URL path must end with a slash"
671 );
672 }
673 }
674 None => bail!("TES backend configuration value `inputs` is required"),
675 }
676
677 match &self.outputs {
678 Some(url) => {
679 if !is_url(url.as_str()) {
680 bail!(
681 "TES backend storage configuration value `outputs` has invalid value \
682 `{url}`: URL scheme is not supported"
683 );
684 }
685
686 if !url.path().ends_with('/') {
687 bail!(
688 "TES backend storage configuration value `outputs` has invalid value \
689 `{url}`: URL path must end with a slash"
690 );
691 }
692 }
693 None => bail!("TES backend storage configuration value `outputs` is required"),
694 }
695
696 Ok(())
697 }
698}
699
700#[cfg(test)]
701mod test {
702 use pretty_assertions::assert_eq;
703
704 use super::*;
705
706 #[test]
707 fn test_config_validate() {
708 let mut config = Config::default();
710 config.task.retries = Some(1000000);
711 assert_eq!(
712 config.validate().unwrap_err().to_string(),
713 "configuration value `task.retries` cannot exceed 100"
714 );
715
716 let mut config = Config::default();
718 config.workflow.scatter.concurrency = Some(0);
719 assert_eq!(
720 config.validate().unwrap_err().to_string(),
721 "configuration value `workflow.scatter.concurrency` cannot be zero"
722 );
723
724 let config = Config {
726 backend: Some("foo".into()),
727 ..Default::default()
728 };
729 assert_eq!(
730 config.validate().unwrap_err().to_string(),
731 "a backend named `foo` is not present in the configuration"
732 );
733 let config = Config {
734 backend: Some("bar".into()),
735 backends: [("foo".to_string(), BackendConfig::default())].into(),
736 ..Default::default()
737 };
738 assert_eq!(
739 config.validate().unwrap_err().to_string(),
740 "a backend named `bar` is not present in the configuration"
741 );
742
743 let config = Config {
745 backends: [("foo".to_string(), BackendConfig::default())].into(),
746 ..Default::default()
747 };
748 config.validate().expect("config should validate");
749
750 let config = Config {
752 backends: [(
753 "default".to_string(),
754 BackendConfig::Local(LocalBackendConfig {
755 cpu: Some(0),
756 ..Default::default()
757 }),
758 )]
759 .into(),
760 ..Default::default()
761 };
762 assert_eq!(
763 config.validate().unwrap_err().to_string(),
764 "local backend configuration value `cpu` cannot be zero"
765 );
766 let config = Config {
767 backends: [(
768 "default".to_string(),
769 BackendConfig::Local(LocalBackendConfig {
770 cpu: Some(10000000),
771 ..Default::default()
772 }),
773 )]
774 .into(),
775 ..Default::default()
776 };
777 assert!(config.validate().unwrap_err().to_string().starts_with(
778 "local backend configuration value `cpu` cannot exceed the virtual CPUs available to \
779 the host"
780 ));
781
782 let config = Config {
784 backends: [(
785 "default".to_string(),
786 BackendConfig::Local(LocalBackendConfig {
787 memory: Some("0 GiB".to_string()),
788 ..Default::default()
789 }),
790 )]
791 .into(),
792 ..Default::default()
793 };
794 assert_eq!(
795 config.validate().unwrap_err().to_string(),
796 "local backend configuration value `memory` cannot be zero"
797 );
798 let config = Config {
799 backends: [(
800 "default".to_string(),
801 BackendConfig::Local(LocalBackendConfig {
802 memory: Some("100 meows".to_string()),
803 ..Default::default()
804 }),
805 )]
806 .into(),
807 ..Default::default()
808 };
809 assert_eq!(
810 config.validate().unwrap_err().to_string(),
811 "local backend configuration value `memory` has invalid value `100 meows`"
812 );
813
814 let config = Config {
815 backends: [(
816 "default".to_string(),
817 BackendConfig::Local(LocalBackendConfig {
818 memory: Some("1000 TiB".to_string()),
819 ..Default::default()
820 }),
821 )]
822 .into(),
823 ..Default::default()
824 };
825 assert!(config.validate().unwrap_err().to_string().starts_with(
826 "local backend configuration value `memory` cannot exceed the total memory of the host"
827 ));
828
829 let config = Config {
831 backends: [(
832 "default".to_string(),
833 BackendConfig::Tes(Default::default()),
834 )]
835 .into(),
836 ..Default::default()
837 };
838 assert_eq!(
839 config.validate().unwrap_err().to_string(),
840 "TES backend configuration value `url` is required"
841 );
842
843 let config = Config {
845 backends: [(
846 "default".to_string(),
847 BackendConfig::Tes(
848 TesBackendConfig {
849 url: Some("http://example.com".parse().unwrap()),
850 inputs: Some("http://example.com".parse().unwrap()),
851 outputs: Some("http://example.com".parse().unwrap()),
852 ..Default::default()
853 }
854 .into(),
855 ),
856 )]
857 .into(),
858 ..Default::default()
859 };
860 assert_eq!(
861 config.validate().unwrap_err().to_string(),
862 "TES backend configuration value `url` has invalid value `http://example.com/`: URL \
863 must use a HTTPS scheme"
864 );
865
866 let config = Config {
868 backends: [(
869 "default".to_string(),
870 BackendConfig::Tes(
871 TesBackendConfig {
872 url: Some("http://example.com".parse().unwrap()),
873 inputs: Some("http://example.com".parse().unwrap()),
874 outputs: Some("http://example.com".parse().unwrap()),
875 insecure: true,
876 ..Default::default()
877 }
878 .into(),
879 ),
880 )]
881 .into(),
882 ..Default::default()
883 };
884 config.validate().expect("configuration should validate");
885
886 let config = Config {
888 backends: [(
889 "default".to_string(),
890 BackendConfig::Tes(Box::new(TesBackendConfig {
891 url: Some(Url::parse("https://example.com").unwrap()),
892 auth: Some(TesBackendAuthConfig::Basic(Default::default())),
893 ..Default::default()
894 })),
895 )]
896 .into(),
897 ..Default::default()
898 };
899 assert_eq!(
900 config.validate().unwrap_err().to_string(),
901 "HTTP basic auth configuration value `username` is required"
902 );
903 let config = Config {
904 backends: [(
905 "default".to_string(),
906 BackendConfig::Tes(Box::new(TesBackendConfig {
907 url: Some(Url::parse("https://example.com").unwrap()),
908 auth: Some(TesBackendAuthConfig::Basic(BasicAuthConfig {
909 username: Some("Foo".into()),
910 ..Default::default()
911 })),
912 ..Default::default()
913 })),
914 )]
915 .into(),
916 ..Default::default()
917 };
918 assert_eq!(
919 config.validate().unwrap_err().to_string(),
920 "HTTP basic auth configuration value `password` is required"
921 );
922
923 let config = Config {
925 backends: [(
926 "default".to_string(),
927 BackendConfig::Tes(Box::new(TesBackendConfig {
928 url: Some(Url::parse("https://example.com").unwrap()),
929 auth: Some(TesBackendAuthConfig::Bearer(Default::default())),
930 ..Default::default()
931 })),
932 )]
933 .into(),
934 ..Default::default()
935 };
936 assert_eq!(
937 config.validate().unwrap_err().to_string(),
938 "HTTP bearer auth configuration value `token` is required"
939 );
940
941 let mut config = Config::default();
942 config.http.max_concurrent_downloads = Some(0);
943 assert_eq!(
944 config.validate().unwrap_err().to_string(),
945 "configuration value `http.max_concurrent_downloads` cannot be zero"
946 );
947
948 let mut config = Config::default();
949 config.http.max_concurrent_downloads = Some(5);
950 assert!(
951 config.validate().is_ok(),
952 "should pass for valid configuration"
953 );
954
955 let mut config = Config::default();
956 config.http.max_concurrent_downloads = None;
957 assert!(config.validate().is_ok(), "should pass for default (None)");
958 }
959}