use std::sync::Arc;
use tokio::sync::{Notify, RwLock};
use super::config::{McpServerConfig, McpServerTransport};
use bitrouter_core::errors::{BitrouterError, Result as BResult};
use bitrouter_core::tools::provider::ToolProvider;
use bitrouter_core::tools::result::{ToolCallResult, ToolContent};
use crate::mcp::transports::McpTransport;
use crate::mcp::transports::TransportKind;
use bitrouter_core::api::mcp::gateway::McpClientRequestHandler;
use bitrouter_core::api::mcp::types::McpGatewayError;
use bitrouter_core::api::mcp::types::{
McpContent, McpGetPromptResult, McpPrompt, McpPromptArgument, McpResource, McpResourceContent,
McpResourceTemplate, McpTool, McpToolCallResult,
};
pub struct NamespacedResource {
pub uri: String,
pub name: String,
pub description: Option<String>,
pub mime_type: Option<String>,
}
pub struct NamespacedResourceTemplate {
pub uri_template: String,
pub name: String,
pub description: Option<String>,
pub mime_type: Option<String>,
}
pub struct NamespacedPrompt {
pub name: String,
pub description: Option<String>,
pub arguments: Vec<McpPromptArgument>,
}
pub struct UpstreamConnection {
name: String,
transport: TransportKind,
tools: Arc<RwLock<Vec<McpTool>>>,
resources: Arc<RwLock<Vec<McpResource>>>,
resource_templates: Arc<RwLock<Vec<McpResourceTemplate>>>,
prompts: Arc<RwLock<Vec<McpPrompt>>>,
tool_notify: Arc<Notify>,
resource_notify: Arc<Notify>,
prompt_notify: Arc<Notify>,
}
impl UpstreamConnection {
pub async fn connect(
config: McpServerConfig,
handler: Option<Arc<dyn McpClientRequestHandler>>,
) -> Result<Self, McpGatewayError> {
config
.validate()
.map_err(|reason| McpGatewayError::InvalidConfig { reason })?;
if config.name.contains("__") {
return Err(McpGatewayError::InvalidConfig {
reason: format!(
"server name '{}' must not contain '__' (reserved as wire-format separator)",
config.name
),
});
}
let name = config.name.clone();
let tool_notify = Arc::new(Notify::new());
let resource_notify = Arc::new(Notify::new());
let prompt_notify = Arc::new(Notify::new());
match config.transport {
McpServerTransport::Http {
ref url,
ref headers,
} => {
use crate::mcp::transports::http::NotifyHandles;
let notify = NotifyHandles {
tool: Arc::clone(&tool_notify),
resource: Arc::clone(&resource_notify),
prompt: Arc::clone(&prompt_notify),
};
let client = crate::mcp::transports::http::McpHttpClient::new(
name.clone(),
url.clone(),
headers,
handler,
Some(notify),
)?;
client
.initialize()
.await
.map_err(|e| McpGatewayError::UpstreamConnect {
name: name.clone(),
reason: e.to_string(),
})?;
let initial_tools =
client
.list_tools()
.await
.map_err(|e| McpGatewayError::UpstreamConnect {
name: name.clone(),
reason: format!("failed to list tools: {e}"),
})?;
let initial_resources = client.list_resources().await.unwrap_or_default();
let initial_templates = client.list_resource_templates().await.unwrap_or_default();
let initial_prompts = client.list_prompts().await.unwrap_or_default();
Ok(Self {
name: config.name,
transport: TransportKind::Http(client),
tools: Arc::new(RwLock::new(initial_tools)),
resources: Arc::new(RwLock::new(initial_resources)),
resource_templates: Arc::new(RwLock::new(initial_templates)),
prompts: Arc::new(RwLock::new(initial_prompts)),
tool_notify,
resource_notify,
prompt_notify,
})
}
}
}
pub async fn raw_tools(&self) -> Vec<McpTool> {
self.tools.read().await.clone()
}
pub async fn raw_resources(&self) -> Vec<McpResource> {
self.resources.read().await.clone()
}
pub async fn raw_resource_templates(&self) -> Vec<McpResourceTemplate> {
self.resource_templates.read().await.clone()
}
pub async fn raw_prompts(&self) -> Vec<McpPrompt> {
self.prompts.read().await.clone()
}
pub async fn namespaced_tools(&self) -> Vec<McpTool> {
let tools = self.tools.read().await;
tools
.iter()
.map(|t| McpTool {
name: format!("{}/{}", self.name, t.name),
description: t.description.clone(),
input_schema: t.input_schema.clone(),
})
.collect()
}
pub async fn refresh_tools(&self) -> Result<(), McpGatewayError> {
let fresh = self.transport.list_tools().await?;
let mut cache = self.tools.write().await;
*cache = fresh;
Ok(())
}
pub async fn call_tool(
&self,
tool_name: &str,
arguments: Option<serde_json::Map<String, serde_json::Value>>,
) -> Result<McpToolCallResult, McpGatewayError> {
self.transport.call_tool(tool_name, arguments).await
}
pub async fn tool_count(&self) -> usize {
self.tools.read().await.len()
}
pub fn tool_change_notify(&self) -> Arc<Notify> {
Arc::clone(&self.tool_notify)
}
pub fn resource_change_notify(&self) -> Arc<Notify> {
Arc::clone(&self.resource_notify)
}
pub fn prompt_change_notify(&self) -> Arc<Notify> {
Arc::clone(&self.prompt_notify)
}
pub async fn namespaced_resources(&self) -> Vec<NamespacedResource> {
let resources = self.resources.read().await;
resources
.iter()
.map(|r| NamespacedResource {
uri: format!("{}+{}", self.name, r.uri),
name: r.name.clone(),
description: r.description.clone(),
mime_type: r.mime_type.clone(),
})
.collect()
}
pub async fn namespaced_resource_templates(&self) -> Vec<NamespacedResourceTemplate> {
let templates = self.resource_templates.read().await;
templates
.iter()
.map(|t| NamespacedResourceTemplate {
uri_template: format!("{}+{}", self.name, t.uri_template),
name: t.name.clone(),
description: t.description.clone(),
mime_type: t.mime_type.clone(),
})
.collect()
}
pub async fn read_resource(
&self,
uri: &str,
) -> Result<Vec<McpResourceContent>, McpGatewayError> {
self.transport.read_resource(uri).await
}
pub async fn refresh_resources(&self) -> Result<(), McpGatewayError> {
let fresh_resources = self.transport.list_resources().await?;
let fresh_templates = self.transport.list_resource_templates().await?;
{
let mut cache = self.resources.write().await;
*cache = fresh_resources;
}
{
let mut cache = self.resource_templates.write().await;
*cache = fresh_templates;
}
Ok(())
}
pub async fn namespaced_prompts(&self) -> Vec<NamespacedPrompt> {
let prompts = self.prompts.read().await;
prompts
.iter()
.map(|p| NamespacedPrompt {
name: format!("{}/{}", self.name, p.name),
description: p.description.clone(),
arguments: p.arguments.clone(),
})
.collect()
}
pub async fn get_prompt(
&self,
name: &str,
arguments: Option<std::collections::HashMap<String, String>>,
) -> Result<McpGetPromptResult, McpGatewayError> {
self.transport.get_prompt(name, arguments).await
}
pub async fn refresh_prompts(&self) -> Result<(), McpGatewayError> {
let fresh = self.transport.list_prompts().await?;
let mut cache = self.prompts.write().await;
*cache = fresh;
Ok(())
}
}
impl ToolProvider for UpstreamConnection {
fn provider_name(&self) -> &str {
&self.name
}
async fn call_tool(
&self,
tool_id: &str,
arguments: serde_json::Value,
) -> BResult<ToolCallResult> {
let args = match arguments {
serde_json::Value::Object(map) => Some(map),
serde_json::Value::Null => None,
other => {
return Err(BitrouterError::invalid_request(
Some(&self.name),
format!("tool arguments must be a JSON object, got {}", other),
None,
));
}
};
let mcp_result = self
.transport
.call_tool(tool_id, args)
.await
.map_err(|e| BitrouterError::transport(Some(&self.name), e.to_string()))?;
Ok(mcp_result_to_tool_result(mcp_result))
}
}
fn mcp_result_to_tool_result(mcp: McpToolCallResult) -> ToolCallResult {
let content = mcp
.content
.into_iter()
.map(|c| match c {
McpContent::Text { text } => ToolContent::Text { text },
})
.collect();
ToolCallResult {
content,
is_error: mcp.is_error.unwrap_or(false),
metadata: None,
}
}