use crate::types::{McpServerConfig, McpConnectionStatus, McpTool, ToolDefinition, ToolInputSchema};
use std::process::Stdio;
#[derive(Debug, Clone)]
pub struct McpConnection {
pub name: String,
pub status: McpConnectionStatus,
pub tools: Vec<ToolDefinition>,
}
impl McpConnection {
pub async fn close(&mut self) {
self.status = McpConnectionStatus::Disconnected;
self.tools.clear();
}
}
pub async fn connect_mcp_server(
name: &str,
config: &McpServerConfig,
) -> Result<McpConnection, crate::error::AgentError> {
match config {
McpServerConfig::Stdio(stdio_config) => connect_mcp_stdio(name, stdio_config).await,
McpServerConfig::Sse(_sse_config) => {
Ok(McpConnection {
name: name.to_string(),
status: McpConnectionStatus::Error,
tools: vec![],
})
}
McpServerConfig::Http(_http_config) => {
Ok(McpConnection {
name: name.to_string(),
status: McpConnectionStatus::Error,
tools: vec![],
})
}
}
}
async fn connect_mcp_stdio(
name: &str,
config: &crate::types::McpStdioConfig,
) -> Result<McpConnection, crate::error::AgentError> {
use std::collections::HashMap;
use tokio::process::Command;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
let mut env_vars: HashMap<String, String> = std::env::vars().collect();
if let Some(custom_env) = &config.env {
for (key, value) in custom_env {
env_vars.insert(key.clone(), value.clone());
}
}
let mut child = Command::new(&config.command)
.args(config.args.as_deref().unwrap_or(&[]))
.envs(&env_vars)
.kill_on_drop(true)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.stdin(Stdio::piped())
.spawn()
.map_err(|e| crate::error::AgentError::Mcp(format!("Failed to spawn MCP server: {}", e)))?;
let stdout = child.stdout.take().ok_or_else(|| {
crate::error::AgentError::Mcp("Failed to take stdout".to_string())
})?;
let mut stdin = child.stdin.take().ok_or_else(|| {
crate::error::AgentError::Mcp("Failed to take stdin".to_string())
})?;
let mut stdout_reader = BufReader::new(stdout).lines();
let initialize_request = serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {
"name": format!("agent-sdk-{}", name),
"version": "1.0.0"
}
}
});
stdin.write_all(format!("{initialize_request}\n").as_bytes()).await
.map_err(|e| crate::error::AgentError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))?;
stdin.flush().await
.map_err(|e| crate::error::AgentError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))?;
let _ = stdout_reader.next_line().await;
let list_tools_request = serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list"
});
stdin.write_all(format!("{list_tools_request}\n").as_bytes()).await
.map_err(|e| crate::error::AgentError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))?;
stdin.flush().await
.map_err(|e| crate::error::AgentError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))?;
let mut tools = vec![];
if let Ok(Some(response)) = stdout_reader.next_line().await {
if let Ok(resp) = serde_json::from_str::<serde_json::Value>(&response) {
if let Some(result) = resp.get("result") {
if let Some(tools_array) = result.get("tools").and_then(|t| t.as_array()) {
for tool_val in tools_array {
if let Ok(mcp_tool) = serde_json::from_value::<McpTool>(tool_val.clone()) {
let tool_def = create_mcp_tool_definition(name, &mcp_tool);
tools.push(tool_def);
}
}
}
}
}
}
drop(stdin);
Ok(McpConnection {
name: name.to_string(),
status: McpConnectionStatus::Connected,
tools,
})
}
fn create_mcp_tool_definition(server_name: &str, mcp_tool: &McpTool) -> ToolDefinition {
let tool_name = format!("mcp__{}__{}", server_name, mcp_tool.name);
let input_schema = mcp_tool.input_schema.clone().unwrap_or_else(|| {
serde_json::json!({
"type": "object",
"properties": {}
})
});
ToolDefinition {
name: tool_name,
description: mcp_tool.description.clone().unwrap_or_else(|| format!("MCP tool: {}", mcp_tool.name)),
input_schema: ToolInputSchema {
schema_type: input_schema.get("type").and_then(|t| t.as_str()).unwrap_or("object").to_string(),
properties: input_schema.get("properties").cloned().unwrap_or(serde_json::json!({})),
required: input_schema.get("required").and_then(|r| r.as_array()).map(|arr| {
arr.iter().filter_map(|s| s.as_str().map(String::from)).collect()
}),
},
}
}
pub async fn close_all_connections(connections: &mut Vec<McpConnection>) {
for conn in connections.iter_mut() {
conn.close().await;
}
connections.clear();
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_mcp_stdio_config() {
let mut env = HashMap::new();
env.insert("KEY".to_string(), "value".to_string());
let config = McpServerConfig::Stdio(crate::types::McpStdioConfig {
transport_type: Some("stdio".to_string()),
command: "npx".to_string(),
args: Some(vec!["-y".to_string(), "some-server".to_string()]),
env: Some(env),
});
match config {
McpServerConfig::Stdio(s) => {
assert_eq!(s.command, "npx");
assert_eq!(s.args.unwrap().len(), 2);
}
_ => panic!("Expected Stdio variant"),
}
}
#[test]
fn test_mcp_tool_definition_creation() {
let mcp_tool = McpTool {
name: "test_tool".to_string(),
description: Some("A test tool".to_string()),
input_schema: Some(serde_json::json!({
"type": "object",
"properties": {
"arg1": { "type": "string" }
},
"required": ["arg1"]
})),
};
let tool_def = create_mcp_tool_definition("myserver", &mcp_tool);
assert_eq!(tool_def.name, "mcp__myserver__test_tool");
assert_eq!(tool_def.description, "A test tool");
}
#[tokio::test]
async fn test_mcp_connection_status() {
let conn = McpConnection {
name: "test".to_string(),
status: McpConnectionStatus::Connected,
tools: vec![],
};
assert_eq!(conn.status, McpConnectionStatus::Connected);
}
#[tokio::test]
async fn test_sse_config() {
let config = McpServerConfig::Sse(crate::types::McpSseConfig {
transport_type: "sse".to_string(),
url: "http://localhost:3000/sse".to_string(),
headers: None,
});
match config {
McpServerConfig::Sse(s) => {
assert_eq!(s.url, "http://localhost:3000/sse");
}
_ => panic!("Expected Sse variant"),
}
}
#[tokio::test]
async fn test_http_config() {
let config = McpServerConfig::Http(crate::types::McpHttpConfig {
transport_type: "http".to_string(),
url: "http://localhost:3000/mcp".to_string(),
headers: Some({
let mut h = HashMap::new();
h.insert("Authorization".to_string(), "Bearer token".to_string());
h
}),
});
match config {
McpServerConfig::Http(h) => {
assert_eq!(h.url, "http://localhost:3000/mcp");
assert!(h.headers.is_some());
}
_ => panic!("Expected Http variant"),
}
}
}