1use std::collections::HashMap;
7use std::process::Stdio;
8use std::sync::Arc;
9
10use anyhow::Result;
11use process_wrap::tokio::{CommandWrap, KillOnDrop};
12use tokio_util::sync::CancellationToken;
13use tracing::info;
14
15use rmcp::{
16 ServiceExt,
17 model::{ClientCapabilities, ClientInfo},
18 transport::{
19 TokioChildProcess,
20 streamable_http_client::{
21 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
22 },
23 streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService},
24 },
25};
26
27#[cfg(unix)]
29use process_wrap::tokio::ProcessGroup;
30
31#[cfg(windows)]
33use process_wrap::tokio::{CreationFlags, JobObject};
34
35use crate::{ProxyAwareSessionManager, ProxyHandler, ToolFilter};
36pub use mcp_common::ToolFilter as CommonToolFilter;
37
38#[derive(Debug, Clone)]
42pub enum BackendConfig {
43 Stdio {
45 command: String,
47 args: Option<Vec<String>>,
49 env: Option<HashMap<String, String>>,
51 },
52 Url {
54 url: String,
56 headers: Option<HashMap<String, String>>,
58 },
59}
60
61#[derive(Debug, Clone, Default)]
63pub struct StreamServerConfig {
64 pub stateful_mode: bool,
66 pub mcp_id: Option<String>,
68 pub tool_filter: Option<ToolFilter>,
70}
71
72pub struct StreamServerBuilder {
93 backend_config: BackendConfig,
94 server_config: StreamServerConfig,
95}
96
97impl StreamServerBuilder {
98 pub fn new(backend: BackendConfig) -> Self {
100 Self {
101 backend_config: backend,
102 server_config: StreamServerConfig::default(),
103 }
104 }
105
106 pub fn stateful(mut self, enabled: bool) -> Self {
110 self.server_config.stateful_mode = enabled;
111 self
112 }
113
114 pub fn mcp_id(mut self, id: impl Into<String>) -> Self {
118 self.server_config.mcp_id = Some(id.into());
119 self
120 }
121
122 pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
124 self.server_config.tool_filter = Some(filter);
125 self
126 }
127
128 pub async fn build(self) -> Result<(axum::Router, CancellationToken, ProxyHandler)> {
134 let mcp_id = self
135 .server_config
136 .mcp_id
137 .clone()
138 .unwrap_or_else(|| "stream-proxy".into());
139
140 let capabilities = ClientCapabilities::builder()
142 .enable_experimental()
143 .enable_roots()
144 .enable_roots_list_changed()
145 .enable_sampling()
146 .build();
147 let client_info = ClientInfo::new(
148 capabilities,
149 rmcp::model::Implementation::new("mcp-streamable-proxy", env!("CARGO_PKG_VERSION")),
150 );
151
152 let client = match &self.backend_config {
154 BackendConfig::Stdio { command, args, env } => {
155 self.connect_stdio(command, args, env, &client_info).await?
156 }
157 BackendConfig::Url { url, headers } => {
158 self.connect_url(url, headers, &client_info).await?
159 }
160 };
161
162 let proxy_handler = if let Some(ref tool_filter) = self.server_config.tool_filter {
164 ProxyHandler::with_tool_filter(client, mcp_id.clone(), tool_filter.clone())
165 } else {
166 ProxyHandler::with_mcp_id(client, mcp_id.clone())
167 };
168
169 let handler_for_return = proxy_handler.clone();
171
172 let (router, ct) = self.create_server(proxy_handler).await?;
174
175 info!(
176 "[StreamServerBuilder] Server created - mcp_id: {}, stateful: {}",
177 mcp_id, self.server_config.stateful_mode
178 );
179
180 Ok((router, ct, handler_for_return))
181 }
182
183 async fn connect_stdio(
185 &self,
186 command: &str,
187 args: &Option<Vec<String>>,
188 env: &Option<HashMap<String, String>>,
189 client_info: &ClientInfo,
190 ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
191 #[cfg(windows)]
193 let (command, args) = self.preprocess_npx_command_windows(command, args.clone());
194 #[cfg(not(windows))]
195 let args = args.clone();
196
197 let mut wrapped_cmd = CommandWrap::with_new(command, |cmd| {
202 if let Some(cmd_args) = &args {
203 cmd.args(cmd_args);
204 }
205 if let Some(env_vars) = env {
207 for (k, v) in env_vars {
208 cmd.env(k, v);
209 }
210 }
211 });
212
213 #[cfg(unix)]
215 wrapped_cmd.wrap(ProcessGroup::leader());
216
217 #[cfg(windows)]
219 {
220 use windows::Win32::System::Threading::{CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW};
221 info!(
222 "[StreamServerBuilder] Setting CreationFlags: CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP"
223 );
224 wrapped_cmd.wrap(CreationFlags(CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP));
225 wrapped_cmd.wrap(JobObject);
226 }
227
228 wrapped_cmd.wrap(KillOnDrop);
230
231 info!(
232 "[StreamServerBuilder] Starting child process - command: {}, args: {:?}",
233 command,
234 args.as_ref().unwrap_or(&vec![])
235 );
236
237 let mcp_id = self.server_config.mcp_id.as_deref().unwrap_or("unknown");
238
239 mcp_common::diagnostic::log_stdio_spawn_context("StreamServerBuilder", mcp_id, env);
241
242 let (tokio_process, child_stderr) = TokioChildProcess::builder(wrapped_cmd)
245 .stderr(Stdio::piped())
246 .spawn()
247 .map_err(|e| {
248 anyhow::anyhow!(
249 "{}",
250 mcp_common::diagnostic::format_spawn_error(mcp_id, command, &args, e)
251 )
252 })?;
253
254 if let Some(stderr_pipe) = child_stderr {
256 mcp_common::spawn_stderr_reader(stderr_pipe, mcp_id.to_string());
257 }
258
259 let client = client_info.clone().serve(tokio_process).await?;
260
261 info!("[StreamServerBuilder] Child process connected successfully");
262 Ok(client)
263 }
264
265 async fn connect_url(
267 &self,
268 url: &str,
269 headers: &Option<HashMap<String, String>>,
270 client_info: &ClientInfo,
271 ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
272 info!("[StreamServerBuilder] Connecting to URL backend: {}", url);
273
274 let mut req_headers = reqwest::header::HeaderMap::new();
276 let mut auth_header: Option<String> = None;
277
278 if let Some(config_headers) = headers {
279 for (key, value) in config_headers {
280 if key.eq_ignore_ascii_case("Authorization") {
282 auth_header = Some(value.strip_prefix("Bearer ").unwrap_or(value).to_string());
283 continue;
284 }
285
286 req_headers.insert(
287 reqwest::header::HeaderName::try_from(key)
288 .map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
289 value.parse().map_err(|e| {
290 anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
291 })?,
292 );
293 }
294 }
295
296 let http_client = reqwest::Client::builder()
297 .default_headers(req_headers)
298 .build()
299 .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
300
301 let config = StreamableHttpClientTransportConfig {
303 uri: url.to_string().into(),
304 auth_header,
305 ..Default::default()
306 };
307
308 let transport = StreamableHttpClientTransport::with_client(http_client, config);
309 let client = client_info.clone().serve(transport).await?;
310
311 info!("[StreamServerBuilder] URL backend connected successfully");
312 Ok(client)
313 }
314
315 #[cfg(windows)]
320 fn preprocess_npx_command_windows(
321 &self,
322 command: &str,
323 args: Option<Vec<String>>,
324 ) -> (String, Option<Vec<String>>) {
325 let is_npx = command == "npx"
327 || command == "npx.cmd"
328 || command.ends_with("/npx")
329 || command.ends_with("\\npx")
330 || command.ends_with("/npx.cmd")
331 || command.ends_with("\\npx.cmd");
332
333 if !is_npx {
334 return (command.to_string(), args);
335 }
336
337 let args = match args {
338 Some(a) => a,
339 None => return (command.to_string(), None),
340 };
341
342 let package_spec = args.iter().find(|s| !s.starts_with('-') && s.contains('@'));
344
345 let Some(pkg) = package_spec else {
346 return (command.to_string(), Some(args));
347 };
348
349 let package_name = pkg.split('@').next().unwrap_or(pkg);
351
352 if let Some((node_exe, js_entry)) = self.find_npx_package_entry_windows(package_name) {
354 info!(
355 "[StreamServerBuilder] Windows npx 转换: npx {} -> node {}",
356 pkg,
357 js_entry.display()
358 );
359
360 let mut new_args = vec![js_entry.to_string_lossy().to_string()];
362 for arg in &args {
363 if arg != "-y" && arg != pkg {
364 new_args.push(arg.clone());
365 }
366 }
367
368 return (node_exe.to_string_lossy().to_string(), Some(new_args));
369 }
370
371 info!(
373 "[StreamServerBuilder] Windows npx 未找到已安装的包: {},保持原命令",
374 pkg
375 );
376 (command.to_string(), Some(args))
377 }
378
379 #[cfg(windows)]
381 fn find_npx_package_entry_windows(
382 &self,
383 package_name: &str,
384 ) -> Option<(std::path::PathBuf, std::path::PathBuf)> {
385 let node_exe = self.find_node_exe_windows()?;
387
388 let search_paths = self.get_npx_cache_paths_windows();
390
391 for node_modules_dir in search_paths {
392 let package_dir = node_modules_dir.join(package_name);
393 if !package_dir.exists() {
394 continue;
395 }
396
397 let package_json_path = package_dir.join("package.json");
399 if let Ok(content) = std::fs::read_to_string(&package_json_path) {
400 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
401 let bin_entry = json.get("bin").and_then(|b| {
403 if let Some(s) = b.as_str() {
404 Some(s.to_string())
405 } else if let Some(obj) = b.as_object() {
406 obj.get(package_name)
407 .or_else(|| obj.values().next())
408 .and_then(|v| v.as_str())
409 .map(str::to_string)
410 } else {
411 None
412 }
413 });
414
415 if let Some(bin_entry) = bin_entry {
416 let js_entry = package_dir.join(bin_entry);
417 if js_entry.exists() {
418 info!(
419 "[StreamServerBuilder] Windows 找到包入口: {} -> {}",
420 package_name,
421 js_entry.display()
422 );
423 return Some((node_exe.clone(), js_entry));
424 }
425 }
426 }
427 }
428 }
429
430 None
431 }
432
433 #[cfg(windows)]
435 fn find_node_exe_windows(&self) -> Option<std::path::PathBuf> {
436 use std::path::PathBuf;
437
438 if let Ok(node_from_env) = std::env::var("NUWAX_NODE_EXE") {
440 let path = PathBuf::from(node_from_env);
441 if path.exists() {
442 return Some(path);
443 }
444 }
445
446 if let Ok(exe_path) = std::env::current_exe() {
448 if let Some(exe_dir) = exe_path.parent() {
449 let resource_paths = [
450 exe_dir
451 .join("resources")
452 .join("node")
453 .join("bin")
454 .join("node.exe"),
455 exe_dir
456 .parent()
457 .unwrap_or(exe_dir)
458 .join("resources")
459 .join("node")
460 .join("bin")
461 .join("node.exe"),
462 ];
463
464 for path in resource_paths {
465 if path.exists() {
466 return Some(path);
467 }
468 }
469 }
470 }
471
472 which::which("node.exe").ok()
474 }
475
476 #[cfg(windows)]
478 fn get_npx_cache_paths_windows(&self) -> Vec<std::path::PathBuf> {
479 use std::path::PathBuf;
480
481 let mut paths = Vec::new();
482
483 if let Ok(appdata) = std::env::var("APPDATA") {
485 let appdata_path = PathBuf::from(&appdata);
486
487 paths.push(appdata_path.join("npm").join("node_modules"));
489
490 paths.push(
492 appdata_path
493 .join("com.nuwax.agent-tauri-client")
494 .join("node_modules"),
495 );
496
497 paths.push(appdata_path.join("npm-cache").join("_npx"));
499 }
500
501 if let Ok(exe_path) = std::env::current_exe() {
503 if let Some(exe_dir) = exe_path.parent() {
504 let resource_paths = [
505 exe_dir.join("resources").join("node").join("node_modules"),
506 exe_dir
507 .parent()
508 .unwrap_or(exe_dir)
509 .join("resources")
510 .join("node")
511 .join("node_modules"),
512 ];
513
514 for path in resource_paths {
515 if path.exists() {
516 paths.push(path);
517 }
518 }
519 }
520 }
521
522 paths
523 }
524
525 async fn create_server(
527 &self,
528 proxy_handler: ProxyHandler,
529 ) -> Result<(axum::Router, CancellationToken)> {
530 let handler = Arc::new(proxy_handler);
531 let ct = CancellationToken::new();
532
533 if self.server_config.stateful_mode {
534 let session_manager = ProxyAwareSessionManager::new(handler.clone());
536 let handler_for_service = handler.clone();
537
538 let service = StreamableHttpService::new(
539 move || Ok((*handler_for_service).clone()),
540 session_manager.into(),
541 StreamableHttpServerConfig {
542 stateful_mode: true,
543 ..Default::default()
544 },
545 );
546
547 let router = axum::Router::new().fallback_service(service);
548 Ok((router, ct))
549 } else {
550 use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
552
553 let handler_for_service = handler.clone();
554
555 let service = StreamableHttpService::new(
556 move || Ok((*handler_for_service).clone()),
557 LocalSessionManager::default().into(),
558 StreamableHttpServerConfig {
559 stateful_mode: false,
560 ..Default::default()
561 },
562 );
563
564 let router = axum::Router::new().fallback_service(service);
565 Ok((router, ct))
566 }
567 }
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573
574 #[test]
575 fn test_builder_creation() {
576 let builder = StreamServerBuilder::new(BackendConfig::Stdio {
577 command: "echo".into(),
578 args: Some(vec!["hello".into()]),
579 env: None,
580 })
581 .mcp_id("test")
582 .stateful(true);
583
584 assert!(builder.server_config.mcp_id.is_some());
585 assert_eq!(builder.server_config.mcp_id.as_deref(), Some("test"));
586 assert!(builder.server_config.stateful_mode);
587 }
588
589 #[test]
590 fn test_url_backend_config() {
591 let mut headers = HashMap::new();
592 headers.insert("Authorization".into(), "Bearer token123".into());
593 headers.insert("X-Custom".into(), "value".into());
594
595 let builder = StreamServerBuilder::new(BackendConfig::Url {
596 url: "http://localhost:8080/mcp".into(),
597 headers: Some(headers),
598 });
599
600 match &builder.backend_config {
601 BackendConfig::Url { url, headers } => {
602 assert_eq!(url, "http://localhost:8080/mcp");
603 assert!(headers.is_some());
604 }
605 _ => panic!("Expected URL backend"),
606 }
607 }
608}