1use std::borrow::Cow;
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::sync::Arc;
7
8use anyhow::Context;
9use anyhow::Result;
10use anyhow::anyhow;
11use anyhow::bail;
12use serde::Deserialize;
13use serde::Serialize;
14use tracing::warn;
15use url::Url;
16
17use crate::DockerBackend;
18use crate::LocalBackend;
19use crate::SYSTEM;
20use crate::TaskExecutionBackend;
21use crate::TesBackend;
22use crate::convert_unit_string;
23use crate::path::is_url;
24
25pub const MAX_RETRIES: u64 = 100;
27
28pub const DEFAULT_TASK_SHELL: &str = "bash";
30
31pub const DEFAULT_MAX_CONCURRENT_DOWNLOADS: u64 = 10;
33
34pub const DEFAULT_BACKEND_NAME: &str = "default";
36
37#[derive(Debug, Default, Clone, Serialize, Deserialize)]
39#[serde(rename_all = "snake_case", deny_unknown_fields)]
40pub struct Config {
41 #[serde(default)]
43 pub http: HttpConfig,
44 #[serde(default)]
46 pub workflow: WorkflowConfig,
47 #[serde(default)]
49 pub task: TaskConfig,
50 pub backend: Option<String>,
55 #[serde(default)]
63 pub backends: HashMap<String, BackendConfig>,
64 #[serde(default)]
66 pub storage: StorageConfig,
67}
68
69impl Config {
70 pub fn validate(&self) -> Result<()> {
72 self.http.validate()?;
73 self.workflow.validate()?;
74 self.task.validate()?;
75
76 if self.backend.is_none() && self.backends.len() < 2 {
77 } else {
80 let backend = self.backend.as_deref().unwrap_or(DEFAULT_BACKEND_NAME);
82 if !self.backends.contains_key(backend) {
83 bail!("a backend named `{backend}` is not present in the configuration");
84 }
85 }
86
87 for backend in self.backends.values() {
88 backend.validate()?;
89 }
90
91 self.storage.validate()?;
92 Ok(())
93 }
94
95 pub async fn create_backend(self: &Arc<Self>) -> Result<Arc<dyn TaskExecutionBackend>> {
97 let config = if self.backend.is_none() && self.backends.len() < 2 {
98 if self.backends.len() == 1 {
99 Cow::Borrowed(self.backends.values().next().unwrap())
101 } else {
102 Cow::Owned(BackendConfig::default())
104 }
105 } else {
106 let backend = self.backend.as_deref().unwrap_or(DEFAULT_BACKEND_NAME);
108 Cow::Borrowed(self.backends.get(backend).ok_or_else(|| {
109 anyhow!("a backend named `{backend}` is not present in the configuration")
110 })?)
111 };
112
113 match config.as_ref() {
114 BackendConfig::Local(config) => {
115 warn!(
116 "the engine is configured to use the local backend: tasks will not be run \
117 inside of a container"
118 );
119 Ok(Arc::new(LocalBackend::new(self.clone(), config)?))
120 }
121 BackendConfig::Docker(config) => {
122 Ok(Arc::new(DockerBackend::new(self.clone(), config).await?))
123 }
124 BackendConfig::Tes(config) => {
125 Ok(Arc::new(TesBackend::new(self.clone(), config).await?))
126 }
127 }
128 }
129}
130
131#[derive(Debug, Default, Clone, Serialize, Deserialize)]
133#[serde(rename_all = "snake_case", deny_unknown_fields)]
134pub struct HttpConfig {
135 #[serde(default)]
139 pub cache: Option<PathBuf>,
140 #[serde(default, skip_serializing_if = "Option::is_none")]
144 pub max_concurrent_downloads: Option<u64>,
145}
146
147impl HttpConfig {
148 pub fn validate(&self) -> Result<()> {
150 if let Some(limit) = self.max_concurrent_downloads {
151 if limit == 0 {
152 bail!("configuration value `http.max_concurrent_downloads` cannot be zero");
153 }
154 }
155 Ok(())
156 }
157}
158
159#[derive(Debug, Default, Clone, Serialize, Deserialize)]
161#[serde(rename_all = "snake_case", deny_unknown_fields)]
162pub struct StorageConfig {
163 #[serde(default)]
165 pub azure: AzureStorageConfig,
166 #[serde(default)]
168 pub s3: S3StorageConfig,
169 #[serde(default)]
171 pub google: GoogleStorageConfig,
172}
173
174impl StorageConfig {
175 pub fn validate(&self) -> Result<()> {
177 self.azure.validate()?;
178 self.s3.validate()?;
179 self.google.validate()?;
180 Ok(())
181 }
182}
183
184#[derive(Debug, Default, Clone, Serialize, Deserialize)]
186#[serde(rename_all = "snake_case", deny_unknown_fields)]
187pub struct AzureStorageConfig {
188 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
197 pub auth: HashMap<String, HashMap<String, String>>,
198}
199
200impl AzureStorageConfig {
201 pub fn validate(&self) -> Result<()> {
203 Ok(())
204 }
205}
206
207#[derive(Debug, Default, Clone, Serialize, Deserialize)]
209#[serde(rename_all = "snake_case", deny_unknown_fields)]
210pub struct S3StorageConfig {
211 #[serde(default, skip_serializing_if = "Option::is_none")]
216 pub region: Option<String>,
217
218 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
224 pub auth: HashMap<String, String>,
225}
226
227impl S3StorageConfig {
228 pub fn validate(&self) -> Result<()> {
230 Ok(())
231 }
232}
233
234#[derive(Debug, Default, Clone, Serialize, Deserialize)]
236#[serde(rename_all = "snake_case", deny_unknown_fields)]
237pub struct GoogleStorageConfig {
238 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
244 pub auth: HashMap<String, String>,
245}
246
247impl GoogleStorageConfig {
248 pub fn validate(&self) -> Result<()> {
250 Ok(())
251 }
252}
253
254#[derive(Debug, Default, Clone, Serialize, Deserialize)]
256#[serde(rename_all = "snake_case", deny_unknown_fields)]
257pub struct WorkflowConfig {
258 #[serde(default)]
260 pub scatter: ScatterConfig,
261}
262
263impl WorkflowConfig {
264 pub fn validate(&self) -> Result<()> {
266 self.scatter.validate()?;
267 Ok(())
268 }
269}
270
271#[derive(Debug, Default, Clone, Serialize, Deserialize)]
273#[serde(rename_all = "snake_case", deny_unknown_fields)]
274pub struct ScatterConfig {
275 #[serde(default, skip_serializing_if = "Option::is_none")]
328 pub concurrency: Option<u64>,
329}
330
331impl ScatterConfig {
332 pub fn validate(&self) -> Result<()> {
334 if let Some(concurrency) = self.concurrency {
335 if concurrency == 0 {
336 bail!("configuration value `workflow.scatter.concurrency` cannot be zero");
337 }
338 }
339
340 Ok(())
341 }
342}
343
344#[derive(Debug, Default, Clone, Serialize, Deserialize)]
346#[serde(rename_all = "snake_case", deny_unknown_fields)]
347pub struct TaskConfig {
348 #[serde(default, skip_serializing_if = "Option::is_none")]
354 pub retries: Option<u64>,
355 #[serde(default, skip_serializing_if = "Option::is_none")]
360 pub container: Option<String>,
361 #[serde(default, skip_serializing_if = "Option::is_none")]
369 pub shell: Option<String>,
370 #[serde(default)]
372 pub cpu_limit_behavior: TaskResourceLimitBehavior,
373 #[serde(default)]
375 pub memory_limit_behavior: TaskResourceLimitBehavior,
376}
377
378impl TaskConfig {
379 pub fn validate(&self) -> Result<()> {
381 if self.retries.unwrap_or(0) > MAX_RETRIES {
382 bail!("configuration value `task.retries` cannot exceed {MAX_RETRIES}");
383 }
384
385 Ok(())
386 }
387}
388
389#[derive(Debug, Default, Clone, Serialize, Deserialize)]
392#[serde(rename_all = "snake_case", deny_unknown_fields)]
393pub enum TaskResourceLimitBehavior {
394 TryWithMax,
397 #[default]
401 Deny,
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize)]
406#[serde(rename_all = "snake_case", tag = "type")]
407pub enum BackendConfig {
408 Local(LocalBackendConfig),
410 Docker(DockerBackendConfig),
412 Tes(Box<TesBackendConfig>),
414}
415
416impl Default for BackendConfig {
417 fn default() -> Self {
418 Self::Docker(Default::default())
419 }
420}
421
422impl BackendConfig {
423 pub fn validate(&self) -> Result<()> {
425 match self {
426 Self::Local(config) => config.validate(),
427 Self::Docker(config) => config.validate(),
428 Self::Tes(config) => config.validate(),
429 }
430 }
431
432 pub fn as_local(&self) -> Option<&LocalBackendConfig> {
436 match self {
437 Self::Local(config) => Some(config),
438 _ => None,
439 }
440 }
441
442 pub fn as_docker(&self) -> Option<&DockerBackendConfig> {
446 match self {
447 Self::Docker(config) => Some(config),
448 _ => None,
449 }
450 }
451
452 pub fn as_tes(&self) -> Option<&TesBackendConfig> {
456 match self {
457 Self::Tes(config) => Some(config),
458 _ => None,
459 }
460 }
461}
462
463#[derive(Debug, Default, Clone, Serialize, Deserialize)]
470#[serde(rename_all = "snake_case", deny_unknown_fields)]
471pub struct LocalBackendConfig {
472 #[serde(default, skip_serializing_if = "Option::is_none")]
478 pub cpu: Option<u64>,
479
480 #[serde(default, skip_serializing_if = "Option::is_none")]
487 pub memory: Option<String>,
488}
489
490impl LocalBackendConfig {
491 pub fn validate(&self) -> Result<()> {
493 if let Some(cpu) = self.cpu {
494 if cpu == 0 {
495 bail!("local backend configuration value `cpu` cannot be zero");
496 }
497
498 let total = SYSTEM.cpus().len() as u64;
499 if cpu > total {
500 bail!(
501 "local backend configuration value `cpu` cannot exceed the virtual CPUs \
502 available to the host ({total})"
503 );
504 }
505 }
506
507 if let Some(memory) = &self.memory {
508 let memory = convert_unit_string(memory).with_context(|| {
509 format!("local backend configuration value `memory` has invalid value `{memory}`")
510 })?;
511
512 if memory == 0 {
513 bail!("local backend configuration value `memory` cannot be zero");
514 }
515
516 let total = SYSTEM.total_memory();
517 if memory > total {
518 bail!(
519 "local backend configuration value `memory` cannot exceed the total memory of \
520 the host ({total} bytes)"
521 );
522 }
523 }
524
525 Ok(())
526 }
527}
528
529const fn cleanup_default() -> bool {
531 true
532}
533
534#[derive(Debug, Clone, Serialize, Deserialize)]
536#[serde(rename_all = "snake_case", deny_unknown_fields)]
537pub struct DockerBackendConfig {
538 #[serde(default = "cleanup_default")]
542 pub cleanup: bool,
543}
544
545impl DockerBackendConfig {
546 pub fn validate(&self) -> Result<()> {
548 Ok(())
549 }
550}
551
552impl Default for DockerBackendConfig {
553 fn default() -> Self {
554 Self { cleanup: true }
555 }
556}
557
558#[derive(Debug, Default, Clone, Serialize, Deserialize)]
560#[serde(rename_all = "snake_case", deny_unknown_fields)]
561pub struct BasicAuthConfig {
562 pub username: Option<String>,
564 pub password: Option<String>,
566}
567
568impl BasicAuthConfig {
569 pub fn validate(&self) -> Result<()> {
571 if self.username.is_none() {
572 bail!("HTTP basic auth configuration value `username` is required");
573 }
574
575 if self.password.is_none() {
576 bail!("HTTP basic auth configuration value `password` is required");
577 }
578
579 Ok(())
580 }
581}
582
583#[derive(Debug, Default, Clone, Serialize, Deserialize)]
585#[serde(rename_all = "snake_case", deny_unknown_fields)]
586pub struct BearerAuthConfig {
587 pub token: Option<String>,
589}
590
591impl BearerAuthConfig {
592 pub fn validate(&self) -> Result<()> {
594 if self.token.is_none() {
595 bail!("HTTP bearer auth configuration value `token` is required");
596 }
597
598 Ok(())
599 }
600}
601
602#[derive(Debug, Clone, Serialize, Deserialize)]
604#[serde(rename_all = "snake_case", tag = "type")]
605pub enum TesBackendAuthConfig {
606 Basic(BasicAuthConfig),
608 Bearer(BearerAuthConfig),
610}
611
612impl TesBackendAuthConfig {
613 pub fn validate(&self) -> Result<()> {
615 match self {
616 Self::Basic(config) => config.validate(),
617 Self::Bearer(config) => config.validate(),
618 }
619 }
620}
621
622#[derive(Debug, Default, Clone, Serialize, Deserialize)]
624#[serde(rename_all = "snake_case", deny_unknown_fields)]
625pub struct TesBackendConfig {
626 #[serde(default)]
628 pub url: Option<Url>,
629
630 #[serde(default, skip_serializing_if = "Option::is_none")]
632 pub auth: Option<TesBackendAuthConfig>,
633
634 #[serde(default, skip_serializing_if = "Option::is_none")]
636 pub inputs: Option<Url>,
637
638 #[serde(default, skip_serializing_if = "Option::is_none")]
640 pub outputs: Option<Url>,
641
642 #[serde(default)]
646 pub interval: Option<u64>,
647
648 #[serde(default)]
652 pub max_concurrency: Option<u64>,
653
654 #[serde(default)]
657 pub insecure: bool,
658}
659
660impl TesBackendConfig {
661 pub fn validate(&self) -> Result<()> {
663 match &self.url {
664 Some(url) => {
665 if !self.insecure && url.scheme() != "https" {
666 bail!(
667 "TES backend configuration value `url` has invalid value `{url}`: URL \
668 must use a HTTPS scheme"
669 );
670 }
671 }
672 None => bail!("TES backend configuration value `url` is required"),
673 }
674
675 if let Some(auth) = &self.auth {
676 auth.validate()?;
677 }
678
679 match &self.inputs {
680 Some(url) => {
681 if !is_url(url.as_str()) {
682 bail!(
683 "TES backend storage configuration value `inputs` has invalid value \
684 `{url}`: URL scheme is not supported"
685 );
686 }
687
688 if !url.path().ends_with('/') {
689 bail!(
690 "TES backend storage configuration value `inputs` has invalid value \
691 `{url}`: URL path must end with a slash"
692 );
693 }
694 }
695 None => bail!("TES backend configuration value `inputs` is required"),
696 }
697
698 match &self.outputs {
699 Some(url) => {
700 if !is_url(url.as_str()) {
701 bail!(
702 "TES backend storage configuration value `outputs` has invalid value \
703 `{url}`: URL scheme is not supported"
704 );
705 }
706
707 if !url.path().ends_with('/') {
708 bail!(
709 "TES backend storage configuration value `outputs` has invalid value \
710 `{url}`: URL path must end with a slash"
711 );
712 }
713 }
714 None => bail!("TES backend storage configuration value `outputs` is required"),
715 }
716
717 Ok(())
718 }
719}
720
721#[cfg(test)]
722mod test {
723 use pretty_assertions::assert_eq;
724
725 use super::*;
726
727 #[test]
728 fn test_config_validate() {
729 let mut config = Config::default();
731 config.task.retries = Some(1000000);
732 assert_eq!(
733 config.validate().unwrap_err().to_string(),
734 "configuration value `task.retries` cannot exceed 100"
735 );
736
737 let mut config = Config::default();
739 config.workflow.scatter.concurrency = Some(0);
740 assert_eq!(
741 config.validate().unwrap_err().to_string(),
742 "configuration value `workflow.scatter.concurrency` cannot be zero"
743 );
744
745 let config = Config {
747 backend: Some("foo".into()),
748 ..Default::default()
749 };
750 assert_eq!(
751 config.validate().unwrap_err().to_string(),
752 "a backend named `foo` is not present in the configuration"
753 );
754 let config = Config {
755 backend: Some("bar".into()),
756 backends: [("foo".to_string(), BackendConfig::default())].into(),
757 ..Default::default()
758 };
759 assert_eq!(
760 config.validate().unwrap_err().to_string(),
761 "a backend named `bar` is not present in the configuration"
762 );
763
764 let config = Config {
766 backends: [("foo".to_string(), BackendConfig::default())].into(),
767 ..Default::default()
768 };
769 config.validate().expect("config should validate");
770
771 let config = Config {
773 backends: [(
774 "default".to_string(),
775 BackendConfig::Local(LocalBackendConfig {
776 cpu: Some(0),
777 ..Default::default()
778 }),
779 )]
780 .into(),
781 ..Default::default()
782 };
783 assert_eq!(
784 config.validate().unwrap_err().to_string(),
785 "local backend configuration value `cpu` cannot be zero"
786 );
787 let config = Config {
788 backends: [(
789 "default".to_string(),
790 BackendConfig::Local(LocalBackendConfig {
791 cpu: Some(10000000),
792 ..Default::default()
793 }),
794 )]
795 .into(),
796 ..Default::default()
797 };
798 assert!(config.validate().unwrap_err().to_string().starts_with(
799 "local backend configuration value `cpu` cannot exceed the virtual CPUs available to \
800 the host"
801 ));
802
803 let config = Config {
805 backends: [(
806 "default".to_string(),
807 BackendConfig::Local(LocalBackendConfig {
808 memory: Some("0 GiB".to_string()),
809 ..Default::default()
810 }),
811 )]
812 .into(),
813 ..Default::default()
814 };
815 assert_eq!(
816 config.validate().unwrap_err().to_string(),
817 "local backend configuration value `memory` cannot be zero"
818 );
819 let config = Config {
820 backends: [(
821 "default".to_string(),
822 BackendConfig::Local(LocalBackendConfig {
823 memory: Some("100 meows".to_string()),
824 ..Default::default()
825 }),
826 )]
827 .into(),
828 ..Default::default()
829 };
830 assert_eq!(
831 config.validate().unwrap_err().to_string(),
832 "local backend configuration value `memory` has invalid value `100 meows`"
833 );
834
835 let config = Config {
836 backends: [(
837 "default".to_string(),
838 BackendConfig::Local(LocalBackendConfig {
839 memory: Some("1000 TiB".to_string()),
840 ..Default::default()
841 }),
842 )]
843 .into(),
844 ..Default::default()
845 };
846 assert!(config.validate().unwrap_err().to_string().starts_with(
847 "local backend configuration value `memory` cannot exceed the total memory of the host"
848 ));
849
850 let config = Config {
852 backends: [(
853 "default".to_string(),
854 BackendConfig::Tes(Default::default()),
855 )]
856 .into(),
857 ..Default::default()
858 };
859 assert_eq!(
860 config.validate().unwrap_err().to_string(),
861 "TES backend configuration value `url` is required"
862 );
863
864 let config = Config {
866 backends: [(
867 "default".to_string(),
868 BackendConfig::Tes(
869 TesBackendConfig {
870 url: Some("http://example.com".parse().unwrap()),
871 inputs: Some("http://example.com".parse().unwrap()),
872 outputs: Some("http://example.com".parse().unwrap()),
873 ..Default::default()
874 }
875 .into(),
876 ),
877 )]
878 .into(),
879 ..Default::default()
880 };
881 assert_eq!(
882 config.validate().unwrap_err().to_string(),
883 "TES backend configuration value `url` has invalid value `http://example.com/`: URL \
884 must use a HTTPS scheme"
885 );
886
887 let config = Config {
889 backends: [(
890 "default".to_string(),
891 BackendConfig::Tes(
892 TesBackendConfig {
893 url: Some("http://example.com".parse().unwrap()),
894 inputs: Some("http://example.com".parse().unwrap()),
895 outputs: Some("http://example.com".parse().unwrap()),
896 insecure: true,
897 ..Default::default()
898 }
899 .into(),
900 ),
901 )]
902 .into(),
903 ..Default::default()
904 };
905 config.validate().expect("configuration should validate");
906
907 let config = Config {
909 backends: [(
910 "default".to_string(),
911 BackendConfig::Tes(Box::new(TesBackendConfig {
912 url: Some(Url::parse("https://example.com").unwrap()),
913 auth: Some(TesBackendAuthConfig::Basic(Default::default())),
914 ..Default::default()
915 })),
916 )]
917 .into(),
918 ..Default::default()
919 };
920 assert_eq!(
921 config.validate().unwrap_err().to_string(),
922 "HTTP basic auth configuration value `username` is required"
923 );
924 let config = Config {
925 backends: [(
926 "default".to_string(),
927 BackendConfig::Tes(Box::new(TesBackendConfig {
928 url: Some(Url::parse("https://example.com").unwrap()),
929 auth: Some(TesBackendAuthConfig::Basic(BasicAuthConfig {
930 username: Some("Foo".into()),
931 ..Default::default()
932 })),
933 ..Default::default()
934 })),
935 )]
936 .into(),
937 ..Default::default()
938 };
939 assert_eq!(
940 config.validate().unwrap_err().to_string(),
941 "HTTP basic auth configuration value `password` is required"
942 );
943
944 let config = Config {
946 backends: [(
947 "default".to_string(),
948 BackendConfig::Tes(Box::new(TesBackendConfig {
949 url: Some(Url::parse("https://example.com").unwrap()),
950 auth: Some(TesBackendAuthConfig::Bearer(Default::default())),
951 ..Default::default()
952 })),
953 )]
954 .into(),
955 ..Default::default()
956 };
957 assert_eq!(
958 config.validate().unwrap_err().to_string(),
959 "HTTP bearer auth configuration value `token` is required"
960 );
961
962 let mut config = Config::default();
963 config.http.max_concurrent_downloads = Some(0);
964 assert_eq!(
965 config.validate().unwrap_err().to_string(),
966 "configuration value `http.max_concurrent_downloads` cannot be zero"
967 );
968
969 let mut config = Config::default();
970 config.http.max_concurrent_downloads = Some(5);
971 assert!(
972 config.validate().is_ok(),
973 "should pass for valid configuration"
974 );
975
976 let mut config = Config::default();
977 config.http.max_concurrent_downloads = None;
978 assert!(config.validate().is_ok(), "should pass for default (None)");
979 }
980}