Skip to main content

mcp_utils/client/
manager.rs

1use llm::ToolDefinition;
2
3use super::{
4    McpError, Result,
5    config::McpServer,
6    connection::{
7        ConnectConfig, McpConnectAttempt, McpConnectOutcome, McpServerConnection, Tool, authenticate_http,
8        connect_server,
9    },
10    mcp_client::McpClient,
11    naming::{create_namespaced_tool_name, split_on_server_name},
12    tool_proxy::ToolProxy,
13};
14use aether_auth::{OAuthCredentialStorage, OAuthHandler};
15use futures::future::join_all;
16use rmcp::{
17    RoleClient,
18    model::{
19        CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
20        ElicitationAction, FormElicitationCapability, Implementation, Root, Tool as RmcpTool, UrlElicitationCapability,
21    },
22    service::RunningService,
23    transport::streamable_http_client::StreamableHttpClientTransportConfig,
24};
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::{BTreeMap, HashMap, HashSet};
28use std::future::Future;
29use std::path::PathBuf;
30use std::sync::Arc;
31use tokio::sync::{RwLock, mpsc, oneshot};
32
33pub use crate::status::{McpServerAuthCapability, McpServerStatus, McpServerStatusEntry};
34
35pub const DEFAULT_PROXY_NAME: &str = "proxy";
36
37pub type OAuthHandlerFactory = Arc<dyn Fn(OAuthHandlerContext) -> Result<Arc<dyn OAuthHandler>> + Send + Sync>;
38
39/// Context passed to an `OAuthHandlerFactory` so the constructed handler can
40/// dispatch user-facing prompts back to the host through the MCP event channel.
41#[derive(Clone)]
42pub struct OAuthHandlerContext {
43    pub server_name: String,
44    pub tx: mpsc::Sender<McpClientEvent>,
45}
46
47#[derive(Debug)]
48pub struct ElicitationRequest {
49    pub server_name: String,
50    pub request: CreateElicitationRequestParams,
51    pub response_sender: oneshot::Sender<CreateElicitationResult>,
52}
53
54#[derive(Debug, Clone)]
55pub struct ElicitationResponse {
56    pub action: ElicitationAction,
57    pub content: Option<Value>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
61pub struct UrlElicitationCompleteParams {
62    pub server_name: String,
63    pub elicitation_id: String,
64}
65
66/// Events emitted by MCP clients that require attention from the host
67/// (e.g. the relay or TUI). Flows through a single channel from `McpManager`
68/// to the consumer.
69#[derive(Debug)]
70pub enum McpClientEvent {
71    Elicitation(ElicitationRequest),
72    UrlElicitationComplete(UrlElicitationCompleteParams),
73    ServerStatusesChanged(Vec<McpServerStatusEntry>),
74    ToolDefinitionsChanged(Vec<ToolDefinition>),
75    ServerInstructionsUpdated { server: String, instructions: Option<String> },
76    AuthenticationFailed { server: String, error: String },
77    ConnectionReady(McpConnectionDetails),
78}
79
80#[derive(Debug, Clone)]
81pub struct McpConnectionDetails {
82    pub instructions: BTreeMap<String, String>,
83    pub tool_definitions: Vec<ToolDefinition>,
84    pub server_statuses: Vec<McpServerStatusEntry>,
85}
86
87/// Manages connections to multiple MCP servers and their tools
88pub struct McpManager {
89    servers: HashMap<String, ServerRecord>,
90    server_order: Vec<String>,
91    proxy: Option<ToolProxy>,
92    aether_home: Option<PathBuf>,
93    client_info: ClientInfo,
94    event_sender: mpsc::Sender<McpClientEvent>,
95    /// Roots shared with all MCP clients
96    roots: Arc<RwLock<Vec<Root>>>,
97    oauth_handler_factory: Option<OAuthHandlerFactory>,
98    oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
99}
100
101impl McpManager {
102    pub fn new(event_sender: mpsc::Sender<McpClientEvent>, oauth_handler_factory: Option<OAuthHandlerFactory>) -> Self {
103        let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
104        if let Some(elicitation) = capabilities.elicitation.as_mut() {
105            elicitation.form = Some(FormElicitationCapability::default());
106            elicitation.url = Some(UrlElicitationCapability::default());
107        }
108
109        Self {
110            servers: HashMap::new(),
111            server_order: Vec::new(),
112            proxy: None,
113            aether_home: None,
114            client_info: ClientInfo::new(capabilities, Implementation::new("aether", "0.1.0")),
115            event_sender,
116            roots: Arc::new(RwLock::new(Vec::new())),
117            oauth_handler_factory,
118            oauth_credential_store: None,
119        }
120    }
121
122    pub fn with_aether_home(mut self, aether_home: impl Into<PathBuf>) -> Self {
123        self.aether_home = Some(aether_home.into());
124        self
125    }
126
127    pub fn with_oauth_credential_store(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
128        self.oauth_credential_store = Some(store);
129        self
130    }
131
132    pub async fn register_pending(&mut self, servers: Vec<McpServer>) -> Result<Vec<McpServer>> {
133        let has_proxy = servers.iter().any(|server| server.proxy);
134        if has_proxy && servers.iter().any(|server| server.name == DEFAULT_PROXY_NAME) {
135            return Err(McpError::Other("server name 'proxy' collides with the tool proxy".into()));
136        }
137
138        for server in &servers {
139            self.register_record(&server.name, ServerState::Connecting, None, server.proxy);
140        }
141
142        self.emit_server_statuses_changed().await;
143        Ok(servers)
144    }
145
146    pub async fn bootstrap_proxy_setup(&mut self, servers: &[McpServer]) -> Result<()> {
147        let proxied_members: HashSet<String> =
148            servers.iter().filter(|server| server.proxy).map(|server| server.name.clone()).collect();
149
150        if proxied_members.is_empty() {
151            return Ok(());
152        }
153
154        let dir = self.proxy_tool_dir()?;
155        ToolProxy::clean_dir(&dir).await?;
156        self.register_proxy(dir, proxied_members);
157        Ok(())
158    }
159
160    pub fn connect_pending_task(&self, server: McpServer) -> impl Future<Output = McpConnectAttempt> + Send + 'static {
161        let ctx = self.connect_config();
162        async move { connect_server(server, &ctx).await }
163    }
164
165    pub async fn add_mcps(&mut self, servers: Vec<McpServer>) -> Result<()> {
166        self.bootstrap_proxy_setup(&servers).await?;
167        let pending = self.register_pending(servers).await?;
168        let ctx = self.connect_config();
169        let attempts = join_all(pending.into_iter().map(|server| connect_server(server, &ctx))).await;
170        for attempt in attempts {
171            self.apply_connection_attempt(attempt).await;
172        }
173        Ok(())
174    }
175
176    pub fn get_client_for_tool(
177        &self,
178        namespaced_tool_name: &str,
179        arguments_json: &str,
180    ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
181        if !self.is_routable_tool(namespaced_tool_name) {
182            return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
183        }
184
185        let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
186            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
187
188        if let Some(proxy) = self.proxy.as_ref().filter(|proxy| proxy.name() == server_name) {
189            let call = proxy.resolve_call(arguments_json)?;
190            let conn = self.connection_for(&call.server).ok_or_else(|| {
191                McpError::ServerNotFound(format!("Proxied server '{}' is not connected", call.server))
192            })?;
193            let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
194            return Ok((conn.client.clone(), params));
195        }
196
197        let client =
198            self.client_for_server(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
199
200        let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
201        let mut params = CallToolRequestParams::new(tool_name.to_string());
202        if let Some(args) = arguments {
203            params = params.with_arguments(args);
204        }
205
206        Ok((client, params))
207    }
208
209    pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
210        let mut definitions = Vec::new();
211        if let Some(proxy) = self.proxy.as_ref() {
212            definitions.push(ToolProxy::call_tool_definition(proxy.name()));
213        }
214        for name in &self.server_order {
215            let Some(record) = self.servers.get(name) else { continue };
216            if record.proxied {
217                continue;
218            }
219            definitions.extend(record.tools().iter().map(|tool| {
220                ToolDefinition::new(
221                    create_namespaced_tool_name(name, &tool.name),
222                    tool.description.clone(),
223                    tool.parameters.to_string(),
224                )
225                .with_server(name.clone())
226                .with_annotations(tool.annotations.clone())
227            }));
228        }
229        definitions
230    }
231
232    pub fn server_instructions(&self) -> BTreeMap<String, String> {
233        let mut instructions: BTreeMap<String, String> = self
234            .servers
235            .iter()
236            .filter(|(name, _)| self.proxy.as_ref().is_none_or(|proxy| !proxy.contains_server(name)))
237            .filter_map(|(name, record)| {
238                record
239                    .connection()
240                    .and_then(|conn| conn.instructions.as_ref())
241                    .map(|instr| (name.clone(), instr.clone()))
242            })
243            .collect();
244
245        if let Some((name, body)) = self.proxy_instructions() {
246            instructions.insert(name, body);
247        }
248
249        instructions
250    }
251
252    pub fn server_statuses(&self) -> Vec<McpServerStatusEntry> {
253        self.server_order
254            .iter()
255            .filter_map(|name| self.servers.get(name).map(|record| record.status_entry(name)))
256            .collect()
257    }
258
259    pub async fn authenticate_server_task(
260        &mut self,
261        name: &str,
262    ) -> Result<impl Future<Output = McpConnectAttempt> + Send + 'static> {
263        let record = self
264            .servers
265            .get(name)
266            .ok_or_else(|| McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")))?;
267        if !record.can_authenticate() {
268            return Err(McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")));
269        }
270        if self.oauth_handler_factory.is_none() {
271            return Err(McpError::ConnectionFailed(format!("No OAuth handler factory available for '{name}'")));
272        }
273
274        let name = name.to_string();
275        let config = record.reauth_config.clone().expect("checked above");
276        let proxied = record.proxied;
277        let ctx = self.connect_config();
278
279        self.set_state(&name, ServerState::Authenticating);
280        self.emit_server_statuses_changed().await;
281
282        Ok(async move { authenticate_http(name, config, ctx, proxied).await })
283    }
284
285    pub async fn apply_connection_attempt(&mut self, attempt: McpConnectAttempt) {
286        let McpConnectAttempt { name, proxied, outcome } = attempt;
287        match outcome {
288            McpConnectOutcome::Connected { conn, reauth_config } => {
289                match self.register_connection(&name, conn, reauth_config, proxied).await {
290                    Ok(tools) => {
291                        self.refresh_proxy_after_auth(&name, &tools, proxied).await;
292                        self.emit_server_statuses_changed().await;
293                        self.emit_tool_definitions_changed().await;
294                        self.emit_instructions_after_connect(&name, proxied).await;
295                    }
296                    Err(error) => self.apply_authentication_failure(name, error.to_string()).await,
297                }
298            }
299            McpConnectOutcome::NeedsOAuth { config, error } => {
300                tracing::warn!("Server '{name}' needs OAuth: {error}");
301                self.register_record(&name, ServerState::NeedsOAuth, Some(config), proxied);
302                self.emit_server_statuses_changed().await;
303            }
304            McpConnectOutcome::Failed { error } => {
305                self.apply_authentication_failure(name, error.to_string()).await;
306            }
307        }
308    }
309
310    /// List all prompts from all connected MCP servers with namespacing
311    pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
312        let futures: Vec<_> = self
313            .servers
314            .iter()
315            .filter_map(|(server_name, record)| {
316                let conn = record.connection()?;
317                conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref())?;
318                let server_name = server_name.clone();
319                let client = conn.client.clone();
320                Some(async move {
321                    let prompts_response = client.list_prompts(None).await.map_err(|e| {
322                        McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
323                    })?;
324
325                    let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
326                        .prompts
327                        .into_iter()
328                        .map(|prompt| {
329                            let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
330                            rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
331                        })
332                        .collect();
333
334                    Ok::<_, McpError>(namespaced_prompts)
335                })
336            })
337            .collect();
338
339        let results = join_all(futures).await;
340        let mut all_prompts = Vec::new();
341        for result in results {
342            all_prompts.extend(result?);
343        }
344
345        Ok(all_prompts)
346    }
347
348    /// Get a specific prompt by namespaced name
349    pub async fn get_prompt(
350        &self,
351        namespaced_prompt_name: &str,
352        arguments: Option<serde_json::Map<String, serde_json::Value>>,
353    ) -> Result<rmcp::model::GetPromptResult> {
354        let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
355            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
356
357        let server_conn =
358            self.connection_for(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
359
360        let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
361        if let Some(args) = arguments {
362            request = request.with_arguments(args);
363        }
364
365        server_conn.client.get_prompt(request).await.map_err(|e| {
366            McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
367        })
368    }
369
370    /// Shutdown all servers and wait for their tasks to complete
371    pub async fn shutdown(&mut self) {
372        let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
373
374        for (server_name, record) in servers {
375            if let Some(conn) = record.into_connection()
376                && let Some(handle) = conn.server_task
377            {
378                drop(conn.client);
379
380                match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
381                    Ok(Ok(())) => {
382                        tracing::info!("Server '{server_name}' shut down gracefully");
383                    }
384                    Ok(Err(e)) => {
385                        tracing::warn!("Server '{server_name}' task panicked: {e:?}");
386                    }
387                    Err(_) => {
388                        tracing::warn!("Server '{server_name}' shutdown timed out");
389                    }
390                }
391            }
392        }
393
394        self.server_order.clear();
395        self.proxy = None;
396    }
397
398    /// Shutdown a specific server by name
399    pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
400        if let Some(record) = self.servers.remove(server_name)
401            && let Some(conn) = record.into_connection()
402            && let Some(handle) = conn.server_task
403        {
404            drop(conn.client);
405
406            match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
407                Ok(Ok(())) => {
408                    tracing::info!("Server '{server_name}' shut down gracefully");
409                }
410                Ok(Err(e)) => {
411                    tracing::warn!("Server '{server_name}' task panicked: {e:?}");
412                }
413                Err(_) => {
414                    tracing::warn!("Server '{server_name}' shutdown timed out");
415                }
416            }
417        }
418
419        Ok(())
420    }
421
422    /// Set the roots advertised to MCP servers.
423    ///
424    /// This updates the roots and sends notifications to all connected servers
425    /// that support the `roots/list_changed` notification.
426    pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
427        {
428            let mut roots = self.roots.write().await;
429            *roots = new_roots;
430        }
431
432        self.notify_roots_changed().await;
433
434        Ok(())
435    }
436
437    async fn emit_server_statuses_changed(&self) {
438        self.emit_event(McpClientEvent::ServerStatusesChanged(self.server_statuses())).await;
439    }
440
441    async fn emit_tool_definitions_changed(&self) {
442        self.emit_event(McpClientEvent::ToolDefinitionsChanged(self.tool_definitions())).await;
443    }
444
445    async fn emit_instructions_after_connect(&self, server_name: &str, proxied: bool) {
446        if proxied {
447            if let Some((server, body)) = self.proxy_instructions() {
448                self.emit_event(McpClientEvent::ServerInstructionsUpdated { server, instructions: Some(body) }).await;
449            }
450            return;
451        }
452
453        if let Some(instructions) =
454            self.connection_for(server_name).and_then(|conn| conn.instructions.as_ref()).cloned()
455        {
456            self.emit_event(McpClientEvent::ServerInstructionsUpdated {
457                server: server_name.to_string(),
458                instructions: Some(instructions),
459            })
460            .await;
461        }
462    }
463
464    fn proxy_instructions(&self) -> Option<(String, String)> {
465        let proxy = self.proxy.as_ref()?;
466        let descriptions: Vec<(String, String)> = proxy
467            .members()
468            .iter()
469            .filter_map(|member| {
470                let conn = self.connection_for(member)?;
471                Some((member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
472            })
473            .collect();
474        Some((proxy.name().to_string(), ToolProxy::build_instructions(proxy.tool_dir(), &descriptions)))
475    }
476
477    pub async fn emit_connection_ready(&self) {
478        self.emit_event(McpClientEvent::ConnectionReady(McpConnectionDetails {
479            tool_definitions: self.tool_definitions(),
480            instructions: self.server_instructions(),
481            server_statuses: self.server_statuses(),
482        }))
483        .await;
484    }
485
486    async fn emit_authentication_failed(&self, server: String, error: String) {
487        self.emit_event(McpClientEvent::AuthenticationFailed { server, error }).await;
488    }
489
490    async fn emit_event(&self, event: McpClientEvent) {
491        if let Err(e) = self.event_sender.send(event).await {
492            tracing::warn!("Failed to emit MCP client event: {e}");
493        }
494    }
495
496    fn connect_config(&self) -> Arc<ConnectConfig> {
497        Arc::new(ConnectConfig {
498            client_info: self.client_info.clone(),
499            event_sender: self.event_sender.clone(),
500            roots: Arc::clone(&self.roots),
501            oauth_handler_factory: self.oauth_handler_factory.clone(),
502            oauth_credential_store: self.oauth_credential_store.clone(),
503        })
504    }
505
506    fn proxy_tool_dir(&self) -> Result<PathBuf> {
507        self.aether_home
508            .as_ref()
509            .map(|home| ToolProxy::dir_in_home(home, DEFAULT_PROXY_NAME))
510            .map_or_else(|| ToolProxy::dir(DEFAULT_PROXY_NAME), Ok)
511    }
512
513    async fn register_connection(
514        &mut self,
515        name: &str,
516        conn: McpServerConnection,
517        reauth_config: Option<StreamableHttpClientTransportConfig>,
518        proxied: bool,
519    ) -> Result<Vec<RmcpTool>> {
520        let tools = conn
521            .list_tools()
522            .await
523            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
524        self.apply_connected(name, conn, &tools, reauth_config, proxied);
525        Ok(tools)
526    }
527
528    fn apply_connected(
529        &mut self,
530        name: &str,
531        conn: McpServerConnection,
532        tools: &[RmcpTool],
533        reauth_config: Option<StreamableHttpClientTransportConfig>,
534        proxied: bool,
535    ) {
536        let existing_reauth = self.servers.get(name).and_then(|r| r.reauth_config.clone());
537        let final_reauth = reauth_config.or(existing_reauth);
538        let tools = tools.iter().map(Tool::from).collect();
539
540        self.remember_server_order(name);
541        self.servers.insert(name.to_string(), ServerRecord::connected(conn, tools, final_reauth, proxied));
542    }
543
544    fn register_proxy(&mut self, tool_dir: std::path::PathBuf, members: HashSet<String>) {
545        self.proxy = Some(ToolProxy::new(DEFAULT_PROXY_NAME.to_string(), members, tool_dir));
546    }
547
548    async fn refresh_proxy_after_auth(&mut self, name: &str, tools: &[RmcpTool], proxied: bool) {
549        if !proxied {
550            return;
551        }
552
553        if let Some(proxy) = self.proxy.as_mut() {
554            proxy.add_member(name.to_string());
555        }
556
557        if let Some(tool_dir) = self.proxy.as_ref().map(|proxy| proxy.tool_dir().to_path_buf())
558            && let Err(e) = ToolProxy::write_tool_entries_to_dir(name, tools, &tool_dir).await
559        {
560            tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
561        }
562    }
563
564    fn remember_server_order(&mut self, name: &str) {
565        if !self.server_order.iter().any(|n| n == name) {
566            self.server_order.push(name.to_string());
567        }
568    }
569
570    async fn apply_authentication_failure(&mut self, name: String, error: String) {
571        self.set_state(&name, ServerState::Failed { error: error.clone() });
572        self.emit_server_statuses_changed().await;
573        self.emit_authentication_failed(name, error).await;
574    }
575
576    fn set_state(&mut self, name: &str, state: ServerState) {
577        self.remember_server_order(name);
578        match self.servers.get_mut(name) {
579            Some(record) => record.state = state,
580            None => {
581                self.servers.insert(name.to_string(), ServerRecord::new(state, None, false));
582            }
583        }
584    }
585
586    fn register_record(
587        &mut self,
588        name: &str,
589        state: ServerState,
590        reauth_config: Option<StreamableHttpClientTransportConfig>,
591        proxied: bool,
592    ) {
593        self.remember_server_order(name);
594        self.servers.insert(name.to_string(), ServerRecord::new(state, reauth_config, proxied));
595    }
596
597    fn connection_for(&self, server_name: &str) -> Option<&McpServerConnection> {
598        self.servers.get(server_name).and_then(ServerRecord::connection)
599    }
600
601    fn client_for_server(&self, server_name: &str) -> Option<Arc<RunningService<RoleClient, McpClient>>> {
602        self.connection_for(server_name).map(|conn| conn.client.clone())
603    }
604
605    fn is_routable_tool(&self, namespaced_tool_name: &str) -> bool {
606        if self.proxy.as_ref().is_some_and(|proxy| proxy.call_tool_name() == namespaced_tool_name) {
607            return true;
608        }
609        match split_on_server_name(namespaced_tool_name) {
610            Some((server_name, tool_name)) => {
611                self.servers.get(server_name).is_some_and(|record| !record.proxied && record.has_tool(tool_name))
612            }
613            None => false,
614        }
615    }
616
617    async fn notify_roots_changed(&self) {
618        for (server_name, record) in &self.servers {
619            if let Some(conn) = record.connection()
620                && let Err(e) = conn.client.notify_roots_list_changed().await
621            {
622                tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
623            }
624        }
625    }
626}
627
628impl Drop for McpManager {
629    fn drop(&mut self) {
630        let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
631        for (server_name, record) in servers {
632            if let Some(conn) = record.into_connection()
633                && let Some(handle) = conn.server_task
634            {
635                handle.abort();
636                tracing::warn!("Server '{server_name}' task aborted during cleanup");
637            }
638        }
639    }
640}
641
642/// Internal record holding all mutable state for a single MCP server.
643struct ServerRecord {
644    state: ServerState,
645    reauth_config: Option<StreamableHttpClientTransportConfig>,
646    proxied: bool,
647}
648
649enum ServerState {
650    Connecting,
651    Connected { connection: McpServerConnection, tools: Vec<Tool> },
652    Authenticating,
653    Failed { error: String },
654    NeedsOAuth,
655}
656
657impl From<&ServerState> for McpServerStatus {
658    fn from(state: &ServerState) -> Self {
659        match state {
660            ServerState::Connecting => Self::Connecting,
661            ServerState::Connected { tools, .. } => Self::Connected { tool_count: tools.len() },
662            ServerState::Authenticating => Self::Authenticating,
663            ServerState::Failed { error } => Self::Failed { error: error.clone() },
664            ServerState::NeedsOAuth => Self::NeedsOAuth,
665        }
666    }
667}
668
669impl ServerRecord {
670    fn new(state: ServerState, reauth_config: Option<StreamableHttpClientTransportConfig>, proxied: bool) -> Self {
671        Self { state, reauth_config, proxied }
672    }
673
674    fn connected(
675        connection: McpServerConnection,
676        tools: Vec<Tool>,
677        reauth_config: Option<StreamableHttpClientTransportConfig>,
678        proxied: bool,
679    ) -> Self {
680        Self::new(ServerState::Connected { connection, tools }, reauth_config, proxied)
681    }
682
683    fn tools(&self) -> &[Tool] {
684        match &self.state {
685            ServerState::Connected { tools, .. } => tools,
686            ServerState::Connecting
687            | ServerState::Authenticating
688            | ServerState::Failed { .. }
689            | ServerState::NeedsOAuth => &[],
690        }
691    }
692
693    fn has_tool(&self, tool_name: &str) -> bool {
694        self.tools().iter().any(|tool| tool.name == tool_name)
695    }
696
697    fn connection(&self) -> Option<&McpServerConnection> {
698        match &self.state {
699            ServerState::Connected { connection, .. } => Some(connection),
700            ServerState::Connecting
701            | ServerState::Authenticating
702            | ServerState::Failed { .. }
703            | ServerState::NeedsOAuth => None,
704        }
705    }
706
707    fn into_connection(self) -> Option<McpServerConnection> {
708        match self.state {
709            ServerState::Connected { connection, .. } => Some(connection),
710            ServerState::Connecting
711            | ServerState::Authenticating
712            | ServerState::Failed { .. }
713            | ServerState::NeedsOAuth => None,
714        }
715    }
716
717    fn auth_capability(&self) -> McpServerAuthCapability {
718        if self.reauth_config.is_some() { McpServerAuthCapability::OAuth } else { McpServerAuthCapability::Unavailable }
719    }
720
721    fn can_authenticate(&self) -> bool {
722        self.reauth_config.is_some()
723    }
724
725    fn status(&self) -> McpServerStatus {
726        (&self.state).into()
727    }
728
729    fn status_entry(&self, name: &str) -> McpServerStatusEntry {
730        McpServerStatusEntry::new(name, self.status())
731            .with_auth_capability(self.auth_capability())
732            .with_proxied(self.proxied)
733    }
734}
735
736#[cfg(test)]
737mod tests {
738    use super::{DEFAULT_PROXY_NAME, McpClientEvent, McpManager, McpServerStatus, ServerState};
739    use crate::client::OAuthHandlerFactory;
740    use crate::client::config::{McpServer, McpTransport};
741    use crate::client::connection::{McpConnectAttempt, McpConnectOutcome};
742    use crate::status::McpServerAuthCapability;
743    use aether_auth::{OAuthCallback, OAuthError, OAuthHandler};
744    use futures::future::BoxFuture;
745    use rmcp::{
746        Json, RoleServer, ServerHandler,
747        handler::server::{router::tool::ToolRouter, wrapper::Parameters},
748        model::{Implementation, ServerCapabilities, ServerInfo},
749        service::DynService,
750        tool, tool_handler, tool_router,
751        transport::streamable_http_client::StreamableHttpClientTransportConfig,
752    };
753    use schemars::JsonSchema;
754    use serde::{Deserialize, Serialize};
755    use std::{
756        io,
757        sync::{Arc, Mutex},
758    };
759    use tokio::sync::mpsc;
760
761    #[derive(Clone)]
762    struct TestServer {
763        tool_router: ToolRouter<Self>,
764    }
765
766    #[tool_handler(router = self.tool_router)]
767    impl ServerHandler for TestServer {
768        fn get_info(&self) -> ServerInfo {
769            ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
770                .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
771                .with_instructions("Test server instructions")
772        }
773    }
774
775    impl Default for TestServer {
776        fn default() -> Self {
777            Self { tool_router: Self::tool_router() }
778        }
779    }
780
781    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
782    struct EchoRequest {
783        value: String,
784    }
785
786    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
787    struct EchoResult {
788        value: String,
789    }
790
791    #[tool_router]
792    impl TestServer {
793        fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
794            Box::new(self)
795        }
796
797        #[tool(description = "Returns the provided value", annotations(read_only_hint = true, open_world_hint = false))]
798        async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
799            let Parameters(EchoRequest { value }) = request;
800            Json(EchoResult { value })
801        }
802    }
803
804    #[derive(Clone)]
805    struct SharedWriter(Arc<Mutex<Vec<u8>>>);
806
807    impl io::Write for SharedWriter {
808        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
809            self.0.lock().unwrap().extend_from_slice(buf);
810            Ok(buf.len())
811        }
812
813        fn flush(&mut self) -> io::Result<()> {
814            Ok(())
815        }
816    }
817
818    struct TestOAuthHandler;
819
820    impl OAuthHandler for TestOAuthHandler {
821        fn redirect_uri(&self) -> &'static str {
822            "http://127.0.0.1:0/oauth2callback"
823        }
824
825        fn authorize(&self, _auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
826            Box::pin(async { Err(OAuthError::UserCancelled) })
827        }
828    }
829
830    fn test_oauth_handler_factory() -> OAuthHandlerFactory {
831        Arc::new(|_ctx| Ok(Arc::new(TestOAuthHandler)))
832    }
833
834    #[tokio::test]
835    async fn authenticate_server_task_rejects_record_without_reauth_config() {
836        let (event_sender, _event_receiver) = mpsc::channel(1);
837        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
838        manager.register_record("public", ServerState::Connecting, None, false);
839
840        let error = match manager.authenticate_server_task("public").await {
841            Ok(_) => panic!("non-OAuth server should be rejected"),
842            Err(error) => error.to_string(),
843        };
844        assert!(error.contains("not OAuth-authenticatable"));
845    }
846
847    #[tokio::test]
848    async fn authenticate_server_task_marks_server_authenticating_and_emits_status() {
849        let (event_sender, mut event_receiver) = mpsc::channel(2);
850        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
851        manager.register_record(
852            "remote",
853            ServerState::NeedsOAuth,
854            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
855            false,
856        );
857
858        let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
859
860        assert!(matches!(manager.server_statuses()[0].status, McpServerStatus::Authenticating));
861        let event = event_receiver.recv().await.expect("status change event");
862        let McpClientEvent::ServerStatusesChanged(servers) = event else {
863            panic!("expected ServerStatusesChanged");
864        };
865        let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
866        assert!(matches!(status.status, McpServerStatus::Authenticating));
867        assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
868    }
869
870    #[tokio::test]
871    async fn apply_connection_attempt_failure_allows_retry() {
872        let (event_sender, mut event_receiver) = mpsc::channel(2);
873        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
874        manager.register_record(
875            "remote",
876            ServerState::NeedsOAuth,
877            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
878            false,
879        );
880        let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
881        let _authenticating_event = event_receiver.recv().await.expect("authenticating status change event");
882
883        manager
884            .apply_connection_attempt(McpConnectAttempt {
885                name: "remote".to_string(),
886                proxied: false,
887                outcome: McpConnectOutcome::Failed {
888                    error: crate::client::McpError::ConnectionFailed("boom".to_string()),
889                },
890            })
891            .await;
892
893        let event = event_receiver.recv().await.expect("status change event");
894        let McpClientEvent::ServerStatusesChanged(servers) = event else {
895            panic!("expected ServerStatusesChanged");
896        };
897        let auth_event = event_receiver.recv().await.expect("authentication failure event");
898        let McpClientEvent::AuthenticationFailed { server, error } = auth_event else {
899            panic!("expected AuthenticationFailed");
900        };
901        assert_eq!(server, "remote");
902        assert!(error.contains("boom"));
903
904        let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
905        assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
906        assert!(matches!(status.status, McpServerStatus::Failed { ref error } if error.contains("boom")));
907        assert!(manager.authenticate_server_task("remote").await.is_ok());
908    }
909
910    #[test]
911    fn status_entries_are_derived_from_reauth_config() {
912        let (event_sender, _event_receiver) = mpsc::channel(1);
913        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
914
915        manager.register_record(
916            "with-oauth",
917            ServerState::Connecting,
918            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp")),
919            false,
920        );
921        manager.register_record("without-oauth", ServerState::Connecting, None, false);
922        manager.register_record(
923            "needs-oauth",
924            ServerState::NeedsOAuth,
925            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp2")),
926            false,
927        );
928
929        let statuses = manager.server_statuses();
930        let with_oauth = statuses.iter().find(|s| s.name == "with-oauth").unwrap();
931        let without_oauth = statuses.iter().find(|s| s.name == "without-oauth").unwrap();
932        let needs_oauth = statuses.iter().find(|s| s.name == "needs-oauth").unwrap();
933
934        assert_eq!(with_oauth.auth_capability, McpServerAuthCapability::OAuth);
935        assert_eq!(without_oauth.auth_capability, McpServerAuthCapability::Unavailable);
936        assert_eq!(needs_oauth.auth_capability, McpServerAuthCapability::OAuth);
937    }
938
939    #[tokio::test]
940    async fn register_pending_marks_every_server_connecting_and_emits_status() {
941        let (event_sender, mut event_receiver) = mpsc::channel(32);
942        let mut manager = McpManager::new(event_sender, None);
943
944        let servers = vec![
945            McpServer::new("alpha", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
946            McpServer::new("beta", McpTransport::InMemory { server: TestServer::default().into_dyn() }, true),
947        ];
948
949        let returned = manager.register_pending(servers).await.unwrap();
950        assert_eq!(returned.iter().map(|s| s.name.as_str()).collect::<Vec<_>>(), vec!["alpha", "beta"]);
951
952        let statuses = manager.server_statuses();
953        assert_eq!(statuses.len(), 2);
954        assert!(matches!(statuses.iter().find(|s| s.name == "alpha").unwrap().status, McpServerStatus::Connecting));
955        assert!(matches!(statuses.iter().find(|s| s.name == "beta").unwrap().status, McpServerStatus::Connecting));
956        assert!(statuses.iter().find(|s| s.name == "beta").unwrap().proxied);
957
958        let event = event_receiver.try_recv().expect("expected initial ServerStatusesChanged emission");
959        let McpClientEvent::ServerStatusesChanged(emitted) = event else {
960            panic!("expected ServerStatusesChanged, got {event:?}");
961        };
962        assert_eq!(emitted.iter().map(|s| s.name.as_str()).collect::<Vec<_>>(), vec!["alpha", "beta"]);
963    }
964
965    #[tokio::test]
966    async fn apply_connection_attempt_emits_instructions_updated_after_connect() {
967        let (event_sender, mut event_receiver) = mpsc::channel(32);
968        let mut manager = McpManager::new(event_sender, None);
969
970        let servers =
971            vec![McpServer::new("test", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false)];
972        manager.add_mcps(servers).await.unwrap();
973
974        let mut update_for_test = None;
975        while let Ok(event) = event_receiver.try_recv() {
976            if let McpClientEvent::ServerInstructionsUpdated { server, instructions } = event
977                && server == "test"
978            {
979                update_for_test = Some(instructions);
980            }
981        }
982        let instructions = update_for_test.expect("expected ServerInstructionsUpdated for 'test'");
983        assert!(instructions.is_some(), "TestServer publishes instructions, so update should carry Some(_)");
984    }
985
986    #[tokio::test]
987    async fn server_statuses_mark_direct_and_proxied_servers_without_proxy_row() {
988        let (event_sender, _event_receiver) = mpsc::channel(32);
989        let mut manager = McpManager::new(event_sender, None);
990        manager
991            .add_mcps(vec![
992                McpServer::new("direct", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
993                McpServer::new("math", McpTransport::InMemory { server: TestServer::default().into_dyn() }, true),
994            ])
995            .await
996            .unwrap();
997
998        let statuses = manager.server_statuses();
999        assert_eq!(statuses.iter().map(|status| status.name.as_str()).collect::<Vec<_>>(), vec!["direct", "math"]);
1000        assert!(!statuses.iter().find(|status| status.name == "direct").unwrap().proxied);
1001        assert!(statuses.iter().find(|status| status.name == "math").unwrap().proxied);
1002        assert!(!statuses.iter().any(|status| status.name == DEFAULT_PROXY_NAME));
1003    }
1004
1005    #[tokio::test]
1006    async fn tool_definitions_drop_when_a_server_shuts_down() {
1007        let (event_sender, _event_receiver) = mpsc::channel(32);
1008        let mut manager = McpManager::new(event_sender, None);
1009        manager
1010            .add_mcps(vec![
1011                McpServer::new("git", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
1012                McpServer::new("github", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
1013            ])
1014            .await
1015            .unwrap();
1016
1017        let names =
1018            |manager: &McpManager| manager.tool_definitions().into_iter().map(|tool| tool.name).collect::<Vec<_>>();
1019        assert!(names(&manager).contains(&"git__echo".to_string()));
1020        assert!(names(&manager).contains(&"github__echo".to_string()));
1021
1022        manager.shutdown_server("git").await.unwrap();
1023
1024        assert!(!names(&manager).iter().any(|name| name.starts_with("git__")));
1025        assert!(names(&manager).contains(&"github__echo".to_string()));
1026    }
1027
1028    #[tokio::test]
1029    async fn tool_definitions_preserve_annotations() {
1030        let (event_sender, _event_receiver) = mpsc::channel(32);
1031        let mut manager = McpManager::new(event_sender, None);
1032        manager
1033            .add_mcps(vec![McpServer::new(
1034                "test",
1035                McpTransport::InMemory { server: TestServer::default().into_dyn() },
1036                false,
1037            )])
1038            .await
1039            .unwrap();
1040
1041        let tools = manager.tool_definitions();
1042        let echo = tools.iter().find(|tool| tool.name == "test__echo").expect("echo tool");
1043        let annotations = echo.annotations.as_ref().expect("annotations should be preserved");
1044        assert_eq!(annotations.read_only_hint, Some(true));
1045        assert_eq!(annotations.open_world_hint, Some(false));
1046    }
1047
1048    #[tokio::test]
1049    async fn drop_logs_cleanup_abort_with_tracing() {
1050        let (event_sender, _event_receiver) = mpsc::channel(32);
1051        let mut manager = McpManager::new(event_sender, None);
1052        manager
1053            .add_mcps(vec![McpServer::new(
1054                "test",
1055                McpTransport::InMemory { server: TestServer::default().into_dyn() },
1056                false,
1057            )])
1058            .await
1059            .unwrap();
1060
1061        let output = Arc::new(Mutex::new(Vec::new()));
1062        let subscriber = tracing_subscriber::fmt()
1063            .with_ansi(false)
1064            .without_time()
1065            .with_writer({
1066                let output = Arc::clone(&output);
1067                move || SharedWriter(Arc::clone(&output))
1068            })
1069            .finish();
1070
1071        tracing::subscriber::with_default(subscriber, || {
1072            drop(manager);
1073        });
1074
1075        let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
1076        assert!(logs.contains("Server 'test' task aborted during cleanup"));
1077    }
1078}