1use llm::ToolDefinition;
2
3use super::{
4 McpError, Result,
5 config::McpServer,
6 connection::{
7 ConnectContext, McpConnectAttempt, McpConnectOutcome, McpServerConnection, ServerInstructions, Tool,
8 authenticate_http, connect_server,
9 },
10 mcp_client::McpClient,
11 naming::{create_namespaced_tool_name, split_on_server_name},
12 oauth::OAuthHandler,
13 tool_proxy::ToolProxy,
14};
15use futures::future::join_all;
16use rmcp::{
17 RoleClient,
18 model::{
19 CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
20 ElicitationAction, FormElicitationCapability, Implementation, Root, Tool as RmcpTool, UrlElicitationCapability,
21 },
22 service::RunningService,
23 transport::streamable_http_client::StreamableHttpClientTransportConfig,
24};
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::{HashMap, HashSet};
28use std::future::Future;
29use std::path::PathBuf;
30use std::sync::Arc;
31use tokio::sync::{RwLock, mpsc, oneshot};
32
33pub use crate::status::{McpServerAuthCapability, McpServerStatus, McpServerStatusEntry};
34
35pub const DEFAULT_PROXY_NAME: &str = "proxy";
36
37pub type OAuthHandlerFactory = Arc<dyn Fn() -> Result<Arc<dyn OAuthHandler>> + Send + Sync>;
38
39#[derive(Debug)]
40pub struct ElicitationRequest {
41 pub server_name: String,
42 pub request: CreateElicitationRequestParams,
43 pub response_sender: oneshot::Sender<CreateElicitationResult>,
44}
45
46#[derive(Debug, Clone)]
47pub struct ElicitationResponse {
48 pub action: ElicitationAction,
49 pub content: Option<Value>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
53pub struct UrlElicitationCompleteParams {
54 pub server_name: String,
55 pub elicitation_id: String,
56}
57
58#[derive(Debug)]
62pub enum McpClientEvent {
63 Elicitation(ElicitationRequest),
64 UrlElicitationComplete(UrlElicitationCompleteParams),
65 ServerStatusesChanged(Vec<McpServerStatusEntry>),
66 ToolDefinitionsChanged(Vec<ToolDefinition>),
67 AuthenticationFailed { server: String, error: String },
68}
69
70pub struct McpManager {
72 servers: HashMap<String, ServerRecord>,
73 server_order: Vec<String>,
74 tools: HashMap<String, Tool>,
75 tool_definitions: Vec<ToolDefinition>,
76 proxy: Option<ToolProxy>,
77 aether_home: Option<PathBuf>,
78 client_info: ClientInfo,
79 event_sender: mpsc::Sender<McpClientEvent>,
80 roots: Arc<RwLock<Vec<Root>>>,
82 oauth_handler_factory: Option<OAuthHandlerFactory>,
83 server_statuses: Vec<McpServerStatusEntry>,
84}
85
86impl McpManager {
87 pub fn new(event_sender: mpsc::Sender<McpClientEvent>, oauth_handler_factory: Option<OAuthHandlerFactory>) -> Self {
88 let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
89 if let Some(elicitation) = capabilities.elicitation.as_mut() {
90 elicitation.form = Some(FormElicitationCapability::default());
91 elicitation.url = Some(UrlElicitationCapability::default());
92 }
93
94 Self {
95 servers: HashMap::new(),
96 server_order: Vec::new(),
97 tools: HashMap::new(),
98 tool_definitions: Vec::new(),
99 proxy: None,
100 aether_home: None,
101 client_info: ClientInfo::new(capabilities, Implementation::new("aether", "0.1.0")),
102 event_sender,
103 roots: Arc::new(RwLock::new(Vec::new())),
104 oauth_handler_factory,
105 server_statuses: Vec::new(),
106 }
107 }
108
109 pub fn with_aether_home(mut self, aether_home: impl Into<PathBuf>) -> Self {
110 self.aether_home = Some(aether_home.into());
111 self
112 }
113
114 pub async fn add_mcps(&mut self, servers: Vec<McpServer>) -> Result<()> {
115 let has_proxy = servers.iter().any(|server| server.proxy);
116 if has_proxy && servers.iter().any(|server| server.name == DEFAULT_PROXY_NAME) {
117 return Err(McpError::Other("server name 'proxy' collides with the tool proxy".into()));
118 }
119
120 let proxied_members: HashSet<String> =
121 servers.iter().filter(|server| server.proxy).map(|server| server.name.clone()).collect();
122 let proxy_tool_dir = if has_proxy {
123 let dir = self.proxy_tool_dir()?;
124 ToolProxy::clean_dir(&dir).await?;
125 Some(dir)
126 } else {
127 None
128 };
129
130 let ctx = self.connect_context();
131 let attempts = join_all(servers.into_iter().map(|server| connect_server(server, &ctx))).await;
132
133 let mut connected_proxied = Vec::new();
134 for McpConnectAttempt { name, proxied, outcome } in attempts {
135 match outcome {
136 McpConnectOutcome::Connected { conn, reauth_config } => {
137 self.register_connection(&name, conn, reauth_config, proxied).await?;
138 if proxied {
139 connected_proxied.push(name);
140 }
141 }
142 McpConnectOutcome::NeedsOAuth { config, error } => {
143 tracing::warn!("Server '{name}' needs OAuth: {error}");
144 self.register_record(&name, McpServerStatus::NeedsOAuth, Some(config), proxied);
145 }
146 McpConnectOutcome::Failed { error } => {
147 tracing::warn!("Failed to connect to MCP server '{name}': {error}");
148 if !self.servers.contains_key(&name) {
149 self.register_record(
150 &name,
151 McpServerStatus::Failed { error: error.to_string() },
152 None,
153 proxied,
154 );
155 }
156 }
157 }
158 }
159
160 if let Some(tool_dir) = proxy_tool_dir {
161 self.write_proxy_tool_files(&connected_proxied, &tool_dir).await;
162 self.register_proxy(tool_dir, proxied_members);
163 }
164
165 Ok(())
166 }
167
168 pub fn get_client_for_tool(
169 &self,
170 namespaced_tool_name: &str,
171 arguments_json: &str,
172 ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
173 if !self.tools.contains_key(namespaced_tool_name) {
174 return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
175 }
176
177 let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
178 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
179
180 if let Some(proxy) = self.proxy.as_ref().filter(|proxy| proxy.name() == server_name) {
181 let call = proxy.resolve_call(arguments_json)?;
182 let conn = self.connection_for(&call.server).ok_or_else(|| {
183 McpError::ServerNotFound(format!("Proxied server '{}' is not connected", call.server))
184 })?;
185 let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
186 return Ok((conn.client.clone(), params));
187 }
188
189 let client =
190 self.client_for_server(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
191
192 let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
193 let mut params = CallToolRequestParams::new(tool_name.to_string());
194 if let Some(args) = arguments {
195 params = params.with_arguments(args);
196 }
197
198 Ok((client, params))
199 }
200
201 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
202 self.tool_definitions.clone()
203 }
204
205 pub fn server_instructions(&self) -> Vec<ServerInstructions> {
206 let mut instructions: Vec<ServerInstructions> = self
207 .servers
208 .iter()
209 .filter(|(name, _)| self.proxy.as_ref().is_none_or(|proxy| !proxy.contains_server(name)))
210 .filter_map(|(name, record)| {
211 record
212 .connection
213 .as_ref()
214 .and_then(|conn| conn.instructions.as_ref())
215 .map(|instr| ServerInstructions { server_name: name.clone(), instructions: instr.clone() })
216 })
217 .collect();
218
219 if let Some(proxy) = &self.proxy {
220 let descriptions: Vec<(String, String)> = proxy
221 .members()
222 .iter()
223 .filter_map(|member| {
224 let conn = self.connection_for(member)?;
225 Some((member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
226 })
227 .collect();
228 instructions.push(ServerInstructions {
229 server_name: proxy.name().to_string(),
230 instructions: ToolProxy::build_instructions(proxy.tool_dir(), &descriptions),
231 });
232 }
233
234 instructions
235 }
236
237 pub fn server_statuses(&self) -> &[McpServerStatusEntry] {
238 &self.server_statuses
239 }
240
241 pub async fn authenticate_server_task(
242 &mut self,
243 name: &str,
244 ) -> Result<impl Future<Output = McpConnectAttempt> + Send + 'static> {
245 let record = self
246 .servers
247 .get(name)
248 .ok_or_else(|| McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")))?;
249 if !record.can_authenticate() {
250 return Err(McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")));
251 }
252 if matches!(record.status, McpServerStatus::Authenticating) {
253 return Err(McpError::ConnectionFailed(format!("server '{name}' is already authenticating")));
254 }
255
256 let oauth_handler_factory = self
257 .oauth_handler_factory
258 .clone()
259 .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler factory available for '{name}'")))?;
260 let name = name.to_string();
261 let config = record.reauth_config.clone().expect("checked above");
262 let client_info = self.client_info.clone();
263 let event_sender = self.event_sender.clone();
264 let roots = Arc::clone(&self.roots);
265 let proxied = record.proxied;
266
267 self.set_status(&name, McpServerStatus::Authenticating);
268 self.emit_server_statuses_changed().await;
269
270 Ok(async move {
271 authenticate_http(name, config, client_info, event_sender, roots, oauth_handler_factory, proxied).await
272 })
273 }
274
275 pub async fn apply_connection_attempt(&mut self, attempt: McpConnectAttempt) {
276 let McpConnectAttempt { name, proxied, outcome } = attempt;
277 match outcome {
278 McpConnectOutcome::Connected { conn, reauth_config } => {
279 match self.register_connection(&name, conn, reauth_config, proxied).await {
280 Ok(tools) => {
281 self.refresh_proxy_after_auth(&name, &tools, proxied).await;
282 self.emit_server_statuses_changed().await;
283 self.emit_tool_definitions_changed().await;
284 }
285 Err(error) => self.apply_authentication_failure(name, error.to_string()).await,
286 }
287 }
288 McpConnectOutcome::Failed { error } => {
289 self.apply_authentication_failure(name, error.to_string()).await;
290 }
291 McpConnectOutcome::NeedsOAuth { .. } => {
292 self.apply_authentication_failure(name, "internal error: auth task returned NeedsOAuth".to_string())
293 .await;
294 }
295 }
296 }
297
298 pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
300 let futures: Vec<_> = self
301 .servers
302 .iter()
303 .filter_map(|(server_name, record)| {
304 let conn = record.connection.as_ref()?;
305 conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref())?;
306 let server_name = server_name.clone();
307 let client = conn.client.clone();
308 Some(async move {
309 let prompts_response = client.list_prompts(None).await.map_err(|e| {
310 McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
311 })?;
312
313 let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
314 .prompts
315 .into_iter()
316 .map(|prompt| {
317 let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
318 rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
319 })
320 .collect();
321
322 Ok::<_, McpError>(namespaced_prompts)
323 })
324 })
325 .collect();
326
327 let results = join_all(futures).await;
328 let mut all_prompts = Vec::new();
329 for result in results {
330 all_prompts.extend(result?);
331 }
332
333 Ok(all_prompts)
334 }
335
336 pub async fn get_prompt(
338 &self,
339 namespaced_prompt_name: &str,
340 arguments: Option<serde_json::Map<String, serde_json::Value>>,
341 ) -> Result<rmcp::model::GetPromptResult> {
342 let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
343 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
344
345 let server_conn =
346 self.connection_for(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
347
348 let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
349 if let Some(args) = arguments {
350 request = request.with_arguments(args);
351 }
352
353 server_conn.client.get_prompt(request).await.map_err(|e| {
354 McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
355 })
356 }
357
358 pub async fn shutdown(&mut self) {
360 let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
361
362 for (server_name, record) in servers {
363 if let Some(conn) = record.connection
364 && let Some(handle) = conn.server_task
365 {
366 drop(conn.client);
367
368 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
369 Ok(Ok(())) => {
370 tracing::info!("Server '{server_name}' shut down gracefully");
371 }
372 Ok(Err(e)) => {
373 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
374 }
375 Err(_) => {
376 tracing::warn!("Server '{server_name}' shutdown timed out");
377 }
378 }
379 }
380 }
381
382 self.tools.clear();
383 self.tool_definitions.clear();
384 self.proxy = None;
385 }
386
387 pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
389 let record = self.servers.remove(server_name);
390
391 if let Some(record) = record {
392 if let Some(conn) = record.connection
393 && let Some(handle) = conn.server_task
394 {
395 drop(conn.client);
396
397 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
398 Ok(Ok(())) => {
399 tracing::info!("Server '{server_name}' shut down gracefully");
400 }
401 Ok(Err(e)) => {
402 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
403 }
404 Err(_) => {
405 tracing::warn!("Server '{server_name}' shutdown timed out");
406 }
407 }
408 }
409
410 self.remove_registered_tools_for_server(server_name);
411 self.refresh_status_entries();
412 }
413
414 Ok(())
415 }
416
417 pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
422 {
423 let mut roots = self.roots.write().await;
424 *roots = new_roots;
425 }
426
427 self.notify_roots_changed().await;
428
429 Ok(())
430 }
431
432 async fn emit_server_statuses_changed(&self) {
433 self.emit_event(McpClientEvent::ServerStatusesChanged(self.server_statuses().to_vec())).await;
434 }
435
436 async fn emit_tool_definitions_changed(&self) {
437 self.emit_event(McpClientEvent::ToolDefinitionsChanged(self.tool_definitions())).await;
438 }
439
440 async fn emit_authentication_failed(&self, server: String, error: String) {
441 self.emit_event(McpClientEvent::AuthenticationFailed { server, error }).await;
442 }
443
444 async fn emit_event(&self, event: McpClientEvent) {
445 if let Err(e) = self.event_sender.send(event).await {
446 tracing::warn!("Failed to emit MCP client event: {e}");
447 }
448 }
449
450 fn connect_context(&self) -> ConnectContext<'_> {
451 ConnectContext {
452 client_info: &self.client_info,
453 event_sender: &self.event_sender,
454 roots: &self.roots,
455 oauth_handler_factory: self.oauth_handler_factory.as_ref(),
456 }
457 }
458
459 fn proxy_tool_dir(&self) -> Result<PathBuf> {
460 self.aether_home
461 .as_ref()
462 .map(|home| ToolProxy::dir_in_home(home, DEFAULT_PROXY_NAME))
463 .map_or_else(|| ToolProxy::dir(DEFAULT_PROXY_NAME), Ok)
464 }
465
466 async fn register_connection(
467 &mut self,
468 name: &str,
469 conn: McpServerConnection,
470 reauth_config: Option<StreamableHttpClientTransportConfig>,
471 proxied: bool,
472 ) -> Result<Vec<RmcpTool>> {
473 let tools = conn
474 .list_tools()
475 .await
476 .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
477 self.apply_connected(name, conn, &tools, reauth_config, proxied);
478 Ok(tools)
479 }
480
481 fn apply_connected(
482 &mut self,
483 name: &str,
484 conn: McpServerConnection,
485 tools: &[RmcpTool],
486 reauth_config: Option<StreamableHttpClientTransportConfig>,
487 proxied: bool,
488 ) {
489 self.remove_registered_tools_for_server(name);
490
491 let existing_reauth = self.servers.get(name).and_then(|r| r.reauth_config.clone());
492 let final_reauth = reauth_config.or(existing_reauth);
493
494 for rmcp_tool in tools {
495 let tool_name = rmcp_tool.name.to_string();
496 let namespaced_tool_name = create_namespaced_tool_name(name, &tool_name);
497 let tool = Tool::from(rmcp_tool);
498
499 if !proxied {
500 self.tool_definitions.push(ToolDefinition {
501 name: namespaced_tool_name.clone(),
502 description: tool.description.clone(),
503 parameters: tool.parameters.to_string(),
504 server: Some(name.to_string()),
505 });
506 self.tools.insert(namespaced_tool_name, tool);
507 }
508 }
509
510 self.remember_server_order(name);
511 self.servers.insert(name.to_string(), ServerRecord::connected(conn, tools.len(), final_reauth, proxied));
512 self.refresh_status_entries();
513 }
514
515 fn register_proxy(&mut self, tool_dir: std::path::PathBuf, members: HashSet<String>) {
516 self.remove_registered_tools_for_server(DEFAULT_PROXY_NAME);
517 let call_tool_def = ToolProxy::call_tool_definition(DEFAULT_PROXY_NAME);
518 self.tools.insert(
519 call_tool_def.name.clone(),
520 Tool {
521 description: call_tool_def.description.clone(),
522 parameters: serde_json::from_str(&call_tool_def.parameters)
523 .unwrap_or(Value::Object(serde_json::Map::default())),
524 },
525 );
526 self.tool_definitions.push(call_tool_def);
527
528 self.proxy = Some(ToolProxy::new(DEFAULT_PROXY_NAME.to_string(), members, tool_dir));
529 }
530
531 async fn refresh_proxy_after_auth(&mut self, name: &str, tools: &[RmcpTool], proxied: bool) {
532 if !proxied {
533 return;
534 }
535
536 if let Some(proxy) = self.proxy.as_mut() {
537 proxy.add_member(name.to_string());
538 }
539
540 if let Some(tool_dir) = self.proxy.as_ref().map(|proxy| proxy.tool_dir().to_path_buf())
541 && let Err(e) = ToolProxy::write_tool_entries_to_dir(name, tools, &tool_dir).await
542 {
543 tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
544 }
545 }
546
547 async fn write_proxy_tool_files(&self, connected_proxied: &[String], tool_dir: &std::path::Path) {
548 let writes = connected_proxied.iter().filter_map(|name| {
549 let client = self.client_for_server(name)?;
550 let dir = tool_dir.to_path_buf();
551 let name = name.clone();
552 Some(async move {
553 if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &dir).await {
554 tracing::warn!("Failed to write tool files for proxied server '{name}': {e}");
555 }
556 })
557 });
558 join_all(writes).await;
559 }
560
561 fn refresh_status_entries(&mut self) {
562 self.server_statuses = self
563 .server_order
564 .iter()
565 .filter_map(|name| self.servers.get(name).map(|record| record.status_entry(name)))
566 .collect();
567 }
568
569 fn remember_server_order(&mut self, name: &str) {
570 if !self.server_order.iter().any(|n| n == name) {
571 self.server_order.push(name.to_string());
572 }
573 }
574
575 async fn apply_authentication_failure(&mut self, name: String, error: String) {
576 self.set_status(&name, McpServerStatus::Failed { error: error.clone() });
577 self.emit_server_statuses_changed().await;
578 self.emit_authentication_failed(name, error).await;
579 }
580
581 fn set_status(&mut self, name: &str, status: McpServerStatus) {
582 self.remember_server_order(name);
583 let record =
584 self.servers.entry(name.to_string()).or_insert_with(|| ServerRecord::new(status.clone(), None, false));
585 record.status = status;
586 self.refresh_status_entries();
587 }
588
589 fn register_record(
590 &mut self,
591 name: &str,
592 status: McpServerStatus,
593 reauth_config: Option<StreamableHttpClientTransportConfig>,
594 proxied: bool,
595 ) {
596 self.remember_server_order(name);
597 self.servers.insert(name.to_string(), ServerRecord::new(status, reauth_config, proxied));
598 self.refresh_status_entries();
599 }
600
601 fn connection_for(&self, server_name: &str) -> Option<&McpServerConnection> {
602 self.servers.get(server_name).and_then(|record| record.connection.as_ref())
603 }
604
605 fn client_for_server(&self, server_name: &str) -> Option<Arc<RunningService<RoleClient, McpClient>>> {
606 self.connection_for(server_name).map(|conn| conn.client.clone())
607 }
608
609 fn remove_registered_tools_for_server(&mut self, server_name: &str) {
610 let prefix = format!("{server_name}__");
611 self.tools.retain(|tool_name, _| !tool_name.starts_with(&prefix));
612 self.tool_definitions.retain(|tool_def| !tool_def.name.starts_with(&prefix));
613 }
614
615 async fn notify_roots_changed(&self) {
616 for (server_name, record) in &self.servers {
617 if let Some(conn) = &record.connection
618 && let Err(e) = conn.client.notify_roots_list_changed().await
619 {
620 tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
621 }
622 }
623 }
624}
625
626impl Drop for McpManager {
627 fn drop(&mut self) {
628 let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
629 for (server_name, record) in servers {
630 if let Some(conn) = record.connection
631 && let Some(handle) = conn.server_task
632 {
633 handle.abort();
634 tracing::warn!("Server '{server_name}' task aborted during cleanup");
635 }
636 }
637 }
638}
639
640struct ServerRecord {
642 connection: Option<McpServerConnection>,
643 status: McpServerStatus,
644 reauth_config: Option<StreamableHttpClientTransportConfig>,
645 proxied: bool,
646}
647
648impl ServerRecord {
649 fn new(status: McpServerStatus, reauth_config: Option<StreamableHttpClientTransportConfig>, proxied: bool) -> Self {
650 Self { connection: None, status, reauth_config, proxied }
651 }
652
653 fn connected(
654 connection: McpServerConnection,
655 tool_count: usize,
656 reauth_config: Option<StreamableHttpClientTransportConfig>,
657 proxied: bool,
658 ) -> Self {
659 Self { connection: Some(connection), status: McpServerStatus::Connected { tool_count }, reauth_config, proxied }
660 }
661
662 fn auth_capability(&self) -> McpServerAuthCapability {
663 if self.reauth_config.is_some() { McpServerAuthCapability::OAuth } else { McpServerAuthCapability::Unavailable }
664 }
665
666 fn can_authenticate(&self) -> bool {
667 self.reauth_config.is_some()
668 }
669
670 fn status_entry(&self, name: &str) -> McpServerStatusEntry {
671 McpServerStatusEntry::new(name, self.status.clone())
672 .with_auth_capability(self.auth_capability())
673 .with_proxied(self.proxied)
674 }
675}
676
677#[cfg(test)]
678mod tests {
679 use super::{DEFAULT_PROXY_NAME, McpClientEvent, McpManager, McpServerStatus, Tool};
680 use crate::client::OAuthHandlerFactory;
681 use crate::client::config::{McpServer, McpTransport};
682 use crate::client::connection::{McpConnectAttempt, McpConnectOutcome};
683 use crate::client::oauth::{OAuthCallback, OAuthError, OAuthHandler};
684 use crate::status::McpServerAuthCapability;
685 use futures::future::BoxFuture;
686 use llm::ToolDefinition;
687 use rmcp::{
688 Json, RoleServer, ServerHandler,
689 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
690 model::{Implementation, ServerCapabilities, ServerInfo},
691 service::DynService,
692 tool, tool_handler, tool_router,
693 transport::streamable_http_client::StreamableHttpClientTransportConfig,
694 };
695 use schemars::JsonSchema;
696 use serde::{Deserialize, Serialize};
697 use serde_json::json;
698 use std::{
699 io,
700 sync::{Arc, Mutex},
701 };
702 use tokio::sync::mpsc;
703
704 #[derive(Clone)]
705 struct TestServer {
706 tool_router: ToolRouter<Self>,
707 }
708
709 #[tool_handler(router = self.tool_router)]
710 impl ServerHandler for TestServer {
711 fn get_info(&self) -> ServerInfo {
712 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
713 .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
714 }
715 }
716
717 impl Default for TestServer {
718 fn default() -> Self {
719 Self { tool_router: Self::tool_router() }
720 }
721 }
722
723 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
724 struct EchoRequest {
725 value: String,
726 }
727
728 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
729 struct EchoResult {
730 value: String,
731 }
732
733 #[tool_router]
734 impl TestServer {
735 fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
736 Box::new(self)
737 }
738
739 #[tool(description = "Returns the provided value")]
740 async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
741 let Parameters(EchoRequest { value }) = request;
742 Json(EchoResult { value })
743 }
744 }
745
746 #[derive(Clone)]
747 struct SharedWriter(Arc<Mutex<Vec<u8>>>);
748
749 impl io::Write for SharedWriter {
750 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
751 self.0.lock().unwrap().extend_from_slice(buf);
752 Ok(buf.len())
753 }
754
755 fn flush(&mut self) -> io::Result<()> {
756 Ok(())
757 }
758 }
759
760 struct TestOAuthHandler;
761
762 impl OAuthHandler for TestOAuthHandler {
763 fn redirect_uri(&self) -> &'static str {
764 "http://127.0.0.1:0/oauth2callback"
765 }
766
767 fn authorize(&self, _auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
768 Box::pin(async { Err(OAuthError::UserCancelled) })
769 }
770 }
771
772 fn test_oauth_handler_factory() -> OAuthHandlerFactory {
773 Arc::new(|| Ok(Arc::new(TestOAuthHandler)))
774 }
775
776 #[tokio::test]
777 async fn authenticate_server_task_rejects_record_without_reauth_config() {
778 let (event_sender, _event_receiver) = mpsc::channel(1);
779 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
780 manager.register_record("public", McpServerStatus::Connected { tool_count: 1 }, None, false);
781
782 let error = match manager.authenticate_server_task("public").await {
783 Ok(_) => panic!("non-OAuth server should be rejected"),
784 Err(error) => error.to_string(),
785 };
786 assert!(error.contains("not OAuth-authenticatable"));
787 }
788
789 #[tokio::test]
790 async fn authenticate_server_task_marks_server_authenticating_and_emits_status() {
791 let (event_sender, mut event_receiver) = mpsc::channel(2);
792 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
793 manager.register_record(
794 "remote",
795 McpServerStatus::NeedsOAuth,
796 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
797 false,
798 );
799
800 let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
801
802 assert!(matches!(manager.server_statuses()[0].status, McpServerStatus::Authenticating));
803 let event = event_receiver.recv().await.expect("status change event");
804 let McpClientEvent::ServerStatusesChanged(servers) = event else {
805 panic!("expected ServerStatusesChanged");
806 };
807 let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
808 assert!(matches!(status.status, McpServerStatus::Authenticating));
809 assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
810 }
811
812 #[tokio::test]
813 async fn authenticate_server_task_rejects_duplicate_same_server_while_in_flight() {
814 let (event_sender, _event_receiver) = mpsc::channel(1);
815 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
816 manager.register_record(
817 "remote",
818 McpServerStatus::NeedsOAuth,
819 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
820 false,
821 );
822
823 let _task = manager.authenticate_server_task("remote").await.expect("first auth should start");
824 let error = match manager.authenticate_server_task("remote").await {
825 Ok(_) => panic!("duplicate auth should be rejected"),
826 Err(error) => error.to_string(),
827 };
828
829 assert!(error.contains("already authenticating"));
830 }
831
832 #[tokio::test]
833 async fn apply_connection_attempt_failure_allows_retry() {
834 let (event_sender, mut event_receiver) = mpsc::channel(2);
835 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
836 manager.register_record(
837 "remote",
838 McpServerStatus::NeedsOAuth,
839 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
840 false,
841 );
842 let _task = manager.authenticate_server_task("remote").await.expect("auth should start");
843 let _authenticating_event = event_receiver.recv().await.expect("authenticating status change event");
844
845 manager
846 .apply_connection_attempt(McpConnectAttempt {
847 name: "remote".to_string(),
848 proxied: false,
849 outcome: McpConnectOutcome::Failed {
850 error: crate::client::McpError::ConnectionFailed("boom".to_string()),
851 },
852 })
853 .await;
854
855 let event = event_receiver.recv().await.expect("status change event");
856 let McpClientEvent::ServerStatusesChanged(servers) = event else {
857 panic!("expected ServerStatusesChanged");
858 };
859 let auth_event = event_receiver.recv().await.expect("authentication failure event");
860 let McpClientEvent::AuthenticationFailed { server, error } = auth_event else {
861 panic!("expected AuthenticationFailed");
862 };
863 assert_eq!(server, "remote");
864 assert!(error.contains("boom"));
865
866 let status = servers.iter().find(|entry| entry.name == "remote").expect("remote status");
867 assert_eq!(status.auth_capability, McpServerAuthCapability::OAuth);
868 assert!(matches!(status.status, McpServerStatus::Failed { ref error } if error.contains("boom")));
869 assert!(manager.authenticate_server_task("remote").await.is_ok());
870 }
871
872 #[test]
873 fn status_entries_are_derived_from_reauth_config() {
874 let (event_sender, _event_receiver) = mpsc::channel(1);
875 let mut manager = McpManager::new(event_sender, Some(test_oauth_handler_factory()));
876
877 manager.register_record(
878 "with-oauth",
879 McpServerStatus::Connected { tool_count: 1 },
880 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp")),
881 false,
882 );
883 manager.register_record("without-oauth", McpServerStatus::Connected { tool_count: 2 }, None, false);
884 manager.register_record(
885 "needs-oauth",
886 McpServerStatus::NeedsOAuth,
887 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp2")),
888 false,
889 );
890
891 let statuses = manager.server_statuses();
892 let with_oauth = statuses.iter().find(|s| s.name == "with-oauth").unwrap();
893 let without_oauth = statuses.iter().find(|s| s.name == "without-oauth").unwrap();
894 let needs_oauth = statuses.iter().find(|s| s.name == "needs-oauth").unwrap();
895
896 assert_eq!(with_oauth.auth_capability, McpServerAuthCapability::OAuth);
897 assert_eq!(without_oauth.auth_capability, McpServerAuthCapability::Unavailable);
898 assert_eq!(needs_oauth.auth_capability, McpServerAuthCapability::OAuth);
899 }
900
901 #[tokio::test]
902 async fn server_statuses_mark_direct_and_proxied_servers_without_proxy_row() {
903 let (event_sender, _event_receiver) = mpsc::channel(1);
904 let mut manager = McpManager::new(event_sender, None);
905 manager
906 .add_mcps(vec![
907 McpServer::new("direct", McpTransport::InMemory { server: TestServer::default().into_dyn() }, false),
908 McpServer::new("math", McpTransport::InMemory { server: TestServer::default().into_dyn() }, true),
909 ])
910 .await
911 .unwrap();
912
913 let statuses = manager.server_statuses();
914 assert_eq!(statuses.iter().map(|status| status.name.as_str()).collect::<Vec<_>>(), vec!["direct", "math"]);
915 assert!(!statuses.iter().find(|status| status.name == "direct").unwrap().proxied);
916 assert!(statuses.iter().find(|status| status.name == "math").unwrap().proxied);
917 assert!(!statuses.iter().any(|status| status.name == DEFAULT_PROXY_NAME));
918 }
919
920 #[test]
921 fn remove_registered_tools_for_server_uses_namespaced_prefix() {
922 let (event_sender, _event_receiver) = mpsc::channel(1);
923 let mut manager = McpManager::new(event_sender, None);
924 manager.tools.insert("git__status".to_string(), Tool { description: String::new(), parameters: json!({}) });
925 manager.tools.insert("github__issue".to_string(), Tool { description: String::new(), parameters: json!({}) });
926 manager.tool_definitions.push(ToolDefinition {
927 name: "git__status".to_string(),
928 description: String::new(),
929 parameters: "{}".to_string(),
930 server: Some("git".to_string()),
931 });
932 manager.tool_definitions.push(ToolDefinition {
933 name: "github__issue".to_string(),
934 description: String::new(),
935 parameters: "{}".to_string(),
936 server: Some("github".to_string()),
937 });
938
939 manager.remove_registered_tools_for_server("git");
940
941 assert!(!manager.tools.contains_key("git__status"));
942 assert!(manager.tools.contains_key("github__issue"));
943 assert_eq!(
944 manager.tool_definitions.iter().map(|tool| tool.name.as_str()).collect::<Vec<_>>(),
945 vec!["github__issue"]
946 );
947 }
948
949 #[tokio::test]
950 async fn drop_logs_cleanup_abort_with_tracing() {
951 let (event_sender, _event_receiver) = mpsc::channel(1);
952 let mut manager = McpManager::new(event_sender, None);
953 manager
954 .add_mcps(vec![McpServer::new(
955 "test",
956 McpTransport::InMemory { server: TestServer::default().into_dyn() },
957 false,
958 )])
959 .await
960 .unwrap();
961
962 let output = Arc::new(Mutex::new(Vec::new()));
963 let subscriber = tracing_subscriber::fmt()
964 .with_ansi(false)
965 .without_time()
966 .with_writer({
967 let output = Arc::clone(&output);
968 move || SharedWriter(Arc::clone(&output))
969 })
970 .finish();
971
972 tracing::subscriber::with_default(subscriber, || {
973 drop(manager);
974 });
975
976 let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
977 assert!(logs.contains("Server 'test' task aborted during cleanup"));
978 }
979}