1use std::collections::BTreeMap;
2use std::path::{Path, PathBuf};
3use std::process::Stdio;
4use std::sync::Arc;
5
6use anyhow::{Context, bail};
7use roder_api::processes::{
8 ProcessDescriptor, ProcessOrigin, ProcessState, ProcessStopper, command_summary,
9};
10use roder_api::remote_runner::RunnerCommandRequest;
11use roder_api::tasks::{
12 TaskExecutionContext, TaskExecutionResult, TaskExecutor, TaskOutputStream, TaskSpec,
13};
14use serde::Deserialize;
15use tokio::io::{AsyncBufReadExt, BufReader};
16use tokio::process::Command;
17use tokio::sync::{Mutex, oneshot};
18
19pub const PROCESS_TASK_EXECUTOR_ID: &str = "process";
20
21#[derive(Debug, Clone, Deserialize)]
22struct ProcessTaskInput {
23 command: String,
24 #[serde(default)]
25 args: Vec<String>,
26 #[serde(default)]
27 cwd: Option<String>,
28 #[serde(default)]
29 env_overrides: BTreeMap<String, String>,
30}
31
32#[derive(Debug, Clone)]
33pub struct ProcessTaskExecutor;
34
35#[async_trait::async_trait]
36impl TaskExecutor for ProcessTaskExecutor {
37 fn id(&self) -> String {
38 PROCESS_TASK_EXECUTOR_ID.to_string()
39 }
40
41 fn spec(&self) -> TaskSpec {
42 TaskSpec {
43 kind: PROCESS_TASK_EXECUTOR_ID.to_string(),
44 description: "Run a background process inside the workspace.".to_string(),
45 input_schema: serde_json::json!({
46 "type": "object",
47 "required": ["command"],
48 "properties": {
49 "command": { "type": "string" },
50 "args": { "type": "array", "items": { "type": "string" } },
51 "cwd": { "type": "string" },
52 "env_overrides": {
53 "type": "object",
54 "additionalProperties": { "type": "string" }
55 }
56 },
57 "additionalProperties": false
58 }),
59 default_timeout_seconds: None,
60 metadata: serde_json::json!({ "category": "process" }),
61 }
62 }
63
64 async fn execute(
65 &self,
66 ctx: TaskExecutionContext,
67 input: serde_json::Value,
68 ) -> anyhow::Result<TaskExecutionResult> {
69 let input: ProcessTaskInput =
70 serde_json::from_value(input).context("deserialize process task input")?;
71 if input.command.trim().is_empty() {
72 bail!("process task command must not be empty");
73 }
74
75 if ctx.runner_session.is_some() {
76 return execute_remote_process_task(ctx, input).await;
77 }
78
79 let cwd = resolve_cwd(ctx.workspace_root.as_deref(), input.cwd.as_deref())?;
80 let command_parts = std::iter::once(input.command.clone())
81 .chain(input.args.clone())
82 .collect::<Vec<_>>();
83 let mut command = Command::new(&input.command);
84 command
85 .args(&input.args)
86 .current_dir(&cwd)
87 .envs(&input.env_overrides)
88 .stdout(Stdio::piped())
89 .stderr(Stdio::piped())
90 .kill_on_drop(true);
91
92 let mut child = command
93 .spawn()
94 .with_context(|| format!("spawn process task {:?}", input.command))?;
95 let pid = child.id();
96 let stdout = child.stdout.take();
97 let stderr = child.stderr.take();
98 let output = Arc::new(ctx.output);
99 let process_id = format!("task-{}", ctx.task_id);
100 let (stop_tx, stop_rx) = oneshot::channel();
101 if let Some(registry) = ctx.process_registry.as_ref() {
102 registry
103 .register_process(
104 ProcessDescriptor {
105 process_id: process_id.clone(),
106 origin: ProcessOrigin::BackgroundTask,
107 state: ProcessState::Running,
108 command: command_parts.clone(),
109 command_summary: command_summary(&command_parts),
110 cwd: Some(cwd.display().to_string()),
111 pid,
112 task_id: Some(ctx.task_id.clone()),
113 thread_id: ctx.thread_id.clone(),
114 turn_id: ctx.turn_id.clone(),
115 runner_destination_id: None,
116 runner_session_id: None,
117 stoppable: true,
118 started_at: time::OffsetDateTime::now_utc(),
119 updated_at: time::OffsetDateTime::now_utc(),
120 stdout_tail: None,
121 stderr_tail: None,
122 },
123 Some(Arc::new(ChannelProcessStopper::new(stop_tx))),
124 )
125 .await?;
126 }
127
128 let stdout_task = tokio::spawn(stream_pipe(
129 stdout,
130 TaskOutputStream::Stdout,
131 Arc::clone(&output),
132 ));
133 let stderr_task = tokio::spawn(stream_pipe(
134 stderr,
135 TaskOutputStream::Stderr,
136 Arc::clone(&output),
137 ));
138 let (status, stopped_by_registry) = tokio::select! {
139 status = child.wait() => (status.context("wait for process task")?, false),
140 _ = stop_rx => {
141 child.kill().await.context("kill stopped process task")?;
142 if let Some(registry) = ctx.process_registry.as_ref() {
143 registry
144 .mark_process_stopped(&process_id, Some("stop requested".to_string()))
145 .await?;
146 }
147 (child.wait().await.context("wait for stopped process task")?, true)
148 }
149 };
150 stdout_task.await.context("join stdout reader")??;
151 stderr_task.await.context("join stderr reader")??;
152 if let Some(registry) = ctx.process_registry.as_ref()
153 && !stopped_by_registry
154 {
155 let _ = registry
156 .mark_process_exited(&process_id, status.code())
157 .await;
158 }
159
160 Ok(TaskExecutionResult {
161 exit_code: status.code(),
162 payload: serde_json::json!({
163 "command": input.command,
164 "args": input.args,
165 "cwd": cwd.display().to_string(),
166 "success": status.success(),
167 }),
168 })
169 }
170}
171
172struct ChannelProcessStopper {
173 stop_tx: Mutex<Option<oneshot::Sender<Option<String>>>>,
174}
175
176impl ChannelProcessStopper {
177 fn new(stop_tx: oneshot::Sender<Option<String>>) -> Self {
178 Self {
179 stop_tx: Mutex::new(Some(stop_tx)),
180 }
181 }
182}
183
184#[async_trait::async_trait]
185impl ProcessStopper for ChannelProcessStopper {
186 async fn stop(&self, reason: Option<String>) -> anyhow::Result<()> {
187 if let Some(stop_tx) = self.stop_tx.lock().await.take() {
188 let _ = stop_tx.send(reason);
189 }
190 Ok(())
191 }
192}
193
194async fn execute_remote_process_task(
195 ctx: TaskExecutionContext,
196 input: ProcessTaskInput,
197) -> anyhow::Result<TaskExecutionResult> {
198 let Some(session) = ctx.runner_session.clone() else {
199 bail!("remote process task requires runner session");
200 };
201 let command_id = ctx.task_id.clone();
202 let command_parts = std::iter::once(input.command.clone())
203 .chain(input.args.clone())
204 .collect::<Vec<_>>();
205 let state = session.state();
206 let process_id = format!("remote-{}", ctx.task_id);
207 if let Some(registry) = ctx.process_registry.as_ref() {
208 registry
209 .register_process(
210 ProcessDescriptor {
211 process_id: process_id.clone(),
212 origin: ProcessOrigin::RemoteRunner,
213 state: ProcessState::Running,
214 command: command_parts.clone(),
215 command_summary: command_summary(&command_parts),
216 cwd: input.cwd.clone(),
217 pid: None,
218 task_id: Some(ctx.task_id.clone()),
219 thread_id: ctx.thread_id.clone(),
220 turn_id: ctx.turn_id.clone(),
221 runner_destination_id: ctx
222 .runner_destination
223 .as_ref()
224 .map(|destination| destination.id.clone())
225 .or_else(|| Some(state.destination_id.clone())),
226 runner_session_id: Some(state.session_id.clone()),
227 stoppable: true,
228 started_at: time::OffsetDateTime::now_utc(),
229 updated_at: time::OffsetDateTime::now_utc(),
230 stdout_tail: None,
231 stderr_tail: None,
232 },
233 Some(Arc::new(RemoteCommandStopper {
234 session: Arc::clone(&session),
235 command_id: command_id.clone(),
236 })),
237 )
238 .await?;
239 }
240 let output = match session
241 .run_command(RunnerCommandRequest {
242 command_id: command_id.clone(),
243 program: input.command.clone(),
244 args: input.args.clone(),
245 cwd: input.cwd.as_deref().map(PathBuf::from),
246 env: input.env_overrides.clone().into_iter().collect(),
247 })
248 .await
249 {
250 Ok(output) => output,
251 Err(error) => {
252 if let Some(registry) = ctx.process_registry.as_ref() {
253 let _ = registry
254 .mark_process_failed(&process_id, error.to_string())
255 .await;
256 }
257 return Err(error);
258 }
259 };
260 if !output.stdout.is_empty() {
261 ctx.output
262 .write(TaskOutputStream::Stdout, output.stdout.clone())
263 .await?;
264 }
265 if !output.stderr.is_empty() {
266 ctx.output
267 .write(TaskOutputStream::Stderr, output.stderr.clone())
268 .await?;
269 }
270 if let Some(registry) = ctx.process_registry.as_ref() {
271 let _ = registry
272 .mark_process_exited(&process_id, output.exit_code)
273 .await;
274 }
275 Ok(TaskExecutionResult {
276 exit_code: output.exit_code,
277 payload: serde_json::json!({
278 "command": input.command,
279 "args": input.args,
280 "cwd": input.cwd.unwrap_or_else(|| ".".to_string()),
281 "runner_destination": ctx.runner_destination.as_ref().map(|destination| &destination.id),
282 "runner_session": session.state().session_id,
283 "success": output.exit_code == Some(0),
284 }),
285 })
286}
287
288struct RemoteCommandStopper {
289 session: Arc<dyn roder_api::remote_runner::RemoteRunnerSession>,
290 command_id: String,
291}
292
293#[async_trait::async_trait]
294impl ProcessStopper for RemoteCommandStopper {
295 async fn stop(&self, _reason: Option<String>) -> anyhow::Result<()> {
296 let cancelled = self.session.cancel_command(&self.command_id).await?;
297 if cancelled {
298 Ok(())
299 } else {
300 bail!("remote runner did not cancel command {:?}", self.command_id)
301 }
302 }
303}
304
305async fn stream_pipe(
306 pipe: Option<impl tokio::io::AsyncRead + Unpin>,
307 stream: TaskOutputStream,
308 output: Arc<roder_api::tasks::TaskOutputSink>,
309) -> anyhow::Result<()> {
310 let Some(pipe) = pipe else {
311 return Ok(());
312 };
313 let mut reader = BufReader::new(pipe);
314 let mut buf = Vec::new();
315 loop {
316 buf.clear();
317 let bytes = reader.read_until(b'\n', &mut buf).await?;
318 if bytes == 0 {
319 break;
320 }
321 output
322 .write(stream.clone(), String::from_utf8_lossy(&buf).to_string())
323 .await?;
324 }
325 Ok(())
326}
327
328fn resolve_cwd(workspace_root: Option<&str>, cwd: Option<&str>) -> anyhow::Result<PathBuf> {
329 let Some(root) = workspace_root else {
330 return match cwd {
331 Some(cwd) => Ok(PathBuf::from(cwd)),
332 None => std::env::current_dir().context("resolve current directory"),
333 };
334 };
335 let root = std::fs::canonicalize(root).with_context(|| format!("canonicalize root {root}"))?;
336 let candidate = match cwd {
337 Some(cwd) => {
338 let path = Path::new(cwd);
339 if path.is_absolute() {
340 path.to_path_buf()
341 } else {
342 root.join(path)
343 }
344 }
345 None => root.clone(),
346 };
347 let candidate = std::fs::canonicalize(&candidate)
348 .with_context(|| format!("canonicalize cwd {}", candidate.display()))?;
349 if !candidate.starts_with(&root) {
350 bail!(
351 "process task cwd {} escapes workspace root {}",
352 candidate.display(),
353 root.display()
354 );
355 }
356 Ok(candidate)
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn resolve_cwd_rejects_paths_outside_workspace() {
365 let root = std::env::current_dir().unwrap();
366 let outside = root.parent().unwrap_or(&root);
367 let err = resolve_cwd(
368 Some(root.to_str().unwrap()),
369 Some(outside.to_str().unwrap()),
370 )
371 .unwrap_err();
372
373 assert!(err.to_string().contains("escapes workspace root"));
374 }
375
376 #[test]
377 fn schema_snapshot_covers_process_task_input() {
378 let executor = ProcessTaskExecutor;
379 let spec = executor
380 .spec()
381 .normalized_for_model(roder_api::ToolSchemaPolicy::strict());
382 let schema = serde_json::to_string(&spec.input_schema).unwrap();
383
384 assert!(schema.starts_with(r#"{"type":"object","required":["command"],"properties":"#));
385 assert!(schema.contains(
386 r#""env_overrides":{"type":"object","additionalProperties":{"type":"string"}}"#
387 ));
388 assert!(schema.contains(r#""additionalProperties":false"#));
389 }
390}