abu-agent 0.2.0

Agent development library
Documentation
use std::{collections::HashMap, ffi::OsStr, path::Path};
use abu_mcp::{client::McpClient, transport::process::McpProcessTransport, McpToolCall, McpToolCallResult, McpToolCallResultContent};
use abu_tool::{ToolCallResult, ToolError};
use thiserrorctx::Context;
use serde::Deserialize;
use tracing::{debug, warn};
use crate::AgentResult;

pub struct McpManager {
    pub default_protocol_version: String,
    pub stdio_servers: Vec<McpClient<McpProcessTransport>>
}

#[derive(Debug, Clone, Deserialize)]
pub struct McpConfig {
    #[serde(default = "default_protocol_version", alias = "defaultProtocolVersion")]
    pub default_protocol_version: String,
    #[serde(alias = "mcpServers")]
    pub mcp_servers: HashMap<String, McpServerConfig>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct McpServerConfig {
    pub transport: String,
    pub command: String,
    #[serde(default)]
    pub args: Vec<String>,
    #[serde(default)]
    pub env: HashMap<String, String>,
}

impl McpManager {
    pub fn new() -> Self {
        Self {
            default_protocol_version: default_protocol_version(),
            stdio_servers: vec![],
        }
    }

    pub async fn load_config(path: impl AsRef<Path>) -> AgentResult<Self> {
        debug!("load mcp config from {}", path.as_ref().display());
        let context = std::fs::read_to_string(path).context("read config file")?;
        let config: McpConfig = serde_json::from_str(&context).context("parse config file")?;

        let mut mcp_manager = McpManager { default_protocol_version: config.default_protocol_version, stdio_servers: vec![],};
        for (name, server_config) in config.mcp_servers {
            debug!("add mcp server {}", name);
            match server_config.transport.as_str() {
                "stdio" => {
                    mcp_manager.add_stdio_server(server_config.command, server_config.args) 
                        .await.with_context(|| format!("init client {}", name))?;
                }
                transport => warn!("unsupport transport '{}' in mcpserver {}", transport, name),
            };
        }

        Ok(mcp_manager)
    }

    pub async fn add_stdio_server<S, I>(&mut self, cmd: S, args: I) -> AgentResult<&McpClient<McpProcessTransport>> 
    where 
        I: IntoIterator<Item = S>,
        S: AsRef<OsStr>,
    {
        let client = Self::init_stdio_clinet(cmd, args).await?;
        self.stdio_servers.push(client);
        Ok(self.stdio_servers.last().unwrap())
    }

    pub async fn execute_toolcall(&mut self, name: &str, arguments: serde_json::Value) -> AgentResult<ToolCallResult> {
        for client in self.stdio_servers.iter_mut() {
            if client.has_tool(&name) {
                let mcp_tool_call = McpToolCall {
                    name: name.to_string(), arguments: Some(arguments)
                };
                let mcp_tool_call_result = client.tools_call(mcp_tool_call).await?;
                let tool_call_result = mcp_tool_call_result_to_tool_call_result(mcp_tool_call_result);
                return Ok(tool_call_result)
            }
        }
        Err(ToolError::ToolNotFound(name.to_string()))?
    }

    pub fn has_tool(&self, tool_name: &str) -> bool {
        for client in self.stdio_servers.iter() {
            if client.has_tool(tool_name) {
                return true;
            }
        }
        false
    }

    pub async fn init_stdio_clinet<I, S>(cmd: S, args: I) -> AgentResult<McpClient<McpProcessTransport>> 
    where 
        I: IntoIterator<Item = S>,
        S: AsRef<OsStr>,
    {
        let transport = McpProcessTransport::new(cmd, args)
            .context("new process transport")?;
        let mut client = McpClient::new(transport);
        client.initialize().await.context("initialize mcpserver")?;
        client.tools_list().await.context("tools_list mcpserver")?;
        Ok(client)
    }
}

fn default_protocol_version() -> String {
    abu_mcp::LATEST_PROTOCOL_VERSION.to_string()
}

fn mcp_tool_call_result_to_tool_call_result(result: McpToolCallResult) -> ToolCallResult {
    let is_error = result.is_error.unwrap_or(false);
    let context = result
        .content
        .iter()
        .map(|content| {
            match content {
                McpToolCallResultContent::Text { text } => text.as_str(),
            }
        })
        .collect::<Vec<&str>>()
        .join("\n");
    ToolCallResult { is_error, context }
}