1use std::borrow::Cow;
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::sync::Arc;
7
8use anyhow::Context;
9use anyhow::Result;
10use anyhow::anyhow;
11use anyhow::bail;
12use serde::Deserialize;
13use serde::Serialize;
14use tracing::warn;
15use url::Url;
16
17use crate::DockerBackend;
18use crate::LocalBackend;
19use crate::SYSTEM;
20use crate::TaskExecutionBackend;
21use crate::TesBackend;
22use crate::convert_unit_string;
23use crate::path::is_url;
24
25pub const MAX_RETRIES: u64 = 100;
27
28pub const DEFAULT_TASK_SHELL: &str = "bash";
30
31pub const DEFAULT_MAX_CONCURRENT_DOWNLOADS: u64 = 10;
33
34pub const DEFAULT_BACKEND_NAME: &str = "default";
36
37#[derive(Debug, Default, Clone, Serialize, Deserialize)]
39#[serde(rename_all = "snake_case", deny_unknown_fields)]
40pub struct Config {
41 #[serde(default)]
43 pub http: HttpConfig,
44 #[serde(default)]
46 pub workflow: WorkflowConfig,
47 #[serde(default)]
49 pub task: TaskConfig,
50 pub backend: Option<String>,
55 #[serde(default)]
63 pub backends: HashMap<String, BackendConfig>,
64 #[serde(default)]
66 pub storage: StorageConfig,
67 #[serde(default)]
81 pub suppress_env_specific_output: bool,
82}
83
84impl Config {
85 pub fn validate(&self) -> Result<()> {
87 self.http.validate()?;
88 self.workflow.validate()?;
89 self.task.validate()?;
90
91 if self.backend.is_none() && self.backends.len() < 2 {
92 } else {
95 let backend = self.backend.as_deref().unwrap_or(DEFAULT_BACKEND_NAME);
97 if !self.backends.contains_key(backend) {
98 bail!("a backend named `{backend}` is not present in the configuration");
99 }
100 }
101
102 for backend in self.backends.values() {
103 backend.validate()?;
104 }
105
106 self.storage.validate()?;
107 Ok(())
108 }
109
110 pub async fn create_backend(self: &Arc<Self>) -> Result<Arc<dyn TaskExecutionBackend>> {
112 let config = if self.backend.is_none() && self.backends.len() < 2 {
113 if self.backends.len() == 1 {
114 Cow::Borrowed(self.backends.values().next().unwrap())
116 } else {
117 Cow::Owned(BackendConfig::default())
119 }
120 } else {
121 let backend = self.backend.as_deref().unwrap_or(DEFAULT_BACKEND_NAME);
123 Cow::Borrowed(self.backends.get(backend).ok_or_else(|| {
124 anyhow!("a backend named `{backend}` is not present in the configuration")
125 })?)
126 };
127
128 match config.as_ref() {
129 BackendConfig::Local(config) => {
130 warn!(
131 "the engine is configured to use the local backend: tasks will not be run \
132 inside of a container"
133 );
134 Ok(Arc::new(LocalBackend::new(self.clone(), config)?))
135 }
136 BackendConfig::Docker(config) => {
137 Ok(Arc::new(DockerBackend::new(self.clone(), config).await?))
138 }
139 BackendConfig::Tes(config) => {
140 Ok(Arc::new(TesBackend::new(self.clone(), config).await?))
141 }
142 }
143 }
144}
145
146#[derive(Debug, Default, Clone, Serialize, Deserialize)]
148#[serde(rename_all = "snake_case", deny_unknown_fields)]
149pub struct HttpConfig {
150 #[serde(default)]
154 pub cache: Option<PathBuf>,
155 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub max_concurrent_downloads: Option<u64>,
160}
161
162impl HttpConfig {
163 pub fn validate(&self) -> Result<()> {
165 if let Some(limit) = self.max_concurrent_downloads
166 && limit == 0
167 {
168 bail!("configuration value `http.max_concurrent_downloads` cannot be zero");
169 }
170 Ok(())
171 }
172}
173
174#[derive(Debug, Default, Clone, Serialize, Deserialize)]
176#[serde(rename_all = "snake_case", deny_unknown_fields)]
177pub struct StorageConfig {
178 #[serde(default)]
180 pub azure: AzureStorageConfig,
181 #[serde(default)]
183 pub s3: S3StorageConfig,
184 #[serde(default)]
186 pub google: GoogleStorageConfig,
187}
188
189impl StorageConfig {
190 pub fn validate(&self) -> Result<()> {
192 self.azure.validate()?;
193 self.s3.validate()?;
194 self.google.validate()?;
195 Ok(())
196 }
197}
198
199#[derive(Debug, Default, Clone, Serialize, Deserialize)]
201#[serde(rename_all = "snake_case", deny_unknown_fields)]
202pub struct AzureStorageConfig {
203 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
212 pub auth: HashMap<String, HashMap<String, String>>,
213}
214
215impl AzureStorageConfig {
216 pub fn validate(&self) -> Result<()> {
218 Ok(())
219 }
220}
221
222#[derive(Debug, Default, Clone, Serialize, Deserialize)]
224#[serde(rename_all = "snake_case", deny_unknown_fields)]
225pub struct S3StorageConfig {
226 #[serde(default, skip_serializing_if = "Option::is_none")]
231 pub region: Option<String>,
232
233 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
239 pub auth: HashMap<String, String>,
240}
241
242impl S3StorageConfig {
243 pub fn validate(&self) -> Result<()> {
245 Ok(())
246 }
247}
248
249#[derive(Debug, Default, Clone, Serialize, Deserialize)]
251#[serde(rename_all = "snake_case", deny_unknown_fields)]
252pub struct GoogleStorageConfig {
253 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
259 pub auth: HashMap<String, String>,
260}
261
262impl GoogleStorageConfig {
263 pub fn validate(&self) -> Result<()> {
265 Ok(())
266 }
267}
268
269#[derive(Debug, Default, Clone, Serialize, Deserialize)]
271#[serde(rename_all = "snake_case", deny_unknown_fields)]
272pub struct WorkflowConfig {
273 #[serde(default)]
275 pub scatter: ScatterConfig,
276}
277
278impl WorkflowConfig {
279 pub fn validate(&self) -> Result<()> {
281 self.scatter.validate()?;
282 Ok(())
283 }
284}
285
286#[derive(Debug, Default, Clone, Serialize, Deserialize)]
288#[serde(rename_all = "snake_case", deny_unknown_fields)]
289pub struct ScatterConfig {
290 #[serde(default, skip_serializing_if = "Option::is_none")]
343 pub concurrency: Option<u64>,
344}
345
346impl ScatterConfig {
347 pub fn validate(&self) -> Result<()> {
349 if let Some(concurrency) = self.concurrency
350 && concurrency == 0
351 {
352 bail!("configuration value `workflow.scatter.concurrency` cannot be zero");
353 }
354
355 Ok(())
356 }
357}
358
359#[derive(Debug, Default, Clone, Serialize, Deserialize)]
361#[serde(rename_all = "snake_case", deny_unknown_fields)]
362pub struct TaskConfig {
363 #[serde(default, skip_serializing_if = "Option::is_none")]
369 pub retries: Option<u64>,
370 #[serde(default, skip_serializing_if = "Option::is_none")]
375 pub container: Option<String>,
376 #[serde(default, skip_serializing_if = "Option::is_none")]
384 pub shell: Option<String>,
385 #[serde(default)]
387 pub cpu_limit_behavior: TaskResourceLimitBehavior,
388 #[serde(default)]
390 pub memory_limit_behavior: TaskResourceLimitBehavior,
391}
392
393impl TaskConfig {
394 pub fn validate(&self) -> Result<()> {
396 if self.retries.unwrap_or(0) > MAX_RETRIES {
397 bail!("configuration value `task.retries` cannot exceed {MAX_RETRIES}");
398 }
399
400 Ok(())
401 }
402}
403
404#[derive(Debug, Default, Clone, Serialize, Deserialize)]
407#[serde(rename_all = "snake_case", deny_unknown_fields)]
408pub enum TaskResourceLimitBehavior {
409 TryWithMax,
412 #[default]
416 Deny,
417}
418
419#[derive(Debug, Clone, Serialize, Deserialize)]
421#[serde(rename_all = "snake_case", tag = "type")]
422pub enum BackendConfig {
423 Local(LocalBackendConfig),
425 Docker(DockerBackendConfig),
427 Tes(Box<TesBackendConfig>),
429}
430
431impl Default for BackendConfig {
432 fn default() -> Self {
433 Self::Docker(Default::default())
434 }
435}
436
437impl BackendConfig {
438 pub fn validate(&self) -> Result<()> {
440 match self {
441 Self::Local(config) => config.validate(),
442 Self::Docker(config) => config.validate(),
443 Self::Tes(config) => config.validate(),
444 }
445 }
446
447 pub fn as_local(&self) -> Option<&LocalBackendConfig> {
451 match self {
452 Self::Local(config) => Some(config),
453 _ => None,
454 }
455 }
456
457 pub fn as_docker(&self) -> Option<&DockerBackendConfig> {
461 match self {
462 Self::Docker(config) => Some(config),
463 _ => None,
464 }
465 }
466
467 pub fn as_tes(&self) -> Option<&TesBackendConfig> {
471 match self {
472 Self::Tes(config) => Some(config),
473 _ => None,
474 }
475 }
476}
477
478#[derive(Debug, Default, Clone, Serialize, Deserialize)]
485#[serde(rename_all = "snake_case", deny_unknown_fields)]
486pub struct LocalBackendConfig {
487 #[serde(default, skip_serializing_if = "Option::is_none")]
493 pub cpu: Option<u64>,
494
495 #[serde(default, skip_serializing_if = "Option::is_none")]
502 pub memory: Option<String>,
503}
504
505impl LocalBackendConfig {
506 pub fn validate(&self) -> Result<()> {
508 if let Some(cpu) = self.cpu {
509 if cpu == 0 {
510 bail!("local backend configuration value `cpu` cannot be zero");
511 }
512
513 let total = SYSTEM.cpus().len() as u64;
514 if cpu > total {
515 bail!(
516 "local backend configuration value `cpu` cannot exceed the virtual CPUs \
517 available to the host ({total})"
518 );
519 }
520 }
521
522 if let Some(memory) = &self.memory {
523 let memory = convert_unit_string(memory).with_context(|| {
524 format!("local backend configuration value `memory` has invalid value `{memory}`")
525 })?;
526
527 if memory == 0 {
528 bail!("local backend configuration value `memory` cannot be zero");
529 }
530
531 let total = SYSTEM.total_memory();
532 if memory > total {
533 bail!(
534 "local backend configuration value `memory` cannot exceed the total memory of \
535 the host ({total} bytes)"
536 );
537 }
538 }
539
540 Ok(())
541 }
542}
543
544const fn cleanup_default() -> bool {
546 true
547}
548
549#[derive(Debug, Clone, Serialize, Deserialize)]
551#[serde(rename_all = "snake_case", deny_unknown_fields)]
552pub struct DockerBackendConfig {
553 #[serde(default = "cleanup_default")]
557 pub cleanup: bool,
558}
559
560impl DockerBackendConfig {
561 pub fn validate(&self) -> Result<()> {
563 Ok(())
564 }
565}
566
567impl Default for DockerBackendConfig {
568 fn default() -> Self {
569 Self { cleanup: true }
570 }
571}
572
573#[derive(Debug, Default, Clone, Serialize, Deserialize)]
575#[serde(rename_all = "snake_case", deny_unknown_fields)]
576pub struct BasicAuthConfig {
577 pub username: Option<String>,
579 pub password: Option<String>,
581}
582
583impl BasicAuthConfig {
584 pub fn validate(&self) -> Result<()> {
586 if self.username.is_none() {
587 bail!("HTTP basic auth configuration value `username` is required");
588 }
589
590 if self.password.is_none() {
591 bail!("HTTP basic auth configuration value `password` is required");
592 }
593
594 Ok(())
595 }
596}
597
598#[derive(Debug, Default, Clone, Serialize, Deserialize)]
600#[serde(rename_all = "snake_case", deny_unknown_fields)]
601pub struct BearerAuthConfig {
602 pub token: Option<String>,
604}
605
606impl BearerAuthConfig {
607 pub fn validate(&self) -> Result<()> {
609 if self.token.is_none() {
610 bail!("HTTP bearer auth configuration value `token` is required");
611 }
612
613 Ok(())
614 }
615}
616
617#[derive(Debug, Clone, Serialize, Deserialize)]
619#[serde(rename_all = "snake_case", tag = "type")]
620pub enum TesBackendAuthConfig {
621 Basic(BasicAuthConfig),
623 Bearer(BearerAuthConfig),
625}
626
627impl TesBackendAuthConfig {
628 pub fn validate(&self) -> Result<()> {
630 match self {
631 Self::Basic(config) => config.validate(),
632 Self::Bearer(config) => config.validate(),
633 }
634 }
635}
636
637#[derive(Debug, Default, Clone, Serialize, Deserialize)]
639#[serde(rename_all = "snake_case", deny_unknown_fields)]
640pub struct TesBackendConfig {
641 #[serde(default)]
643 pub url: Option<Url>,
644
645 #[serde(default, skip_serializing_if = "Option::is_none")]
647 pub auth: Option<TesBackendAuthConfig>,
648
649 #[serde(default, skip_serializing_if = "Option::is_none")]
651 pub inputs: Option<Url>,
652
653 #[serde(default, skip_serializing_if = "Option::is_none")]
655 pub outputs: Option<Url>,
656
657 #[serde(default)]
661 pub interval: Option<u64>,
662
663 #[serde(default)]
667 pub max_concurrency: Option<u64>,
668
669 #[serde(default)]
672 pub insecure: bool,
673}
674
675impl TesBackendConfig {
676 pub fn validate(&self) -> Result<()> {
678 match &self.url {
679 Some(url) => {
680 if !self.insecure && url.scheme() != "https" {
681 bail!(
682 "TES backend configuration value `url` has invalid value `{url}`: URL \
683 must use a HTTPS scheme"
684 );
685 }
686 }
687 None => bail!("TES backend configuration value `url` is required"),
688 }
689
690 if let Some(auth) = &self.auth {
691 auth.validate()?;
692 }
693
694 match &self.inputs {
695 Some(url) => {
696 if !is_url(url.as_str()) {
697 bail!(
698 "TES backend storage configuration value `inputs` has invalid value \
699 `{url}`: URL scheme is not supported"
700 );
701 }
702
703 if !url.path().ends_with('/') {
704 bail!(
705 "TES backend storage configuration value `inputs` has invalid value \
706 `{url}`: URL path must end with a slash"
707 );
708 }
709 }
710 None => bail!("TES backend configuration value `inputs` is required"),
711 }
712
713 match &self.outputs {
714 Some(url) => {
715 if !is_url(url.as_str()) {
716 bail!(
717 "TES backend storage configuration value `outputs` has invalid value \
718 `{url}`: URL scheme is not supported"
719 );
720 }
721
722 if !url.path().ends_with('/') {
723 bail!(
724 "TES backend storage configuration value `outputs` has invalid value \
725 `{url}`: URL path must end with a slash"
726 );
727 }
728 }
729 None => bail!("TES backend storage configuration value `outputs` is required"),
730 }
731
732 Ok(())
733 }
734}
735
736#[cfg(test)]
737mod test {
738 use pretty_assertions::assert_eq;
739
740 use super::*;
741
742 #[test]
743 fn test_config_validate() {
744 let mut config = Config::default();
746 config.task.retries = Some(1000000);
747 assert_eq!(
748 config.validate().unwrap_err().to_string(),
749 "configuration value `task.retries` cannot exceed 100"
750 );
751
752 let mut config = Config::default();
754 config.workflow.scatter.concurrency = Some(0);
755 assert_eq!(
756 config.validate().unwrap_err().to_string(),
757 "configuration value `workflow.scatter.concurrency` cannot be zero"
758 );
759
760 let config = Config {
762 backend: Some("foo".into()),
763 ..Default::default()
764 };
765 assert_eq!(
766 config.validate().unwrap_err().to_string(),
767 "a backend named `foo` is not present in the configuration"
768 );
769 let config = Config {
770 backend: Some("bar".into()),
771 backends: [("foo".to_string(), BackendConfig::default())].into(),
772 ..Default::default()
773 };
774 assert_eq!(
775 config.validate().unwrap_err().to_string(),
776 "a backend named `bar` is not present in the configuration"
777 );
778
779 let config = Config {
781 backends: [("foo".to_string(), BackendConfig::default())].into(),
782 ..Default::default()
783 };
784 config.validate().expect("config should validate");
785
786 let config = Config {
788 backends: [(
789 "default".to_string(),
790 BackendConfig::Local(LocalBackendConfig {
791 cpu: Some(0),
792 ..Default::default()
793 }),
794 )]
795 .into(),
796 ..Default::default()
797 };
798 assert_eq!(
799 config.validate().unwrap_err().to_string(),
800 "local backend configuration value `cpu` cannot be zero"
801 );
802 let config = Config {
803 backends: [(
804 "default".to_string(),
805 BackendConfig::Local(LocalBackendConfig {
806 cpu: Some(10000000),
807 ..Default::default()
808 }),
809 )]
810 .into(),
811 ..Default::default()
812 };
813 assert!(config.validate().unwrap_err().to_string().starts_with(
814 "local backend configuration value `cpu` cannot exceed the virtual CPUs available to \
815 the host"
816 ));
817
818 let config = Config {
820 backends: [(
821 "default".to_string(),
822 BackendConfig::Local(LocalBackendConfig {
823 memory: Some("0 GiB".to_string()),
824 ..Default::default()
825 }),
826 )]
827 .into(),
828 ..Default::default()
829 };
830 assert_eq!(
831 config.validate().unwrap_err().to_string(),
832 "local backend configuration value `memory` cannot be zero"
833 );
834 let config = Config {
835 backends: [(
836 "default".to_string(),
837 BackendConfig::Local(LocalBackendConfig {
838 memory: Some("100 meows".to_string()),
839 ..Default::default()
840 }),
841 )]
842 .into(),
843 ..Default::default()
844 };
845 assert_eq!(
846 config.validate().unwrap_err().to_string(),
847 "local backend configuration value `memory` has invalid value `100 meows`"
848 );
849
850 let config = Config {
851 backends: [(
852 "default".to_string(),
853 BackendConfig::Local(LocalBackendConfig {
854 memory: Some("1000 TiB".to_string()),
855 ..Default::default()
856 }),
857 )]
858 .into(),
859 ..Default::default()
860 };
861 assert!(config.validate().unwrap_err().to_string().starts_with(
862 "local backend configuration value `memory` cannot exceed the total memory of the host"
863 ));
864
865 let config = Config {
867 backends: [(
868 "default".to_string(),
869 BackendConfig::Tes(Default::default()),
870 )]
871 .into(),
872 ..Default::default()
873 };
874 assert_eq!(
875 config.validate().unwrap_err().to_string(),
876 "TES backend configuration value `url` is required"
877 );
878
879 let config = Config {
881 backends: [(
882 "default".to_string(),
883 BackendConfig::Tes(
884 TesBackendConfig {
885 url: Some("http://example.com".parse().unwrap()),
886 inputs: Some("http://example.com".parse().unwrap()),
887 outputs: Some("http://example.com".parse().unwrap()),
888 ..Default::default()
889 }
890 .into(),
891 ),
892 )]
893 .into(),
894 ..Default::default()
895 };
896 assert_eq!(
897 config.validate().unwrap_err().to_string(),
898 "TES backend configuration value `url` has invalid value `http://example.com/`: URL \
899 must use a HTTPS scheme"
900 );
901
902 let config = Config {
904 backends: [(
905 "default".to_string(),
906 BackendConfig::Tes(
907 TesBackendConfig {
908 url: Some("http://example.com".parse().unwrap()),
909 inputs: Some("http://example.com".parse().unwrap()),
910 outputs: Some("http://example.com".parse().unwrap()),
911 insecure: true,
912 ..Default::default()
913 }
914 .into(),
915 ),
916 )]
917 .into(),
918 ..Default::default()
919 };
920 config.validate().expect("configuration should validate");
921
922 let config = Config {
924 backends: [(
925 "default".to_string(),
926 BackendConfig::Tes(Box::new(TesBackendConfig {
927 url: Some(Url::parse("https://example.com").unwrap()),
928 auth: Some(TesBackendAuthConfig::Basic(Default::default())),
929 ..Default::default()
930 })),
931 )]
932 .into(),
933 ..Default::default()
934 };
935 assert_eq!(
936 config.validate().unwrap_err().to_string(),
937 "HTTP basic auth configuration value `username` is required"
938 );
939 let config = Config {
940 backends: [(
941 "default".to_string(),
942 BackendConfig::Tes(Box::new(TesBackendConfig {
943 url: Some(Url::parse("https://example.com").unwrap()),
944 auth: Some(TesBackendAuthConfig::Basic(BasicAuthConfig {
945 username: Some("Foo".into()),
946 ..Default::default()
947 })),
948 ..Default::default()
949 })),
950 )]
951 .into(),
952 ..Default::default()
953 };
954 assert_eq!(
955 config.validate().unwrap_err().to_string(),
956 "HTTP basic auth configuration value `password` is required"
957 );
958
959 let config = Config {
961 backends: [(
962 "default".to_string(),
963 BackendConfig::Tes(Box::new(TesBackendConfig {
964 url: Some(Url::parse("https://example.com").unwrap()),
965 auth: Some(TesBackendAuthConfig::Bearer(Default::default())),
966 ..Default::default()
967 })),
968 )]
969 .into(),
970 ..Default::default()
971 };
972 assert_eq!(
973 config.validate().unwrap_err().to_string(),
974 "HTTP bearer auth configuration value `token` is required"
975 );
976
977 let mut config = Config::default();
978 config.http.max_concurrent_downloads = Some(0);
979 assert_eq!(
980 config.validate().unwrap_err().to_string(),
981 "configuration value `http.max_concurrent_downloads` cannot be zero"
982 );
983
984 let mut config = Config::default();
985 config.http.max_concurrent_downloads = Some(5);
986 assert!(
987 config.validate().is_ok(),
988 "should pass for valid configuration"
989 );
990
991 let mut config = Config::default();
992 config.http.max_concurrent_downloads = None;
993 assert!(config.validate().is_ok(), "should pass for default (None)");
994 }
995}