Skip to main content

mcp_utils/client/
manager.rs

1use llm::ToolDefinition;
2
3use super::{
4    McpError, Result,
5    config::McpServer,
6    connection::{
7        ConnectContext, McpConnectAttempt, McpConnectOutcome, McpServerConnection, ServerInstructions, Tool,
8        authenticate_http, connect_server,
9    },
10    mcp_client::McpClient,
11    naming::{create_namespaced_tool_name, split_on_server_name},
12    tool_proxy::ToolProxy,
13};
14use aether_auth::{OAuthCredentialStorage, OAuthHandler};
15use futures::future::join_all;
16use rmcp::{
17    RoleClient,
18    model::{
19        CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
20        ElicitationAction, FormElicitationCapability, Implementation, Root, Tool as RmcpTool, UrlElicitationCapability,
21    },
22    service::RunningService,
23    transport::streamable_http_client::StreamableHttpClientTransportConfig,
24};
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::{HashMap, HashSet};
28use std::future::Future;
29use std::path::PathBuf;
30use std::sync::Arc;
31use tokio::sync::{RwLock, mpsc, oneshot};
32
33pub use crate::status::{McpServerAuthCapability, McpServerStatus, McpServerStatusEntry};
34
35pub const DEFAULT_PROXY_NAME: &str = "proxy";
36
37pub type OAuthHandlerFactory = Arc<dyn Fn() -> Result<Arc<dyn OAuthHandler>> + Send + Sync>;
38
39#[derive(Debug)]
40pub struct ElicitationRequest {
41    pub server_name: String,
42    pub request: CreateElicitationRequestParams,
43    pub response_sender: oneshot::Sender<CreateElicitationResult>,
44}
45
46#[derive(Debug, Clone)]
47pub struct ElicitationResponse {
48    pub action: ElicitationAction,
49    pub content: Option<Value>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
53pub struct UrlElicitationCompleteParams {
54    pub server_name: String,
55    pub elicitation_id: String,
56}
57
58/// Events emitted by MCP clients that require attention from the host
59/// (e.g. the relay or TUI). Flows through a single channel from `McpManager`
60/// to the consumer.
61#[derive(Debug)]
62pub enum McpClientEvent {
63    Elicitation(ElicitationRequest),
64    UrlElicitationComplete(UrlElicitationCompleteParams),
65    ServerStatusesChanged(Vec<McpServerStatusEntry>),
66    ToolDefinitionsChanged(Vec<ToolDefinition>),
67    AuthenticationFailed { server: String, error: String },
68}
69
70/// Manages connections to multiple MCP servers and their tools
71pub struct McpManager {
72    servers: HashMap<String, ServerRecord>,
73    server_order: Vec<String>,
74    tools: HashMap<String, Tool>,
75    tool_definitions: Vec<ToolDefinition>,
76    proxy: Option<ToolProxy>,
77    aether_home: Option<PathBuf>,
78    client_info: ClientInfo,
79    event_sender: mpsc::Sender<McpClientEvent>,
80    /// Roots shared with all MCP clients
81    roots: Arc<RwLock<Vec<Root>>>,
82    oauth_handler_factory: Option<OAuthHandlerFactory>,
83    oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
84    server_statuses: Vec<McpServerStatusEntry>,
85}
86
87impl McpManager {
88    pub fn new(event_sender: mpsc::Sender<McpClientEvent>, oauth_handler_factory: Option<OAuthHandlerFactory>) -> Self {
89        let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
90        if let Some(elicitation) = capabilities.elicitation.as_mut() {
91            elicitation.form = Some(FormElicitationCapability::default());
92            elicitation.url = Some(UrlElicitationCapability::default());
93        }
94
95        Self {
96            servers: HashMap::new(),
97            server_order: Vec::new(),
98            tools: HashMap::new(),
99            tool_definitions: Vec::new(),
100            proxy: None,
101            aether_home: None,
102            client_info: ClientInfo::new(capabilities, Implementation::new("aether", "0.1.0")),
103            event_sender,
104            roots: Arc::new(RwLock::new(Vec::new())),
105            oauth_handler_factory,
106            oauth_credential_store: None,
107            server_statuses: Vec::new(),
108        }
109    }
110
111    pub fn with_aether_home(mut self, aether_home: impl Into<PathBuf>) -> Self {
112        self.aether_home = Some(aether_home.into());
113        self
114    }
115
116    pub fn with_oauth_credential_store(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
117        self.oauth_credential_store = Some(store);
118        self
119    }
120
121    pub async fn add_mcps(&mut self, servers: Vec<McpServer>) -> Result<()> {
122        let has_proxy = servers.iter().any(|server| server.proxy);
123        if has_proxy && servers.iter().any(|server| server.name == DEFAULT_PROXY_NAME) {
124            return Err(McpError::Other("server name 'proxy' collides with the tool proxy".into()));
125        }
126
127        let proxied_members: HashSet<String> =
128            servers.iter().filter(|server| server.proxy).map(|server| server.name.clone()).collect();
129        let proxy_tool_dir = if has_proxy {
130            let dir = self.proxy_tool_dir()?;
131            ToolProxy::clean_dir(&dir).await?;
132            Some(dir)
133        } else {
134            None
135        };
136
137        let ctx = self.connect_context();
138        let attempts = join_all(servers.into_iter().map(|server| connect_server(server, &ctx))).await;
139
140        let mut connected_proxied = Vec::new();
141        for McpConnectAttempt { name, proxied, outcome } in attempts {
142            match outcome {
143                McpConnectOutcome::Connected { conn, reauth_config } => {
144                    self.register_connection(&name, conn, reauth_config, proxied).await?;
145                    if proxied {
146                        connected_proxied.push(name);
147                    }
148                }
149                McpConnectOutcome::NeedsOAuth { config, error } => {
150                    tracing::warn!("Server '{name}' needs OAuth: {error}");
151                    self.register_record(&name, McpServerStatus::NeedsOAuth, Some(config), proxied);
152                }
153                McpConnectOutcome::Failed { error } => {
154                    tracing::warn!("Failed to connect to MCP server '{name}': {error}");
155                    if !self.servers.contains_key(&name) {
156                        self.register_record(
157                            &name,
158                            McpServerStatus::Failed { error: error.to_string() },
159                            None,
160                            proxied,
161                        );
162                    }
163                }
164            }
165        }
166
167        if let Some(tool_dir) = proxy_tool_dir {
168            self.write_proxy_tool_files(&connected_proxied, &tool_dir).await;
169            self.register_proxy(tool_dir, proxied_members);
170        }
171
172        Ok(())
173    }
174
175    pub fn get_client_for_tool(
176        &self,
177        namespaced_tool_name: &str,
178        arguments_json: &str,
179    ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
180        if !self.tools.contains_key(namespaced_tool_name) {
181            return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
182        }
183
184        let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
185            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
186
187        if let Some(proxy) = self.proxy.as_ref().filter(|proxy| proxy.name() == server_name) {
188            let call = proxy.resolve_call(arguments_json)?;
189            let conn = self.connection_for(&call.server).ok_or_else(|| {
190                McpError::ServerNotFound(format!("Proxied server '{}' is not connected", call.server))
191            })?;
192            let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
193            return Ok((conn.client.clone(), params));
194        }
195
196        let client =
197            self.client_for_server(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
198
199        let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
200        let mut params = CallToolRequestParams::new(tool_name.to_string());
201        if let Some(args) = arguments {
202            params = params.with_arguments(args);
203        }
204
205        Ok((client, params))
206    }
207
208    pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
209        self.tool_definitions.clone()
210    }
211
212    pub fn server_instructions(&self) -> Vec<ServerInstructions> {
213        let mut instructions: Vec<ServerInstructions> = self
214            .servers
215            .iter()
216            .filter(|(name, _)| self.proxy.as_ref().is_none_or(|proxy| !proxy.contains_server(name)))
217            .filter_map(|(name, record)| {
218                record
219                    .connection
220                    .as_ref()
221                    .and_then(|conn| conn.instructions.as_ref())
222                    .map(|instr| ServerInstructions { server_name: name.clone(), instructions: instr.clone() })
223            })
224            .collect();
225
226        if let Some(proxy) = &self.proxy {
227            let descriptions: Vec<(String, String)> = proxy
228                .members()
229                .iter()
230                .filter_map(|member| {
231                    let conn = self.connection_for(member)?;
232                    Some((member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
233                })
234                .collect();
235            instructions.push(ServerInstructions {
236                server_name: proxy.name().to_string(),
237                instructions: ToolProxy::build_instructions(proxy.tool_dir(), &descriptions),
238            });
239        }
240
241        instructions
242    }
243
244    pub fn server_statuses(&self) -> &[McpServerStatusEntry] {
245        &self.server_statuses
246    }
247
248    pub async fn authenticate_server_task(
249        &mut self,
250        name: &str,
251    ) -> Result<impl Future<Output = McpConnectAttempt> + Send + 'static> {
252        let record = self
253            .servers
254            .get(name)
255            .ok_or_else(|| McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")))?;
256        if !record.can_authenticate() {
257            return Err(McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")));
258        }
259        if matches!(record.status, McpServerStatus::Authenticating) {
260            return Err(McpError::ConnectionFailed(format!("server '{name}' is already authenticating")));
261        }
262
263        let oauth_handler_factory = self
264            .oauth_handler_factory
265            .clone()
266            .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler factory available for '{name}'")))?;
267        let oauth_credential_store = self.oauth_credential_store.clone();
268        let name = name.to_string();
269        let config = record.reauth_config.clone().expect("checked above");
270        let client_info = self.client_info.clone();
271        let event_sender = self.event_sender.clone();
272        let roots = Arc::clone(&self.roots);
273        let proxied = record.proxied;
274
275        self.set_status(&name, McpServerStatus::Authenticating);
276        self.emit_server_statuses_changed().await;
277
278        Ok(async move {
279            authenticate_http(
280                name,
281                config,
282                client_info,
283                event_sender,
284                roots,
285                oauth_handler_factory,
286                oauth_credential_store,
287                proxied,
288            )
289            .await
290        })
291    }
292
293    pub async fn apply_connection_attempt(&mut self, attempt: McpConnectAttempt) {
294        let McpConnectAttempt { name, proxied, outcome } = attempt;
295        match outcome {
296            McpConnectOutcome::Connected { conn, reauth_config } => {
297                match self.register_connection(&name, conn, reauth_config, proxied).await {
298                    Ok(tools) => {
299                        self.refresh_proxy_after_auth(&name, &tools, proxied).await;
300                        self.emit_server_statuses_changed().await;
301                        self.emit_tool_definitions_changed().await;
302                    }
303                    Err(error) => self.apply_authentication_failure(name, error.to_string()).await,
304                }
305            }
306            McpConnectOutcome::Failed { error } => {
307                self.apply_authentication_failure(name, error.to_string()).await;
308            }
309            McpConnectOutcome::NeedsOAuth { .. } => {
310                self.apply_authentication_failure(name, "internal error: auth task returned NeedsOAuth".to_string())
311                    .await;
312            }
313        }
314    }
315
316    /// List all prompts from all connected MCP servers with namespacing
317    pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
318        let futures: Vec<_> = self
319            .servers
320            .iter()
321            .filter_map(|(server_name, record)| {
322                let conn = record.connection.as_ref()?;
323                conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref())?;
324                let server_name = server_name.clone();
325                let client = conn.client.clone();
326                Some(async move {
327                    let prompts_response = client.list_prompts(None).await.map_err(|e| {
328                        McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
329                    })?;
330
331                    let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
332                        .prompts
333                        .into_iter()
334                        .map(|prompt| {
335                            let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
336                            rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
337                        })
338                        .collect();
339
340                    Ok::<_, McpError>(namespaced_prompts)
341                })
342            })
343            .collect();
344
345        let results = join_all(futures).await;
346        let mut all_prompts = Vec::new();
347        for result in results {
348            all_prompts.extend(result?);
349        }
350
351        Ok(all_prompts)
352    }
353
354    /// Get a specific prompt by namespaced name
355    pub async fn get_prompt(
356        &self,
357        namespaced_prompt_name: &str,
358        arguments: Option<serde_json::Map<String, serde_json::Value>>,
359    ) -> Result<rmcp::model::GetPromptResult> {
360        let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
361            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
362
363        let server_conn =
364            self.connection_for(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
365
366        let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
367        if let Some(args) = arguments {
368            request = request.with_arguments(args);
369        }
370
371        server_conn.client.get_prompt(request).await.map_err(|e| {
372            McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
373        })
374    }
375
376    /// Shutdown all servers and wait for their tasks to complete
377    pub async fn shutdown(&mut self) {
378        let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
379
380        for (server_name, record) in servers {
381            if let Some(conn) = record.connection
382                && let Some(handle) = conn.server_task
383            {
384                drop(conn.client);
385
386                match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
387                    Ok(Ok(())) => {
388                        tracing::info!("Server '{server_name}' shut down gracefully");
389                    }
390                    Ok(Err(e)) => {
391                        tracing::warn!("Server '{server_name}' task panicked: {e:?}");
392                    }
393                    Err(_) => {
394                        tracing::warn!("Server '{server_name}' shutdown timed out");
395                    }
396                }
397            }
398        }
399
400        self.tools.clear();
401        self.tool_definitions.clear();
402        self.proxy = None;
403    }
404
405    /// Shutdown a specific server by name
406    pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
407        let record = self.servers.remove(server_name);
408
409        if let Some(record) = record {
410            if let Some(conn) = record.connection
411                && let Some(handle) = conn.server_task
412            {
413                drop(conn.client);
414
415                match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
416                    Ok(Ok(())) => {
417                        tracing::info!("Server '{server_name}' shut down gracefully");
418                    }
419                    Ok(Err(e)) => {
420                        tracing::warn!("Server '{server_name}' task panicked: {e:?}");
421                    }
422                    Err(_) => {
423                        tracing::warn!("Server '{server_name}' shutdown timed out");
424                    }
425                }
426            }
427
428            self.remove_registered_tools_for_server(server_name);
429            self.refresh_status_entries();
430        }
431
432        Ok(())
433    }
434
435    /// Set the roots advertised to MCP servers.
436    ///
437    /// This updates the roots and sends notifications to all connected servers
438    /// that support the `roots/list_changed` notification.
439    pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
440        {
441            let mut roots = self.roots.write().await;
442            *roots = new_roots;
443        }
444
445        self.notify_roots_changed().await;
446
447        Ok(())
448    }
449
450    async fn emit_server_statuses_changed(&self) {
451        self.emit_event(McpClientEvent::ServerStatusesChanged(self.server_statuses().to_vec())).await;
452    }
453
454    async fn emit_tool_definitions_changed(&self) {
455        self.emit_event(McpClientEvent::ToolDefinitionsChanged(self.tool_definitions())).await;
456    }
457
458    async fn emit_authentication_failed(&self, server: String, error: String) {
459        self.emit_event(McpClientEvent::AuthenticationFailed { server, error }).await;
460    }
461
462    async fn emit_event(&self, event: McpClientEvent) {
463        if let Err(e) = self.event_sender.send(event).await {
464            tracing::warn!("Failed to emit MCP client event: {e}");
465        }
466    }
467
468    fn connect_context(&self) -> ConnectContext<'_> {
469        ConnectContext {
470            client_info: &self.client_info,
471            event_sender: &self.event_sender,
472            roots: &self.roots,
473            oauth_handler_factory: self.oauth_handler_factory.as_ref(),
474            oauth_credential_store: self.oauth_credential_store.as_ref(),
475        }
476    }
477
478    fn proxy_tool_dir(&self) -> Result<PathBuf> {
479        self.aether_home
480            .as_ref()
481            .map(|home| ToolProxy::dir_in_home(home, DEFAULT_PROXY_NAME))
482            .map_or_else(|| ToolProxy::dir(DEFAULT_PROXY_NAME), Ok)
483    }
484
485    async fn register_connection(
486        &mut self,
487        name: &str,
488        conn: McpServerConnection,
489        reauth_config: Option<StreamableHttpClientTransportConfig>,
490        proxied: bool,
491    ) -> Result<Vec<RmcpTool>> {
492        let tools = conn
493            .list_tools()
494            .await
495            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
496        self.apply_connected(name, conn, &tools, reauth_config, proxied);
497        Ok(tools)
498    }
499
500    fn apply_connected(
501        &mut self,
502        name: &str,
503        conn: McpServerConnection,
504        tools: &[RmcpTool],
505        reauth_config: Option<StreamableHttpClientTransportConfig>,
506        proxied: bool,
507    ) {
508        self.remove_registered_tools_for_server(name);
509
510        let existing_reauth = self.servers.get(name).and_then(|r| r.reauth_config.clone());
511        let final_reauth = reauth_config.or(existing_reauth);
512
513        for rmcp_tool in tools {
514            let tool_name = rmcp_tool.name.to_string();
515            let namespaced_tool_name = create_namespaced_tool_name(name, &tool_name);
516            let tool = Tool::from(rmcp_tool);
517
518            if !proxied {
519                self.tool_definitions.push(ToolDefinition {
520                    name: namespaced_tool_name.clone(),
521                    description: tool.description.clone(),
522                    parameters: tool.parameters.to_string(),
523                    server: Some(name.to_string()),
524                });
525                self.tools.insert(namespaced_tool_name, tool);
526            }
527        }
528
529        self.remember_server_order(name);
530        self.servers.insert(name.to_string(), ServerRecord::connected(conn, tools.len(), final_reauth, proxied));
531        self.refresh_status_entries();
532    }
533
534    fn register_proxy(&mut self, tool_dir: std::path::PathBuf, members: HashSet<String>) {
535        self.remove_registered_tools_for_server(DEFAULT_PROXY_NAME);
536        let call_tool_def = ToolProxy::call_tool_definition(DEFAULT_PROXY_NAME);
537        self.tools.insert(
538            call_tool_def.name.clone(),
539            Tool {
540                description: call_tool_def.description.clone(),
541                parameters: serde_json::from_str(&call_tool_def.parameters)
542                    .unwrap_or(Value::Object(serde_json::Map::default())),
543            },
544        );
545        self.tool_definitions.push(call_tool_def);
546
547        self.proxy = Some(ToolProxy::new(DEFAULT_PROXY_NAME.to_string(), members, tool_dir));
548    }
549
550    async fn refresh_proxy_after_auth(&mut self, name: &str, tools: &[RmcpTool], proxied: bool) {
551        if !proxied {
552            return;
553        }
554
555        if let Some(proxy) = self.proxy.as_mut() {
556            proxy.add_member(name.to_string());
557        }
558
559        if let Some(tool_dir) = self.proxy.as_ref().map(|proxy| proxy.tool_dir().to_path_buf())
560            && let Err(e) = ToolProxy::write_tool_entries_to_dir(name, tools, &tool_dir).await
561        {
562            tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
563        }
564    }
565
566    async fn write_proxy_tool_files(&self, connected_proxied: &[String], tool_dir: &std::path::Path) {
567        let writes = connected_proxied.iter().filter_map(|name| {
568            let client = self.client_for_server(name)?;
569            let dir = tool_dir.to_path_buf();
570            let name = name.clone();
571            Some(async move {
572                if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &dir).await {
573                    tracing::warn!("Failed to write tool files for proxied server '{name}': {e}");
574                }
575            })
576        });
577        join_all(writes).await;
578    }
579
580    fn refresh_status_entries(&mut self) {
581        self.server_statuses = self
582            .server_order
583            .iter()
584            .filter_map(|name| self.servers.get(name).map(|record| record.status_entry(name)))
585            .collect();
586    }
587
588    fn remember_server_order(&mut self, name: &str) {
589        if !self.server_order.iter().any(|n| n == name) {
590            self.server_order.push(name.to_string());
591        }
592    }
593
594    async fn apply_authentication_failure(&mut self, name: String, error: String) {
595        self.set_status(&name, McpServerStatus::Failed { error: error.clone() });
596        self.emit_server_statuses_changed().await;
597        self.emit_authentication_failed(name, error).await;
598    }
599
600    fn set_status(&mut self, name: &str, status: McpServerStatus) {
601        self.remember_server_order(name);
602        let record =
603            self.servers.entry(name.to_string()).or_insert_with(|| ServerRecord::new(status.clone(), None, false));
604        record.status = status;
605        self.refresh_status_entries();
606    }
607
608    fn register_record(
609        &mut self,
610        name: &str,
611        status: McpServerStatus,
612        reauth_config: Option<StreamableHttpClientTransportConfig>,
613        proxied: bool,
614    ) {
615        self.remember_server_order(name);
616        self.servers.insert(name.to_string(), ServerRecord::new(status, reauth_config, proxied));
617        self.refresh_status_entries();
618    }
619
620    fn connection_for(&self, server_name: &str) -> Option<&McpServerConnection> {
621        self.servers.get(server_name).and_then(|record| record.connection.as_ref())
622    }
623
624    fn client_for_server(&self, server_name: &str) -> Option<Arc<RunningService<RoleClient, McpClient>>> {
625        self.connection_for(server_name).map(|conn| conn.client.clone())
626    }
627
628    fn remove_registered_tools_for_server(&mut self, server_name: &str) {
629        let prefix = format!("{server_name}__");
630        self.tools.retain(|tool_name, _| !tool_name.starts_with(&prefix));
631        self.tool_definitions.retain(|tool_def| !tool_def.name.starts_with(&prefix));
632    }
633
634    async fn notify_roots_changed(&self) {
635        for (server_name, record) in &self.servers {
636            if let Some(conn) = &record.connection
637                && let Err(e) = conn.client.notify_roots_list_changed().await
638            {
639                tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
640            }
641        }
642    }
643}
644
645impl Drop for McpManager {
646    fn drop(&mut self) {
647        let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
648        for (server_name, record) in servers {
649            if let Some(conn) = record.connection
650                && let Some(handle) = conn.server_task
651            {
652                handle.abort();
653                tracing::warn!("Server '{server_name}' task aborted during cleanup");
654            }
655        }
656    }
657}
658
659/// Internal record holding all mutable state for a single MCP server.
660struct ServerRecord {
661    connection: Option<McpServerConnection>,
662    status: McpServerStatus,
663    reauth_config: Option<StreamableHttpClientTransportConfig>,
664    proxied: bool,
665}
666
667impl ServerRecord {
668    fn new(status: McpServerStatus, reauth_config: Option<StreamableHttpClientTransportConfig>, proxied: bool) -> Self {
669        Self { connection: None, status, reauth_config, proxied }
670    }
671
672    fn connected(
673        connection: McpServerConnection,
674        tool_count: usize,
675        reauth_config: Option<StreamableHttpClientTransportConfig>,
676        proxied: bool,
677    ) -> Self {
678        Self { connection: Some(connection), status: McpServerStatus::Connected { tool_count }, reauth_config, proxied }
679    }
680
681    fn auth_capability(&self) -> McpServerAuthCapability {
682        if self.reauth_config.is_some() { McpServerAuthCapability::OAuth } else { McpServerAuthCapability::Unavailable }
683    }
684
685    fn can_authenticate(&self) -> bool {
686        self.reauth_config.is_some()
687    }
688
689    fn status_entry(&self, name: &str) -> McpServerStatusEntry {
690        McpServerStatusEntry::new(name, self.status.clone())
691            .with_auth_capability(self.auth_capability())
692            .with_proxied(self.proxied)
693    }
694}
695
696#[cfg(test)]
697mod tests {
698    use super::{DEFAULT_PROXY_NAME, McpClientEvent, McpManager, McpServerStatus, Tool};
699    use crate::client::OAuthHandlerFactory;
700    use crate::client::config::{McpServer, McpTransport};
701    use crate::client::connection::{McpConnectAttempt, McpConnectOutcome};
702    use crate::status::McpServerAuthCapability;
703    use aether_auth::{OAuthCallback, OAuthError, OAuthHandler};
704    use futures::future::BoxFuture;
705    use llm::ToolDefinition;
706    use rmcp::{
707        Json, RoleServer, ServerHandler,
708        handler::server::{router::tool::ToolRouter, wrapper::Parameters},
709        model::{Implementation, ServerCapabilities, ServerInfo},
710        service::DynService,
711        tool, tool_handler, tool_router,
712        transport::streamable_http_client::StreamableHttpClientTransportConfig,
713    };
714    use schemars::JsonSchema;
715    use serde::{Deserialize, Serialize};
716    use serde_json::json;
717    use std::{
718        io,
719        sync::{Arc, Mutex},
720    };
721    use tokio::sync::mpsc;
722
723    #[derive(Clone)]
724    struct TestServer {
725        tool_router: ToolRouter<Self>,
726    }
727
728    #[tool_handler(router = self.tool_router)]
729    impl ServerHandler for TestServer {
730        fn get_info(&self) -> ServerInfo {
731            ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
732                .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
733        }
734    }
735
736    impl Default for TestServer {
737        fn default() -> Self {
738            Self { tool_router: Self::tool_router() }
739        }
740    }
741
742    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
743    struct EchoRequest {
744        value: String,
745    }
746
747    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
748    struct EchoResult {
749        value: String,
750    }
751
752    #[tool_router]
753    impl TestServer {
754        fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
755            Box::new(self)
756        }
757
758        #[tool(description = "Returns the provided value")]
759        async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
760            let Parameters(EchoRequest { value }) = request;
761            Json(EchoResult { value })
762        }
763    }
764
765    #[derive(Clone)]
766    struct SharedWriter(Arc<Mutex<Vec<u8>>>);
767
768    impl io::Write for SharedWriter {
769        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
770            self.0.lock().unwrap().extend_from_slice(buf);
771            Ok(buf.len())
772        }
773
774        fn flush(&mut self) -> io::Result<()> {
775            Ok(())
776        }
777    }
778
779    struct TestOAuthHandler;
780
781    impl OAuthHandler for TestOAuthHandler {
782        fn redirect_uri(&self) -> &'static str {
783            "http://127.0.0.1:0/oauth2callback"
784        }
785
786        fn authorize(&self, _auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
787            Box::pin(async { Err(OAuthError::UserCancelled) })
788        }
789    }
790
791    fn test_oauth_handler_factory() -> OAuthHandlerFactory {
792        Arc::new(|| Ok(Arc::new(TestOAuthHandler)))
793    }
794
795    #[tokio::test]
796    async fn authenticate_server_task_rejects_record_without_reauth_config() {
797        let (event_sender, _event_receiver) = mpsc::channel(1);
798        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
799        manager.register_record("public", McpServerStatus::Connected { tool_count: 1 }, None, false);
800
801        let error = match manager.authenticate_server_task("public").await {
802            Ok(_) => panic!("non-OAuth server should be rejected"),
803            Err(error) => error.to_string(),
804        };
805        assert!(error.contains("not OAuth-authenticatable"));
806    }
807
808    #[tokio::test]
809    async fn authenticate_server_task_marks_server_authenticating_and_emits_status() {
810        let (event_sender, mut event_receiver) = mpsc::channel(2);
811        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
812        manager.register_record(
813            "remote",
814            McpServerStatus::NeedsOAuth,
815            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
816            false,
817        );
818
819        let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
820
821        assert!(matches!(manager.server_statuses()[0].status, McpServerStatus::Authenticating));
822        let event = event_receiver.recv().await.expect("status change event");
823        let McpClientEvent::ServerStatusesChanged(servers) = event else {
824            panic!("expected ServerStatusesChanged");
825        };
826        let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
827        assert!(matches!(status.status, McpServerStatus::Authenticating));
828        assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
829    }
830
831    #[tokio::test]
832    async fn authenticate_server_task_rejects_duplicate_same_server_while_in_flight() {
833        let (event_sender, _event_receiver) = mpsc::channel(1);
834        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
835        manager.register_record(
836            "remote",
837            McpServerStatus::NeedsOAuth,
838            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
839            false,
840        );
841
842        let _task = manager.authenticate_server_task("remote").await.expect("first auth should start");
843        let error = match manager.authenticate_server_task("remote").await {
844            Ok(_) => panic!("duplicate auth should be rejected"),
845            Err(error) => error.to_string(),
846        };
847
848        assert!(error.contains("already authenticating"));
849    }
850
851    #[tokio::test]
852    async fn apply_connection_attempt_failure_allows_retry() {
853        let (event_sender, mut event_receiver) = mpsc::channel(2);
854        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
855        manager.register_record(
856            "remote",
857            McpServerStatus::NeedsOAuth,
858            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
859            false,
860        );
861        let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
862        let _authenticating_event = event_receiver.recv().await.expect("authenticating status change event");
863
864        manager
865            .apply_connection_attempt(McpConnectAttempt {
866                name: "remote".to_string(),
867                proxied: false,
868                outcome: McpConnectOutcome::Failed {
869                    error: crate::client::McpError::ConnectionFailed("boom".to_string()),
870                },
871            })
872            .await;
873
874        let event = event_receiver.recv().await.expect("status change event");
875        let McpClientEvent::ServerStatusesChanged(servers) = event else {
876            panic!("expected ServerStatusesChanged");
877        };
878        let auth_event = event_receiver.recv().await.expect("authentication failure event");
879        let McpClientEvent::AuthenticationFailed { server, error } = auth_event else {
880            panic!("expected AuthenticationFailed");
881        };
882        assert_eq!(server, "remote");
883        assert!(error.contains("boom"));
884
885        let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
886        assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
887        assert!(matches!(status.status, McpServerStatus::Failed { ref error } if error.contains("boom")));
888        assert!(manager.authenticate_server_task("remote").await.is_ok());
889    }
890
891    #[test]
892    fn status_entries_are_derived_from_reauth_config() {
893        let (event_sender, _event_receiver) = mpsc::channel(1);
894        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
895
896        manager.register_record(
897            "with-oauth",
898            McpServerStatus::Connected { tool_count: 1 },
899            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp")),
900            false,
901        );
902        manager.register_record("without-oauth", McpServerStatus::Connected { tool_count: 2 }, None, false);
903        manager.register_record(
904            "needs-oauth",
905            McpServerStatus::NeedsOAuth,
906            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp2")),
907            false,
908        );
909
910        let statuses = manager.server_statuses();
911        let with_oauth = statuses.iter().find(|s| s.name == "with-oauth").unwrap();
912        let without_oauth = statuses.iter().find(|s| s.name == "without-oauth").unwrap();
913        let needs_oauth = statuses.iter().find(|s| s.name == "needs-oauth").unwrap();
914
915        assert_eq!(with_oauth.auth_capability, McpServerAuthCapability::OAuth);
916        assert_eq!(without_oauth.auth_capability, McpServerAuthCapability::Unavailable);
917        assert_eq!(needs_oauth.auth_capability, McpServerAuthCapability::OAuth);
918    }
919
920    #[tokio::test]
921    async fn server_statuses_mark_direct_and_proxied_servers_without_proxy_row() {
922        let (event_sender, _event_receiver) = mpsc::channel(1);
923        let mut manager = McpManager::new(event_sender, None);
924        manager
925            .add_mcps(vec![
926                McpServer::new("direct", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
927                McpServer::new("math", McpTransport::InMemory { server: TestServer::default().into_dyn() }, true),
928            ])
929            .await
930            .unwrap();
931
932        let statuses = manager.server_statuses();
933        assert_eq!(statuses.iter().map(|status| status.name.as_str()).collect::<Vec<_>>(), vec!["direct", "math"]);
934        assert!(!statuses.iter().find(|status| status.name == "direct").unwrap().proxied);
935        assert!(statuses.iter().find(|status| status.name == "math").unwrap().proxied);
936        assert!(!statuses.iter().any(|status| status.name == DEFAULT_PROXY_NAME));
937    }
938
939    #[test]
940    fn remove_registered_tools_for_server_uses_namespaced_prefix() {
941        let (event_sender, _event_receiver) = mpsc::channel(1);
942        let mut manager = McpManager::new(event_sender, None);
943        manager.tools.insert("git__status".to_string(), Tool { description: String::new(), parameters: json!({}) });
944        manager.tools.insert("github__issue".to_string(), Tool { description: String::new(), parameters: json!({}) });
945        manager.tool_definitions.push(ToolDefinition {
946            name: "git__status".to_string(),
947            description: String::new(),
948            parameters: "{}".to_string(),
949            server: Some("git".to_string()),
950        });
951        manager.tool_definitions.push(ToolDefinition {
952            name: "github__issue".to_string(),
953            description: String::new(),
954            parameters: "{}".to_string(),
955            server: Some("github".to_string()),
956        });
957
958        manager.remove_registered_tools_for_server("git");
959
960        assert!(!manager.tools.contains_key("git__status"));
961        assert!(manager.tools.contains_key("github__issue"));
962        assert_eq!(
963            manager.tool_definitions.iter().map(|tool| tool.name.as_str()).collect::<Vec<_>>(),
964            vec!["github__issue"]
965        );
966    }
967
968    #[tokio::test]
969    async fn drop_logs_cleanup_abort_with_tracing() {
970        let (event_sender, _event_receiver) = mpsc::channel(1);
971        let mut manager = McpManager::new(event_sender, None);
972        manager
973            .add_mcps(vec![McpServer::new(
974                "test",
975                McpTransport::InMemory { server: TestServer::default().into_dyn() },
976                false,
977            )])
978            .await
979            .unwrap();
980
981        let output = Arc::new(Mutex::new(Vec::new()));
982        let subscriber = tracing_subscriber::fmt()
983            .with_ansi(false)
984            .without_time()
985            .with_writer({
986                let output = Arc::clone(&output);
987                move || SharedWriter(Arc::clone(&output))
988            })
989            .finish();
990
991        tracing::subscriber::with_default(subscriber, || {
992            drop(manager);
993        });
994
995        let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
996        assert!(logs.contains("Server 'test' task aborted during cleanup"));
997    }
998}