use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum McpTransportConfig {
Stdio {
command: String,
args: Vec<String>,
env: HashMap<String, String>,
},
Http {
url: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerConfig {
pub name: String,
pub transport: McpTransportConfig,
pub auto_connect: bool,
}
impl McpServerConfig {
pub fn stdio(
name: impl Into<String>,
command: impl Into<String>,
args: Vec<String>,
) -> Self {
Self {
name: name.into(),
transport: McpTransportConfig::Stdio {
command: command.into(),
args,
env: HashMap::new(),
},
auto_connect: true,
}
}
pub fn http(name: impl Into<String>, url: impl Into<String>) -> Self {
Self {
name: name.into(),
transport: McpTransportConfig::Http { url: url.into() },
auto_connect: true,
}
}
pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
if let McpTransportConfig::Stdio { ref mut env, .. } = self.transport {
env.insert(key.into(), value.into());
}
self
}
pub fn with_auto_connect(mut self, auto_connect: bool) -> Self {
self.auto_connect = auto_connect;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToolInfo {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerInfo {
pub name: String,
pub version: String,
pub instructions: Option<String>,
}
#[async_trait]
pub trait McpClient: Send + Sync {
async fn connect(&mut self, config: McpServerConfig) -> crate::agent::error::AgentResult<()>;
async fn disconnect(&mut self, server_name: &str) -> crate::agent::error::AgentResult<()>;
async fn list_tools(
&self,
server_name: &str,
) -> crate::agent::error::AgentResult<Vec<McpToolInfo>>;
async fn call_tool(
&self,
server_name: &str,
tool_name: &str,
arguments: serde_json::Value,
) -> crate::agent::error::AgentResult<serde_json::Value>;
async fn server_info(
&self,
server_name: &str,
) -> crate::agent::error::AgentResult<McpServerInfo>;
fn connected_servers(&self) -> Vec<String>;
fn is_connected(&self, server_name: &str) -> bool;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mcp_server_config_stdio() {
let config = McpServerConfig::stdio(
"test-server",
"node",
vec!["server.js".to_string()],
)
.with_env("API_KEY", "test-key")
.with_auto_connect(false);
assert_eq!(config.name, "test-server");
assert!(!config.auto_connect);
if let McpTransportConfig::Stdio { command, args, env } = &config.transport {
assert_eq!(command, "node");
assert_eq!(args, &["server.js"]);
assert_eq!(env.get("API_KEY"), Some(&"test-key".to_string()));
} else {
panic!("Expected Stdio transport");
}
}
#[test]
fn test_mcp_server_config_http() {
let config = McpServerConfig::http("api-server", "http://localhost:8080/mcp");
assert_eq!(config.name, "api-server");
assert!(config.auto_connect);
if let McpTransportConfig::Http { url } = &config.transport {
assert_eq!(url, "http://localhost:8080/mcp");
} else {
panic!("Expected Http transport");
}
}
#[test]
fn test_mcp_tool_info_serialization() {
let tool = McpToolInfo {
name: "list_repos".to_string(),
description: "List GitHub repositories".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"owner": { "type": "string" }
},
"required": ["owner"]
}),
};
let json = serde_json::to_string(&tool).unwrap();
let deserialized: McpToolInfo = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.name, "list_repos");
}
#[test]
fn test_mcp_server_info() {
let info = McpServerInfo {
name: "github".to_string(),
version: "1.0.0".to_string(),
instructions: Some("GitHub MCP Server".to_string()),
};
assert_eq!(info.name, "github");
assert_eq!(info.version, "1.0.0");
assert_eq!(info.instructions, Some("GitHub MCP Server".to_string()));
}
}