use crate::proxy_handler::ProxyHandler;
use rmcp::{
model::{ClientCapabilities, ClientInfo},
transport::{
child_process::TokioChildProcess,
sse_server::{SseServer, SseServerConfig},
},
ServiceExt,
};
use std::{collections::HashMap, error::Error as StdError, net::SocketAddr, time::Duration};
use tokio::process::Command;
use tokio_util::sync::CancellationToken;
use tracing::info;
pub struct SseServerSettings {
pub bind_addr: SocketAddr,
pub keep_alive: Option<Duration>,
}
pub struct StdioServerParameters {
pub command: String,
pub args: Vec<String>,
pub env: HashMap<String, String>,
}
pub async fn run_sse_server(
stdio_params: StdioServerParameters,
sse_settings: SseServerSettings,
) -> Result<(), Box<dyn StdError>> {
info!(
"Running SSE server on {:?} with command: {}",
sse_settings.bind_addr, stdio_params.command,
);
let config = SseServerConfig {
bind: sse_settings.bind_addr,
sse_path: "/sse".to_string(),
post_path: "/message".to_string(),
ct: CancellationToken::new(),
};
let mut command = Command::new(&stdio_params.command);
command.args(&stdio_params.args);
for (key, value) in &stdio_params.env {
command.env(key, value);
}
let tokio_process = TokioChildProcess::new(&mut command)?;
let client_info = ClientInfo {
protocol_version: Default::default(),
capabilities: ClientCapabilities::builder()
.enable_experimental()
.enable_roots()
.enable_roots_list_changed()
.enable_sampling()
.build(),
..Default::default()
};
let client = client_info.serve(tokio_process).await?;
let server_info = client.peer_info();
info!("Connected to server: {}", server_info.server_info.name);
let proxy_handler = ProxyHandler::new(client);
let sse_server = SseServer::serve_with_config(config.clone()).await?;
let ct = sse_server.with_service(move || proxy_handler.clone());
tokio::signal::ctrl_c().await?;
ct.cancel();
Ok(())
}