use std::collections::HashMap;
use std::process::Stdio;
use std::sync::Arc;
use anyhow::Result;
use process_wrap::tokio::{CommandWrap, KillOnDrop};
use tokio_util::sync::CancellationToken;
use tracing::info;
use rmcp::{
ServiceExt,
model::{ClientCapabilities, ClientInfo},
transport::{
TokioChildProcess,
streamable_http_client::{
StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
},
streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService},
},
};
#[cfg(unix)]
use process_wrap::tokio::ProcessGroup;
#[cfg(windows)]
use process_wrap::tokio::{CreationFlags, JobObject};
use crate::{ProxyAwareSessionManager, ProxyHandler, ToolFilter};
pub use mcp_common::ToolFilter as CommonToolFilter;
#[derive(Debug, Clone)]
pub enum BackendConfig {
Stdio {
command: String,
args: Option<Vec<String>>,
env: Option<HashMap<String, String>>,
},
Url {
url: String,
headers: Option<HashMap<String, String>>,
},
}
#[derive(Debug, Clone, Default)]
pub struct StreamServerConfig {
pub stateful_mode: bool,
pub mcp_id: Option<String>,
pub tool_filter: Option<ToolFilter>,
}
pub struct StreamServerBuilder {
backend_config: BackendConfig,
server_config: StreamServerConfig,
}
impl StreamServerBuilder {
pub fn new(backend: BackendConfig) -> Self {
Self {
backend_config: backend,
server_config: StreamServerConfig::default(),
}
}
pub fn stateful(mut self, enabled: bool) -> Self {
self.server_config.stateful_mode = enabled;
self
}
pub fn mcp_id(mut self, id: impl Into<String>) -> Self {
self.server_config.mcp_id = Some(id.into());
self
}
pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
self.server_config.tool_filter = Some(filter);
self
}
pub async fn build(self) -> Result<(axum::Router, CancellationToken, ProxyHandler)> {
let mcp_id = self
.server_config
.mcp_id
.clone()
.unwrap_or_else(|| "stream-proxy".into());
let capabilities = ClientCapabilities::builder()
.enable_experimental()
.enable_roots()
.enable_roots_list_changed()
.enable_sampling()
.build();
let client_info = ClientInfo::new(
capabilities,
rmcp::model::Implementation::new("mcp-streamable-proxy", env!("CARGO_PKG_VERSION")),
);
let client = match &self.backend_config {
BackendConfig::Stdio { command, args, env } => {
self.connect_stdio(command, args, env, &client_info).await?
}
BackendConfig::Url { url, headers } => {
self.connect_url(url, headers, &client_info).await?
}
};
let proxy_handler = if let Some(ref tool_filter) = self.server_config.tool_filter {
ProxyHandler::with_tool_filter(client, mcp_id.clone(), tool_filter.clone())
} else {
ProxyHandler::with_mcp_id(client, mcp_id.clone())
};
let handler_for_return = proxy_handler.clone();
let (router, ct) = self.create_server(proxy_handler).await?;
info!(
"[StreamServerBuilder] Server created - mcp_id: {}, stateful: {}",
mcp_id, self.server_config.stateful_mode
);
Ok((router, ct, handler_for_return))
}
async fn connect_stdio(
&self,
command: &str,
args: &Option<Vec<String>>,
env: &Option<HashMap<String, String>>,
client_info: &ClientInfo,
) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
let args = args.clone();
let mut wrapped_cmd = CommandWrap::with_new(command, |cmd| {
if let Some(cmd_args) = &args {
cmd.args(cmd_args);
}
if let Some(env_vars) = env {
for (k, v) in env_vars {
cmd.env(k, v);
}
}
});
#[cfg(unix)]
wrapped_cmd.wrap(ProcessGroup::leader());
#[cfg(windows)]
{
use windows::Win32::System::Threading::{CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW};
info!(
"[StreamServerBuilder] Setting CreationFlags: CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP"
);
wrapped_cmd.wrap(CreationFlags(CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP));
wrapped_cmd.wrap(JobObject);
}
wrapped_cmd.wrap(KillOnDrop);
info!(
"[StreamServerBuilder] Starting child process - command: {}, args: {:?}",
command,
args.as_ref().unwrap_or(&vec![])
);
let mcp_id = self.server_config.mcp_id.as_deref().unwrap_or("unknown");
mcp_common::diagnostic::log_stdio_spawn_context("StreamServerBuilder", mcp_id, env);
let (tokio_process, child_stderr) = TokioChildProcess::builder(wrapped_cmd)
.stderr(Stdio::piped())
.spawn()
.map_err(|e| {
anyhow::anyhow!(
"{}",
mcp_common::diagnostic::format_spawn_error(mcp_id, command, &args, e)
)
})?;
if let Some(stderr_pipe) = child_stderr {
mcp_common::spawn_stderr_reader(stderr_pipe, mcp_id.to_string());
}
let client = client_info.clone().serve(tokio_process).await?;
info!("[StreamServerBuilder] Child process connected successfully");
Ok(client)
}
async fn connect_url(
&self,
url: &str,
headers: &Option<HashMap<String, String>>,
client_info: &ClientInfo,
) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
info!("[StreamServerBuilder] Connecting to URL backend: {}", url);
let mut req_headers = reqwest::header::HeaderMap::new();
let mut auth_header: Option<String> = None;
if let Some(config_headers) = headers {
for (key, value) in config_headers {
if key.eq_ignore_ascii_case("Authorization") {
auth_header = Some(value.strip_prefix("Bearer ").unwrap_or(value).to_string());
continue;
}
req_headers.insert(
reqwest::header::HeaderName::try_from(key)
.map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
value.parse().map_err(|e| {
anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
})?,
);
}
}
let http_client = reqwest::Client::builder()
.default_headers(req_headers)
.build()
.map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
let mut config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
config.auth_header = auth_header;
let transport = StreamableHttpClientTransport::with_client(http_client, config);
let client = client_info.clone().serve(transport).await?;
info!("[StreamServerBuilder] URL backend connected successfully");
Ok(client)
}
async fn create_server(
&self,
proxy_handler: ProxyHandler,
) -> Result<(axum::Router, CancellationToken)> {
let handler = Arc::new(proxy_handler);
let ct = CancellationToken::new();
if self.server_config.stateful_mode {
let session_manager = ProxyAwareSessionManager::new(handler.clone());
let handler_for_service = handler.clone();
let mut server_config = StreamableHttpServerConfig::default();
server_config.stateful_mode = true;
let service = StreamableHttpService::new(
move || Ok((*handler_for_service).clone()),
session_manager.into(),
server_config,
);
let router = axum::Router::new().fallback_service(service);
Ok((router, ct))
} else {
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
let handler_for_service = handler.clone();
let server_config = StreamableHttpServerConfig::default(); let service = StreamableHttpService::new(
move || Ok((*handler_for_service).clone()),
LocalSessionManager::default().into(),
server_config,
);
let router = axum::Router::new().fallback_service(service);
Ok((router, ct))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_creation() {
let builder = StreamServerBuilder::new(BackendConfig::Stdio {
command: "echo".into(),
args: Some(vec!["hello".into()]),
env: None,
})
.mcp_id("test")
.stateful(true);
assert!(builder.server_config.mcp_id.is_some());
assert_eq!(builder.server_config.mcp_id.as_deref(), Some("test"));
assert!(builder.server_config.stateful_mode);
}
#[test]
fn test_url_backend_config() {
let mut headers = HashMap::new();
headers.insert("Authorization".into(), "Bearer token123".into());
headers.insert("X-Custom".into(), "value".into());
let builder = StreamServerBuilder::new(BackendConfig::Url {
url: "http://localhost:8080/mcp".into(),
headers: Some(headers),
});
match &builder.backend_config {
BackendConfig::Url { url, headers } => {
assert_eq!(url, "http://localhost:8080/mcp");
assert!(headers.is_some());
}
_ => panic!("Expected URL backend"),
}
}
}