1use std::collections::HashMap;
7use std::sync::Arc;
8
9use anyhow::Result;
10use tokio::process::Command;
11use tokio_util::sync::CancellationToken;
12use tracing::info;
13
14use rmcp::{
15 ServiceExt,
16 model::{ClientCapabilities, ClientInfo},
17 transport::{
18 TokioChildProcess,
19 streamable_http_client::{
20 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
21 },
22 streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService},
23 },
24};
25
26use crate::{ProxyAwareSessionManager, ProxyHandler, ToolFilter};
27pub use mcp_common::ToolFilter as CommonToolFilter;
28
29#[derive(Debug, Clone)]
33pub enum BackendConfig {
34 Stdio {
36 command: String,
38 args: Option<Vec<String>>,
40 env: Option<HashMap<String, String>>,
42 },
43 Url {
45 url: String,
47 headers: Option<HashMap<String, String>>,
49 },
50}
51
52#[derive(Debug, Clone, Default)]
54pub struct StreamServerConfig {
55 pub stateful_mode: bool,
57 pub mcp_id: Option<String>,
59 pub tool_filter: Option<ToolFilter>,
61}
62
63pub struct StreamServerBuilder {
84 backend_config: BackendConfig,
85 server_config: StreamServerConfig,
86}
87
88impl StreamServerBuilder {
89 pub fn new(backend: BackendConfig) -> Self {
91 Self {
92 backend_config: backend,
93 server_config: StreamServerConfig::default(),
94 }
95 }
96
97 pub fn stateful(mut self, enabled: bool) -> Self {
101 self.server_config.stateful_mode = enabled;
102 self
103 }
104
105 pub fn mcp_id(mut self, id: impl Into<String>) -> Self {
109 self.server_config.mcp_id = Some(id.into());
110 self
111 }
112
113 pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
115 self.server_config.tool_filter = Some(filter);
116 self
117 }
118
119 pub async fn build(self) -> Result<(axum::Router, CancellationToken, ProxyHandler)> {
125 let mcp_id = self
126 .server_config
127 .mcp_id
128 .clone()
129 .unwrap_or_else(|| "stream-proxy".into());
130
131 let client_info = ClientInfo {
133 protocol_version: Default::default(),
134 capabilities: ClientCapabilities::builder()
135 .enable_experimental()
136 .enable_roots()
137 .enable_roots_list_changed()
138 .enable_sampling()
139 .build(),
140 ..Default::default()
141 };
142
143 let client = match &self.backend_config {
145 BackendConfig::Stdio { command, args, env } => {
146 self.connect_stdio(command, args, env, &client_info).await?
147 }
148 BackendConfig::Url { url, headers } => {
149 self.connect_url(url, headers, &client_info).await?
150 }
151 };
152
153 let proxy_handler = if let Some(ref tool_filter) = self.server_config.tool_filter {
155 ProxyHandler::with_tool_filter(client, mcp_id.clone(), tool_filter.clone())
156 } else {
157 ProxyHandler::with_mcp_id(client, mcp_id.clone())
158 };
159
160 let handler_for_return = proxy_handler.clone();
162
163 let (router, ct) = self.create_server(proxy_handler).await?;
165
166 info!(
167 "[StreamServerBuilder] Server created - mcp_id: {}, stateful: {}",
168 mcp_id, self.server_config.stateful_mode
169 );
170
171 Ok((router, ct, handler_for_return))
172 }
173
174 async fn connect_stdio(
176 &self,
177 command: &str,
178 args: &Option<Vec<String>>,
179 env: &Option<HashMap<String, String>>,
180 client_info: &ClientInfo,
181 ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
182 let mut cmd = Command::new(command);
183
184 if let Some(cmd_args) = args {
185 cmd.args(cmd_args);
186 }
187
188 if let Some(env_vars) = env {
189 for (k, v) in env_vars {
190 cmd.env(k, v);
191 }
192 }
193
194 info!(
195 "[StreamServerBuilder] Starting child process - command: {}, args: {:?}",
196 command,
197 args.as_ref().unwrap_or(&vec![])
198 );
199
200 let tokio_process = TokioChildProcess::new(cmd)?;
201 let client = client_info.clone().serve(tokio_process).await?;
202
203 info!("[StreamServerBuilder] Child process connected successfully");
204 Ok(client)
205 }
206
207 async fn connect_url(
209 &self,
210 url: &str,
211 headers: &Option<HashMap<String, String>>,
212 client_info: &ClientInfo,
213 ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
214 info!("[StreamServerBuilder] Connecting to URL backend: {}", url);
215
216 let mut req_headers = reqwest::header::HeaderMap::new();
218 let mut auth_header: Option<String> = None;
219
220 if let Some(config_headers) = headers {
221 for (key, value) in config_headers {
222 if key.eq_ignore_ascii_case("Authorization") {
224 auth_header = Some(value.strip_prefix("Bearer ").unwrap_or(value).to_string());
225 continue;
226 }
227
228 req_headers.insert(
229 reqwest::header::HeaderName::try_from(key)
230 .map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
231 value.parse().map_err(|e| {
232 anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
233 })?,
234 );
235 }
236 }
237
238 let http_client = reqwest::Client::builder()
239 .default_headers(req_headers)
240 .build()
241 .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
242
243 let config = StreamableHttpClientTransportConfig {
245 uri: url.to_string().into(),
246 auth_header,
247 ..Default::default()
248 };
249
250 let transport = StreamableHttpClientTransport::with_client(http_client, config);
251 let client = client_info.clone().serve(transport).await?;
252
253 info!("[StreamServerBuilder] URL backend connected successfully");
254 Ok(client)
255 }
256
257 async fn create_server(
259 &self,
260 proxy_handler: ProxyHandler,
261 ) -> Result<(axum::Router, CancellationToken)> {
262 let handler = Arc::new(proxy_handler);
263 let ct = CancellationToken::new();
264
265 if self.server_config.stateful_mode {
266 let session_manager = ProxyAwareSessionManager::new(handler.clone());
268 let handler_for_service = handler.clone();
269
270 let service = StreamableHttpService::new(
271 move || Ok((*handler_for_service).clone()),
272 session_manager.into(),
273 StreamableHttpServerConfig {
274 stateful_mode: true,
275 ..Default::default()
276 },
277 );
278
279 let router = axum::Router::new().fallback_service(service);
280 Ok((router, ct))
281 } else {
282 use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
284
285 let handler_for_service = handler.clone();
286
287 let service = StreamableHttpService::new(
288 move || Ok((*handler_for_service).clone()),
289 LocalSessionManager::default().into(),
290 StreamableHttpServerConfig {
291 stateful_mode: false,
292 ..Default::default()
293 },
294 );
295
296 let router = axum::Router::new().fallback_service(service);
297 Ok((router, ct))
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_builder_creation() {
308 let builder = StreamServerBuilder::new(BackendConfig::Stdio {
309 command: "echo".into(),
310 args: Some(vec!["hello".into()]),
311 env: None,
312 })
313 .mcp_id("test")
314 .stateful(true);
315
316 assert!(builder.server_config.mcp_id.is_some());
317 assert_eq!(builder.server_config.mcp_id.as_deref(), Some("test"));
318 assert!(builder.server_config.stateful_mode);
319 }
320
321 #[test]
322 fn test_url_backend_config() {
323 let mut headers = HashMap::new();
324 headers.insert("Authorization".into(), "Bearer token123".into());
325 headers.insert("X-Custom".into(), "value".into());
326
327 let builder = StreamServerBuilder::new(BackendConfig::Url {
328 url: "http://localhost:8080/mcp".into(),
329 headers: Some(headers),
330 });
331
332 match &builder.backend_config {
333 BackendConfig::Url { url, headers } => {
334 assert_eq!(url, "http://localhost:8080/mcp");
335 assert!(headers.is_some());
336 }
337 _ => panic!("Expected URL backend"),
338 }
339 }
340}