use std::collections::HashMap;
use anyhow::Result;
use tokio::process::Command;
use tokio_util::sync::CancellationToken;
use tracing::info;
use rmcp::{
ServiceExt,
model::{ClientCapabilities, ClientInfo},
transport::{
SseClientTransport, TokioChildProcess,
sse_client::SseClientConfig,
sse_server::{SseServer, SseServerConfig},
streamable_http_client::{
StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
},
},
};
use crate::{SseHandler, ToolFilter};
#[derive(Debug, Clone)]
pub enum BackendConfig {
Stdio {
command: String,
args: Option<Vec<String>>,
env: Option<HashMap<String, String>>,
},
SseUrl {
url: String,
headers: Option<HashMap<String, String>>,
},
StreamUrl {
url: String,
headers: Option<HashMap<String, String>>,
},
}
#[derive(Debug, Clone)]
pub struct SseServerBuilderConfig {
pub sse_path: String,
pub post_path: String,
pub mcp_id: Option<String>,
pub tool_filter: Option<ToolFilter>,
pub keep_alive_secs: u64,
}
impl Default for SseServerBuilderConfig {
fn default() -> Self {
Self {
sse_path: "/sse".into(),
post_path: "/message".into(),
mcp_id: None,
tool_filter: None,
keep_alive_secs: 15,
}
}
}
pub struct SseServerBuilder {
backend_config: BackendConfig,
server_config: SseServerBuilderConfig,
}
impl SseServerBuilder {
pub fn new(backend: BackendConfig) -> Self {
Self {
backend_config: backend,
server_config: SseServerBuilderConfig::default(),
}
}
pub fn sse_path(mut self, path: impl Into<String>) -> Self {
self.server_config.sse_path = path.into();
self
}
pub fn post_path(mut self, path: impl Into<String>) -> Self {
self.server_config.post_path = path.into();
self
}
pub fn mcp_id(mut self, id: impl Into<String>) -> Self {
self.server_config.mcp_id = Some(id.into());
self
}
pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
self.server_config.tool_filter = Some(filter);
self
}
pub fn keep_alive(mut self, secs: u64) -> Self {
self.server_config.keep_alive_secs = secs;
self
}
pub async fn build(self) -> Result<(axum::Router, CancellationToken, SseHandler)> {
let mcp_id = self
.server_config
.mcp_id
.clone()
.unwrap_or_else(|| "sse-proxy".into());
let client_info = ClientInfo {
protocol_version: Default::default(),
capabilities: ClientCapabilities::builder()
.enable_experimental()
.enable_roots()
.enable_roots_list_changed()
.enable_sampling()
.build(),
..Default::default()
};
let client = match &self.backend_config {
BackendConfig::Stdio { command, args, env } => {
self.connect_stdio(command, args, env, &client_info).await?
}
BackendConfig::SseUrl { url, headers } => {
self.connect_sse_url(url, headers, &client_info).await?
}
BackendConfig::StreamUrl { url, headers } => {
self.connect_stream_url(url, headers, &client_info).await?
}
};
let sse_handler = if let Some(ref tool_filter) = self.server_config.tool_filter {
SseHandler::with_tool_filter(client, mcp_id.clone(), tool_filter.clone())
} else {
SseHandler::with_mcp_id(client, mcp_id.clone())
};
let handler_for_return = sse_handler.clone();
let (router, ct) = self.create_server(sse_handler)?;
info!(
"[SseServerBuilder] Server created - mcp_id: {}, sse_path: {}, post_path: {}",
mcp_id, self.server_config.sse_path, self.server_config.post_path
);
Ok((router, ct, handler_for_return))
}
async fn connect_stdio(
&self,
command: &str,
args: &Option<Vec<String>>,
env: &Option<HashMap<String, String>>,
client_info: &ClientInfo,
) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
let mut cmd = Command::new(command);
if let Some(cmd_args) = args {
cmd.args(cmd_args);
}
if let Some(env_vars) = env {
for (k, v) in env_vars {
cmd.env(k, v);
}
}
info!(
"[SseServerBuilder] Starting child process - command: {}, args: {:?}",
command,
args.as_ref().unwrap_or(&vec![])
);
let tokio_process = TokioChildProcess::new(cmd)?;
let client = client_info.clone().serve(tokio_process).await?;
info!("[SseServerBuilder] Child process connected successfully");
Ok(client)
}
async fn connect_sse_url(
&self,
url: &str,
headers: &Option<HashMap<String, String>>,
client_info: &ClientInfo,
) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
info!("[SseServerBuilder] Connecting to SSE URL backend: {}", url);
let mut req_headers = reqwest::header::HeaderMap::new();
if let Some(config_headers) = headers {
for (key, value) in config_headers {
req_headers.insert(
reqwest::header::HeaderName::try_from(key)
.map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
value.parse().map_err(|e| {
anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
})?,
);
}
}
let http_client = reqwest::Client::builder()
.default_headers(req_headers)
.build()
.map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
let sse_config = SseClientConfig {
sse_endpoint: url.to_string().into(),
..Default::default()
};
let sse_transport = SseClientTransport::start_with_client(http_client, sse_config).await?;
let client = client_info.clone().serve(sse_transport).await?;
info!("[SseServerBuilder] SSE URL backend connected successfully");
Ok(client)
}
async fn connect_stream_url(
&self,
url: &str,
headers: &Option<HashMap<String, String>>,
client_info: &ClientInfo,
) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
info!(
"[SseServerBuilder] Connecting to Streamable HTTP URL backend: {}",
url
);
let mut req_headers = reqwest::header::HeaderMap::new();
let mut auth_header: Option<String> = None;
if let Some(config_headers) = headers {
for (key, value) in config_headers {
if key.eq_ignore_ascii_case("Authorization") {
auth_header = Some(value.strip_prefix("Bearer ").unwrap_or(value).to_string());
continue;
}
req_headers.insert(
reqwest::header::HeaderName::try_from(key)
.map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
value.parse().map_err(|e| {
anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
})?,
);
}
}
let http_client = reqwest::Client::builder()
.default_headers(req_headers)
.build()
.map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
let config = StreamableHttpClientTransportConfig {
uri: url.to_string().into(),
auth_header,
..Default::default()
};
let transport = StreamableHttpClientTransport::with_client(http_client, config);
let client = client_info.clone().serve(transport).await?;
info!("[SseServerBuilder] Streamable HTTP URL backend connected successfully");
Ok(client)
}
fn create_server(&self, sse_handler: SseHandler) -> Result<(axum::Router, CancellationToken)> {
let config = SseServerConfig {
bind: "0.0.0.0:0".parse()?,
sse_path: self.server_config.sse_path.clone(),
post_path: self.server_config.post_path.clone(),
ct: CancellationToken::new(),
sse_keep_alive: Some(std::time::Duration::from_secs(
self.server_config.keep_alive_secs,
)),
};
let (sse_server, router) = SseServer::new(config);
let ct = sse_server.with_service(move || sse_handler.clone());
Ok((router, ct))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_creation() {
let builder = SseServerBuilder::new(BackendConfig::Stdio {
command: "echo".into(),
args: Some(vec!["hello".into()]),
env: None,
})
.mcp_id("test")
.sse_path("/custom/sse")
.post_path("/custom/message");
assert!(builder.server_config.mcp_id.is_some());
assert_eq!(builder.server_config.mcp_id.as_deref(), Some("test"));
assert_eq!(builder.server_config.sse_path, "/custom/sse");
assert_eq!(builder.server_config.post_path, "/custom/message");
}
#[test]
fn test_default_config() {
let config = SseServerBuilderConfig::default();
assert_eq!(config.sse_path, "/sse");
assert_eq!(config.post_path, "/message");
assert_eq!(config.keep_alive_secs, 15);
}
}