use std::collections::HashMap;
use regex::Regex;
use reqwest::header::HeaderMap;
use rmcp::model::{
CallToolRequest, CallToolRequestMethod, ClientRequest, Extensions, GetMeta, ServerResult,
};
use rmcp::transport::sse_client::SseClientConfig;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use rmcp::transport::StreamableHttpClientTransport;
use rmcp::ServiceError;
use rmcp::{model::CallToolRequestParam, transport::SseClientTransport, RoleClient};
use tracing::debug;
use crate::types::gateway::{McpDefinition, McpTool, McpTransportType, ServerTools, ToolsFilter};
use rmcp::service::DynService;
use rmcp::service::RunningService;
use rmcp::service::ServiceExt;
#[derive(Debug, thiserror::Error)]
pub enum McpServerError {
#[error("Invalid server name: {0}")]
InvalidServerName(String),
#[error("Server initialization error: {0}")]
ServerInitializeError(#[from] Box<rmcp::service::ServerInitializeError>),
#[error("SSE transport error: {0}")]
SseTransportError(#[from] rmcp::transport::sse_client::SseTransportError<reqwest::Error>),
#[error("Client initialization error: {0}")]
ClientInitializeError(#[from] Box<rmcp::service::ClientInitializeError>),
#[error("Service error: {0}")]
ServiceError(#[from] rmcp::ServiceError),
#[error("Client start error: {0}")]
ClientStartError(String),
#[error(transparent)]
ReqwestError(#[from] reqwest::Error),
#[error(transparent)]
BoxedError(#[from] Box<dyn std::error::Error + Send + Sync>),
#[error(transparent)]
StdIOError(#[from] std::io::Error),
#[error(transparent)]
ParseError(#[from] serde_json::Error),
#[error("No text content in tool {0} result")]
NoTextInToolResult(String),
#[error("Join error: {0}")]
JoinError(#[from] tokio::task::JoinError),
}
impl From<rmcp::service::ClientInitializeError> for McpServerError {
fn from(value: rmcp::service::ClientInitializeError) -> Self {
McpServerError::ClientInitializeError(Box::new(value))
}
}
impl From<rmcp::service::ServerInitializeError> for McpServerError {
fn from(value: rmcp::service::ServerInitializeError) -> Self {
McpServerError::ServerInitializeError(Box::new(value))
}
}
fn validate_server_name(name: &str) -> Result<(), McpServerError> {
match name {
"websearch" | "Web Search" => Ok(()),
_ => Err(McpServerError::InvalidServerName(name.to_string())),
}
}
pub fn stdio() -> (tokio::io::Stdin, tokio::io::Stdout) {
(tokio::io::stdin(), tokio::io::stdout())
}
fn create_reqwest_client_with_headers(
headers: &HashMap<String, String>,
) -> Result<reqwest::Client, McpServerError> {
let mut headers_map = HeaderMap::new();
for (key, value) in headers {
match (
key.parse::<reqwest::header::HeaderName>(),
value.parse::<reqwest::header::HeaderValue>(),
) {
(Ok(header_name), Ok(header_value)) => {
headers_map.insert(header_name, header_value);
}
_ => {
tracing::warn!("Invalid header: {:?}", (key, value));
continue;
}
}
}
reqwest::Client::builder()
.default_headers(headers_map)
.build()
.map_err(McpServerError::ReqwestError)
}
pub async fn get_transport(
definition: &McpDefinition,
) -> Result<RunningService<RoleClient, Box<dyn DynService<RoleClient>>>, McpServerError> {
match &definition.r#type {
McpTransportType::Sse {
server_url,
headers,
..
} => {
let reqwest_client = create_reqwest_client_with_headers(headers)?;
let transport = SseClientTransport::start_with_client(
reqwest_client,
SseClientConfig {
sse_endpoint: server_url.clone().into(),
..Default::default()
},
)
.await?;
Ok(()
.into_dyn()
.serve(transport)
.await
.map_err(|e| McpServerError::ClientStartError(e.to_string()))?)
}
McpTransportType::Http {
server_url,
headers,
..
} => {
let reqwest_client = create_reqwest_client_with_headers(headers)?;
let transport = StreamableHttpClientTransport::with_client(
reqwest_client,
StreamableHttpClientTransportConfig::with_uri(server_url.clone()),
);
Ok(()
.into_dyn()
.serve(transport)
.await
.map_err(|e| McpServerError::ClientStartError(e.to_string()))?)
}
McpTransportType::InMemory { name, .. } => {
validate_server_name(name)?;
let transport = SseClientTransport::start(
std::env::var("TAVILY_MCP_URL").unwrap_or("http://localhost:8083/sse".to_string()),
)
.await?;
Ok(()
.into_dyn()
.serve(transport)
.await
.map_err(|e| McpServerError::ClientStartError(e.to_string()))?)
}
_ => Err(McpServerError::InvalidServerName(
"Invalid or unsupported server type".to_string(),
)),
}
}
pub async fn get_tools(definitions: &[McpDefinition]) -> Result<Vec<ServerTools>, McpServerError> {
let mut all_tools = Vec::new();
for tool_def in definitions {
let mcp_server_name = tool_def.server_name();
let client = get_transport(tool_def).await?;
let tools = client.list_tools(Default::default()).await?;
client.cancel().await?;
let mut tools = tools.tools;
let total_tools = tools.len();
match &tool_def.filter {
ToolsFilter::All => {
tracing::debug!("Loading all {} tools from {}", total_tools, mcp_server_name);
}
ToolsFilter::Selected(selected) => {
let before_count = tools.len();
tools.retain_mut(|tool| {
let found = selected.iter().find(|t| {
if tool.name == t.name {
true
} else if let Ok(name_regex) = Regex::new(&t.name) {
debug!("Matching {} against pattern {}", tool.name, t.name);
name_regex.is_match(&tool.name)
} else {
false
}
});
if let Some(Some(d)) = found.as_ref().map(|t| t.description.as_ref()) {
tool.description = Some(d.clone().into());
}
found.is_some()
});
tracing::debug!(
"Filtered tools for {}: {}/{} tools selected",
mcp_server_name,
tools.len(),
before_count
);
}
}
let mcp_tools = tools
.into_iter()
.map(|t| McpTool(t, tool_def.clone()))
.collect();
all_tools.push(ServerTools {
tools: mcp_tools,
definition: tool_def.clone(),
});
}
tracing::debug!("Loaded {} tool definitions in total", all_tools.len());
Ok(all_tools)
}
pub async fn get_raw_tools(
definitions: &McpDefinition,
) -> Result<Vec<rmcp::model::Tool>, McpServerError> {
let client = get_transport(definitions).await?;
let tools = client.list_tools(Default::default()).await?;
client.cancel().await?;
Ok(tools.tools)
}
pub async fn execute_mcp_tool(
def: &McpDefinition,
tool: &rmcp::model::Tool,
inputs: HashMap<String, serde_json::Value>,
mut meta: Option<serde_json::Value>,
) -> Result<String, McpServerError> {
if let McpTransportType::InMemory { .. } = def.r#type {
if def.server_name() == "websearch" {
if let Ok(var) = std::env::var("TAVILY_API_KEY") {
meta = Some(serde_json::json!({"env_vars": {
"TAVILY_API_KEY": var
}}));
}
}
}
let name = tool.name.clone();
let client = get_transport(def).await?;
let mut args = serde_json::Map::new();
for (key, value) in inputs {
args.insert(key, value);
}
let params = CallToolRequestParam {
name: tool.name.clone(),
arguments: Some(args),
};
let mut t = ClientRequest::CallToolRequest(CallToolRequest {
method: CallToolRequestMethod,
params,
extensions: Extensions::default(),
});
if let Some(meta) = meta {
if let Some(map) = meta.as_object() {
for (key, value) in map {
t.get_meta_mut().insert(key.clone(), value.clone());
}
}
}
let response = client.send_request(t).await?;
client.cancel().await?;
let response = match response {
ServerResult::CallToolResult(result) => Ok(result),
_ => Err(ServiceError::UnexpectedResponse),
}?;
if !response.content.is_empty() {
if let Some(content) = response.content.first() {
if let Some(text) = content.raw.as_text().map(|t| t.text.clone()) {
tracing::debug!("Tool {name}: execution completed successfully", name = name);
return Ok(text);
}
}
}
tracing::error!("Tool {name}: No text content in tool response", name = name);
Err(McpServerError::NoTextInToolResult(name.to_string()))
}