1use llm::ToolDefinition;
2
3use super::{
4 McpError, Result,
5 config::McpServer,
6 connection::{
7 ConnectConfig, McpConnectAttempt, McpConnectOutcome, McpServerConnection, Tool, authenticate_http,
8 connect_server,
9 },
10 mcp_client::McpClient,
11 naming::{create_namespaced_tool_name, split_on_server_name},
12 tool_proxy::ToolProxy,
13};
14use aether_auth::{OAuthCredentialStorage, OAuthHandler};
15use futures::future::join_all;
16use rmcp::{
17 RoleClient,
18 model::{
19 CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
20 ElicitationAction, FormElicitationCapability, Implementation, Root, Tool as RmcpTool, UrlElicitationCapability,
21 },
22 service::RunningService,
23 transport::streamable_http_client::StreamableHttpClientTransportConfig,
24};
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::{BTreeMap, HashMap, HashSet};
28use std::future::Future;
29use std::path::PathBuf;
30use std::sync::Arc;
31use tokio::sync::{RwLock, mpsc, oneshot};
32
33pub use crate::status::{McpServerAuthCapability, McpServerStatus, McpServerStatusEntry};
34
35pub const DEFAULT_PROXY_NAME: &str = "proxy";
36
37pub type OAuthHandlerFactory = Arc<dyn Fn(OAuthHandlerContext) -> Result<Arc<dyn OAuthHandler>> + Send + Sync>;
38
39#[derive(Clone)]
42pub struct OAuthHandlerContext {
43 pub server_name: String,
44 pub tx: mpsc::Sender<McpClientEvent>,
45}
46
47#[derive(Debug)]
48pub struct ElicitationRequest {
49 pub server_name: String,
50 pub request: CreateElicitationRequestParams,
51 pub response_sender: oneshot::Sender<CreateElicitationResult>,
52}
53
54#[derive(Debug, Clone)]
55pub struct ElicitationResponse {
56 pub action: ElicitationAction,
57 pub content: Option<Value>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
61pub struct UrlElicitationCompleteParams {
62 pub server_name: String,
63 pub elicitation_id: String,
64}
65
66#[derive(Debug)]
70pub enum McpClientEvent {
71 Elicitation(ElicitationRequest),
72 UrlElicitationComplete(UrlElicitationCompleteParams),
73 ServerStatusesChanged(Vec<McpServerStatusEntry>),
74 ToolDefinitionsChanged(Vec<ToolDefinition>),
75 ServerInstructionsUpdated { server: String, instructions: Option<String> },
76 AuthenticationFailed { server: String, error: String },
77 ConnectionReady(McpConnectionDetails),
78}
79
80#[derive(Debug, Clone)]
81pub struct McpConnectionDetails {
82 pub instructions: BTreeMap<String, String>,
83 pub tool_definitions: Vec<ToolDefinition>,
84 pub server_statuses: Vec<McpServerStatusEntry>,
85}
86
87pub struct McpManager {
89 servers: HashMap<String, ServerRecord>,
90 server_order: Vec<String>,
91 proxy: Option<ToolProxy>,
92 aether_home: Option<PathBuf>,
93 client_info: ClientInfo,
94 event_sender: mpsc::Sender<McpClientEvent>,
95 roots: Arc<RwLock<Vec<Root>>>,
97 oauth_handler_factory: Option<OAuthHandlerFactory>,
98 oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
99}
100
101impl McpManager {
102 pub fn new(event_sender: mpsc::Sender<McpClientEvent>, oauth_handler_factory: Option<OAuthHandlerFactory>) -> Self {
103 let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
104 if let Some(elicitation) = capabilities.elicitation.as_mut() {
105 elicitation.form = Some(FormElicitationCapability::default());
106 elicitation.url = Some(UrlElicitationCapability::default());
107 }
108
109 Self {
110 servers: HashMap::new(),
111 server_order: Vec::new(),
112 proxy: None,
113 aether_home: None,
114 client_info: ClientInfo::new(capabilities, Implementation::new("aether", "0.1.0")),
115 event_sender,
116 roots: Arc::new(RwLock::new(Vec::new())),
117 oauth_handler_factory,
118 oauth_credential_store: None,
119 }
120 }
121
122 pub fn with_aether_home(mut self, aether_home: impl Into<PathBuf>) -> Self {
123 self.aether_home = Some(aether_home.into());
124 self
125 }
126
127 pub fn with_oauth_credential_store(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
128 self.oauth_credential_store = Some(store);
129 self
130 }
131
132 pub async fn register_pending(&mut self, servers: Vec<McpServer>) -> Result<Vec<McpServer>> {
133 let has_proxy = servers.iter().any(|server| server.proxy);
134 if has_proxy && servers.iter().any(|server| server.name == DEFAULT_PROXY_NAME) {
135 return Err(McpError::Other("server name 'proxy' collides with the tool proxy".into()));
136 }
137
138 for server in &servers {
139 self.register_record(&server.name, ServerState::Connecting, None, server.proxy);
140 }
141
142 self.emit_server_statuses_changed().await;
143 Ok(servers)
144 }
145
146 pub async fn bootstrap_proxy_setup(&mut self, servers: &[McpServer]) -> Result<()> {
147 let proxied_members: HashSet<String> =
148 servers.iter().filter(|server| server.proxy).map(|server| server.name.clone()).collect();
149
150 if proxied_members.is_empty() {
151 return Ok(());
152 }
153
154 let dir = self.proxy_tool_dir()?;
155 ToolProxy::clean_dir(&dir).await?;
156 self.register_proxy(dir, proxied_members);
157 Ok(())
158 }
159
160 pub fn connect_pending_task(&self, server: McpServer) -> impl Future<Output = McpConnectAttempt> + Send + 'static {
161 let ctx = self.connect_config();
162 async move { connect_server(server, &ctx).await }
163 }
164
165 pub async fn add_mcps(&mut self, servers: Vec<McpServer>) -> Result<()> {
166 self.bootstrap_proxy_setup(&servers).await?;
167 let pending = self.register_pending(servers).await?;
168 let ctx = self.connect_config();
169 let attempts = join_all(pending.into_iter().map(|server| connect_server(server, &ctx))).await;
170 for attempt in attempts {
171 self.apply_connection_attempt(attempt).await;
172 }
173 Ok(())
174 }
175
176 pub fn get_client_for_tool(
177 &self,
178 namespaced_tool_name: &str,
179 arguments_json: &str,
180 ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
181 if !self.is_routable_tool(namespaced_tool_name) {
182 return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
183 }
184
185 let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
186 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
187
188 if let Some(proxy) = self.proxy.as_ref().filter(|proxy| proxy.name() == server_name) {
189 let call = proxy.resolve_call(arguments_json)?;
190 let conn = self.connection_for(&call.server).ok_or_else(|| {
191 McpError::ServerNotFound(format!("Proxied server '{}' is not connected", call.server))
192 })?;
193 let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
194 return Ok((conn.client.clone(), params));
195 }
196
197 let client =
198 self.client_for_server(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
199
200 let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
201 let mut params = CallToolRequestParams::new(tool_name.to_string());
202 if let Some(args) = arguments {
203 params = params.with_arguments(args);
204 }
205
206 Ok((client, params))
207 }
208
209 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
210 let mut definitions = Vec::new();
211 if let Some(proxy) = self.proxy.as_ref() {
212 definitions.push(ToolProxy::call_tool_definition(proxy.name()));
213 }
214 for name in &self.server_order {
215 let Some(record) = self.servers.get(name) else { continue };
216 if record.proxied {
217 continue;
218 }
219 definitions.extend(record.tools().iter().map(|tool| {
220 ToolDefinition::new(
221 create_namespaced_tool_name(name, &tool.name),
222 tool.description.clone(),
223 tool.parameters.to_string(),
224 )
225 .with_server(name.clone())
226 .with_annotations(tool.annotations.clone())
227 }));
228 }
229 definitions
230 }
231
232 pub fn server_instructions(&self) -> BTreeMap<String, String> {
233 let mut instructions: BTreeMap<String, String> = self
234 .servers
235 .iter()
236 .filter(|(name, _)| self.proxy.as_ref().is_none_or(|proxy| !proxy.contains_server(name)))
237 .filter_map(|(name, record)| {
238 record
239 .connection()
240 .and_then(|conn| conn.instructions.as_ref())
241 .map(|instr| (name.clone(), instr.clone()))
242 })
243 .collect();
244
245 if let Some((name, body)) = self.proxy_instructions() {
246 instructions.insert(name, body);
247 }
248
249 instructions
250 }
251
252 pub fn server_statuses(&self) -> Vec<McpServerStatusEntry> {
253 self.server_order
254 .iter()
255 .filter_map(|name| self.servers.get(name).map(|record| record.status_entry(name)))
256 .collect()
257 }
258
259 pub async fn authenticate_server_task(
260 &mut self,
261 name: &str,
262 ) -> Result<impl Future<Output = McpConnectAttempt> + Send + 'static> {
263 let record = self
264 .servers
265 .get(name)
266 .ok_or_else(|| McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")))?;
267 if !record.can_authenticate() {
268 return Err(McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")));
269 }
270 if self.oauth_handler_factory.is_none() {
271 return Err(McpError::ConnectionFailed(format!("No OAuth handler factory available for '{name}'")));
272 }
273
274 let name = name.to_string();
275 let config = record.reauth_config.clone().expect("checked above");
276 let proxied = record.proxied;
277 let ctx = self.connect_config();
278
279 self.set_state(&name, ServerState::Authenticating);
280 self.emit_server_statuses_changed().await;
281
282 Ok(async move { authenticate_http(name, config, ctx, proxied).await })
283 }
284
285 pub async fn apply_connection_attempt(&mut self, attempt: McpConnectAttempt) {
286 let McpConnectAttempt { name, proxied, outcome } = attempt;
287 match outcome {
288 McpConnectOutcome::Connected { conn, reauth_config } => {
289 match self.register_connection(&name, conn, reauth_config, proxied).await {
290 Ok(tools) => {
291 self.refresh_proxy_after_auth(&name, &tools, proxied).await;
292 self.emit_server_statuses_changed().await;
293 self.emit_tool_definitions_changed().await;
294 self.emit_instructions_after_connect(&name, proxied).await;
295 }
296 Err(error) => self.apply_authentication_failure(name, error.to_string()).await,
297 }
298 }
299 McpConnectOutcome::NeedsOAuth { config, error } => {
300 tracing::warn!("Server '{name}' needs OAuth: {error}");
301 self.register_record(&name, ServerState::NeedsOAuth, Some(config), proxied);
302 self.emit_server_statuses_changed().await;
303 }
304 McpConnectOutcome::Failed { error } => {
305 self.apply_authentication_failure(name, error.to_string()).await;
306 }
307 }
308 }
309
310 pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
312 let futures: Vec<_> = self
313 .servers
314 .iter()
315 .filter_map(|(server_name, record)| {
316 let conn = record.connection()?;
317 conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref())?;
318 let server_name = server_name.clone();
319 let client = conn.client.clone();
320 Some(async move {
321 let prompts_response = client.list_prompts(None).await.map_err(|e| {
322 McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
323 })?;
324
325 let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
326 .prompts
327 .into_iter()
328 .map(|prompt| {
329 let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
330 rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
331 })
332 .collect();
333
334 Ok::<_, McpError>(namespaced_prompts)
335 })
336 })
337 .collect();
338
339 let results = join_all(futures).await;
340 let mut all_prompts = Vec::new();
341 for result in results {
342 all_prompts.extend(result?);
343 }
344
345 Ok(all_prompts)
346 }
347
348 pub async fn get_prompt(
350 &self,
351 namespaced_prompt_name: &str,
352 arguments: Option<serde_json::Map<String, serde_json::Value>>,
353 ) -> Result<rmcp::model::GetPromptResult> {
354 let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
355 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
356
357 let server_conn =
358 self.connection_for(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
359
360 let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
361 if let Some(args) = arguments {
362 request = request.with_arguments(args);
363 }
364
365 server_conn.client.get_prompt(request).await.map_err(|e| {
366 McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
367 })
368 }
369
370 pub async fn shutdown(&mut self) {
372 let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
373
374 for (server_name, record) in servers {
375 if let Some(conn) = record.into_connection()
376 && let Some(handle) = conn.server_task
377 {
378 drop(conn.client);
379
380 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
381 Ok(Ok(())) => {
382 tracing::info!("Server '{server_name}' shut down gracefully");
383 }
384 Ok(Err(e)) => {
385 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
386 }
387 Err(_) => {
388 tracing::warn!("Server '{server_name}' shutdown timed out");
389 }
390 }
391 }
392 }
393
394 self.server_order.clear();
395 self.proxy = None;
396 }
397
398 pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
400 if let Some(record) = self.servers.remove(server_name)
401 && let Some(conn) = record.into_connection()
402 && let Some(handle) = conn.server_task
403 {
404 drop(conn.client);
405
406 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
407 Ok(Ok(())) => {
408 tracing::info!("Server '{server_name}' shut down gracefully");
409 }
410 Ok(Err(e)) => {
411 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
412 }
413 Err(_) => {
414 tracing::warn!("Server '{server_name}' shutdown timed out");
415 }
416 }
417 }
418
419 Ok(())
420 }
421
422 pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
427 {
428 let mut roots = self.roots.write().await;
429 *roots = new_roots;
430 }
431
432 self.notify_roots_changed().await;
433
434 Ok(())
435 }
436
437 async fn emit_server_statuses_changed(&self) {
438 self.emit_event(McpClientEvent::ServerStatusesChanged(self.server_statuses())).await;
439 }
440
441 async fn emit_tool_definitions_changed(&self) {
442 self.emit_event(McpClientEvent::ToolDefinitionsChanged(self.tool_definitions())).await;
443 }
444
445 async fn emit_instructions_after_connect(&self, server_name: &str, proxied: bool) {
446 if proxied {
447 if let Some((server, body)) = self.proxy_instructions() {
448 self.emit_event(McpClientEvent::ServerInstructionsUpdated { server, instructions: Some(body) }).await;
449 }
450 return;
451 }
452
453 if let Some(instructions) =
454 self.connection_for(server_name).and_then(|conn| conn.instructions.as_ref()).cloned()
455 {
456 self.emit_event(McpClientEvent::ServerInstructionsUpdated {
457 server: server_name.to_string(),
458 instructions: Some(instructions),
459 })
460 .await;
461 }
462 }
463
464 fn proxy_instructions(&self) -> Option<(String, String)> {
465 let proxy = self.proxy.as_ref()?;
466 let descriptions: Vec<(String, String)> = proxy
467 .members()
468 .iter()
469 .filter_map(|member| {
470 let conn = self.connection_for(member)?;
471 Some((member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
472 })
473 .collect();
474 Some((proxy.name().to_string(), ToolProxy::build_instructions(proxy.tool_dir(), &descriptions)))
475 }
476
477 pub async fn emit_connection_ready(&self) {
478 self.emit_event(McpClientEvent::ConnectionReady(McpConnectionDetails {
479 tool_definitions: self.tool_definitions(),
480 instructions: self.server_instructions(),
481 server_statuses: self.server_statuses(),
482 }))
483 .await;
484 }
485
486 async fn emit_authentication_failed(&self, server: String, error: String) {
487 self.emit_event(McpClientEvent::AuthenticationFailed { server, error }).await;
488 }
489
490 async fn emit_event(&self, event: McpClientEvent) {
491 if let Err(e) = self.event_sender.send(event).await {
492 tracing::warn!("Failed to emit MCP client event: {e}");
493 }
494 }
495
496 fn connect_config(&self) -> Arc<ConnectConfig> {
497 Arc::new(ConnectConfig {
498 client_info: self.client_info.clone(),
499 event_sender: self.event_sender.clone(),
500 roots: Arc::clone(&self.roots),
501 oauth_handler_factory: self.oauth_handler_factory.clone(),
502 oauth_credential_store: self.oauth_credential_store.clone(),
503 })
504 }
505
506 fn proxy_tool_dir(&self) -> Result<PathBuf> {
507 self.aether_home
508 .as_ref()
509 .map(|home| ToolProxy::dir_in_home(home, DEFAULT_PROXY_NAME))
510 .map_or_else(|| ToolProxy::dir(DEFAULT_PROXY_NAME), Ok)
511 }
512
513 async fn register_connection(
514 &mut self,
515 name: &str,
516 conn: McpServerConnection,
517 reauth_config: Option<StreamableHttpClientTransportConfig>,
518 proxied: bool,
519 ) -> Result<Vec<RmcpTool>> {
520 let tools = conn
521 .list_tools()
522 .await
523 .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
524 self.apply_connected(name, conn, &tools, reauth_config, proxied);
525 Ok(tools)
526 }
527
528 fn apply_connected(
529 &mut self,
530 name: &str,
531 conn: McpServerConnection,
532 tools: &[RmcpTool],
533 reauth_config: Option<StreamableHttpClientTransportConfig>,
534 proxied: bool,
535 ) {
536 let existing_reauth = self.servers.get(name).and_then(|r| r.reauth_config.clone());
537 let final_reauth = reauth_config.or(existing_reauth);
538 let tools = tools.iter().map(Tool::from).collect();
539
540 self.remember_server_order(name);
541 self.servers.insert(name.to_string(), ServerRecord::connected(conn, tools, final_reauth, proxied));
542 }
543
544 fn register_proxy(&mut self, tool_dir: std::path::PathBuf, members: HashSet<String>) {
545 self.proxy = Some(ToolProxy::new(DEFAULT_PROXY_NAME.to_string(), members, tool_dir));
546 }
547
548 async fn refresh_proxy_after_auth(&mut self, name: &str, tools: &[RmcpTool], proxied: bool) {
549 if !proxied {
550 return;
551 }
552
553 if let Some(proxy) = self.proxy.as_mut() {
554 proxy.add_member(name.to_string());
555 }
556
557 if let Some(tool_dir) = self.proxy.as_ref().map(|proxy| proxy.tool_dir().to_path_buf())
558 && let Err(e) = ToolProxy::write_tool_entries_to_dir(name, tools, &tool_dir).await
559 {
560 tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
561 }
562 }
563
564 fn remember_server_order(&mut self, name: &str) {
565 if !self.server_order.iter().any(|n| n == name) {
566 self.server_order.push(name.to_string());
567 }
568 }
569
570 async fn apply_authentication_failure(&mut self, name: String, error: String) {
571 self.set_state(&name, ServerState::Failed { error: error.clone() });
572 self.emit_server_statuses_changed().await;
573 self.emit_authentication_failed(name, error).await;
574 }
575
576 fn set_state(&mut self, name: &str, state: ServerState) {
577 self.remember_server_order(name);
578 match self.servers.get_mut(name) {
579 Some(record) => record.state = state,
580 None => {
581 self.servers.insert(name.to_string(), ServerRecord::new(state, None, false));
582 }
583 }
584 }
585
586 fn register_record(
587 &mut self,
588 name: &str,
589 state: ServerState,
590 reauth_config: Option<StreamableHttpClientTransportConfig>,
591 proxied: bool,
592 ) {
593 self.remember_server_order(name);
594 self.servers.insert(name.to_string(), ServerRecord::new(state, reauth_config, proxied));
595 }
596
597 fn connection_for(&self, server_name: &str) -> Option<&McpServerConnection> {
598 self.servers.get(server_name).and_then(ServerRecord::connection)
599 }
600
601 fn client_for_server(&self, server_name: &str) -> Option<Arc<RunningService<RoleClient, McpClient>>> {
602 self.connection_for(server_name).map(|conn| conn.client.clone())
603 }
604
605 fn is_routable_tool(&self, namespaced_tool_name: &str) -> bool {
606 if self.proxy.as_ref().is_some_and(|proxy| proxy.call_tool_name() == namespaced_tool_name) {
607 return true;
608 }
609 match split_on_server_name(namespaced_tool_name) {
610 Some((server_name, tool_name)) => {
611 self.servers.get(server_name).is_some_and(|record| !record.proxied && record.has_tool(tool_name))
612 }
613 None => false,
614 }
615 }
616
617 async fn notify_roots_changed(&self) {
618 for (server_name, record) in &self.servers {
619 if let Some(conn) = record.connection()
620 && let Err(e) = conn.client.notify_roots_list_changed().await
621 {
622 tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
623 }
624 }
625 }
626}
627
628impl Drop for McpManager {
629 fn drop(&mut self) {
630 let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
631 for (server_name, record) in servers {
632 if let Some(conn) = record.into_connection()
633 && let Some(handle) = conn.server_task
634 {
635 handle.abort();
636 tracing::warn!("Server '{server_name}' task aborted during cleanup");
637 }
638 }
639 }
640}
641
642struct ServerRecord {
644 state: ServerState,
645 reauth_config: Option<StreamableHttpClientTransportConfig>,
646 proxied: bool,
647}
648
649enum ServerState {
650 Connecting,
651 Connected { connection: McpServerConnection, tools: Vec<Tool> },
652 Authenticating,
653 Failed { error: String },
654 NeedsOAuth,
655}
656
657impl From<&ServerState> for McpServerStatus {
658 fn from(state: &ServerState) -> Self {
659 match state {
660 ServerState::Connecting => Self::Connecting,
661 ServerState::Connected { tools, .. } => Self::Connected { tool_count: tools.len() },
662 ServerState::Authenticating => Self::Authenticating,
663 ServerState::Failed { error } => Self::Failed { error: error.clone() },
664 ServerState::NeedsOAuth => Self::NeedsOAuth,
665 }
666 }
667}
668
669impl ServerRecord {
670 fn new(state: ServerState, reauth_config: Option<StreamableHttpClientTransportConfig>, proxied: bool) -> Self {
671 Self { state, reauth_config, proxied }
672 }
673
674 fn connected(
675 connection: McpServerConnection,
676 tools: Vec<Tool>,
677 reauth_config: Option<StreamableHttpClientTransportConfig>,
678 proxied: bool,
679 ) -> Self {
680 Self::new(ServerState::Connected { connection, tools }, reauth_config, proxied)
681 }
682
683 fn tools(&self) -> &[Tool] {
684 match &self.state {
685 ServerState::Connected { tools, .. } => tools,
686 ServerState::Connecting
687 | ServerState::Authenticating
688 | ServerState::Failed { .. }
689 | ServerState::NeedsOAuth => &[],
690 }
691 }
692
693 fn has_tool(&self, tool_name: &str) -> bool {
694 self.tools().iter().any(|tool| tool.name == tool_name)
695 }
696
697 fn connection(&self) -> Option<&McpServerConnection> {
698 match &self.state {
699 ServerState::Connected { connection, .. } => Some(connection),
700 ServerState::Connecting
701 | ServerState::Authenticating
702 | ServerState::Failed { .. }
703 | ServerState::NeedsOAuth => None,
704 }
705 }
706
707 fn into_connection(self) -> Option<McpServerConnection> {
708 match self.state {
709 ServerState::Connected { connection, .. } => Some(connection),
710 ServerState::Connecting
711 | ServerState::Authenticating
712 | ServerState::Failed { .. }
713 | ServerState::NeedsOAuth => None,
714 }
715 }
716
717 fn auth_capability(&self) -> McpServerAuthCapability {
718 if self.reauth_config.is_some() { McpServerAuthCapability::OAuth } else { McpServerAuthCapability::Unavailable }
719 }
720
721 fn can_authenticate(&self) -> bool {
722 self.reauth_config.is_some()
723 }
724
725 fn status(&self) -> McpServerStatus {
726 (&self.state).into()
727 }
728
729 fn status_entry(&self, name: &str) -> McpServerStatusEntry {
730 McpServerStatusEntry::new(name, self.status())
731 .with_auth_capability(self.auth_capability())
732 .with_proxied(self.proxied)
733 }
734}
735
736#[cfg(test)]
737mod tests {
738 use super::{DEFAULT_PROXY_NAME, McpClientEvent, McpManager, McpServerStatus, ServerState};
739 use crate::client::OAuthHandlerFactory;
740 use crate::client::config::{McpServer, McpTransport};
741 use crate::client::connection::{McpConnectAttempt, McpConnectOutcome};
742 use crate::status::McpServerAuthCapability;
743 use aether_auth::{OAuthCallback, OAuthError, OAuthHandler};
744 use futures::future::BoxFuture;
745 use rmcp::{
746 Json, RoleServer, ServerHandler,
747 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
748 model::{Implementation, ServerCapabilities, ServerInfo},
749 service::DynService,
750 tool, tool_handler, tool_router,
751 transport::streamable_http_client::StreamableHttpClientTransportConfig,
752 };
753 use schemars::JsonSchema;
754 use serde::{Deserialize, Serialize};
755 use std::{
756 io,
757 sync::{Arc, Mutex},
758 };
759 use tokio::sync::mpsc;
760
761 #[derive(Clone)]
762 struct TestServer {
763 tool_router: ToolRouter<Self>,
764 }
765
766 #[tool_handler(router = self.tool_router)]
767 impl ServerHandler for TestServer {
768 fn get_info(&self) -> ServerInfo {
769 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
770 .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
771 .with_instructions("Test server instructions")
772 }
773 }
774
775 impl Default for TestServer {
776 fn default() -> Self {
777 Self { tool_router: Self::tool_router() }
778 }
779 }
780
781 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
782 struct EchoRequest {
783 value: String,
784 }
785
786 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
787 struct EchoResult {
788 value: String,
789 }
790
791 #[tool_router]
792 impl TestServer {
793 fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
794 Box::new(self)
795 }
796
797 #[tool(description = "Returns the provided value", annotations(read_only_hint = true, open_world_hint = false))]
798 async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
799 let Parameters(EchoRequest { value }) = request;
800 Json(EchoResult { value })
801 }
802 }
803
804 #[derive(Clone)]
805 struct SharedWriter(Arc<Mutex<Vec<u8>>>);
806
807 impl io::Write for SharedWriter {
808 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
809 self.0.lock().unwrap().extend_from_slice(buf);
810 Ok(buf.len())
811 }
812
813 fn flush(&mut self) -> io::Result<()> {
814 Ok(())
815 }
816 }
817
818 struct TestOAuthHandler;
819
820 impl OAuthHandler for TestOAuthHandler {
821 fn redirect_uri(&self) -> &'static str {
822 "http://127.0.0.1:0/oauth2callback"
823 }
824
825 fn authorize(&self, _auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
826 Box::pin(async { Err(OAuthError::UserCancelled) })
827 }
828 }
829
830 fn test_oauth_handler_factory() -> OAuthHandlerFactory {
831 Arc::new(|_ctx| Ok(Arc::new(TestOAuthHandler)))
832 }
833
834 #[tokio::test]
835 async fn authenticate_server_task_rejects_record_without_reauth_config() {
836 let (event_sender, _event_receiver) = mpsc::channel(1);
837 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
838 manager.register_record("public", ServerState::Connecting, None, false);
839
840 let error = match manager.authenticate_server_task("public").await {
841 Ok(_) => panic!("non-OAuth server should be rejected"),
842 Err(error) => error.to_string(),
843 };
844 assert!(error.contains("not OAuth-authenticatable"));
845 }
846
847 #[tokio::test]
848 async fn authenticate_server_task_marks_server_authenticating_and_emits_status() {
849 let (event_sender, mut event_receiver) = mpsc::channel(2);
850 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
851 manager.register_record(
852 "remote",
853 ServerState::NeedsOAuth,
854 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
855 false,
856 );
857
858 let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
859
860 assert!(matches!(manager.server_statuses()[0].status, McpServerStatus::Authenticating));
861 let event = event_receiver.recv().await.expect("status change event");
862 let McpClientEvent::ServerStatusesChanged(servers) = event else {
863 panic!("expected ServerStatusesChanged");
864 };
865 let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
866 assert!(matches!(status.status, McpServerStatus::Authenticating));
867 assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
868 }
869
870 #[tokio::test]
871 async fn apply_connection_attempt_failure_allows_retry() {
872 let (event_sender, mut event_receiver) = mpsc::channel(2);
873 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
874 manager.register_record(
875 "remote",
876 ServerState::NeedsOAuth,
877 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
878 false,
879 );
880 let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
881 let _authenticating_event = event_receiver.recv().await.expect("authenticating status change event");
882
883 manager
884 .apply_connection_attempt(McpConnectAttempt {
885 name: "remote".to_string(),
886 proxied: false,
887 outcome: McpConnectOutcome::Failed {
888 error: crate::client::McpError::ConnectionFailed("boom".to_string()),
889 },
890 })
891 .await;
892
893 let event = event_receiver.recv().await.expect("status change event");
894 let McpClientEvent::ServerStatusesChanged(servers) = event else {
895 panic!("expected ServerStatusesChanged");
896 };
897 let auth_event = event_receiver.recv().await.expect("authentication failure event");
898 let McpClientEvent::AuthenticationFailed { server, error } = auth_event else {
899 panic!("expected AuthenticationFailed");
900 };
901 assert_eq!(server, "remote");
902 assert!(error.contains("boom"));
903
904 let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
905 assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
906 assert!(matches!(status.status, McpServerStatus::Failed { ref error } if error.contains("boom")));
907 assert!(manager.authenticate_server_task("remote").await.is_ok());
908 }
909
910 #[test]
911 fn status_entries_are_derived_from_reauth_config() {
912 let (event_sender, _event_receiver) = mpsc::channel(1);
913 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
914
915 manager.register_record(
916 "with-oauth",
917 ServerState::Connecting,
918 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp")),
919 false,
920 );
921 manager.register_record("without-oauth", ServerState::Connecting, None, false);
922 manager.register_record(
923 "needs-oauth",
924 ServerState::NeedsOAuth,
925 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp2")),
926 false,
927 );
928
929 let statuses = manager.server_statuses();
930 let with_oauth = statuses.iter().find(|s| s.name == "with-oauth").unwrap();
931 let without_oauth = statuses.iter().find(|s| s.name == "without-oauth").unwrap();
932 let needs_oauth = statuses.iter().find(|s| s.name == "needs-oauth").unwrap();
933
934 assert_eq!(with_oauth.auth_capability, McpServerAuthCapability::OAuth);
935 assert_eq!(without_oauth.auth_capability, McpServerAuthCapability::Unavailable);
936 assert_eq!(needs_oauth.auth_capability, McpServerAuthCapability::OAuth);
937 }
938
939 #[tokio::test]
940 async fn register_pending_marks_every_server_connecting_and_emits_status() {
941 let (event_sender, mut event_receiver) = mpsc::channel(32);
942 let mut manager = McpManager::new(event_sender, None);
943
944 let servers = vec![
945 McpServer::new("alpha", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
946 McpServer::new("beta", McpTransport::InMemory { server: TestServer::default().into_dyn() }, true),
947 ];
948
949 let returned = manager.register_pending(servers).await.unwrap();
950 assert_eq!(returned.iter().map(|s| s.name.as_str()).collect::<Vec<_>>(), vec!["alpha", "beta"]);
951
952 let statuses = manager.server_statuses();
953 assert_eq!(statuses.len(), 2);
954 assert!(matches!(statuses.iter().find(|s| s.name == "alpha").unwrap().status, McpServerStatus::Connecting));
955 assert!(matches!(statuses.iter().find(|s| s.name == "beta").unwrap().status, McpServerStatus::Connecting));
956 assert!(statuses.iter().find(|s| s.name == "beta").unwrap().proxied);
957
958 let event = event_receiver.try_recv().expect("expected initial ServerStatusesChanged emission");
959 let McpClientEvent::ServerStatusesChanged(emitted) = event else {
960 panic!("expected ServerStatusesChanged, got {event:?}");
961 };
962 assert_eq!(emitted.iter().map(|s| s.name.as_str()).collect::<Vec<_>>(), vec!["alpha", "beta"]);
963 }
964
965 #[tokio::test]
966 async fn apply_connection_attempt_emits_instructions_updated_after_connect() {
967 let (event_sender, mut event_receiver) = mpsc::channel(32);
968 let mut manager = McpManager::new(event_sender, None);
969
970 let servers =
971 vec![McpServer::new("test", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false)];
972 manager.add_mcps(servers).await.unwrap();
973
974 let mut update_for_test = None;
975 while let Ok(event) = event_receiver.try_recv() {
976 if let McpClientEvent::ServerInstructionsUpdated { server, instructions } = event
977 && server == "test"
978 {
979 update_for_test = Some(instructions);
980 }
981 }
982 let instructions = update_for_test.expect("expected ServerInstructionsUpdated for 'test'");
983 assert!(instructions.is_some(), "TestServer publishes instructions, so update should carry Some(_)");
984 }
985
986 #[tokio::test]
987 async fn server_statuses_mark_direct_and_proxied_servers_without_proxy_row() {
988 let (event_sender, _event_receiver) = mpsc::channel(32);
989 let mut manager = McpManager::new(event_sender, None);
990 manager
991 .add_mcps(vec![
992 McpServer::new("direct", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
993 McpServer::new("math", McpTransport::InMemory { server: TestServer::default().into_dyn() }, true),
994 ])
995 .await
996 .unwrap();
997
998 let statuses = manager.server_statuses();
999 assert_eq!(statuses.iter().map(|status| status.name.as_str()).collect::<Vec<_>>(), vec!["direct", "math"]);
1000 assert!(!statuses.iter().find(|status| status.name == "direct").unwrap().proxied);
1001 assert!(statuses.iter().find(|status| status.name == "math").unwrap().proxied);
1002 assert!(!statuses.iter().any(|status| status.name == DEFAULT_PROXY_NAME));
1003 }
1004
1005 #[tokio::test]
1006 async fn tool_definitions_drop_when_a_server_shuts_down() {
1007 let (event_sender, _event_receiver) = mpsc::channel(32);
1008 let mut manager = McpManager::new(event_sender, None);
1009 manager
1010 .add_mcps(vec![
1011 McpServer::new("git", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
1012 McpServer::new("github", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
1013 ])
1014 .await
1015 .unwrap();
1016
1017 let names =
1018 |manager: &McpManager| manager.tool_definitions().into_iter().map(|tool| tool.name).collect::<Vec<_>>();
1019 assert!(names(&manager).contains(&"git__echo".to_string()));
1020 assert!(names(&manager).contains(&"github__echo".to_string()));
1021
1022 manager.shutdown_server("git").await.unwrap();
1023
1024 assert!(!names(&manager).iter().any(|name| name.starts_with("git__")));
1025 assert!(names(&manager).contains(&"github__echo".to_string()));
1026 }
1027
1028 #[tokio::test]
1029 async fn tool_definitions_preserve_annotations() {
1030 let (event_sender, _event_receiver) = mpsc::channel(32);
1031 let mut manager = McpManager::new(event_sender, None);
1032 manager
1033 .add_mcps(vec![McpServer::new(
1034 "test",
1035 McpTransport::InMemory { server: TestServer::default().into_dyn() },
1036 false,
1037 )])
1038 .await
1039 .unwrap();
1040
1041 let tools = manager.tool_definitions();
1042 let echo = tools.iter().find(|tool| tool.name == "test__echo").expect("echo tool");
1043 let annotations = echo.annotations.as_ref().expect("annotations should be preserved");
1044 assert_eq!(annotations.read_only_hint, Some(true));
1045 assert_eq!(annotations.open_world_hint, Some(false));
1046 }
1047
1048 #[tokio::test]
1049 async fn drop_logs_cleanup_abort_with_tracing() {
1050 let (event_sender, _event_receiver) = mpsc::channel(32);
1051 let mut manager = McpManager::new(event_sender, None);
1052 manager
1053 .add_mcps(vec![McpServer::new(
1054 "test",
1055 McpTransport::InMemory { server: TestServer::default().into_dyn() },
1056 false,
1057 )])
1058 .await
1059 .unwrap();
1060
1061 let output = Arc::new(Mutex::new(Vec::new()));
1062 let subscriber = tracing_subscriber::fmt()
1063 .with_ansi(false)
1064 .without_time()
1065 .with_writer({
1066 let output = Arc::clone(&output);
1067 move || SharedWriter(Arc::clone(&output))
1068 })
1069 .finish();
1070
1071 tracing::subscriber::with_default(subscriber, || {
1072 drop(manager);
1073 });
1074
1075 let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
1076 assert!(logs.contains("Server 'test' task aborted during cleanup"));
1077 }
1078}