Skip to main content

mcp_stdio_proxy/server/task/
mcp_start_task.rs

1//! MCP Service Start Task
2//!
3//! This module handles starting MCP services using the Builder APIs from
4//! mcp-sse-proxy and mcp-streamable-proxy libraries.
5//!
6//! The refactored implementation removes direct rmcp dependency by delegating
7//! protocol-specific logic to the proxy libraries.
8
9use crate::{
10    AppError, DynamicRouterService, get_proxy_manager,
11    model::GLOBAL_RESTART_TRACKER,
12    model::{
13        CheckMcpStatusResponseStatus, McpConfig, McpProtocol, McpProtocolPath, McpRouterPath,
14        McpServerCommandConfig, McpServerConfig, McpServiceStatus, McpType,
15    },
16    proxy::{
17        McpHandler, SseBackendConfig, SseServerBuilder, StreamBackendConfig, StreamServerBuilder,
18    },
19};
20
21use anyhow::Result;
22use log::{debug, info};
23
24/// Start an MCP service based on configuration
25///
26/// This function creates and configures an MCP proxy service based on the
27/// provided configuration. It supports both SSE and Streamable HTTP client
28/// protocols, with automatic backend protocol detection for URL-based services.
29pub async fn mcp_start_task(
30    mcp_config: McpConfig,
31) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
32    let mcp_id = mcp_config.mcp_id.clone();
33    let client_protocol = mcp_config.client_protocol.clone();
34
35    // Create router path based on client protocol (determines exposed API interface)
36    let mcp_router_path: McpRouterPath =
37        McpRouterPath::new(mcp_id, client_protocol).map_err(AppError::McpServerError)?;
38
39    let mcp_json_config = mcp_config
40        .mcp_json_config
41        .clone()
42        .expect("mcp_json_config is required");
43
44    let mcp_server_config = McpServerConfig::try_from(mcp_json_config)?;
45
46    // Use the integrated method to create the server
47    integrate_server_with_axum(
48        mcp_server_config.clone(),
49        mcp_router_path.clone(),
50        mcp_config.clone(),
51    )
52    .await
53}
54
55/// Integrate MCP server with axum router
56///
57/// This function:
58/// 1. Determines backend protocol (stdio, SSE, or Streamable HTTP)
59/// 2. Creates the appropriate server using Builder APIs
60/// 3. Registers the handler with ProxyManager
61/// 4. Sets up dynamic routing
62pub async fn integrate_server_with_axum(
63    mcp_config: McpServerConfig,
64    mcp_router_path: McpRouterPath,
65    full_mcp_config: McpConfig,
66) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
67    let mcp_type = full_mcp_config.mcp_type.clone();
68    let base_path = mcp_router_path.base_path.clone();
69    let mcp_id = mcp_router_path.mcp_id.clone();
70
71    // Determine backend protocol from configuration
72    let backend_protocol = match &mcp_config {
73        // Command-line config: use stdio protocol
74        McpServerConfig::Command(_) => McpProtocol::Stdio,
75        // URL config: parse type field or auto-detect
76        McpServerConfig::Url(url_config) => {
77            // Check type field first
78            if let Some(type_str) = &url_config.r#type {
79                match type_str.parse::<McpProtocol>() {
80                    Ok(protocol) => {
81                        debug!(
82                            "Using configured protocol type: {} -> {:?}",
83                            type_str, protocol
84                        );
85                        protocol
86                    }
87                    Err(_) => {
88                        // If parsing fails, auto-detect
89                        debug!("Protocol type '{}' unrecognized, auto-detecting", type_str);
90                        let detected_protocol = crate::server::detect_mcp_protocol(
91                            url_config.get_url(),
92                        )
93                        .await
94                        .map_err(|e| {
95                            anyhow::anyhow!(
96                                "Protocol type '{}' unrecognized and auto-detection failed: {}",
97                                type_str,
98                                e
99                            )
100                        })?;
101                        debug!(
102                            "Auto-detected protocol: {:?} (original config: '{}')",
103                            detected_protocol, type_str
104                        );
105                        detected_protocol
106                    }
107                }
108            } else {
109                // No type field, auto-detect
110                debug!("No type field specified, auto-detecting protocol");
111
112                crate::server::detect_mcp_protocol(url_config.get_url())
113                    .await
114                    .map_err(|e| anyhow::anyhow!("Auto-detection failed: {}", e))?
115            }
116        }
117    };
118
119    debug!(
120        "MCP ID: {}, client protocol: {:?}, backend protocol: {:?}",
121        mcp_id, mcp_router_path.mcp_protocol, backend_protocol
122    );
123
124    // Create server based on client protocol using Builder APIs
125    let (router, ct, handler) = match mcp_router_path.mcp_protocol.clone() {
126        // ================ Client uses SSE protocol ================
127        McpProtocol::Sse => {
128            let sse_path = match &mcp_router_path.mcp_protocol_path {
129                McpProtocolPath::SsePath(sse_path) => sse_path,
130                _ => unreachable!(),
131            };
132
133            // Build backend config for SSE
134            let backend_config = build_sse_backend_config(&mcp_config, backend_protocol)?;
135
136            debug!(
137                "Creating SSE server, sse_path={}, post_path={}",
138                sse_path.sse_path, sse_path.message_path
139            );
140
141            // 对于 OneShot 服务,使用更短的 keep_alive 间隔(5秒)来保持后端活跃
142            // 防止后端进程因空闲超时而退出
143            let keep_alive_secs = if matches!(mcp_type, McpType::OneShot) {
144                5
145            } else {
146                15
147            };
148
149            // 对于 OneShot 服务,禁用 stateful 模式以加快响应速度
150            // stateful=false 会跳过 MCP 初始化步骤,直接处理请求
151            let stateful = !matches!(mcp_type, McpType::OneShot);
152
153            let (router, ct, handler) = SseServerBuilder::new(backend_config)
154                .mcp_id(mcp_id.clone())
155                .sse_path(sse_path.sse_path.clone())
156                .post_path(sse_path.message_path.clone())
157                .keep_alive(keep_alive_secs)
158                .stateful(stateful)
159                .build()
160                .await?;
161
162            info!(
163                "SSE server started - MCP ID: {}, type: {:?}",
164                mcp_router_path.mcp_id, mcp_type
165            );
166
167            (router, ct, McpHandler::Sse(handler))
168        }
169
170        // ================ Client uses Streamable HTTP protocol ================
171        McpProtocol::Stream => {
172            // Build backend config for Stream
173            let backend_config = build_stream_backend_config(&mcp_config, backend_protocol)?;
174
175            let (router, ct, handler) = StreamServerBuilder::new(backend_config)
176                .mcp_id(mcp_id.clone())
177                .stateful(false)
178                .build()
179                .await?;
180
181            info!(
182                "Streamable HTTP server started - MCP ID: {}, type: {:?}",
183                mcp_router_path.mcp_id, mcp_type
184            );
185
186            (router, ct, McpHandler::Stream(handler))
187        }
188
189        // Client stdio protocol is not supported in server mode
190        McpProtocol::Stdio => {
191            return Err(anyhow::anyhow!(
192                "Client protocol cannot be Stdio. McpRouterPath::new does not support creating Stdio protocol router paths"
193            ));
194        }
195    };
196
197    // Clone cancellation token for monitoring
198    let ct_clone = ct.clone();
199    let mcp_id_clone = mcp_id.clone();
200
201    // Store MCP service status with full mcp_config for auto-restart
202    let mcp_service_status = McpServiceStatus::new(
203        mcp_id_clone.clone(),
204        mcp_type.clone(),
205        mcp_router_path.clone(),
206        ct_clone.clone(),
207        CheckMcpStatusResponseStatus::Ready,
208    )
209    .with_mcp_config(full_mcp_config.clone());
210
211    // Add MCP service status and proxy handler to global manager
212    let proxy_manager = get_proxy_manager();
213    proxy_manager.add_mcp_service_status_and_proxy(mcp_service_status, Some(handler));
214
215    // ===== 新增:注册配置到缓存 =====
216    proxy_manager
217        .register_mcp_config(&mcp_id, full_mcp_config.clone())
218        .await;
219
220    // Add base path fallback handler for SSE protocol
221    let router = if matches!(mcp_router_path.mcp_protocol, McpProtocol::Sse) {
222        let modified_router = router.fallback(base_path_fallback_handler);
223        info!("SSE base path handler added, base_path: {}", base_path);
224        modified_router
225    } else {
226        router
227    };
228
229    // Register route to global route table
230    info!(
231        "Registering route: base_path={}, mcp_id={}",
232        base_path, mcp_id
233    );
234    info!(
235        "SSE path config: sse_path={}, post_path={}",
236        match &mcp_router_path.mcp_protocol_path {
237            McpProtocolPath::SsePath(sse_path) => &sse_path.sse_path,
238            _ => "N/A",
239        },
240        match &mcp_router_path.mcp_protocol_path {
241            McpProtocolPath::SsePath(sse_path) => &sse_path.message_path,
242            _ => "N/A",
243        }
244    );
245    DynamicRouterService::register_route(&base_path, router.clone());
246    info!("Route registration complete: base_path={}", base_path);
247
248    // 记录重启时间戳(仅在服务成功启动后)
249    GLOBAL_RESTART_TRACKER.record_restart(&mcp_id);
250
251    Ok((router, ct))
252}
253
254/// Build SSE backend configuration from MCP server config
255fn build_sse_backend_config(
256    mcp_config: &McpServerConfig,
257    backend_protocol: McpProtocol,
258) -> Result<SseBackendConfig> {
259    match mcp_config {
260        McpServerConfig::Command(cmd_config) => {
261            log_command_details(cmd_config);
262            Ok(SseBackendConfig::Stdio {
263                command: cmd_config.command.clone(),
264                args: cmd_config.args.clone(),
265                env: cmd_config.env.clone(),
266            })
267        }
268        McpServerConfig::Url(url_config) => match backend_protocol {
269            McpProtocol::Stdio => Err(anyhow::anyhow!(
270                "URL-based MCP service cannot use Stdio protocol"
271            )),
272            McpProtocol::Sse => {
273                info!("Connecting to SSE backend: {}", url_config.get_url());
274                Ok(SseBackendConfig::SseUrl {
275                    url: url_config.get_url().to_string(),
276                    headers: url_config.headers.clone(),
277                })
278            }
279            McpProtocol::Stream => {
280                info!(
281                    "Connecting to Streamable HTTP backend (SSE frontend): {}",
282                    url_config.get_url()
283                );
284                Ok(SseBackendConfig::StreamUrl {
285                    url: url_config.get_url().to_string(),
286                    headers: url_config.headers.clone(),
287                })
288            }
289        },
290    }
291}
292
293/// Build Stream backend configuration from MCP server config
294fn build_stream_backend_config(
295    mcp_config: &McpServerConfig,
296    backend_protocol: McpProtocol,
297) -> Result<StreamBackendConfig> {
298    match mcp_config {
299        McpServerConfig::Command(cmd_config) => {
300            log_command_details(cmd_config);
301            Ok(StreamBackendConfig::Stdio {
302                command: cmd_config.command.clone(),
303                args: cmd_config.args.clone(),
304                env: cmd_config.env.clone(),
305            })
306        }
307        McpServerConfig::Url(url_config) => {
308            match backend_protocol {
309                McpProtocol::Stdio => Err(anyhow::anyhow!(
310                    "URL-based MCP service cannot use Stdio protocol"
311                )),
312                McpProtocol::Sse => {
313                    // Note: StreamServerBuilder currently only supports Streamable HTTP URL backend
314                    // SSE backend with Stream frontend would require protocol conversion
315                    // For now, we return an error for this combination
316                    Err(anyhow::anyhow!(
317                        "SSE backend with Streamable HTTP frontend is not yet supported. \
318                         Please use SSE frontend or configure a Streamable HTTP backend."
319                    ))
320                }
321                McpProtocol::Stream => {
322                    info!(
323                        "Connecting to Streamable HTTP backend: {}",
324                        url_config.get_url()
325                    );
326                    Ok(StreamBackendConfig::Url {
327                        url: url_config.get_url().to_string(),
328                        headers: url_config.headers.clone(),
329                    })
330                }
331            }
332        }
333    }
334}
335
336/// Log command execution details for debugging
337fn log_command_details(mcp_config: &McpServerCommandConfig) {
338    let args_str = mcp_config
339        .args
340        .as_ref()
341        .map_or(String::new(), |args| args.join(" "));
342    let cmd_str = format!("Executing command: {} {}", mcp_config.command, args_str);
343    debug!("{cmd_str}");
344
345    if let Some(env_vars) = &mcp_config.env {
346        let env_vars: Vec<String> = env_vars.iter().map(|(k, v)| format!("{k}={v}")).collect();
347        if !env_vars.is_empty() {
348            debug!("Environment variables: {}", env_vars.join(", "));
349        }
350    }
351
352    debug!("Full command: {:?}", mcp_config.command);
353
354    let env_str = mcp_config.env.as_ref().map_or(String::new(), |env| {
355        env.iter()
356            .map(|(k, v)| format!("{k}={v}"))
357            .collect::<Vec<String>>()
358            .join(" ")
359    });
360
361    let full_command = format!("{} {} {}", mcp_config.command, args_str, env_str);
362    info!("Full command string: {:?}", full_command);
363}
364
365/// Base path fallback handler - supports direct access to base path with automatic redirection
366#[axum::debug_handler]
367async fn base_path_fallback_handler(
368    method: axum::http::Method,
369    uri: axum::http::Uri,
370    headers: axum::http::HeaderMap,
371) -> impl axum::response::IntoResponse {
372    let path = uri.path();
373    info!("Base path handler: {} {}", method, path);
374
375    // Determine if SSE or Stream protocol
376    if path.contains("/sse/proxy/") {
377        // SSE protocol handling
378        match method {
379            axum::http::Method::GET => {
380                // Extract MCP ID from path
381                let mcp_id = path.split("/sse/proxy/").nth(1);
382
383                if let Some(mcp_id) = mcp_id {
384                    // Check if MCP service exists
385                    let proxy_manager = get_proxy_manager();
386                    if proxy_manager.get_mcp_service_status(mcp_id).is_none() {
387                        // MCP service not found
388                        (
389                            axum::http::StatusCode::NOT_FOUND,
390                            [("Content-Type", "text/plain".to_string())],
391                            format!("MCP service '{}' not found", mcp_id).to_string(),
392                        )
393                    } else {
394                        // MCP service exists, check Accept header
395                        let accept_header = headers.get("accept");
396                        if let Some(accept) = accept_header {
397                            let accept_str = accept.to_str().unwrap_or("");
398                            if accept_str.contains("text/event-stream") {
399                                // Correct Accept header, redirect to /sse
400                                let redirect_uri = format!("{}/sse", path);
401                                info!("SSE redirect to: {}", redirect_uri);
402                                (
403                                    axum::http::StatusCode::FOUND,
404                                    [("Location", redirect_uri.to_string())],
405                                    "Redirecting to SSE endpoint".to_string(),
406                                )
407                            } else {
408                                // Incorrect Accept header
409                                (
410                                    axum::http::StatusCode::BAD_REQUEST,
411                                    [("Content-Type", "text/plain".to_string())],
412                                    "SSE error: Invalid Accept header, expected 'text/event-stream'".to_string(),
413                                )
414                            }
415                        } else {
416                            // No Accept header
417                            (
418                                axum::http::StatusCode::BAD_REQUEST,
419                                [("Content-Type", "text/plain".to_string())],
420                                "SSE error: Missing Accept header, expected 'text/event-stream'"
421                                    .to_string(),
422                            )
423                        }
424                    }
425                } else {
426                    // Cannot extract MCP ID from path
427                    (
428                        axum::http::StatusCode::BAD_REQUEST,
429                        [("Content-Type", "text/plain".to_string())],
430                        "SSE error: Invalid SSE path".to_string(),
431                    )
432                }
433            }
434            axum::http::Method::POST => {
435                // POST request redirect to /message
436                let redirect_uri = format!("{}/message", path);
437                info!("SSE redirect to: {}", redirect_uri);
438                (
439                    axum::http::StatusCode::FOUND,
440                    [("Location", redirect_uri.to_string())],
441                    "Redirecting to message endpoint".to_string(),
442                )
443            }
444            _ => {
445                // Other methods return 405 Method Not Allowed
446                (
447                    axum::http::StatusCode::METHOD_NOT_ALLOWED,
448                    [("Allow", "GET, POST".to_string())],
449                    "Only GET and POST methods are allowed".to_string(),
450                )
451            }
452        }
453    } else if path.contains("/stream/proxy/") {
454        // Stream protocol handling - return success directly without redirect
455        match method {
456            axum::http::Method::GET => {
457                // GET request returns server info
458                (
459                    axum::http::StatusCode::OK,
460                    [("Content-Type", "application/json".to_string())],
461                    r#"{"jsonrpc":"2.0","result":{"info":"Streamable MCP Server","version":"1.0"}}"#.to_string(),
462                )
463            }
464            axum::http::Method::POST => {
465                // POST request returns success, let StreamableHttpService handle
466                (
467                    axum::http::StatusCode::OK,
468                    [("Content-Type", "application/json".to_string())],
469                    r#"{"jsonrpc":"2.0","result":{"message":"Stream request received","protocol":"streamable-http"}}"#.to_string(),
470                )
471            }
472            _ => {
473                // Other methods return 405 Method Not Allowed
474                (
475                    axum::http::StatusCode::METHOD_NOT_ALLOWED,
476                    [("Allow", "GET, POST".to_string())],
477                    "Only GET and POST methods are allowed".to_string(),
478                )
479            }
480        }
481    } else {
482        // Unknown protocol
483        (
484            axum::http::StatusCode::BAD_REQUEST,
485            [("Content-Type", "text/plain".to_string())],
486            "Unknown protocol or path".to_string(),
487        )
488    }
489}