Skip to main content

roder_ext_task_process/
task.rs

1use std::collections::BTreeMap;
2use std::path::{Path, PathBuf};
3use std::process::Stdio;
4use std::sync::Arc;
5
6use anyhow::{Context, bail};
7use roder_api::processes::{
8    ProcessDescriptor, ProcessOrigin, ProcessState, ProcessStopper, command_summary,
9};
10use roder_api::remote_runner::RunnerCommandRequest;
11use roder_api::tasks::{
12    TaskExecutionContext, TaskExecutionResult, TaskExecutor, TaskOutputStream, TaskSpec,
13};
14use serde::Deserialize;
15use tokio::io::{AsyncBufReadExt, BufReader};
16use tokio::process::Command;
17use tokio::sync::{Mutex, oneshot};
18
19pub const PROCESS_TASK_EXECUTOR_ID: &str = "process";
20
21#[derive(Debug, Clone, Deserialize)]
22struct ProcessTaskInput {
23    command: String,
24    #[serde(default)]
25    args: Vec<String>,
26    #[serde(default)]
27    cwd: Option<String>,
28    #[serde(default)]
29    env_overrides: BTreeMap<String, String>,
30}
31
32#[derive(Debug, Clone)]
33pub struct ProcessTaskExecutor;
34
35#[async_trait::async_trait]
36impl TaskExecutor for ProcessTaskExecutor {
37    fn id(&self) -> String {
38        PROCESS_TASK_EXECUTOR_ID.to_string()
39    }
40
41    fn spec(&self) -> TaskSpec {
42        TaskSpec {
43            kind: PROCESS_TASK_EXECUTOR_ID.to_string(),
44            description: "Run a background process inside the workspace.".to_string(),
45            input_schema: serde_json::json!({
46                "type": "object",
47                "required": ["command"],
48                "properties": {
49                    "command": { "type": "string" },
50                    "args": { "type": "array", "items": { "type": "string" } },
51                    "cwd": { "type": "string" },
52                    "env_overrides": {
53                        "type": "object",
54                        "additionalProperties": { "type": "string" }
55                    }
56                },
57                "additionalProperties": false
58            }),
59            default_timeout_seconds: None,
60            metadata: serde_json::json!({ "category": "process" }),
61        }
62    }
63
64    async fn execute(
65        &self,
66        ctx: TaskExecutionContext,
67        input: serde_json::Value,
68    ) -> anyhow::Result<TaskExecutionResult> {
69        let input: ProcessTaskInput =
70            serde_json::from_value(input).context("deserialize process task input")?;
71        if input.command.trim().is_empty() {
72            bail!("process task command must not be empty");
73        }
74
75        if ctx.runner_session.is_some() {
76            return execute_remote_process_task(ctx, input).await;
77        }
78
79        let cwd = resolve_cwd(ctx.workspace_root.as_deref(), input.cwd.as_deref())?;
80        let command_parts = std::iter::once(input.command.clone())
81            .chain(input.args.clone())
82            .collect::<Vec<_>>();
83        let mut command = Command::new(&input.command);
84        command
85            .args(&input.args)
86            .current_dir(&cwd)
87            .envs(&input.env_overrides)
88            .stdout(Stdio::piped())
89            .stderr(Stdio::piped())
90            .kill_on_drop(true);
91
92        let mut child = command
93            .spawn()
94            .with_context(|| format!("spawn process task {:?}", input.command))?;
95        let pid = child.id();
96        let stdout = child.stdout.take();
97        let stderr = child.stderr.take();
98        let output = Arc::new(ctx.output);
99        let process_id = format!("task-{}", ctx.task_id);
100        let (stop_tx, stop_rx) = oneshot::channel();
101        if let Some(registry) = ctx.process_registry.as_ref() {
102            registry
103                .register_process(
104                    ProcessDescriptor {
105                        process_id: process_id.clone(),
106                        origin: ProcessOrigin::BackgroundTask,
107                        state: ProcessState::Running,
108                        command: command_parts.clone(),
109                        command_summary: command_summary(&command_parts),
110                        cwd: Some(cwd.display().to_string()),
111                        pid,
112                        task_id: Some(ctx.task_id.clone()),
113                        thread_id: ctx.thread_id.clone(),
114                        turn_id: ctx.turn_id.clone(),
115                        runner_destination_id: None,
116                        runner_session_id: None,
117                        stoppable: true,
118                        started_at: time::OffsetDateTime::now_utc(),
119                        updated_at: time::OffsetDateTime::now_utc(),
120                        stdout_tail: None,
121                        stderr_tail: None,
122                    },
123                    Some(Arc::new(ChannelProcessStopper::new(stop_tx))),
124                )
125                .await?;
126        }
127
128        let stdout_task = tokio::spawn(stream_pipe(
129            stdout,
130            TaskOutputStream::Stdout,
131            Arc::clone(&output),
132        ));
133        let stderr_task = tokio::spawn(stream_pipe(
134            stderr,
135            TaskOutputStream::Stderr,
136            Arc::clone(&output),
137        ));
138        let (status, stopped_by_registry) = tokio::select! {
139            status = child.wait() => (status.context("wait for process task")?, false),
140            _ = stop_rx => {
141                child.kill().await.context("kill stopped process task")?;
142                if let Some(registry) = ctx.process_registry.as_ref() {
143                    registry
144                        .mark_process_stopped(&process_id, Some("stop requested".to_string()))
145                        .await?;
146                }
147                (child.wait().await.context("wait for stopped process task")?, true)
148            }
149        };
150        stdout_task.await.context("join stdout reader")??;
151        stderr_task.await.context("join stderr reader")??;
152        if let Some(registry) = ctx.process_registry.as_ref()
153            && !stopped_by_registry
154        {
155            let _ = registry
156                .mark_process_exited(&process_id, status.code())
157                .await;
158        }
159
160        Ok(TaskExecutionResult {
161            exit_code: status.code(),
162            payload: serde_json::json!({
163                "command": input.command,
164                "args": input.args,
165                "cwd": cwd.display().to_string(),
166                "success": status.success(),
167            }),
168        })
169    }
170}
171
172struct ChannelProcessStopper {
173    stop_tx: Mutex<Option<oneshot::Sender<Option<String>>>>,
174}
175
176impl ChannelProcessStopper {
177    fn new(stop_tx: oneshot::Sender<Option<String>>) -> Self {
178        Self {
179            stop_tx: Mutex::new(Some(stop_tx)),
180        }
181    }
182}
183
184#[async_trait::async_trait]
185impl ProcessStopper for ChannelProcessStopper {
186    async fn stop(&self, reason: Option<String>) -> anyhow::Result<()> {
187        if let Some(stop_tx) = self.stop_tx.lock().await.take() {
188            let _ = stop_tx.send(reason);
189        }
190        Ok(())
191    }
192}
193
194async fn execute_remote_process_task(
195    ctx: TaskExecutionContext,
196    input: ProcessTaskInput,
197) -> anyhow::Result<TaskExecutionResult> {
198    let Some(session) = ctx.runner_session.clone() else {
199        bail!("remote process task requires runner session");
200    };
201    let command_id = ctx.task_id.clone();
202    let command_parts = std::iter::once(input.command.clone())
203        .chain(input.args.clone())
204        .collect::<Vec<_>>();
205    let state = session.state();
206    let process_id = format!("remote-{}", ctx.task_id);
207    if let Some(registry) = ctx.process_registry.as_ref() {
208        registry
209            .register_process(
210                ProcessDescriptor {
211                    process_id: process_id.clone(),
212                    origin: ProcessOrigin::RemoteRunner,
213                    state: ProcessState::Running,
214                    command: command_parts.clone(),
215                    command_summary: command_summary(&command_parts),
216                    cwd: input.cwd.clone(),
217                    pid: None,
218                    task_id: Some(ctx.task_id.clone()),
219                    thread_id: ctx.thread_id.clone(),
220                    turn_id: ctx.turn_id.clone(),
221                    runner_destination_id: ctx
222                        .runner_destination
223                        .as_ref()
224                        .map(|destination| destination.id.clone())
225                        .or_else(|| Some(state.destination_id.clone())),
226                    runner_session_id: Some(state.session_id.clone()),
227                    stoppable: true,
228                    started_at: time::OffsetDateTime::now_utc(),
229                    updated_at: time::OffsetDateTime::now_utc(),
230                    stdout_tail: None,
231                    stderr_tail: None,
232                },
233                Some(Arc::new(RemoteCommandStopper {
234                    session: Arc::clone(&session),
235                    command_id: command_id.clone(),
236                })),
237            )
238            .await?;
239    }
240    let output = match session
241        .run_command(RunnerCommandRequest {
242            command_id: command_id.clone(),
243            program: input.command.clone(),
244            args: input.args.clone(),
245            cwd: input.cwd.as_deref().map(PathBuf::from),
246            env: input.env_overrides.clone().into_iter().collect(),
247        })
248        .await
249    {
250        Ok(output) => output,
251        Err(error) => {
252            if let Some(registry) = ctx.process_registry.as_ref() {
253                let _ = registry
254                    .mark_process_failed(&process_id, error.to_string())
255                    .await;
256            }
257            return Err(error);
258        }
259    };
260    if !output.stdout.is_empty() {
261        ctx.output
262            .write(TaskOutputStream::Stdout, output.stdout.clone())
263            .await?;
264    }
265    if !output.stderr.is_empty() {
266        ctx.output
267            .write(TaskOutputStream::Stderr, output.stderr.clone())
268            .await?;
269    }
270    if let Some(registry) = ctx.process_registry.as_ref() {
271        let _ = registry
272            .mark_process_exited(&process_id, output.exit_code)
273            .await;
274    }
275    Ok(TaskExecutionResult {
276        exit_code: output.exit_code,
277        payload: serde_json::json!({
278            "command": input.command,
279            "args": input.args,
280            "cwd": input.cwd.unwrap_or_else(|| ".".to_string()),
281            "runner_destination": ctx.runner_destination.as_ref().map(|destination| &destination.id),
282            "runner_session": session.state().session_id,
283            "success": output.exit_code == Some(0),
284        }),
285    })
286}
287
288struct RemoteCommandStopper {
289    session: Arc<dyn roder_api::remote_runner::RemoteRunnerSession>,
290    command_id: String,
291}
292
293#[async_trait::async_trait]
294impl ProcessStopper for RemoteCommandStopper {
295    async fn stop(&self, _reason: Option<String>) -> anyhow::Result<()> {
296        let cancelled = self.session.cancel_command(&self.command_id).await?;
297        if cancelled {
298            Ok(())
299        } else {
300            bail!("remote runner did not cancel command {:?}", self.command_id)
301        }
302    }
303}
304
305async fn stream_pipe(
306    pipe: Option<impl tokio::io::AsyncRead + Unpin>,
307    stream: TaskOutputStream,
308    output: Arc<roder_api::tasks::TaskOutputSink>,
309) -> anyhow::Result<()> {
310    let Some(pipe) = pipe else {
311        return Ok(());
312    };
313    let mut reader = BufReader::new(pipe);
314    let mut buf = Vec::new();
315    loop {
316        buf.clear();
317        let bytes = reader.read_until(b'\n', &mut buf).await?;
318        if bytes == 0 {
319            break;
320        }
321        output
322            .write(stream.clone(), String::from_utf8_lossy(&buf).to_string())
323            .await?;
324    }
325    Ok(())
326}
327
328fn resolve_cwd(workspace_root: Option<&str>, cwd: Option<&str>) -> anyhow::Result<PathBuf> {
329    let Some(root) = workspace_root else {
330        return match cwd {
331            Some(cwd) => Ok(PathBuf::from(cwd)),
332            None => std::env::current_dir().context("resolve current directory"),
333        };
334    };
335    let root = std::fs::canonicalize(root).with_context(|| format!("canonicalize root {root}"))?;
336    let candidate = match cwd {
337        Some(cwd) => {
338            let path = Path::new(cwd);
339            if path.is_absolute() {
340                path.to_path_buf()
341            } else {
342                root.join(path)
343            }
344        }
345        None => root.clone(),
346    };
347    let candidate = std::fs::canonicalize(&candidate)
348        .with_context(|| format!("canonicalize cwd {}", candidate.display()))?;
349    if !candidate.starts_with(&root) {
350        bail!(
351            "process task cwd {} escapes workspace root {}",
352            candidate.display(),
353            root.display()
354        );
355    }
356    Ok(candidate)
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn resolve_cwd_rejects_paths_outside_workspace() {
365        let root = std::env::current_dir().unwrap();
366        let outside = root.parent().unwrap_or(&root);
367        let err = resolve_cwd(
368            Some(root.to_str().unwrap()),
369            Some(outside.to_str().unwrap()),
370        )
371        .unwrap_err();
372
373        assert!(err.to_string().contains("escapes workspace root"));
374    }
375
376    #[test]
377    fn schema_snapshot_covers_process_task_input() {
378        let executor = ProcessTaskExecutor;
379        let spec = executor
380            .spec()
381            .normalized_for_model(roder_api::ToolSchemaPolicy::strict());
382        let schema = serde_json::to_string(&spec.input_schema).unwrap();
383
384        assert!(schema.starts_with(r#"{"type":"object","required":["command"],"properties":"#));
385        assert!(schema.contains(
386            r#""env_overrides":{"type":"object","additionalProperties":{"type":"string"}}"#
387        ));
388        assert!(schema.contains(r#""additionalProperties":false"#));
389    }
390}