1use llm::ToolDefinition;
2
3use super::{
4 McpError, Result,
5 config::McpServer,
6 connection::{
7 ConnectContext, McpConnectAttempt, McpConnectOutcome, McpServerConnection, ServerInstructions, Tool,
8 authenticate_http, connect_server,
9 },
10 mcp_client::McpClient,
11 naming::{create_namespaced_tool_name, split_on_server_name},
12 tool_proxy::ToolProxy,
13};
14use aether_auth::{OAuthCredentialStorage, OAuthHandler};
15use futures::future::join_all;
16use rmcp::{
17 RoleClient,
18 model::{
19 CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
20 ElicitationAction, FormElicitationCapability, Implementation, Root, Tool as RmcpTool, UrlElicitationCapability,
21 },
22 service::RunningService,
23 transport::streamable_http_client::StreamableHttpClientTransportConfig,
24};
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::{HashMap, HashSet};
28use std::future::Future;
29use std::path::PathBuf;
30use std::sync::Arc;
31use tokio::sync::{RwLock, mpsc, oneshot};
32
33pub use crate::status::{McpServerAuthCapability, McpServerStatus, McpServerStatusEntry};
34
35pub const DEFAULT_PROXY_NAME: &str = "proxy";
36
37pub type OAuthHandlerFactory = Arc<dyn Fn() -> Result<Arc<dyn OAuthHandler>> + Send + Sync>;
38
39#[derive(Debug)]
40pub struct ElicitationRequest {
41 pub server_name: String,
42 pub request: CreateElicitationRequestParams,
43 pub response_sender: oneshot::Sender<CreateElicitationResult>,
44}
45
46#[derive(Debug, Clone)]
47pub struct ElicitationResponse {
48 pub action: ElicitationAction,
49 pub content: Option<Value>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
53pub struct UrlElicitationCompleteParams {
54 pub server_name: String,
55 pub elicitation_id: String,
56}
57
58#[derive(Debug)]
62pub enum McpClientEvent {
63 Elicitation(ElicitationRequest),
64 UrlElicitationComplete(UrlElicitationCompleteParams),
65 ServerStatusesChanged(Vec<McpServerStatusEntry>),
66 ToolDefinitionsChanged(Vec<ToolDefinition>),
67 AuthenticationFailed { server: String, error: String },
68}
69
70pub struct McpManager {
72 servers: HashMap<String, ServerRecord>,
73 server_order: Vec<String>,
74 tools: HashMap<String, Tool>,
75 tool_definitions: Vec<ToolDefinition>,
76 proxy: Option<ToolProxy>,
77 aether_home: Option<PathBuf>,
78 client_info: ClientInfo,
79 event_sender: mpsc::Sender<McpClientEvent>,
80 roots: Arc<RwLock<Vec<Root>>>,
82 oauth_handler_factory: Option<OAuthHandlerFactory>,
83 oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
84 server_statuses: Vec<McpServerStatusEntry>,
85}
86
87impl McpManager {
88 pub fn new(event_sender: mpsc::Sender<McpClientEvent>, oauth_handler_factory: Option<OAuthHandlerFactory>) -> Self {
89 let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
90 if let Some(elicitation) = capabilities.elicitation.as_mut() {
91 elicitation.form = Some(FormElicitationCapability::default());
92 elicitation.url = Some(UrlElicitationCapability::default());
93 }
94
95 Self {
96 servers: HashMap::new(),
97 server_order: Vec::new(),
98 tools: HashMap::new(),
99 tool_definitions: Vec::new(),
100 proxy: None,
101 aether_home: None,
102 client_info: ClientInfo::new(capabilities, Implementation::new("aether", "0.1.0")),
103 event_sender,
104 roots: Arc::new(RwLock::new(Vec::new())),
105 oauth_handler_factory,
106 oauth_credential_store: None,
107 server_statuses: Vec::new(),
108 }
109 }
110
111 pub fn with_aether_home(mut self, aether_home: impl Into<PathBuf>) -> Self {
112 self.aether_home = Some(aether_home.into());
113 self
114 }
115
116 pub fn with_oauth_credential_store(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
117 self.oauth_credential_store = Some(store);
118 self
119 }
120
121 pub async fn add_mcps(&mut self, servers: Vec<McpServer>) -> Result<()> {
122 let has_proxy = servers.iter().any(|server| server.proxy);
123 if has_proxy && servers.iter().any(|server| server.name == DEFAULT_PROXY_NAME) {
124 return Err(McpError::Other("server name 'proxy' collides with the tool proxy".into()));
125 }
126
127 let proxied_members: HashSet<String> =
128 servers.iter().filter(|server| server.proxy).map(|server| server.name.clone()).collect();
129 let proxy_tool_dir = if has_proxy {
130 let dir = self.proxy_tool_dir()?;
131 ToolProxy::clean_dir(&dir).await?;
132 Some(dir)
133 } else {
134 None
135 };
136
137 let ctx = self.connect_context();
138 let attempts = join_all(servers.into_iter().map(|server| connect_server(server, &ctx))).await;
139
140 let mut connected_proxied = Vec::new();
141 for McpConnectAttempt { name, proxied, outcome } in attempts {
142 match outcome {
143 McpConnectOutcome::Connected { conn, reauth_config } => {
144 self.register_connection(&name, conn, reauth_config, proxied).await?;
145 if proxied {
146 connected_proxied.push(name);
147 }
148 }
149 McpConnectOutcome::NeedsOAuth { config, error } => {
150 tracing::warn!("Server '{name}' needs OAuth: {error}");
151 self.register_record(&name, McpServerStatus::NeedsOAuth, Some(config), proxied);
152 }
153 McpConnectOutcome::Failed { error } => {
154 tracing::warn!("Failed to connect to MCP server '{name}': {error}");
155 if !self.servers.contains_key(&name) {
156 self.register_record(
157 &name,
158 McpServerStatus::Failed { error: error.to_string() },
159 None,
160 proxied,
161 );
162 }
163 }
164 }
165 }
166
167 if let Some(tool_dir) = proxy_tool_dir {
168 self.write_proxy_tool_files(&connected_proxied, &tool_dir).await;
169 self.register_proxy(tool_dir, proxied_members);
170 }
171
172 Ok(())
173 }
174
175 pub fn get_client_for_tool(
176 &self,
177 namespaced_tool_name: &str,
178 arguments_json: &str,
179 ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
180 if !self.tools.contains_key(namespaced_tool_name) {
181 return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
182 }
183
184 let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
185 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
186
187 if let Some(proxy) = self.proxy.as_ref().filter(|proxy| proxy.name() == server_name) {
188 let call = proxy.resolve_call(arguments_json)?;
189 let conn = self.connection_for(&call.server).ok_or_else(|| {
190 McpError::ServerNotFound(format!("Proxied server '{}' is not connected", call.server))
191 })?;
192 let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
193 return Ok((conn.client.clone(), params));
194 }
195
196 let client =
197 self.client_for_server(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
198
199 let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
200 let mut params = CallToolRequestParams::new(tool_name.to_string());
201 if let Some(args) = arguments {
202 params = params.with_arguments(args);
203 }
204
205 Ok((client, params))
206 }
207
208 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
209 self.tool_definitions.clone()
210 }
211
212 pub fn server_instructions(&self) -> Vec<ServerInstructions> {
213 let mut instructions: Vec<ServerInstructions> = self
214 .servers
215 .iter()
216 .filter(|(name, _)| self.proxy.as_ref().is_none_or(|proxy| !proxy.contains_server(name)))
217 .filter_map(|(name, record)| {
218 record
219 .connection
220 .as_ref()
221 .and_then(|conn| conn.instructions.as_ref())
222 .map(|instr| ServerInstructions { server_name: name.clone(), instructions: instr.clone() })
223 })
224 .collect();
225
226 if let Some(proxy) = &self.proxy {
227 let descriptions: Vec<(String, String)> = proxy
228 .members()
229 .iter()
230 .filter_map(|member| {
231 let conn = self.connection_for(member)?;
232 Some((member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
233 })
234 .collect();
235 instructions.push(ServerInstructions {
236 server_name: proxy.name().to_string(),
237 instructions: ToolProxy::build_instructions(proxy.tool_dir(), &descriptions),
238 });
239 }
240
241 instructions
242 }
243
244 pub fn server_statuses(&self) -> &[McpServerStatusEntry] {
245 &self.server_statuses
246 }
247
248 pub async fn authenticate_server_task(
249 &mut self,
250 name: &str,
251 ) -> Result<impl Future<Output = McpConnectAttempt> + Send + 'static> {
252 let record = self
253 .servers
254 .get(name)
255 .ok_or_else(|| McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")))?;
256 if !record.can_authenticate() {
257 return Err(McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")));
258 }
259 if matches!(record.status, McpServerStatus::Authenticating) {
260 return Err(McpError::ConnectionFailed(format!("server '{name}' is already authenticating")));
261 }
262
263 let oauth_handler_factory = self
264 .oauth_handler_factory
265 .clone()
266 .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler factory available for '{name}'")))?;
267 let oauth_credential_store = self.oauth_credential_store.clone();
268 let name = name.to_string();
269 let config = record.reauth_config.clone().expect("checked above");
270 let client_info = self.client_info.clone();
271 let event_sender = self.event_sender.clone();
272 let roots = Arc::clone(&self.roots);
273 let proxied = record.proxied;
274
275 self.set_status(&name, McpServerStatus::Authenticating);
276 self.emit_server_statuses_changed().await;
277
278 Ok(async move {
279 authenticate_http(
280 name,
281 config,
282 client_info,
283 event_sender,
284 roots,
285 oauth_handler_factory,
286 oauth_credential_store,
287 proxied,
288 )
289 .await
290 })
291 }
292
293 pub async fn apply_connection_attempt(&mut self, attempt: McpConnectAttempt) {
294 let McpConnectAttempt { name, proxied, outcome } = attempt;
295 match outcome {
296 McpConnectOutcome::Connected { conn, reauth_config } => {
297 match self.register_connection(&name, conn, reauth_config, proxied).await {
298 Ok(tools) => {
299 self.refresh_proxy_after_auth(&name, &tools, proxied).await;
300 self.emit_server_statuses_changed().await;
301 self.emit_tool_definitions_changed().await;
302 }
303 Err(error) => self.apply_authentication_failure(name, error.to_string()).await,
304 }
305 }
306 McpConnectOutcome::Failed { error } => {
307 self.apply_authentication_failure(name, error.to_string()).await;
308 }
309 McpConnectOutcome::NeedsOAuth { .. } => {
310 self.apply_authentication_failure(name, "internal error: auth task returned NeedsOAuth".to_string())
311 .await;
312 }
313 }
314 }
315
316 pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
318 let futures: Vec<_> = self
319 .servers
320 .iter()
321 .filter_map(|(server_name, record)| {
322 let conn = record.connection.as_ref()?;
323 conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref())?;
324 let server_name = server_name.clone();
325 let client = conn.client.clone();
326 Some(async move {
327 let prompts_response = client.list_prompts(None).await.map_err(|e| {
328 McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
329 })?;
330
331 let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
332 .prompts
333 .into_iter()
334 .map(|prompt| {
335 let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
336 rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
337 })
338 .collect();
339
340 Ok::<_, McpError>(namespaced_prompts)
341 })
342 })
343 .collect();
344
345 let results = join_all(futures).await;
346 let mut all_prompts = Vec::new();
347 for result in results {
348 all_prompts.extend(result?);
349 }
350
351 Ok(all_prompts)
352 }
353
354 pub async fn get_prompt(
356 &self,
357 namespaced_prompt_name: &str,
358 arguments: Option<serde_json::Map<String, serde_json::Value>>,
359 ) -> Result<rmcp::model::GetPromptResult> {
360 let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
361 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
362
363 let server_conn =
364 self.connection_for(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
365
366 let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
367 if let Some(args) = arguments {
368 request = request.with_arguments(args);
369 }
370
371 server_conn.client.get_prompt(request).await.map_err(|e| {
372 McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
373 })
374 }
375
376 pub async fn shutdown(&mut self) {
378 let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
379
380 for (server_name, record) in servers {
381 if let Some(conn) = record.connection
382 && let Some(handle) = conn.server_task
383 {
384 drop(conn.client);
385
386 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
387 Ok(Ok(())) => {
388 tracing::info!("Server '{server_name}' shut down gracefully");
389 }
390 Ok(Err(e)) => {
391 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
392 }
393 Err(_) => {
394 tracing::warn!("Server '{server_name}' shutdown timed out");
395 }
396 }
397 }
398 }
399
400 self.tools.clear();
401 self.tool_definitions.clear();
402 self.proxy = None;
403 }
404
405 pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
407 let record = self.servers.remove(server_name);
408
409 if let Some(record) = record {
410 if let Some(conn) = record.connection
411 && let Some(handle) = conn.server_task
412 {
413 drop(conn.client);
414
415 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
416 Ok(Ok(())) => {
417 tracing::info!("Server '{server_name}' shut down gracefully");
418 }
419 Ok(Err(e)) => {
420 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
421 }
422 Err(_) => {
423 tracing::warn!("Server '{server_name}' shutdown timed out");
424 }
425 }
426 }
427
428 self.remove_registered_tools_for_server(server_name);
429 self.refresh_status_entries();
430 }
431
432 Ok(())
433 }
434
435 pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
440 {
441 let mut roots = self.roots.write().await;
442 *roots = new_roots;
443 }
444
445 self.notify_roots_changed().await;
446
447 Ok(())
448 }
449
450 async fn emit_server_statuses_changed(&self) {
451 self.emit_event(McpClientEvent::ServerStatusesChanged(self.server_statuses().to_vec())).await;
452 }
453
454 async fn emit_tool_definitions_changed(&self) {
455 self.emit_event(McpClientEvent::ToolDefinitionsChanged(self.tool_definitions())).await;
456 }
457
458 async fn emit_authentication_failed(&self, server: String, error: String) {
459 self.emit_event(McpClientEvent::AuthenticationFailed { server, error }).await;
460 }
461
462 async fn emit_event(&self, event: McpClientEvent) {
463 if let Err(e) = self.event_sender.send(event).await {
464 tracing::warn!("Failed to emit MCP client event: {e}");
465 }
466 }
467
468 fn connect_context(&self) -> ConnectContext<'_> {
469 ConnectContext {
470 client_info: &self.client_info,
471 event_sender: &self.event_sender,
472 roots: &self.roots,
473 oauth_handler_factory: self.oauth_handler_factory.as_ref(),
474 oauth_credential_store: self.oauth_credential_store.as_ref(),
475 }
476 }
477
478 fn proxy_tool_dir(&self) -> Result<PathBuf> {
479 self.aether_home
480 .as_ref()
481 .map(|home| ToolProxy::dir_in_home(home, DEFAULT_PROXY_NAME))
482 .map_or_else(|| ToolProxy::dir(DEFAULT_PROXY_NAME), Ok)
483 }
484
485 async fn register_connection(
486 &mut self,
487 name: &str,
488 conn: McpServerConnection,
489 reauth_config: Option<StreamableHttpClientTransportConfig>,
490 proxied: bool,
491 ) -> Result<Vec<RmcpTool>> {
492 let tools = conn
493 .list_tools()
494 .await
495 .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
496 self.apply_connected(name, conn, &tools, reauth_config, proxied);
497 Ok(tools)
498 }
499
500 fn apply_connected(
501 &mut self,
502 name: &str,
503 conn: McpServerConnection,
504 tools: &[RmcpTool],
505 reauth_config: Option<StreamableHttpClientTransportConfig>,
506 proxied: bool,
507 ) {
508 self.remove_registered_tools_for_server(name);
509
510 let existing_reauth = self.servers.get(name).and_then(|r| r.reauth_config.clone());
511 let final_reauth = reauth_config.or(existing_reauth);
512
513 for rmcp_tool in tools {
514 let tool_name = rmcp_tool.name.to_string();
515 let namespaced_tool_name = create_namespaced_tool_name(name, &tool_name);
516 let tool = Tool::from(rmcp_tool);
517
518 if !proxied {
519 self.tool_definitions.push(ToolDefinition {
520 name: namespaced_tool_name.clone(),
521 description: tool.description.clone(),
522 parameters: tool.parameters.to_string(),
523 server: Some(name.to_string()),
524 });
525 self.tools.insert(namespaced_tool_name, tool);
526 }
527 }
528
529 self.remember_server_order(name);
530 self.servers.insert(name.to_string(), ServerRecord::connected(conn, tools.len(), final_reauth, proxied));
531 self.refresh_status_entries();
532 }
533
534 fn register_proxy(&mut self, tool_dir: std::path::PathBuf, members: HashSet<String>) {
535 self.remove_registered_tools_for_server(DEFAULT_PROXY_NAME);
536 let call_tool_def = ToolProxy::call_tool_definition(DEFAULT_PROXY_NAME);
537 self.tools.insert(
538 call_tool_def.name.clone(),
539 Tool {
540 description: call_tool_def.description.clone(),
541 parameters: serde_json::from_str(&call_tool_def.parameters)
542 .unwrap_or(Value::Object(serde_json::Map::default())),
543 },
544 );
545 self.tool_definitions.push(call_tool_def);
546
547 self.proxy = Some(ToolProxy::new(DEFAULT_PROXY_NAME.to_string(), members, tool_dir));
548 }
549
550 async fn refresh_proxy_after_auth(&mut self, name: &str, tools: &[RmcpTool], proxied: bool) {
551 if !proxied {
552 return;
553 }
554
555 if let Some(proxy) = self.proxy.as_mut() {
556 proxy.add_member(name.to_string());
557 }
558
559 if let Some(tool_dir) = self.proxy.as_ref().map(|proxy| proxy.tool_dir().to_path_buf())
560 && let Err(e) = ToolProxy::write_tool_entries_to_dir(name, tools, &tool_dir).await
561 {
562 tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
563 }
564 }
565
566 async fn write_proxy_tool_files(&self, connected_proxied: &[String], tool_dir: &std::path::Path) {
567 let writes = connected_proxied.iter().filter_map(|name| {
568 let client = self.client_for_server(name)?;
569 let dir = tool_dir.to_path_buf();
570 let name = name.clone();
571 Some(async move {
572 if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &dir).await {
573 tracing::warn!("Failed to write tool files for proxied server '{name}': {e}");
574 }
575 })
576 });
577 join_all(writes).await;
578 }
579
580 fn refresh_status_entries(&mut self) {
581 self.server_statuses = self
582 .server_order
583 .iter()
584 .filter_map(|name| self.servers.get(name).map(|record| record.status_entry(name)))
585 .collect();
586 }
587
588 fn remember_server_order(&mut self, name: &str) {
589 if !self.server_order.iter().any(|n| n == name) {
590 self.server_order.push(name.to_string());
591 }
592 }
593
594 async fn apply_authentication_failure(&mut self, name: String, error: String) {
595 self.set_status(&name, McpServerStatus::Failed { error: error.clone() });
596 self.emit_server_statuses_changed().await;
597 self.emit_authentication_failed(name, error).await;
598 }
599
600 fn set_status(&mut self, name: &str, status: McpServerStatus) {
601 self.remember_server_order(name);
602 let record =
603 self.servers.entry(name.to_string()).or_insert_with(|| ServerRecord::new(status.clone(), None, false));
604 record.status = status;
605 self.refresh_status_entries();
606 }
607
608 fn register_record(
609 &mut self,
610 name: &str,
611 status: McpServerStatus,
612 reauth_config: Option<StreamableHttpClientTransportConfig>,
613 proxied: bool,
614 ) {
615 self.remember_server_order(name);
616 self.servers.insert(name.to_string(), ServerRecord::new(status, reauth_config, proxied));
617 self.refresh_status_entries();
618 }
619
620 fn connection_for(&self, server_name: &str) -> Option<&McpServerConnection> {
621 self.servers.get(server_name).and_then(|record| record.connection.as_ref())
622 }
623
624 fn client_for_server(&self, server_name: &str) -> Option<Arc<RunningService<RoleClient, McpClient>>> {
625 self.connection_for(server_name).map(|conn| conn.client.clone())
626 }
627
628 fn remove_registered_tools_for_server(&mut self, server_name: &str) {
629 let prefix = format!("{server_name}__");
630 self.tools.retain(|tool_name, _| !tool_name.starts_with(&prefix));
631 self.tool_definitions.retain(|tool_def| !tool_def.name.starts_with(&prefix));
632 }
633
634 async fn notify_roots_changed(&self) {
635 for (server_name, record) in &self.servers {
636 if let Some(conn) = &record.connection
637 && let Err(e) = conn.client.notify_roots_list_changed().await
638 {
639 tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
640 }
641 }
642 }
643}
644
645impl Drop for McpManager {
646 fn drop(&mut self) {
647 let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
648 for (server_name, record) in servers {
649 if let Some(conn) = record.connection
650 && let Some(handle) = conn.server_task
651 {
652 handle.abort();
653 tracing::warn!("Server '{server_name}' task aborted during cleanup");
654 }
655 }
656 }
657}
658
659struct ServerRecord {
661 connection: Option<McpServerConnection>,
662 status: McpServerStatus,
663 reauth_config: Option<StreamableHttpClientTransportConfig>,
664 proxied: bool,
665}
666
667impl ServerRecord {
668 fn new(status: McpServerStatus, reauth_config: Option<StreamableHttpClientTransportConfig>, proxied: bool) -> Self {
669 Self { connection: None, status, reauth_config, proxied }
670 }
671
672 fn connected(
673 connection: McpServerConnection,
674 tool_count: usize,
675 reauth_config: Option<StreamableHttpClientTransportConfig>,
676 proxied: bool,
677 ) -> Self {
678 Self { connection: Some(connection), status: McpServerStatus::Connected { tool_count }, reauth_config, proxied }
679 }
680
681 fn auth_capability(&self) -> McpServerAuthCapability {
682 if self.reauth_config.is_some() { McpServerAuthCapability::OAuth } else { McpServerAuthCapability::Unavailable }
683 }
684
685 fn can_authenticate(&self) -> bool {
686 self.reauth_config.is_some()
687 }
688
689 fn status_entry(&self, name: &str) -> McpServerStatusEntry {
690 McpServerStatusEntry::new(name, self.status.clone())
691 .with_auth_capability(self.auth_capability())
692 .with_proxied(self.proxied)
693 }
694}
695
696#[cfg(test)]
697mod tests {
698 use super::{DEFAULT_PROXY_NAME, McpClientEvent, McpManager, McpServerStatus, Tool};
699 use crate::client::OAuthHandlerFactory;
700 use crate::client::config::{McpServer, McpTransport};
701 use crate::client::connection::{McpConnectAttempt, McpConnectOutcome};
702 use crate::status::McpServerAuthCapability;
703 use aether_auth::{OAuthCallback, OAuthError, OAuthHandler};
704 use futures::future::BoxFuture;
705 use llm::ToolDefinition;
706 use rmcp::{
707 Json, RoleServer, ServerHandler,
708 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
709 model::{Implementation, ServerCapabilities, ServerInfo},
710 service::DynService,
711 tool, tool_handler, tool_router,
712 transport::streamable_http_client::StreamableHttpClientTransportConfig,
713 };
714 use schemars::JsonSchema;
715 use serde::{Deserialize, Serialize};
716 use serde_json::json;
717 use std::{
718 io,
719 sync::{Arc, Mutex},
720 };
721 use tokio::sync::mpsc;
722
723 #[derive(Clone)]
724 struct TestServer {
725 tool_router: ToolRouter<Self>,
726 }
727
728 #[tool_handler(router = self.tool_router)]
729 impl ServerHandler for TestServer {
730 fn get_info(&self) -> ServerInfo {
731 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
732 .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
733 }
734 }
735
736 impl Default for TestServer {
737 fn default() -> Self {
738 Self { tool_router: Self::tool_router() }
739 }
740 }
741
742 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
743 struct EchoRequest {
744 value: String,
745 }
746
747 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
748 struct EchoResult {
749 value: String,
750 }
751
752 #[tool_router]
753 impl TestServer {
754 fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
755 Box::new(self)
756 }
757
758 #[tool(description = "Returns the provided value")]
759 async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
760 let Parameters(EchoRequest { value }) = request;
761 Json(EchoResult { value })
762 }
763 }
764
765 #[derive(Clone)]
766 struct SharedWriter(Arc<Mutex<Vec<u8>>>);
767
768 impl io::Write for SharedWriter {
769 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
770 self.0.lock().unwrap().extend_from_slice(buf);
771 Ok(buf.len())
772 }
773
774 fn flush(&mut self) -> io::Result<()> {
775 Ok(())
776 }
777 }
778
779 struct TestOAuthHandler;
780
781 impl OAuthHandler for TestOAuthHandler {
782 fn redirect_uri(&self) -> &'static str {
783 "http://127.0.0.1:0/oauth2callback"
784 }
785
786 fn authorize(&self, _auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
787 Box::pin(async { Err(OAuthError::UserCancelled) })
788 }
789 }
790
791 fn test_oauth_handler_factory() -> OAuthHandlerFactory {
792 Arc::new(|| Ok(Arc::new(TestOAuthHandler)))
793 }
794
795 #[tokio::test]
796 async fn authenticate_server_task_rejects_record_without_reauth_config() {
797 let (event_sender, _event_receiver) = mpsc::channel(1);
798 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
799 manager.register_record("public", McpServerStatus::Connected { tool_count: 1 }, None, false);
800
801 let error = match manager.authenticate_server_task("public").await {
802 Ok(_) => panic!("non-OAuth server should be rejected"),
803 Err(error) => error.to_string(),
804 };
805 assert!(error.contains("not OAuth-authenticatable"));
806 }
807
808 #[tokio::test]
809 async fn authenticate_server_task_marks_server_authenticating_and_emits_status() {
810 let (event_sender, mut event_receiver) = mpsc::channel(2);
811 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
812 manager.register_record(
813 "remote",
814 McpServerStatus::NeedsOAuth,
815 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
816 false,
817 );
818
819 let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
820
821 assert!(matches!(manager.server_statuses()[0].status, McpServerStatus::Authenticating));
822 let event = event_receiver.recv().await.expect("status change event");
823 let McpClientEvent::ServerStatusesChanged(servers) = event else {
824 panic!("expected ServerStatusesChanged");
825 };
826 let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
827 assert!(matches!(status.status, McpServerStatus::Authenticating));
828 assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
829 }
830
831 #[tokio::test]
832 async fn authenticate_server_task_rejects_duplicate_same_server_while_in_flight() {
833 let (event_sender, _event_receiver) = mpsc::channel(1);
834 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
835 manager.register_record(
836 "remote",
837 McpServerStatus::NeedsOAuth,
838 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
839 false,
840 );
841
842 let _task = manager.authenticate_server_task("remote").await.expect("first auth should start");
843 let error = match manager.authenticate_server_task("remote").await {
844 Ok(_) => panic!("duplicate auth should be rejected"),
845 Err(error) => error.to_string(),
846 };
847
848 assert!(error.contains("already authenticating"));
849 }
850
851 #[tokio::test]
852 async fn apply_connection_attempt_failure_allows_retry() {
853 let (event_sender, mut event_receiver) = mpsc::channel(2);
854 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
855 manager.register_record(
856 "remote",
857 McpServerStatus::NeedsOAuth,
858 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
859 false,
860 );
861 let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
862 let _authenticating_event = event_receiver.recv().await.expect("authenticating status change event");
863
864 manager
865 .apply_connection_attempt(McpConnectAttempt {
866 name: "remote".to_string(),
867 proxied: false,
868 outcome: McpConnectOutcome::Failed {
869 error: crate::client::McpError::ConnectionFailed("boom".to_string()),
870 },
871 })
872 .await;
873
874 let event = event_receiver.recv().await.expect("status change event");
875 let McpClientEvent::ServerStatusesChanged(servers) = event else {
876 panic!("expected ServerStatusesChanged");
877 };
878 let auth_event = event_receiver.recv().await.expect("authentication failure event");
879 let McpClientEvent::AuthenticationFailed { server, error } = auth_event else {
880 panic!("expected AuthenticationFailed");
881 };
882 assert_eq!(server, "remote");
883 assert!(error.contains("boom"));
884
885 let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
886 assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
887 assert!(matches!(status.status, McpServerStatus::Failed { ref error } if error.contains("boom")));
888 assert!(manager.authenticate_server_task("remote").await.is_ok());
889 }
890
891 #[test]
892 fn status_entries_are_derived_from_reauth_config() {
893 let (event_sender, _event_receiver) = mpsc::channel(1);
894 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
895
896 manager.register_record(
897 "with-oauth",
898 McpServerStatus::Connected { tool_count: 1 },
899 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp")),
900 false,
901 );
902 manager.register_record("without-oauth", McpServerStatus::Connected { tool_count: 2 }, None, false);
903 manager.register_record(
904 "needs-oauth",
905 McpServerStatus::NeedsOAuth,
906 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp2")),
907 false,
908 );
909
910 let statuses = manager.server_statuses();
911 let with_oauth = statuses.iter().find(|s| s.name == "with-oauth").unwrap();
912 let without_oauth = statuses.iter().find(|s| s.name == "without-oauth").unwrap();
913 let needs_oauth = statuses.iter().find(|s| s.name == "needs-oauth").unwrap();
914
915 assert_eq!(with_oauth.auth_capability, McpServerAuthCapability::OAuth);
916 assert_eq!(without_oauth.auth_capability, McpServerAuthCapability::Unavailable);
917 assert_eq!(needs_oauth.auth_capability, McpServerAuthCapability::OAuth);
918 }
919
920 #[tokio::test]
921 async fn server_statuses_mark_direct_and_proxied_servers_without_proxy_row() {
922 let (event_sender, _event_receiver) = mpsc::channel(1);
923 let mut manager = McpManager::new(event_sender, None);
924 manager
925 .add_mcps(vec![
926 McpServer::new("direct", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
927 McpServer::new("math", McpTransport::InMemory { server: TestServer::default().into_dyn() }, true),
928 ])
929 .await
930 .unwrap();
931
932 let statuses = manager.server_statuses();
933 assert_eq!(statuses.iter().map(|status| status.name.as_str()).collect::<Vec<_>>(), vec!["direct", "math"]);
934 assert!(!statuses.iter().find(|status| status.name == "direct").unwrap().proxied);
935 assert!(statuses.iter().find(|status| status.name == "math").unwrap().proxied);
936 assert!(!statuses.iter().any(|status| status.name == DEFAULT_PROXY_NAME));
937 }
938
939 #[test]
940 fn remove_registered_tools_for_server_uses_namespaced_prefix() {
941 let (event_sender, _event_receiver) = mpsc::channel(1);
942 let mut manager = McpManager::new(event_sender, None);
943 manager.tools.insert("git__status".to_string(), Tool { description: String::new(), parameters: json!({}) });
944 manager.tools.insert("github__issue".to_string(), Tool { description: String::new(), parameters: json!({}) });
945 manager.tool_definitions.push(ToolDefinition {
946 name: "git__status".to_string(),
947 description: String::new(),
948 parameters: "{}".to_string(),
949 server: Some("git".to_string()),
950 });
951 manager.tool_definitions.push(ToolDefinition {
952 name: "github__issue".to_string(),
953 description: String::new(),
954 parameters: "{}".to_string(),
955 server: Some("github".to_string()),
956 });
957
958 manager.remove_registered_tools_for_server("git");
959
960 assert!(!manager.tools.contains_key("git__status"));
961 assert!(manager.tools.contains_key("github__issue"));
962 assert_eq!(
963 manager.tool_definitions.iter().map(|tool| tool.name.as_str()).collect::<Vec<_>>(),
964 vec!["github__issue"]
965 );
966 }
967
968 #[tokio::test]
969 async fn drop_logs_cleanup_abort_with_tracing() {
970 let (event_sender, _event_receiver) = mpsc::channel(1);
971 let mut manager = McpManager::new(event_sender, None);
972 manager
973 .add_mcps(vec![McpServer::new(
974 "test",
975 McpTransport::InMemory { server: TestServer::default().into_dyn() },
976 false,
977 )])
978 .await
979 .unwrap();
980
981 let output = Arc::new(Mutex::new(Vec::new()));
982 let subscriber = tracing_subscriber::fmt()
983 .with_ansi(false)
984 .without_time()
985 .with_writer({
986 let output = Arc::clone(&output);
987 move || SharedWriter(Arc::clone(&output))
988 })
989 .finish();
990
991 tracing::subscriber::with_default(subscriber, || {
992 drop(manager);
993 });
994
995 let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
996 assert!(logs.contains("Server 'test' task aborted during cleanup"));
997 }
998}