wdl_engine/backend/
local.rs

1//! Implementation of the local backend.
2
3use std::collections::HashMap;
4use std::ffi::OsStr;
5use std::fs;
6use std::fs::File;
7use std::path::Path;
8use std::process::Stdio;
9use std::sync::Arc;
10
11use anyhow::Context;
12use anyhow::Result;
13use anyhow::bail;
14use futures::FutureExt;
15use futures::future::BoxFuture;
16use tokio::process::Command;
17use tokio::select;
18use tokio::sync::oneshot;
19use tokio::task::JoinSet;
20use tokio_util::sync::CancellationToken;
21use tracing::info;
22use tracing::warn;
23
24use super::TaskExecutionBackend;
25use super::TaskExecutionConstraints;
26use super::TaskExecutionEvents;
27use super::TaskManager;
28use super::TaskManagerRequest;
29use super::TaskSpawnRequest;
30use crate::COMMAND_FILE_NAME;
31use crate::Input;
32use crate::ONE_GIBIBYTE;
33use crate::PrimitiveValue;
34use crate::STDERR_FILE_NAME;
35use crate::STDOUT_FILE_NAME;
36use crate::SYSTEM;
37use crate::TaskExecutionResult;
38use crate::Value;
39use crate::WORK_DIR_NAME;
40use crate::config::Config;
41use crate::config::DEFAULT_TASK_SHELL;
42use crate::config::LocalBackendConfig;
43use crate::config::TaskResourceLimitBehavior;
44use crate::convert_unit_string;
45use crate::http::Downloader;
46use crate::http::HttpDownloader;
47use crate::http::Location;
48use crate::path::EvaluationPath;
49use crate::v1::cpu;
50use crate::v1::memory;
51
52/// Represents a local task request.
53///
54/// This request contains the requested cpu and memory reservations for the task
55/// as well as the result receiver channel.
56#[derive(Debug)]
57struct LocalTaskRequest {
58    /// The engine configuration.
59    config: Arc<Config>,
60    /// The inner task spawn request.
61    inner: TaskSpawnRequest,
62    /// The requested CPU reservation for the task.
63    ///
64    /// Note that CPU isn't actually reserved for the task process.
65    cpu: f64,
66    /// The requested memory reservation for the task.
67    ///
68    /// Note that memory isn't actually reserved for the task process.
69    memory: u64,
70    /// The cancellation token for the request.
71    token: CancellationToken,
72}
73
74impl TaskManagerRequest for LocalTaskRequest {
75    fn cpu(&self) -> f64 {
76        self.cpu
77    }
78
79    fn memory(&self) -> u64 {
80        self.memory
81    }
82
83    async fn run(self, spawned: oneshot::Sender<()>) -> Result<TaskExecutionResult> {
84        // Create the working directory
85        let work_dir = self.inner.attempt_dir().join(WORK_DIR_NAME);
86        fs::create_dir_all(&work_dir).with_context(|| {
87            format!(
88                "failed to create directory `{path}`",
89                path = work_dir.display()
90            )
91        })?;
92
93        // Write the evaluated command to disk
94        let command_path = self.inner.attempt_dir().join(COMMAND_FILE_NAME);
95        fs::write(&command_path, self.inner.command()).with_context(|| {
96            format!(
97                "failed to write command contents to `{path}`",
98                path = command_path.display()
99            )
100        })?;
101
102        // Create a file for the stdout
103        let stdout_path = self.inner.attempt_dir().join(STDOUT_FILE_NAME);
104        let stdout = File::create(&stdout_path).with_context(|| {
105            format!(
106                "failed to create stdout file `{path}`",
107                path = stdout_path.display()
108            )
109        })?;
110
111        // Create a file for the stderr
112        let stderr_path = self.inner.attempt_dir().join(STDERR_FILE_NAME);
113        let stderr = File::create(&stderr_path).with_context(|| {
114            format!(
115                "failed to create stderr file `{path}`",
116                path = stderr_path.display()
117            )
118        })?;
119
120        let mut command = Command::new(
121            self.config
122                .task
123                .shell
124                .as_deref()
125                .unwrap_or(DEFAULT_TASK_SHELL),
126        );
127        command
128            .current_dir(&work_dir)
129            .arg(command_path)
130            .stdin(Stdio::null())
131            .stdout(stdout)
132            .stderr(stderr)
133            .envs(
134                self.inner
135                    .env()
136                    .iter()
137                    .map(|(k, v)| (OsStr::new(k), OsStr::new(v))),
138            )
139            .kill_on_drop(true);
140
141        // Set the PATH variable for the child on Windows to get consistent PATH
142        // searching. See: https://github.com/rust-lang/rust/issues/122660
143        #[cfg(windows)]
144        if let Ok(path) = std::env::var("PATH") {
145            command.env("PATH", path);
146        }
147
148        let mut child = command.spawn().context("failed to spawn shell")?;
149
150        // Notify that the process has spawned
151        spawned.send(()).ok();
152
153        let id = child.id().expect("should have id");
154        info!("spawned local shell process {id} for task execution");
155
156        select! {
157            // Poll the cancellation token before the child future
158            biased;
159
160            _ = self.token.cancelled() => {
161                bail!("task was cancelled");
162            }
163            status = child.wait() => {
164                let status = status.with_context(|| {
165                    format!("failed to wait for termination of task child process {id}")
166                })?;
167
168                #[cfg(unix)]
169                {
170                    use std::os::unix::process::ExitStatusExt;
171                    if let Some(signal) = status.signal() {
172                        tracing::warn!("task process {id} has terminated with signal {signal}");
173
174                        bail!(
175                            "task child process {id} has terminated with signal {signal}; see stderr file \
176                            `{path}` for more details",
177                            path = stderr_path.display()
178                        );
179                    }
180                }
181
182                let exit_code = status.code().expect("process should have exited");
183                info!("task process {id} has terminated with status code {exit_code}");
184                Ok(TaskExecutionResult {
185                    inputs: self.inner.info.inputs,
186                    exit_code,
187                    work_dir: EvaluationPath::Local(work_dir),
188                    stdout: PrimitiveValue::new_file(stdout_path.into_os_string().into_string().expect("path should be UTF-8")).into(),
189                    stderr: PrimitiveValue::new_file(stderr_path.into_os_string().into_string().expect("path should be UTF-8")).into(),
190                })
191            }
192        }
193    }
194}
195
196/// Represents a task execution backend that locally executes tasks.
197///
198/// <div class="warning">
199/// Warning: the local task execution backend spawns processes on the host
200/// directly without the use of a container; only use this backend on trusted
201/// WDL. </div>
202pub struct LocalBackend {
203    /// The engine configuration.
204    config: Arc<Config>,
205    /// The total CPU of the host.
206    cpu: u64,
207    /// The total memory of the host.
208    memory: u64,
209    /// The underlying task manager.
210    manager: TaskManager<LocalTaskRequest>,
211}
212
213impl LocalBackend {
214    /// Constructs a new local task execution backend with the given
215    /// configuration.
216    ///
217    /// The provided configuration is expected to have already been validated.
218    pub fn new(config: Arc<Config>, backend_config: &LocalBackendConfig) -> Result<Self> {
219        info!("initializing local backend");
220
221        let cpu = backend_config
222            .cpu
223            .unwrap_or_else(|| SYSTEM.cpus().len() as u64);
224        let memory = backend_config
225            .memory
226            .as_ref()
227            .map(|s| convert_unit_string(s).expect("value should be valid"))
228            .unwrap_or_else(|| SYSTEM.total_memory());
229        let manager = TaskManager::new(cpu, cpu, memory, memory);
230
231        Ok(Self {
232            config,
233            cpu,
234            memory,
235            manager,
236        })
237    }
238}
239
240impl TaskExecutionBackend for LocalBackend {
241    fn max_concurrency(&self) -> u64 {
242        self.cpu
243    }
244
245    fn constraints(
246        &self,
247        requirements: &HashMap<String, Value>,
248        _: &HashMap<String, Value>,
249    ) -> Result<TaskExecutionConstraints> {
250        let mut cpu = cpu(requirements);
251        if (self.cpu as f64) < cpu {
252            let env_specific = if self.config.suppress_env_specific_output {
253                String::new()
254            } else {
255                format!(
256                    ", but the host only has {total_cpu} available",
257                    total_cpu = self.cpu
258                )
259            };
260            match self.config.task.cpu_limit_behavior {
261                TaskResourceLimitBehavior::TryWithMax => {
262                    warn!(
263                        "task requires at least {cpu} CPU{s}{env_specific}",
264                        s = if cpu == 1.0 { "" } else { "s" },
265                    );
266                    // clamp the reported constraint to what's available
267                    cpu = self.cpu as f64;
268                }
269                TaskResourceLimitBehavior::Deny => {
270                    bail!(
271                        "task requires at least {cpu} CPU{s}{env_specific}",
272                        s = if cpu == 1.0 { "" } else { "s" },
273                    );
274                }
275            }
276        }
277
278        let mut memory = memory(requirements)?;
279        if self.memory < memory as u64 {
280            let env_specific = if self.config.suppress_env_specific_output {
281                String::new()
282            } else {
283                format!(
284                    ", but the host only has {total_memory} GiB available",
285                    total_memory = self.memory as f64 / ONE_GIBIBYTE,
286                )
287            };
288            match self.config.task.memory_limit_behavior {
289                TaskResourceLimitBehavior::TryWithMax => {
290                    warn!(
291                        "task requires at least {memory} GiB of memory{env_specific}",
292                        // Display the error in GiB, as it is the most common unit for memory
293                        memory = memory as f64 / ONE_GIBIBYTE,
294                    );
295                    // clamp the reported constraint to what's available
296                    memory = self.memory.try_into().unwrap_or(i64::MAX);
297                }
298                TaskResourceLimitBehavior::Deny => {
299                    bail!(
300                        "task requires at least {memory} GiB of memory{env_specific}",
301                        // Display the error in GiB, as it is the most common unit for memory
302                        memory = memory as f64 / ONE_GIBIBYTE,
303                    );
304                }
305            }
306        }
307
308        Ok(TaskExecutionConstraints {
309            container: None,
310            cpu,
311            memory,
312            gpu: Default::default(),
313            fpga: Default::default(),
314            disks: Default::default(),
315        })
316    }
317
318    fn guest_work_dir(&self) -> Option<&Path> {
319        // Local execution does not use a container
320        None
321    }
322
323    fn localize_inputs<'a, 'b, 'c, 'd>(
324        &'a self,
325        downloader: &'b HttpDownloader,
326        inputs: &'c mut [Input],
327    ) -> BoxFuture<'d, Result<()>>
328    where
329        'a: 'd,
330        'b: 'd,
331        'c: 'd,
332        Self: 'd,
333    {
334        async move {
335            let mut downloads = JoinSet::new();
336
337            for (idx, input) in inputs.iter_mut().enumerate() {
338                match input.path() {
339                    EvaluationPath::Local(path) => {
340                        let location = Location::Path(path.clone().into());
341                        let guest_path = location
342                            .to_str()
343                            .with_context(|| {
344                                format!("path `{path}` is not UTF-8", path = path.display())
345                            })?
346                            .to_string();
347                        input.set_location(location.into_owned());
348                        input.set_guest_path(guest_path);
349                    }
350                    EvaluationPath::Remote(url) => {
351                        let downloader = downloader.clone();
352                        let url = url.clone();
353                        downloads.spawn(async move {
354                            let location_result = downloader.download(&url).await;
355
356                            match location_result {
357                                Ok(location) => Ok((idx, location.into_owned())),
358                                Err(e) => bail!("failed to localize `{url}`: {e:?}"),
359                            }
360                        });
361                    }
362                }
363            }
364
365            while let Some(result) = downloads.join_next().await {
366                match result {
367                    Ok(Ok((idx, location))) => {
368                        let guest_path = location
369                            .to_str()
370                            .with_context(|| {
371                                format!(
372                                    "downloaded path `{path}` is not UTF-8",
373                                    path = location.display()
374                                )
375                            })?
376                            .to_string();
377
378                        let input = inputs.get_mut(idx).expect("index should be valid");
379                        input.set_location(location);
380                        input.set_guest_path(guest_path);
381                    }
382                    Ok(Err(e)) => {
383                        // Futures are aborted when the `JoinSet` is dropped.
384                        bail!(e);
385                    }
386                    Err(e) => {
387                        // Futures are aborted when the `JoinSet` is dropped.
388                        bail!("download task failed: {e}");
389                    }
390                }
391            }
392
393            Ok(())
394        }
395        .boxed()
396    }
397
398    fn spawn(
399        &self,
400        request: TaskSpawnRequest,
401        token: CancellationToken,
402    ) -> Result<TaskExecutionEvents> {
403        let (spawned_tx, spawned_rx) = oneshot::channel();
404        let (completed_tx, completed_rx) = oneshot::channel();
405
406        let requirements = request.requirements();
407        let mut cpu = cpu(requirements);
408        if let TaskResourceLimitBehavior::TryWithMax = self.config.task.cpu_limit_behavior {
409            cpu = std::cmp::min(cpu.ceil() as u64, self.cpu) as f64;
410        }
411        let mut memory = memory(requirements)? as u64;
412        if let TaskResourceLimitBehavior::TryWithMax = self.config.task.memory_limit_behavior {
413            memory = std::cmp::min(memory, self.memory);
414        }
415
416        self.manager.send(
417            LocalTaskRequest {
418                config: self.config.clone(),
419                inner: request,
420                cpu,
421                memory,
422                token,
423            },
424            spawned_tx,
425            completed_tx,
426        );
427
428        Ok(TaskExecutionEvents {
429            spawned: spawned_rx,
430            completed: completed_rx,
431        })
432    }
433}