acpr 0.3.0

Run agents from the ACP registry
Documentation
pub mod cli;
pub mod registry;

pub use cli::*;
pub use registry::*;

use agent_client_protocol::{Agent as AcpAgent, ByteStreams, Client, ConnectTo};
use std::path::PathBuf;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::process::Command;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
use tracing::{debug, info};

/// Simple function to run an agent by name
pub async fn run(agent_name: &str) -> Result<(), Box<dyn std::error::Error>> {
    Acpr::new(agent_name).run().await
}

/// Main library interface for acpr
pub struct Acpr {
    pub agent_name: String,
    cache_dir: Option<PathBuf>,
    registry_file: Option<PathBuf>,
    force: Option<ForceOption>,
}

impl Acpr {
    /// Create a new Acpr instance for the specified agent
    pub fn new(agent_name: &str) -> Self {
        Self {
            agent_name: agent_name.to_string(),
            cache_dir: None,
            registry_file: None,
            force: None,
        }
    }

    /// Set a custom cache directory
    pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
        self.cache_dir = Some(cache_dir);
        self
    }

    /// Set a custom registry file
    pub fn with_registry_file(mut self, registry_file: PathBuf) -> Self {
        self.registry_file = Some(registry_file);
        self
    }

    /// Set force option
    pub fn with_force(mut self, force: ForceOption) -> Self {
        self.force = Some(force);
        self
    }

    /// Run the agent with default stdio
    pub async fn run(&self) -> Result<(), Box<dyn std::error::Error>> {
        self.run_with_streams(tokio::io::stdin(), tokio::io::stdout())
            .await
    }

    /// Run the agent with custom stdio streams
    pub async fn run_with_streams<R, W>(
        &self,
        stdin: R,
        stdout: W,
    ) -> Result<(), Box<dyn std::error::Error>>
    where
        R: AsyncRead + Unpin + Send + 'static,
        W: AsyncWrite + Unpin + Send + 'static,
    {
        let cache_dir = self.cache_dir.clone().unwrap_or_else(|| {
            dirs::cache_dir()
                .expect("No cache directory found")
                .join("acpr")
        });

        tokio::fs::create_dir_all(&cache_dir).await?;
        let registry =
            fetch_registry(&cache_dir, self.force.as_ref(), self.registry_file.as_ref()).await?;
        let agent = registry
            .agents
            .iter()
            .find(|a| a.id == self.agent_name)
            .ok_or("Agent not found")?;

        debug!("Running agent: {}", agent.id);

        let mut cmd = self.build_command(agent, &cache_dir).await?;
        cmd.stdin(std::process::Stdio::piped())
            .stdout(std::process::Stdio::piped())
            .stderr(std::process::Stdio::inherit());
        debug!("Running cmd: {cmd:?}");

        let mut child = cmd.spawn()?;
        let child_stdin = child.stdin.take().unwrap();
        let child_stdout = child.stdout.take().unwrap();

        let stdin_future = async {
            let mut stdin = stdin;
            let mut child_stdin = child_stdin;
            let mut buf = [0u8; 8192];
            loop {
                match stdin.read(&mut buf).await {
                    Ok(0) => {
                        debug!("stdin: EOF received");
                        break;
                    }
                    Ok(n) => {
                        debug!("stdin: received {} bytes", n);
                        if let Err(e) = child_stdin.write_all(&buf[..n]).await {
                            tracing::debug!("stdin write error: {}", e);
                            break;
                        }
                        if let Err(e) = child_stdin.flush().await {
                            tracing::debug!("stdin flush error: {}", e);
                            break;
                        }
                        debug!("stdin: forwarded {} bytes to child", n);
                    }
                    Err(e) => {
                        tracing::debug!("stdin read error: {}", e);
                        break;
                    }
                }
            }
            Ok::<(), std::io::Error>(())
        };

        let stdout_future = async {
            let mut child_stdout = child_stdout;
            let mut stdout = stdout;
            let mut buf = [0u8; 8192];
            loop {
                match child_stdout.read(&mut buf).await {
                    Ok(0) => {
                        debug!("stdout: EOF from child");
                        break;
                    }
                    Ok(n) => {
                        debug!("stdout: received {} bytes from child", n);
                        if let Err(e) = stdout.write_all(&buf[..n]).await {
                            tracing::debug!("stdout write error: {}", e);
                            break;
                        }
                        if let Err(e) = stdout.flush().await {
                            tracing::debug!("stdout flush error: {}", e);
                            break;
                        }
                        debug!("stdout: forwarded {} bytes", n);
                    }
                    Err(e) => {
                        tracing::debug!("stdout read error: {}", e);
                        break;
                    }
                }
            }
            Ok::<(), std::io::Error>(())
        };

        tokio::try_join!(
            async { child.wait().await.map_err(|e| e.into()) },
            stdin_future,
            stdout_future
        )?;

        Ok(())
    }

    async fn build_command(
        &self,
        agent: &Agent,
        cache_dir: &PathBuf,
    ) -> Result<Command, Box<dyn std::error::Error>> {
        if let Some(npx) = &agent.distribution.npx {
            info!("Executing npx package: {}", npx.package);
            let mut cmd = Command::new("npx");
            cmd.arg("-y");
            let package_arg = if npx.package.contains('@') && npx.package.matches('@').count() > 1 {
                npx.package.clone()
            } else {
                format!("{}@latest", npx.package)
            };
            cmd.arg(package_arg).args(&npx.args);
            Ok(cmd)
        } else if let Some(uvx) = &agent.distribution.uvx {
            info!("Executing uvx package: {}", uvx.package);
            let mut cmd = Command::new("uvx");
            cmd.arg(&uvx.package).args(&uvx.args);
            Ok(cmd)
        } else if !agent.distribution.binary.is_empty() {
            let platform = get_platform();
            debug!("Platform detected: {}", platform);
            if let Some(binary_dist) = agent.distribution.binary.get(&platform) {
                let binary_path =
                    download_binary(agent, binary_dist, cache_dir, self.force.as_ref()).await?;
                info!("Executing binary: {:?}", binary_path);
                let mut cmd = Command::new(&binary_path);
                cmd.args(&binary_dist.args);
                Ok(cmd)
            } else {
                Err(format!("No binary available for platform: {}", platform).into())
            }
        } else {
            Err("No supported distribution method found".into())
        }
    }
}

/// Implement ConnectTo<Client> so Acpr can act as an ACP agent
impl ConnectTo<Client> for Acpr {
    async fn connect_to(
        self,
        client: impl ConnectTo<AcpAgent>,
    ) -> Result<(), agent_client_protocol::Error> {
        debug!("ConnectTo: creating duplex streams");
        let (client_stdin, agent_stdin) = tokio::io::duplex(8192);
        let (agent_stdout, client_stdout) = tokio::io::duplex(8192);

        debug!("ConnectTo: creating ByteStreams for sacp");
        let byte_streams = ByteStreams::new(client_stdin.compat_write(), client_stdout.compat());

        debug!("ConnectTo: starting agent and client tasks");
        tokio::try_join!(
            async {
                debug!("ConnectTo: starting agent process");
                self.run_with_streams(agent_stdin, agent_stdout)
                    .await
                    .map_err(|e| agent_client_protocol::Error::internal_error().data(e.to_string()))
            },
            async {
                debug!("ConnectTo: starting sacp client connection");
                ConnectTo::<Client>::connect_to(byte_streams, client).await
            }
        )?;

        debug!("ConnectTo: both tasks completed successfully");
        Ok(())
    }
}

pub fn get_platform() -> String {
    let os = std::env::consts::OS;
    let arch = std::env::consts::ARCH;
    match (os, arch) {
        ("macos", "aarch64") => "darwin-aarch64",
        ("macos", "x86_64") => "darwin-x86_64",
        ("linux", "aarch64") => "linux-aarch64",
        ("linux", "x86_64") => "linux-x86_64",
        ("windows", "aarch64") => "windows-aarch64",
        ("windows", "x86_64") => "windows-x86_64",
        _ => "unknown",
    }
    .to_string()
}

pub async fn download_binary(
    agent: &Agent,
    binary_dist: &BinaryDist,
    cache_dir: &PathBuf,
    force: Option<&ForceOption>,
) -> Result<PathBuf, Box<dyn std::error::Error>> {
    let agent_cache_dir = cache_dir.join(&agent.id);
    tokio::fs::create_dir_all(&agent_cache_dir).await?;
    let binary_name = binary_dist.cmd.trim_start_matches("./");
    let binary_path = agent_cache_dir.join(binary_name);

    let should_download = match force {
        Some(ForceOption::All | ForceOption::Binary) => {
            debug!("Force download requested for binary");
            true
        }
        _ => {
            let exists = binary_path.exists();
            debug!("Binary exists at {:?}: {}", binary_path, exists);
            !exists
        }
    };

    if should_download {
        info!("Downloading binary from: {}", binary_dist.archive);
        let response = reqwest::get(&binary_dist.archive).await?;
        let archive_data = response.bytes().await?;
        debug!("Downloaded {} bytes", archive_data.len());

        if binary_dist.archive.ends_with(".zip") {
            debug!("Extracting zip archive");
            extract_zip(&archive_data, &agent_cache_dir).await?;
        } else if binary_dist.archive.ends_with(".tar.gz") || binary_dist.archive.ends_with(".tgz")
        {
            debug!("Extracting tar.gz archive");
            extract_tar_gz(&archive_data, &agent_cache_dir).await?;
        } else {
            debug!("Writing raw binary");
            tokio::fs::write(&binary_path, &archive_data).await?;
        }

        #[cfg(unix)]
        {
            use std::os::unix::fs::PermissionsExt;
            let mut perms = tokio::fs::metadata(&binary_path).await?.permissions();
            perms.set_mode(0o755);
            tokio::fs::set_permissions(&binary_path, perms).await?;
            debug!("Set executable permissions on binary");
        }

        info!("Binary ready at: {:?}", binary_path);
    } else {
        debug!("Using cached binary: {:?}", binary_path);
    }

    Ok(binary_path)
}

async fn extract_zip(data: &[u8], dest: &PathBuf) -> Result<(), Box<dyn std::error::Error>> {
    let data = data.to_vec();
    let dest = dest.clone();
    tokio::task::spawn_blocking(move || -> Result<(), String> {
        let cursor = std::io::Cursor::new(data);
        let mut archive = zip::ZipArchive::new(cursor).map_err(|e| e.to_string())?;
        for i in 0..archive.len() {
            let mut file = archive.by_index(i).map_err(|e| e.to_string())?;
            let outpath = dest.join(file.name());
            if file.is_dir() {
                std::fs::create_dir_all(&outpath).map_err(|e| e.to_string())?;
            } else {
                if let Some(parent) = outpath.parent() {
                    std::fs::create_dir_all(parent).map_err(|e| e.to_string())?;
                }
                let mut outfile = std::fs::File::create(&outpath).map_err(|e| e.to_string())?;
                std::io::copy(&mut file, &mut outfile).map_err(|e| e.to_string())?;
            }
        }
        Ok(())
    })
    .await
    .map_err(|e| e.to_string())??;
    Ok(())
}

async fn extract_tar_gz(data: &[u8], dest: &PathBuf) -> Result<(), Box<dyn std::error::Error>> {
    let data = data.to_vec();
    let dest = dest.clone();
    tokio::task::spawn_blocking(move || -> Result<(), String> {
        let decoder = flate2::read::GzDecoder::new(&data[..]);
        let mut archive = tar::Archive::new(decoder);
        archive.unpack(&dest).map_err(|e| e.to_string())?;
        Ok(())
    })
    .await
    .map_err(|e| e.to_string())??;
    Ok(())
}