unified-agent-api 0.3.5

Agent-agnostic facade and registry for wrapper backends
Documentation
use std::{
    env,
    ffi::OsString,
    io,
    path::{Path, PathBuf},
    process::{ExitStatus, Stdio},
    time::Duration,
};

use codex::CodexHomeLayout;
use tokio::{
    io::{AsyncRead, AsyncReadExt},
    process::{Child, Command},
    task::JoinHandle,
};

use crate::{
    bounds::{enforce_mcp_output_bound, MCP_STDERR_BOUND_BYTES, MCP_STDOUT_BOUND_BYTES},
    mcp::AgentWrapperMcpCommandOutput,
    AgentWrapperError,
};

use super::{
    backend_error, PINNED_CAPTURE_FAILURE, PINNED_MCP_RUNTIME_CONFLICT,
    PINNED_PREPARE_CODEX_HOME_FAILURE, PINNED_SPAWN_FAILURE, PINNED_WAIT_FAILURE,
};

pub(super) struct CapturedCodexMcpCommandOutput {
    pub(super) status: ExitStatus,
    pub(super) stdout_bytes: Vec<u8>,
    pub(super) stdout_saw_more: bool,
    pub(super) stderr_bytes: Vec<u8>,
    pub(super) stderr_saw_more: bool,
}

pub(super) async fn capture_codex_mcp_output(
    resolved: &super::resolve::ResolvedCodexMcpCommand,
    argv: &[OsString],
) -> Result<CapturedCodexMcpCommandOutput, AgentWrapperError> {
    if resolved.timeout == Some(Duration::ZERO) {
        return Err(backend_error(super::super::PINNED_TIMEOUT));
    }

    if let Some(codex_home) = resolved.materialize_codex_home.as_ref() {
        CodexHomeLayout::new(codex_home.clone())
            .materialize(true)
            .map_err(|_| backend_error(PINNED_PREPARE_CODEX_HOME_FAILURE))?;
    }

    let mut command = Command::new(&resolved.binary_path);
    command
        .args(argv)
        .stdin(Stdio::null())
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .kill_on_drop(true)
        .env_clear()
        .envs(&resolved.env)
        .current_dir(spawn_current_dir(resolved.working_dir.as_deref()));

    let mut child = command
        .spawn()
        .map_err(|_| backend_error(PINNED_SPAWN_FAILURE))?;

    let Some(stdout) = child.stdout.take() else {
        cleanup_child(&mut child).await;
        return Err(backend_error(PINNED_CAPTURE_FAILURE));
    };
    let Some(stderr) = child.stderr.take() else {
        cleanup_child(&mut child).await;
        return Err(backend_error(PINNED_CAPTURE_FAILURE));
    };

    let stdout_task = tokio::spawn(capture_bounded(stdout, MCP_STDOUT_BOUND_BYTES));
    let stderr_task = tokio::spawn(capture_bounded(stderr, MCP_STDERR_BOUND_BYTES));

    let status = match wait_for_exit(&mut child, resolved.timeout).await {
        Ok(status) => status,
        Err(err) => {
            stdout_task.abort();
            stderr_task.abort();
            return Err(err);
        }
    };

    let (stdout_bytes, stdout_saw_more) = join_capture_task(stdout_task).await?;
    let (stderr_bytes, stderr_saw_more) = join_capture_task(stderr_task).await?;

    Ok(CapturedCodexMcpCommandOutput {
        status,
        stdout_bytes,
        stdout_saw_more,
        stderr_bytes,
        stderr_saw_more,
    })
}

fn spawn_current_dir(working_dir: Option<&Path>) -> PathBuf {
    working_dir
        .map(Path::to_path_buf)
        .or_else(|| env::current_dir().ok())
        .unwrap_or_else(env::temp_dir)
}

pub(super) fn finalize_codex_mcp_output(
    argv: &[OsString],
    captured: CapturedCodexMcpCommandOutput,
) -> Result<AgentWrapperMcpCommandOutput, AgentWrapperError> {
    if !captured.status.success()
        && is_manifest_runtime_conflict(argv, &captured.stdout_bytes, &captured.stderr_bytes)
    {
        return Err(backend_error(PINNED_MCP_RUNTIME_CONFLICT));
    }

    let (stdout, stdout_truncated) = enforce_mcp_output_bound(
        &captured.stdout_bytes,
        captured.stdout_saw_more,
        MCP_STDOUT_BOUND_BYTES,
    );
    let (stderr, stderr_truncated) = enforce_mcp_output_bound(
        &captured.stderr_bytes,
        captured.stderr_saw_more,
        MCP_STDERR_BOUND_BYTES,
    );

    Ok(AgentWrapperMcpCommandOutput {
        status: captured.status,
        stdout,
        stderr,
        stdout_truncated,
        stderr_truncated,
    })
}

fn is_manifest_runtime_conflict(argv: &[OsString], stdout: &[u8], stderr: &[u8]) -> bool {
    let stderr = String::from_utf8_lossy(stderr);
    let stdout = String::from_utf8_lossy(stdout);
    classify_manifest_runtime_conflict_text(argv, &format!("{stderr}\n{stdout}"))
}

pub(super) fn classify_manifest_runtime_conflict_text(argv: &[OsString], text: &str) -> bool {
    let text = text.to_ascii_lowercase();

    let unknown_signal = [
        "unknown",
        "unrecognized",
        "unexpected",
        "invalid",
        "no such",
        "not recognized",
    ]
    .iter()
    .any(|signal| text.contains(signal));

    if !unknown_signal {
        return false;
    }

    let syntax_context = ["command", "subcommand", "argument", "option", "flag"]
        .iter()
        .any(|signal| text.contains(signal));
    if !syntax_context {
        return false;
    }

    if is_add_shape_conflict(argv, &text) {
        return true;
    }

    let subcommand_conflict = manifest_conflict_tokens(argv)
        .into_iter()
        .any(|token| text.contains(token));
    let json_flag_conflict = text.contains("--json")
        && (text.contains("flag") || text.contains("option") || text.contains("argument"));

    subcommand_conflict || json_flag_conflict
}

fn is_add_shape_conflict(argv: &[OsString], text: &str) -> bool {
    matches!(argv.get(1).and_then(|arg| arg.to_str()), Some("add"))
        && ["--env", "--url", "--bearer-token-env-var"]
            .iter()
            .any(|token| text.contains(token))
}

fn manifest_conflict_tokens(argv: &[OsString]) -> Vec<&'static str> {
    let mut tokens = vec!["mcp"];
    match argv.get(1).and_then(|arg| arg.to_str()) {
        Some("list") => tokens.push("list"),
        Some("get") => tokens.push("get"),
        Some("add") => tokens.push("add"),
        Some("remove") => tokens.push("remove"),
        _ => {}
    }
    tokens
}

async fn wait_for_exit(
    child: &mut Child,
    timeout: Option<Duration>,
) -> Result<ExitStatus, AgentWrapperError> {
    match timeout {
        Some(timeout) => {
            debug_assert_ne!(timeout, Duration::ZERO);
            match tokio::time::timeout(timeout, child.wait()).await {
                Ok(Ok(status)) => Ok(status),
                Ok(Err(_)) => Err(backend_error(PINNED_WAIT_FAILURE)),
                Err(_) => {
                    cleanup_child(child).await;
                    Err(backend_error(super::super::PINNED_TIMEOUT))
                }
            }
        }
        None => child
            .wait()
            .await
            .map_err(|_| backend_error(PINNED_WAIT_FAILURE)),
    }
}

async fn cleanup_child(child: &mut Child) {
    let _ = child.kill().await;
    let _ = child.wait().await;
}

async fn join_capture_task(
    task: JoinHandle<io::Result<(Vec<u8>, bool)>>,
) -> Result<(Vec<u8>, bool), AgentWrapperError> {
    task.await
        .map_err(|_| backend_error(PINNED_CAPTURE_FAILURE))?
        .map_err(|_| backend_error(PINNED_CAPTURE_FAILURE))
}

pub(super) async fn capture_bounded<R>(
    mut reader: R,
    bound_bytes: usize,
) -> io::Result<(Vec<u8>, bool)>
where
    R: AsyncRead + Unpin,
{
    let retain_bound = bound_bytes.saturating_add(1);
    let mut retained = Vec::with_capacity(retain_bound.min(4096));
    let mut saw_more = false;
    let mut chunk = [0u8; 4096];

    loop {
        let read = reader.read(&mut chunk).await?;
        if read == 0 {
            break;
        }

        if retained.len() < retain_bound {
            let remaining = retain_bound - retained.len();
            let to_copy = remaining.min(read);
            retained.extend_from_slice(&chunk[..to_copy]);
            if to_copy < read {
                saw_more = true;
            }
        } else {
            saw_more = true;
        }
    }

    if retained.len() > bound_bytes {
        retained.truncate(bound_bytes);
        saw_more = true;
    }

    Ok((retained, saw_more))
}