1use anyhow::{Context, Result};
2use rmcp::{
3 model::{
4 CallToolRequestParam, CallToolResult, ErrorCode, ErrorData, Implementation,
5 InitializeRequestParam, InitializeResult, ListToolsResult, PaginatedRequestParam,
6 ProtocolVersion, ServerCapabilities, Tool, ToolsCapability,
7 },
8 service::{Peer, RequestContext, RoleServer},
9 Error as McpError, ServerHandler,
10};
11use std::sync::Arc;
12use tokio::sync::{mpsc, oneshot, Mutex};
13use tracing::{debug, error, info, warn};
14
15use crate::{
16 config::Config,
17 hub_actor::{spawn_hub_actor, HubMessage},
18 metrics::{HubMetrics, MetricsTimer},
19 supervisor::SupervisorConfig,
20};
21
22#[derive(Debug, Clone)]
24pub struct McpAggregator {
25 hub_sender: mpsc::Sender<HubMessage>,
27 server_info: Implementation,
29 pub instructions: Option<String>,
31 peer: Option<Peer<RoleServer>>,
33 shutdown_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
35 metrics: HubMetrics,
37}
38
39impl McpAggregator {
40 pub fn new(config: Config, supervisor_config: SupervisorConfig) -> Self {
42 let (server_info, instructions) = Self::create_server_info(&config);
43
44 let (shutdown_tx, shutdown_rx) = oneshot::channel();
46 let shutdown_tx = Arc::new(Mutex::new(Some(shutdown_tx)));
47
48 let hub_sender = spawn_hub_actor(config, supervisor_config, shutdown_rx);
50
51 Self {
52 hub_sender,
53 server_info,
54 instructions,
55 peer: None,
56 shutdown_tx,
57 metrics: HubMetrics::new(),
58 }
59 }
60
61 fn create_server_info(config: &Config) -> (Implementation, Option<String>) {
63 let server_names = config
64 .servers()
65 .iter()
66 .map(|s| s.name.as_str())
67 .collect::<Vec<_>>()
68 .join(", ");
69
70 let server_info = Implementation {
71 name: "MCP Hub Aggregator".into(),
72 version: env!("CARGO_PKG_VERSION").into(),
73 };
74
75 let instructions = Some(format!(
76 "MCP Hub aggregating {} servers: {}",
77 config.servers().len(),
78 server_names
79 ));
80
81 (server_info, instructions)
82 }
83
84 async fn initialize_hub(&self) -> Result<Vec<Tool>> {
86 debug!("Initializing hub actor - entering method");
87 debug!(
88 "Hub sender channel capacity: {}",
89 self.hub_sender.capacity()
90 );
91
92 let (response_tx, response_rx) = oneshot::channel();
93
94 debug!("Sending Initialize message to hub actor");
95 self.hub_sender
96 .send(HubMessage::Initialize {
97 response: response_tx,
98 })
99 .await
100 .context("Failed to send initialize message to hub actor")?;
101
102 debug!("Initialize message sent, waiting for response");
103
104 let result = response_rx
105 .await
106 .context("Hub actor did not respond to initialize")?;
107
108 debug!(
109 "Received response from hub actor with {} tools",
110 result.as_ref().map(|t| t.len()).unwrap_or(0)
111 );
112 result
113 }
114
115 async fn get_aggregated_tools(&self) -> Vec<Tool> {
117 let (response_tx, response_rx) = oneshot::channel();
118
119 if self
120 .hub_sender
121 .send(HubMessage::ListTools {
122 response: response_tx,
123 })
124 .await
125 .is_err()
126 {
127 warn!("Failed to send list_tools message to hub actor");
128 return Vec::new();
129 }
130
131 response_rx.await.unwrap_or_default()
132 }
133
134 async fn route_tool_call(&self, params: CallToolRequestParam) -> Result<CallToolResult> {
136 debug!("Routing tool call: {}", params.name);
137
138 let timer = MetricsTimer::new(
140 self.metrics.clone(),
141 "hub".to_string(), params.name.to_string(),
143 );
144
145 let (response_tx, response_rx) = oneshot::channel();
146
147 let send_result = self
148 .hub_sender
149 .send(HubMessage::CallTool {
150 params,
151 response: response_tx,
152 })
153 .await;
154
155 match send_result {
156 Ok(_) => {
157 let response_result = response_rx
158 .await
159 .context("Hub actor did not respond to call_tool");
160
161 match response_result {
162 Ok(result) => {
163 timer.finish_with_status("success");
164 result
165 }
166 Err(e) => {
167 timer.finish_with_error("hub_timeout");
168 Err(e)
169 }
170 }
171 }
172 Err(e) => {
173 timer.finish_with_error("send_failed");
174 Err(anyhow::anyhow!(
175 "Failed to send call_tool message to hub actor: {}",
176 e
177 ))
178 }
179 }
180 }
181
182 #[allow(dead_code)]
184 pub async fn reload_config(&self, config: Config) -> Result<()> {
185 let (response_tx, response_rx) = oneshot::channel();
186
187 self.hub_sender
188 .send(HubMessage::ReloadConfig {
189 config,
190 response: response_tx,
191 })
192 .await
193 .context("Failed to send reload_config message to hub actor")?;
194
195 response_rx
196 .await
197 .context("Hub actor did not respond to reload_config")?
198 }
199
200 #[allow(dead_code)]
202 pub async fn get_status(
203 &self,
204 ) -> std::collections::HashMap<String, crate::supervisor::ProcessState> {
205 let (response_tx, response_rx) = oneshot::channel();
206
207 if self
208 .hub_sender
209 .send(HubMessage::GetStatus {
210 response: response_tx,
211 })
212 .await
213 .is_err()
214 {
215 warn!("Failed to send get_status message to hub actor");
216 return std::collections::HashMap::new();
217 }
218
219 response_rx.await.unwrap_or_default()
220 }
221
222 pub fn hub_handle(&self) -> mpsc::Sender<HubMessage> {
224 self.hub_sender.clone()
225 }
226}
227
228impl ServerHandler for McpAggregator {
229 fn get_info(&self) -> InitializeResult {
230 debug!("get_info() called - returning server capabilities");
231 InitializeResult {
232 protocol_version: ProtocolVersion::default(),
233 capabilities: ServerCapabilities {
234 experimental: None,
235 logging: None,
236 prompts: None,
237 resources: None,
238 tools: Some(ToolsCapability {
239 list_changed: Some(true),
240 }),
241 },
242 server_info: self.server_info.clone(),
243 instructions: self.instructions.clone(),
244 }
245 }
246
247 fn get_peer(&self) -> Option<Peer<RoleServer>> {
248 self.peer.clone()
249 }
250
251 fn set_peer(&mut self, peer: Peer<RoleServer>) {
252 self.peer = Some(peer);
253 }
254
255 fn initialize(
256 &self,
257 _request: InitializeRequestParam,
258 _context: RequestContext<RoleServer>,
259 ) -> impl std::future::Future<Output = Result<InitializeResult, McpError>> + Send + '_ {
260 async move {
261 info!("Initializing MCP aggregator");
262 debug!("About to call initialize_hub()");
263
264 let _tools = self.initialize_hub().await.map_err(|e| {
266 error!("Failed to initialize hub: {}", e);
267 ErrorData {
268 code: ErrorCode::INTERNAL_ERROR,
269 message: format!("Failed to initialize hub: {}", e).into(),
270 data: None,
271 }
272 })?;
273
274 debug!("Hub initialization completed, got {} tools", _tools.len());
275 info!("Aggregator initialized successfully");
276
277 let server_capabilities = ServerCapabilities {
278 experimental: None,
279 logging: None,
280 prompts: None,
281 resources: None,
282 tools: Some(ToolsCapability {
283 list_changed: Some(true),
284 }),
285 };
286
287 Ok(InitializeResult {
288 protocol_version: ProtocolVersion::default(),
289 capabilities: server_capabilities,
290 server_info: self.server_info.clone(),
291 instructions: self.instructions.clone(),
292 })
293 }
294 }
295
296 async fn list_tools(
297 &self,
298 _request: PaginatedRequestParam,
299 _context: RequestContext<RoleServer>,
300 ) -> Result<ListToolsResult, McpError> {
301 debug!("Listing all aggregated tools");
302
303 let tools = self.get_aggregated_tools().await;
304
305 debug!("Returning {} tools", tools.len());
306
307 Ok(ListToolsResult {
308 tools,
309 next_cursor: None,
310 })
311 }
312
313 async fn call_tool(
314 &self,
315 request: CallToolRequestParam,
316 _context: RequestContext<RoleServer>,
317 ) -> Result<CallToolResult, McpError> {
318 info!("Calling tool: {}", request.name);
319
320 self.route_tool_call(request).await.map_err(|e| {
321 error!("Tool call failed: {}", e);
322 ErrorData {
323 code: ErrorCode::METHOD_NOT_FOUND,
324 message: format!("Tool call failed: {}", e).into(),
325 data: None,
326 }
327 })
328 }
329
330 async fn ping(&self, _context: RequestContext<RoleServer>) -> Result<(), McpError> {
331 debug!("Ping received");
332 Ok(())
333 }
334
335 async fn on_initialized(&self) {
336 info!("MCP aggregator initialized and ready");
337 }
338}
339
340impl Drop for McpAggregator {
341 fn drop(&mut self) {
342 if let Some(shutdown_tx) = self
344 .shutdown_tx
345 .try_lock()
346 .ok()
347 .and_then(|mut guard| guard.take())
348 {
349 let _ = shutdown_tx.send(());
350 }
351 }
352}
353
354pub struct AggregatorBuilder {
356 config: Option<Config>,
357 supervisor_config: Option<SupervisorConfig>,
358}
359
360impl AggregatorBuilder {
361 pub fn new() -> Self {
363 Self {
364 config: None,
365 supervisor_config: None,
366 }
367 }
368
369 pub fn with_config(mut self, config: Config) -> Self {
371 self.config = Some(config);
372 self
373 }
374
375 pub fn with_supervisor_config(mut self, supervisor_config: SupervisorConfig) -> Self {
377 self.supervisor_config = Some(supervisor_config);
378 self
379 }
380
381 pub fn build(self) -> Result<McpAggregator> {
383 let config = self
384 .config
385 .ok_or_else(|| anyhow::anyhow!("Configuration is required"))?;
386
387 let supervisor_config = self.supervisor_config.unwrap_or_default();
388
389 Ok(McpAggregator::new(config, supervisor_config))
390 }
391}
392
393impl Default for AggregatorBuilder {
394 fn default() -> Self {
395 Self::new()
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use crate::config::ServerConfig;
403 use std::collections::HashMap;
404
405 fn create_test_config() -> Config {
406 let server_configs = vec![
407 ServerConfig {
408 name: "test1".to_string(),
409 cmd: "echo".to_string(),
410 args: vec!["hello".to_string()],
411 env: HashMap::new(),
412 whitelist: None,
413 blacklist: None,
414 prefix: Some("test1_".to_string()),
415 description_suffix: Some(" (test server 1)".to_string()),
416 description_prefix: None,
417 },
418 ServerConfig {
419 name: "test2".to_string(),
420 cmd: "echo".to_string(),
421 args: vec!["world".to_string()],
422 env: HashMap::new(),
423 whitelist: None,
424 blacklist: None,
425 prefix: Some("test2_".to_string()),
426 description_suffix: Some(" (test server 2)".to_string()),
427 description_prefix: None,
428 },
429 ];
430
431 Config {
432 server: server_configs,
433 }
434 }
435
436 #[tokio::test]
437 async fn test_aggregator_creation() {
438 let config = create_test_config();
439 let aggregator = McpAggregator::new(config, SupervisorConfig::default());
440
441 let info = aggregator.get_info();
442 assert_eq!(info.server_info.name, "MCP Hub Aggregator");
443 assert!(aggregator
444 .instructions
445 .as_ref()
446 .unwrap()
447 .contains("test1, test2"));
448 }
449
450 #[tokio::test]
451 async fn test_builder_pattern() -> Result<()> {
452 let config = create_test_config();
453 let aggregator = AggregatorBuilder::new()
454 .with_config(config)
455 .with_supervisor_config(SupervisorConfig::default())
456 .build()?;
457
458 assert_eq!(aggregator.get_info().server_info.name, "MCP Hub Aggregator");
459 Ok(())
460 }
461
462 #[tokio::test]
463 async fn test_builder_requires_config() {
464 let result = AggregatorBuilder::new().build();
465 assert!(result.is_err());
466 assert!(result
467 .unwrap_err()
468 .to_string()
469 .contains("Configuration is required"));
470 }
471}