Skip to main content

mcp_utils/client/
manager.rs

1use llm::ToolDefinition;
2
3use super::{
4    McpError, Result,
5    config::McpServer,
6    connection::{
7        ConnectContext, McpConnectAttempt, McpConnectOutcome, McpServerConnection, ServerInstructions, Tool,
8        authenticate_http, connect_server,
9    },
10    mcp_client::McpClient,
11    naming::{create_namespaced_tool_name, split_on_server_name},
12    oauth::OAuthHandler,
13    tool_proxy::ToolProxy,
14};
15use futures::future::join_all;
16use rmcp::{
17    RoleClient,
18    model::{
19        CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
20        ElicitationAction, FormElicitationCapability, Implementation, Root, Tool as RmcpTool, UrlElicitationCapability,
21    },
22    service::RunningService,
23    transport::streamable_http_client::StreamableHttpClientTransportConfig,
24};
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::{HashMap, HashSet};
28use std::future::Future;
29use std::path::PathBuf;
30use std::sync::Arc;
31use tokio::sync::{RwLock, mpsc, oneshot};
32
33pub use crate::status::{McpServerAuthCapability, McpServerStatus, McpServerStatusEntry};
34
35pub const DEFAULT_PROXY_NAME: &str = "proxy";
36
37pub type OAuthHandlerFactory = Arc<dyn Fn() -> Result<Arc<dyn OAuthHandler>> + Send + Sync>;
38
39#[derive(Debug)]
40pub struct ElicitationRequest {
41    pub server_name: String,
42    pub request: CreateElicitationRequestParams,
43    pub response_sender: oneshot::Sender<CreateElicitationResult>,
44}
45
46#[derive(Debug, Clone)]
47pub struct ElicitationResponse {
48    pub action: ElicitationAction,
49    pub content: Option<Value>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
53pub struct UrlElicitationCompleteParams {
54    pub server_name: String,
55    pub elicitation_id: String,
56}
57
58/// Events emitted by MCP clients that require attention from the host
59/// (e.g. the relay or TUI). Flows through a single channel from `McpManager`
60/// to the consumer.
61#[derive(Debug)]
62pub enum McpClientEvent {
63    Elicitation(ElicitationRequest),
64    UrlElicitationComplete(UrlElicitationCompleteParams),
65    ServerStatusesChanged(Vec<McpServerStatusEntry>),
66    ToolDefinitionsChanged(Vec<ToolDefinition>),
67    AuthenticationFailed { server: String, error: String },
68}
69
70/// Manages connections to multiple MCP servers and their tools
71pub struct McpManager {
72    servers: HashMap<String, ServerRecord>,
73    server_order: Vec<String>,
74    tools: HashMap<String, Tool>,
75    tool_definitions: Vec<ToolDefinition>,
76    proxy: Option<ToolProxy>,
77    aether_home: Option<PathBuf>,
78    client_info: ClientInfo,
79    event_sender: mpsc::Sender<McpClientEvent>,
80    /// Roots shared with all MCP clients
81    roots: Arc<RwLock<Vec<Root>>>,
82    oauth_handler_factory: Option<OAuthHandlerFactory>,
83    server_statuses: Vec<McpServerStatusEntry>,
84}
85
86impl McpManager {
87    pub fn new(event_sender: mpsc::Sender<McpClientEvent>, oauth_handler_factory: Option<OAuthHandlerFactory>) -> Self {
88        let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
89        if let Some(elicitation) = capabilities.elicitation.as_mut() {
90            elicitation.form = Some(FormElicitationCapability::default());
91            elicitation.url = Some(UrlElicitationCapability::default());
92        }
93
94        Self {
95            servers: HashMap::new(),
96            server_order: Vec::new(),
97            tools: HashMap::new(),
98            tool_definitions: Vec::new(),
99            proxy: None,
100            aether_home: None,
101            client_info: ClientInfo::new(capabilities, Implementation::new("aether", "0.1.0")),
102            event_sender,
103            roots: Arc::new(RwLock::new(Vec::new())),
104            oauth_handler_factory,
105            server_statuses: Vec::new(),
106        }
107    }
108
109    pub fn with_aether_home(mut self, aether_home: impl Into<PathBuf>) -> Self {
110        self.aether_home = Some(aether_home.into());
111        self
112    }
113
114    pub async fn add_mcps(&mut self, servers: Vec<McpServer>) -> Result<()> {
115        let has_proxy = servers.iter().any(|server| server.proxy);
116        if has_proxy && servers.iter().any(|server| server.name == DEFAULT_PROXY_NAME) {
117            return Err(McpError::Other("server name 'proxy' collides with the tool proxy".into()));
118        }
119
120        let proxied_members: HashSet<String> =
121            servers.iter().filter(|server| server.proxy).map(|server| server.name.clone()).collect();
122        let proxy_tool_dir = if has_proxy {
123            let dir = self.proxy_tool_dir()?;
124            ToolProxy::clean_dir(&dir).await?;
125            Some(dir)
126        } else {
127            None
128        };
129
130        let ctx = self.connect_context();
131        let attempts = join_all(servers.into_iter().map(|server| connect_server(server, &ctx))).await;
132
133        let mut connected_proxied = Vec::new();
134        for McpConnectAttempt { name, proxied, outcome } in attempts {
135            match outcome {
136                McpConnectOutcome::Connected { conn, reauth_config } => {
137                    self.register_connection(&name, conn, reauth_config, proxied).await?;
138                    if proxied {
139                        connected_proxied.push(name);
140                    }
141                }
142                McpConnectOutcome::NeedsOAuth { config, error } => {
143                    tracing::warn!("Server '{name}' needs OAuth: {error}");
144                    self.register_record(&name, McpServerStatus::NeedsOAuth, Some(config), proxied);
145                }
146                McpConnectOutcome::Failed { error } => {
147                    tracing::warn!("Failed to connect to MCP server '{name}': {error}");
148                    if !self.servers.contains_key(&name) {
149                        self.register_record(
150                            &name,
151                            McpServerStatus::Failed { error: error.to_string() },
152                            None,
153                            proxied,
154                        );
155                    }
156                }
157            }
158        }
159
160        if let Some(tool_dir) = proxy_tool_dir {
161            self.write_proxy_tool_files(&connected_proxied, &tool_dir).await;
162            self.register_proxy(tool_dir, proxied_members);
163        }
164
165        Ok(())
166    }
167
168    pub fn get_client_for_tool(
169        &self,
170        namespaced_tool_name: &str,
171        arguments_json: &str,
172    ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
173        if !self.tools.contains_key(namespaced_tool_name) {
174            return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
175        }
176
177        let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
178            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
179
180        if let Some(proxy) = self.proxy.as_ref().filter(|proxy| proxy.name() == server_name) {
181            let call = proxy.resolve_call(arguments_json)?;
182            let conn = self.connection_for(&call.server).ok_or_else(|| {
183                McpError::ServerNotFound(format!("Proxied server '{}' is not connected", call.server))
184            })?;
185            let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
186            return Ok((conn.client.clone(), params));
187        }
188
189        let client =
190            self.client_for_server(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
191
192        let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
193        let mut params = CallToolRequestParams::new(tool_name.to_string());
194        if let Some(args) = arguments {
195            params = params.with_arguments(args);
196        }
197
198        Ok((client, params))
199    }
200
201    pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
202        self.tool_definitions.clone()
203    }
204
205    pub fn server_instructions(&self) -> Vec<ServerInstructions> {
206        let mut instructions: Vec<ServerInstructions> = self
207            .servers
208            .iter()
209            .filter(|(name, _)| self.proxy.as_ref().is_none_or(|proxy| !proxy.contains_server(name)))
210            .filter_map(|(name, record)| {
211                record
212                    .connection
213                    .as_ref()
214                    .and_then(|conn| conn.instructions.as_ref())
215                    .map(|instr| ServerInstructions { server_name: name.clone(), instructions: instr.clone() })
216            })
217            .collect();
218
219        if let Some(proxy) = &self.proxy {
220            let descriptions: Vec<(String, String)> = proxy
221                .members()
222                .iter()
223                .filter_map(|member| {
224                    let conn = self.connection_for(member)?;
225                    Some((member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
226                })
227                .collect();
228            instructions.push(ServerInstructions {
229                server_name: proxy.name().to_string(),
230                instructions: ToolProxy::build_instructions(proxy.tool_dir(), &descriptions),
231            });
232        }
233
234        instructions
235    }
236
237    pub fn server_statuses(&self) -> &[McpServerStatusEntry] {
238        &self.server_statuses
239    }
240
241    pub async fn authenticate_server_task(
242        &mut self,
243        name: &str,
244    ) -> Result<impl Future<Output = McpConnectAttempt> + Send + 'static> {
245        let record = self
246            .servers
247            .get(name)
248            .ok_or_else(|| McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")))?;
249        if !record.can_authenticate() {
250            return Err(McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")));
251        }
252        if matches!(record.status, McpServerStatus::Authenticating) {
253            return Err(McpError::ConnectionFailed(format!("server '{name}' is already authenticating")));
254        }
255
256        let oauth_handler_factory = self
257            .oauth_handler_factory
258            .clone()
259            .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler factory available for '{name}'")))?;
260        let name = name.to_string();
261        let config = record.reauth_config.clone().expect("checked above");
262        let client_info = self.client_info.clone();
263        let event_sender = self.event_sender.clone();
264        let roots = Arc::clone(&self.roots);
265        let proxied = record.proxied;
266
267        self.set_status(&name, McpServerStatus::Authenticating);
268        self.emit_server_statuses_changed().await;
269
270        Ok(async move {
271            authenticate_http(name, config, client_info, event_sender, roots, oauth_handler_factory, proxied).await
272        })
273    }
274
275    pub async fn apply_connection_attempt(&mut self, attempt: McpConnectAttempt) {
276        let McpConnectAttempt { name, proxied, outcome } = attempt;
277        match outcome {
278            McpConnectOutcome::Connected { conn, reauth_config } => {
279                match self.register_connection(&name, conn, reauth_config, proxied).await {
280                    Ok(tools) => {
281                        self.refresh_proxy_after_auth(&name, &tools, proxied).await;
282                        self.emit_server_statuses_changed().await;
283                        self.emit_tool_definitions_changed().await;
284                    }
285                    Err(error) => self.apply_authentication_failure(name, error.to_string()).await,
286                }
287            }
288            McpConnectOutcome::Failed { error } => {
289                self.apply_authentication_failure(name, error.to_string()).await;
290            }
291            McpConnectOutcome::NeedsOAuth { .. } => {
292                self.apply_authentication_failure(name, "internal error: auth task returned NeedsOAuth".to_string())
293                    .await;
294            }
295        }
296    }
297
298    /// List all prompts from all connected MCP servers with namespacing
299    pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
300        let futures: Vec<_> = self
301            .servers
302            .iter()
303            .filter_map(|(server_name, record)| {
304                let conn = record.connection.as_ref()?;
305                conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref())?;
306                let server_name = server_name.clone();
307                let client = conn.client.clone();
308                Some(async move {
309                    let prompts_response = client.list_prompts(None).await.map_err(|e| {
310                        McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
311                    })?;
312
313                    let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
314                        .prompts
315                        .into_iter()
316                        .map(|prompt| {
317                            let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
318                            rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
319                        })
320                        .collect();
321
322                    Ok::<_, McpError>(namespaced_prompts)
323                })
324            })
325            .collect();
326
327        let results = join_all(futures).await;
328        let mut all_prompts = Vec::new();
329        for result in results {
330            all_prompts.extend(result?);
331        }
332
333        Ok(all_prompts)
334    }
335
336    /// Get a specific prompt by namespaced name
337    pub async fn get_prompt(
338        &self,
339        namespaced_prompt_name: &str,
340        arguments: Option<serde_json::Map<String, serde_json::Value>>,
341    ) -> Result<rmcp::model::GetPromptResult> {
342        let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
343            .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
344
345        let server_conn =
346            self.connection_for(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
347
348        let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
349        if let Some(args) = arguments {
350            request = request.with_arguments(args);
351        }
352
353        server_conn.client.get_prompt(request).await.map_err(|e| {
354            McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
355        })
356    }
357
358    /// Shutdown all servers and wait for their tasks to complete
359    pub async fn shutdown(&mut self) {
360        let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
361
362        for (server_name, record) in servers {
363            if let Some(conn) = record.connection
364                && let Some(handle) = conn.server_task
365            {
366                drop(conn.client);
367
368                match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
369                    Ok(Ok(())) => {
370                        tracing::info!("Server '{server_name}' shut down gracefully");
371                    }
372                    Ok(Err(e)) => {
373                        tracing::warn!("Server '{server_name}' task panicked: {e:?}");
374                    }
375                    Err(_) => {
376                        tracing::warn!("Server '{server_name}' shutdown timed out");
377                    }
378                }
379            }
380        }
381
382        self.tools.clear();
383        self.tool_definitions.clear();
384        self.proxy = None;
385    }
386
387    /// Shutdown a specific server by name
388    pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
389        let record = self.servers.remove(server_name);
390
391        if let Some(record) = record {
392            if let Some(conn) = record.connection
393                && let Some(handle) = conn.server_task
394            {
395                drop(conn.client);
396
397                match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
398                    Ok(Ok(())) => {
399                        tracing::info!("Server '{server_name}' shut down gracefully");
400                    }
401                    Ok(Err(e)) => {
402                        tracing::warn!("Server '{server_name}' task panicked: {e:?}");
403                    }
404                    Err(_) => {
405                        tracing::warn!("Server '{server_name}' shutdown timed out");
406                    }
407                }
408            }
409
410            self.remove_registered_tools_for_server(server_name);
411            self.refresh_status_entries();
412        }
413
414        Ok(())
415    }
416
417    /// Set the roots advertised to MCP servers.
418    ///
419    /// This updates the roots and sends notifications to all connected servers
420    /// that support the `roots/list_changed` notification.
421    pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
422        {
423            let mut roots = self.roots.write().await;
424            *roots = new_roots;
425        }
426
427        self.notify_roots_changed().await;
428
429        Ok(())
430    }
431
432    async fn emit_server_statuses_changed(&self) {
433        self.emit_event(McpClientEvent::ServerStatusesChanged(self.server_statuses().to_vec())).await;
434    }
435
436    async fn emit_tool_definitions_changed(&self) {
437        self.emit_event(McpClientEvent::ToolDefinitionsChanged(self.tool_definitions())).await;
438    }
439
440    async fn emit_authentication_failed(&self, server: String, error: String) {
441        self.emit_event(McpClientEvent::AuthenticationFailed { server, error }).await;
442    }
443
444    async fn emit_event(&self, event: McpClientEvent) {
445        if let Err(e) = self.event_sender.send(event).await {
446            tracing::warn!("Failed to emit MCP client event: {e}");
447        }
448    }
449
450    fn connect_context(&self) -> ConnectContext<'_> {
451        ConnectContext {
452            client_info: &self.client_info,
453            event_sender: &self.event_sender,
454            roots: &self.roots,
455            oauth_handler_factory: self.oauth_handler_factory.as_ref(),
456        }
457    }
458
459    fn proxy_tool_dir(&self) -> Result<PathBuf> {
460        self.aether_home
461            .as_ref()
462            .map(|home| ToolProxy::dir_in_home(home, DEFAULT_PROXY_NAME))
463            .map_or_else(|| ToolProxy::dir(DEFAULT_PROXY_NAME), Ok)
464    }
465
466    async fn register_connection(
467        &mut self,
468        name: &str,
469        conn: McpServerConnection,
470        reauth_config: Option<StreamableHttpClientTransportConfig>,
471        proxied: bool,
472    ) -> Result<Vec<RmcpTool>> {
473        let tools = conn
474            .list_tools()
475            .await
476            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
477        self.apply_connected(name, conn, &tools, reauth_config, proxied);
478        Ok(tools)
479    }
480
481    fn apply_connected(
482        &mut self,
483        name: &str,
484        conn: McpServerConnection,
485        tools: &[RmcpTool],
486        reauth_config: Option<StreamableHttpClientTransportConfig>,
487        proxied: bool,
488    ) {
489        self.remove_registered_tools_for_server(name);
490
491        let existing_reauth = self.servers.get(name).and_then(|r| r.reauth_config.clone());
492        let final_reauth = reauth_config.or(existing_reauth);
493
494        for rmcp_tool in tools {
495            let tool_name = rmcp_tool.name.to_string();
496            let namespaced_tool_name = create_namespaced_tool_name(name, &tool_name);
497            let tool = Tool::from(rmcp_tool);
498
499            if !proxied {
500                self.tool_definitions.push(ToolDefinition {
501                    name: namespaced_tool_name.clone(),
502                    description: tool.description.clone(),
503                    parameters: tool.parameters.to_string(),
504                    server: Some(name.to_string()),
505                });
506                self.tools.insert(namespaced_tool_name, tool);
507            }
508        }
509
510        self.remember_server_order(name);
511        self.servers.insert(name.to_string(), ServerRecord::connected(conn, tools.len(), final_reauth, proxied));
512        self.refresh_status_entries();
513    }
514
515    fn register_proxy(&mut self, tool_dir: std::path::PathBuf, members: HashSet<String>) {
516        self.remove_registered_tools_for_server(DEFAULT_PROXY_NAME);
517        let call_tool_def = ToolProxy::call_tool_definition(DEFAULT_PROXY_NAME);
518        self.tools.insert(
519            call_tool_def.name.clone(),
520            Tool {
521                description: call_tool_def.description.clone(),
522                parameters: serde_json::from_str(&call_tool_def.parameters)
523                    .unwrap_or(Value::Object(serde_json::Map::default())),
524            },
525        );
526        self.tool_definitions.push(call_tool_def);
527
528        self.proxy = Some(ToolProxy::new(DEFAULT_PROXY_NAME.to_string(), members, tool_dir));
529    }
530
531    async fn refresh_proxy_after_auth(&mut self, name: &str, tools: &[RmcpTool], proxied: bool) {
532        if !proxied {
533            return;
534        }
535
536        if let Some(proxy) = self.proxy.as_mut() {
537            proxy.add_member(name.to_string());
538        }
539
540        if let Some(tool_dir) = self.proxy.as_ref().map(|proxy| proxy.tool_dir().to_path_buf())
541            && let Err(e) = ToolProxy::write_tool_entries_to_dir(name, tools, &tool_dir).await
542        {
543            tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
544        }
545    }
546
547    async fn write_proxy_tool_files(&self, connected_proxied: &[String], tool_dir: &std::path::Path) {
548        let writes = connected_proxied.iter().filter_map(|name| {
549            let client = self.client_for_server(name)?;
550            let dir = tool_dir.to_path_buf();
551            let name = name.clone();
552            Some(async move {
553                if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &dir).await {
554                    tracing::warn!("Failed to write tool files for proxied server '{name}': {e}");
555                }
556            })
557        });
558        join_all(writes).await;
559    }
560
561    fn refresh_status_entries(&mut self) {
562        self.server_statuses = self
563            .server_order
564            .iter()
565            .filter_map(|name| self.servers.get(name).map(|record| record.status_entry(name)))
566            .collect();
567    }
568
569    fn remember_server_order(&mut self, name: &str) {
570        if !self.server_order.iter().any(|n| n == name) {
571            self.server_order.push(name.to_string());
572        }
573    }
574
575    async fn apply_authentication_failure(&mut self, name: String, error: String) {
576        self.set_status(&name, McpServerStatus::Failed { error: error.clone() });
577        self.emit_server_statuses_changed().await;
578        self.emit_authentication_failed(name, error).await;
579    }
580
581    fn set_status(&mut self, name: &str, status: McpServerStatus) {
582        self.remember_server_order(name);
583        let record =
584            self.servers.entry(name.to_string()).or_insert_with(|| ServerRecord::new(status.clone(), None, false));
585        record.status = status;
586        self.refresh_status_entries();
587    }
588
589    fn register_record(
590        &mut self,
591        name: &str,
592        status: McpServerStatus,
593        reauth_config: Option<StreamableHttpClientTransportConfig>,
594        proxied: bool,
595    ) {
596        self.remember_server_order(name);
597        self.servers.insert(name.to_string(), ServerRecord::new(status, reauth_config, proxied));
598        self.refresh_status_entries();
599    }
600
601    fn connection_for(&self, server_name: &str) -> Option<&McpServerConnection> {
602        self.servers.get(server_name).and_then(|record| record.connection.as_ref())
603    }
604
605    fn client_for_server(&self, server_name: &str) -> Option<Arc<RunningService<RoleClient, McpClient>>> {
606        self.connection_for(server_name).map(|conn| conn.client.clone())
607    }
608
609    fn remove_registered_tools_for_server(&mut self, server_name: &str) {
610        let prefix = format!("{server_name}__");
611        self.tools.retain(|tool_name, _| !tool_name.starts_with(&prefix));
612        self.tool_definitions.retain(|tool_def| !tool_def.name.starts_with(&prefix));
613    }
614
615    async fn notify_roots_changed(&self) {
616        for (server_name, record) in &self.servers {
617            if let Some(conn) = &record.connection
618                && let Err(e) = conn.client.notify_roots_list_changed().await
619            {
620                tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
621            }
622        }
623    }
624}
625
626impl Drop for McpManager {
627    fn drop(&mut self) {
628        let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
629        for (server_name, record) in servers {
630            if let Some(conn) = record.connection
631                && let Some(handle) = conn.server_task
632            {
633                handle.abort();
634                tracing::warn!("Server '{server_name}' task aborted during cleanup");
635            }
636        }
637    }
638}
639
640/// Internal record holding all mutable state for a single MCP server.
641struct ServerRecord {
642    connection: Option<McpServerConnection>,
643    status: McpServerStatus,
644    reauth_config: Option<StreamableHttpClientTransportConfig>,
645    proxied: bool,
646}
647
648impl ServerRecord {
649    fn new(status: McpServerStatus, reauth_config: Option<StreamableHttpClientTransportConfig>, proxied: bool) -> Self {
650        Self { connection: None, status, reauth_config, proxied }
651    }
652
653    fn connected(
654        connection: McpServerConnection,
655        tool_count: usize,
656        reauth_config: Option<StreamableHttpClientTransportConfig>,
657        proxied: bool,
658    ) -> Self {
659        Self { connection: Some(connection), status: McpServerStatus::Connected { tool_count }, reauth_config, proxied }
660    }
661
662    fn auth_capability(&self) -> McpServerAuthCapability {
663        if self.reauth_config.is_some() { McpServerAuthCapability::OAuth } else { McpServerAuthCapability::Unavailable }
664    }
665
666    fn can_authenticate(&self) -> bool {
667        self.reauth_config.is_some()
668    }
669
670    fn status_entry(&self, name: &str) -> McpServerStatusEntry {
671        McpServerStatusEntry::new(name, self.status.clone())
672            .with_auth_capability(self.auth_capability())
673            .with_proxied(self.proxied)
674    }
675}
676
677#[cfg(test)]
678mod tests {
679    use super::{DEFAULT_PROXY_NAME, McpClientEvent, McpManager, McpServerStatus, Tool};
680    use crate::client::OAuthHandlerFactory;
681    use crate::client::config::{McpServer, McpTransport};
682    use crate::client::connection::{McpConnectAttempt, McpConnectOutcome};
683    use crate::client::oauth::{OAuthCallback, OAuthError, OAuthHandler};
684    use crate::status::McpServerAuthCapability;
685    use futures::future::BoxFuture;
686    use llm::ToolDefinition;
687    use rmcp::{
688        Json, RoleServer, ServerHandler,
689        handler::server::{router::tool::ToolRouter, wrapper::Parameters},
690        model::{Implementation, ServerCapabilities, ServerInfo},
691        service::DynService,
692        tool, tool_handler, tool_router,
693        transport::streamable_http_client::StreamableHttpClientTransportConfig,
694    };
695    use schemars::JsonSchema;
696    use serde::{Deserialize, Serialize};
697    use serde_json::json;
698    use std::{
699        io,
700        sync::{Arc, Mutex},
701    };
702    use tokio::sync::mpsc;
703
704    #[derive(Clone)]
705    struct TestServer {
706        tool_router: ToolRouter<Self>,
707    }
708
709    #[tool_handler(router = self.tool_router)]
710    impl ServerHandler for TestServer {
711        fn get_info(&self) -> ServerInfo {
712            ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
713                .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
714        }
715    }
716
717    impl Default for TestServer {
718        fn default() -> Self {
719            Self { tool_router: Self::tool_router() }
720        }
721    }
722
723    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
724    struct EchoRequest {
725        value: String,
726    }
727
728    #[derive(Debug, Deserialize, Serialize, JsonSchema)]
729    struct EchoResult {
730        value: String,
731    }
732
733    #[tool_router]
734    impl TestServer {
735        fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
736            Box::new(self)
737        }
738
739        #[tool(description = "Returns the provided value")]
740        async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
741            let Parameters(EchoRequest { value }) = request;
742            Json(EchoResult { value })
743        }
744    }
745
746    #[derive(Clone)]
747    struct SharedWriter(Arc<Mutex<Vec<u8>>>);
748
749    impl io::Write for SharedWriter {
750        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
751            self.0.lock().unwrap().extend_from_slice(buf);
752            Ok(buf.len())
753        }
754
755        fn flush(&mut self) -> io::Result<()> {
756            Ok(())
757        }
758    }
759
760    struct TestOAuthHandler;
761
762    impl OAuthHandler for TestOAuthHandler {
763        fn redirect_uri(&self) -> &'static str {
764            "http://127.0.0.1:0/oauth2callback"
765        }
766
767        fn authorize(&self, _auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
768            Box::pin(async { Err(OAuthError::UserCancelled) })
769        }
770    }
771
772    fn test_oauth_handler_factory() -> OAuthHandlerFactory {
773        Arc::new(|| Ok(Arc::new(TestOAuthHandler)))
774    }
775
776    #[tokio::test]
777    async fn authenticate_server_task_rejects_record_without_reauth_config() {
778        let (event_sender, _event_receiver) = mpsc::channel(1);
779        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
780        manager.register_record("public", McpServerStatus::Connected { tool_count: 1 }, None, false);
781
782        let error = match manager.authenticate_server_task("public").await {
783            Ok(_) => panic!("non-OAuth server should be rejected"),
784            Err(error) => error.to_string(),
785        };
786        assert!(error.contains("not OAuth-authenticatable"));
787    }
788
789    #[tokio::test]
790    async fn authenticate_server_task_marks_server_authenticating_and_emits_status() {
791        let (event_sender, mut event_receiver) = mpsc::channel(2);
792        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
793        manager.register_record(
794            "remote",
795            McpServerStatus::NeedsOAuth,
796            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
797            false,
798        );
799
800        let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
801
802        assert!(matches!(manager.server_statuses()[0].status, McpServerStatus::Authenticating));
803        let event = event_receiver.recv().await.expect("status change event");
804        let McpClientEvent::ServerStatusesChanged(servers) = event else {
805            panic!("expected ServerStatusesChanged");
806        };
807        let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
808        assert!(matches!(status.status, McpServerStatus::Authenticating));
809        assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
810    }
811
812    #[tokio::test]
813    async fn authenticate_server_task_rejects_duplicate_same_server_while_in_flight() {
814        let (event_sender, _event_receiver) = mpsc::channel(1);
815        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
816        manager.register_record(
817            "remote",
818            McpServerStatus::NeedsOAuth,
819            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
820            false,
821        );
822
823        let _task = manager.authenticate_server_task("remote").await.expect("first auth should start");
824        let error = match manager.authenticate_server_task("remote").await {
825            Ok(_) => panic!("duplicate auth should be rejected"),
826            Err(error) => error.to_string(),
827        };
828
829        assert!(error.contains("already authenticating"));
830    }
831
832    #[tokio::test]
833    async fn apply_connection_attempt_failure_allows_retry() {
834        let (event_sender, mut event_receiver) = mpsc::channel(2);
835        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
836        manager.register_record(
837            "remote",
838            McpServerStatus::NeedsOAuth,
839            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
840            false,
841        );
842        let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
843        let _authenticating_event = event_receiver.recv().await.expect("authenticating status change event");
844
845        manager
846            .apply_connection_attempt(McpConnectAttempt {
847                name: "remote".to_string(),
848                proxied: false,
849                outcome: McpConnectOutcome::Failed {
850                    error: crate::client::McpError::ConnectionFailed("boom".to_string()),
851                },
852            })
853            .await;
854
855        let event = event_receiver.recv().await.expect("status change event");
856        let McpClientEvent::ServerStatusesChanged(servers) = event else {
857            panic!("expected ServerStatusesChanged");
858        };
859        let auth_event = event_receiver.recv().await.expect("authentication failure event");
860        let McpClientEvent::AuthenticationFailed { server, error } = auth_event else {
861            panic!("expected AuthenticationFailed");
862        };
863        assert_eq!(server, "remote");
864        assert!(error.contains("boom"));
865
866        let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
867        assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
868        assert!(matches!(status.status, McpServerStatus::Failed { ref error } if error.contains("boom")));
869        assert!(manager.authenticate_server_task("remote").await.is_ok());
870    }
871
872    #[test]
873    fn status_entries_are_derived_from_reauth_config() {
874        let (event_sender, _event_receiver) = mpsc::channel(1);
875        let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
876
877        manager.register_record(
878            "with-oauth",
879            McpServerStatus::Connected { tool_count: 1 },
880            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp")),
881            false,
882        );
883        manager.register_record("without-oauth", McpServerStatus::Connected { tool_count: 2 }, None, false);
884        manager.register_record(
885            "needs-oauth",
886            McpServerStatus::NeedsOAuth,
887            Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp2")),
888            false,
889        );
890
891        let statuses = manager.server_statuses();
892        let with_oauth = statuses.iter().find(|s| s.name == "with-oauth").unwrap();
893        let without_oauth = statuses.iter().find(|s| s.name == "without-oauth").unwrap();
894        let needs_oauth = statuses.iter().find(|s| s.name == "needs-oauth").unwrap();
895
896        assert_eq!(with_oauth.auth_capability, McpServerAuthCapability::OAuth);
897        assert_eq!(without_oauth.auth_capability, McpServerAuthCapability::Unavailable);
898        assert_eq!(needs_oauth.auth_capability, McpServerAuthCapability::OAuth);
899    }
900
901    #[tokio::test]
902    async fn server_statuses_mark_direct_and_proxied_servers_without_proxy_row() {
903        let (event_sender, _event_receiver) = mpsc::channel(1);
904        let mut manager = McpManager::new(event_sender, None);
905        manager
906            .add_mcps(vec![
907                McpServer::new("direct", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
908                McpServer::new("math", McpTransport::InMemory { server: TestServer::default().into_dyn() }, true),
909            ])
910            .await
911            .unwrap();
912
913        let statuses = manager.server_statuses();
914        assert_eq!(statuses.iter().map(|status| status.name.as_str()).collect::<Vec<_>>(), vec!["direct", "math"]);
915        assert!(!statuses.iter().find(|status| status.name == "direct").unwrap().proxied);
916        assert!(statuses.iter().find(|status| status.name == "math").unwrap().proxied);
917        assert!(!statuses.iter().any(|status| status.name == DEFAULT_PROXY_NAME));
918    }
919
920    #[test]
921    fn remove_registered_tools_for_server_uses_namespaced_prefix() {
922        let (event_sender, _event_receiver) = mpsc::channel(1);
923        let mut manager = McpManager::new(event_sender, None);
924        manager.tools.insert("git__status".to_string(), Tool { description: String::new(), parameters: json!({}) });
925        manager.tools.insert("github__issue".to_string(), Tool { description: String::new(), parameters: json!({}) });
926        manager.tool_definitions.push(ToolDefinition {
927            name: "git__status".to_string(),
928            description: String::new(),
929            parameters: "{}".to_string(),
930            server: Some("git".to_string()),
931        });
932        manager.tool_definitions.push(ToolDefinition {
933            name: "github__issue".to_string(),
934            description: String::new(),
935            parameters: "{}".to_string(),
936            server: Some("github".to_string()),
937        });
938
939        manager.remove_registered_tools_for_server("git");
940
941        assert!(!manager.tools.contains_key("git__status"));
942        assert!(manager.tools.contains_key("github__issue"));
943        assert_eq!(
944            manager.tool_definitions.iter().map(|tool| tool.name.as_str()).collect::<Vec<_>>(),
945            vec!["github__issue"]
946        );
947    }
948
949    #[tokio::test]
950    async fn drop_logs_cleanup_abort_with_tracing() {
951        let (event_sender, _event_receiver) = mpsc::channel(1);
952        let mut manager = McpManager::new(event_sender, None);
953        manager
954            .add_mcps(vec![McpServer::new(
955                "test",
956                McpTransport::InMemory { server: TestServer::default().into_dyn() },
957                false,
958            )])
959            .await
960            .unwrap();
961
962        let output = Arc::new(Mutex::new(Vec::new()));
963        let subscriber = tracing_subscriber::fmt()
964            .with_ansi(false)
965            .without_time()
966            .with_writer({
967                let output = Arc::clone(&output);
968                move || SharedWriter(Arc::clone(&output))
969            })
970            .finish();
971
972        tracing::subscriber::with_default(subscriber, || {
973            drop(manager);
974        });
975
976        let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
977        assert!(logs.contains("Server 'test' task aborted during cleanup"));
978    }
979}