Skip to main content

mcp_utils/client/
manager.rs

1use llm::ToolDefinition;
2
3use super::{
4    McpError, Result,
5    config::{McpServerConfig, ServerConfig},
6    connect_mcp::{ConnectOutcome, ConnectionSpec, ProxySpec, Registration, build_plan, connect_mcp},
7    connection::{McpServerConnection, ServerInstructions, Tool},
8    mcp_client::McpClient,
9    naming::{create_namespaced_tool_name, split_on_server_name},
10    oauth::{OAuthHandler, perform_oauth_flow},
11    tool_proxy::ToolProxy,
12};
13use futures::future::join_all;
14use rmcp::{
15    RoleClient,
16    model::{
17        CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
18        ElicitationAction, FormElicitationCapability, Implementation, Root, Tool as RmcpTool, UrlElicitationCapability,
19    },
20    service::RunningService,
21    transport::streamable_http_client::StreamableHttpClientTransportConfig,
22};
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use std::collections::{HashMap, HashSet};
26use std::sync::Arc;
27use tokio::sync::{RwLock, mpsc, oneshot};
28
29pub use crate::status::{McpServerAuthCapability, McpServerStatus, McpServerStatusEntry};
30
31#[derive(Debug)]
32pub struct ElicitationRequest {
33    pub server_name: String,
34    pub request: CreateElicitationRequestParams,
35    pub response_sender: oneshot::Sender<CreateElicitationResult>,
36}
37
38#[derive(Debug, Clone)]
39pub struct ElicitationResponse {
40    pub action: ElicitationAction,
41    pub content: Option<Value>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
45pub struct UrlElicitationCompleteParams {
46    pub server_name: String,
47    pub elicitation_id: String,
48}
49
50/// Events emitted by MCP clients that require attention from the host
51/// (e.g. the relay or TUI). Flows through a single channel from `McpManager`
52/// to the consumer.
53#[derive(Debug)]
54pub enum McpClientEvent {
55    Elicitation(ElicitationRequest),
56    UrlElicitationComplete(UrlElicitationCompleteParams),
57}
58
59/// Internal record holding all mutable state for a single MCP server.
60struct ServerRecord {
61    connection: Option<McpServerConnection>,
62    status: McpServerStatus,
63    reauth_config: Option<StreamableHttpClientTransportConfig>,
64}
65
66impl ServerRecord {
67    fn new(status: McpServerStatus, reauth_config: Option<StreamableHttpClientTransportConfig>) -> Self {
68        Self { connection: None, status, reauth_config }
69    }
70
71    fn connected(
72        connection: McpServerConnection,
73        tool_count: usize,
74        reauth_config: Option<StreamableHttpClientTransportConfig>,
75    ) -> Self {
76        Self { connection: Some(connection), status: McpServerStatus::Connected { tool_count }, reauth_config }
77    }
78
79    fn auth_capability(&self) -> McpServerAuthCapability {
80        if self.reauth_config.is_some() { McpServerAuthCapability::OAuth } else { McpServerAuthCapability::Unavailable }
81    }
82
83    fn can_authenticate(&self) -> bool {
84        self.reauth_config.is_some()
85    }
86
87    fn status_entry(&self, name: &str) -> McpServerStatusEntry {
88        McpServerStatusEntry::new(name, self.status.clone()).with_auth_capability(self.auth_capability())
89    }
90}
91
92/// Manages connections to multiple MCP servers and their tools
93pub struct McpManager {
94    servers: HashMap<String, ServerRecord>,
95    server_order: Vec<String>,
96    tools: HashMap<String, Tool>,
97    tool_definitions: Vec<ToolDefinition>,
98    client_info: ClientInfo,
99    event_sender: mpsc::Sender<McpClientEvent>,
100    /// Roots shared with all MCP clients
101    roots: Arc<RwLock<Vec<Root>>>,
102    oauth_handler: Option<Arc<dyn OAuthHandler>>,
103    server_statuses: Vec<McpServerStatusEntry>,
104    /// Optional tool-proxy that wraps multiple servers behind a single `call_tool`.
105    proxy: Option<ToolProxy>,
106}
107
108impl McpManager {
109    pub fn new(event_sender: mpsc::Sender<McpClientEvent>, oauth_handler: Option<Arc<dyn OAuthHandler>>) -> Self {
110        let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
111        if let Some(elicitation) = capabilities.elicitation.as_mut() {
112            elicitation.form = Some(FormElicitationCapability::default());
113            elicitation.url = Some(UrlElicitationCapability::default());
114        }
115
116        Self {
117            servers: HashMap::new(),
118            server_order: Vec::new(),
119            tools: HashMap::new(),
120            tool_definitions: Vec::new(),
121            client_info: ClientInfo::new(capabilities, Implementation::new("aether", "0.1.0")),
122            event_sender,
123            roots: Arc::new(RwLock::new(Vec::new())),
124            oauth_handler,
125            server_statuses: Vec::new(),
126            proxy: None,
127        }
128    }
129
130    fn refresh_status_entries(&mut self) {
131        self.server_statuses = self
132            .server_order
133            .iter()
134            .filter_map(|name| self.servers.get(name).map(|record| record.status_entry(name)))
135            .collect();
136    }
137
138    fn remember_server_order(&mut self, name: &str) {
139        if !self.server_order.iter().any(|n| n == name) {
140            self.server_order.push(name.to_string());
141        }
142    }
143
144    fn upsert_status(
145        &mut self,
146        name: &str,
147        status: McpServerStatus,
148        reauth_config: Option<StreamableHttpClientTransportConfig>,
149    ) {
150        self.remember_server_order(name);
151        let record = self
152            .servers
153            .entry(name.to_string())
154            .or_insert_with(|| ServerRecord::new(status.clone(), reauth_config.clone()));
155        record.status = status;
156        if reauth_config.is_some() {
157            record.reauth_config = reauth_config;
158        }
159        self.refresh_status_entries();
160    }
161
162    fn connection_for(&self, server_name: &str) -> Option<&McpServerConnection> {
163        self.servers.get(server_name).and_then(|record| record.connection.as_ref())
164    }
165
166    fn client_for_server(&self, server_name: &str) -> Option<Arc<RunningService<RoleClient, McpClient>>> {
167        self.connection_for(server_name).map(|conn| conn.client.clone())
168    }
169
170    pub async fn add_mcps(&mut self, configs: Vec<McpServerConfig>) -> Result<()> {
171        let (direct, proxies) = build_plan(configs).await?;
172        let outcomes: Vec<ConnectOutcome> = join_all(direct.into_iter().map(|leaf| {
173            connect_mcp(leaf, &self.client_info, &self.event_sender, &self.roots, self.oauth_handler.as_ref())
174        }))
175        .await;
176
177        let mut mcp_proxies: HashMap<String, HashSet<String>> =
178            proxies.iter().map(|p| (p.name.clone(), HashSet::new())).collect();
179
180        let mut connected_mcps_to_proxy: HashMap<String, Vec<String>> = HashMap::new();
181        for outcome in outcomes {
182            match outcome {
183                ConnectOutcome::Ready { name, conn, tools, proxy, registration, reauth_config } => {
184                    self.apply_connected(&name, conn, &tools, registration, reauth_config);
185                    if let Some(p) = proxy {
186                        if let Some(members) = mcp_proxies.get_mut(&p) {
187                            members.insert(name.clone());
188                        }
189                        connected_mcps_to_proxy.entry(p).or_default().push(name);
190                    }
191                }
192                ConnectOutcome::NeedsOAuth { name, config, error, proxy } => {
193                    tracing::warn!("Server '{name}' needs OAuth: {error}");
194                    self.upsert_status(&name, McpServerStatus::NeedsOAuth, Some(config));
195                    if let Some(p) = proxy
196                        && let Some(members) = mcp_proxies.get_mut(&p)
197                    {
198                        members.insert(name);
199                    }
200                }
201                ConnectOutcome::Failed { name, error } => {
202                    tracing::warn!("Failed to connect to MCP server '{name}': {error}");
203                    if !self.servers.contains_key(&name) {
204                        self.upsert_status(&name, McpServerStatus::Failed { error: error.to_string() }, None);
205                    }
206                }
207            }
208        }
209
210        let writes = proxies.iter().flat_map(|proxy| {
211            connected_mcps_to_proxy
212                .get(&proxy.name)
213                .into_iter()
214                .flatten()
215                .filter_map(|member| {
216                    let client = self.client_for_server(member)?;
217                    let dir = proxy.tool_dir.clone();
218                    let name = member.clone();
219                    Some(async move {
220                        if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &dir).await {
221                            tracing::warn!("Failed to write tool files for nested server '{name}': {e}");
222                        }
223                    })
224                })
225                .collect::<Vec<_>>()
226        });
227        join_all(writes).await;
228
229        for proxy in proxies {
230            self.register_proxy(proxy, &mut mcp_proxies, &mut connected_mcps_to_proxy);
231        }
232
233        Ok(())
234    }
235
236    pub async fn add_mcp_with_auth(&mut self, name: String, base_url: &str, auth_header: String) -> Result<()> {
237        let config = ServerConfig::Http {
238            name: name.clone(),
239            config: StreamableHttpClientTransportConfig::with_uri(base_url).auth_header(auth_header),
240        };
241        let leaf = ConnectionSpec { name, config, proxy: None, registration: Registration::Direct };
242        match connect_mcp(leaf, &self.client_info, &self.event_sender, &self.roots, self.oauth_handler.as_ref()).await {
243            ConnectOutcome::Ready { name, conn, tools, registration, .. } => {
244                self.apply_connected(&name, conn, &tools, registration, None);
245                Ok(())
246            }
247            ConnectOutcome::NeedsOAuth { error, .. } | ConnectOutcome::Failed { error, .. } => Err(error),
248        }
249    }
250
251    pub async fn add_mcp(&mut self, config: McpServerConfig) -> Result<()> {
252        match config {
253            McpServerConfig::ToolProxy { .. } => self.add_mcps(vec![config]).await,
254
255            McpServerConfig::Server(config) => {
256                let name = config.name().to_string();
257                let leaf = ConnectionSpec { name, config, proxy: None, registration: Registration::Direct };
258                match connect_mcp(leaf, &self.client_info, &self.event_sender, &self.roots, self.oauth_handler.as_ref())
259                    .await
260                {
261                    ConnectOutcome::Ready { name, conn, tools, registration, reauth_config, .. } => {
262                        self.apply_connected(&name, conn, &tools, registration, reauth_config);
263                        Ok(())
264                    }
265                    ConnectOutcome::NeedsOAuth { name, config, error, .. } => {
266                        self.upsert_status(&name, McpServerStatus::NeedsOAuth, Some(config));
267                        Err(error)
268                    }
269                    ConnectOutcome::Failed { error, .. } => Err(error),
270                }
271            }
272        }
273    }
274
275    fn register_proxy(
276        &mut self,
277        proxy: ProxySpec,
278        proxy_members: &mut HashMap<String, HashSet<String>>,
279        ready_for_proxy: &mut HashMap<String, Vec<String>>,
280    ) {
281        let members = ready_for_proxy.remove(&proxy.name).unwrap_or_default();
282
283        let server_descriptions: Vec<(String, String)> = members
284            .iter()
285            .filter_map(|member| {
286                self.connection_for(member)
287                    .map(|conn| (member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
288            })
289            .collect();
290
291        self.remove_registered_tools_for_server(&proxy.name);
292        let call_tool_def = ToolProxy::call_tool_definition(&proxy.name);
293        self.tools.insert(
294            call_tool_def.name.clone(),
295            Tool {
296                description: call_tool_def.description.clone(),
297                parameters: serde_json::from_str(&call_tool_def.parameters)
298                    .unwrap_or(Value::Object(serde_json::Map::default())),
299            },
300        );
301        self.tool_definitions.push(call_tool_def);
302
303        let nested = proxy_members.remove(&proxy.name).unwrap_or_default();
304        self.proxy = Some(ToolProxy::new(proxy.name.clone(), nested, proxy.tool_dir, &server_descriptions));
305        self.upsert_status(&proxy.name, McpServerStatus::Connected { tool_count: 1 }, None);
306    }
307
308    async fn oauth_and_reconnect(&mut self, name: String, config: StreamableHttpClientTransportConfig) -> Result<()> {
309        let handler = self
310            .oauth_handler
311            .as_ref()
312            .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler available for '{name}'")))?;
313        let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref())
314            .await
315            .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
316
317        let mcp_client =
318            McpClient::new(self.client_info.clone(), name.clone(), self.event_sender.clone(), Arc::clone(&self.roots));
319        let conn = McpServerConnection::reconnect_with_auth(&name, config.clone(), auth_client, mcp_client).await?;
320
321        let is_proxied = self.proxy.as_ref().is_some_and(|p| p.contains_server(&name));
322        if is_proxied {
323            let tool_dir = self.proxy.as_ref().expect("checked above").tool_dir().to_path_buf();
324            self.register_server(&name, conn, Registration::Proxied, Some(config)).await?;
325            if let Some(proxy) = self.proxy.as_mut() {
326                proxy.add_member(name.clone());
327            }
328            if let Some(conn) = self.connection_for(&name) {
329                let client = conn.client.clone();
330                if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &tool_dir).await {
331                    tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
332                }
333            }
334            Ok(())
335        } else {
336            self.register_server(&name, conn, Registration::Direct, Some(config)).await
337        }
338    }
339
340    async fn register_server(
341        &mut self,
342        name: &str,
343        conn: McpServerConnection,
344        registration: Registration,
345        reauth_config: Option<StreamableHttpClientTransportConfig>,
346    ) -> Result<()> {
347        let tools = conn
348            .list_tools()
349            .await
350            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
351        self.apply_connected(name, conn, &tools, registration, reauth_config);
352        Ok(())
353    }
354
355    fn apply_connected(
356        &mut self,
357        name: &str,
358        conn: McpServerConnection,
359        tools: &[RmcpTool],
360        registration: Registration,
361        reauth_config: Option<StreamableHttpClientTransportConfig>,
362    ) {
363        self.remove_registered_tools_for_server(name);
364
365        let existing_reauth = self.servers.get(name).and_then(|r| r.reauth_config.clone());
366        let final_reauth = reauth_config.or(existing_reauth);
367
368        for rmcp_tool in tools {
369            let tool_name = rmcp_tool.name.to_string();
370            let namespaced_tool_name = create_namespaced_tool_name(name, &tool_name);
371            let tool = Tool::from(rmcp_tool);
372
373            if registration == Registration::Direct {
374                self.tool_definitions.push(ToolDefinition {
375                    name: namespaced_tool_name.clone(),
376                    description: tool.description.clone(),
377                    parameters: tool.parameters.to_string(),
378                    server: Some(name.to_string()),
379                });
380            }
381
382            self.tools.insert(namespaced_tool_name, tool);
383        }
384
385        self.remember_server_order(name);
386        self.servers.insert(name.to_string(), ServerRecord::connected(conn, tools.len(), final_reauth));
387        self.refresh_status_entries();
388    }
389
390    fn remove_registered_tools_for_server(&mut self, server_name: &str) {
391        let prefix = format!("{server_name}__");
392        self.tools.retain(|tool_name, _| !tool_name.starts_with(&prefix));
393        self.tool_definitions.retain(|tool_def| !tool_def.name.starts_with(&prefix));
394    }
395
396    pub fn get_client_for_tool(
397        &self,
398        namespaced_tool_name: &str,
399        arguments_json: &str,
400    ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
401        if !self.tools.contains_key(namespaced_tool_name) {
402            return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
403        }
404
405        let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
406            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
407
408        if let Some(proxy) = self.proxy.as_ref().filter(|p| p.name() == server_name) {
409            let call = proxy.resolve_call(arguments_json)?;
410            let conn = self
411                .connection_for(&call.server)
412                .ok_or_else(|| McpError::ServerNotFound(format!("Nested server '{}' is not connected", call.server)))?;
413            let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
414            return Ok((conn.client.clone(), params));
415        }
416
417        let client =
418            self.client_for_server(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
419
420        let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
421        let mut params = CallToolRequestParams::new(tool_name.to_string());
422        if let Some(args) = arguments {
423            params = params.with_arguments(args);
424        }
425
426        Ok((client, params))
427    }
428
429    pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
430        self.tool_definitions.clone()
431    }
432
433    pub fn server_instructions(&self) -> Vec<ServerInstructions> {
434        let mut instructions: Vec<ServerInstructions> = self
435            .servers
436            .iter()
437            .filter(|(name, _)| self.proxy.as_ref().is_none_or(|p| !p.contains_server(name)))
438            .filter_map(|(name, record)| {
439                record
440                    .connection
441                    .as_ref()
442                    .and_then(|conn| conn.instructions.as_ref())
443                    .map(|instr| ServerInstructions { server_name: name.clone(), instructions: instr.clone() })
444            })
445            .collect();
446
447        if let Some(proxy) = &self.proxy {
448            let descriptions: Vec<(String, String)> = proxy
449                .members()
450                .iter()
451                .filter_map(|member| {
452                    let conn = self.connection_for(member)?;
453                    Some((member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
454                })
455                .collect();
456            instructions.push(ServerInstructions {
457                server_name: proxy.name().to_string(),
458                instructions: ToolProxy::build_instructions(proxy.tool_dir(), &descriptions),
459            });
460        }
461
462        instructions
463    }
464
465    pub fn server_statuses(&self) -> &[McpServerStatusEntry] {
466        &self.server_statuses
467    }
468
469    pub async fn authenticate_server(&mut self, name: &str) -> Result<()> {
470        let record = self
471            .servers
472            .get(name)
473            .ok_or_else(|| McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")))?;
474        if !record.can_authenticate() {
475            return Err(McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")));
476        }
477
478        self.oauth_and_reconnect(name.to_string(), record.reauth_config.clone().expect("checked above")).await
479    }
480
481    /// List all prompts from all connected MCP servers with namespacing
482    pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
483        use futures::future::join_all;
484
485        let futures: Vec<_> = self
486            .servers
487            .iter()
488            .filter_map(|(server_name, record)| {
489                let conn = record.connection.as_ref()?;
490                conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref())?;
491                let server_name = server_name.clone();
492                let client = conn.client.clone();
493                Some(async move {
494                    let prompts_response = client.list_prompts(None).await.map_err(|e| {
495                        McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
496                    })?;
497
498                    let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
499                        .prompts
500                        .into_iter()
501                        .map(|prompt| {
502                            let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
503                            rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
504                        })
505                        .collect();
506
507                    Ok::<_, McpError>(namespaced_prompts)
508                })
509            })
510            .collect();
511
512        let results = join_all(futures).await;
513        let mut all_prompts = Vec::new();
514        for result in results {
515            all_prompts.extend(result?);
516        }
517
518        Ok(all_prompts)
519    }
520
521    /// Get a specific prompt by namespaced name
522    pub async fn get_prompt(
523        &self,
524        namespaced_prompt_name: &str,
525        arguments: Option<serde_json::Map<String, serde_json::Value>>,
526    ) -> Result<rmcp::model::GetPromptResult> {
527        let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
528            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
529
530        let server_conn =
531            self.connection_for(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
532
533        let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
534        if let Some(args) = arguments {
535            request = request.with_arguments(args);
536        }
537
538        server_conn.client.get_prompt(request).await.map_err(|e| {
539            McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
540        })
541    }
542
543    /// Shutdown all servers and wait for their tasks to complete
544    pub async fn shutdown(&mut self) {
545        let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
546
547        for (server_name, record) in servers {
548            if let Some(conn) = record.connection
549                && let Some(handle) = conn.server_task
550            {
551                drop(conn.client);
552
553                match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
554                    Ok(Ok(())) => {
555                        tracing::info!("Server '{server_name}' shut down gracefully");
556                    }
557                    Ok(Err(e)) => {
558                        tracing::warn!("Server '{server_name}' task panicked: {e:?}");
559                    }
560                    Err(_) => {
561                        tracing::warn!("Server '{server_name}' shutdown timed out");
562                    }
563                }
564            }
565        }
566
567        self.tools.clear();
568        self.tool_definitions.clear();
569        self.proxy = None;
570    }
571
572    /// Shutdown a specific server by name
573    pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
574        let record = self.servers.remove(server_name);
575
576        if let Some(record) = record {
577            if let Some(conn) = record.connection
578                && let Some(handle) = conn.server_task
579            {
580                drop(conn.client);
581
582                match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
583                    Ok(Ok(())) => {
584                        tracing::info!("Server '{server_name}' shut down gracefully");
585                    }
586                    Ok(Err(e)) => {
587                        tracing::warn!("Server '{server_name}' task panicked: {e:?}");
588                    }
589                    Err(_) => {
590                        tracing::warn!("Server '{server_name}' shutdown timed out");
591                    }
592                }
593            }
594
595            self.remove_registered_tools_for_server(server_name);
596            self.refresh_status_entries();
597        }
598
599        Ok(())
600    }
601
602    /// Set the roots advertised to MCP servers.
603    ///
604    /// This updates the roots and sends notifications to all connected servers
605    /// that support the `roots/list_changed` notification.
606    pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
607        {
608            let mut roots = self.roots.write().await;
609            *roots = new_roots;
610        }
611
612        self.notify_roots_changed().await;
613
614        Ok(())
615    }
616
617    async fn notify_roots_changed(&self) {
618        for (server_name, record) in &self.servers {
619            if let Some(conn) = &record.connection
620                && let Err(e) = conn.client.notify_roots_list_changed().await
621            {
622                tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
623            }
624        }
625    }
626}
627
628impl Drop for McpManager {
629    fn drop(&mut self) {
630        let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
631        for (server_name, record) in servers {
632            if let Some(conn) = record.connection
633                && let Some(handle) = conn.server_task
634            {
635                handle.abort();
636                tracing::warn!("Server '{server_name}' task aborted during cleanup");
637            }
638        }
639    }
640}
641
642#[cfg(test)]
643mod tests {
644    use super::{McpManager, McpServerStatus, Tool};
645    use crate::client::config::ServerConfig;
646    use crate::client::oauth::{OAuthCallback, OAuthError, OAuthHandler};
647    use crate::status::McpServerAuthCapability;
648    use futures::future::BoxFuture;
649    use llm::ToolDefinition;
650    use rmcp::{
651        Json, RoleServer, ServerHandler,
652        handler::server::{router::tool::ToolRouter, wrapper::Parameters},
653        model::{Implementation, ServerCapabilities, ServerInfo},
654        service::DynService,
655        tool, tool_handler, tool_router,
656        transport::streamable_http_client::StreamableHttpClientTransportConfig,
657    };
658    use schemars::JsonSchema;
659    use serde::{Deserialize, Serialize};
660    use serde_json::json;
661    use std::{
662        io,
663        sync::{Arc, Mutex},
664    };
665    use tokio::sync::mpsc;
666
667    #[derive(Clone)]
668    struct TestServer {
669        tool_router: ToolRouter<Self>,
670    }
671
672    #[tool_handler(router = self.tool_router)]
673    impl ServerHandler for TestServer {
674        fn get_info(&self) -> ServerInfo {
675            ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
676                .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
677        }
678    }
679
680    impl Default for TestServer {
681        fn default() -> Self {
682            Self { tool_router: Self::tool_router() }
683        }
684    }
685
686    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
687    struct EchoRequest {
688        value: String,
689    }
690
691    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
692    struct EchoResult {
693        value: String,
694    }
695
696    #[tool_router]
697    impl TestServer {
698        fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
699            Box::new(self)
700        }
701
702        #[tool(description = "Returns the provided value")]
703        async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
704            let Parameters(EchoRequest { value }) = request;
705            Json(EchoResult { value })
706        }
707    }
708
709    #[derive(Clone)]
710    struct SharedWriter(Arc<Mutex<Vec<u8>>>);
711
712    impl io::Write for SharedWriter {
713        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
714            self.0.lock().unwrap().extend_from_slice(buf);
715            Ok(buf.len())
716        }
717
718        fn flush(&mut self) -> io::Result<()> {
719            Ok(())
720        }
721    }
722
723    struct TestOAuthHandler;
724
725    impl OAuthHandler for TestOAuthHandler {
726        fn redirect_uri(&self) -> &'static str {
727            "http://127.0.0.1:0/oauth2callback"
728        }
729
730        fn authorize(&self, _auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
731            Box::pin(async { Err(OAuthError::UserCancelled) })
732        }
733    }
734
735    #[tokio::test]
736    async fn authenticate_server_rejects_record_without_reauth_config() {
737        let (event_sender, _event_receiver) = mpsc::channel(1);
738        let mut manager = McpManager::new(event_sender, Some(Arc::new(TestOAuthHandler)));
739        manager.upsert_status("public", McpServerStatus::Connected { tool_count: 1 }, None);
740
741        let error = manager.authenticate_server("public").await.unwrap_err().to_string();
742        assert!(error.contains("not OAuth-authenticatable"));
743    }
744
745    #[tokio::test]
746    async fn authenticate_server_uses_reauth_config_for_connected_oauth_server() {
747        let (event_sender, _event_receiver) = mpsc::channel(1);
748        let mut manager = McpManager::new(event_sender, Some(Arc::new(TestOAuthHandler)));
749        manager.upsert_status(
750            "remote",
751            McpServerStatus::Connected { tool_count: 1 },
752            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
753        );
754
755        let error = manager.authenticate_server("remote").await.unwrap_err().to_string();
756        assert!(!error.contains("not OAuth-authenticatable"));
757        assert!(error.contains("OAuth failed") || error.contains("UserCancelled"));
758    }
759
760    #[test]
761    fn status_entries_are_derived_from_reauth_config() {
762        let (event_sender, _event_receiver) = mpsc::channel(1);
763        let mut manager = McpManager::new(event_sender, Some(Arc::new(TestOAuthHandler)));
764
765        manager.upsert_status(
766            "with-oauth",
767            McpServerStatus::Connected { tool_count: 1 },
768            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp")),
769        );
770        manager.upsert_status("without-oauth", McpServerStatus::Connected { tool_count: 2 }, None);
771        manager.upsert_status(
772            "needs-oauth",
773            McpServerStatus::NeedsOAuth,
774            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp2")),
775        );
776
777        let statuses = manager.server_statuses();
778        let with_oauth = statuses.iter().find(|s| s.name == "with-oauth").unwrap();
779        let without_oauth = statuses.iter().find(|s| s.name == "without-oauth").unwrap();
780        let needs_oauth = statuses.iter().find(|s| s.name == "needs-oauth").unwrap();
781
782        assert_eq!(with_oauth.auth_capability, McpServerAuthCapability::OAuth);
783        assert_eq!(without_oauth.auth_capability, McpServerAuthCapability::Unavailable);
784        assert_eq!(needs_oauth.auth_capability, McpServerAuthCapability::OAuth);
785    }
786
787    #[test]
788    fn remove_registered_tools_for_server_uses_namespaced_prefix() {
789        let (event_sender, _event_receiver) = mpsc::channel(1);
790        let mut manager = McpManager::new(event_sender, None);
791        manager.tools.insert("git__status".to_string(), Tool { description: String::new(), parameters: json!({}) });
792        manager.tools.insert("github__issue".to_string(), Tool { description: String::new(), parameters: json!({}) });
793        manager.tool_definitions.push(ToolDefinition {
794            name: "git__status".to_string(),
795            description: String::new(),
796            parameters: "{}".to_string(),
797            server: Some("git".to_string()),
798        });
799        manager.tool_definitions.push(ToolDefinition {
800            name: "github__issue".to_string(),
801            description: String::new(),
802            parameters: "{}".to_string(),
803            server: Some("github".to_string()),
804        });
805
806        manager.remove_registered_tools_for_server("git");
807
808        assert!(!manager.tools.contains_key("git__status"));
809        assert!(manager.tools.contains_key("github__issue"));
810        assert_eq!(
811            manager.tool_definitions.iter().map(|tool| tool.name.as_str()).collect::<Vec<_>>(),
812            vec!["github__issue"]
813        );
814    }
815
816    #[tokio::test]
817    async fn drop_logs_cleanup_abort_with_tracing() {
818        let (event_sender, _event_receiver) = mpsc::channel(1);
819        let mut manager = McpManager::new(event_sender, None);
820        manager
821            .add_mcp(
822                ServerConfig::InMemory { name: "test".to_string(), server: TestServer::default().into_dyn() }.into(),
823            )
824            .await
825            .unwrap();
826
827        let output = Arc::new(Mutex::new(Vec::new()));
828        let subscriber = tracing_subscriber::fmt()
829            .with_ansi(false)
830            .without_time()
831            .with_writer({
832                let output = Arc::clone(&output);
833                move || SharedWriter(Arc::clone(&output))
834            })
835            .finish();
836
837        tracing::subscriber::with_default(subscriber, || {
838            drop(manager);
839        });
840
841        let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
842        assert!(logs.contains("Server 'test' task aborted during cleanup"));
843    }
844}