1use std::collections::HashMap;
7use std::process::Stdio;
8use std::sync::Arc;
9
10use anyhow::Result;
11use process_wrap::tokio::{CommandWrap, KillOnDrop};
12use tokio_util::sync::CancellationToken;
13use tracing::{info, warn};
14
15use rmcp::{
16 ServiceExt,
17 model::{ClientCapabilities, ClientInfo},
18 transport::{
19 TokioChildProcess,
20 streamable_http_client::{
21 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
22 },
23 streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService},
24 },
25};
26
27#[cfg(unix)]
29use process_wrap::tokio::ProcessGroup;
30
31#[cfg(windows)]
33use process_wrap::tokio::{CreationFlags, JobObject};
34
35use crate::{ProxyAwareSessionManager, ProxyHandler, ToolFilter};
36pub use mcp_common::ToolFilter as CommonToolFilter;
37
38#[derive(Debug, Clone)]
42pub enum BackendConfig {
43 Stdio {
45 command: String,
47 args: Option<Vec<String>>,
49 env: Option<HashMap<String, String>>,
51 },
52 Url {
54 url: String,
56 headers: Option<HashMap<String, String>>,
58 },
59}
60
61#[derive(Debug, Clone, Default)]
63pub struct StreamServerConfig {
64 pub stateful_mode: bool,
66 pub mcp_id: Option<String>,
68 pub tool_filter: Option<ToolFilter>,
70}
71
72pub struct StreamServerBuilder {
93 backend_config: BackendConfig,
94 server_config: StreamServerConfig,
95}
96
97impl StreamServerBuilder {
98 pub fn new(backend: BackendConfig) -> Self {
100 Self {
101 backend_config: backend,
102 server_config: StreamServerConfig::default(),
103 }
104 }
105
106 pub fn stateful(mut self, enabled: bool) -> Self {
110 self.server_config.stateful_mode = enabled;
111 self
112 }
113
114 pub fn mcp_id(mut self, id: impl Into<String>) -> Self {
118 self.server_config.mcp_id = Some(id.into());
119 self
120 }
121
122 pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
124 self.server_config.tool_filter = Some(filter);
125 self
126 }
127
128 pub async fn build(self) -> Result<(axum::Router, CancellationToken, ProxyHandler)> {
134 let mcp_id = self
135 .server_config
136 .mcp_id
137 .clone()
138 .unwrap_or_else(|| "stream-proxy".into());
139
140 let capabilities = ClientCapabilities::builder()
142 .enable_experimental()
143 .enable_roots()
144 .enable_roots_list_changed()
145 .enable_sampling()
146 .build();
147 let client_info = ClientInfo::new(
148 capabilities,
149 rmcp::model::Implementation::new("mcp-streamable-proxy", env!("CARGO_PKG_VERSION")),
150 );
151
152 let client = match &self.backend_config {
154 BackendConfig::Stdio { command, args, env } => {
155 self.connect_stdio(command, args, env, &client_info).await?
156 }
157 BackendConfig::Url { url, headers } => {
158 self.connect_url(url, headers, &client_info).await?
159 }
160 };
161
162 let proxy_handler = if let Some(ref tool_filter) = self.server_config.tool_filter {
164 ProxyHandler::with_tool_filter(client, mcp_id.clone(), tool_filter.clone())
165 } else {
166 ProxyHandler::with_mcp_id(client, mcp_id.clone())
167 };
168
169 let handler_for_return = proxy_handler.clone();
171
172 let (router, ct) = self.create_server(proxy_handler).await?;
174
175 info!(
176 "[StreamServerBuilder] Server created - mcp_id: {}, stateful: {}",
177 mcp_id, self.server_config.stateful_mode
178 );
179
180 Ok((router, ct, handler_for_return))
181 }
182
183 async fn connect_stdio(
185 &self,
186 command: &str,
187 args: &Option<Vec<String>>,
188 env: &Option<HashMap<String, String>>,
189 client_info: &ClientInfo,
190 ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
191 #[cfg(windows)]
193 let (command, args) = self.preprocess_npx_command_windows(command, args.clone());
194 #[cfg(not(windows))]
195 let args = args.clone();
196
197 let mut wrapped_cmd = CommandWrap::with_new(&command, |cmd| {
201 let (final_path, filtered_env) = mcp_common::prepare_stdio_env(env);
202 if let Some(path) = final_path {
203 cmd.env("PATH", path);
204 } else {
205 warn!("[StreamServerBuilder] PATH not available from parent process or config");
206 }
207
208 if let Some(cmd_args) = &args {
209 cmd.args(cmd_args);
210 }
211
212 if let Some(vars) = filtered_env {
213 for (k, v) in vars {
214 cmd.env(k, v);
215 }
216 }
217 });
218
219 #[cfg(unix)]
221 wrapped_cmd.wrap(ProcessGroup::leader());
222
223 #[cfg(windows)]
225 {
226 use windows::Win32::System::Threading::{CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW};
227 info!(
228 "[StreamServerBuilder] Setting CreationFlags: CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP"
229 );
230 wrapped_cmd.wrap(CreationFlags(CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP));
231 wrapped_cmd.wrap(JobObject);
232 }
233
234 wrapped_cmd.wrap(KillOnDrop);
236
237 info!(
238 "[StreamServerBuilder] Starting child process - command: {}, args: {:?}",
239 command,
240 args.as_ref().unwrap_or(&vec![])
241 );
242
243 let mcp_id = self.server_config.mcp_id.as_deref().unwrap_or("unknown");
244
245 mcp_common::diagnostic::log_stdio_spawn_context("StreamServerBuilder", mcp_id, env);
247
248 let (tokio_process, child_stderr) = TokioChildProcess::builder(wrapped_cmd)
251 .stderr(Stdio::piped())
252 .spawn()
253 .map_err(|e| {
254 anyhow::anyhow!(
255 "{}",
256 mcp_common::diagnostic::format_spawn_error(mcp_id, &command, &args, e)
257 )
258 })?;
259
260 if let Some(stderr_pipe) = child_stderr {
262 mcp_common::spawn_stderr_reader(stderr_pipe, mcp_id.to_string());
263 }
264
265 let client = client_info.clone().serve(tokio_process).await?;
266
267 info!("[StreamServerBuilder] Child process connected successfully");
268 Ok(client)
269 }
270
271 async fn connect_url(
273 &self,
274 url: &str,
275 headers: &Option<HashMap<String, String>>,
276 client_info: &ClientInfo,
277 ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
278 info!("[StreamServerBuilder] Connecting to URL backend: {}", url);
279
280 let mut req_headers = reqwest::header::HeaderMap::new();
282 let mut auth_header: Option<String> = None;
283
284 if let Some(config_headers) = headers {
285 for (key, value) in config_headers {
286 if key.eq_ignore_ascii_case("Authorization") {
288 auth_header = Some(value.strip_prefix("Bearer ").unwrap_or(value).to_string());
289 continue;
290 }
291
292 req_headers.insert(
293 reqwest::header::HeaderName::try_from(key)
294 .map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
295 value.parse().map_err(|e| {
296 anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
297 })?,
298 );
299 }
300 }
301
302 let http_client = reqwest::Client::builder()
303 .default_headers(req_headers)
304 .build()
305 .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
306
307 let config = StreamableHttpClientTransportConfig {
309 uri: url.to_string().into(),
310 auth_header,
311 ..Default::default()
312 };
313
314 let transport = StreamableHttpClientTransport::with_client(http_client, config);
315 let client = client_info.clone().serve(transport).await?;
316
317 info!("[StreamServerBuilder] URL backend connected successfully");
318 Ok(client)
319 }
320
321 #[cfg(windows)]
326 fn preprocess_npx_command_windows(
327 &self,
328 command: &str,
329 args: Option<Vec<String>>,
330 ) -> (String, Option<Vec<String>>) {
331 let is_npx = command == "npx"
333 || command == "npx.cmd"
334 || command.ends_with("/npx")
335 || command.ends_with("\\npx")
336 || command.ends_with("/npx.cmd")
337 || command.ends_with("\\npx.cmd");
338
339 if !is_npx {
340 return (command.to_string(), args);
341 }
342
343 let args = match args {
344 Some(a) => a,
345 None => return (command.to_string(), None),
346 };
347
348 let package_spec = args.iter().find(|s| !s.starts_with('-') && s.contains('@'));
350
351 let Some(pkg) = package_spec else {
352 return (command.to_string(), Some(args));
353 };
354
355 let package_name = pkg.split('@').next().unwrap_or(pkg);
357
358 if let Some((node_exe, js_entry)) = self.find_npx_package_entry_windows(package_name) {
360 info!(
361 "[StreamServerBuilder] Windows npx 转换: npx {} -> node {}",
362 pkg,
363 js_entry.display()
364 );
365
366 let mut new_args = vec![js_entry.to_string_lossy().to_string()];
368 for arg in &args {
369 if arg != "-y" && arg != pkg {
370 new_args.push(arg.clone());
371 }
372 }
373
374 return (node_exe.to_string_lossy().to_string(), Some(new_args));
375 }
376
377 info!(
379 "[StreamServerBuilder] Windows npx 未找到已安装的包: {},保持原命令",
380 pkg
381 );
382 (command.to_string(), Some(args))
383 }
384
385 #[cfg(windows)]
387 fn find_npx_package_entry_windows(
388 &self,
389 package_name: &str,
390 ) -> Option<(std::path::PathBuf, std::path::PathBuf)> {
391 let node_exe = self.find_node_exe_windows()?;
393
394 let search_paths = self.get_npx_cache_paths_windows();
396
397 for node_modules_dir in search_paths {
398 let package_dir = node_modules_dir.join(package_name);
399 if !package_dir.exists() {
400 continue;
401 }
402
403 let package_json_path = package_dir.join("package.json");
405 if let Ok(content) = std::fs::read_to_string(&package_json_path) {
406 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
407 let bin_entry = json.get("bin").and_then(|b| {
409 if let Some(s) = b.as_str() {
410 Some(s.to_string())
411 } else if let Some(obj) = b.as_object() {
412 obj.get(package_name)
413 .or_else(|| obj.values().next())
414 .and_then(|v| v.as_str())
415 .map(str::to_string)
416 } else {
417 None
418 }
419 });
420
421 if let Some(bin_entry) = bin_entry {
422 let js_entry = package_dir.join(bin_entry);
423 if js_entry.exists() {
424 info!(
425 "[StreamServerBuilder] Windows 找到包入口: {} -> {}",
426 package_name,
427 js_entry.display()
428 );
429 return Some((node_exe.clone(), js_entry));
430 }
431 }
432 }
433 }
434 }
435
436 None
437 }
438
439 #[cfg(windows)]
441 fn find_node_exe_windows(&self) -> Option<std::path::PathBuf> {
442 use std::path::PathBuf;
443
444 if let Ok(node_from_env) = std::env::var("NUWAX_NODE_EXE") {
446 let path = PathBuf::from(node_from_env);
447 if path.exists() {
448 return Some(path);
449 }
450 }
451
452 if let Ok(exe_path) = std::env::current_exe() {
454 if let Some(exe_dir) = exe_path.parent() {
455 let resource_paths = [
456 exe_dir
457 .join("resources")
458 .join("node")
459 .join("bin")
460 .join("node.exe"),
461 exe_dir
462 .parent()
463 .unwrap_or(exe_dir)
464 .join("resources")
465 .join("node")
466 .join("bin")
467 .join("node.exe"),
468 ];
469
470 for path in resource_paths {
471 if path.exists() {
472 return Some(path);
473 }
474 }
475 }
476 }
477
478 which::which("node.exe").ok()
480 }
481
482 #[cfg(windows)]
484 fn get_npx_cache_paths_windows(&self) -> Vec<std::path::PathBuf> {
485 use std::path::PathBuf;
486
487 let mut paths = Vec::new();
488
489 if let Ok(appdata) = std::env::var("APPDATA") {
491 let appdata_path = PathBuf::from(&appdata);
492
493 paths.push(appdata_path.join("npm").join("node_modules"));
495
496 paths.push(
498 appdata_path
499 .join("com.nuwax.agent-tauri-client")
500 .join("node_modules"),
501 );
502
503 paths.push(appdata_path.join("npm-cache").join("_npx"));
505 }
506
507 if let Ok(exe_path) = std::env::current_exe() {
509 if let Some(exe_dir) = exe_path.parent() {
510 let resource_paths = [
511 exe_dir.join("resources").join("node").join("node_modules"),
512 exe_dir
513 .parent()
514 .unwrap_or(exe_dir)
515 .join("resources")
516 .join("node")
517 .join("node_modules"),
518 ];
519
520 for path in resource_paths {
521 if path.exists() {
522 paths.push(path);
523 }
524 }
525 }
526 }
527
528 paths
529 }
530
531 async fn create_server(
533 &self,
534 proxy_handler: ProxyHandler,
535 ) -> Result<(axum::Router, CancellationToken)> {
536 let handler = Arc::new(proxy_handler);
537 let ct = CancellationToken::new();
538
539 if self.server_config.stateful_mode {
540 let session_manager = ProxyAwareSessionManager::new(handler.clone());
542 let handler_for_service = handler.clone();
543
544 let service = StreamableHttpService::new(
545 move || Ok((*handler_for_service).clone()),
546 session_manager.into(),
547 StreamableHttpServerConfig {
548 stateful_mode: true,
549 ..Default::default()
550 },
551 );
552
553 let router = axum::Router::new().fallback_service(service);
554 Ok((router, ct))
555 } else {
556 use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
558
559 let handler_for_service = handler.clone();
560
561 let service = StreamableHttpService::new(
562 move || Ok((*handler_for_service).clone()),
563 LocalSessionManager::default().into(),
564 StreamableHttpServerConfig {
565 stateful_mode: false,
566 ..Default::default()
567 },
568 );
569
570 let router = axum::Router::new().fallback_service(service);
571 Ok((router, ct))
572 }
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579
580 #[test]
581 fn test_builder_creation() {
582 let builder = StreamServerBuilder::new(BackendConfig::Stdio {
583 command: "echo".into(),
584 args: Some(vec!["hello".into()]),
585 env: None,
586 })
587 .mcp_id("test")
588 .stateful(true);
589
590 assert!(builder.server_config.mcp_id.is_some());
591 assert_eq!(builder.server_config.mcp_id.as_deref(), Some("test"));
592 assert!(builder.server_config.stateful_mode);
593 }
594
595 #[test]
596 fn test_url_backend_config() {
597 let mut headers = HashMap::new();
598 headers.insert("Authorization".into(), "Bearer token123".into());
599 headers.insert("X-Custom".into(), "value".into());
600
601 let builder = StreamServerBuilder::new(BackendConfig::Url {
602 url: "http://localhost:8080/mcp".into(),
603 headers: Some(headers),
604 });
605
606 match &builder.backend_config {
607 BackendConfig::Url { url, headers } => {
608 assert_eq!(url, "http://localhost:8080/mcp");
609 assert!(headers.is_some());
610 }
611 _ => panic!("Expected URL backend"),
612 }
613 }
614}