use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Notify;
use bitrouter_core::api::mcp::gateway::McpClientRequestHandler;
use bitrouter_core::api::mcp::types::McpGatewayError;
pub(crate) struct NotifyHandles {
pub tool: Arc<Notify>,
pub resource: Arc<Notify>,
pub prompt: Arc<Notify>,
}
pub(crate) struct BitrouterClientHandler {
server_name: String,
notify: NotifyHandles,
handler: Option<Arc<dyn McpClientRequestHandler>>,
}
impl BitrouterClientHandler {
pub fn new(
server_name: impl Into<String>,
notify: NotifyHandles,
handler: Option<Arc<dyn McpClientRequestHandler>>,
) -> Self {
Self {
server_name: server_name.into(),
notify,
handler,
}
}
}
impl rmcp::handler::client::ClientHandler for BitrouterClientHandler {
async fn create_message(
&self,
params: rmcp::model::CreateMessageRequestParams,
_context: rmcp::service::RequestContext<rmcp::service::RoleClient>,
) -> Result<rmcp::model::CreateMessageResult, rmcp::model::ErrorData> {
let Some(ref handler) = self.handler else {
return Err(rmcp::model::ErrorData::new(
rmcp::model::ErrorCode::METHOD_NOT_FOUND,
"sampling not supported",
None,
));
};
let core_params: bitrouter_core::api::mcp::types::CreateMessageParams =
serde_json::from_value(serde_json::to_value(¶ms).map_err(|e| {
rmcp::model::ErrorData::internal_error(
format!("failed to serialize sampling params: {e}"),
None,
)
})?)
.map_err(|e| {
rmcp::model::ErrorData::internal_error(
format!("failed to deserialize sampling params: {e}"),
None,
)
})?;
let result = handler
.handle_sampling(&self.server_name, core_params)
.await
.map_err(|e| {
rmcp::model::ErrorData::new(
rmcp::model::ErrorCode(e.code as i32),
e.message,
e.data,
)
})?;
serde_json::from_value(serde_json::to_value(&result).map_err(|e| {
rmcp::model::ErrorData::internal_error(
format!("failed to serialize sampling result: {e}"),
None,
)
})?)
.map_err(|e| {
rmcp::model::ErrorData::internal_error(
format!("failed to deserialize sampling result: {e}"),
None,
)
})
}
async fn create_elicitation(
&self,
params: rmcp::model::CreateElicitationRequestParams,
_context: rmcp::service::RequestContext<rmcp::service::RoleClient>,
) -> Result<rmcp::model::CreateElicitationResult, rmcp::model::ErrorData> {
let Some(ref handler) = self.handler else {
return Err(rmcp::model::ErrorData::new(
rmcp::model::ErrorCode::METHOD_NOT_FOUND,
"elicitation not supported",
None,
));
};
let core_params: bitrouter_core::api::mcp::types::ElicitationCreateParams =
serde_json::from_value(serde_json::to_value(¶ms).map_err(|e| {
rmcp::model::ErrorData::internal_error(
format!("failed to serialize elicitation params: {e}"),
None,
)
})?)
.map_err(|e| {
rmcp::model::ErrorData::internal_error(
format!("failed to deserialize elicitation params: {e}"),
None,
)
})?;
let result = handler
.handle_elicitation(&self.server_name, core_params)
.await
.map_err(|e| {
rmcp::model::ErrorData::new(
rmcp::model::ErrorCode(e.code as i32),
e.message,
e.data,
)
})?;
serde_json::from_value(serde_json::to_value(&result).map_err(|e| {
rmcp::model::ErrorData::internal_error(
format!("failed to serialize elicitation result: {e}"),
None,
)
})?)
.map_err(|e| {
rmcp::model::ErrorData::internal_error(
format!("failed to deserialize elicitation result: {e}"),
None,
)
})
}
fn on_tool_list_changed(
&self,
_context: rmcp::service::NotificationContext<rmcp::service::RoleClient>,
) -> impl Future<Output = ()> + Send + '_ {
self.notify.tool.notify_one();
std::future::ready(())
}
fn on_resource_list_changed(
&self,
_context: rmcp::service::NotificationContext<rmcp::service::RoleClient>,
) -> impl Future<Output = ()> + Send + '_ {
self.notify.resource.notify_one();
std::future::ready(())
}
fn on_prompt_list_changed(
&self,
_context: rmcp::service::NotificationContext<rmcp::service::RoleClient>,
) -> impl Future<Output = ()> + Send + '_ {
self.notify.prompt.notify_one();
std::future::ready(())
}
fn get_info(&self) -> rmcp::model::ClientInfo {
rmcp::model::ClientInfo::default()
}
}
pub(crate) struct ConnectedPeer {
inner: Box<dyn RunningServiceHandle>,
}
trait RunningServiceHandle: Send + Sync {
fn peer(&self) -> &rmcp::service::Peer<rmcp::service::RoleClient>;
}
struct Wrapper<S: rmcp::handler::client::ClientHandler>(
rmcp::service::RunningService<rmcp::service::RoleClient, S>,
);
impl<S: rmcp::handler::client::ClientHandler> RunningServiceHandle for Wrapper<S> {
fn peer(&self) -> &rmcp::service::Peer<rmcp::service::RoleClient> {
self.0.peer()
}
}
impl ConnectedPeer {
pub fn from_service<S: rmcp::handler::client::ClientHandler>(
service: rmcp::service::RunningService<rmcp::service::RoleClient, S>,
) -> Self {
Self {
inner: Box::new(Wrapper(service)),
}
}
pub fn peer(&self) -> &rmcp::service::Peer<rmcp::service::RoleClient> {
self.inner.peer()
}
}
pub(crate) fn build_http_transport(
url: &str,
headers: &HashMap<String, String>,
name: &str,
) -> Result<rmcp::transport::StreamableHttpClientTransport<reqwest::Client>, McpGatewayError> {
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
let mut header_map = HeaderMap::new();
for (k, v) in headers {
let header_name: HeaderName = k.parse().map_err(|e| McpGatewayError::UpstreamConnect {
name: name.to_owned(),
reason: format!("invalid header name '{k}': {e}"),
})?;
let header_value: HeaderValue =
v.parse().map_err(|e| McpGatewayError::UpstreamConnect {
name: name.to_owned(),
reason: format!("invalid header value for '{k}': {e}"),
})?;
header_map.insert(header_name, header_value);
}
let http_client = reqwest::Client::builder()
.default_headers(header_map)
.build()
.map_err(|e| McpGatewayError::UpstreamConnect {
name: name.to_owned(),
reason: format!("failed to build HTTP client: {e}"),
})?;
let config =
rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig::with_uri(url);
Ok(rmcp::transport::StreamableHttpClientTransport::with_client(
http_client,
config,
))
}