use std::borrow::Cow;
use std::collections::HashMap;
use std::collections::HashSet;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use anyhow::bail;
use anyhow::ensure;
use bytesize::ByteSize;
use indexmap::IndexMap;
use secrecy::ExposeSecret;
use serde::Deserialize;
use serde::Serialize;
use tokio::process::Command;
use tracing::error;
use tracing::warn;
use url::Url;
use crate::CancellationContext;
use crate::Events;
use crate::SYSTEM;
use crate::Value;
use crate::backend::TaskExecutionBackend;
use crate::convert_unit_string;
use crate::path::is_supported_url;
pub(crate) const MAX_RETRIES: u64 = 100;
pub(crate) const DEFAULT_TASK_SHELL: &str = "bash";
pub(crate) const DEFAULT_TASK_CONTAINER: &str = "ubuntu:latest";
const DEFAULT_BACKEND_NAME: &str = "default";
const MAX_LSF_JOB_NAME_PREFIX: usize = 100;
const REDACTED: &str = "<REDACTED>";
const CACHE_DIR_SENTINEL: &str = "system";
pub(crate) fn cache_dir() -> Result<PathBuf> {
const CACHE_DIR_ROOT: &str = "sprocket";
Ok(dirs::cache_dir()
.context("failed to determine user cache directory")?
.join(CACHE_DIR_ROOT))
}
fn is_default_shell(shell: &str) -> bool {
shell == DEFAULT_TASK_SHELL
}
fn get_default_shell() -> String {
DEFAULT_TASK_SHELL.to_string()
}
fn get_default_container() -> String {
DEFAULT_TASK_CONTAINER.to_string()
}
fn get_default_backend_name() -> String {
DEFAULT_BACKEND_NAME.to_string()
}
fn get_sentinel_cache_dir() -> String {
CACHE_DIR_SENTINEL.to_string()
}
#[derive(Debug, Clone)]
pub struct SecretString {
inner: secrecy::SecretString,
redacted: bool,
}
impl SecretString {
pub fn redact(&mut self) {
self.redacted = true;
}
pub fn unredact(&mut self) {
self.redacted = false;
}
pub fn inner(&self) -> &secrecy::SecretString {
&self.inner
}
}
impl From<String> for SecretString {
fn from(s: String) -> Self {
Self {
inner: s.into(),
redacted: true,
}
}
}
impl From<&str> for SecretString {
fn from(s: &str) -> Self {
Self {
inner: s.into(),
redacted: true,
}
}
}
impl Default for SecretString {
fn default() -> Self {
Self {
inner: Default::default(),
redacted: true,
}
}
}
impl serde::Serialize for SecretString {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use secrecy::ExposeSecret;
if self.redacted {
serializer.serialize_str(REDACTED)
} else {
serializer.serialize_str(self.inner.expose_secret())
}
}
}
impl<'de> serde::Deserialize<'de> for SecretString {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let inner = secrecy::SecretString::deserialize(deserializer)?;
Ok(Self {
inner,
redacted: true,
})
}
}
#[macro_export]
macro_rules! nullable_config_type {
(
$name:ident,
$inner:ty,
$sentinel:literal,
$value:ident,
$validation:expr,
$expected:literal,
$default:expr
) => {
#[doc = concat!("Configuration for [`", stringify!($name), "`].")]
#[derive(Clone, Debug)]
pub struct $name(Option<$inner>);
impl $name {
#[doc = concat!("Get the inner [`", stringify!($inner), "`].")]
pub fn inner(&self) -> Option<&$inner> {
self.0.as_ref()
}
#[doc = concat!("Try to create a new `", stringify!($name), "` from a `", stringify!($inner), "`.")]
pub fn try_new(val: Option<$inner>) -> std::result::Result<Self, anyhow::Error> {
match val {
None => Ok(Self(None)),
Some($value) if $validation => Ok(Self(Some($value))),
Some($value) => Err(anyhow::anyhow!(format!(
"expected {}, got `{}`",
$expected, $value
))),
}
}
}
impl Default for $name {
fn default() -> Self {
Self($default)
}
}
impl Serialize for $name {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
$name(None) => $sentinel.serialize(serializer),
$name(Some(i)) => i.serialize(serializer),
}
}
}
impl<'de> Deserialize<'de> for $name {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum Value {
Inner($inner),
Str(String),
Null,
}
match Value::deserialize(deserializer)? {
Value::Inner(i) => $name::try_new(Some(i)).map_err(serde::de::Error::custom),
Value::Str(s) if s == $sentinel => Ok($name(None)),
Value::Str($value) => Err(serde::de::Error::custom(format!(
"expected {} or `{}`, got `{}`",
$expected, $sentinel, $value
))),
Value::Null => Ok($name(None)),
}
}
}
};
}
#[derive(Debug, Default, Copy, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum FailureMode {
#[default]
Slow,
Fast,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct Config {
#[serde(default)]
pub http: HttpConfig,
#[serde(default)]
pub workflow: WorkflowConfig,
#[serde(default)]
pub task: TaskConfig,
#[serde(default = "get_default_backend_name")]
pub backend: String,
#[serde(default)]
pub backends: IndexMap<String, BackendConfig>,
#[serde(default)]
pub storage: StorageConfig,
#[serde(default)]
pub suppress_env_specific_output: bool,
#[serde(default)]
pub experimental_features_enabled: bool,
#[serde(default, rename = "fail")]
pub failure_mode: FailureMode,
}
impl Default for Config {
fn default() -> Self {
Self {
http: Default::default(),
workflow: Default::default(),
task: Default::default(),
backend: get_default_backend_name(),
backends: Default::default(),
storage: Default::default(),
suppress_env_specific_output: Default::default(),
experimental_features_enabled: Default::default(),
failure_mode: Default::default(),
}
}
}
impl Config {
pub async fn validate(&self) -> Result<()> {
self.http.validate()?;
self.workflow.validate()?;
self.task.validate()?;
if self.backends.is_empty() && self.backend == DEFAULT_BACKEND_NAME {
} else {
let backend = &self.backend;
if !self.backends.contains_key(backend) {
bail!("a backend named `{backend}` is not present in the configuration");
}
}
for backend in self.backends.values() {
backend.validate(self).await?;
}
self.storage.validate()?;
if self.suppress_env_specific_output && !self.experimental_features_enabled {
bail!("`suppress_env_specific_output` requires enabling experimental features");
}
Ok(())
}
pub fn redact(&mut self) {
for backend in self.backends.values_mut() {
backend.redact();
}
if let Some(auth) = &mut self.storage.azure.auth {
auth.redact();
}
if let Some(auth) = &mut self.storage.s3.auth {
auth.redact();
}
if let Some(auth) = &mut self.storage.google.auth {
auth.redact();
}
}
pub fn unredact(&mut self) {
for backend in self.backends.values_mut() {
backend.unredact();
}
if let Some(auth) = &mut self.storage.azure.auth {
auth.unredact();
}
if let Some(auth) = &mut self.storage.s3.auth {
auth.unredact();
}
if let Some(auth) = &mut self.storage.google.auth {
auth.unredact();
}
}
pub fn backend(&self) -> Result<Cow<'_, BackendConfig>> {
if !self.backends.is_empty() {
let backend = &self.backend;
return Ok(Cow::Borrowed(self.backends.get(backend).ok_or_else(
|| anyhow!("a backend named `{backend}` is not present in the configuration"),
)?));
}
Ok(Cow::Owned(BackendConfig::default()))
}
pub(crate) async fn create_backend(
self: &Arc<Self>,
run_root_dir: &Path,
events: Events,
cancellation: CancellationContext,
) -> Result<Arc<dyn TaskExecutionBackend>> {
use crate::backend::*;
match self.backend()?.as_ref() {
BackendConfig::Local(_) => {
warn!(
"the engine is configured to use the local backend: tasks will not be run \
inside of a container"
);
Ok(Arc::new(LocalBackend::new(
self.clone(),
events,
cancellation,
)?))
}
BackendConfig::Docker(_) => Ok(Arc::new(
DockerBackend::new(self.clone(), events, cancellation).await?,
)),
BackendConfig::Tes(_) => Ok(Arc::new(
TesBackend::new(self.clone(), events, cancellation).await?,
)),
BackendConfig::LsfApptainer(_) => Ok(Arc::new(LsfApptainerBackend::new(
self.clone(),
run_root_dir,
events,
cancellation,
)?)),
BackendConfig::SlurmApptainer(_) => Ok(Arc::new(SlurmApptainerBackend::new(
self.clone(),
run_root_dir,
events,
cancellation,
)?)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct HttpConfig {
#[serde(default = "get_sentinel_cache_dir")]
pub cache_dir: String,
pub retries: usize,
pub parallelism: Parallelism,
}
nullable_config_type!(
Parallelism,
usize,
"available",
value,
value > 0,
"a positive number",
None
);
impl Default for HttpConfig {
fn default() -> Self {
Self {
cache_dir: get_sentinel_cache_dir(),
retries: 5, parallelism: Default::default(),
}
}
}
impl HttpConfig {
pub fn validate(&self) -> Result<()> {
if let Some(parallelism) = self.parallelism.inner()
&& *parallelism == 0
{
bail!("configuration value `http.parallelism` cannot be zero");
}
Ok(())
}
pub fn cache_dir(&self) -> Result<PathBuf> {
const DOWNLOADS_CACHE_SUBDIR: &str = "downloads";
if self.using_system_cache_dir() {
cache_dir().map(|d| d.join(DOWNLOADS_CACHE_SUBDIR))
} else {
Ok(PathBuf::from(&self.cache_dir))
}
}
pub fn using_system_cache_dir(&self) -> bool {
self.cache_dir == CACHE_DIR_SENTINEL
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct StorageConfig {
#[serde(default)]
pub azure: AzureStorageConfig,
#[serde(default)]
pub s3: S3StorageConfig,
#[serde(default)]
pub google: GoogleStorageConfig,
}
impl StorageConfig {
pub fn validate(&self) -> Result<()> {
self.azure.validate()?;
self.s3.validate()?;
self.google.validate()?;
Ok(())
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct AzureStorageAuthConfig {
pub account_name: String,
pub access_key: SecretString,
}
impl AzureStorageAuthConfig {
pub fn validate(&self) -> Result<()> {
if self.account_name.is_empty() {
bail!("configuration value `storage.azure.auth.account_name` is required");
}
if self.access_key.inner.expose_secret().is_empty() {
bail!("configuration value `storage.azure.auth.access_key` is required");
}
Ok(())
}
pub fn redact(&mut self) {
self.access_key.redact();
}
pub fn unredact(&mut self) {
self.access_key.unredact();
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct AzureStorageConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub auth: Option<AzureStorageAuthConfig>,
}
impl AzureStorageConfig {
pub fn validate(&self) -> Result<()> {
if let Some(auth) = &self.auth {
auth.validate()?;
}
Ok(())
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct S3StorageAuthConfig {
pub access_key_id: String,
pub secret_access_key: SecretString,
}
impl S3StorageAuthConfig {
pub fn validate(&self) -> Result<()> {
if self.access_key_id.is_empty() {
bail!("configuration value `storage.s3.auth.access_key_id` is required");
}
if self.secret_access_key.inner.expose_secret().is_empty() {
bail!("configuration value `storage.s3.auth.secret_access_key` is required");
}
Ok(())
}
pub fn redact(&mut self) {
self.secret_access_key.redact();
}
pub fn unredact(&mut self) {
self.secret_access_key.unredact();
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct S3StorageConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub region: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub auth: Option<S3StorageAuthConfig>,
}
impl S3StorageConfig {
pub fn validate(&self) -> Result<()> {
if let Some(auth) = &self.auth {
auth.validate()?;
}
Ok(())
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct GoogleStorageAuthConfig {
pub access_key: String,
pub secret: SecretString,
}
impl GoogleStorageAuthConfig {
pub fn validate(&self) -> Result<()> {
if self.access_key.is_empty() {
bail!("configuration value `storage.google.auth.access_key` is required");
}
if self.secret.inner.expose_secret().is_empty() {
bail!("configuration value `storage.google.auth.secret` is required");
}
Ok(())
}
pub fn redact(&mut self) {
self.secret.redact();
}
pub fn unredact(&mut self) {
self.secret.unredact();
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct GoogleStorageConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub auth: Option<GoogleStorageAuthConfig>,
}
impl GoogleStorageConfig {
pub fn validate(&self) -> Result<()> {
if let Some(auth) = &self.auth {
auth.validate()?;
}
Ok(())
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct WorkflowConfig {
#[serde(default)]
pub scatter: ScatterConfig,
}
impl WorkflowConfig {
pub fn validate(&self) -> Result<()> {
self.scatter.validate()?;
Ok(())
}
}
const DEFAULT_SCATTER_CONCURRENCY: u64 = 1000;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct ScatterConfig {
pub concurrency: u64,
}
impl Default for ScatterConfig {
fn default() -> Self {
Self {
concurrency: DEFAULT_SCATTER_CONCURRENCY,
}
}
}
impl ScatterConfig {
pub fn validate(&self) -> Result<()> {
if self.concurrency == 0 {
bail!("configuration value `workflow.scatter.concurrency` cannot be zero");
}
Ok(())
}
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CallCachingMode {
#[default]
Off,
On,
Explicit,
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ContentDigestMode {
Strong,
#[default]
Weak,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct TaskConfig {
pub retries: Retries,
#[serde(default = "get_default_container")]
pub container: String,
#[serde(
default = "get_default_shell",
skip_serializing_if = "is_default_shell"
)]
pub shell: String,
pub cpu_limit_behavior: TaskResourceLimitBehavior,
pub memory_limit_behavior: TaskResourceLimitBehavior,
#[serde(default = "get_sentinel_cache_dir")]
pub cache_dir: String,
pub cache: CallCachingMode,
pub digests: ContentDigestMode,
#[serde(default)]
pub excluded_cache_requirements: HashSet<String>,
#[serde(default)]
pub excluded_cache_hints: HashSet<String>,
#[serde(default)]
pub excluded_cache_inputs: HashSet<String>,
}
nullable_config_type!(
Retries,
u64,
"default",
value,
value <= MAX_RETRIES,
"a number less than or equal to 100",
None
);
impl Default for TaskConfig {
fn default() -> Self {
Self {
retries: Default::default(),
container: get_default_container(),
shell: get_default_shell(),
cpu_limit_behavior: Default::default(),
memory_limit_behavior: Default::default(),
cache_dir: get_sentinel_cache_dir(),
cache: Default::default(),
digests: Default::default(),
excluded_cache_requirements: Default::default(),
excluded_cache_hints: Default::default(),
excluded_cache_inputs: Default::default(),
}
}
}
impl TaskConfig {
pub fn validate(&self) -> Result<()> {
if self.retries.inner().cloned().unwrap_or(0) > MAX_RETRIES {
bail!("configuration value `task.retries` cannot exceed {MAX_RETRIES}");
}
Ok(())
}
pub fn cache_dir(&self) -> Option<PathBuf> {
if self.cache_dir == CACHE_DIR_SENTINEL {
None
} else {
Some(PathBuf::from(&self.cache_dir))
}
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub enum TaskResourceLimitBehavior {
TryWithMax,
#[default]
Deny,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum BackendConfig {
Local(LocalBackendConfig),
Docker(DockerBackendConfig),
Tes(TesBackendConfig),
LsfApptainer(LsfApptainerBackendConfig),
SlurmApptainer(SlurmApptainerBackendConfig),
}
impl Default for BackendConfig {
fn default() -> Self {
Self::Docker(Default::default())
}
}
impl BackendConfig {
pub async fn validate(&self, engine_config: &Config) -> Result<()> {
match self {
Self::Local(config) => config.validate(),
Self::Docker(config) => config.validate(),
Self::Tes(config) => config.validate(),
Self::LsfApptainer(config) => config.validate(engine_config).await,
Self::SlurmApptainer(config) => config.validate(engine_config).await,
}
}
pub fn as_local(&self) -> Option<&LocalBackendConfig> {
match self {
Self::Local(config) => Some(config),
_ => None,
}
}
pub fn as_docker(&self) -> Option<&DockerBackendConfig> {
match self {
Self::Docker(config) => Some(config),
_ => None,
}
}
pub fn as_tes(&self) -> Option<&TesBackendConfig> {
match self {
Self::Tes(config) => Some(config),
_ => None,
}
}
pub fn as_lsf_apptainer(&self) -> Option<&LsfApptainerBackendConfig> {
match self {
Self::LsfApptainer(config) => Some(config),
_ => None,
}
}
pub fn as_slurm_apptainer(&self) -> Option<&SlurmApptainerBackendConfig> {
match self {
Self::SlurmApptainer(config) => Some(config),
_ => None,
}
}
pub fn redact(&mut self) {
match self {
Self::Local(_) | Self::Docker(_) | Self::LsfApptainer(_) | Self::SlurmApptainer(_) => {}
Self::Tes(config) => config.redact(),
}
}
pub fn unredact(&mut self) {
match self {
Self::Local(_) | Self::Docker(_) | Self::LsfApptainer(_) | Self::SlurmApptainer(_) => {}
Self::Tes(config) => config.unredact(),
}
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct LocalBackendConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cpu: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub memory: Option<String>,
}
impl LocalBackendConfig {
pub fn validate(&self) -> Result<()> {
if let Some(cpu) = self.cpu {
if cpu == 0 {
bail!("local backend configuration value `cpu` cannot be zero");
}
let total = SYSTEM.cpus().len() as u64;
if cpu > total {
bail!(
"local backend configuration value `cpu` cannot exceed the virtual CPUs \
available to the host ({total})"
);
}
}
if let Some(memory) = &self.memory {
let memory = convert_unit_string(memory).with_context(|| {
format!("local backend configuration value `memory` has invalid value `{memory}`")
})?;
if memory == 0 {
bail!("local backend configuration value `memory` cannot be zero");
}
let total = SYSTEM.total_memory();
if memory > total {
bail!(
"local backend configuration value `memory` cannot exceed the total memory of \
the host ({total} bytes)"
);
}
}
Ok(())
}
}
const fn cleanup_default() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct DockerBackendConfig {
#[serde(default = "cleanup_default")]
pub cleanup: bool,
}
impl DockerBackendConfig {
pub fn validate(&self) -> Result<()> {
Ok(())
}
}
impl Default for DockerBackendConfig {
fn default() -> Self {
Self { cleanup: true }
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct BasicAuthConfig {
#[serde(default)]
pub username: String,
#[serde(default)]
pub password: SecretString,
}
impl BasicAuthConfig {
pub fn validate(&self) -> Result<()> {
Ok(())
}
pub fn redact(&mut self) {
self.password.redact();
}
pub fn unredact(&mut self) {
self.password.unredact();
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct BearerAuthConfig {
#[serde(default)]
pub token: SecretString,
}
impl BearerAuthConfig {
pub fn validate(&self) -> Result<()> {
Ok(())
}
pub fn redact(&mut self) {
self.token.redact();
}
pub fn unredact(&mut self) {
self.token.unredact();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum TesBackendAuthConfig {
Basic(BasicAuthConfig),
Bearer(BearerAuthConfig),
}
impl TesBackendAuthConfig {
pub fn validate(&self) -> Result<()> {
match self {
Self::Basic(config) => config.validate(),
Self::Bearer(config) => config.validate(),
}
}
pub fn redact(&mut self) {
match self {
Self::Basic(auth) => auth.redact(),
Self::Bearer(auth) => auth.redact(),
}
}
pub fn unredact(&mut self) {
match self {
Self::Basic(auth) => auth.unredact(),
Self::Bearer(auth) => auth.unredact(),
}
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct TesBackendConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub url: Option<Url>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub auth: Option<TesBackendAuthConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub inputs: Option<Url>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub outputs: Option<Url>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub interval: Option<u64>,
pub retries: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_concurrency: Option<u32>,
#[serde(default)]
pub insecure: bool,
}
impl TesBackendConfig {
pub fn validate(&self) -> Result<()> {
match &self.url {
Some(url) => {
if !self.insecure && url.scheme() != "https" {
bail!(
"TES backend configuration value `url` has invalid value `{url}`: URL \
must use a HTTPS scheme"
);
}
}
None => bail!("TES backend configuration value `url` is required"),
}
if let Some(auth) = &self.auth {
auth.validate()?;
}
if let Some(max_concurrency) = self.max_concurrency
&& max_concurrency == 0
{
bail!("TES backend configuration value `max_concurrency` cannot be zero");
}
match &self.inputs {
Some(url) => {
if !is_supported_url(url.as_str()) {
bail!(
"TES backend storage configuration value `inputs` has invalid value \
`{url}`: URL scheme is not supported"
);
}
if !url.path().ends_with('/') {
bail!(
"TES backend storage configuration value `inputs` has invalid value \
`{url}`: URL path must end with a slash"
);
}
}
None => bail!("TES backend configuration value `inputs` is required"),
}
match &self.outputs {
Some(url) => {
if !is_supported_url(url.as_str()) {
bail!(
"TES backend storage configuration value `outputs` has invalid value \
`{url}`: URL scheme is not supported"
);
}
if !url.path().ends_with('/') {
bail!(
"TES backend storage configuration value `outputs` has invalid value \
`{url}`: URL path must end with a slash"
);
}
}
None => bail!("TES backend storage configuration value `outputs` is required"),
}
Ok(())
}
pub fn redact(&mut self) {
if let Some(auth) = &mut self.auth {
auth.redact();
}
}
pub fn unredact(&mut self) {
if let Some(auth) = &mut self.auth {
auth.unredact();
}
}
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct ApptainerConfig {
#[serde(default = "default_apptainer_executable")]
pub executable: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub image_cache_dir: Option<PathBuf>,
pub extra_apptainer_exec_args: Option<Vec<String>>,
}
const DEFAULT_APPTAINER_EXECUTABLE: &str = "apptainer";
fn default_apptainer_executable() -> String {
String::from(DEFAULT_APPTAINER_EXECUTABLE)
}
impl Default for ApptainerConfig {
fn default() -> Self {
Self {
executable: default_apptainer_executable(),
image_cache_dir: None,
extra_apptainer_exec_args: None,
}
}
}
impl ApptainerConfig {
pub async fn validate(&self) -> Result<(), anyhow::Error> {
Ok(())
}
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct LsfQueueConfig {
pub name: String,
pub max_cpu_per_task: Option<u64>,
pub max_memory_per_task: Option<ByteSize>,
}
impl LsfQueueConfig {
pub async fn validate(&self, name: &str) -> Result<(), anyhow::Error> {
let queue = &self.name;
ensure!(!queue.is_empty(), "{name}_lsf_queue name cannot be empty");
if let Some(max_cpu_per_task) = self.max_cpu_per_task {
ensure!(
max_cpu_per_task > 0,
"{name}_lsf_queue `{queue}` must allow at least 1 CPU to be provisioned"
);
}
if let Some(max_memory_per_task) = self.max_memory_per_task {
ensure!(
max_memory_per_task.as_u64() > 0,
"{name}_lsf_queue `{queue}` must allow at least some memory to be provisioned"
);
}
match tokio::time::timeout(
std::time::Duration::from_secs(10),
Command::new("bqueues").arg(queue).output(),
)
.await
{
Ok(output) => {
let output = output.context("validating LSF queue")?;
if !output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
error!(%stdout, %stderr, %queue, "failed to validate {name}_lsf_queue");
Err(anyhow!("failed to validate {name}_lsf_queue `{queue}`"))
} else {
Ok(())
}
}
Err(_) => Err(anyhow!(
"timed out trying to validate {name}_lsf_queue `{queue}`"
)),
}
}
}
#[derive(Debug, Default, Clone, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct LsfApptainerBackendConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub interval: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_concurrency: Option<u32>,
pub default_lsf_queue: Option<LsfQueueConfig>,
pub short_task_lsf_queue: Option<LsfQueueConfig>,
pub gpu_lsf_queue: Option<LsfQueueConfig>,
pub fpga_lsf_queue: Option<LsfQueueConfig>,
pub extra_bsub_args: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub job_name_prefix: Option<String>,
#[serde(default)]
#[serde(flatten)]
pub apptainer_config: ApptainerConfig,
}
impl LsfApptainerBackendConfig {
pub async fn validate(&self, engine_config: &Config) -> Result<(), anyhow::Error> {
if cfg!(not(unix)) {
bail!("LSF + Apptainer backend is not supported on non-unix platforms");
}
if !engine_config.experimental_features_enabled {
bail!("LSF + Apptainer backend requires enabling experimental features");
}
if let Some(queue) = &self.default_lsf_queue {
queue.validate("default").await?;
}
if let Some(queue) = &self.short_task_lsf_queue {
queue.validate("short_task").await?;
}
if let Some(queue) = &self.gpu_lsf_queue {
queue.validate("gpu").await?;
}
if let Some(queue) = &self.fpga_lsf_queue {
queue.validate("fpga").await?;
}
if let Some(prefix) = &self.job_name_prefix
&& prefix.len() > MAX_LSF_JOB_NAME_PREFIX
{
bail!(
"LSF job name prefix `{prefix}` exceeds the maximum {MAX_LSF_JOB_NAME_PREFIX} \
bytes"
);
}
self.apptainer_config.validate().await?;
Ok(())
}
pub(crate) fn lsf_queue_for_task(
&self,
requirements: &HashMap<String, Value>,
hints: &HashMap<String, Value>,
) -> Option<&LsfQueueConfig> {
if let Some(queue) = self.fpga_lsf_queue.as_ref()
&& let Some(true) = requirements
.get(wdl_ast::v1::TASK_REQUIREMENT_FPGA)
.and_then(Value::as_boolean)
{
return Some(queue);
}
if let Some(queue) = self.gpu_lsf_queue.as_ref()
&& let Some(true) = requirements
.get(wdl_ast::v1::TASK_REQUIREMENT_GPU)
.and_then(Value::as_boolean)
{
return Some(queue);
}
if let Some(queue) = self.short_task_lsf_queue.as_ref()
&& let Some(true) = hints
.get(wdl_ast::v1::TASK_HINT_SHORT_TASK)
.and_then(Value::as_boolean)
{
return Some(queue);
}
self.default_lsf_queue.as_ref()
}
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct SlurmPartitionConfig {
pub name: String,
pub max_cpu_per_task: Option<u64>,
pub max_memory_per_task: Option<ByteSize>,
}
impl SlurmPartitionConfig {
pub async fn validate(&self, name: &str) -> Result<(), anyhow::Error> {
let partition = &self.name;
ensure!(
!partition.is_empty(),
"{name}_slurm_partition name cannot be empty"
);
if let Some(max_cpu_per_task) = self.max_cpu_per_task {
ensure!(
max_cpu_per_task > 0,
"{name}_slurm_partition `{partition}` must allow at least 1 CPU to be provisioned"
);
}
if let Some(max_memory_per_task) = self.max_memory_per_task {
ensure!(
max_memory_per_task.as_u64() > 0,
"{name}_slurm_partition `{partition}` must allow at least some memory to be \
provisioned"
);
}
match tokio::time::timeout(
std::time::Duration::from_secs(10),
Command::new("scontrol")
.arg("show")
.arg("partition")
.arg(partition)
.output(),
)
.await
{
Ok(output) => {
let output = output.context("validating Slurm partition")?;
if !output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
error!(%stdout, %stderr, %partition, "failed to validate {name}_slurm_partition");
Err(anyhow!(
"failed to validate {name}_slurm_partition `{partition}`"
))
} else {
Ok(())
}
}
Err(_) => Err(anyhow!(
"timed out trying to validate {name}_slurm_partition `{partition}`"
)),
}
}
}
#[derive(Debug, Default, Clone, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct SlurmApptainerBackendConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub interval: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_concurrency: Option<u32>,
pub default_slurm_partition: Option<SlurmPartitionConfig>,
pub short_task_slurm_partition: Option<SlurmPartitionConfig>,
pub gpu_slurm_partition: Option<SlurmPartitionConfig>,
pub fpga_slurm_partition: Option<SlurmPartitionConfig>,
pub extra_sbatch_args: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub job_name_prefix: Option<String>,
#[serde(default)]
#[serde(flatten)]
pub apptainer_config: ApptainerConfig,
}
impl SlurmApptainerBackendConfig {
pub async fn validate(&self, engine_config: &Config) -> Result<(), anyhow::Error> {
if cfg!(not(unix)) {
bail!("Slurm + Apptainer backend is not supported on non-unix platforms");
}
if !engine_config.experimental_features_enabled {
bail!("Slurm + Apptainer backend requires enabling experimental features");
}
if let Some(partition) = &self.default_slurm_partition {
partition.validate("default").await?;
}
if let Some(partition) = &self.short_task_slurm_partition {
partition.validate("short_task").await?;
}
if let Some(partition) = &self.gpu_slurm_partition {
partition.validate("gpu").await?;
}
if let Some(partition) = &self.fpga_slurm_partition {
partition.validate("fpga").await?;
}
self.apptainer_config.validate().await?;
Ok(())
}
pub(crate) fn slurm_partition_for_task(
&self,
requirements: &HashMap<String, Value>,
hints: &HashMap<String, Value>,
) -> Option<&SlurmPartitionConfig> {
if let Some(partition) = self.fpga_slurm_partition.as_ref()
&& let Some(true) = requirements
.get(wdl_ast::v1::TASK_REQUIREMENT_FPGA)
.and_then(Value::as_boolean)
{
return Some(partition);
}
if let Some(partition) = self.gpu_slurm_partition.as_ref()
&& let Some(true) = requirements
.get(wdl_ast::v1::TASK_REQUIREMENT_GPU)
.and_then(Value::as_boolean)
{
return Some(partition);
}
if let Some(partition) = self.short_task_slurm_partition.as_ref()
&& let Some(true) = hints
.get(wdl_ast::v1::TASK_HINT_SHORT_TASK)
.and_then(Value::as_boolean)
{
return Some(partition);
}
self.default_slurm_partition.as_ref()
}
}
#[cfg(test)]
mod test {
use pretty_assertions::assert_eq;
use super::*;
#[test]
fn redacted_secret() {
let mut secret: SecretString = "secret".into();
assert_eq!(
serde_json::to_string(&secret).unwrap(),
format!(r#""{REDACTED}""#)
);
secret.unredact();
assert_eq!(serde_json::to_string(&secret).unwrap(), r#""secret""#);
secret.redact();
assert_eq!(
serde_json::to_string(&secret).unwrap(),
format!(r#""{REDACTED}""#)
);
}
#[test]
fn redacted_config() {
let config = Config {
backends: [
(
"first".to_string(),
BackendConfig::Tes(TesBackendConfig {
auth: Some(TesBackendAuthConfig::Basic(BasicAuthConfig {
username: "foo".into(),
password: "secret".into(),
})),
..Default::default()
}),
),
(
"second".to_string(),
BackendConfig::Tes(TesBackendConfig {
auth: Some(TesBackendAuthConfig::Bearer(BearerAuthConfig {
token: "secret".into(),
})),
..Default::default()
}),
),
]
.into(),
storage: StorageConfig {
azure: AzureStorageConfig {
auth: Some(AzureStorageAuthConfig {
account_name: "foo".into(),
access_key: "secret".into(),
}),
},
s3: S3StorageConfig {
auth: Some(S3StorageAuthConfig {
access_key_id: "foo".into(),
secret_access_key: "secret".into(),
}),
..Default::default()
},
google: GoogleStorageConfig {
auth: Some(GoogleStorageAuthConfig {
access_key: "foo".into(),
secret: "secret".into(),
}),
},
},
..Default::default()
};
let json = serde_json::to_string_pretty(&config).unwrap();
assert!(json.contains("secret"), "`{json}` contains a secret");
}
#[tokio::test]
async fn test_config_validate() {
let mut config = Config::default();
config.task.retries = Retries(Some(255));
assert_eq!(
config.validate().await.unwrap_err().to_string(),
"configuration value `task.retries` cannot exceed 100"
);
let mut config = Config::default();
config.workflow.scatter.concurrency = 0;
assert_eq!(
config.validate().await.unwrap_err().to_string(),
"configuration value `workflow.scatter.concurrency` cannot be zero"
);
let config = Config {
backend: "foo".into(),
..Default::default()
};
assert_eq!(
config.validate().await.unwrap_err().to_string(),
"a backend named `foo` is not present in the configuration"
);
let config = Config {
backend: "bar".into(),
backends: [("foo".to_string(), BackendConfig::default())].into(),
..Default::default()
};
assert_eq!(
config.validate().await.unwrap_err().to_string(),
"a backend named `bar` is not present in the configuration"
);
let config = Config {
backend: "foo".to_string(),
backends: [("foo".to_string(), BackendConfig::default())].into(),
..Default::default()
};
config.validate().await.expect("config should validate");
let config = Config {
backends: [(
"default".to_string(),
BackendConfig::Local(LocalBackendConfig {
cpu: Some(0),
..Default::default()
}),
)]
.into(),
..Default::default()
};
assert_eq!(
config.validate().await.unwrap_err().to_string(),
"local backend configuration value `cpu` cannot be zero"
);
let config = Config {
backends: [(
"default".to_string(),
BackendConfig::Local(LocalBackendConfig {
cpu: Some(10000000),
..Default::default()
}),
)]
.into(),
..Default::default()
};
assert!(
config
.validate()
.await
.unwrap_err()
.to_string()
.starts_with(
"local backend configuration value `cpu` cannot exceed the virtual CPUs \
available to the host"
)
);
let config = Config {
backends: [(
"default".to_string(),
BackendConfig::Local(LocalBackendConfig {
memory: Some("0 GiB".to_string()),
..Default::default()
}),
)]
.into(),
..Default::default()
};
assert_eq!(
config.validate().await.unwrap_err().to_string(),
"local backend configuration value `memory` cannot be zero"
);
let config = Config {
backends: [(
"default".to_string(),
BackendConfig::Local(LocalBackendConfig {
memory: Some("100 meows".to_string()),
..Default::default()
}),
)]
.into(),
..Default::default()
};
assert_eq!(
config.validate().await.unwrap_err().to_string(),
"local backend configuration value `memory` has invalid value `100 meows`"
);
let config = Config {
backends: [(
"default".to_string(),
BackendConfig::Local(LocalBackendConfig {
memory: Some("1000 TiB".to_string()),
..Default::default()
}),
)]
.into(),
..Default::default()
};
assert!(
config
.validate()
.await
.unwrap_err()
.to_string()
.starts_with(
"local backend configuration value `memory` cannot exceed the total memory of \
the host"
)
);
let config = Config {
backends: [(
"default".to_string(),
BackendConfig::Tes(Default::default()),
)]
.into(),
..Default::default()
};
assert_eq!(
config.validate().await.unwrap_err().to_string(),
"TES backend configuration value `url` is required"
);
let config = Config {
backends: [(
"default".to_string(),
BackendConfig::Tes(TesBackendConfig {
url: Some("https://example.com".parse().unwrap()),
max_concurrency: Some(0),
..Default::default()
}),
)]
.into(),
..Default::default()
};
assert_eq!(
config.validate().await.unwrap_err().to_string(),
"TES backend configuration value `max_concurrency` cannot be zero"
);
let config = Config {
backends: [(
"default".to_string(),
BackendConfig::Tes(TesBackendConfig {
url: Some("http://example.com".parse().unwrap()),
inputs: Some("http://example.com".parse().unwrap()),
outputs: Some("http://example.com".parse().unwrap()),
..Default::default()
}),
)]
.into(),
..Default::default()
};
assert_eq!(
config.validate().await.unwrap_err().to_string(),
"TES backend configuration value `url` has invalid value `http://example.com/`: URL \
must use a HTTPS scheme"
);
let config = Config {
backends: [(
"default".to_string(),
BackendConfig::Tes(TesBackendConfig {
url: Some("http://example.com".parse().unwrap()),
inputs: Some("http://example.com".parse().unwrap()),
outputs: Some("http://example.com".parse().unwrap()),
insecure: true,
..Default::default()
}),
)]
.into(),
..Default::default()
};
config
.validate()
.await
.expect("configuration should validate");
let mut config = Config::default();
config.http.parallelism = Parallelism(Some(0));
assert_eq!(
config.validate().await.unwrap_err().to_string(),
"configuration value `http.parallelism` cannot be zero"
);
let mut config = Config::default();
config.http.parallelism = Parallelism(Some(5));
assert!(
config.validate().await.is_ok(),
"should pass for valid configuration"
);
let mut config = Config::default();
config.http.parallelism = Parallelism(None);
assert!(
config.validate().await.is_ok(),
"should pass for default (None)"
);
#[cfg(unix)]
{
let job_name_prefix = "A".repeat(MAX_LSF_JOB_NAME_PREFIX * 2);
let mut config = Config {
experimental_features_enabled: true,
..Default::default()
};
config.backends.insert(
"default".to_string(),
BackendConfig::LsfApptainer(LsfApptainerBackendConfig {
job_name_prefix: Some(job_name_prefix.clone()),
..Default::default()
}),
);
assert_eq!(
config.validate().await.unwrap_err().to_string(),
format!("LSF job name prefix `{job_name_prefix}` exceeds the maximum 100 bytes")
);
}
}
}