Skip to main content

mcp_utils/client/
manager.rs

1use llm::ToolDefinition;
2
3use super::{
4    McpError, Result,
5    config::{McpServerConfig, ServerConfig},
6    connection::{ConnectParams, ConnectResult, McpServerConnection, ServerInstructions, Tool},
7    mcp_client::McpClient,
8    naming::{create_namespaced_tool_name, split_on_server_name},
9    oauth::{OAuthHandler, perform_oauth_flow},
10    tool_proxy::ToolProxy,
11};
12use rmcp::{
13    RoleClient,
14    model::{
15        CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
16        ElicitationAction, Implementation, Root,
17    },
18    service::RunningService,
19    transport::streamable_http_client::StreamableHttpClientTransportConfig,
20};
21use serde_json::Value;
22use std::collections::{HashMap, HashSet};
23use std::sync::Arc;
24use tokio::sync::{RwLock, mpsc, oneshot};
25
26pub use crate::status::{McpServerStatus, McpServerStatusEntry};
27
28#[derive(Debug)]
29pub struct ElicitationRequest {
30    pub request: CreateElicitationRequestParams,
31    pub response_sender: oneshot::Sender<CreateElicitationResult>,
32}
33
34#[derive(Debug, Clone)]
35pub struct ElicitationResponse {
36    pub action: ElicitationAction,
37    pub content: Option<Value>,
38}
39
40/// Whether a server's tools should be directly exposed to the agent or only
41/// registered internally for proxy routing.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43enum Registration {
44    /// Tools are added to `tool_definitions` (visible to the agent).
45    Direct,
46    /// Tools are stored in `self.tools` for routing but not exposed to the agent.
47    Proxied,
48}
49
50/// Manages connections to multiple MCP servers and their tools
51pub struct McpManager {
52    servers: HashMap<String, McpServerConnection>,
53    tools: HashMap<String, Tool>,
54    tool_definitions: Vec<ToolDefinition>,
55    client_info: ClientInfo,
56    elicitation_sender: mpsc::Sender<ElicitationRequest>,
57    /// Roots shared with all MCP clients
58    roots: Arc<RwLock<Vec<Root>>>,
59    oauth_handler: Option<Arc<dyn OAuthHandler>>,
60    server_statuses: Vec<McpServerStatusEntry>,
61    /// Configs for failed HTTP servers so we can retry OAuth later
62    pending_configs: HashMap<String, StreamableHttpClientTransportConfig>,
63    /// Optional tool-proxy that wraps multiple servers behind a single `call_tool`.
64    proxy: Option<ToolProxy>,
65}
66
67impl McpManager {
68    pub fn new(
69        elicitation_sender: mpsc::Sender<ElicitationRequest>,
70        oauth_handler: Option<Arc<dyn OAuthHandler>>,
71    ) -> Self {
72        Self {
73            servers: HashMap::new(),
74            tools: HashMap::new(),
75            tool_definitions: Vec::new(),
76            client_info: ClientInfo::new(
77                ClientCapabilities::builder().enable_elicitation().enable_roots().build(),
78                Implementation::new("aether", "0.1.0"),
79            ),
80            elicitation_sender,
81            roots: Arc::new(RwLock::new(Vec::new())),
82            oauth_handler,
83            server_statuses: Vec::new(),
84            pending_configs: HashMap::new(),
85            proxy: None,
86        }
87    }
88
89    fn create_mcp_client(&self) -> McpClient {
90        McpClient::new(self.client_info.clone(), self.elicitation_sender.clone(), Arc::clone(&self.roots))
91    }
92
93    fn connect_params(&self) -> ConnectParams {
94        ConnectParams { mcp_client: self.create_mcp_client(), oauth_handler: self.oauth_handler.clone() }
95    }
96
97    /// Update or insert the status entry for a server.
98    fn set_status(&mut self, name: &str, status: McpServerStatus) {
99        if let Some(entry) = self.server_statuses.iter_mut().find(|s| s.name == name) {
100            entry.status = status;
101        } else {
102            self.server_statuses.push(McpServerStatusEntry { name: name.to_string(), status });
103        }
104    }
105
106    pub async fn add_mcps(&mut self, configs: Vec<McpServerConfig>) -> Result<()> {
107        for config in configs {
108            let name = config.name().to_string();
109            if let Err(e) = self.add_mcp(config).await {
110                // Log warning but continue with other servers
111                tracing::warn!("Failed to connect to MCP server '{}': {}", name, e);
112                // Only record Failed if not already recorded by connect logic
113                if !self.server_statuses.iter().any(|s| s.name == name) {
114                    self.set_status(&name, McpServerStatus::Failed { error: e.to_string() });
115                }
116            }
117        }
118        Ok(())
119    }
120
121    pub async fn add_mcp_with_auth(&mut self, name: String, base_url: &str, auth_header: String) -> Result<()> {
122        let config = ServerConfig::Http {
123            name: name.clone(),
124            config: StreamableHttpClientTransportConfig::with_uri(base_url).auth_header(auth_header),
125        };
126        let params = self.connect_params();
127        match McpServerConnection::connect(config, params).await {
128            ConnectResult::Connected(conn) => self.register_server(&name, conn, Registration::Direct).await,
129            ConnectResult::NeedsOAuth { error, .. } => Err(error),
130            ConnectResult::Failed(e) => Err(e),
131        }
132    }
133
134    pub async fn add_mcp(&mut self, config: McpServerConfig) -> Result<()> {
135        match config {
136            McpServerConfig::ToolProxy { name, servers } => self.connect_tool_proxy(name, servers).await,
137
138            McpServerConfig::Server(config) => {
139                let name = config.name().to_string();
140                let params = self.connect_params();
141                match McpServerConnection::connect(config, params).await {
142                    ConnectResult::Connected(conn) => self.register_server(&name, conn, Registration::Direct).await,
143                    ConnectResult::NeedsOAuth { name, config, error } => {
144                        self.pending_configs.insert(name.clone(), config);
145                        self.set_status(&name, McpServerStatus::NeedsOAuth);
146                        Err(error)
147                    }
148                    ConnectResult::Failed(e) => Err(e),
149                }
150            }
151        }
152    }
153
154    /// Connect a tool-proxy: register each nested server individually through
155    /// the manager (getting OAuth for free), then inject a single `call_tool`
156    /// virtual tool for the agent.
157    async fn connect_tool_proxy(&mut self, proxy_name: String, servers: Vec<ServerConfig>) -> Result<()> {
158        let tool_dir = ToolProxy::dir(&proxy_name)?;
159        ToolProxy::clean_dir(&tool_dir).await?;
160
161        let mut nested_names = HashSet::new();
162        let mut server_descriptions = Vec::new();
163
164        for config in servers {
165            let server_name = config.name().to_string();
166            let params = self.connect_params();
167
168            let result = match McpServerConnection::connect(config, params).await {
169                ConnectResult::Connected(conn) => self.register_server(&server_name, conn, Registration::Proxied).await,
170                ConnectResult::NeedsOAuth { name, config, error } => {
171                    self.pending_configs.insert(name.clone(), config);
172                    self.set_status(&name, McpServerStatus::NeedsOAuth);
173                    Err(error)
174                }
175                ConnectResult::Failed(e) => Err(e),
176            };
177
178            match result {
179                Ok(()) => {
180                    // Write tool files to disk for agent browsing
181                    if let Some(conn) = self.servers.get(&server_name) {
182                        let client = conn.client.clone();
183                        if let Err(e) = ToolProxy::write_tools_to_dir(&server_name, &client, &tool_dir).await {
184                            tracing::warn!("Failed to write tool files for nested server '{server_name}': {e}");
185                        }
186
187                        let description = ToolProxy::extract_server_description(&client, &server_name);
188                        server_descriptions.push((server_name.clone(), description));
189                    }
190                    nested_names.insert(server_name);
191                }
192                Err(e) => {
193                    tracing::warn!("Failed to connect nested server '{server_name}': {e}");
194                    // If it was stashed as NeedsOAuth, record the membership so
195                    // authenticate_server can write tool files later.
196                    if self.pending_configs.contains_key(&server_name) {
197                        nested_names.insert(server_name);
198                    }
199                }
200            }
201        }
202
203        let call_tool_def = ToolProxy::call_tool_definition(&proxy_name);
204        self.tools.insert(
205            call_tool_def.name.clone(),
206            Tool {
207                description: call_tool_def.description.clone(),
208                parameters: serde_json::from_str(&call_tool_def.parameters)
209                    .unwrap_or(Value::Object(serde_json::Map::default())),
210            },
211        );
212        self.tool_definitions.push(call_tool_def);
213
214        self.proxy = Some(ToolProxy::new(proxy_name.clone(), nested_names, tool_dir, &server_descriptions));
215
216        // Add proxy status entry
217        self.set_status(&proxy_name, McpServerStatus::Connected { tool_count: 1 });
218
219        Ok(())
220    }
221
222    async fn oauth_and_reconnect(&mut self, name: String, config: StreamableHttpClientTransportConfig) -> Result<()> {
223        let handler = self
224            .oauth_handler
225            .as_ref()
226            .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler available for '{name}'")))?;
227        let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref())
228            .await
229            .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
230
231        let mcp_client = self.create_mcp_client();
232        let conn = McpServerConnection::reconnect_with_auth(&name, config, auth_client, mcp_client).await?;
233
234        // If this server is proxied, register without exposing tools to the agent
235        if let Some(proxy) = self.proxy.as_ref().filter(|p| p.contains_server(&name)) {
236            let tool_dir = proxy.tool_dir().to_path_buf();
237            self.register_server(&name, conn, Registration::Proxied).await?;
238            // Write tool files now that connection succeeded
239            if let Some(conn) = self.servers.get(&name) {
240                let client = conn.client.clone();
241                if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &tool_dir).await {
242                    tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
243                }
244            }
245            Ok(())
246        } else {
247            self.register_server(&name, conn, Registration::Direct).await
248        }
249    }
250
251    /// Register a connected server and discover its tools.
252    ///
253    /// When `registration` is `Direct`, discovered tools are added to
254    /// `self.tool_definitions` (visible to the agent). When `Proxied`, tools are
255    /// only stored in `self.tools` for internal routing.
256    async fn register_server(
257        &mut self,
258        name: &str,
259        conn: McpServerConnection,
260        registration: Registration,
261    ) -> Result<()> {
262        let tools = conn
263            .list_tools()
264            .await
265            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
266
267        for rmcp_tool in &tools {
268            let tool_name = rmcp_tool.name.to_string();
269            let namespaced_tool_name = create_namespaced_tool_name(name, &tool_name);
270            let tool = Tool::from(rmcp_tool);
271
272            if registration == Registration::Direct {
273                self.tool_definitions.push(ToolDefinition {
274                    name: namespaced_tool_name.clone(),
275                    description: tool.description.clone(),
276                    parameters: tool.parameters.to_string(),
277                    server: Some(name.to_string()),
278                });
279            }
280
281            self.tools.insert(namespaced_tool_name, tool);
282        }
283
284        let tool_count = tools.len();
285
286        self.set_status(name, McpServerStatus::Connected { tool_count });
287
288        // Remove from pending configs if it was there
289        self.pending_configs.remove(name);
290
291        self.servers.insert(name.to_string(), conn);
292        Ok(())
293    }
294
295    /// Resolve and route a tool call.
296    ///
297    /// Returns the target MCP client and normalized call params. For proxy
298    /// `call_tool`, this parses the wrapper arguments and forwards to the
299    /// selected nested server/tool.
300    pub fn get_client_for_tool(
301        &self,
302        namespaced_tool_name: &str,
303        arguments_json: &str,
304    ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
305        if !self.tools.contains_key(namespaced_tool_name) {
306            return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
307        }
308
309        let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
310            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
311
312        if let Some(proxy) = self.proxy.as_ref().filter(|p| p.name() == server_name) {
313            let call = proxy.resolve_call(arguments_json)?;
314            let conn = self
315                .servers
316                .get(&call.server)
317                .ok_or_else(|| McpError::ServerNotFound(format!("Nested server '{}' is not connected", call.server)))?;
318            let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
319            return Ok((conn.client.clone(), params));
320        }
321
322        let client = self
323            .servers
324            .get(server_name)
325            .map(|server| server.client.clone())
326            .ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
327
328        let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
329        let mut params = CallToolRequestParams::new(tool_name.to_string());
330        if let Some(args) = arguments {
331            params = params.with_arguments(args);
332        }
333
334        Ok((client, params))
335    }
336
337    pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
338        self.tool_definitions.clone()
339    }
340
341    /// Returns instructions from all connected MCP servers that provide them,
342    /// plus synthesized instructions for tool-proxy groups.
343    pub fn server_instructions(&self) -> Vec<ServerInstructions> {
344        let mut instructions: Vec<ServerInstructions> = self
345            .servers
346            .iter()
347            .filter(|(name, _)| self.proxy.as_ref().is_none_or(|p| !p.contains_server(name)))
348            .filter_map(|(name, conn)| {
349                conn.instructions
350                    .as_ref()
351                    .map(|instr| ServerInstructions { server_name: name.clone(), instructions: instr.clone() })
352            })
353            .collect();
354
355        if let Some(proxy) = &self.proxy {
356            instructions.push(ServerInstructions {
357                server_name: proxy.name().to_string(),
358                instructions: proxy.instructions().to_string(),
359            });
360        }
361
362        instructions
363    }
364
365    pub fn server_statuses(&self) -> &[McpServerStatusEntry] {
366        &self.server_statuses
367    }
368
369    /// Authenticate a server that previously failed with `NeedsOAuth`.
370    ///
371    /// Looks up the pending config, runs the OAuth flow, and updates the status
372    /// entry on success.
373    pub async fn authenticate_server(&mut self, name: &str) -> Result<()> {
374        let config = self
375            .pending_configs
376            .get(name)
377            .ok_or_else(|| McpError::ConnectionFailed(format!("no pending config for server '{name}'")))?
378            .clone();
379
380        self.oauth_and_reconnect(name.to_string(), config).await
381    }
382
383    /// List all prompts from all connected MCP servers with namespacing
384    pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
385        use futures::future::join_all;
386
387        let futures: Vec<_> = self
388            .servers
389            .iter()
390            .filter(|(_, server_conn)| {
391                server_conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref()).is_some()
392            })
393            .map(|(server_name, server_conn)| {
394                let server_name = server_name.clone();
395                let client = server_conn.client.clone();
396                async move {
397                    let prompts_response = client.list_prompts(None).await.map_err(|e| {
398                        McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
399                    })?;
400
401                    let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
402                        .prompts
403                        .into_iter()
404                        .map(|prompt| {
405                            let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
406                            rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
407                        })
408                        .collect();
409
410                    Ok::<_, McpError>(namespaced_prompts)
411                }
412            })
413            .collect();
414
415        let results = join_all(futures).await;
416        let mut all_prompts = Vec::new();
417        for result in results {
418            all_prompts.extend(result?);
419        }
420
421        Ok(all_prompts)
422    }
423
424    /// Get a specific prompt by namespaced name
425    pub async fn get_prompt(
426        &self,
427        namespaced_prompt_name: &str,
428        arguments: Option<serde_json::Map<String, serde_json::Value>>,
429    ) -> Result<rmcp::model::GetPromptResult> {
430        let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
431            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
432
433        let server_conn =
434            self.servers.get(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
435
436        let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
437        if let Some(args) = arguments {
438            request = request.with_arguments(args);
439        }
440
441        server_conn.client.get_prompt(request).await.map_err(|e| {
442            McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
443        })
444    }
445
446    /// Shutdown all servers and wait for their tasks to complete
447    pub async fn shutdown(&mut self) {
448        let servers: Vec<(String, McpServerConnection)> = self.servers.drain().collect();
449
450        for (server_name, server) in servers {
451            if let Some(handle) = server.server_task {
452                // Drop the client first to signal shutdown
453                drop(server.client);
454
455                // Wait for the server task to complete (with a timeout)
456                match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
457                    Ok(Ok(())) => {
458                        tracing::info!("Server '{server_name}' shut down gracefully");
459                    }
460                    Ok(Err(e)) => {
461                        tracing::warn!("Server '{server_name}' task panicked: {e:?}");
462                    }
463                    Err(_) => {
464                        tracing::warn!("Server '{server_name}' shutdown timed out");
465                        // Task will be cancelled when the handle is dropped
466                    }
467                }
468            }
469        }
470
471        self.tools.clear();
472        self.tool_definitions.clear();
473        self.proxy = None;
474    }
475
476    /// Shutdown a specific server by name
477    pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
478        let server = self.servers.remove(server_name);
479
480        if let Some(server) = server {
481            if let Some(handle) = server.server_task {
482                // Drop the client first to signal shutdown
483                drop(server.client);
484
485                // Wait for the server task to complete (with a timeout)
486                match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
487                    Ok(Ok(())) => {
488                        tracing::info!("Server '{server_name}' shut down gracefully");
489                    }
490                    Ok(Err(e)) => {
491                        tracing::warn!("Server '{server_name}' task panicked: {e:?}");
492                    }
493                    Err(_) => {
494                        tracing::warn!("Server '{server_name}' shutdown timed out");
495                        // Task will be cancelled when the handle is dropped
496                    }
497                }
498            }
499
500            // Remove tools from this server
501            self.tools.retain(|tool_name, _| !tool_name.starts_with(server_name));
502
503            self.tool_definitions.retain(|tool_def| !tool_def.name.starts_with(server_name));
504        }
505
506        Ok(())
507    }
508
509    /// Set the roots advertised to MCP servers.
510    ///
511    /// This updates the roots and sends notifications to all connected servers
512    /// that support the `roots/list_changed` notification.
513    pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
514        // Update stored roots
515        {
516            let mut roots = self.roots.write().await;
517            *roots = new_roots;
518        }
519
520        // Notify all connected servers
521        self.notify_roots_changed().await;
522
523        Ok(())
524    }
525
526    /// Send `roots/list_changed` notification to all connected servers.
527    ///
528    /// This prompts servers to re-request the roots via the roots/list endpoint.
529    /// Servers that don't support roots will simply ignore the notification.
530    async fn notify_roots_changed(&self) {
531        for (server_name, server_conn) in &self.servers {
532            // Try to send notification - servers that don't support roots will ignore it
533            if let Err(e) = server_conn.client.notify_roots_list_changed().await {
534                // Only log errors for debugging; it's expected that some servers may not support roots
535                tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
536            }
537        }
538    }
539}
540
541impl Drop for McpManager {
542    fn drop(&mut self) {
543        let servers: Vec<(String, McpServerConnection)> = self.servers.drain().collect();
544        for (server_name, server) in servers {
545            if let Some(handle) = server.server_task {
546                handle.abort();
547                tracing::warn!("Server '{server_name}' task aborted during cleanup");
548            }
549        }
550    }
551}
552
553#[cfg(test)]
554mod tests {
555    use super::McpManager;
556    use crate::client::config::ServerConfig;
557    use rmcp::{
558        Json, RoleServer, ServerHandler,
559        handler::server::{router::tool::ToolRouter, wrapper::Parameters},
560        model::{Implementation, ServerCapabilities, ServerInfo},
561        service::DynService,
562        tool, tool_handler, tool_router,
563    };
564    use schemars::JsonSchema;
565    use serde::{Deserialize, Serialize};
566    use std::{
567        io,
568        sync::{Arc, Mutex},
569    };
570    use tokio::sync::mpsc;
571
572    #[derive(Clone)]
573    struct TestServer {
574        tool_router: ToolRouter<Self>,
575    }
576
577    #[tool_handler(router = self.tool_router)]
578    impl ServerHandler for TestServer {
579        fn get_info(&self) -> ServerInfo {
580            ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
581                .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
582        }
583    }
584
585    impl Default for TestServer {
586        fn default() -> Self {
587            Self { tool_router: Self::tool_router() }
588        }
589    }
590
591    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
592    struct EchoRequest {
593        value: String,
594    }
595
596    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
597    struct EchoResult {
598        value: String,
599    }
600
601    #[tool_router]
602    impl TestServer {
603        fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
604            Box::new(self)
605        }
606
607        #[tool(description = "Returns the provided value")]
608        async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
609            let Parameters(EchoRequest { value }) = request;
610            Json(EchoResult { value })
611        }
612    }
613
614    #[derive(Clone)]
615    struct SharedWriter(Arc<Mutex<Vec<u8>>>);
616
617    impl io::Write for SharedWriter {
618        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
619            self.0.lock().unwrap().extend_from_slice(buf);
620            Ok(buf.len())
621        }
622
623        fn flush(&mut self) -> io::Result<()> {
624            Ok(())
625        }
626    }
627
628    #[tokio::test]
629    async fn drop_logs_cleanup_abort_with_tracing() {
630        let (elicitation_sender, _elicitation_receiver) = mpsc::channel(1);
631        let mut manager = McpManager::new(elicitation_sender, None);
632        manager
633            .add_mcp(
634                ServerConfig::InMemory { name: "test".to_string(), server: TestServer::default().into_dyn() }.into(),
635            )
636            .await
637            .unwrap();
638
639        let output = Arc::new(Mutex::new(Vec::new()));
640        let subscriber = tracing_subscriber::fmt()
641            .with_ansi(false)
642            .without_time()
643            .with_writer({
644                let output = Arc::clone(&output);
645                move || SharedWriter(Arc::clone(&output))
646            })
647            .finish();
648
649        tracing::subscriber::with_default(subscriber, || {
650            drop(manager);
651        });
652
653        let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
654        assert!(logs.contains("Server 'test' task aborted during cleanup"));
655    }
656}