roder-ext-task-process 0.1.1

Agentic software development tools and SDKs for Roder.
Documentation
use std::collections::BTreeMap;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::Arc;

use anyhow::{Context, bail};
use roder_api::processes::{
    ProcessDescriptor, ProcessOrigin, ProcessState, ProcessStopper, command_summary,
};
use roder_api::remote_runner::RunnerCommandRequest;
use roder_api::tasks::{
    TaskExecutionContext, TaskExecutionResult, TaskExecutor, TaskOutputStream, TaskSpec,
};
use serde::Deserialize;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
use tokio::sync::{Mutex, oneshot};

pub const PROCESS_TASK_EXECUTOR_ID: &str = "process";

#[derive(Debug, Clone, Deserialize)]
struct ProcessTaskInput {
    command: String,
    #[serde(default)]
    args: Vec<String>,
    #[serde(default)]
    cwd: Option<String>,
    #[serde(default)]
    env_overrides: BTreeMap<String, String>,
}

#[derive(Debug, Clone)]
pub struct ProcessTaskExecutor;

#[async_trait::async_trait]
impl TaskExecutor for ProcessTaskExecutor {
    fn id(&self) -> String {
        PROCESS_TASK_EXECUTOR_ID.to_string()
    }

    fn spec(&self) -> TaskSpec {
        TaskSpec {
            kind: PROCESS_TASK_EXECUTOR_ID.to_string(),
            description: "Run a background process inside the workspace.".to_string(),
            input_schema: serde_json::json!({
                "type": "object",
                "required": ["command"],
                "properties": {
                    "command": { "type": "string" },
                    "args": { "type": "array", "items": { "type": "string" } },
                    "cwd": { "type": "string" },
                    "env_overrides": {
                        "type": "object",
                        "additionalProperties": { "type": "string" }
                    }
                },
                "additionalProperties": false
            }),
            default_timeout_seconds: None,
            metadata: serde_json::json!({ "category": "process" }),
        }
    }

    async fn execute(
        &self,
        ctx: TaskExecutionContext,
        input: serde_json::Value,
    ) -> anyhow::Result<TaskExecutionResult> {
        let input: ProcessTaskInput =
            serde_json::from_value(input).context("deserialize process task input")?;
        if input.command.trim().is_empty() {
            bail!("process task command must not be empty");
        }

        if ctx.runner_session.is_some() {
            return execute_remote_process_task(ctx, input).await;
        }

        let cwd = resolve_cwd(ctx.workspace_root.as_deref(), input.cwd.as_deref())?;
        let command_parts = std::iter::once(input.command.clone())
            .chain(input.args.clone())
            .collect::<Vec<_>>();
        let mut command = Command::new(&input.command);
        command
            .args(&input.args)
            .current_dir(&cwd)
            .envs(&input.env_overrides)
            .stdout(Stdio::piped())
            .stderr(Stdio::piped())
            .kill_on_drop(true);

        let mut child = command
            .spawn()
            .with_context(|| format!("spawn process task {:?}", input.command))?;
        let pid = child.id();
        let stdout = child.stdout.take();
        let stderr = child.stderr.take();
        let output = Arc::new(ctx.output);
        let process_id = format!("task-{}", ctx.task_id);
        let (stop_tx, stop_rx) = oneshot::channel();
        if let Some(registry) = ctx.process_registry.as_ref() {
            registry
                .register_process(
                    ProcessDescriptor {
                        process_id: process_id.clone(),
                        origin: ProcessOrigin::BackgroundTask,
                        state: ProcessState::Running,
                        command: command_parts.clone(),
                        command_summary: command_summary(&command_parts),
                        cwd: Some(cwd.display().to_string()),
                        pid,
                        task_id: Some(ctx.task_id.clone()),
                        thread_id: ctx.thread_id.clone(),
                        turn_id: ctx.turn_id.clone(),
                        runner_destination_id: None,
                        runner_session_id: None,
                        stoppable: true,
                        started_at: time::OffsetDateTime::now_utc(),
                        updated_at: time::OffsetDateTime::now_utc(),
                        stdout_tail: None,
                        stderr_tail: None,
                    },
                    Some(Arc::new(ChannelProcessStopper::new(stop_tx))),
                )
                .await?;
        }

        let stdout_task = tokio::spawn(stream_pipe(
            stdout,
            TaskOutputStream::Stdout,
            Arc::clone(&output),
        ));
        let stderr_task = tokio::spawn(stream_pipe(
            stderr,
            TaskOutputStream::Stderr,
            Arc::clone(&output),
        ));
        let (status, stopped_by_registry) = tokio::select! {
            status = child.wait() => (status.context("wait for process task")?, false),
            _ = stop_rx => {
                child.kill().await.context("kill stopped process task")?;
                if let Some(registry) = ctx.process_registry.as_ref() {
                    registry
                        .mark_process_stopped(&process_id, Some("stop requested".to_string()))
                        .await?;
                }
                (child.wait().await.context("wait for stopped process task")?, true)
            }
        };
        stdout_task.await.context("join stdout reader")??;
        stderr_task.await.context("join stderr reader")??;
        if let Some(registry) = ctx.process_registry.as_ref()
            && !stopped_by_registry
        {
            let _ = registry
                .mark_process_exited(&process_id, status.code())
                .await;
        }

        Ok(TaskExecutionResult {
            exit_code: status.code(),
            payload: serde_json::json!({
                "command": input.command,
                "args": input.args,
                "cwd": cwd.display().to_string(),
                "success": status.success(),
            }),
        })
    }
}

struct ChannelProcessStopper {
    stop_tx: Mutex<Option<oneshot::Sender<Option<String>>>>,
}

impl ChannelProcessStopper {
    fn new(stop_tx: oneshot::Sender<Option<String>>) -> Self {
        Self {
            stop_tx: Mutex::new(Some(stop_tx)),
        }
    }
}

#[async_trait::async_trait]
impl ProcessStopper for ChannelProcessStopper {
    async fn stop(&self, reason: Option<String>) -> anyhow::Result<()> {
        if let Some(stop_tx) = self.stop_tx.lock().await.take() {
            let _ = stop_tx.send(reason);
        }
        Ok(())
    }
}

async fn execute_remote_process_task(
    ctx: TaskExecutionContext,
    input: ProcessTaskInput,
) -> anyhow::Result<TaskExecutionResult> {
    let Some(session) = ctx.runner_session.clone() else {
        bail!("remote process task requires runner session");
    };
    let command_id = ctx.task_id.clone();
    let command_parts = std::iter::once(input.command.clone())
        .chain(input.args.clone())
        .collect::<Vec<_>>();
    let state = session.state();
    let process_id = format!("remote-{}", ctx.task_id);
    if let Some(registry) = ctx.process_registry.as_ref() {
        registry
            .register_process(
                ProcessDescriptor {
                    process_id: process_id.clone(),
                    origin: ProcessOrigin::RemoteRunner,
                    state: ProcessState::Running,
                    command: command_parts.clone(),
                    command_summary: command_summary(&command_parts),
                    cwd: input.cwd.clone(),
                    pid: None,
                    task_id: Some(ctx.task_id.clone()),
                    thread_id: ctx.thread_id.clone(),
                    turn_id: ctx.turn_id.clone(),
                    runner_destination_id: ctx
                        .runner_destination
                        .as_ref()
                        .map(|destination| destination.id.clone())
                        .or_else(|| Some(state.destination_id.clone())),
                    runner_session_id: Some(state.session_id.clone()),
                    stoppable: true,
                    started_at: time::OffsetDateTime::now_utc(),
                    updated_at: time::OffsetDateTime::now_utc(),
                    stdout_tail: None,
                    stderr_tail: None,
                },
                Some(Arc::new(RemoteCommandStopper {
                    session: Arc::clone(&session),
                    command_id: command_id.clone(),
                })),
            )
            .await?;
    }
    let output = match session
        .run_command(RunnerCommandRequest {
            command_id: command_id.clone(),
            program: input.command.clone(),
            args: input.args.clone(),
            cwd: input.cwd.as_deref().map(PathBuf::from),
            env: input.env_overrides.clone().into_iter().collect(),
        })
        .await
    {
        Ok(output) => output,
        Err(error) => {
            if let Some(registry) = ctx.process_registry.as_ref() {
                let _ = registry
                    .mark_process_failed(&process_id, error.to_string())
                    .await;
            }
            return Err(error);
        }
    };
    if !output.stdout.is_empty() {
        ctx.output
            .write(TaskOutputStream::Stdout, output.stdout.clone())
            .await?;
    }
    if !output.stderr.is_empty() {
        ctx.output
            .write(TaskOutputStream::Stderr, output.stderr.clone())
            .await?;
    }
    if let Some(registry) = ctx.process_registry.as_ref() {
        let _ = registry
            .mark_process_exited(&process_id, output.exit_code)
            .await;
    }
    Ok(TaskExecutionResult {
        exit_code: output.exit_code,
        payload: serde_json::json!({
            "command": input.command,
            "args": input.args,
            "cwd": input.cwd.unwrap_or_else(|| ".".to_string()),
            "runner_destination": ctx.runner_destination.as_ref().map(|destination| &destination.id),
            "runner_session": session.state().session_id,
            "success": output.exit_code == Some(0),
        }),
    })
}

struct RemoteCommandStopper {
    session: Arc<dyn roder_api::remote_runner::RemoteRunnerSession>,
    command_id: String,
}

#[async_trait::async_trait]
impl ProcessStopper for RemoteCommandStopper {
    async fn stop(&self, _reason: Option<String>) -> anyhow::Result<()> {
        let cancelled = self.session.cancel_command(&self.command_id).await?;
        if cancelled {
            Ok(())
        } else {
            bail!("remote runner did not cancel command {:?}", self.command_id)
        }
    }
}

async fn stream_pipe(
    pipe: Option<impl tokio::io::AsyncRead + Unpin>,
    stream: TaskOutputStream,
    output: Arc<roder_api::tasks::TaskOutputSink>,
) -> anyhow::Result<()> {
    let Some(pipe) = pipe else {
        return Ok(());
    };
    let mut reader = BufReader::new(pipe);
    let mut buf = Vec::new();
    loop {
        buf.clear();
        let bytes = reader.read_until(b'\n', &mut buf).await?;
        if bytes == 0 {
            break;
        }
        output
            .write(stream.clone(), String::from_utf8_lossy(&buf).to_string())
            .await?;
    }
    Ok(())
}

fn resolve_cwd(workspace_root: Option<&str>, cwd: Option<&str>) -> anyhow::Result<PathBuf> {
    let Some(root) = workspace_root else {
        return match cwd {
            Some(cwd) => Ok(PathBuf::from(cwd)),
            None => std::env::current_dir().context("resolve current directory"),
        };
    };
    let root = std::fs::canonicalize(root).with_context(|| format!("canonicalize root {root}"))?;
    let candidate = match cwd {
        Some(cwd) => {
            let path = Path::new(cwd);
            if path.is_absolute() {
                path.to_path_buf()
            } else {
                root.join(path)
            }
        }
        None => root.clone(),
    };
    let candidate = std::fs::canonicalize(&candidate)
        .with_context(|| format!("canonicalize cwd {}", candidate.display()))?;
    if !candidate.starts_with(&root) {
        bail!(
            "process task cwd {} escapes workspace root {}",
            candidate.display(),
            root.display()
        );
    }
    Ok(candidate)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn resolve_cwd_rejects_paths_outside_workspace() {
        let root = std::env::current_dir().unwrap();
        let outside = root.parent().unwrap_or(&root);
        let err = resolve_cwd(
            Some(root.to_str().unwrap()),
            Some(outside.to_str().unwrap()),
        )
        .unwrap_err();

        assert!(err.to_string().contains("escapes workspace root"));
    }

    #[test]
    fn schema_snapshot_covers_process_task_input() {
        let executor = ProcessTaskExecutor;
        let spec = executor
            .spec()
            .normalized_for_model(roder_api::ToolSchemaPolicy::strict());
        let schema = serde_json::to_string(&spec.input_schema).unwrap();

        assert!(schema.starts_with(r#"{"type":"object","required":["command"],"properties":"#));
        assert!(schema.contains(
            r#""env_overrides":{"type":"object","additionalProperties":{"type":"string"}}"#
        ));
        assert!(schema.contains(r#""additionalProperties":false"#));
    }
}