velocia 0.3.2

velocia – production-ready AI agent framework using ADK-Rust, A2A protocol, and AWS DynamoDB
//! MCP (Model Context Protocol) tool loader.
//!
//! Supports both streamable-HTTP and stdio transports, mirroring Python's
//! `MCPToolLoaderADK`.

use std::collections::HashMap;

use serde_json::{json, Value};
use tracing::info;

use crate::config::tool::{McpConfig, McpToolConfig, McpTypeStreamable, McpTypeStdio};
use crate::error::{AgentKitError, Result};
use crate::tools::loader::DynTool;
use crate::utils::{expand_env_vars, get_protocol};

pub struct McpToolLoader;

impl McpToolLoader {
    /// Load an MCP tool from configuration, returning a `DynTool` that proxies
    /// calls to the MCP server.
    pub async fn load(cfg: &McpToolConfig) -> Result<DynTool> {
        match &cfg.mcp_config {
            McpConfig::Streamable(s) => Self::load_streamable(cfg, s).await,
            McpConfig::Stdio(s) => Self::load_stdio(cfg, s).await,
        }
    }

    async fn load_streamable(cfg: &McpToolConfig, stream: &McpTypeStreamable) -> Result<DynTool> {
        let protocol = get_protocol();
        let url = format!(
            "{protocol}://{}:{}{}",
            stream.url, stream.port, stream.path
        );
        let headers = resolve_headers(cfg);
        let _filter = cfg.tool_filter.clone().unwrap_or_default();
        let name = cfg.base.name.clone();
        let description = cfg.base.description.clone();

        info!("Loading MCP streamable-HTTP tool '{name}' from {url}");

        // ── adk-rust MCP integration ──────────────────────────────────────────
        // TODO: replace the HTTP stub with the real adk-rust MCPToolset once
        // crate integration is validated:
        //
        //   use adk_rust::tool::mcp::{McpToolset, StreamableHttpConnectionParams};
        //   let toolset = McpToolset::builder()
        //       .connection(StreamableHttpConnectionParams { url, headers })
        //       .tool_filter(filter)
        //       .build()
        //       .await?;
        //
        // For now we wrap a simple HTTP forward as a DynTool.

        let http = reqwest::Client::new();
        let url_clone = url.clone();

        Ok(DynTool::new(name.clone(), description, move |args: Value| {
            let url = url_clone.clone();
            let _headers = headers.clone();
            let http = http.clone();
            Box::pin(async move {
                let resp = http
                    .post(&url)
                    .json(&args)
                    .send()
                    .await?
                    .json::<Value>()
                    .await?;
                Ok(resp)
            })
        }))
    }

    async fn load_stdio(cfg: &McpToolConfig, stdio: &McpTypeStdio) -> Result<DynTool> {
        let name = cfg.base.name.clone();
        let description = cfg.base.description.clone();
        let command = stdio.command.clone();
        let args = stdio.args.clone();
        let env = expand_env_vars(stdio.env.as_ref());

        info!("Loading MCP stdio tool '{name}' via command '{command}'");

        // TODO: replace with adk-rust stdio MCP toolset:
        //
        //   use adk_rust::tool::mcp::{McpToolset, StdioConnectionParams, StdioServerParameters};
        //   let toolset = McpToolset::builder()
        //       .connection(StdioConnectionParams {
        //           server_params: StdioServerParameters { command, args, env },
        //       })
        //       .build()
        //       .await?;

        Ok(DynTool::new(name, description, move |input: Value| {
            let command = command.clone();
            let args = args.clone();
            let env = env.clone();
            Box::pin(async move {
                let mut child = tokio::process::Command::new(&command)
                    .args(&args)
                    .envs(&env)
                    .stdin(std::process::Stdio::piped())
                    .stdout(std::process::Stdio::piped())
                    .spawn()
                    .map_err(|e| AgentKitError::ToolLoad {
                        name: command.clone(),
                        reason: e.to_string(),
                    })?;

                if let Some(stdin) = child.stdin.take() {
                    use tokio::io::AsyncWriteExt;
                    let mut stdin = stdin;
                    let payload = serde_json::to_vec(&input)?;
                    stdin.write_all(&payload).await.ok();
                }

                let output = child
                    .wait_with_output()
                    .await
                    .map_err(|e| AgentKitError::ToolLoad {
                        name: command.clone(),
                        reason: e.to_string(),
                    })?;

                let response: Value = serde_json::from_slice(&output.stdout)
                    .unwrap_or(json!({"stdout": String::from_utf8_lossy(&output.stdout).to_string()}));
                Ok(response)
            })
        }))
    }
}

// ── Helpers ───────────────────────────────────────────────────────────────────

fn resolve_headers(cfg: &McpToolConfig) -> HashMap<String, String> {
    if let McpConfig::Streamable(s) = &cfg.mcp_config {
        if let Some(auth) = &s.auth {
            return auth
                .headers
                .as_deref()
                .unwrap_or_default()
                .iter()
                .filter_map(|h| {
                    let name = h.header_name.as_deref()?;
                    let env_key = h.header_value.as_deref()?;
                    let value = std::env::var(env_key).ok()?;
                    Some((name.to_string(), value))
                })
                .collect();
        }
    }
    HashMap::new()
}