wdl_engine/backend/
tes.rs

1//! Implementation of the TES backend.
2
3use std::collections::HashMap;
4use std::fs;
5use std::sync::Arc;
6use std::sync::Mutex;
7
8use anyhow::Context;
9use anyhow::Result;
10use anyhow::bail;
11use cloud_copy::UrlExt;
12use crankshaft::config::backend;
13use crankshaft::config::backend::tes::http::HttpAuthConfig;
14use crankshaft::engine::Task;
15use crankshaft::engine::service::name::GeneratorIterator;
16use crankshaft::engine::service::name::UniqueAlphanumeric;
17use crankshaft::engine::service::runner::Backend;
18use crankshaft::engine::service::runner::backend::TaskRunError;
19use crankshaft::engine::service::runner::backend::tes;
20use crankshaft::engine::task::Execution;
21use crankshaft::engine::task::Input;
22use crankshaft::engine::task::Output;
23use crankshaft::engine::task::Resources;
24use crankshaft::engine::task::input::Contents;
25use crankshaft::engine::task::input::Type as InputType;
26use crankshaft::engine::task::output::Type as OutputType;
27use crankshaft::events::Event;
28use nonempty::NonEmpty;
29use secrecy::ExposeSecret;
30use tokio::sync::broadcast;
31use tokio::sync::oneshot;
32use tokio::sync::oneshot::Receiver;
33use tokio::task::JoinSet;
34use tokio_util::sync::CancellationToken;
35use tracing::info;
36use wdl_ast::v1::TASK_REQUIREMENT_DISKS;
37
38use super::TaskExecutionBackend;
39use super::TaskExecutionConstraints;
40use super::TaskExecutionResult;
41use super::TaskManager;
42use super::TaskManagerRequest;
43use super::TaskSpawnRequest;
44use crate::ONE_GIBIBYTE;
45use crate::PrimitiveValue;
46use crate::Value;
47use crate::backend::COMMAND_FILE_NAME;
48use crate::backend::INITIAL_EXPECTED_NAMES;
49use crate::backend::STDERR_FILE_NAME;
50use crate::backend::STDOUT_FILE_NAME;
51use crate::backend::WORK_DIR_NAME;
52use crate::config::Config;
53use crate::config::DEFAULT_TASK_SHELL;
54use crate::config::TesBackendAuthConfig;
55use crate::config::TesBackendConfig;
56use crate::digest::UrlDigestExt;
57use crate::digest::calculate_local_digest;
58use crate::path::EvaluationPath;
59use crate::v1::DEFAULT_TASK_REQUIREMENT_DISKS;
60use crate::v1::container;
61use crate::v1::cpu;
62use crate::v1::disks;
63use crate::v1::max_cpu;
64use crate::v1::max_memory;
65use crate::v1::memory;
66use crate::v1::preemptible;
67
68/// The root guest path for inputs.
69const GUEST_INPUTS_DIR: &str = "/mnt/task/inputs/";
70
71/// The guest working directory.
72const GUEST_WORK_DIR: &str = "/mnt/task/work";
73
74/// The guest path for the command file.
75const GUEST_COMMAND_PATH: &str = "/mnt/task/command";
76
77/// The path to the container's stdout.
78const GUEST_STDOUT_PATH: &str = "/mnt/task/stdout";
79
80/// The path to the container's stderr.
81const GUEST_STDERR_PATH: &str = "/mnt/task/stderr";
82
83/// The default poll interval, in seconds, for the TES backend.
84const DEFAULT_TES_INTERVAL: u64 = 1;
85
86/// Represents a TES task request.
87///
88/// This request contains the requested cpu and memory reservations for the task
89/// as well as the result receiver channel.
90#[derive(Debug)]
91struct TesTaskRequest {
92    /// The engine configuration.
93    config: Arc<Config>,
94    /// The backend configuration.
95    backend_config: Arc<TesBackendConfig>,
96    /// The inner task spawn request.
97    inner: TaskSpawnRequest,
98    /// The Crankshaft TES backend to use.
99    backend: Arc<tes::Backend>,
100    /// The name of the task.
101    name: String,
102    /// The requested container for the task.
103    container: String,
104    /// The requested CPU reservation for the task.
105    cpu: f64,
106    /// The requested memory reservation for the task, in bytes.
107    memory: u64,
108    /// The requested maximum CPU limit for the task.
109    max_cpu: Option<f64>,
110    /// The requested maximum memory limit for the task, in bytes.
111    max_memory: Option<u64>,
112    /// The number of preemptible task retries to do before using a
113    /// non-preemptible task.
114    ///
115    /// If this value is 0, no preemptible tasks are requested from the TES
116    /// server.
117    preemptible: i64,
118    /// The cancellation token for the request.
119    token: CancellationToken,
120}
121
122impl TesTaskRequest {
123    /// Gets the TES disk resource for the request.
124    fn disk_resource(&self) -> Result<f64> {
125        let disks = disks(self.inner.requirements(), self.inner.hints())?;
126        if disks.len() > 1 {
127            bail!(
128                "TES backend does not support more than one disk specification for the \
129                 `{TASK_REQUIREMENT_DISKS}` task requirement"
130            );
131        }
132
133        if let Some(mount_point) = disks.keys().next()
134            && *mount_point != "/"
135        {
136            bail!(
137                "TES backend does not support a disk mount point other than `/` for the \
138                 `{TASK_REQUIREMENT_DISKS}` task requirement"
139            );
140        }
141
142        Ok(disks
143            .values()
144            .next()
145            .map(|d| d.size as f64)
146            .unwrap_or(DEFAULT_TASK_REQUIREMENT_DISKS))
147    }
148}
149
150impl TaskManagerRequest for TesTaskRequest {
151    fn cpu(&self) -> f64 {
152        self.cpu
153    }
154
155    fn memory(&self) -> u64 {
156        self.memory
157    }
158
159    async fn run(self) -> Result<TaskExecutionResult> {
160        // Create the attempt directory
161        let attempt_dir = self.inner.attempt_dir();
162        fs::create_dir_all(attempt_dir).with_context(|| {
163            format!(
164                "failed to create directory `{path}`",
165                path = attempt_dir.display()
166            )
167        })?;
168
169        // Write the evaluated command to disk
170        // This is done even for remote execution so that a copy exists locally
171        let command_path = attempt_dir.join(COMMAND_FILE_NAME);
172        fs::write(&command_path, self.inner.command()).with_context(|| {
173            format!(
174                "failed to write command contents to `{path}`",
175                path = command_path.display()
176            )
177        })?;
178
179        // SAFETY: currently `inputs` is required by configuration validation, so it
180        // should always unwrap
181        let inputs_url = Arc::new(
182            self.backend_config
183                .inputs
184                .clone()
185                .expect("should have inputs URL"),
186        );
187
188        // Start with the command file as an input
189        let mut inputs = vec![
190            Input::builder()
191                .path(GUEST_COMMAND_PATH)
192                .contents(Contents::Path(command_path.to_path_buf()))
193                .ty(InputType::File)
194                .read_only(true)
195                .build(),
196        ];
197
198        // Spawn upload tasks for inputs available locally, and apply authentication to
199        // the URLs for remote inputs.
200        let mut uploads = JoinSet::new();
201        for (i, input) in self.inner.inputs().iter().enumerate() {
202            match input.path() {
203                EvaluationPath::Local(path) => {
204                    // Input is local, spawn an upload of it
205                    let kind = input.kind();
206                    let path = path.to_path_buf();
207                    let transferer = self.inner.transferer().clone();
208                    let inputs_url = inputs_url.clone();
209                    uploads.spawn(async move {
210                        let url = inputs_url.join_digest(
211                            calculate_local_digest(&path, kind).await.with_context(|| {
212                                format!(
213                                    "failed to calculate digest of `{path}`",
214                                    path = path.display()
215                                )
216                            })?,
217                        );
218                        transferer
219                            .upload(&path, &url)
220                            .await
221                            .with_context(|| {
222                                format!(
223                                    "failed to upload `{path}` to `{url}`",
224                                    path = path.display(),
225                                    url = url.display()
226                                )
227                            })
228                            .map(|_| (i, url))
229                    });
230                }
231                EvaluationPath::Remote(url) => {
232                    // Input is already remote, add it to the Crankshaft inputs list
233                    inputs.push(
234                        Input::builder()
235                            .path(
236                                input
237                                    .guest_path()
238                                    .expect("input should have guest path")
239                                    .as_str(),
240                            )
241                            .contents(Contents::Url(url.clone()))
242                            .ty(input.kind())
243                            .read_only(true)
244                            .build(),
245                    );
246                }
247            }
248        }
249
250        // Wait for any uploads to complete
251        while let Some(result) = uploads.join_next().await {
252            let (i, url) = result.context("upload task")??;
253            let input = &self.inner.inputs()[i];
254            inputs.push(
255                Input::builder()
256                    .path(
257                        input
258                            .guest_path()
259                            .expect("input should have guest path")
260                            .as_str(),
261                    )
262                    .contents(Contents::Url(url))
263                    .ty(input.kind())
264                    .read_only(true)
265                    .build(),
266            );
267        }
268
269        let output_dir = format!(
270            "{name}-{timestamp}/",
271            name = self.name,
272            timestamp = chrono::Utc::now().format("%Y%m%d-%H%M%S")
273        );
274
275        // SAFETY: currently `outputs` is required by configuration validation, so it
276        // should always unwrap
277        let outputs_url = self
278            .backend_config
279            .outputs
280            .as_ref()
281            .expect("should have outputs URL")
282            .join(&output_dir)
283            .expect("should join");
284
285        let mut work_dir_url = outputs_url.join(WORK_DIR_NAME).expect("should join");
286        let stdout_url = outputs_url.join(STDOUT_FILE_NAME).expect("should join");
287        let stderr_url = outputs_url.join(STDERR_FILE_NAME).expect("should join");
288
289        // The TES backend will output three things: the working directory contents,
290        // stdout, and stderr.
291        let outputs = vec![
292            Output::builder()
293                .path(GUEST_WORK_DIR)
294                .url(work_dir_url.clone())
295                .ty(OutputType::Directory)
296                .build(),
297            Output::builder()
298                .path(GUEST_STDOUT_PATH)
299                .url(stdout_url.clone())
300                .ty(OutputType::File)
301                .build(),
302            Output::builder()
303                .path(GUEST_STDERR_PATH)
304                .url(stderr_url.clone())
305                .ty(OutputType::File)
306                .build(),
307        ];
308
309        let mut preemptible = self.preemptible;
310        loop {
311            let task = Task::builder()
312                .name(&self.name)
313                .executions(NonEmpty::new(
314                    Execution::builder()
315                        .image(&self.container)
316                        .program(
317                            self.config
318                                .task
319                                .shell
320                                .as_deref()
321                                .unwrap_or(DEFAULT_TASK_SHELL),
322                        )
323                        .args([GUEST_COMMAND_PATH.to_string()])
324                        .work_dir(GUEST_WORK_DIR)
325                        .env(self.inner.env().clone())
326                        .stdout(GUEST_STDOUT_PATH)
327                        .stderr(GUEST_STDERR_PATH)
328                        .build(),
329                ))
330                .inputs(inputs.clone())
331                .outputs(outputs.clone())
332                .resources(
333                    Resources::builder()
334                        .cpu(self.cpu)
335                        .maybe_cpu_limit(self.max_cpu)
336                        .ram(self.memory as f64 / ONE_GIBIBYTE)
337                        .disk(self.disk_resource()?)
338                        .maybe_ram_limit(self.max_memory.map(|m| m as f64 / ONE_GIBIBYTE))
339                        .preemptible(preemptible > 0)
340                        .build(),
341                )
342                .build();
343
344            let statuses = match self.backend.run(task, self.token.clone())?.await {
345                Ok(statuses) => statuses,
346                Err(TaskRunError::Preempted) if preemptible > 0 => {
347                    // Decrement the preemptible count and retry
348                    preemptible -= 1;
349                    continue;
350                }
351                Err(e) => {
352                    return Err(e.into());
353                }
354            };
355
356            assert_eq!(statuses.len(), 1, "there should only be one output");
357            let status = statuses.first();
358
359            // Push an empty path segment so that future joins of the work directory URL
360            // treat it as a directory
361            work_dir_url.path_segments_mut().unwrap().push("");
362
363            return Ok(TaskExecutionResult {
364                exit_code: status.code().expect("should have exit code"),
365                work_dir: EvaluationPath::Remote(work_dir_url),
366                stdout: PrimitiveValue::new_file(stdout_url).into(),
367                stderr: PrimitiveValue::new_file(stderr_url).into(),
368            });
369        }
370    }
371}
372
373/// Represents the Task Execution Service (TES) backend.
374pub struct TesBackend {
375    /// The engine configuration.
376    config: Arc<Config>,
377    /// The backend configuration.
378    backend_config: Arc<TesBackendConfig>,
379    /// The underlying Crankshaft backend.
380    inner: Arc<tes::Backend>,
381    /// The maximum CPUs for any of one node.
382    max_cpu: u64,
383    /// The maximum memory for any of one node.
384    max_memory: u64,
385    /// The task manager for the backend.
386    manager: TaskManager<TesTaskRequest>,
387    /// The name generator for tasks.
388    names: Arc<Mutex<GeneratorIterator<UniqueAlphanumeric>>>,
389}
390
391impl TesBackend {
392    /// Constructs a new TES task execution backend with the given
393    /// configuration.
394    ///
395    /// The provided configuration is expected to have already been validated.
396    pub async fn new(
397        config: Arc<Config>,
398        backend_config: &TesBackendConfig,
399        events: Option<broadcast::Sender<Event>>,
400    ) -> Result<Self> {
401        info!("initializing TES backend");
402
403        // There's no way to ask the TES service for its limits, so use the maximums
404        // allowed
405        let max_cpu = u64::MAX;
406        let max_memory = u64::MAX;
407        let manager = TaskManager::new_unlimited(max_cpu, max_memory);
408
409        let mut http = backend::tes::http::Config::default();
410        match &backend_config.auth {
411            Some(TesBackendAuthConfig::Basic(config)) => {
412                http.auth = Some(HttpAuthConfig::Basic {
413                    username: config.username.clone(),
414                    password: config.password.inner().expose_secret().to_string(),
415                });
416            }
417            Some(TesBackendAuthConfig::Bearer(config)) => {
418                http.auth = Some(HttpAuthConfig::Bearer {
419                    token: config.token.inner().expose_secret().to_string(),
420                });
421            }
422            None => {}
423        }
424
425        http.retries = backend_config.retries;
426        http.max_concurrency = backend_config.max_concurrency.map(|c| c as usize);
427
428        let names = Arc::new(Mutex::new(GeneratorIterator::new(
429            UniqueAlphanumeric::default_with_expected_generations(INITIAL_EXPECTED_NAMES),
430            INITIAL_EXPECTED_NAMES,
431        )));
432
433        let backend = tes::Backend::initialize(
434            backend::tes::Config::builder()
435                .url(backend_config.url.clone().expect("should have URL"))
436                .http(http)
437                .interval(backend_config.interval.unwrap_or(DEFAULT_TES_INTERVAL))
438                .build(),
439            names.clone(),
440            events,
441        )
442        .await;
443
444        Ok(Self {
445            config,
446            backend_config: Arc::new(backend_config.clone()),
447            inner: Arc::new(backend),
448            max_cpu,
449            max_memory,
450            manager,
451            names,
452        })
453    }
454}
455
456impl TaskExecutionBackend for TesBackend {
457    fn max_concurrency(&self) -> u64 {
458        // The TES backend doesn't limit the number of tasks that can be queued at a
459        // time
460        u64::MAX
461    }
462
463    fn constraints(
464        &self,
465        requirements: &HashMap<String, Value>,
466        hints: &HashMap<String, Value>,
467    ) -> Result<TaskExecutionConstraints> {
468        let container = container(requirements, self.config.task.container.as_deref());
469
470        let cpu = cpu(requirements);
471        if (self.max_cpu as f64) < cpu {
472            bail!(
473                "task requires at least {cpu} CPU{s}, but the execution backend has a maximum of \
474                 {max_cpu}",
475                s = if cpu == 1.0 { "" } else { "s" },
476                max_cpu = self.max_cpu,
477            );
478        }
479
480        let memory = memory(requirements)?;
481        if self.max_memory < memory as u64 {
482            // Display the error in GiB, as it is the most common unit for memory
483            let memory = memory as f64 / ONE_GIBIBYTE;
484            let max_memory = self.max_memory as f64 / ONE_GIBIBYTE;
485
486            bail!(
487                "task requires at least {memory} GiB of memory, but the execution backend has a \
488                 maximum of {max_memory} GiB",
489            );
490        }
491
492        // TODO: only parse the disks requirement once
493        let disks = disks(requirements, hints)?
494            .into_iter()
495            .map(|(mp, disk)| (mp.to_string(), disk.size))
496            .collect();
497
498        Ok(TaskExecutionConstraints {
499            container: Some(container.into_owned()),
500            cpu,
501            memory,
502            gpu: Default::default(),
503            fpga: Default::default(),
504            disks,
505        })
506    }
507
508    fn guest_inputs_dir(&self) -> Option<&'static str> {
509        Some(GUEST_INPUTS_DIR)
510    }
511
512    fn needs_local_inputs(&self) -> bool {
513        false
514    }
515
516    fn spawn(
517        &self,
518        request: TaskSpawnRequest,
519        token: CancellationToken,
520    ) -> Result<Receiver<Result<TaskExecutionResult>>> {
521        let (completed_tx, completed_rx) = oneshot::channel();
522
523        let requirements = request.requirements();
524        let hints = request.hints();
525
526        let container = container(requirements, self.config.task.container.as_deref()).into_owned();
527        let cpu = cpu(requirements);
528        let memory = memory(requirements)? as u64;
529        let max_cpu = max_cpu(hints);
530        let max_memory = max_memory(hints)?.map(|i| i as u64);
531        let preemptible = preemptible(hints);
532
533        let name = format!(
534            "{id}-{generated}",
535            id = request.id(),
536            generated = self
537                .names
538                .lock()
539                .expect("generator should always acquire")
540                .next()
541                .expect("generator should never be exhausted")
542        );
543        self.manager.send(
544            TesTaskRequest {
545                config: self.config.clone(),
546                backend_config: self.backend_config.clone(),
547                inner: request,
548                backend: self.inner.clone(),
549                name,
550                container,
551                cpu,
552                memory,
553                max_cpu,
554                max_memory,
555                token,
556                preemptible,
557            },
558            completed_tx,
559        );
560
561        Ok(completed_rx)
562    }
563}