Skip to main content

mcp_utils/client/
manager.rs

1use llm::ToolDefinition;
2
3use super::{
4    McpError, Result,
5    config::McpServer,
6    connection::{
7        ConnectConfig, McpConnectAttempt, McpConnectOutcome, McpServerConnection, Tool, authenticate_http,
8        connect_server,
9    },
10    mcp_client::McpClient,
11    naming::{create_namespaced_tool_name, split_on_server_name},
12    tool_proxy::ToolProxy,
13};
14use aether_auth::{OAuthCredentialStorage, OAuthHandler};
15use futures::future::join_all;
16use rmcp::{
17    RoleClient,
18    model::{
19        CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
20        ElicitationAction, FormElicitationCapability, Implementation, Root, Tool as RmcpTool, UrlElicitationCapability,
21    },
22    service::RunningService,
23    transport::streamable_http_client::StreamableHttpClientTransportConfig,
24};
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::{BTreeMap, HashMap, HashSet};
28use std::future::Future;
29use std::path::PathBuf;
30use std::sync::Arc;
31use tokio::sync::{RwLock, mpsc, oneshot};
32
33pub use crate::status::{McpServerAuthCapability, McpServerStatus, McpServerStatusEntry};
34
35pub const DEFAULT_PROXY_NAME: &str = "proxy";
36
37pub type OAuthHandlerFactory = Arc<dyn Fn(OAuthHandlerContext) -> Result<Arc<dyn OAuthHandler>> + Send + Sync>;
38
39/// Context passed to an `OAuthHandlerFactory` so the constructed handler can
40/// dispatch user-facing prompts back to the host through the MCP event channel.
41#[derive(Clone)]
42pub struct OAuthHandlerContext {
43    pub server_name: String,
44    pub tx: mpsc::Sender<McpClientEvent>,
45}
46
47#[derive(Debug)]
48pub struct ElicitationRequest {
49    pub server_name: String,
50    pub request: CreateElicitationRequestParams,
51    pub response_sender: oneshot::Sender<CreateElicitationResult>,
52}
53
54#[derive(Debug, Clone)]
55pub struct ElicitationResponse {
56    pub action: ElicitationAction,
57    pub content: Option<Value>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
61pub struct UrlElicitationCompleteParams {
62    pub server_name: String,
63    pub elicitation_id: String,
64}
65
66/// Events emitted by MCP clients that require attention from the host
67/// (e.g. the relay or TUI). Flows through a single channel from `McpManager`
68/// to the consumer.
69#[derive(Debug)]
70pub enum McpClientEvent {
71    Elicitation(ElicitationRequest),
72    UrlElicitationComplete(UrlElicitationCompleteParams),
73    ServerStatusesChanged(Vec<McpServerStatusEntry>),
74    ToolDefinitionsChanged(Vec<ToolDefinition>),
75    ServerInstructionsUpdated { server: String, instructions: Option<String> },
76    AuthenticationFailed { server: String, error: String },
77    ConnectionReady(McpConnectionDetails),
78}
79
80#[derive(Debug, Clone)]
81pub struct McpConnectionDetails {
82    pub instructions: BTreeMap<String, String>,
83    pub tool_definitions: Vec<ToolDefinition>,
84    pub server_statuses: Vec<McpServerStatusEntry>,
85}
86
87/// Manages connections to multiple MCP servers and their tools
88pub struct McpManager {
89    servers: HashMap<String, ServerRecord>,
90    server_order: Vec<String>,
91    proxy: Option<ToolProxy>,
92    aether_home: Option<PathBuf>,
93    client_info: ClientInfo,
94    event_sender: mpsc::Sender<McpClientEvent>,
95    /// Roots shared with all MCP clients
96    roots: Arc<RwLock<Vec<Root>>>,
97    oauth_handler_factory: Option<OAuthHandlerFactory>,
98    oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
99}
100
101impl McpManager {
102    pub fn new(event_sender: mpsc::Sender<McpClientEvent>, oauth_handler_factory: Option<OAuthHandlerFactory>) -> Self {
103        let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
104        if let Some(elicitation) = capabilities.elicitation.as_mut() {
105            elicitation.form = Some(FormElicitationCapability::default());
106            elicitation.url = Some(UrlElicitationCapability::default());
107        }
108
109        Self {
110            servers: HashMap::new(),
111            server_order: Vec::new(),
112            proxy: None,
113            aether_home: None,
114            client_info: ClientInfo::new(capabilities, Implementation::new("aether", "0.1.0")),
115            event_sender,
116            roots: Arc::new(RwLock::new(Vec::new())),
117            oauth_handler_factory,
118            oauth_credential_store: None,
119        }
120    }
121
122    pub fn with_aether_home(mut self, aether_home: impl Into<PathBuf>) -> Self {
123        self.aether_home = Some(aether_home.into());
124        self
125    }
126
127    pub fn with_oauth_credential_store(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
128        self.oauth_credential_store = Some(store);
129        self
130    }
131
132    pub async fn register_pending(&mut self, servers: Vec<McpServer>) -> Result<Vec<McpServer>> {
133        let has_proxy = servers.iter().any(|server| server.proxy);
134        if has_proxy && servers.iter().any(|server| server.name == DEFAULT_PROXY_NAME) {
135            return Err(McpError::Other("server name 'proxy' collides with the tool proxy".into()));
136        }
137
138        for server in &servers {
139            self.register_record(&server.name, ServerState::Connecting, None, server.proxy);
140        }
141
142        self.emit_server_statuses_changed().await;
143        Ok(servers)
144    }
145
146    pub async fn bootstrap_proxy_setup(&mut self, servers: &[McpServer]) -> Result<()> {
147        let proxied_members: HashSet<String> =
148            servers.iter().filter(|server| server.proxy).map(|server| server.name.clone()).collect();
149
150        if proxied_members.is_empty() {
151            return Ok(());
152        }
153
154        let dir = self.proxy_tool_dir()?;
155        ToolProxy::clean_dir(&dir).await?;
156        self.register_proxy(dir, proxied_members);
157        Ok(())
158    }
159
160    pub fn connect_pending_task(&self, server: McpServer) -> impl Future<Output = McpConnectAttempt> + Send + 'static {
161        let ctx = self.connect_config();
162        async move { connect_server(server, &ctx).await }
163    }
164
165    pub async fn add_mcps(&mut self, servers: Vec<McpServer>) -> Result<()> {
166        self.bootstrap_proxy_setup(&servers).await?;
167        let pending = self.register_pending(servers).await?;
168        let ctx = self.connect_config();
169        let attempts = join_all(pending.into_iter().map(|server| connect_server(server, &ctx))).await;
170        for attempt in attempts {
171            self.apply_connection_attempt(attempt).await;
172        }
173        Ok(())
174    }
175
176    pub fn get_client_for_tool(
177        &self,
178        namespaced_tool_name: &str,
179        arguments_json: &str,
180    ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
181        if !self.is_routable_tool(namespaced_tool_name) {
182            return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
183        }
184
185        let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
186            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
187
188        if let Some(proxy) = self.proxy.as_ref().filter(|proxy| proxy.name() == server_name) {
189            let call = proxy.resolve_call(arguments_json)?;
190            let conn = self.connection_for(&call.server).ok_or_else(|| {
191                McpError::ServerNotFound(format!("Proxied server '{}' is not connected", call.server))
192            })?;
193            let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
194            return Ok((conn.client.clone(), params));
195        }
196
197        let client =
198            self.client_for_server(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
199
200        let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
201        let mut params = CallToolRequestParams::new(tool_name.to_string());
202        if let Some(args) = arguments {
203            params = params.with_arguments(args);
204        }
205
206        Ok((client, params))
207    }
208
209    pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
210        let mut definitions = Vec::new();
211        if let Some(proxy) = self.proxy.as_ref() {
212            definitions.push(ToolProxy::call_tool_definition(proxy.name()));
213        }
214        for name in &self.server_order {
215            let Some(record) = self.servers.get(name) else { continue };
216            if record.proxied {
217                continue;
218            }
219            definitions.extend(record.tools().iter().map(|tool| ToolDefinition {
220                name: create_namespaced_tool_name(name, &tool.name),
221                description: tool.description.clone(),
222                parameters: tool.parameters.to_string(),
223                server: Some(name.clone()),
224            }));
225        }
226        definitions
227    }
228
229    pub fn server_instructions(&self) -> BTreeMap<String, String> {
230        let mut instructions: BTreeMap<String, String> = self
231            .servers
232            .iter()
233            .filter(|(name, _)| self.proxy.as_ref().is_none_or(|proxy| !proxy.contains_server(name)))
234            .filter_map(|(name, record)| {
235                record
236                    .connection()
237                    .and_then(|conn| conn.instructions.as_ref())
238                    .map(|instr| (name.clone(), instr.clone()))
239            })
240            .collect();
241
242        if let Some((name, body)) = self.proxy_instructions() {
243            instructions.insert(name, body);
244        }
245
246        instructions
247    }
248
249    pub fn server_statuses(&self) -> Vec<McpServerStatusEntry> {
250        self.server_order
251            .iter()
252            .filter_map(|name| self.servers.get(name).map(|record| record.status_entry(name)))
253            .collect()
254    }
255
256    pub async fn authenticate_server_task(
257        &mut self,
258        name: &str,
259    ) -> Result<impl Future<Output = McpConnectAttempt> + Send + 'static> {
260        let record = self
261            .servers
262            .get(name)
263            .ok_or_else(|| McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")))?;
264        if !record.can_authenticate() {
265            return Err(McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")));
266        }
267        if self.oauth_handler_factory.is_none() {
268            return Err(McpError::ConnectionFailed(format!("No OAuth handler factory available for '{name}'")));
269        }
270
271        let name = name.to_string();
272        let config = record.reauth_config.clone().expect("checked above");
273        let proxied = record.proxied;
274        let ctx = self.connect_config();
275
276        self.set_state(&name, ServerState::Authenticating);
277        self.emit_server_statuses_changed().await;
278
279        Ok(async move { authenticate_http(name, config, ctx, proxied).await })
280    }
281
282    pub async fn apply_connection_attempt(&mut self, attempt: McpConnectAttempt) {
283        let McpConnectAttempt { name, proxied, outcome } = attempt;
284        match outcome {
285            McpConnectOutcome::Connected { conn, reauth_config } => {
286                match self.register_connection(&name, conn, reauth_config, proxied).await {
287                    Ok(tools) => {
288                        self.refresh_proxy_after_auth(&name, &tools, proxied).await;
289                        self.emit_server_statuses_changed().await;
290                        self.emit_tool_definitions_changed().await;
291                        self.emit_instructions_after_connect(&name, proxied).await;
292                    }
293                    Err(error) => self.apply_authentication_failure(name, error.to_string()).await,
294                }
295            }
296            McpConnectOutcome::NeedsOAuth { config, error } => {
297                tracing::warn!("Server '{name}' needs OAuth: {error}");
298                self.register_record(&name, ServerState::NeedsOAuth, Some(config), proxied);
299                self.emit_server_statuses_changed().await;
300            }
301            McpConnectOutcome::Failed { error } => {
302                self.apply_authentication_failure(name, error.to_string()).await;
303            }
304        }
305    }
306
307    /// List all prompts from all connected MCP servers with namespacing
308    pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
309        let futures: Vec<_> = self
310            .servers
311            .iter()
312            .filter_map(|(server_name, record)| {
313                let conn = record.connection()?;
314                conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref())?;
315                let server_name = server_name.clone();
316                let client = conn.client.clone();
317                Some(async move {
318                    let prompts_response = client.list_prompts(None).await.map_err(|e| {
319                        McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
320                    })?;
321
322                    let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
323                        .prompts
324                        .into_iter()
325                        .map(|prompt| {
326                            let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
327                            rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
328                        })
329                        .collect();
330
331                    Ok::<_, McpError>(namespaced_prompts)
332                })
333            })
334            .collect();
335
336        let results = join_all(futures).await;
337        let mut all_prompts = Vec::new();
338        for result in results {
339            all_prompts.extend(result?);
340        }
341
342        Ok(all_prompts)
343    }
344
345    /// Get a specific prompt by namespaced name
346    pub async fn get_prompt(
347        &self,
348        namespaced_prompt_name: &str,
349        arguments: Option<serde_json::Map<String, serde_json::Value>>,
350    ) -> Result<rmcp::model::GetPromptResult> {
351        let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
352            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
353
354        let server_conn =
355            self.connection_for(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
356
357        let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
358        if let Some(args) = arguments {
359            request = request.with_arguments(args);
360        }
361
362        server_conn.client.get_prompt(request).await.map_err(|e| {
363            McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
364        })
365    }
366
367    /// Shutdown all servers and wait for their tasks to complete
368    pub async fn shutdown(&mut self) {
369        let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
370
371        for (server_name, record) in servers {
372            if let Some(conn) = record.into_connection()
373                && let Some(handle) = conn.server_task
374            {
375                drop(conn.client);
376
377                match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
378                    Ok(Ok(())) => {
379                        tracing::info!("Server '{server_name}' shut down gracefully");
380                    }
381                    Ok(Err(e)) => {
382                        tracing::warn!("Server '{server_name}' task panicked: {e:?}");
383                    }
384                    Err(_) => {
385                        tracing::warn!("Server '{server_name}' shutdown timed out");
386                    }
387                }
388            }
389        }
390
391        self.server_order.clear();
392        self.proxy = None;
393    }
394
395    /// Shutdown a specific server by name
396    pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
397        if let Some(record) = self.servers.remove(server_name)
398            && let Some(conn) = record.into_connection()
399            && let Some(handle) = conn.server_task
400        {
401            drop(conn.client);
402
403            match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
404                Ok(Ok(())) => {
405                    tracing::info!("Server '{server_name}' shut down gracefully");
406                }
407                Ok(Err(e)) => {
408                    tracing::warn!("Server '{server_name}' task panicked: {e:?}");
409                }
410                Err(_) => {
411                    tracing::warn!("Server '{server_name}' shutdown timed out");
412                }
413            }
414        }
415
416        Ok(())
417    }
418
419    /// Set the roots advertised to MCP servers.
420    ///
421    /// This updates the roots and sends notifications to all connected servers
422    /// that support the `roots/list_changed` notification.
423    pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
424        {
425            let mut roots = self.roots.write().await;
426            *roots = new_roots;
427        }
428
429        self.notify_roots_changed().await;
430
431        Ok(())
432    }
433
434    async fn emit_server_statuses_changed(&self) {
435        self.emit_event(McpClientEvent::ServerStatusesChanged(self.server_statuses())).await;
436    }
437
438    async fn emit_tool_definitions_changed(&self) {
439        self.emit_event(McpClientEvent::ToolDefinitionsChanged(self.tool_definitions())).await;
440    }
441
442    async fn emit_instructions_after_connect(&self, server_name: &str, proxied: bool) {
443        if proxied {
444            if let Some((server, body)) = self.proxy_instructions() {
445                self.emit_event(McpClientEvent::ServerInstructionsUpdated { server, instructions: Some(body) }).await;
446            }
447            return;
448        }
449
450        if let Some(instructions) =
451            self.connection_for(server_name).and_then(|conn| conn.instructions.as_ref()).cloned()
452        {
453            self.emit_event(McpClientEvent::ServerInstructionsUpdated {
454                server: server_name.to_string(),
455                instructions: Some(instructions),
456            })
457            .await;
458        }
459    }
460
461    fn proxy_instructions(&self) -> Option<(String, String)> {
462        let proxy = self.proxy.as_ref()?;
463        let descriptions: Vec<(String, String)> = proxy
464            .members()
465            .iter()
466            .filter_map(|member| {
467                let conn = self.connection_for(member)?;
468                Some((member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
469            })
470            .collect();
471        Some((proxy.name().to_string(), ToolProxy::build_instructions(proxy.tool_dir(), &descriptions)))
472    }
473
474    pub async fn emit_connection_ready(&self) {
475        self.emit_event(McpClientEvent::ConnectionReady(McpConnectionDetails {
476            tool_definitions: self.tool_definitions(),
477            instructions: self.server_instructions(),
478            server_statuses: self.server_statuses(),
479        }))
480        .await;
481    }
482
483    async fn emit_authentication_failed(&self, server: String, error: String) {
484        self.emit_event(McpClientEvent::AuthenticationFailed { server, error }).await;
485    }
486
487    async fn emit_event(&self, event: McpClientEvent) {
488        if let Err(e) = self.event_sender.send(event).await {
489            tracing::warn!("Failed to emit MCP client event: {e}");
490        }
491    }
492
493    fn connect_config(&self) -> Arc<ConnectConfig> {
494        Arc::new(ConnectConfig {
495            client_info: self.client_info.clone(),
496            event_sender: self.event_sender.clone(),
497            roots: Arc::clone(&self.roots),
498            oauth_handler_factory: self.oauth_handler_factory.clone(),
499            oauth_credential_store: self.oauth_credential_store.clone(),
500        })
501    }
502
503    fn proxy_tool_dir(&self) -> Result<PathBuf> {
504        self.aether_home
505            .as_ref()
506            .map(|home| ToolProxy::dir_in_home(home, DEFAULT_PROXY_NAME))
507            .map_or_else(|| ToolProxy::dir(DEFAULT_PROXY_NAME), Ok)
508    }
509
510    async fn register_connection(
511        &mut self,
512        name: &str,
513        conn: McpServerConnection,
514        reauth_config: Option<StreamableHttpClientTransportConfig>,
515        proxied: bool,
516    ) -> Result<Vec<RmcpTool>> {
517        let tools = conn
518            .list_tools()
519            .await
520            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
521        self.apply_connected(name, conn, &tools, reauth_config, proxied);
522        Ok(tools)
523    }
524
525    fn apply_connected(
526        &mut self,
527        name: &str,
528        conn: McpServerConnection,
529        tools: &[RmcpTool],
530        reauth_config: Option<StreamableHttpClientTransportConfig>,
531        proxied: bool,
532    ) {
533        let existing_reauth = self.servers.get(name).and_then(|r| r.reauth_config.clone());
534        let final_reauth = reauth_config.or(existing_reauth);
535        let tools = tools.iter().map(Tool::from).collect();
536
537        self.remember_server_order(name);
538        self.servers.insert(name.to_string(), ServerRecord::connected(conn, tools, final_reauth, proxied));
539    }
540
541    fn register_proxy(&mut self, tool_dir: std::path::PathBuf, members: HashSet<String>) {
542        self.proxy = Some(ToolProxy::new(DEFAULT_PROXY_NAME.to_string(), members, tool_dir));
543    }
544
545    async fn refresh_proxy_after_auth(&mut self, name: &str, tools: &[RmcpTool], proxied: bool) {
546        if !proxied {
547            return;
548        }
549
550        if let Some(proxy) = self.proxy.as_mut() {
551            proxy.add_member(name.to_string());
552        }
553
554        if let Some(tool_dir) = self.proxy.as_ref().map(|proxy| proxy.tool_dir().to_path_buf())
555            && let Err(e) = ToolProxy::write_tool_entries_to_dir(name, tools, &tool_dir).await
556        {
557            tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
558        }
559    }
560
561    fn remember_server_order(&mut self, name: &str) {
562        if !self.server_order.iter().any(|n| n == name) {
563            self.server_order.push(name.to_string());
564        }
565    }
566
567    async fn apply_authentication_failure(&mut self, name: String, error: String) {
568        self.set_state(&name, ServerState::Failed { error: error.clone() });
569        self.emit_server_statuses_changed().await;
570        self.emit_authentication_failed(name, error).await;
571    }
572
573    fn set_state(&mut self, name: &str, state: ServerState) {
574        self.remember_server_order(name);
575        match self.servers.get_mut(name) {
576            Some(record) => record.state = state,
577            None => {
578                self.servers.insert(name.to_string(), ServerRecord::new(state, None, false));
579            }
580        }
581    }
582
583    fn register_record(
584        &mut self,
585        name: &str,
586        state: ServerState,
587        reauth_config: Option<StreamableHttpClientTransportConfig>,
588        proxied: bool,
589    ) {
590        self.remember_server_order(name);
591        self.servers.insert(name.to_string(), ServerRecord::new(state, reauth_config, proxied));
592    }
593
594    fn connection_for(&self, server_name: &str) -> Option<&McpServerConnection> {
595        self.servers.get(server_name).and_then(ServerRecord::connection)
596    }
597
598    fn client_for_server(&self, server_name: &str) -> Option<Arc<RunningService<RoleClient, McpClient>>> {
599        self.connection_for(server_name).map(|conn| conn.client.clone())
600    }
601
602    fn is_routable_tool(&self, namespaced_tool_name: &str) -> bool {
603        if self.proxy.as_ref().is_some_and(|proxy| proxy.call_tool_name() == namespaced_tool_name) {
604            return true;
605        }
606        match split_on_server_name(namespaced_tool_name) {
607            Some((server_name, tool_name)) => {
608                self.servers.get(server_name).is_some_and(|record| !record.proxied && record.has_tool(tool_name))
609            }
610            None => false,
611        }
612    }
613
614    async fn notify_roots_changed(&self) {
615        for (server_name, record) in &self.servers {
616            if let Some(conn) = record.connection()
617                && let Err(e) = conn.client.notify_roots_list_changed().await
618            {
619                tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
620            }
621        }
622    }
623}
624
625impl Drop for McpManager {
626    fn drop(&mut self) {
627        let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
628        for (server_name, record) in servers {
629            if let Some(conn) = record.into_connection()
630                && let Some(handle) = conn.server_task
631            {
632                handle.abort();
633                tracing::warn!("Server '{server_name}' task aborted during cleanup");
634            }
635        }
636    }
637}
638
639/// Internal record holding all mutable state for a single MCP server.
640struct ServerRecord {
641    state: ServerState,
642    reauth_config: Option<StreamableHttpClientTransportConfig>,
643    proxied: bool,
644}
645
646enum ServerState {
647    Connecting,
648    Connected { connection: McpServerConnection, tools: Vec<Tool> },
649    Authenticating,
650    Failed { error: String },
651    NeedsOAuth,
652}
653
654impl From<&ServerState> for McpServerStatus {
655    fn from(state: &ServerState) -> Self {
656        match state {
657            ServerState::Connecting => Self::Connecting,
658            ServerState::Connected { tools, .. } => Self::Connected { tool_count: tools.len() },
659            ServerState::Authenticating => Self::Authenticating,
660            ServerState::Failed { error } => Self::Failed { error: error.clone() },
661            ServerState::NeedsOAuth => Self::NeedsOAuth,
662        }
663    }
664}
665
666impl ServerRecord {
667    fn new(state: ServerState, reauth_config: Option<StreamableHttpClientTransportConfig>, proxied: bool) -> Self {
668        Self { state, reauth_config, proxied }
669    }
670
671    fn connected(
672        connection: McpServerConnection,
673        tools: Vec<Tool>,
674        reauth_config: Option<StreamableHttpClientTransportConfig>,
675        proxied: bool,
676    ) -> Self {
677        Self::new(ServerState::Connected { connection, tools }, reauth_config, proxied)
678    }
679
680    fn tools(&self) -> &[Tool] {
681        match &self.state {
682            ServerState::Connected { tools, .. } => tools,
683            ServerState::Connecting
684            | ServerState::Authenticating
685            | ServerState::Failed { .. }
686            | ServerState::NeedsOAuth => &[],
687        }
688    }
689
690    fn has_tool(&self, tool_name: &str) -> bool {
691        self.tools().iter().any(|tool| tool.name == tool_name)
692    }
693
694    fn connection(&self) -> Option<&McpServerConnection> {
695        match &self.state {
696            ServerState::Connected { connection, .. } => Some(connection),
697            ServerState::Connecting
698            | ServerState::Authenticating
699            | ServerState::Failed { .. }
700            | ServerState::NeedsOAuth => None,
701        }
702    }
703
704    fn into_connection(self) -> Option<McpServerConnection> {
705        match self.state {
706            ServerState::Connected { connection, .. } => Some(connection),
707            ServerState::Connecting
708            | ServerState::Authenticating
709            | ServerState::Failed { .. }
710            | ServerState::NeedsOAuth => None,
711        }
712    }
713
714    fn auth_capability(&self) -> McpServerAuthCapability {
715        if self.reauth_config.is_some() { McpServerAuthCapability::OAuth } else { McpServerAuthCapability::Unavailable }
716    }
717
718    fn can_authenticate(&self) -> bool {
719        self.reauth_config.is_some()
720    }
721
722    fn status(&self) -> McpServerStatus {
723        (&self.state).into()
724    }
725
726    fn status_entry(&self, name: &str) -> McpServerStatusEntry {
727        McpServerStatusEntry::new(name, self.status())
728            .with_auth_capability(self.auth_capability())
729            .with_proxied(self.proxied)
730    }
731}
732
733#[cfg(test)]
734mod tests {
735    use super::{DEFAULT_PROXY_NAME, McpClientEvent, McpManager, McpServerStatus, ServerState};
736    use crate::client::OAuthHandlerFactory;
737    use crate::client::config::{McpServer, McpTransport};
738    use crate::client::connection::{McpConnectAttempt, McpConnectOutcome};
739    use crate::status::McpServerAuthCapability;
740    use aether_auth::{OAuthCallback, OAuthError, OAuthHandler};
741    use futures::future::BoxFuture;
742    use rmcp::{
743        Json, RoleServer, ServerHandler,
744        handler::server::{router::tool::ToolRouter, wrapper::Parameters},
745        model::{Implementation, ServerCapabilities, ServerInfo},
746        service::DynService,
747        tool, tool_handler, tool_router,
748        transport::streamable_http_client::StreamableHttpClientTransportConfig,
749    };
750    use schemars::JsonSchema;
751    use serde::{Deserialize, Serialize};
752    use std::{
753        io,
754        sync::{Arc, Mutex},
755    };
756    use tokio::sync::mpsc;
757
758    #[derive(Clone)]
759    struct TestServer {
760        tool_router: ToolRouter<Self>,
761    }
762
763    #[tool_handler(router = self.tool_router)]
764    impl ServerHandler for TestServer {
765        fn get_info(&self) -> ServerInfo {
766            ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
767                .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
768                .with_instructions("Test server instructions")
769        }
770    }
771
772    impl Default for TestServer {
773        fn default() -> Self {
774            Self { tool_router: Self::tool_router() }
775        }
776    }
777
778    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
779    struct EchoRequest {
780        value: String,
781    }
782
783    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
784    struct EchoResult {
785        value: String,
786    }
787
788    #[tool_router]
789    impl TestServer {
790        fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
791            Box::new(self)
792        }
793
794        #[tool(description = "Returns the provided value")]
795        async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
796            let Parameters(EchoRequest { value }) = request;
797            Json(EchoResult { value })
798        }
799    }
800
801    #[derive(Clone)]
802    struct SharedWriter(Arc<Mutex<Vec<u8>>>);
803
804    impl io::Write for SharedWriter {
805        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
806            self.0.lock().unwrap().extend_from_slice(buf);
807            Ok(buf.len())
808        }
809
810        fn flush(&mut self) -> io::Result<()> {
811            Ok(())
812        }
813    }
814
815    struct TestOAuthHandler;
816
817    impl OAuthHandler for TestOAuthHandler {
818        fn redirect_uri(&self) -> &'static str {
819            "http://127.0.0.1:0/oauth2callback"
820        }
821
822        fn authorize(&self, _auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
823            Box::pin(async { Err(OAuthError::UserCancelled) })
824        }
825    }
826
827    fn test_oauth_handler_factory() -> OAuthHandlerFactory {
828        Arc::new(|_ctx| Ok(Arc::new(TestOAuthHandler)))
829    }
830
831    #[tokio::test]
832    async fn authenticate_server_task_rejects_record_without_reauth_config() {
833        let (event_sender, _event_receiver) = mpsc::channel(1);
834        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
835        manager.register_record("public", ServerState::Connecting, None, false);
836
837        let error = match manager.authenticate_server_task("public").await {
838            Ok(_) => panic!("non-OAuth server should be rejected"),
839            Err(error) => error.to_string(),
840        };
841        assert!(error.contains("not OAuth-authenticatable"));
842    }
843
844    #[tokio::test]
845    async fn authenticate_server_task_marks_server_authenticating_and_emits_status() {
846        let (event_sender, mut event_receiver) = mpsc::channel(2);
847        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
848        manager.register_record(
849            "remote",
850            ServerState::NeedsOAuth,
851            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
852            false,
853        );
854
855        let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
856
857        assert!(matches!(manager.server_statuses()[0].status, McpServerStatus::Authenticating));
858        let event = event_receiver.recv().await.expect("status change event");
859        let McpClientEvent::ServerStatusesChanged(servers) = event else {
860            panic!("expected ServerStatusesChanged");
861        };
862        let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
863        assert!(matches!(status.status, McpServerStatus::Authenticating));
864        assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
865    }
866
867    #[tokio::test]
868    async fn apply_connection_attempt_failure_allows_retry() {
869        let (event_sender, mut event_receiver) = mpsc::channel(2);
870        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
871        manager.register_record(
872            "remote",
873            ServerState::NeedsOAuth,
874            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
875            false,
876        );
877        let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
878        let _authenticating_event = event_receiver.recv().await.expect("authenticating status change event");
879
880        manager
881            .apply_connection_attempt(McpConnectAttempt {
882                name: "remote".to_string(),
883                proxied: false,
884                outcome: McpConnectOutcome::Failed {
885                    error: crate::client::McpError::ConnectionFailed("boom".to_string()),
886                },
887            })
888            .await;
889
890        let event = event_receiver.recv().await.expect("status change event");
891        let McpClientEvent::ServerStatusesChanged(servers) = event else {
892            panic!("expected ServerStatusesChanged");
893        };
894        let auth_event = event_receiver.recv().await.expect("authentication failure event");
895        let McpClientEvent::AuthenticationFailed { server, error } = auth_event else {
896            panic!("expected AuthenticationFailed");
897        };
898        assert_eq!(server, "remote");
899        assert!(error.contains("boom"));
900
901        let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
902        assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
903        assert!(matches!(status.status, McpServerStatus::Failed { ref error } if error.contains("boom")));
904        assert!(manager.authenticate_server_task("remote").await.is_ok());
905    }
906
907    #[test]
908    fn status_entries_are_derived_from_reauth_config() {
909        let (event_sender, _event_receiver) = mpsc::channel(1);
910        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
911
912        manager.register_record(
913            "with-oauth",
914            ServerState::Connecting,
915            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp")),
916            false,
917        );
918        manager.register_record("without-oauth", ServerState::Connecting, None, false);
919        manager.register_record(
920            "needs-oauth",
921            ServerState::NeedsOAuth,
922            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp2")),
923            false,
924        );
925
926        let statuses = manager.server_statuses();
927        let with_oauth = statuses.iter().find(|s| s.name == "with-oauth").unwrap();
928        let without_oauth = statuses.iter().find(|s| s.name == "without-oauth").unwrap();
929        let needs_oauth = statuses.iter().find(|s| s.name == "needs-oauth").unwrap();
930
931        assert_eq!(with_oauth.auth_capability, McpServerAuthCapability::OAuth);
932        assert_eq!(without_oauth.auth_capability, McpServerAuthCapability::Unavailable);
933        assert_eq!(needs_oauth.auth_capability, McpServerAuthCapability::OAuth);
934    }
935
936    #[tokio::test]
937    async fn register_pending_marks_every_server_connecting_and_emits_status() {
938        let (event_sender, mut event_receiver) = mpsc::channel(32);
939        let mut manager = McpManager::new(event_sender, None);
940
941        let servers = vec![
942            McpServer::new("alpha", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
943            McpServer::new("beta", McpTransport::InMemory { server: TestServer::default().into_dyn() }, true),
944        ];
945
946        let returned = manager.register_pending(servers).await.unwrap();
947        assert_eq!(returned.iter().map(|s| s.name.as_str()).collect::<Vec<_>>(), vec!["alpha", "beta"]);
948
949        let statuses = manager.server_statuses();
950        assert_eq!(statuses.len(), 2);
951        assert!(matches!(statuses.iter().find(|s| s.name == "alpha").unwrap().status, McpServerStatus::Connecting));
952        assert!(matches!(statuses.iter().find(|s| s.name == "beta").unwrap().status, McpServerStatus::Connecting));
953        assert!(statuses.iter().find(|s| s.name == "beta").unwrap().proxied);
954
955        let event = event_receiver.try_recv().expect("expected initial ServerStatusesChanged emission");
956        let McpClientEvent::ServerStatusesChanged(emitted) = event else {
957            panic!("expected ServerStatusesChanged, got {event:?}");
958        };
959        assert_eq!(emitted.iter().map(|s| s.name.as_str()).collect::<Vec<_>>(), vec!["alpha", "beta"]);
960    }
961
962    #[tokio::test]
963    async fn apply_connection_attempt_emits_instructions_updated_after_connect() {
964        let (event_sender, mut event_receiver) = mpsc::channel(32);
965        let mut manager = McpManager::new(event_sender, None);
966
967        let servers =
968            vec![McpServer::new("test", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false)];
969        manager.add_mcps(servers).await.unwrap();
970
971        let mut update_for_test = None;
972        while let Ok(event) = event_receiver.try_recv() {
973            if let McpClientEvent::ServerInstructionsUpdated { server, instructions } = event
974                && server == "test"
975            {
976                update_for_test = Some(instructions);
977            }
978        }
979        let instructions = update_for_test.expect("expected ServerInstructionsUpdated for 'test'");
980        assert!(instructions.is_some(), "TestServer publishes instructions, so update should carry Some(_)");
981    }
982
983    #[tokio::test]
984    async fn server_statuses_mark_direct_and_proxied_servers_without_proxy_row() {
985        let (event_sender, _event_receiver) = mpsc::channel(32);
986        let mut manager = McpManager::new(event_sender, None);
987        manager
988            .add_mcps(vec![
989                McpServer::new("direct", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
990                McpServer::new("math", McpTransport::InMemory { server: TestServer::default().into_dyn() }, true),
991            ])
992            .await
993            .unwrap();
994
995        let statuses = manager.server_statuses();
996        assert_eq!(statuses.iter().map(|status| status.name.as_str()).collect::<Vec<_>>(), vec!["direct", "math"]);
997        assert!(!statuses.iter().find(|status| status.name == "direct").unwrap().proxied);
998        assert!(statuses.iter().find(|status| status.name == "math").unwrap().proxied);
999        assert!(!statuses.iter().any(|status| status.name == DEFAULT_PROXY_NAME));
1000    }
1001
1002    #[tokio::test]
1003    async fn tool_definitions_drop_when_a_server_shuts_down() {
1004        let (event_sender, _event_receiver) = mpsc::channel(32);
1005        let mut manager = McpManager::new(event_sender, None);
1006        manager
1007            .add_mcps(vec![
1008                McpServer::new("git", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
1009                McpServer::new("github", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
1010            ])
1011            .await
1012            .unwrap();
1013
1014        let names =
1015            |manager: &McpManager| manager.tool_definitions().into_iter().map(|tool| tool.name).collect::<Vec<_>>();
1016        assert!(names(&manager).contains(&"git__echo".to_string()));
1017        assert!(names(&manager).contains(&"github__echo".to_string()));
1018
1019        manager.shutdown_server("git").await.unwrap();
1020
1021        assert!(!names(&manager).iter().any(|name| name.starts_with("git__")));
1022        assert!(names(&manager).contains(&"github__echo".to_string()));
1023    }
1024
1025    #[tokio::test]
1026    async fn drop_logs_cleanup_abort_with_tracing() {
1027        let (event_sender, _event_receiver) = mpsc::channel(32);
1028        let mut manager = McpManager::new(event_sender, None);
1029        manager
1030            .add_mcps(vec![McpServer::new(
1031                "test",
1032                McpTransport::InMemory { server: TestServer::default().into_dyn() },
1033                false,
1034            )])
1035            .await
1036            .unwrap();
1037
1038        let output = Arc::new(Mutex::new(Vec::new()));
1039        let subscriber = tracing_subscriber::fmt()
1040            .with_ansi(false)
1041            .without_time()
1042            .with_writer({
1043                let output = Arc::clone(&output);
1044                move || SharedWriter(Arc::clone(&output))
1045            })
1046            .finish();
1047
1048        tracing::subscriber::with_default(subscriber, || {
1049            drop(manager);
1050        });
1051
1052        let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
1053        assert!(logs.contains("Server 'test' task aborted during cleanup"));
1054    }
1055}