rmcp_proxy/
sse_server.rs

1use crate::proxy_handler::ProxyHandler;
2/**
3 * Create a local SSE server that proxies requests to a stdio MCP server.
4 */
5use rmcp::{
6    model::{ClientCapabilities, ClientInfo},
7    transport::{
8        child_process::TokioChildProcess,
9        sse_server::{SseServer, SseServerConfig},
10    },
11    ServiceExt,
12};
13use std::{collections::HashMap, error::Error as StdError, net::SocketAddr, time::Duration};
14use tokio::process::Command;
15use tokio_util::sync::CancellationToken;
16use tracing::info;
17
18/// Settings for the SSE server
19pub struct SseServerSettings {
20    pub bind_addr: SocketAddr,
21    pub keep_alive: Option<Duration>,
22}
23
24/// StdioServerParameters holds parameters for the stdio client.
25pub struct StdioServerParameters {
26    pub command: String,
27    pub args: Vec<String>,
28    pub env: HashMap<String, String>,
29}
30
31/// Run the SSE server with a stdio client
32///
33/// This function connects to a stdio server and exposes it as an SSE server.
34pub async fn run_sse_server(
35    stdio_params: StdioServerParameters,
36    sse_settings: SseServerSettings,
37) -> Result<(), Box<dyn StdError>> {
38    info!(
39        "Running SSE server on {:?} with command: {}",
40        sse_settings.bind_addr, stdio_params.command,
41    );
42
43    // Configure SSE server
44    let config = SseServerConfig {
45        bind: sse_settings.bind_addr,
46        sse_path: "/sse".to_string(),
47        post_path: "/message".to_string(),
48        ct: CancellationToken::new(),
49        // sse_keep_alive: sse_settings.keep_alive,
50    };
51
52    let mut command = Command::new(&stdio_params.command);
53    command.args(&stdio_params.args);
54
55    for (key, value) in &stdio_params.env {
56        command.env(key, value);
57    }
58
59    // Create child process
60    let tokio_process = TokioChildProcess::new(&mut command)?;
61
62    let client_info = ClientInfo {
63        protocol_version: Default::default(),
64        capabilities: ClientCapabilities::builder()
65            .enable_experimental()
66            .enable_roots()
67            .enable_roots_list_changed()
68            .enable_sampling()
69            .build(),
70        ..Default::default()
71    };
72
73    // Create client service
74    let client = client_info.serve(tokio_process).await?;
75
76    // Get server info
77    let server_info = client.peer_info();
78    info!("Connected to server: {}", server_info.server_info.name);
79
80    // Create proxy handler
81    let proxy_handler = ProxyHandler::new(client);
82
83    // Start the SSE server
84    let sse_server = SseServer::serve_with_config(config.clone()).await?;
85
86    // Register the proxy handler with the SSE server
87    let ct = sse_server.with_service(move || proxy_handler.clone());
88
89    // Wait for Ctrl+C to shut down
90    tokio::signal::ctrl_c().await?;
91    ct.cancel();
92
93    Ok(())
94}