1use llm::ToolDefinition;
2
3use super::{
4 McpError, Result,
5 config::McpServer,
6 connection::{
7 ConnectContext, McpConnectAttempt, McpConnectOutcome, McpServerConnection, ServerInstructions, Tool,
8 authenticate_http, connect_server,
9 },
10 mcp_client::McpClient,
11 naming::{create_namespaced_tool_name, split_on_server_name},
12 tool_proxy::ToolProxy,
13};
14use aether_auth::{OAuthCredentialStorage, OAuthHandler};
15use futures::future::join_all;
16use rmcp::{
17 RoleClient,
18 model::{
19 CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
20 ElicitationAction, FormElicitationCapability, Implementation, Root, Tool as RmcpTool, UrlElicitationCapability,
21 },
22 service::RunningService,
23 transport::streamable_http_client::StreamableHttpClientTransportConfig,
24};
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::{HashMap, HashSet};
28use std::future::Future;
29use std::path::PathBuf;
30use std::sync::Arc;
31use tokio::sync::{RwLock, mpsc, oneshot};
32
33pub use crate::status::{McpServerAuthCapability, McpServerStatus, McpServerStatusEntry};
34
35pub const DEFAULT_PROXY_NAME: &str = "proxy";
36
37pub type OAuthHandlerFactory = Arc<dyn Fn(OAuthHandlerContext) -> Result<Arc<dyn OAuthHandler>> + Send + Sync>;
38
39#[derive(Clone)]
42pub struct OAuthHandlerContext {
43 pub server_name: String,
44 pub tx: mpsc::Sender<McpClientEvent>,
45}
46
47#[derive(Debug)]
48pub struct ElicitationRequest {
49 pub server_name: String,
50 pub request: CreateElicitationRequestParams,
51 pub response_sender: oneshot::Sender<CreateElicitationResult>,
52}
53
54#[derive(Debug, Clone)]
55pub struct ElicitationResponse {
56 pub action: ElicitationAction,
57 pub content: Option<Value>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
61pub struct UrlElicitationCompleteParams {
62 pub server_name: String,
63 pub elicitation_id: String,
64}
65
66#[derive(Debug)]
70pub enum McpClientEvent {
71 Elicitation(ElicitationRequest),
72 UrlElicitationComplete(UrlElicitationCompleteParams),
73 ServerStatusesChanged(Vec<McpServerStatusEntry>),
74 ToolDefinitionsChanged(Vec<ToolDefinition>),
75 AuthenticationFailed { server: String, error: String },
76}
77
78pub struct McpManager {
80 servers: HashMap<String, ServerRecord>,
81 server_order: Vec<String>,
82 tools: HashMap<String, Tool>,
83 tool_definitions: Vec<ToolDefinition>,
84 proxy: Option<ToolProxy>,
85 aether_home: Option<PathBuf>,
86 client_info: ClientInfo,
87 event_sender: mpsc::Sender<McpClientEvent>,
88 roots: Arc<RwLock<Vec<Root>>>,
90 oauth_handler_factory: Option<OAuthHandlerFactory>,
91 oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
92 server_statuses: Vec<McpServerStatusEntry>,
93}
94
95impl McpManager {
96 pub fn new(event_sender: mpsc::Sender<McpClientEvent>, oauth_handler_factory: Option<OAuthHandlerFactory>) -> Self {
97 let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
98 if let Some(elicitation) = capabilities.elicitation.as_mut() {
99 elicitation.form = Some(FormElicitationCapability::default());
100 elicitation.url = Some(UrlElicitationCapability::default());
101 }
102
103 Self {
104 servers: HashMap::new(),
105 server_order: Vec::new(),
106 tools: HashMap::new(),
107 tool_definitions: Vec::new(),
108 proxy: None,
109 aether_home: None,
110 client_info: ClientInfo::new(capabilities, Implementation::new("aether", "0.1.0")),
111 event_sender,
112 roots: Arc::new(RwLock::new(Vec::new())),
113 oauth_handler_factory,
114 oauth_credential_store: None,
115 server_statuses: Vec::new(),
116 }
117 }
118
119 pub fn with_aether_home(mut self, aether_home: impl Into<PathBuf>) -> Self {
120 self.aether_home = Some(aether_home.into());
121 self
122 }
123
124 pub fn with_oauth_credential_store(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
125 self.oauth_credential_store = Some(store);
126 self
127 }
128
129 pub async fn add_mcps(&mut self, servers: Vec<McpServer>) -> Result<()> {
130 let has_proxy = servers.iter().any(|server| server.proxy);
131 if has_proxy && servers.iter().any(|server| server.name == DEFAULT_PROXY_NAME) {
132 return Err(McpError::Other("server name 'proxy' collides with the tool proxy".into()));
133 }
134
135 let proxied_members: HashSet<String> =
136 servers.iter().filter(|server| server.proxy).map(|server| server.name.clone()).collect();
137 let proxy_tool_dir = if has_proxy {
138 let dir = self.proxy_tool_dir()?;
139 ToolProxy::clean_dir(&dir).await?;
140 Some(dir)
141 } else {
142 None
143 };
144
145 let ctx = self.connect_context();
146 let attempts = join_all(servers.into_iter().map(|server| connect_server(server, &ctx))).await;
147
148 let mut connected_proxied = Vec::new();
149 for McpConnectAttempt { name, proxied, outcome } in attempts {
150 match outcome {
151 McpConnectOutcome::Connected { conn, reauth_config } => {
152 self.register_connection(&name, conn, reauth_config, proxied).await?;
153 if proxied {
154 connected_proxied.push(name);
155 }
156 }
157 McpConnectOutcome::NeedsOAuth { config, error } => {
158 tracing::warn!("Server '{name}' needs OAuth: {error}");
159 self.register_record(&name, McpServerStatus::NeedsOAuth, Some(config), proxied);
160 }
161 McpConnectOutcome::Failed { error } => {
162 tracing::warn!("Failed to connect to MCP server '{name}': {error}");
163 if !self.servers.contains_key(&name) {
164 self.register_record(
165 &name,
166 McpServerStatus::Failed { error: error.to_string() },
167 None,
168 proxied,
169 );
170 }
171 }
172 }
173 }
174
175 if let Some(tool_dir) = proxy_tool_dir {
176 self.write_proxy_tool_files(&connected_proxied, &tool_dir).await;
177 self.register_proxy(tool_dir, proxied_members);
178 }
179
180 Ok(())
181 }
182
183 pub fn get_client_for_tool(
184 &self,
185 namespaced_tool_name: &str,
186 arguments_json: &str,
187 ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
188 if !self.tools.contains_key(namespaced_tool_name) {
189 return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
190 }
191
192 let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
193 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
194
195 if let Some(proxy) = self.proxy.as_ref().filter(|proxy| proxy.name() == server_name) {
196 let call = proxy.resolve_call(arguments_json)?;
197 let conn = self.connection_for(&call.server).ok_or_else(|| {
198 McpError::ServerNotFound(format!("Proxied server '{}' is not connected", call.server))
199 })?;
200 let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
201 return Ok((conn.client.clone(), params));
202 }
203
204 let client =
205 self.client_for_server(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
206
207 let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
208 let mut params = CallToolRequestParams::new(tool_name.to_string());
209 if let Some(args) = arguments {
210 params = params.with_arguments(args);
211 }
212
213 Ok((client, params))
214 }
215
216 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
217 self.tool_definitions.clone()
218 }
219
220 pub fn server_instructions(&self) -> Vec<ServerInstructions> {
221 let mut instructions: Vec<ServerInstructions> = self
222 .servers
223 .iter()
224 .filter(|(name, _)| self.proxy.as_ref().is_none_or(|proxy| !proxy.contains_server(name)))
225 .filter_map(|(name, record)| {
226 record
227 .connection
228 .as_ref()
229 .and_then(|conn| conn.instructions.as_ref())
230 .map(|instr| ServerInstructions { server_name: name.clone(), instructions: instr.clone() })
231 })
232 .collect();
233
234 if let Some(proxy) = &self.proxy {
235 let descriptions: Vec<(String, String)> = proxy
236 .members()
237 .iter()
238 .filter_map(|member| {
239 let conn = self.connection_for(member)?;
240 Some((member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
241 })
242 .collect();
243 instructions.push(ServerInstructions {
244 server_name: proxy.name().to_string(),
245 instructions: ToolProxy::build_instructions(proxy.tool_dir(), &descriptions),
246 });
247 }
248
249 instructions
250 }
251
252 pub fn server_statuses(&self) -> &[McpServerStatusEntry] {
253 &self.server_statuses
254 }
255
256 pub async fn authenticate_server_task(
257 &mut self,
258 name: &str,
259 ) -> Result<impl Future<Output = McpConnectAttempt> + Send + 'static> {
260 let record = self
261 .servers
262 .get(name)
263 .ok_or_else(|| McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")))?;
264 if !record.can_authenticate() {
265 return Err(McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")));
266 }
267
268 let oauth_handler_factory = self
269 .oauth_handler_factory
270 .clone()
271 .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler factory available for '{name}'")))?;
272 let oauth_credential_store = self.oauth_credential_store.clone();
273 let name = name.to_string();
274 let config = record.reauth_config.clone().expect("checked above");
275 let client_info = self.client_info.clone();
276 let event_sender = self.event_sender.clone();
277 let roots = Arc::clone(&self.roots);
278 let proxied = record.proxied;
279
280 self.set_status(&name, McpServerStatus::Authenticating);
281 self.emit_server_statuses_changed().await;
282
283 Ok(async move {
284 authenticate_http(
285 name,
286 config,
287 client_info,
288 event_sender,
289 roots,
290 oauth_handler_factory,
291 oauth_credential_store,
292 proxied,
293 )
294 .await
295 })
296 }
297
298 pub async fn apply_connection_attempt(&mut self, attempt: McpConnectAttempt) {
299 let McpConnectAttempt { name, proxied, outcome } = attempt;
300 match outcome {
301 McpConnectOutcome::Connected { conn, reauth_config } => {
302 match self.register_connection(&name, conn, reauth_config, proxied).await {
303 Ok(tools) => {
304 self.refresh_proxy_after_auth(&name, &tools, proxied).await;
305 self.emit_server_statuses_changed().await;
306 self.emit_tool_definitions_changed().await;
307 }
308 Err(error) => self.apply_authentication_failure(name, error.to_string()).await,
309 }
310 }
311 McpConnectOutcome::Failed { error } => {
312 self.apply_authentication_failure(name, error.to_string()).await;
313 }
314 McpConnectOutcome::NeedsOAuth { .. } => {
315 self.apply_authentication_failure(name, "internal error: auth task returned NeedsOAuth".to_string())
316 .await;
317 }
318 }
319 }
320
321 pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
323 let futures: Vec<_> = self
324 .servers
325 .iter()
326 .filter_map(|(server_name, record)| {
327 let conn = record.connection.as_ref()?;
328 conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref())?;
329 let server_name = server_name.clone();
330 let client = conn.client.clone();
331 Some(async move {
332 let prompts_response = client.list_prompts(None).await.map_err(|e| {
333 McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
334 })?;
335
336 let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
337 .prompts
338 .into_iter()
339 .map(|prompt| {
340 let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
341 rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
342 })
343 .collect();
344
345 Ok::<_, McpError>(namespaced_prompts)
346 })
347 })
348 .collect();
349
350 let results = join_all(futures).await;
351 let mut all_prompts = Vec::new();
352 for result in results {
353 all_prompts.extend(result?);
354 }
355
356 Ok(all_prompts)
357 }
358
359 pub async fn get_prompt(
361 &self,
362 namespaced_prompt_name: &str,
363 arguments: Option<serde_json::Map<String, serde_json::Value>>,
364 ) -> Result<rmcp::model::GetPromptResult> {
365 let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
366 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
367
368 let server_conn =
369 self.connection_for(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
370
371 let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
372 if let Some(args) = arguments {
373 request = request.with_arguments(args);
374 }
375
376 server_conn.client.get_prompt(request).await.map_err(|e| {
377 McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
378 })
379 }
380
381 pub async fn shutdown(&mut self) {
383 let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
384
385 for (server_name, record) in servers {
386 if let Some(conn) = record.connection
387 && let Some(handle) = conn.server_task
388 {
389 drop(conn.client);
390
391 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
392 Ok(Ok(())) => {
393 tracing::info!("Server '{server_name}' shut down gracefully");
394 }
395 Ok(Err(e)) => {
396 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
397 }
398 Err(_) => {
399 tracing::warn!("Server '{server_name}' shutdown timed out");
400 }
401 }
402 }
403 }
404
405 self.tools.clear();
406 self.tool_definitions.clear();
407 self.proxy = None;
408 }
409
410 pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
412 let record = self.servers.remove(server_name);
413
414 if let Some(record) = record {
415 if let Some(conn) = record.connection
416 && let Some(handle) = conn.server_task
417 {
418 drop(conn.client);
419
420 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
421 Ok(Ok(())) => {
422 tracing::info!("Server '{server_name}' shut down gracefully");
423 }
424 Ok(Err(e)) => {
425 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
426 }
427 Err(_) => {
428 tracing::warn!("Server '{server_name}' shutdown timed out");
429 }
430 }
431 }
432
433 self.remove_registered_tools_for_server(server_name);
434 self.refresh_status_entries();
435 }
436
437 Ok(())
438 }
439
440 pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
445 {
446 let mut roots = self.roots.write().await;
447 *roots = new_roots;
448 }
449
450 self.notify_roots_changed().await;
451
452 Ok(())
453 }
454
455 async fn emit_server_statuses_changed(&self) {
456 self.emit_event(McpClientEvent::ServerStatusesChanged(self.server_statuses().to_vec())).await;
457 }
458
459 async fn emit_tool_definitions_changed(&self) {
460 self.emit_event(McpClientEvent::ToolDefinitionsChanged(self.tool_definitions())).await;
461 }
462
463 async fn emit_authentication_failed(&self, server: String, error: String) {
464 self.emit_event(McpClientEvent::AuthenticationFailed { server, error }).await;
465 }
466
467 async fn emit_event(&self, event: McpClientEvent) {
468 if let Err(e) = self.event_sender.send(event).await {
469 tracing::warn!("Failed to emit MCP client event: {e}");
470 }
471 }
472
473 fn connect_context(&self) -> ConnectContext<'_> {
474 ConnectContext {
475 client_info: &self.client_info,
476 event_sender: &self.event_sender,
477 roots: &self.roots,
478 oauth_handler_factory: self.oauth_handler_factory.as_ref(),
479 oauth_credential_store: self.oauth_credential_store.as_ref(),
480 }
481 }
482
483 fn proxy_tool_dir(&self) -> Result<PathBuf> {
484 self.aether_home
485 .as_ref()
486 .map(|home| ToolProxy::dir_in_home(home, DEFAULT_PROXY_NAME))
487 .map_or_else(|| ToolProxy::dir(DEFAULT_PROXY_NAME), Ok)
488 }
489
490 async fn register_connection(
491 &mut self,
492 name: &str,
493 conn: McpServerConnection,
494 reauth_config: Option<StreamableHttpClientTransportConfig>,
495 proxied: bool,
496 ) -> Result<Vec<RmcpTool>> {
497 let tools = conn
498 .list_tools()
499 .await
500 .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
501 self.apply_connected(name, conn, &tools, reauth_config, proxied);
502 Ok(tools)
503 }
504
505 fn apply_connected(
506 &mut self,
507 name: &str,
508 conn: McpServerConnection,
509 tools: &[RmcpTool],
510 reauth_config: Option<StreamableHttpClientTransportConfig>,
511 proxied: bool,
512 ) {
513 self.remove_registered_tools_for_server(name);
514
515 let existing_reauth = self.servers.get(name).and_then(|r| r.reauth_config.clone());
516 let final_reauth = reauth_config.or(existing_reauth);
517
518 for rmcp_tool in tools {
519 let tool_name = rmcp_tool.name.to_string();
520 let namespaced_tool_name = create_namespaced_tool_name(name, &tool_name);
521 let tool = Tool::from(rmcp_tool);
522
523 if !proxied {
524 self.tool_definitions.push(ToolDefinition {
525 name: namespaced_tool_name.clone(),
526 description: tool.description.clone(),
527 parameters: tool.parameters.to_string(),
528 server: Some(name.to_string()),
529 });
530 self.tools.insert(namespaced_tool_name, tool);
531 }
532 }
533
534 self.remember_server_order(name);
535 self.servers.insert(name.to_string(), ServerRecord::connected(conn, tools.len(), final_reauth, proxied));
536 self.refresh_status_entries();
537 }
538
539 fn register_proxy(&mut self, tool_dir: std::path::PathBuf, members: HashSet<String>) {
540 self.remove_registered_tools_for_server(DEFAULT_PROXY_NAME);
541 let call_tool_def = ToolProxy::call_tool_definition(DEFAULT_PROXY_NAME);
542 self.tools.insert(
543 call_tool_def.name.clone(),
544 Tool {
545 description: call_tool_def.description.clone(),
546 parameters: serde_json::from_str(&call_tool_def.parameters)
547 .unwrap_or(Value::Object(serde_json::Map::default())),
548 },
549 );
550 self.tool_definitions.push(call_tool_def);
551
552 self.proxy = Some(ToolProxy::new(DEFAULT_PROXY_NAME.to_string(), members, tool_dir));
553 }
554
555 async fn refresh_proxy_after_auth(&mut self, name: &str, tools: &[RmcpTool], proxied: bool) {
556 if !proxied {
557 return;
558 }
559
560 if let Some(proxy) = self.proxy.as_mut() {
561 proxy.add_member(name.to_string());
562 }
563
564 if let Some(tool_dir) = self.proxy.as_ref().map(|proxy| proxy.tool_dir().to_path_buf())
565 && let Err(e) = ToolProxy::write_tool_entries_to_dir(name, tools, &tool_dir).await
566 {
567 tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
568 }
569 }
570
571 async fn write_proxy_tool_files(&self, connected_proxied: &[String], tool_dir: &std::path::Path) {
572 let writes = connected_proxied.iter().filter_map(|name| {
573 let client = self.client_for_server(name)?;
574 let dir = tool_dir.to_path_buf();
575 let name = name.clone();
576 Some(async move {
577 if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &dir).await {
578 tracing::warn!("Failed to write tool files for proxied server '{name}': {e}");
579 }
580 })
581 });
582 join_all(writes).await;
583 }
584
585 fn refresh_status_entries(&mut self) {
586 self.server_statuses = self
587 .server_order
588 .iter()
589 .filter_map(|name| self.servers.get(name).map(|record| record.status_entry(name)))
590 .collect();
591 }
592
593 fn remember_server_order(&mut self, name: &str) {
594 if !self.server_order.iter().any(|n| n == name) {
595 self.server_order.push(name.to_string());
596 }
597 }
598
599 async fn apply_authentication_failure(&mut self, name: String, error: String) {
600 self.set_status(&name, McpServerStatus::Failed { error: error.clone() });
601 self.emit_server_statuses_changed().await;
602 self.emit_authentication_failed(name, error).await;
603 }
604
605 fn set_status(&mut self, name: &str, status: McpServerStatus) {
606 self.remember_server_order(name);
607 let record =
608 self.servers.entry(name.to_string()).or_insert_with(|| ServerRecord::new(status.clone(), None, false));
609 record.status = status;
610 self.refresh_status_entries();
611 }
612
613 fn register_record(
614 &mut self,
615 name: &str,
616 status: McpServerStatus,
617 reauth_config: Option<StreamableHttpClientTransportConfig>,
618 proxied: bool,
619 ) {
620 self.remember_server_order(name);
621 self.servers.insert(name.to_string(), ServerRecord::new(status, reauth_config, proxied));
622 self.refresh_status_entries();
623 }
624
625 fn connection_for(&self, server_name: &str) -> Option<&McpServerConnection> {
626 self.servers.get(server_name).and_then(|record| record.connection.as_ref())
627 }
628
629 fn client_for_server(&self, server_name: &str) -> Option<Arc<RunningService<RoleClient, McpClient>>> {
630 self.connection_for(server_name).map(|conn| conn.client.clone())
631 }
632
633 fn remove_registered_tools_for_server(&mut self, server_name: &str) {
634 let prefix = format!("{server_name}__");
635 self.tools.retain(|tool_name, _| !tool_name.starts_with(&prefix));
636 self.tool_definitions.retain(|tool_def| !tool_def.name.starts_with(&prefix));
637 }
638
639 async fn notify_roots_changed(&self) {
640 for (server_name, record) in &self.servers {
641 if let Some(conn) = &record.connection
642 && let Err(e) = conn.client.notify_roots_list_changed().await
643 {
644 tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
645 }
646 }
647 }
648}
649
650impl Drop for McpManager {
651 fn drop(&mut self) {
652 let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
653 for (server_name, record) in servers {
654 if let Some(conn) = record.connection
655 && let Some(handle) = conn.server_task
656 {
657 handle.abort();
658 tracing::warn!("Server '{server_name}' task aborted during cleanup");
659 }
660 }
661 }
662}
663
664struct ServerRecord {
666 connection: Option<McpServerConnection>,
667 status: McpServerStatus,
668 reauth_config: Option<StreamableHttpClientTransportConfig>,
669 proxied: bool,
670}
671
672impl ServerRecord {
673 fn new(status: McpServerStatus, reauth_config: Option<StreamableHttpClientTransportConfig>, proxied: bool) -> Self {
674 Self { connection: None, status, reauth_config, proxied }
675 }
676
677 fn connected(
678 connection: McpServerConnection,
679 tool_count: usize,
680 reauth_config: Option<StreamableHttpClientTransportConfig>,
681 proxied: bool,
682 ) -> Self {
683 Self { connection: Some(connection), status: McpServerStatus::Connected { tool_count }, reauth_config, proxied }
684 }
685
686 fn auth_capability(&self) -> McpServerAuthCapability {
687 if self.reauth_config.is_some() { McpServerAuthCapability::OAuth } else { McpServerAuthCapability::Unavailable }
688 }
689
690 fn can_authenticate(&self) -> bool {
691 self.reauth_config.is_some()
692 }
693
694 fn status_entry(&self, name: &str) -> McpServerStatusEntry {
695 McpServerStatusEntry::new(name, self.status.clone())
696 .with_auth_capability(self.auth_capability())
697 .with_proxied(self.proxied)
698 }
699}
700
701#[cfg(test)]
702mod tests {
703 use super::{DEFAULT_PROXY_NAME, McpClientEvent, McpManager, McpServerStatus, Tool};
704 use crate::client::OAuthHandlerFactory;
705 use crate::client::config::{McpServer, McpTransport};
706 use crate::client::connection::{McpConnectAttempt, McpConnectOutcome};
707 use crate::status::McpServerAuthCapability;
708 use aether_auth::{OAuthCallback, OAuthError, OAuthHandler};
709 use futures::future::BoxFuture;
710 use llm::ToolDefinition;
711 use rmcp::{
712 Json, RoleServer, ServerHandler,
713 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
714 model::{Implementation, ServerCapabilities, ServerInfo},
715 service::DynService,
716 tool, tool_handler, tool_router,
717 transport::streamable_http_client::StreamableHttpClientTransportConfig,
718 };
719 use schemars::JsonSchema;
720 use serde::{Deserialize, Serialize};
721 use serde_json::json;
722 use std::{
723 io,
724 sync::{Arc, Mutex},
725 };
726 use tokio::sync::mpsc;
727
728 #[derive(Clone)]
729 struct TestServer {
730 tool_router: ToolRouter<Self>,
731 }
732
733 #[tool_handler(router = self.tool_router)]
734 impl ServerHandler for TestServer {
735 fn get_info(&self) -> ServerInfo {
736 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
737 .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
738 }
739 }
740
741 impl Default for TestServer {
742 fn default() -> Self {
743 Self { tool_router: Self::tool_router() }
744 }
745 }
746
747 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
748 struct EchoRequest {
749 value: String,
750 }
751
752 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
753 struct EchoResult {
754 value: String,
755 }
756
757 #[tool_router]
758 impl TestServer {
759 fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
760 Box::new(self)
761 }
762
763 #[tool(description = "Returns the provided value")]
764 async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
765 let Parameters(EchoRequest { value }) = request;
766 Json(EchoResult { value })
767 }
768 }
769
770 #[derive(Clone)]
771 struct SharedWriter(Arc<Mutex<Vec<u8>>>);
772
773 impl io::Write for SharedWriter {
774 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
775 self.0.lock().unwrap().extend_from_slice(buf);
776 Ok(buf.len())
777 }
778
779 fn flush(&mut self) -> io::Result<()> {
780 Ok(())
781 }
782 }
783
784 struct TestOAuthHandler;
785
786 impl OAuthHandler for TestOAuthHandler {
787 fn redirect_uri(&self) -> &'static str {
788 "http://127.0.0.1:0/oauth2callback"
789 }
790
791 fn authorize(&self, _auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
792 Box::pin(async { Err(OAuthError::UserCancelled) })
793 }
794 }
795
796 fn test_oauth_handler_factory() -> OAuthHandlerFactory {
797 Arc::new(|_ctx| Ok(Arc::new(TestOAuthHandler)))
798 }
799
800 #[tokio::test]
801 async fn authenticate_server_task_rejects_record_without_reauth_config() {
802 let (event_sender, _event_receiver) = mpsc::channel(1);
803 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
804 manager.register_record("public", McpServerStatus::Connected { tool_count: 1 }, None, false);
805
806 let error = match manager.authenticate_server_task("public").await {
807 Ok(_) => panic!("non-OAuth server should be rejected"),
808 Err(error) => error.to_string(),
809 };
810 assert!(error.contains("not OAuth-authenticatable"));
811 }
812
813 #[tokio::test]
814 async fn authenticate_server_task_marks_server_authenticating_and_emits_status() {
815 let (event_sender, mut event_receiver) = mpsc::channel(2);
816 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
817 manager.register_record(
818 "remote",
819 McpServerStatus::NeedsOAuth,
820 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
821 false,
822 );
823
824 let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
825
826 assert!(matches!(manager.server_statuses()[0].status, McpServerStatus::Authenticating));
827 let event = event_receiver.recv().await.expect("status change event");
828 let McpClientEvent::ServerStatusesChanged(servers) = event else {
829 panic!("expected ServerStatusesChanged");
830 };
831 let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
832 assert!(matches!(status.status, McpServerStatus::Authenticating));
833 assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
834 }
835
836 #[tokio::test]
837 async fn apply_connection_attempt_failure_allows_retry() {
838 let (event_sender, mut event_receiver) = mpsc::channel(2);
839 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
840 manager.register_record(
841 "remote",
842 McpServerStatus::NeedsOAuth,
843 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
844 false,
845 );
846 let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
847 let _authenticating_event = event_receiver.recv().await.expect("authenticating status change event");
848
849 manager
850 .apply_connection_attempt(McpConnectAttempt {
851 name: "remote".to_string(),
852 proxied: false,
853 outcome: McpConnectOutcome::Failed {
854 error: crate::client::McpError::ConnectionFailed("boom".to_string()),
855 },
856 })
857 .await;
858
859 let event = event_receiver.recv().await.expect("status change event");
860 let McpClientEvent::ServerStatusesChanged(servers) = event else {
861 panic!("expected ServerStatusesChanged");
862 };
863 let auth_event = event_receiver.recv().await.expect("authentication failure event");
864 let McpClientEvent::AuthenticationFailed { server, error } = auth_event else {
865 panic!("expected AuthenticationFailed");
866 };
867 assert_eq!(server, "remote");
868 assert!(error.contains("boom"));
869
870 let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
871 assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
872 assert!(matches!(status.status, McpServerStatus::Failed { ref error } if error.contains("boom")));
873 assert!(manager.authenticate_server_task("remote").await.is_ok());
874 }
875
876 #[test]
877 fn status_entries_are_derived_from_reauth_config() {
878 let (event_sender, _event_receiver) = mpsc::channel(1);
879 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
880
881 manager.register_record(
882 "with-oauth",
883 McpServerStatus::Connected { tool_count: 1 },
884 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp")),
885 false,
886 );
887 manager.register_record("without-oauth", McpServerStatus::Connected { tool_count: 2 }, None, false);
888 manager.register_record(
889 "needs-oauth",
890 McpServerStatus::NeedsOAuth,
891 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp2")),
892 false,
893 );
894
895 let statuses = manager.server_statuses();
896 let with_oauth = statuses.iter().find(|s| s.name == "with-oauth").unwrap();
897 let without_oauth = statuses.iter().find(|s| s.name == "without-oauth").unwrap();
898 let needs_oauth = statuses.iter().find(|s| s.name == "needs-oauth").unwrap();
899
900 assert_eq!(with_oauth.auth_capability, McpServerAuthCapability::OAuth);
901 assert_eq!(without_oauth.auth_capability, McpServerAuthCapability::Unavailable);
902 assert_eq!(needs_oauth.auth_capability, McpServerAuthCapability::OAuth);
903 }
904
905 #[tokio::test]
906 async fn server_statuses_mark_direct_and_proxied_servers_without_proxy_row() {
907 let (event_sender, _event_receiver) = mpsc::channel(1);
908 let mut manager = McpManager::new(event_sender, None);
909 manager
910 .add_mcps(vec![
911 McpServer::new("direct", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
912 McpServer::new("math", McpTransport::InMemory { server: TestServer::default().into_dyn() }, true),
913 ])
914 .await
915 .unwrap();
916
917 let statuses = manager.server_statuses();
918 assert_eq!(statuses.iter().map(|status| status.name.as_str()).collect::<Vec<_>>(), vec!["direct", "math"]);
919 assert!(!statuses.iter().find(|status| status.name == "direct").unwrap().proxied);
920 assert!(statuses.iter().find(|status| status.name == "math").unwrap().proxied);
921 assert!(!statuses.iter().any(|status| status.name == DEFAULT_PROXY_NAME));
922 }
923
924 #[test]
925 fn remove_registered_tools_for_server_uses_namespaced_prefix() {
926 let (event_sender, _event_receiver) = mpsc::channel(1);
927 let mut manager = McpManager::new(event_sender, None);
928 manager.tools.insert("git__status".to_string(), Tool { description: String::new(), parameters: json!({}) });
929 manager.tools.insert("github__issue".to_string(), Tool { description: String::new(), parameters: json!({}) });
930 manager.tool_definitions.push(ToolDefinition {
931 name: "git__status".to_string(),
932 description: String::new(),
933 parameters: "{}".to_string(),
934 server: Some("git".to_string()),
935 });
936 manager.tool_definitions.push(ToolDefinition {
937 name: "github__issue".to_string(),
938 description: String::new(),
939 parameters: "{}".to_string(),
940 server: Some("github".to_string()),
941 });
942
943 manager.remove_registered_tools_for_server("git");
944
945 assert!(!manager.tools.contains_key("git__status"));
946 assert!(manager.tools.contains_key("github__issue"));
947 assert_eq!(
948 manager.tool_definitions.iter().map(|tool| tool.name.as_str()).collect::<Vec<_>>(),
949 vec!["github__issue"]
950 );
951 }
952
953 #[tokio::test]
954 async fn drop_logs_cleanup_abort_with_tracing() {
955 let (event_sender, _event_receiver) = mpsc::channel(1);
956 let mut manager = McpManager::new(event_sender, None);
957 manager
958 .add_mcps(vec![McpServer::new(
959 "test",
960 McpTransport::InMemory { server: TestServer::default().into_dyn() },
961 false,
962 )])
963 .await
964 .unwrap();
965
966 let output = Arc::new(Mutex::new(Vec::new()));
967 let subscriber = tracing_subscriber::fmt()
968 .with_ansi(false)
969 .without_time()
970 .with_writer({
971 let output = Arc::clone(&output);
972 move || SharedWriter(Arc::clone(&output))
973 })
974 .finish();
975
976 tracing::subscriber::with_default(subscriber, || {
977 drop(manager);
978 });
979
980 let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
981 assert!(logs.contains("Server 'test' task aborted during cleanup"));
982 }
983}