Skip to main content

mcp_utils/client/
manager.rs

1use llm::ToolDefinition;
2
3use super::{
4    McpError, Result,
5    config::McpServer,
6    connection::{
7        ConnectContext, McpConnectAttempt, McpConnectOutcome, McpServerConnection, ServerInstructions, Tool,
8        authenticate_http, connect_server,
9    },
10    mcp_client::McpClient,
11    naming::{create_namespaced_tool_name, split_on_server_name},
12    tool_proxy::ToolProxy,
13};
14use aether_auth::{OAuthCredentialStorage, OAuthHandler};
15use futures::future::join_all;
16use rmcp::{
17    RoleClient,
18    model::{
19        CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
20        ElicitationAction, FormElicitationCapability, Implementation, Root, Tool as RmcpTool, UrlElicitationCapability,
21    },
22    service::RunningService,
23    transport::streamable_http_client::StreamableHttpClientTransportConfig,
24};
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::{HashMap, HashSet};
28use std::future::Future;
29use std::path::PathBuf;
30use std::sync::Arc;
31use tokio::sync::{RwLock, mpsc, oneshot};
32
33pub use crate::status::{McpServerAuthCapability, McpServerStatus, McpServerStatusEntry};
34
35pub const DEFAULT_PROXY_NAME: &str = "proxy";
36
37pub type OAuthHandlerFactory = Arc<dyn Fn(OAuthHandlerContext) -> Result<Arc<dyn OAuthHandler>> + Send + Sync>;
38
39/// Context passed to an `OAuthHandlerFactory` so the constructed handler can
40/// dispatch user-facing prompts back to the host through the MCP event channel.
41#[derive(Clone)]
42pub struct OAuthHandlerContext {
43    pub server_name: String,
44    pub tx: mpsc::Sender<McpClientEvent>,
45}
46
47#[derive(Debug)]
48pub struct ElicitationRequest {
49    pub server_name: String,
50    pub request: CreateElicitationRequestParams,
51    pub response_sender: oneshot::Sender<CreateElicitationResult>,
52}
53
54#[derive(Debug, Clone)]
55pub struct ElicitationResponse {
56    pub action: ElicitationAction,
57    pub content: Option<Value>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
61pub struct UrlElicitationCompleteParams {
62    pub server_name: String,
63    pub elicitation_id: String,
64}
65
66/// Events emitted by MCP clients that require attention from the host
67/// (e.g. the relay or TUI). Flows through a single channel from `McpManager`
68/// to the consumer.
69#[derive(Debug)]
70pub enum McpClientEvent {
71    Elicitation(ElicitationRequest),
72    UrlElicitationComplete(UrlElicitationCompleteParams),
73    ServerStatusesChanged(Vec<McpServerStatusEntry>),
74    ToolDefinitionsChanged(Vec<ToolDefinition>),
75    AuthenticationFailed { server: String, error: String },
76}
77
78/// Manages connections to multiple MCP servers and their tools
79pub struct McpManager {
80    servers: HashMap<String, ServerRecord>,
81    server_order: Vec<String>,
82    tools: HashMap<String, Tool>,
83    tool_definitions: Vec<ToolDefinition>,
84    proxy: Option<ToolProxy>,
85    aether_home: Option<PathBuf>,
86    client_info: ClientInfo,
87    event_sender: mpsc::Sender<McpClientEvent>,
88    /// Roots shared with all MCP clients
89    roots: Arc<RwLock<Vec<Root>>>,
90    oauth_handler_factory: Option<OAuthHandlerFactory>,
91    oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
92    server_statuses: Vec<McpServerStatusEntry>,
93}
94
95impl McpManager {
96    pub fn new(event_sender: mpsc::Sender<McpClientEvent>, oauth_handler_factory: Option<OAuthHandlerFactory>) -> Self {
97        let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
98        if let Some(elicitation) = capabilities.elicitation.as_mut() {
99            elicitation.form = Some(FormElicitationCapability::default());
100            elicitation.url = Some(UrlElicitationCapability::default());
101        }
102
103        Self {
104            servers: HashMap::new(),
105            server_order: Vec::new(),
106            tools: HashMap::new(),
107            tool_definitions: Vec::new(),
108            proxy: None,
109            aether_home: None,
110            client_info: ClientInfo::new(capabilities, Implementation::new("aether", "0.1.0")),
111            event_sender,
112            roots: Arc::new(RwLock::new(Vec::new())),
113            oauth_handler_factory,
114            oauth_credential_store: None,
115            server_statuses: Vec::new(),
116        }
117    }
118
119    pub fn with_aether_home(mut self, aether_home: impl Into<PathBuf>) -> Self {
120        self.aether_home = Some(aether_home.into());
121        self
122    }
123
124    pub fn with_oauth_credential_store(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
125        self.oauth_credential_store = Some(store);
126        self
127    }
128
129    pub async fn add_mcps(&mut self, servers: Vec<McpServer>) -> Result<()> {
130        let has_proxy = servers.iter().any(|server| server.proxy);
131        if has_proxy && servers.iter().any(|server| server.name == DEFAULT_PROXY_NAME) {
132            return Err(McpError::Other("server name 'proxy' collides with the tool proxy".into()));
133        }
134
135        let proxied_members: HashSet<String> =
136            servers.iter().filter(|server| server.proxy).map(|server| server.name.clone()).collect();
137        let proxy_tool_dir = if has_proxy {
138            let dir = self.proxy_tool_dir()?;
139            ToolProxy::clean_dir(&dir).await?;
140            Some(dir)
141        } else {
142            None
143        };
144
145        let ctx = self.connect_context();
146        let attempts = join_all(servers.into_iter().map(|server| connect_server(server, &ctx))).await;
147
148        let mut connected_proxied = Vec::new();
149        for McpConnectAttempt { name, proxied, outcome } in attempts {
150            match outcome {
151                McpConnectOutcome::Connected { conn, reauth_config } => {
152                    self.register_connection(&name, conn, reauth_config, proxied).await?;
153                    if proxied {
154                        connected_proxied.push(name);
155                    }
156                }
157                McpConnectOutcome::NeedsOAuth { config, error } => {
158                    tracing::warn!("Server '{name}' needs OAuth: {error}");
159                    self.register_record(&name, McpServerStatus::NeedsOAuth, Some(config), proxied);
160                }
161                McpConnectOutcome::Failed { error } => {
162                    tracing::warn!("Failed to connect to MCP server '{name}': {error}");
163                    if !self.servers.contains_key(&name) {
164                        self.register_record(
165                            &name,
166                            McpServerStatus::Failed { error: error.to_string() },
167                            None,
168                            proxied,
169                        );
170                    }
171                }
172            }
173        }
174
175        if let Some(tool_dir) = proxy_tool_dir {
176            self.write_proxy_tool_files(&connected_proxied, &tool_dir).await;
177            self.register_proxy(tool_dir, proxied_members);
178        }
179
180        Ok(())
181    }
182
183    pub fn get_client_for_tool(
184        &self,
185        namespaced_tool_name: &str,
186        arguments_json: &str,
187    ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
188        if !self.tools.contains_key(namespaced_tool_name) {
189            return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
190        }
191
192        let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
193            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
194
195        if let Some(proxy) = self.proxy.as_ref().filter(|proxy| proxy.name() == server_name) {
196            let call = proxy.resolve_call(arguments_json)?;
197            let conn = self.connection_for(&call.server).ok_or_else(|| {
198                McpError::ServerNotFound(format!("Proxied server '{}' is not connected", call.server))
199            })?;
200            let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
201            return Ok((conn.client.clone(), params));
202        }
203
204        let client =
205            self.client_for_server(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
206
207        let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
208        let mut params = CallToolRequestParams::new(tool_name.to_string());
209        if let Some(args) = arguments {
210            params = params.with_arguments(args);
211        }
212
213        Ok((client, params))
214    }
215
216    pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
217        self.tool_definitions.clone()
218    }
219
220    pub fn server_instructions(&self) -> Vec<ServerInstructions> {
221        let mut instructions: Vec<ServerInstructions> = self
222            .servers
223            .iter()
224            .filter(|(name, _)| self.proxy.as_ref().is_none_or(|proxy| !proxy.contains_server(name)))
225            .filter_map(|(name, record)| {
226                record
227                    .connection
228                    .as_ref()
229                    .and_then(|conn| conn.instructions.as_ref())
230                    .map(|instr| ServerInstructions { server_name: name.clone(), instructions: instr.clone() })
231            })
232            .collect();
233
234        if let Some(proxy) = &self.proxy {
235            let descriptions: Vec<(String, String)> = proxy
236                .members()
237                .iter()
238                .filter_map(|member| {
239                    let conn = self.connection_for(member)?;
240                    Some((member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
241                })
242                .collect();
243            instructions.push(ServerInstructions {
244                server_name: proxy.name().to_string(),
245                instructions: ToolProxy::build_instructions(proxy.tool_dir(), &descriptions),
246            });
247        }
248
249        instructions
250    }
251
252    pub fn server_statuses(&self) -> &[McpServerStatusEntry] {
253        &self.server_statuses
254    }
255
256    pub async fn authenticate_server_task(
257        &mut self,
258        name: &str,
259    ) -> Result<impl Future<Output = McpConnectAttempt> + Send + 'static> {
260        let record = self
261            .servers
262            .get(name)
263            .ok_or_else(|| McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")))?;
264        if !record.can_authenticate() {
265            return Err(McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")));
266        }
267
268        let oauth_handler_factory = self
269            .oauth_handler_factory
270            .clone()
271            .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler factory available for '{name}'")))?;
272        let oauth_credential_store = self.oauth_credential_store.clone();
273        let name = name.to_string();
274        let config = record.reauth_config.clone().expect("checked above");
275        let client_info = self.client_info.clone();
276        let event_sender = self.event_sender.clone();
277        let roots = Arc::clone(&self.roots);
278        let proxied = record.proxied;
279
280        self.set_status(&name, McpServerStatus::Authenticating);
281        self.emit_server_statuses_changed().await;
282
283        Ok(async move {
284            authenticate_http(
285                name,
286                config,
287                client_info,
288                event_sender,
289                roots,
290                oauth_handler_factory,
291                oauth_credential_store,
292                proxied,
293            )
294            .await
295        })
296    }
297
298    pub async fn apply_connection_attempt(&mut self, attempt: McpConnectAttempt) {
299        let McpConnectAttempt { name, proxied, outcome } = attempt;
300        match outcome {
301            McpConnectOutcome::Connected { conn, reauth_config } => {
302                match self.register_connection(&name, conn, reauth_config, proxied).await {
303                    Ok(tools) => {
304                        self.refresh_proxy_after_auth(&name, &tools, proxied).await;
305                        self.emit_server_statuses_changed().await;
306                        self.emit_tool_definitions_changed().await;
307                    }
308                    Err(error) => self.apply_authentication_failure(name, error.to_string()).await,
309                }
310            }
311            McpConnectOutcome::Failed { error } => {
312                self.apply_authentication_failure(name, error.to_string()).await;
313            }
314            McpConnectOutcome::NeedsOAuth { .. } => {
315                self.apply_authentication_failure(name, "internal error: auth task returned NeedsOAuth".to_string())
316                    .await;
317            }
318        }
319    }
320
321    /// List all prompts from all connected MCP servers with namespacing
322    pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
323        let futures: Vec<_> = self
324            .servers
325            .iter()
326            .filter_map(|(server_name, record)| {
327                let conn = record.connection.as_ref()?;
328                conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref())?;
329                let server_name = server_name.clone();
330                let client = conn.client.clone();
331                Some(async move {
332                    let prompts_response = client.list_prompts(None).await.map_err(|e| {
333                        McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
334                    })?;
335
336                    let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
337                        .prompts
338                        .into_iter()
339                        .map(|prompt| {
340                            let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
341                            rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
342                        })
343                        .collect();
344
345                    Ok::<_, McpError>(namespaced_prompts)
346                })
347            })
348            .collect();
349
350        let results = join_all(futures).await;
351        let mut all_prompts = Vec::new();
352        for result in results {
353            all_prompts.extend(result?);
354        }
355
356        Ok(all_prompts)
357    }
358
359    /// Get a specific prompt by namespaced name
360    pub async fn get_prompt(
361        &self,
362        namespaced_prompt_name: &str,
363        arguments: Option<serde_json::Map<String, serde_json::Value>>,
364    ) -> Result<rmcp::model::GetPromptResult> {
365        let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
366            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
367
368        let server_conn =
369            self.connection_for(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
370
371        let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
372        if let Some(args) = arguments {
373            request = request.with_arguments(args);
374        }
375
376        server_conn.client.get_prompt(request).await.map_err(|e| {
377            McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
378        })
379    }
380
381    /// Shutdown all servers and wait for their tasks to complete
382    pub async fn shutdown(&mut self) {
383        let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
384
385        for (server_name, record) in servers {
386            if let Some(conn) = record.connection
387                && let Some(handle) = conn.server_task
388            {
389                drop(conn.client);
390
391                match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
392                    Ok(Ok(())) => {
393                        tracing::info!("Server '{server_name}' shut down gracefully");
394                    }
395                    Ok(Err(e)) => {
396                        tracing::warn!("Server '{server_name}' task panicked: {e:?}");
397                    }
398                    Err(_) => {
399                        tracing::warn!("Server '{server_name}' shutdown timed out");
400                    }
401                }
402            }
403        }
404
405        self.tools.clear();
406        self.tool_definitions.clear();
407        self.proxy = None;
408    }
409
410    /// Shutdown a specific server by name
411    pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
412        let record = self.servers.remove(server_name);
413
414        if let Some(record) = record {
415            if let Some(conn) = record.connection
416                && let Some(handle) = conn.server_task
417            {
418                drop(conn.client);
419
420                match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
421                    Ok(Ok(())) => {
422                        tracing::info!("Server '{server_name}' shut down gracefully");
423                    }
424                    Ok(Err(e)) => {
425                        tracing::warn!("Server '{server_name}' task panicked: {e:?}");
426                    }
427                    Err(_) => {
428                        tracing::warn!("Server '{server_name}' shutdown timed out");
429                    }
430                }
431            }
432
433            self.remove_registered_tools_for_server(server_name);
434            self.refresh_status_entries();
435        }
436
437        Ok(())
438    }
439
440    /// Set the roots advertised to MCP servers.
441    ///
442    /// This updates the roots and sends notifications to all connected servers
443    /// that support the `roots/list_changed` notification.
444    pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
445        {
446            let mut roots = self.roots.write().await;
447            *roots = new_roots;
448        }
449
450        self.notify_roots_changed().await;
451
452        Ok(())
453    }
454
455    async fn emit_server_statuses_changed(&self) {
456        self.emit_event(McpClientEvent::ServerStatusesChanged(self.server_statuses().to_vec())).await;
457    }
458
459    async fn emit_tool_definitions_changed(&self) {
460        self.emit_event(McpClientEvent::ToolDefinitionsChanged(self.tool_definitions())).await;
461    }
462
463    async fn emit_authentication_failed(&self, server: String, error: String) {
464        self.emit_event(McpClientEvent::AuthenticationFailed { server, error }).await;
465    }
466
467    async fn emit_event(&self, event: McpClientEvent) {
468        if let Err(e) = self.event_sender.send(event).await {
469            tracing::warn!("Failed to emit MCP client event: {e}");
470        }
471    }
472
473    fn connect_context(&self) -> ConnectContext<'_> {
474        ConnectContext {
475            client_info: &self.client_info,
476            event_sender: &self.event_sender,
477            roots: &self.roots,
478            oauth_handler_factory: self.oauth_handler_factory.as_ref(),
479            oauth_credential_store: self.oauth_credential_store.as_ref(),
480        }
481    }
482
483    fn proxy_tool_dir(&self) -> Result<PathBuf> {
484        self.aether_home
485            .as_ref()
486            .map(|home| ToolProxy::dir_in_home(home, DEFAULT_PROXY_NAME))
487            .map_or_else(|| ToolProxy::dir(DEFAULT_PROXY_NAME), Ok)
488    }
489
490    async fn register_connection(
491        &mut self,
492        name: &str,
493        conn: McpServerConnection,
494        reauth_config: Option<StreamableHttpClientTransportConfig>,
495        proxied: bool,
496    ) -> Result<Vec<RmcpTool>> {
497        let tools = conn
498            .list_tools()
499            .await
500            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
501        self.apply_connected(name, conn, &tools, reauth_config, proxied);
502        Ok(tools)
503    }
504
505    fn apply_connected(
506        &mut self,
507        name: &str,
508        conn: McpServerConnection,
509        tools: &[RmcpTool],
510        reauth_config: Option<StreamableHttpClientTransportConfig>,
511        proxied: bool,
512    ) {
513        self.remove_registered_tools_for_server(name);
514
515        let existing_reauth = self.servers.get(name).and_then(|r| r.reauth_config.clone());
516        let final_reauth = reauth_config.or(existing_reauth);
517
518        for rmcp_tool in tools {
519            let tool_name = rmcp_tool.name.to_string();
520            let namespaced_tool_name = create_namespaced_tool_name(name, &tool_name);
521            let tool = Tool::from(rmcp_tool);
522
523            if !proxied {
524                self.tool_definitions.push(ToolDefinition {
525                    name: namespaced_tool_name.clone(),
526                    description: tool.description.clone(),
527                    parameters: tool.parameters.to_string(),
528                    server: Some(name.to_string()),
529                });
530                self.tools.insert(namespaced_tool_name, tool);
531            }
532        }
533
534        self.remember_server_order(name);
535        self.servers.insert(name.to_string(), ServerRecord::connected(conn, tools.len(), final_reauth, proxied));
536        self.refresh_status_entries();
537    }
538
539    fn register_proxy(&mut self, tool_dir: std::path::PathBuf, members: HashSet<String>) {
540        self.remove_registered_tools_for_server(DEFAULT_PROXY_NAME);
541        let call_tool_def = ToolProxy::call_tool_definition(DEFAULT_PROXY_NAME);
542        self.tools.insert(
543            call_tool_def.name.clone(),
544            Tool {
545                description: call_tool_def.description.clone(),
546                parameters: serde_json::from_str(&call_tool_def.parameters)
547                    .unwrap_or(Value::Object(serde_json::Map::default())),
548            },
549        );
550        self.tool_definitions.push(call_tool_def);
551
552        self.proxy = Some(ToolProxy::new(DEFAULT_PROXY_NAME.to_string(), members, tool_dir));
553    }
554
555    async fn refresh_proxy_after_auth(&mut self, name: &str, tools: &[RmcpTool], proxied: bool) {
556        if !proxied {
557            return;
558        }
559
560        if let Some(proxy) = self.proxy.as_mut() {
561            proxy.add_member(name.to_string());
562        }
563
564        if let Some(tool_dir) = self.proxy.as_ref().map(|proxy| proxy.tool_dir().to_path_buf())
565            && let Err(e) = ToolProxy::write_tool_entries_to_dir(name, tools, &tool_dir).await
566        {
567            tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
568        }
569    }
570
571    async fn write_proxy_tool_files(&self, connected_proxied: &[String], tool_dir: &std::path::Path) {
572        let writes = connected_proxied.iter().filter_map(|name| {
573            let client = self.client_for_server(name)?;
574            let dir = tool_dir.to_path_buf();
575            let name = name.clone();
576            Some(async move {
577                if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &dir).await {
578                    tracing::warn!("Failed to write tool files for proxied server '{name}': {e}");
579                }
580            })
581        });
582        join_all(writes).await;
583    }
584
585    fn refresh_status_entries(&mut self) {
586        self.server_statuses = self
587            .server_order
588            .iter()
589            .filter_map(|name| self.servers.get(name).map(|record| record.status_entry(name)))
590            .collect();
591    }
592
593    fn remember_server_order(&mut self, name: &str) {
594        if !self.server_order.iter().any(|n| n == name) {
595            self.server_order.push(name.to_string());
596        }
597    }
598
599    async fn apply_authentication_failure(&mut self, name: String, error: String) {
600        self.set_status(&name, McpServerStatus::Failed { error: error.clone() });
601        self.emit_server_statuses_changed().await;
602        self.emit_authentication_failed(name, error).await;
603    }
604
605    fn set_status(&mut self, name: &str, status: McpServerStatus) {
606        self.remember_server_order(name);
607        let record =
608            self.servers.entry(name.to_string()).or_insert_with(|| ServerRecord::new(status.clone(), None, false));
609        record.status = status;
610        self.refresh_status_entries();
611    }
612
613    fn register_record(
614        &mut self,
615        name: &str,
616        status: McpServerStatus,
617        reauth_config: Option<StreamableHttpClientTransportConfig>,
618        proxied: bool,
619    ) {
620        self.remember_server_order(name);
621        self.servers.insert(name.to_string(), ServerRecord::new(status, reauth_config, proxied));
622        self.refresh_status_entries();
623    }
624
625    fn connection_for(&self, server_name: &str) -> Option<&McpServerConnection> {
626        self.servers.get(server_name).and_then(|record| record.connection.as_ref())
627    }
628
629    fn client_for_server(&self, server_name: &str) -> Option<Arc<RunningService<RoleClient, McpClient>>> {
630        self.connection_for(server_name).map(|conn| conn.client.clone())
631    }
632
633    fn remove_registered_tools_for_server(&mut self, server_name: &str) {
634        let prefix = format!("{server_name}__");
635        self.tools.retain(|tool_name, _| !tool_name.starts_with(&prefix));
636        self.tool_definitions.retain(|tool_def| !tool_def.name.starts_with(&prefix));
637    }
638
639    async fn notify_roots_changed(&self) {
640        for (server_name, record) in &self.servers {
641            if let Some(conn) = &record.connection
642                && let Err(e) = conn.client.notify_roots_list_changed().await
643            {
644                tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
645            }
646        }
647    }
648}
649
650impl Drop for McpManager {
651    fn drop(&mut self) {
652        let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
653        for (server_name, record) in servers {
654            if let Some(conn) = record.connection
655                && let Some(handle) = conn.server_task
656            {
657                handle.abort();
658                tracing::warn!("Server '{server_name}' task aborted during cleanup");
659            }
660        }
661    }
662}
663
664/// Internal record holding all mutable state for a single MCP server.
665struct ServerRecord {
666    connection: Option<McpServerConnection>,
667    status: McpServerStatus,
668    reauth_config: Option<StreamableHttpClientTransportConfig>,
669    proxied: bool,
670}
671
672impl ServerRecord {
673    fn new(status: McpServerStatus, reauth_config: Option<StreamableHttpClientTransportConfig>, proxied: bool) -> Self {
674        Self { connection: None, status, reauth_config, proxied }
675    }
676
677    fn connected(
678        connection: McpServerConnection,
679        tool_count: usize,
680        reauth_config: Option<StreamableHttpClientTransportConfig>,
681        proxied: bool,
682    ) -> Self {
683        Self { connection: Some(connection), status: McpServerStatus::Connected { tool_count }, reauth_config, proxied }
684    }
685
686    fn auth_capability(&self) -> McpServerAuthCapability {
687        if self.reauth_config.is_some() { McpServerAuthCapability::OAuth } else { McpServerAuthCapability::Unavailable }
688    }
689
690    fn can_authenticate(&self) -> bool {
691        self.reauth_config.is_some()
692    }
693
694    fn status_entry(&self, name: &str) -> McpServerStatusEntry {
695        McpServerStatusEntry::new(name, self.status.clone())
696            .with_auth_capability(self.auth_capability())
697            .with_proxied(self.proxied)
698    }
699}
700
701#[cfg(test)]
702mod tests {
703    use super::{DEFAULT_PROXY_NAME, McpClientEvent, McpManager, McpServerStatus, Tool};
704    use crate::client::OAuthHandlerFactory;
705    use crate::client::config::{McpServer, McpTransport};
706    use crate::client::connection::{McpConnectAttempt, McpConnectOutcome};
707    use crate::status::McpServerAuthCapability;
708    use aether_auth::{OAuthCallback, OAuthError, OAuthHandler};
709    use futures::future::BoxFuture;
710    use llm::ToolDefinition;
711    use rmcp::{
712        Json, RoleServer, ServerHandler,
713        handler::server::{router::tool::ToolRouter, wrapper::Parameters},
714        model::{Implementation, ServerCapabilities, ServerInfo},
715        service::DynService,
716        tool, tool_handler, tool_router,
717        transport::streamable_http_client::StreamableHttpClientTransportConfig,
718    };
719    use schemars::JsonSchema;
720    use serde::{Deserialize, Serialize};
721    use serde_json::json;
722    use std::{
723        io,
724        sync::{Arc, Mutex},
725    };
726    use tokio::sync::mpsc;
727
728    #[derive(Clone)]
729    struct TestServer {
730        tool_router: ToolRouter<Self>,
731    }
732
733    #[tool_handler(router = self.tool_router)]
734    impl ServerHandler for TestServer {
735        fn get_info(&self) -> ServerInfo {
736            ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
737                .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
738        }
739    }
740
741    impl Default for TestServer {
742        fn default() -> Self {
743            Self { tool_router: Self::tool_router() }
744        }
745    }
746
747    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
748    struct EchoRequest {
749        value: String,
750    }
751
752    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
753    struct EchoResult {
754        value: String,
755    }
756
757    #[tool_router]
758    impl TestServer {
759        fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
760            Box::new(self)
761        }
762
763        #[tool(description = "Returns the provided value")]
764        async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
765            let Parameters(EchoRequest { value }) = request;
766            Json(EchoResult { value })
767        }
768    }
769
770    #[derive(Clone)]
771    struct SharedWriter(Arc<Mutex<Vec<u8>>>);
772
773    impl io::Write for SharedWriter {
774        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
775            self.0.lock().unwrap().extend_from_slice(buf);
776            Ok(buf.len())
777        }
778
779        fn flush(&mut self) -> io::Result<()> {
780            Ok(())
781        }
782    }
783
784    struct TestOAuthHandler;
785
786    impl OAuthHandler for TestOAuthHandler {
787        fn redirect_uri(&self) -> &'static str {
788            "http://127.0.0.1:0/oauth2callback"
789        }
790
791        fn authorize(&self, _auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
792            Box::pin(async { Err(OAuthError::UserCancelled) })
793        }
794    }
795
796    fn test_oauth_handler_factory() -> OAuthHandlerFactory {
797        Arc::new(|_ctx| Ok(Arc::new(TestOAuthHandler)))
798    }
799
800    #[tokio::test]
801    async fn authenticate_server_task_rejects_record_without_reauth_config() {
802        let (event_sender, _event_receiver) = mpsc::channel(1);
803        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
804        manager.register_record("public", McpServerStatus::Connected { tool_count: 1 }, None, false);
805
806        let error = match manager.authenticate_server_task("public").await {
807            Ok(_) => panic!("non-OAuth server should be rejected"),
808            Err(error) => error.to_string(),
809        };
810        assert!(error.contains("not OAuth-authenticatable"));
811    }
812
813    #[tokio::test]
814    async fn authenticate_server_task_marks_server_authenticating_and_emits_status() {
815        let (event_sender, mut event_receiver) = mpsc::channel(2);
816        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
817        manager.register_record(
818            "remote",
819            McpServerStatus::NeedsOAuth,
820            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
821            false,
822        );
823
824        let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
825
826        assert!(matches!(manager.server_statuses()[0].status, McpServerStatus::Authenticating));
827        let event = event_receiver.recv().await.expect("status change event");
828        let McpClientEvent::ServerStatusesChanged(servers) = event else {
829            panic!("expected ServerStatusesChanged");
830        };
831        let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
832        assert!(matches!(status.status, McpServerStatus::Authenticating));
833        assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
834    }
835
836    #[tokio::test]
837    async fn apply_connection_attempt_failure_allows_retry() {
838        let (event_sender, mut event_receiver) = mpsc::channel(2);
839        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
840        manager.register_record(
841            "remote",
842            McpServerStatus::NeedsOAuth,
843            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
844            false,
845        );
846        let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
847        let _authenticating_event = event_receiver.recv().await.expect("authenticating status change event");
848
849        manager
850            .apply_connection_attempt(McpConnectAttempt {
851                name: "remote".to_string(),
852                proxied: false,
853                outcome: McpConnectOutcome::Failed {
854                    error: crate::client::McpError::ConnectionFailed("boom".to_string()),
855                },
856            })
857            .await;
858
859        let event = event_receiver.recv().await.expect("status change event");
860        let McpClientEvent::ServerStatusesChanged(servers) = event else {
861            panic!("expected ServerStatusesChanged");
862        };
863        let auth_event = event_receiver.recv().await.expect("authentication failure event");
864        let McpClientEvent::AuthenticationFailed { server, error } = auth_event else {
865            panic!("expected AuthenticationFailed");
866        };
867        assert_eq!(server, "remote");
868        assert!(error.contains("boom"));
869
870        let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
871        assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
872        assert!(matches!(status.status, McpServerStatus::Failed { ref error } if error.contains("boom")));
873        assert!(manager.authenticate_server_task("remote").await.is_ok());
874    }
875
876    #[test]
877    fn status_entries_are_derived_from_reauth_config() {
878        let (event_sender, _event_receiver) = mpsc::channel(1);
879        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
880
881        manager.register_record(
882            "with-oauth",
883            McpServerStatus::Connected { tool_count: 1 },
884            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp")),
885            false,
886        );
887        manager.register_record("without-oauth", McpServerStatus::Connected { tool_count: 2 }, None, false);
888        manager.register_record(
889            "needs-oauth",
890            McpServerStatus::NeedsOAuth,
891            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp2")),
892            false,
893        );
894
895        let statuses = manager.server_statuses();
896        let with_oauth = statuses.iter().find(|s| s.name == "with-oauth").unwrap();
897        let without_oauth = statuses.iter().find(|s| s.name == "without-oauth").unwrap();
898        let needs_oauth = statuses.iter().find(|s| s.name == "needs-oauth").unwrap();
899
900        assert_eq!(with_oauth.auth_capability, McpServerAuthCapability::OAuth);
901        assert_eq!(without_oauth.auth_capability, McpServerAuthCapability::Unavailable);
902        assert_eq!(needs_oauth.auth_capability, McpServerAuthCapability::OAuth);
903    }
904
905    #[tokio::test]
906    async fn server_statuses_mark_direct_and_proxied_servers_without_proxy_row() {
907        let (event_sender, _event_receiver) = mpsc::channel(1);
908        let mut manager = McpManager::new(event_sender, None);
909        manager
910            .add_mcps(vec![
911                McpServer::new("direct", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
912                McpServer::new("math", McpTransport::InMemory { server: TestServer::default().into_dyn() }, true),
913            ])
914            .await
915            .unwrap();
916
917        let statuses = manager.server_statuses();
918        assert_eq!(statuses.iter().map(|status| status.name.as_str()).collect::<Vec<_>>(), vec!["direct", "math"]);
919        assert!(!statuses.iter().find(|status| status.name == "direct").unwrap().proxied);
920        assert!(statuses.iter().find(|status| status.name == "math").unwrap().proxied);
921        assert!(!statuses.iter().any(|status| status.name == DEFAULT_PROXY_NAME));
922    }
923
924    #[test]
925    fn remove_registered_tools_for_server_uses_namespaced_prefix() {
926        let (event_sender, _event_receiver) = mpsc::channel(1);
927        let mut manager = McpManager::new(event_sender, None);
928        manager.tools.insert("git__status".to_string(), Tool { description: String::new(), parameters: json!({}) });
929        manager.tools.insert("github__issue".to_string(), Tool { description: String::new(), parameters: json!({}) });
930        manager.tool_definitions.push(ToolDefinition {
931            name: "git__status".to_string(),
932            description: String::new(),
933            parameters: "{}".to_string(),
934            server: Some("git".to_string()),
935        });
936        manager.tool_definitions.push(ToolDefinition {
937            name: "github__issue".to_string(),
938            description: String::new(),
939            parameters: "{}".to_string(),
940            server: Some("github".to_string()),
941        });
942
943        manager.remove_registered_tools_for_server("git");
944
945        assert!(!manager.tools.contains_key("git__status"));
946        assert!(manager.tools.contains_key("github__issue"));
947        assert_eq!(
948            manager.tool_definitions.iter().map(|tool| tool.name.as_str()).collect::<Vec<_>>(),
949            vec!["github__issue"]
950        );
951    }
952
953    #[tokio::test]
954    async fn drop_logs_cleanup_abort_with_tracing() {
955        let (event_sender, _event_receiver) = mpsc::channel(1);
956        let mut manager = McpManager::new(event_sender, None);
957        manager
958            .add_mcps(vec![McpServer::new(
959                "test",
960                McpTransport::InMemory { server: TestServer::default().into_dyn() },
961                false,
962            )])
963            .await
964            .unwrap();
965
966        let output = Arc::new(Mutex::new(Vec::new()));
967        let subscriber = tracing_subscriber::fmt()
968            .with_ansi(false)
969            .without_time()
970            .with_writer({
971                let output = Arc::clone(&output);
972                move || SharedWriter(Arc::clone(&output))
973            })
974            .finish();
975
976        tracing::subscriber::with_default(subscriber, || {
977            drop(manager);
978        });
979
980        let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
981        assert!(logs.contains("Server 'test' task aborted during cleanup"));
982    }
983}