use std::collections::HashMap;
use serde_json::{json, Value};
use tracing::info;
use crate::config::tool::{McpConfig, McpToolConfig, McpTypeStreamable, McpTypeStdio};
use crate::error::{AgentKitError, Result};
use crate::tools::loader::DynTool;
use crate::utils::{expand_env_vars, get_protocol};
pub struct McpToolLoader;
impl McpToolLoader {
pub async fn load(cfg: &McpToolConfig) -> Result<DynTool> {
match &cfg.mcp_config {
McpConfig::Streamable(s) => Self::load_streamable(cfg, s).await,
McpConfig::Stdio(s) => Self::load_stdio(cfg, s).await,
}
}
async fn load_streamable(cfg: &McpToolConfig, stream: &McpTypeStreamable) -> Result<DynTool> {
let protocol = get_protocol();
let url = format!(
"{protocol}://{}:{}{}",
stream.url, stream.port, stream.path
);
let headers = resolve_headers(cfg);
let _filter = cfg.tool_filter.clone().unwrap_or_default();
let name = cfg.base.name.clone();
let description = cfg.base.description.clone();
info!("Loading MCP streamable-HTTP tool '{name}' from {url}");
let http = reqwest::Client::new();
let url_clone = url.clone();
Ok(DynTool::new(name.clone(), description, move |args: Value| {
let url = url_clone.clone();
let _headers = headers.clone();
let http = http.clone();
Box::pin(async move {
let resp = http
.post(&url)
.json(&args)
.send()
.await?
.json::<Value>()
.await?;
Ok(resp)
})
}))
}
async fn load_stdio(cfg: &McpToolConfig, stdio: &McpTypeStdio) -> Result<DynTool> {
let name = cfg.base.name.clone();
let description = cfg.base.description.clone();
let command = stdio.command.clone();
let args = stdio.args.clone();
let env = expand_env_vars(stdio.env.as_ref());
info!("Loading MCP stdio tool '{name}' via command '{command}'");
Ok(DynTool::new(name, description, move |input: Value| {
let command = command.clone();
let args = args.clone();
let env = env.clone();
Box::pin(async move {
let mut child = tokio::process::Command::new(&command)
.args(&args)
.envs(&env)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.spawn()
.map_err(|e| AgentKitError::ToolLoad {
name: command.clone(),
reason: e.to_string(),
})?;
if let Some(stdin) = child.stdin.take() {
use tokio::io::AsyncWriteExt;
let mut stdin = stdin;
let payload = serde_json::to_vec(&input)?;
stdin.write_all(&payload).await.ok();
}
let output = child
.wait_with_output()
.await
.map_err(|e| AgentKitError::ToolLoad {
name: command.clone(),
reason: e.to_string(),
})?;
let response: Value = serde_json::from_slice(&output.stdout)
.unwrap_or(json!({"stdout": String::from_utf8_lossy(&output.stdout).to_string()}));
Ok(response)
})
}))
}
}
fn resolve_headers(cfg: &McpToolConfig) -> HashMap<String, String> {
if let McpConfig::Streamable(s) = &cfg.mcp_config {
if let Some(auth) = &s.auth {
return auth
.headers
.as_deref()
.unwrap_or_default()
.iter()
.filter_map(|h| {
let name = h.header_name.as_deref()?;
let env_key = h.header_value.as_deref()?;
let value = std::env::var(env_key).ok()?;
Some((name.to_string(), value))
})
.collect();
}
}
HashMap::new()
}