mcp_hub/
aggregator.rs

1use anyhow::{Context, Result};
2use rmcp::{
3    model::{
4        CallToolRequestParam, CallToolResult, ErrorCode, ErrorData, Implementation,
5        InitializeRequestParam, InitializeResult, ListToolsResult, PaginatedRequestParam,
6        ProtocolVersion, ServerCapabilities, Tool, ToolsCapability,
7    },
8    service::{Peer, RequestContext, RoleServer},
9    Error as McpError, ServerHandler,
10};
11use std::sync::Arc;
12use tokio::sync::{mpsc, oneshot, Mutex};
13use tracing::{debug, error, info, warn};
14
15use crate::{
16    config::Config,
17    hub_actor::{spawn_hub_actor, HubMessage},
18    metrics::{HubMetrics, MetricsTimer},
19    supervisor::SupervisorConfig,
20};
21
22/// Main MCP aggregator that coordinates with the hub actor
23#[derive(Debug, Clone)]
24pub struct McpAggregator {
25    /// Channel to communicate with the hub actor
26    hub_sender: mpsc::Sender<HubMessage>,
27    /// Server information
28    server_info: Implementation,
29    /// Instructions for the aggregator
30    pub instructions: Option<String>,
31    /// Peer connection (set during initialization)
32    peer: Option<Peer<RoleServer>>,
33    /// Shutdown coordination
34    shutdown_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
35    /// Metrics collector
36    metrics: HubMetrics,
37}
38
39impl McpAggregator {
40    /// Create a new MCP aggregator
41    pub fn new(config: Config, supervisor_config: SupervisorConfig) -> Self {
42        let (server_info, instructions) = Self::create_server_info(&config);
43
44        // Create shutdown coordination
45        let (shutdown_tx, shutdown_rx) = oneshot::channel();
46        let shutdown_tx = Arc::new(Mutex::new(Some(shutdown_tx)));
47
48        // Spawn the hub actor
49        let hub_sender = spawn_hub_actor(config, supervisor_config, shutdown_rx);
50
51        Self {
52            hub_sender,
53            server_info,
54            instructions,
55            peer: None,
56            shutdown_tx,
57            metrics: HubMetrics::new(),
58        }
59    }
60
61    /// Create server info from configuration
62    fn create_server_info(config: &Config) -> (Implementation, Option<String>) {
63        let server_names = config
64            .servers()
65            .iter()
66            .map(|s| s.name.as_str())
67            .collect::<Vec<_>>()
68            .join(", ");
69
70        let server_info = Implementation {
71            name: "MCP Hub Aggregator".into(),
72            version: env!("CARGO_PKG_VERSION").into(),
73        };
74
75        let instructions = Some(format!(
76            "MCP Hub aggregating {} servers: {}",
77            config.servers().len(),
78            server_names
79        ));
80
81        (server_info, instructions)
82    }
83
84    /// Initialize the hub actor
85    async fn initialize_hub(&self) -> Result<Vec<Tool>> {
86        debug!("Initializing hub actor - entering method");
87        debug!(
88            "Hub sender channel capacity: {}",
89            self.hub_sender.capacity()
90        );
91
92        let (response_tx, response_rx) = oneshot::channel();
93
94        debug!("Sending Initialize message to hub actor");
95        self.hub_sender
96            .send(HubMessage::Initialize {
97                response: response_tx,
98            })
99            .await
100            .context("Failed to send initialize message to hub actor")?;
101
102        debug!("Initialize message sent, waiting for response");
103
104        let result = response_rx
105            .await
106            .context("Hub actor did not respond to initialize")?;
107
108        debug!(
109            "Received response from hub actor with {} tools",
110            result.as_ref().map(|t| t.len()).unwrap_or(0)
111        );
112        result
113    }
114
115    /// Get aggregated tools from the hub
116    async fn get_aggregated_tools(&self) -> Vec<Tool> {
117        let (response_tx, response_rx) = oneshot::channel();
118
119        if self
120            .hub_sender
121            .send(HubMessage::ListTools {
122                response: response_tx,
123            })
124            .await
125            .is_err()
126        {
127            warn!("Failed to send list_tools message to hub actor");
128            return Vec::new();
129        }
130
131        response_rx.await.unwrap_or_default()
132    }
133
134    /// Route a tool call to the hub
135    async fn route_tool_call(&self, params: CallToolRequestParam) -> Result<CallToolResult> {
136        debug!("Routing tool call: {}", params.name);
137
138        // Create metrics timer for this tool call
139        let timer = MetricsTimer::new(
140            self.metrics.clone(),
141            "hub".to_string(), // We'll use "hub" as server name for aggregated metrics
142            params.name.to_string(),
143        );
144
145        let (response_tx, response_rx) = oneshot::channel();
146
147        let send_result = self
148            .hub_sender
149            .send(HubMessage::CallTool {
150                params,
151                response: response_tx,
152            })
153            .await;
154
155        match send_result {
156            Ok(_) => {
157                let response_result = response_rx
158                    .await
159                    .context("Hub actor did not respond to call_tool");
160
161                match response_result {
162                    Ok(result) => {
163                        timer.finish_with_status("success");
164                        result
165                    }
166                    Err(e) => {
167                        timer.finish_with_error("hub_timeout");
168                        Err(e)
169                    }
170                }
171            }
172            Err(e) => {
173                timer.finish_with_error("send_failed");
174                Err(anyhow::anyhow!(
175                    "Failed to send call_tool message to hub actor: {}",
176                    e
177                ))
178            }
179        }
180    }
181
182    /// Reload configuration in the hub
183    #[allow(dead_code)]
184    pub async fn reload_config(&self, config: Config) -> Result<()> {
185        let (response_tx, response_rx) = oneshot::channel();
186
187        self.hub_sender
188            .send(HubMessage::ReloadConfig {
189                config,
190                response: response_tx,
191            })
192            .await
193            .context("Failed to send reload_config message to hub actor")?;
194
195        response_rx
196            .await
197            .context("Hub actor did not respond to reload_config")?
198    }
199
200    /// Get status of all supervisors
201    #[allow(dead_code)]
202    pub async fn get_status(
203        &self,
204    ) -> std::collections::HashMap<String, crate::supervisor::ProcessState> {
205        let (response_tx, response_rx) = oneshot::channel();
206
207        if self
208            .hub_sender
209            .send(HubMessage::GetStatus {
210                response: response_tx,
211            })
212            .await
213            .is_err()
214        {
215            warn!("Failed to send get_status message to hub actor");
216            return std::collections::HashMap::new();
217        }
218
219        response_rx.await.unwrap_or_default()
220    }
221
222    /// Get a handle to communicate with the hub actor
223    pub fn hub_handle(&self) -> mpsc::Sender<HubMessage> {
224        self.hub_sender.clone()
225    }
226}
227
228impl ServerHandler for McpAggregator {
229    fn get_info(&self) -> InitializeResult {
230        debug!("get_info() called - returning server capabilities");
231        InitializeResult {
232            protocol_version: ProtocolVersion::default(),
233            capabilities: ServerCapabilities {
234                experimental: None,
235                logging: None,
236                prompts: None,
237                resources: None,
238                tools: Some(ToolsCapability {
239                    list_changed: Some(true),
240                }),
241            },
242            server_info: self.server_info.clone(),
243            instructions: self.instructions.clone(),
244        }
245    }
246
247    fn get_peer(&self) -> Option<Peer<RoleServer>> {
248        self.peer.clone()
249    }
250
251    fn set_peer(&mut self, peer: Peer<RoleServer>) {
252        self.peer = Some(peer);
253    }
254
255    fn initialize(
256        &self,
257        _request: InitializeRequestParam,
258        _context: RequestContext<RoleServer>,
259    ) -> impl std::future::Future<Output = Result<InitializeResult, McpError>> + Send + '_ {
260        async move {
261            info!("Initializing MCP aggregator");
262            debug!("About to call initialize_hub()");
263
264            // Initialize the hub actor
265            let _tools = self.initialize_hub().await.map_err(|e| {
266                error!("Failed to initialize hub: {}", e);
267                ErrorData {
268                    code: ErrorCode::INTERNAL_ERROR,
269                    message: format!("Failed to initialize hub: {}", e).into(),
270                    data: None,
271                }
272            })?;
273
274            debug!("Hub initialization completed, got {} tools", _tools.len());
275            info!("Aggregator initialized successfully");
276
277            let server_capabilities = ServerCapabilities {
278                experimental: None,
279                logging: None,
280                prompts: None,
281                resources: None,
282                tools: Some(ToolsCapability {
283                    list_changed: Some(true),
284                }),
285            };
286
287            Ok(InitializeResult {
288                protocol_version: ProtocolVersion::default(),
289                capabilities: server_capabilities,
290                server_info: self.server_info.clone(),
291                instructions: self.instructions.clone(),
292            })
293        }
294    }
295
296    async fn list_tools(
297        &self,
298        _request: PaginatedRequestParam,
299        _context: RequestContext<RoleServer>,
300    ) -> Result<ListToolsResult, McpError> {
301        debug!("Listing all aggregated tools");
302
303        let tools = self.get_aggregated_tools().await;
304
305        debug!("Returning {} tools", tools.len());
306
307        Ok(ListToolsResult {
308            tools,
309            next_cursor: None,
310        })
311    }
312
313    async fn call_tool(
314        &self,
315        request: CallToolRequestParam,
316        _context: RequestContext<RoleServer>,
317    ) -> Result<CallToolResult, McpError> {
318        info!("Calling tool: {}", request.name);
319
320        self.route_tool_call(request).await.map_err(|e| {
321            error!("Tool call failed: {}", e);
322            ErrorData {
323                code: ErrorCode::METHOD_NOT_FOUND,
324                message: format!("Tool call failed: {}", e).into(),
325                data: None,
326            }
327        })
328    }
329
330    async fn ping(&self, _context: RequestContext<RoleServer>) -> Result<(), McpError> {
331        debug!("Ping received");
332        Ok(())
333    }
334
335    async fn on_initialized(&self) {
336        info!("MCP aggregator initialized and ready");
337    }
338}
339
340impl Drop for McpAggregator {
341    fn drop(&mut self) {
342        // Trigger shutdown when aggregator is dropped
343        if let Some(shutdown_tx) = self
344            .shutdown_tx
345            .try_lock()
346            .ok()
347            .and_then(|mut guard| guard.take())
348        {
349            let _ = shutdown_tx.send(());
350        }
351    }
352}
353
354/// Builder for creating an MCP aggregator
355pub struct AggregatorBuilder {
356    config: Option<Config>,
357    supervisor_config: Option<SupervisorConfig>,
358}
359
360impl AggregatorBuilder {
361    /// Create a new builder
362    pub fn new() -> Self {
363        Self {
364            config: None,
365            supervisor_config: None,
366        }
367    }
368
369    /// Set configuration
370    pub fn with_config(mut self, config: Config) -> Self {
371        self.config = Some(config);
372        self
373    }
374
375    /// Set supervisor configuration
376    pub fn with_supervisor_config(mut self, supervisor_config: SupervisorConfig) -> Self {
377        self.supervisor_config = Some(supervisor_config);
378        self
379    }
380
381    /// Build the aggregator
382    pub fn build(self) -> Result<McpAggregator> {
383        let config = self
384            .config
385            .ok_or_else(|| anyhow::anyhow!("Configuration is required"))?;
386
387        let supervisor_config = self.supervisor_config.unwrap_or_default();
388
389        Ok(McpAggregator::new(config, supervisor_config))
390    }
391}
392
393impl Default for AggregatorBuilder {
394    fn default() -> Self {
395        Self::new()
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use crate::config::ServerConfig;
403    use std::collections::HashMap;
404
405    fn create_test_config() -> Config {
406        let server_configs = vec![
407            ServerConfig {
408                name: "test1".to_string(),
409                cmd: "echo".to_string(),
410                args: vec!["hello".to_string()],
411                env: HashMap::new(),
412                whitelist: None,
413                blacklist: None,
414                prefix: Some("test1_".to_string()),
415                description_suffix: Some(" (test server 1)".to_string()),
416                description_prefix: None,
417            },
418            ServerConfig {
419                name: "test2".to_string(),
420                cmd: "echo".to_string(),
421                args: vec!["world".to_string()],
422                env: HashMap::new(),
423                whitelist: None,
424                blacklist: None,
425                prefix: Some("test2_".to_string()),
426                description_suffix: Some(" (test server 2)".to_string()),
427                description_prefix: None,
428            },
429        ];
430
431        Config {
432            server: server_configs,
433        }
434    }
435
436    #[tokio::test]
437    async fn test_aggregator_creation() {
438        let config = create_test_config();
439        let aggregator = McpAggregator::new(config, SupervisorConfig::default());
440
441        let info = aggregator.get_info();
442        assert_eq!(info.server_info.name, "MCP Hub Aggregator");
443        assert!(aggregator
444            .instructions
445            .as_ref()
446            .unwrap()
447            .contains("test1, test2"));
448    }
449
450    #[tokio::test]
451    async fn test_builder_pattern() -> Result<()> {
452        let config = create_test_config();
453        let aggregator = AggregatorBuilder::new()
454            .with_config(config)
455            .with_supervisor_config(SupervisorConfig::default())
456            .build()?;
457
458        assert_eq!(aggregator.get_info().server_info.name, "MCP Hub Aggregator");
459        Ok(())
460    }
461
462    #[tokio::test]
463    async fn test_builder_requires_config() {
464        let result = AggregatorBuilder::new().build();
465        assert!(result.is_err());
466        assert!(result
467            .unwrap_err()
468            .to_string()
469            .contains("Configuration is required"));
470    }
471}