1use std::collections::HashMap;
7use std::process::Stdio;
8use std::sync::Arc;
9
10use anyhow::Result;
11use process_wrap::tokio::{CommandWrap, KillOnDrop};
12use tokio_util::sync::CancellationToken;
13use tracing::info;
14
15use rmcp::{
16 ServiceExt,
17 model::{ClientCapabilities, ClientInfo},
18 transport::{
19 TokioChildProcess,
20 streamable_http_client::{
21 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
22 },
23 streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService},
24 },
25};
26
27#[cfg(unix)]
29use process_wrap::tokio::ProcessGroup;
30
31#[cfg(windows)]
33use process_wrap::tokio::{CreationFlags, JobObject};
34
35use crate::{ProxyAwareSessionManager, ProxyHandler, ToolFilter};
36pub use mcp_common::ToolFilter as CommonToolFilter;
37
38#[derive(Debug, Clone)]
42pub enum BackendConfig {
43 Stdio {
45 command: String,
47 args: Option<Vec<String>>,
49 env: Option<HashMap<String, String>>,
51 },
52 Url {
54 url: String,
56 headers: Option<HashMap<String, String>>,
58 },
59}
60
61#[derive(Debug, Clone, Default)]
63pub struct StreamServerConfig {
64 pub stateful_mode: bool,
66 pub mcp_id: Option<String>,
68 pub tool_filter: Option<ToolFilter>,
70}
71
72pub struct StreamServerBuilder {
93 backend_config: BackendConfig,
94 server_config: StreamServerConfig,
95}
96
97impl StreamServerBuilder {
98 pub fn new(backend: BackendConfig) -> Self {
100 Self {
101 backend_config: backend,
102 server_config: StreamServerConfig::default(),
103 }
104 }
105
106 pub fn stateful(mut self, enabled: bool) -> Self {
110 self.server_config.stateful_mode = enabled;
111 self
112 }
113
114 pub fn mcp_id(mut self, id: impl Into<String>) -> Self {
118 self.server_config.mcp_id = Some(id.into());
119 self
120 }
121
122 pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
124 self.server_config.tool_filter = Some(filter);
125 self
126 }
127
128 pub async fn build(self) -> Result<(axum::Router, CancellationToken, ProxyHandler)> {
134 let mcp_id = self
135 .server_config
136 .mcp_id
137 .clone()
138 .unwrap_or_else(|| "stream-proxy".into());
139
140 let capabilities = ClientCapabilities::builder()
142 .enable_experimental()
143 .enable_roots()
144 .enable_roots_list_changed()
145 .enable_sampling()
146 .build();
147 let client_info = ClientInfo::new(
148 capabilities,
149 rmcp::model::Implementation::new("mcp-streamable-proxy", env!("CARGO_PKG_VERSION")),
150 );
151
152 let client = match &self.backend_config {
154 BackendConfig::Stdio { command, args, env } => {
155 self.connect_stdio(command, args, env, &client_info).await?
156 }
157 BackendConfig::Url { url, headers } => {
158 self.connect_url(url, headers, &client_info).await?
159 }
160 };
161
162 let proxy_handler = if let Some(ref tool_filter) = self.server_config.tool_filter {
164 ProxyHandler::with_tool_filter(client, mcp_id.clone(), tool_filter.clone())
165 } else {
166 ProxyHandler::with_mcp_id(client, mcp_id.clone())
167 };
168
169 let handler_for_return = proxy_handler.clone();
171
172 let (router, ct) = self.create_server(proxy_handler).await?;
174
175 info!(
176 "[StreamServerBuilder] Server created - mcp_id: {}, stateful: {}",
177 mcp_id, self.server_config.stateful_mode
178 );
179
180 Ok((router, ct, handler_for_return))
181 }
182
183 async fn connect_stdio(
185 &self,
186 command: &str,
187 args: &Option<Vec<String>>,
188 env: &Option<HashMap<String, String>>,
189 client_info: &ClientInfo,
190 ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
191 let args = args.clone();
192
193 let mut wrapped_cmd = CommandWrap::with_new(command, |cmd| {
198 if let Some(cmd_args) = &args {
199 cmd.args(cmd_args);
200 }
201 if let Some(env_vars) = env {
203 for (k, v) in env_vars {
204 cmd.env(k, v);
205 }
206 }
207 });
208
209 #[cfg(unix)]
211 wrapped_cmd.wrap(ProcessGroup::leader());
212
213 #[cfg(windows)]
215 {
216 use windows::Win32::System::Threading::{CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW};
217 info!(
218 "[StreamServerBuilder] Setting CreationFlags: CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP"
219 );
220 wrapped_cmd.wrap(CreationFlags(CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP));
221 wrapped_cmd.wrap(JobObject);
222 }
223
224 wrapped_cmd.wrap(KillOnDrop);
226
227 info!(
228 "[StreamServerBuilder] Starting child process - command: {}, args: {:?}",
229 command,
230 args.as_ref().unwrap_or(&vec![])
231 );
232
233 let mcp_id = self.server_config.mcp_id.as_deref().unwrap_or("unknown");
234
235 mcp_common::diagnostic::log_stdio_spawn_context("StreamServerBuilder", mcp_id, env);
237
238 let (tokio_process, child_stderr) = TokioChildProcess::builder(wrapped_cmd)
241 .stderr(Stdio::piped())
242 .spawn()
243 .map_err(|e| {
244 anyhow::anyhow!(
245 "{}",
246 mcp_common::diagnostic::format_spawn_error(mcp_id, command, &args, e)
247 )
248 })?;
249
250 if let Some(stderr_pipe) = child_stderr {
252 mcp_common::spawn_stderr_reader(stderr_pipe, mcp_id.to_string());
253 }
254
255 let client = client_info.clone().serve(tokio_process).await?;
256
257 info!("[StreamServerBuilder] Child process connected successfully");
258 Ok(client)
259 }
260
261 async fn connect_url(
263 &self,
264 url: &str,
265 headers: &Option<HashMap<String, String>>,
266 client_info: &ClientInfo,
267 ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
268 info!("[StreamServerBuilder] Connecting to URL backend: {}", url);
269
270 let mut req_headers = reqwest::header::HeaderMap::new();
272 let mut auth_header: Option<String> = None;
273
274 if let Some(config_headers) = headers {
275 for (key, value) in config_headers {
276 if key.eq_ignore_ascii_case("Authorization") {
278 auth_header = Some(value.strip_prefix("Bearer ").unwrap_or(value).to_string());
279 continue;
280 }
281
282 req_headers.insert(
283 reqwest::header::HeaderName::try_from(key)
284 .map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
285 value.parse().map_err(|e| {
286 anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
287 })?,
288 );
289 }
290 }
291
292 let http_client = reqwest::Client::builder()
293 .default_headers(req_headers)
294 .build()
295 .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
296
297 let config = StreamableHttpClientTransportConfig {
299 uri: url.to_string().into(),
300 auth_header,
301 ..Default::default()
302 };
303
304 let transport = StreamableHttpClientTransport::with_client(http_client, config);
305 let client = client_info.clone().serve(transport).await?;
306
307 info!("[StreamServerBuilder] URL backend connected successfully");
308 Ok(client)
309 }
310
311 async fn create_server(
313 &self,
314 proxy_handler: ProxyHandler,
315 ) -> Result<(axum::Router, CancellationToken)> {
316 let handler = Arc::new(proxy_handler);
317 let ct = CancellationToken::new();
318
319 if self.server_config.stateful_mode {
320 let session_manager = ProxyAwareSessionManager::new(handler.clone());
322 let handler_for_service = handler.clone();
323
324 let service = StreamableHttpService::new(
325 move || Ok((*handler_for_service).clone()),
326 session_manager.into(),
327 StreamableHttpServerConfig {
328 stateful_mode: true,
329 ..Default::default()
330 },
331 );
332
333 let router = axum::Router::new().fallback_service(service);
334 Ok((router, ct))
335 } else {
336 use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
338
339 let handler_for_service = handler.clone();
340
341 let service = StreamableHttpService::new(
342 move || Ok((*handler_for_service).clone()),
343 LocalSessionManager::default().into(),
344 StreamableHttpServerConfig {
345 stateful_mode: false,
346 ..Default::default()
347 },
348 );
349
350 let router = axum::Router::new().fallback_service(service);
351 Ok((router, ct))
352 }
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_builder_creation() {
362 let builder = StreamServerBuilder::new(BackendConfig::Stdio {
363 command: "echo".into(),
364 args: Some(vec!["hello".into()]),
365 env: None,
366 })
367 .mcp_id("test")
368 .stateful(true);
369
370 assert!(builder.server_config.mcp_id.is_some());
371 assert_eq!(builder.server_config.mcp_id.as_deref(), Some("test"));
372 assert!(builder.server_config.stateful_mode);
373 }
374
375 #[test]
376 fn test_url_backend_config() {
377 let mut headers = HashMap::new();
378 headers.insert("Authorization".into(), "Bearer token123".into());
379 headers.insert("X-Custom".into(), "value".into());
380
381 let builder = StreamServerBuilder::new(BackendConfig::Url {
382 url: "http://localhost:8080/mcp".into(),
383 headers: Some(headers),
384 });
385
386 match &builder.backend_config {
387 BackendConfig::Url { url, headers } => {
388 assert_eq!(url, "http://localhost:8080/mcp");
389 assert!(headers.is_some());
390 }
391 _ => panic!("Expected URL backend"),
392 }
393 }
394}