use std::borrow::Cow;
use std::collections::HashMap;
use std::path::PathBuf;
use std::str::FromStr;
use anyhow::Context;
use anyhow::Result;
use anyhow::bail;
use serde::Deserialize;
use serde::Serialize;
use tracing::warn;
use wdl_analysis::types::PrimitiveType;
use wdl_ast::v1::TASK_HINT_DISKS;
use wdl_ast::v1::TASK_HINT_GPU;
use wdl_ast::v1::TASK_REQUIREMENT_CONTAINER;
use wdl_ast::v1::TASK_REQUIREMENT_CONTAINER_ALIAS;
use wdl_ast::v1::TASK_REQUIREMENT_CPU;
use wdl_ast::v1::TASK_REQUIREMENT_DISKS;
use wdl_ast::v1::TASK_REQUIREMENT_GPU;
use wdl_ast::v1::TASK_REQUIREMENT_MAX_RETRIES;
use wdl_ast::v1::TASK_REQUIREMENT_MAX_RETRIES_ALIAS;
use wdl_ast::v1::TASK_REQUIREMENT_MEMORY;
use crate::Coercible;
use crate::ONE_GIBIBYTE;
use crate::TaskInputs;
use crate::Value;
use crate::config::Config;
use crate::units::StorageUnit;
use crate::v1::DEFAULT_DISK_MOUNT_POINT;
use crate::v1::task::DEFAULT_GPU_COUNT;
use crate::v1::task::DEFAULT_TASK_REQUIREMENT_CPU;
use crate::v1::task::DEFAULT_TASK_REQUIREMENT_MAX_RETRIES;
use crate::v1::task::DEFAULT_TASK_REQUIREMENT_MEMORY;
use crate::v1::task::find_key_value;
use crate::v1::task::parse_storage_value;
use crate::v1::validators::SettingSource;
use crate::v1::validators::ensure_non_negative_i64;
use crate::v1::validators::invalid_numeric_value_message;
const DOCKER_PROTOCOL: &str = "docker://";
const LIBRARY_PROTOCOL: &str = "library://";
const ORAS_PROTOCOL: &str = "oras://";
const FILE_PROTOCOL: &str = "file://";
const SIF_EXTENSION: &str = "sif";
const WILDCARD_CONTAINER: &str = "*";
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ContainerSource {
Docker(String),
Library(String),
Oras(String),
SifFile(PathBuf),
Unknown(String),
}
impl ContainerSource {
pub fn scheme(&self) -> Option<&'static str> {
match self {
Self::Docker(_) => Some("docker"),
Self::Library(_) => Some("library"),
Self::Oras(_) => Some("oras"),
Self::SifFile(_) => Some("file"),
Self::Unknown(_) => None,
}
}
pub fn name(&self) -> Option<&str> {
match self {
Self::Docker(name) | Self::Library(name) | Self::Oras(name) => Some(name),
Self::SifFile(_) | Self::Unknown(_) => None,
}
}
}
impl FromStr for ContainerSource {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Some(path_str) = s.strip_prefix(FILE_PROTOCOL) {
let path = PathBuf::from(path_str);
return match path.extension().and_then(|e| e.to_str()) {
Some(ext) if ext == SIF_EXTENSION => Ok(Self::SifFile(path)),
_ => Ok(Self::Unknown(s.to_string())),
};
}
if let Some(image) = s.strip_prefix(DOCKER_PROTOCOL) {
return Ok(Self::Docker(image.to_string()));
}
if let Some(image) = s.strip_prefix(LIBRARY_PROTOCOL) {
return Ok(Self::Library(image.to_string()));
}
if let Some(image) = s.strip_prefix(ORAS_PROTOCOL) {
return Ok(Self::Oras(image.to_string()));
}
if s.contains("://") {
return Ok(Self::Unknown(s.to_string()));
}
Ok(Self::Docker(s.to_string()))
}
}
impl std::fmt::Display for ContainerSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if f.alternate() {
match self {
Self::Docker(s) => write!(f, "docker://{s}"),
Self::Library(s) => write!(f, "library://{s}"),
Self::Oras(s) => write!(f, "oras://{s}"),
Self::SifFile(p) => write!(f, "file://{}", p.display()),
Self::Unknown(s) => write!(f, "{s}"),
}
} else {
match self {
Self::Docker(s) | Self::Library(s) | Self::Oras(s) | Self::Unknown(s) => {
write!(f, "{s}")
}
Self::SifFile(p) => write!(f, "{}", p.display()),
}
}
}
}
pub(crate) fn has_container_requirement(
inputs: &TaskInputs,
requirements: &HashMap<String, Value>,
) -> bool {
find_key_value(
&[TASK_REQUIREMENT_CONTAINER, TASK_REQUIREMENT_CONTAINER_ALIAS],
|key| inputs.requirement(key).or_else(|| requirements.get(key)),
)
.is_some()
}
pub(crate) fn container(
inputs: &TaskInputs,
requirements: &HashMap<String, Value>,
default: &str,
) -> Vec<ContainerSource> {
let entry = find_key_value(
&[TASK_REQUIREMENT_CONTAINER, TASK_REQUIREMENT_CONTAINER_ALIAS],
|key| inputs.requirement(key).or_else(|| requirements.get(key)),
);
let Some((_, value)) = entry else {
return vec![default.parse().unwrap()];
};
if let Some(array) = value.as_array() {
array
.as_slice()
.iter()
.map(|v| {
let s = v
.as_string()
.expect("container array element should be a `String`");
let s = s.as_ref();
if s == WILDCARD_CONTAINER { default } else { s }
.parse()
.unwrap()
})
.collect()
} else {
let s: Cow<'_, str> = value
.coerce(None, &PrimitiveType::String.into())
.expect("container value should be coercible to `String`")
.unwrap_string()
.as_ref()
.clone()
.into();
vec![
if *s == *WILDCARD_CONTAINER {
default
} else {
&s
}
.parse()
.unwrap(),
]
}
}
pub(crate) fn cpu(inputs: &TaskInputs, requirements: &HashMap<String, Value>) -> f64 {
find_key_value(&[TASK_REQUIREMENT_CPU], |key| {
inputs.requirement(key).or_else(|| requirements.get(key))
})
.map(|(_, v)| {
v.coerce(None, &PrimitiveType::Float.into())
.expect("type should coerce")
.unwrap_float()
})
.unwrap_or(DEFAULT_TASK_REQUIREMENT_CPU)
}
pub(crate) fn memory(inputs: &TaskInputs, requirements: &HashMap<String, Value>) -> Result<i64> {
if let Some((key, value)) = find_key_value(&[TASK_REQUIREMENT_MEMORY], |key| {
inputs.requirement(key).or_else(|| requirements.get(key))
}) {
let bytes = parse_storage_value(value, |raw| {
invalid_numeric_value_message(SettingSource::Requirement, key, raw)
})?;
return ensure_non_negative_i64(SettingSource::Requirement, key, bytes);
}
Ok(DEFAULT_TASK_REQUIREMENT_MEMORY)
}
pub(crate) fn gpu(
inputs: &TaskInputs,
requirements: &HashMap<String, Value>,
hints: &HashMap<String, Value>,
) -> Option<u64> {
let Some(true) = find_key_value(&[TASK_REQUIREMENT_GPU], |key| {
inputs.requirement(key).or_else(|| requirements.get(key))
})
.and_then(|(_, v)| v.as_boolean()) else {
return None;
};
let Some((_, hint)) = find_key_value(&[TASK_HINT_GPU], |key| {
inputs.hint(key).or_else(|| hints.get(key))
}) else {
return Some(DEFAULT_GPU_COUNT);
};
if let Some(hint) = hint.as_string() {
warn!(
%hint,
"hint `{TASK_HINT_GPU}` cannot be a string: falling back to {DEFAULT_GPU_COUNT} GPU(s)"
);
return Some(DEFAULT_GPU_COUNT);
}
match hint.as_integer() {
Some(count) if count >= 1 => Some(count as u64),
Some(count) => {
warn!(
%count,
"`{TASK_HINT_GPU}` hint specified {count} GPU(s); no GPUs will be requested for execution"
);
None
}
None => {
unreachable!("`{TASK_HINT_GPU}` hint must be an integer or string")
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[allow(clippy::upper_case_acronyms)]
pub(crate) enum DiskType {
SSD,
HDD,
}
impl FromStr for DiskType {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"SSD" => Ok(Self::SSD),
"HDD" => Ok(Self::HDD),
_ => Err(()),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct DiskRequirement {
pub size: i64,
pub ty: Option<DiskType>,
}
pub(crate) fn disks<'a>(
inputs: &'a TaskInputs,
requirements: &'a HashMap<String, Value>,
hints: &HashMap<String, Value>,
) -> Result<HashMap<&'a str, DiskRequirement>> {
fn lookup_type(
mount_point: Option<&str>,
hints: &HashMap<String, Value>,
inputs: &TaskInputs,
) -> Option<DiskType> {
find_key_value(&[TASK_HINT_DISKS], |key| {
inputs.hint(key).or_else(|| hints.get(key))
})
.and_then(|(_, v)| {
if let Some(ty) = v.as_string() {
return ty.parse().ok();
}
if let Some(map) = v.as_map() {
if let Some((_, v)) = map.iter().find(|(k, _)| match (k, mount_point) {
(_, None) => false,
(k, Some(mount_point)) => k
.as_string()
.map(|k| k.as_str() == mount_point)
.unwrap_or(false),
}) {
return v.as_string().and_then(|ty| ty.parse().ok());
}
}
None
})
}
fn parse_disk_spec(spec: &str) -> Option<(i64, Option<&str>)> {
let iter = spec.split_whitespace();
let mut first = None;
let mut second = None;
let mut third = None;
for part in iter {
if first.is_none() {
first = Some(part);
continue;
}
if second.is_none() {
second = Some(part);
continue;
}
if third.is_none() {
third = Some(part);
continue;
}
return None;
}
match (first, second, third) {
(None, None, None) => None,
(Some(size), None, None) => {
Some((size.parse().ok()?, None))
}
(Some(first), Some(second), None) => {
if let Ok(size) = first.parse() {
let unit: StorageUnit = second.parse().ok()?;
let size = unit.bytes(size)? / (ONE_GIBIBYTE as u64);
return Some((size.try_into().ok()?, None));
}
if !first.starts_with('/') {
return None;
}
Some((second.parse().ok()?, Some(first)))
}
(Some(mount_point), Some(size), Some(unit)) => {
let unit: StorageUnit = unit.parse().ok()?;
let size = unit.bytes(size.parse().ok()?)? / (ONE_GIBIBYTE as u64);
if !mount_point.starts_with('/') {
return None;
}
Some((size.try_into().ok()?, Some(mount_point)))
}
_ => unreachable!("should have one, two, or three values"),
}
}
fn insert_disk<'a>(
spec: &'a str,
hints: &HashMap<String, Value>,
inputs: &TaskInputs,
disks: &mut HashMap<&'a str, DiskRequirement>,
) -> Result<()> {
let (size, mount_point) =
parse_disk_spec(spec).with_context(|| format!("invalid disk specification `{spec}"))?;
let prev = disks.insert(
mount_point.unwrap_or(DEFAULT_DISK_MOUNT_POINT),
DiskRequirement {
size,
ty: lookup_type(mount_point, hints, inputs),
},
);
if prev.is_some() {
bail!(
"duplicate mount point `{mp}` specified in `disks` requirement",
mp = mount_point.unwrap_or(DEFAULT_DISK_MOUNT_POINT)
);
}
Ok(())
}
let mut disks = HashMap::new();
if let Some((key, v)) = find_key_value(&[TASK_REQUIREMENT_DISKS], |key| {
inputs.requirement(key).or_else(|| requirements.get(key))
}) {
if let Some(size) = v.as_integer() {
if size < 0 {
bail!("task requirement `{key}` cannot be less than zero");
}
disks.insert(
"/",
DiskRequirement {
size,
ty: lookup_type(None, hints, inputs),
},
);
} else if let Some(spec) = v.as_string() {
insert_disk(spec, hints, inputs, &mut disks)?;
} else if let Some(v) = v.as_array() {
for spec in v.as_slice() {
insert_disk(
spec.as_string().expect("spec should be a string"),
hints,
inputs,
&mut disks,
)?;
}
} else {
unreachable!("value should be an integer, string, or array");
}
}
Ok(disks)
}
pub(crate) fn max_retries(
inputs: &TaskInputs,
requirements: &HashMap<String, Value>,
config: &Config,
) -> Result<u64> {
if let Some((key, value)) = find_key_value(
&[
TASK_REQUIREMENT_MAX_RETRIES,
TASK_REQUIREMENT_MAX_RETRIES_ALIAS,
],
|key| inputs.requirement(key).or_else(|| requirements.get(key)),
) {
let retries = value
.as_integer()
.expect("`max_retries` requirement should be an integer");
return ensure_non_negative_i64(SettingSource::Requirement, key, retries)
.map(|value| value as u64);
}
Ok(config
.task
.retries
.inner()
.cloned()
.unwrap_or(DEFAULT_TASK_REQUIREMENT_MAX_RETRIES))
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::ContainerSource;
use super::*;
use crate::PrimitiveValue;
use crate::config::DEFAULT_TASK_CONTAINER;
fn map_with_value(key: &str, value: Value) -> HashMap<String, Value> {
let mut map = HashMap::new();
map.insert(key.to_string(), value);
map
}
#[test]
fn memory_disallows_negative_values() {
let requirements = map_with_value(TASK_REQUIREMENT_MEMORY, Value::from(-1));
let err = memory(&TaskInputs::default(), &requirements)
.expect_err("`memory` should reject negatives");
assert!(
err.to_string()
.contains("task requirement `memory` cannot be less than zero")
);
}
#[test]
fn max_retries_disallows_negative_values() {
let requirements = map_with_value(TASK_REQUIREMENT_MAX_RETRIES, Value::from(-2));
let err = max_retries(&TaskInputs::default(), &requirements, &Config::default())
.expect_err("`max_retries` should reject negatives");
assert!(
err.to_string()
.contains("task requirement `max_retries` cannot be less than zero")
);
}
#[test]
fn parses_bare_docker_image() {
let source: ContainerSource = "ubuntu:22.04".parse().unwrap();
assert_eq!(source, ContainerSource::Docker("ubuntu:22.04".to_string()));
assert_eq!(source.to_string(), "ubuntu:22.04");
assert_eq!(format!("{source:#}"), "docker://ubuntu:22.04");
}
#[test]
fn parses_docker_protocol() {
let source: ContainerSource = "docker://ubuntu:latest".parse().unwrap();
assert_eq!(source, ContainerSource::Docker("ubuntu:latest".to_string()));
assert_eq!(source.to_string(), "ubuntu:latest");
assert_eq!(format!("{source:#}"), "docker://ubuntu:latest");
}
#[test]
fn parses_library_protocol() {
let source: ContainerSource = "library://sylabs/default/alpine:3.18".parse().unwrap();
assert_eq!(
source,
ContainerSource::Library("sylabs/default/alpine:3.18".to_string())
);
assert_eq!(source.to_string(), "sylabs/default/alpine:3.18");
assert_eq!(
format!("{source:#}"),
"library://sylabs/default/alpine:3.18"
);
}
#[test]
fn parses_oras_protocol() {
let source: ContainerSource = "oras://ghcr.io/org/image:tag".parse().unwrap();
assert_eq!(
source,
ContainerSource::Oras("ghcr.io/org/image:tag".to_string())
);
assert_eq!(source.to_string(), "ghcr.io/org/image:tag");
assert_eq!(format!("{source:#}"), "oras://ghcr.io/org/image:tag");
}
#[test]
fn parses_file_protocol_sif() {
let source: ContainerSource = "file:///path/to/image.sif".parse().unwrap();
assert_eq!(
source,
ContainerSource::SifFile(PathBuf::from("/path/to/image.sif"))
);
assert_eq!(source.to_string(), "/path/to/image.sif");
assert_eq!(format!("{source:#}"), "file:///path/to/image.sif");
}
#[test]
fn parses_file_protocol_unknown_extension() {
let source: ContainerSource = "file:///path/to/image.tar".parse().unwrap();
assert_eq!(
source,
ContainerSource::Unknown("file:///path/to/image.tar".to_string())
);
assert_eq!(source.to_string(), "file:///path/to/image.tar");
assert_eq!(format!("{source:#}"), "file:///path/to/image.tar");
}
#[test]
fn parses_unknown_protocol() {
let source: ContainerSource = "ftp://example.com/image".parse().unwrap();
assert_eq!(
source,
ContainerSource::Unknown("ftp://example.com/image".to_string())
);
assert_eq!(source.to_string(), "ftp://example.com/image");
assert_eq!(format!("{source:#}"), "ftp://example.com/image");
}
#[test]
fn parses_complex_docker_image() {
let source: ContainerSource = "ghcr.io/stjude/sprocket:v1.0.0".parse().unwrap();
assert_eq!(
source,
ContainerSource::Docker("ghcr.io/stjude/sprocket:v1.0.0".to_string())
);
}
#[test]
fn parses_docker_image_with_digest() {
let source: ContainerSource = "ubuntu@sha256:abcdef1234567890".parse().unwrap();
assert_eq!(
source,
ContainerSource::Docker("ubuntu@sha256:abcdef1234567890".to_string())
);
}
#[test]
fn container_returns_default_when_unset() {
let result = container(
&TaskInputs::default(),
&HashMap::new(),
DEFAULT_TASK_CONTAINER,
);
assert_eq!(result.len(), 1);
assert_eq!(
result[0],
ContainerSource::Docker(DEFAULT_TASK_CONTAINER.to_string())
);
}
#[test]
fn container_returns_custom_default_when_unset() {
let result = container(&TaskInputs::default(), &HashMap::new(), "alpine:3.18");
assert_eq!(result.len(), 1);
assert_eq!(
result[0],
ContainerSource::Docker("alpine:3.18".to_string())
);
}
#[test]
fn container_returns_single_image() {
let requirements = map_with_value(
TASK_REQUIREMENT_CONTAINER,
PrimitiveValue::new_string("foo:bar").into(),
);
let result = container(
&TaskInputs::default(),
&requirements,
DEFAULT_TASK_CONTAINER,
);
assert_eq!(result.len(), 1);
assert_eq!(result[0], ContainerSource::Docker("foo:bar".to_string()));
}
#[test]
fn container_resolves_single_wildcard_to_default() {
let requirements = map_with_value(
TASK_REQUIREMENT_CONTAINER,
PrimitiveValue::new_string("*").into(),
);
let result = container(
&TaskInputs::default(),
&requirements,
DEFAULT_TASK_CONTAINER,
);
assert_eq!(result.len(), 1);
assert_eq!(
result[0],
ContainerSource::Docker(DEFAULT_TASK_CONTAINER.to_string())
);
}
#[test]
fn container_resolves_single_wildcard_to_custom_default() {
let requirements = map_with_value(
TASK_REQUIREMENT_CONTAINER,
PrimitiveValue::new_string("*").into(),
);
let result = container(&TaskInputs::default(), &requirements, "debian:12");
assert_eq!(result.len(), 1);
assert_eq!(result[0], ContainerSource::Docker("debian:12".to_string()));
}
#[test]
fn container_returns_array_of_images() {
use wdl_analysis::types::ArrayType;
let elements = vec![
PrimitiveValue::new_string("foo:1.0").into(),
PrimitiveValue::new_string("bar:2.0").into(),
PrimitiveValue::new_string("baz:3.0").into(),
];
let array = crate::Array::new_unchecked(
ArrayType::new(wdl_analysis::types::PrimitiveType::String),
elements,
);
let requirements = map_with_value(
TASK_REQUIREMENT_CONTAINER,
Value::Compound(crate::CompoundValue::Array(array)),
);
let result = container(
&TaskInputs::default(),
&requirements,
DEFAULT_TASK_CONTAINER,
);
assert_eq!(result.len(), 3);
assert_eq!(result[0], ContainerSource::Docker("foo:1.0".to_string()));
assert_eq!(result[1], ContainerSource::Docker("bar:2.0".to_string()));
assert_eq!(result[2], ContainerSource::Docker("baz:3.0".to_string()));
}
#[test]
fn container_resolves_wildcard_in_array() {
use wdl_analysis::types::ArrayType;
let elements = vec![
PrimitiveValue::new_string("foo:1.0").into(),
PrimitiveValue::new_string("*").into(),
PrimitiveValue::new_string("bar:2.0").into(),
];
let array = crate::Array::new_unchecked(
ArrayType::new(wdl_analysis::types::PrimitiveType::String),
elements,
);
let requirements = map_with_value(
TASK_REQUIREMENT_CONTAINER,
Value::Compound(crate::CompoundValue::Array(array)),
);
let result = container(
&TaskInputs::default(),
&requirements,
DEFAULT_TASK_CONTAINER,
);
assert_eq!(result.len(), 3);
assert_eq!(result[0], ContainerSource::Docker("foo:1.0".to_string()));
assert_eq!(
result[1],
ContainerSource::Docker(DEFAULT_TASK_CONTAINER.to_string())
);
assert_eq!(result[2], ContainerSource::Docker("bar:2.0".to_string()));
}
#[test]
fn container_resolves_wildcard_in_array_with_custom_default() {
use wdl_analysis::types::ArrayType;
let elements = vec![
PrimitiveValue::new_string("foo:1.0").into(),
PrimitiveValue::new_string("*").into(),
];
let array = crate::Array::new_unchecked(
ArrayType::new(wdl_analysis::types::PrimitiveType::String),
elements,
);
let requirements = map_with_value(
TASK_REQUIREMENT_CONTAINER,
Value::Compound(crate::CompoundValue::Array(array)),
);
let result = container(&TaskInputs::default(), &requirements, "alpine:3.18");
assert_eq!(result.len(), 2);
assert_eq!(result[0], ContainerSource::Docker("foo:1.0".to_string()));
assert_eq!(
result[1],
ContainerSource::Docker("alpine:3.18".to_string())
);
}
#[test]
fn respect_inputs_over_requirements() {
let mut inputs = TaskInputs::default();
inputs.override_requirement("container", PrimitiveValue::new_string("foo:bar"));
inputs.override_requirement("cpu", 1234);
inputs.override_requirement("gpu", true);
inputs.override_requirement("memory", 1234);
inputs.override_requirement("disks", 1234);
inputs.override_requirement("max_retries", 1234);
inputs.override_hint("gpu", 10);
let mut requirements: HashMap<String, Value> = Default::default();
requirements.insert(
"container".to_string(),
PrimitiveValue::new_string("baz:qux").into(),
);
requirements.insert("cpu".to_string(), PrimitiveValue::from(1.0).into());
requirements.insert("memory".to_string(), PrimitiveValue::from(1).into());
requirements.insert(
"disks".to_string(),
PrimitiveValue::new_string("1 GiB").into(),
);
requirements.insert("max_retries".to_string(), PrimitiveValue::from(1).into());
assert_eq!(
container(&inputs, &requirements, DEFAULT_TASK_CONTAINER)
.first()
.unwrap()
.to_string(),
"foo:bar"
);
assert_eq!(cpu(&inputs, &requirements), 1234.0);
assert_eq!(gpu(&inputs, &requirements, &Default::default()), Some(10));
assert_eq!(
disks(&inputs, &requirements, &Default::default()).unwrap(),
HashMap::from_iter([(
"/",
DiskRequirement {
size: 1234,
ty: None
}
)])
);
assert_eq!(
max_retries(&inputs, &requirements, &Default::default()).unwrap(),
1234
);
}
}