1use llm::ToolDefinition;
2
3use super::{
4 McpError, Result,
5 config::{McpServerConfig, ServerConfig},
6 connect_mcp::{ConnectOutcome, ConnectionSpec, ProxySpec, Registration, build_plan, connect_mcp},
7 connection::{McpServerConnection, ServerInstructions, Tool},
8 mcp_client::McpClient,
9 naming::{create_namespaced_tool_name, split_on_server_name},
10 oauth::{OAuthHandler, perform_oauth_flow},
11 tool_proxy::ToolProxy,
12};
13use futures::future::join_all;
14use rmcp::{
15 RoleClient,
16 model::{
17 CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
18 ElicitationAction, FormElicitationCapability, Implementation, Root, Tool as RmcpTool, UrlElicitationCapability,
19 },
20 service::RunningService,
21 transport::streamable_http_client::StreamableHttpClientTransportConfig,
22};
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use std::collections::{HashMap, HashSet};
26use std::sync::Arc;
27use tokio::sync::{RwLock, mpsc, oneshot};
28
29pub use crate::status::{McpServerAuthCapability, McpServerStatus, McpServerStatusEntry};
30
31#[derive(Debug)]
32pub struct ElicitationRequest {
33 pub server_name: String,
34 pub request: CreateElicitationRequestParams,
35 pub response_sender: oneshot::Sender<CreateElicitationResult>,
36}
37
38#[derive(Debug, Clone)]
39pub struct ElicitationResponse {
40 pub action: ElicitationAction,
41 pub content: Option<Value>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
45pub struct UrlElicitationCompleteParams {
46 pub server_name: String,
47 pub elicitation_id: String,
48}
49
50#[derive(Debug)]
54pub enum McpClientEvent {
55 Elicitation(ElicitationRequest),
56 UrlElicitationComplete(UrlElicitationCompleteParams),
57}
58
59struct ServerRecord {
61 connection: Option<McpServerConnection>,
62 status: McpServerStatus,
63 reauth_config: Option<StreamableHttpClientTransportConfig>,
64}
65
66impl ServerRecord {
67 fn new(status: McpServerStatus, reauth_config: Option<StreamableHttpClientTransportConfig>) -> Self {
68 Self { connection: None, status, reauth_config }
69 }
70
71 fn connected(
72 connection: McpServerConnection,
73 tool_count: usize,
74 reauth_config: Option<StreamableHttpClientTransportConfig>,
75 ) -> Self {
76 Self { connection: Some(connection), status: McpServerStatus::Connected { tool_count }, reauth_config }
77 }
78
79 fn auth_capability(&self) -> McpServerAuthCapability {
80 if self.reauth_config.is_some() { McpServerAuthCapability::OAuth } else { McpServerAuthCapability::Unavailable }
81 }
82
83 fn can_authenticate(&self) -> bool {
84 self.reauth_config.is_some()
85 }
86
87 fn status_entry(&self, name: &str) -> McpServerStatusEntry {
88 McpServerStatusEntry::new(name, self.status.clone()).with_auth_capability(self.auth_capability())
89 }
90}
91
92pub struct McpManager {
94 servers: HashMap<String, ServerRecord>,
95 server_order: Vec<String>,
96 tools: HashMap<String, Tool>,
97 tool_definitions: Vec<ToolDefinition>,
98 client_info: ClientInfo,
99 event_sender: mpsc::Sender<McpClientEvent>,
100 roots: Arc<RwLock<Vec<Root>>>,
102 oauth_handler: Option<Arc<dyn OAuthHandler>>,
103 server_statuses: Vec<McpServerStatusEntry>,
104 proxy: Option<ToolProxy>,
106}
107
108impl McpManager {
109 pub fn new(event_sender: mpsc::Sender<McpClientEvent>, oauth_handler: Option<Arc<dyn OAuthHandler>>) -> Self {
110 let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
111 if let Some(elicitation) = capabilities.elicitation.as_mut() {
112 elicitation.form = Some(FormElicitationCapability::default());
113 elicitation.url = Some(UrlElicitationCapability::default());
114 }
115
116 Self {
117 servers: HashMap::new(),
118 server_order: Vec::new(),
119 tools: HashMap::new(),
120 tool_definitions: Vec::new(),
121 client_info: ClientInfo::new(capabilities, Implementation::new("aether", "0.1.0")),
122 event_sender,
123 roots: Arc::new(RwLock::new(Vec::new())),
124 oauth_handler,
125 server_statuses: Vec::new(),
126 proxy: None,
127 }
128 }
129
130 fn refresh_status_entries(&mut self) {
131 self.server_statuses = self
132 .server_order
133 .iter()
134 .filter_map(|name| self.servers.get(name).map(|record| record.status_entry(name)))
135 .collect();
136 }
137
138 fn remember_server_order(&mut self, name: &str) {
139 if !self.server_order.iter().any(|n| n == name) {
140 self.server_order.push(name.to_string());
141 }
142 }
143
144 fn upsert_status(
145 &mut self,
146 name: &str,
147 status: McpServerStatus,
148 reauth_config: Option<StreamableHttpClientTransportConfig>,
149 ) {
150 self.remember_server_order(name);
151 let record = self
152 .servers
153 .entry(name.to_string())
154 .or_insert_with(|| ServerRecord::new(status.clone(), reauth_config.clone()));
155 record.status = status;
156 if reauth_config.is_some() {
157 record.reauth_config = reauth_config;
158 }
159 self.refresh_status_entries();
160 }
161
162 fn connection_for(&self, server_name: &str) -> Option<&McpServerConnection> {
163 self.servers.get(server_name).and_then(|record| record.connection.as_ref())
164 }
165
166 fn client_for_server(&self, server_name: &str) -> Option<Arc<RunningService<RoleClient, McpClient>>> {
167 self.connection_for(server_name).map(|conn| conn.client.clone())
168 }
169
170 pub async fn add_mcps(&mut self, configs: Vec<McpServerConfig>) -> Result<()> {
171 let (direct, proxies) = build_plan(configs).await?;
172 let outcomes: Vec<ConnectOutcome> = join_all(direct.into_iter().map(|leaf| {
173 connect_mcp(leaf, &self.client_info, &self.event_sender, &self.roots, self.oauth_handler.as_ref())
174 }))
175 .await;
176
177 let mut mcp_proxies: HashMap<String, HashSet<String>> =
178 proxies.iter().map(|p| (p.name.clone(), HashSet::new())).collect();
179
180 let mut connected_mcps_to_proxy: HashMap<String, Vec<String>> = HashMap::new();
181 for outcome in outcomes {
182 match outcome {
183 ConnectOutcome::Ready { name, conn, tools, proxy, registration, reauth_config } => {
184 self.apply_connected(&name, conn, &tools, registration, reauth_config);
185 if let Some(p) = proxy {
186 if let Some(members) = mcp_proxies.get_mut(&p) {
187 members.insert(name.clone());
188 }
189 connected_mcps_to_proxy.entry(p).or_default().push(name);
190 }
191 }
192 ConnectOutcome::NeedsOAuth { name, config, error, proxy } => {
193 tracing::warn!("Server '{name}' needs OAuth: {error}");
194 self.upsert_status(&name, McpServerStatus::NeedsOAuth, Some(config));
195 if let Some(p) = proxy
196 && let Some(members) = mcp_proxies.get_mut(&p)
197 {
198 members.insert(name);
199 }
200 }
201 ConnectOutcome::Failed { name, error } => {
202 tracing::warn!("Failed to connect to MCP server '{name}': {error}");
203 if !self.servers.contains_key(&name) {
204 self.upsert_status(&name, McpServerStatus::Failed { error: error.to_string() }, None);
205 }
206 }
207 }
208 }
209
210 let writes = proxies.iter().flat_map(|proxy| {
211 connected_mcps_to_proxy
212 .get(&proxy.name)
213 .into_iter()
214 .flatten()
215 .filter_map(|member| {
216 let client = self.client_for_server(member)?;
217 let dir = proxy.tool_dir.clone();
218 let name = member.clone();
219 Some(async move {
220 if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &dir).await {
221 tracing::warn!("Failed to write tool files for nested server '{name}': {e}");
222 }
223 })
224 })
225 .collect::<Vec<_>>()
226 });
227 join_all(writes).await;
228
229 for proxy in proxies {
230 self.register_proxy(proxy, &mut mcp_proxies, &mut connected_mcps_to_proxy);
231 }
232
233 Ok(())
234 }
235
236 pub async fn add_mcp_with_auth(&mut self, name: String, base_url: &str, auth_header: String) -> Result<()> {
237 let config = ServerConfig::Http {
238 name: name.clone(),
239 config: StreamableHttpClientTransportConfig::with_uri(base_url).auth_header(auth_header),
240 };
241 let leaf = ConnectionSpec { name, config, proxy: None, registration: Registration::Direct };
242 match connect_mcp(leaf, &self.client_info, &self.event_sender, &self.roots, self.oauth_handler.as_ref()).await {
243 ConnectOutcome::Ready { name, conn, tools, registration, .. } => {
244 self.apply_connected(&name, conn, &tools, registration, None);
245 Ok(())
246 }
247 ConnectOutcome::NeedsOAuth { error, .. } | ConnectOutcome::Failed { error, .. } => Err(error),
248 }
249 }
250
251 pub async fn add_mcp(&mut self, config: McpServerConfig) -> Result<()> {
252 match config {
253 McpServerConfig::ToolProxy { .. } => self.add_mcps(vec![config]).await,
254
255 McpServerConfig::Server(config) => {
256 let name = config.name().to_string();
257 let leaf = ConnectionSpec { name, config, proxy: None, registration: Registration::Direct };
258 match connect_mcp(leaf, &self.client_info, &self.event_sender, &self.roots, self.oauth_handler.as_ref())
259 .await
260 {
261 ConnectOutcome::Ready { name, conn, tools, registration, reauth_config, .. } => {
262 self.apply_connected(&name, conn, &tools, registration, reauth_config);
263 Ok(())
264 }
265 ConnectOutcome::NeedsOAuth { name, config, error, .. } => {
266 self.upsert_status(&name, McpServerStatus::NeedsOAuth, Some(config));
267 Err(error)
268 }
269 ConnectOutcome::Failed { error, .. } => Err(error),
270 }
271 }
272 }
273 }
274
275 fn register_proxy(
276 &mut self,
277 proxy: ProxySpec,
278 proxy_members: &mut HashMap<String, HashSet<String>>,
279 ready_for_proxy: &mut HashMap<String, Vec<String>>,
280 ) {
281 let members = ready_for_proxy.remove(&proxy.name).unwrap_or_default();
282
283 let server_descriptions: Vec<(String, String)> = members
284 .iter()
285 .filter_map(|member| {
286 self.connection_for(member)
287 .map(|conn| (member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
288 })
289 .collect();
290
291 self.remove_registered_tools_for_server(&proxy.name);
292 let call_tool_def = ToolProxy::call_tool_definition(&proxy.name);
293 self.tools.insert(
294 call_tool_def.name.clone(),
295 Tool {
296 description: call_tool_def.description.clone(),
297 parameters: serde_json::from_str(&call_tool_def.parameters)
298 .unwrap_or(Value::Object(serde_json::Map::default())),
299 },
300 );
301 self.tool_definitions.push(call_tool_def);
302
303 let nested = proxy_members.remove(&proxy.name).unwrap_or_default();
304 self.proxy = Some(ToolProxy::new(proxy.name.clone(), nested, proxy.tool_dir, &server_descriptions));
305 self.upsert_status(&proxy.name, McpServerStatus::Connected { tool_count: 1 }, None);
306 }
307
308 async fn oauth_and_reconnect(&mut self, name: String, config: StreamableHttpClientTransportConfig) -> Result<()> {
309 let handler = self
310 .oauth_handler
311 .as_ref()
312 .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler available for '{name}'")))?;
313 let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref())
314 .await
315 .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
316
317 let mcp_client =
318 McpClient::new(self.client_info.clone(), name.clone(), self.event_sender.clone(), Arc::clone(&self.roots));
319 let conn = McpServerConnection::reconnect_with_auth(&name, config.clone(), auth_client, mcp_client).await?;
320
321 let is_proxied = self.proxy.as_ref().is_some_and(|p| p.contains_server(&name));
322 if is_proxied {
323 let tool_dir = self.proxy.as_ref().expect("checked above").tool_dir().to_path_buf();
324 self.register_server(&name, conn, Registration::Proxied, Some(config)).await?;
325 if let Some(proxy) = self.proxy.as_mut() {
326 proxy.add_member(name.clone());
327 }
328 if let Some(conn) = self.connection_for(&name) {
329 let client = conn.client.clone();
330 if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &tool_dir).await {
331 tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
332 }
333 }
334 Ok(())
335 } else {
336 self.register_server(&name, conn, Registration::Direct, Some(config)).await
337 }
338 }
339
340 async fn register_server(
341 &mut self,
342 name: &str,
343 conn: McpServerConnection,
344 registration: Registration,
345 reauth_config: Option<StreamableHttpClientTransportConfig>,
346 ) -> Result<()> {
347 let tools = conn
348 .list_tools()
349 .await
350 .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
351 self.apply_connected(name, conn, &tools, registration, reauth_config);
352 Ok(())
353 }
354
355 fn apply_connected(
356 &mut self,
357 name: &str,
358 conn: McpServerConnection,
359 tools: &[RmcpTool],
360 registration: Registration,
361 reauth_config: Option<StreamableHttpClientTransportConfig>,
362 ) {
363 self.remove_registered_tools_for_server(name);
364
365 let existing_reauth = self.servers.get(name).and_then(|r| r.reauth_config.clone());
366 let final_reauth = reauth_config.or(existing_reauth);
367
368 for rmcp_tool in tools {
369 let tool_name = rmcp_tool.name.to_string();
370 let namespaced_tool_name = create_namespaced_tool_name(name, &tool_name);
371 let tool = Tool::from(rmcp_tool);
372
373 if registration == Registration::Direct {
374 self.tool_definitions.push(ToolDefinition {
375 name: namespaced_tool_name.clone(),
376 description: tool.description.clone(),
377 parameters: tool.parameters.to_string(),
378 server: Some(name.to_string()),
379 });
380 }
381
382 self.tools.insert(namespaced_tool_name, tool);
383 }
384
385 self.remember_server_order(name);
386 self.servers.insert(name.to_string(), ServerRecord::connected(conn, tools.len(), final_reauth));
387 self.refresh_status_entries();
388 }
389
390 fn remove_registered_tools_for_server(&mut self, server_name: &str) {
391 let prefix = format!("{server_name}__");
392 self.tools.retain(|tool_name, _| !tool_name.starts_with(&prefix));
393 self.tool_definitions.retain(|tool_def| !tool_def.name.starts_with(&prefix));
394 }
395
396 pub fn get_client_for_tool(
397 &self,
398 namespaced_tool_name: &str,
399 arguments_json: &str,
400 ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
401 if !self.tools.contains_key(namespaced_tool_name) {
402 return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
403 }
404
405 let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
406 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
407
408 if let Some(proxy) = self.proxy.as_ref().filter(|p| p.name() == server_name) {
409 let call = proxy.resolve_call(arguments_json)?;
410 let conn = self
411 .connection_for(&call.server)
412 .ok_or_else(|| McpError::ServerNotFound(format!("Nested server '{}' is not connected", call.server)))?;
413 let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
414 return Ok((conn.client.clone(), params));
415 }
416
417 let client =
418 self.client_for_server(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
419
420 let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
421 let mut params = CallToolRequestParams::new(tool_name.to_string());
422 if let Some(args) = arguments {
423 params = params.with_arguments(args);
424 }
425
426 Ok((client, params))
427 }
428
429 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
430 self.tool_definitions.clone()
431 }
432
433 pub fn server_instructions(&self) -> Vec<ServerInstructions> {
434 let mut instructions: Vec<ServerInstructions> = self
435 .servers
436 .iter()
437 .filter(|(name, _)| self.proxy.as_ref().is_none_or(|p| !p.contains_server(name)))
438 .filter_map(|(name, record)| {
439 record
440 .connection
441 .as_ref()
442 .and_then(|conn| conn.instructions.as_ref())
443 .map(|instr| ServerInstructions { server_name: name.clone(), instructions: instr.clone() })
444 })
445 .collect();
446
447 if let Some(proxy) = &self.proxy {
448 let descriptions: Vec<(String, String)> = proxy
449 .members()
450 .iter()
451 .filter_map(|member| {
452 let conn = self.connection_for(member)?;
453 Some((member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
454 })
455 .collect();
456 instructions.push(ServerInstructions {
457 server_name: proxy.name().to_string(),
458 instructions: ToolProxy::build_instructions(proxy.tool_dir(), &descriptions),
459 });
460 }
461
462 instructions
463 }
464
465 pub fn server_statuses(&self) -> &[McpServerStatusEntry] {
466 &self.server_statuses
467 }
468
469 pub async fn authenticate_server(&mut self, name: &str) -> Result<()> {
470 let record = self
471 .servers
472 .get(name)
473 .ok_or_else(|| McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")))?;
474 if !record.can_authenticate() {
475 return Err(McpError::ConnectionFailed(format!("server '{name}' is not OAuth-authenticatable")));
476 }
477
478 self.oauth_and_reconnect(name.to_string(), record.reauth_config.clone().expect("checked above")).await
479 }
480
481 pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
483 use futures::future::join_all;
484
485 let futures: Vec<_> = self
486 .servers
487 .iter()
488 .filter_map(|(server_name, record)| {
489 let conn = record.connection.as_ref()?;
490 conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref())?;
491 let server_name = server_name.clone();
492 let client = conn.client.clone();
493 Some(async move {
494 let prompts_response = client.list_prompts(None).await.map_err(|e| {
495 McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
496 })?;
497
498 let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
499 .prompts
500 .into_iter()
501 .map(|prompt| {
502 let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
503 rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
504 })
505 .collect();
506
507 Ok::<_, McpError>(namespaced_prompts)
508 })
509 })
510 .collect();
511
512 let results = join_all(futures).await;
513 let mut all_prompts = Vec::new();
514 for result in results {
515 all_prompts.extend(result?);
516 }
517
518 Ok(all_prompts)
519 }
520
521 pub async fn get_prompt(
523 &self,
524 namespaced_prompt_name: &str,
525 arguments: Option<serde_json::Map<String, serde_json::Value>>,
526 ) -> Result<rmcp::model::GetPromptResult> {
527 let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
528 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
529
530 let server_conn =
531 self.connection_for(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
532
533 let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
534 if let Some(args) = arguments {
535 request = request.with_arguments(args);
536 }
537
538 server_conn.client.get_prompt(request).await.map_err(|e| {
539 McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
540 })
541 }
542
543 pub async fn shutdown(&mut self) {
545 let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
546
547 for (server_name, record) in servers {
548 if let Some(conn) = record.connection
549 && let Some(handle) = conn.server_task
550 {
551 drop(conn.client);
552
553 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
554 Ok(Ok(())) => {
555 tracing::info!("Server '{server_name}' shut down gracefully");
556 }
557 Ok(Err(e)) => {
558 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
559 }
560 Err(_) => {
561 tracing::warn!("Server '{server_name}' shutdown timed out");
562 }
563 }
564 }
565 }
566
567 self.tools.clear();
568 self.tool_definitions.clear();
569 self.proxy = None;
570 }
571
572 pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
574 let record = self.servers.remove(server_name);
575
576 if let Some(record) = record {
577 if let Some(conn) = record.connection
578 && let Some(handle) = conn.server_task
579 {
580 drop(conn.client);
581
582 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
583 Ok(Ok(())) => {
584 tracing::info!("Server '{server_name}' shut down gracefully");
585 }
586 Ok(Err(e)) => {
587 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
588 }
589 Err(_) => {
590 tracing::warn!("Server '{server_name}' shutdown timed out");
591 }
592 }
593 }
594
595 self.remove_registered_tools_for_server(server_name);
596 self.refresh_status_entries();
597 }
598
599 Ok(())
600 }
601
602 pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
607 {
608 let mut roots = self.roots.write().await;
609 *roots = new_roots;
610 }
611
612 self.notify_roots_changed().await;
613
614 Ok(())
615 }
616
617 async fn notify_roots_changed(&self) {
618 for (server_name, record) in &self.servers {
619 if let Some(conn) = &record.connection
620 && let Err(e) = conn.client.notify_roots_list_changed().await
621 {
622 tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
623 }
624 }
625 }
626}
627
628impl Drop for McpManager {
629 fn drop(&mut self) {
630 let servers: Vec<(String, ServerRecord)> = self.servers.drain().collect();
631 for (server_name, record) in servers {
632 if let Some(conn) = record.connection
633 && let Some(handle) = conn.server_task
634 {
635 handle.abort();
636 tracing::warn!("Server '{server_name}' task aborted during cleanup");
637 }
638 }
639 }
640}
641
642#[cfg(test)]
643mod tests {
644 use super::{McpManager, McpServerStatus, Tool};
645 use crate::client::config::ServerConfig;
646 use crate::client::oauth::{OAuthCallback, OAuthError, OAuthHandler};
647 use crate::status::McpServerAuthCapability;
648 use futures::future::BoxFuture;
649 use llm::ToolDefinition;
650 use rmcp::{
651 Json, RoleServer, ServerHandler,
652 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
653 model::{Implementation, ServerCapabilities, ServerInfo},
654 service::DynService,
655 tool, tool_handler, tool_router,
656 transport::streamable_http_client::StreamableHttpClientTransportConfig,
657 };
658 use schemars::JsonSchema;
659 use serde::{Deserialize, Serialize};
660 use serde_json::json;
661 use std::{
662 io,
663 sync::{Arc, Mutex},
664 };
665 use tokio::sync::mpsc;
666
667 #[derive(Clone)]
668 struct TestServer {
669 tool_router: ToolRouter<Self>,
670 }
671
672 #[tool_handler(router = self.tool_router)]
673 impl ServerHandler for TestServer {
674 fn get_info(&self) -> ServerInfo {
675 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
676 .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
677 }
678 }
679
680 impl Default for TestServer {
681 fn default() -> Self {
682 Self { tool_router: Self::tool_router() }
683 }
684 }
685
686 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
687 struct EchoRequest {
688 value: String,
689 }
690
691 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
692 struct EchoResult {
693 value: String,
694 }
695
696 #[tool_router]
697 impl TestServer {
698 fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
699 Box::new(self)
700 }
701
702 #[tool(description = "Returns the provided value")]
703 async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
704 let Parameters(EchoRequest { value }) = request;
705 Json(EchoResult { value })
706 }
707 }
708
709 #[derive(Clone)]
710 struct SharedWriter(Arc<Mutex<Vec<u8>>>);
711
712 impl io::Write for SharedWriter {
713 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
714 self.0.lock().unwrap().extend_from_slice(buf);
715 Ok(buf.len())
716 }
717
718 fn flush(&mut self) -> io::Result<()> {
719 Ok(())
720 }
721 }
722
723 struct TestOAuthHandler;
724
725 impl OAuthHandler for TestOAuthHandler {
726 fn redirect_uri(&self) -> &'static str {
727 "http://127.0.0.1:0/oauth2callback"
728 }
729
730 fn authorize(&self, _auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
731 Box::pin(async { Err(OAuthError::UserCancelled) })
732 }
733 }
734
735 #[tokio::test]
736 async fn authenticate_server_rejects_record_without_reauth_config() {
737 let (event_sender, _event_receiver) = mpsc::channel(1);
738 let mut manager = McpManager::new(event_sender, Some(Arc::new(TestOAuthHandler)));
739 manager.upsert_status("public", McpServerStatus::Connected { tool_count: 1 }, None);
740
741 let error = manager.authenticate_server("public").await.unwrap_err().to_string();
742 assert!(error.contains("not OAuth-authenticatable"));
743 }
744
745 #[tokio::test]
746 async fn authenticate_server_uses_reauth_config_for_connected_oauth_server() {
747 let (event_sender, _event_receiver) = mpsc::channel(1);
748 let mut manager = McpManager::new(event_sender, Some(Arc::new(TestOAuthHandler)));
749 manager.upsert_status(
750 "remote",
751 McpServerStatus::Connected { tool_count: 1 },
752 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost:19999/mcp")),
753 );
754
755 let error = manager.authenticate_server("remote").await.unwrap_err().to_string();
756 assert!(!error.contains("not OAuth-authenticatable"));
757 assert!(error.contains("OAuth failed") || error.contains("UserCancelled"));
758 }
759
760 #[test]
761 fn status_entries_are_derived_from_reauth_config() {
762 let (event_sender, _event_receiver) = mpsc::channel(1);
763 let mut manager = McpManager::new(event_sender, Some(Arc::new(TestOAuthHandler)));
764
765 manager.upsert_status(
766 "with-oauth",
767 McpServerStatus::Connected { tool_count: 1 },
768 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp")),
769 );
770 manager.upsert_status("without-oauth", McpServerStatus::Connected { tool_count: 2 }, None);
771 manager.upsert_status(
772 "needs-oauth",
773 McpServerStatus::NeedsOAuth,
774 Some(StreamableHttpClientTransportConfig::with_uri("http://localhost/mcp2")),
775 );
776
777 let statuses = manager.server_statuses();
778 let with_oauth = statuses.iter().find(|s| s.name == "with-oauth").unwrap();
779 let without_oauth = statuses.iter().find(|s| s.name == "without-oauth").unwrap();
780 let needs_oauth = statuses.iter().find(|s| s.name == "needs-oauth").unwrap();
781
782 assert_eq!(with_oauth.auth_capability, McpServerAuthCapability::OAuth);
783 assert_eq!(without_oauth.auth_capability, McpServerAuthCapability::Unavailable);
784 assert_eq!(needs_oauth.auth_capability, McpServerAuthCapability::OAuth);
785 }
786
787 #[test]
788 fn remove_registered_tools_for_server_uses_namespaced_prefix() {
789 let (event_sender, _event_receiver) = mpsc::channel(1);
790 let mut manager = McpManager::new(event_sender, None);
791 manager.tools.insert("git__status".to_string(), Tool { description: String::new(), parameters: json!({}) });
792 manager.tools.insert("github__issue".to_string(), Tool { description: String::new(), parameters: json!({}) });
793 manager.tool_definitions.push(ToolDefinition {
794 name: "git__status".to_string(),
795 description: String::new(),
796 parameters: "{}".to_string(),
797 server: Some("git".to_string()),
798 });
799 manager.tool_definitions.push(ToolDefinition {
800 name: "github__issue".to_string(),
801 description: String::new(),
802 parameters: "{}".to_string(),
803 server: Some("github".to_string()),
804 });
805
806 manager.remove_registered_tools_for_server("git");
807
808 assert!(!manager.tools.contains_key("git__status"));
809 assert!(manager.tools.contains_key("github__issue"));
810 assert_eq!(
811 manager.tool_definitions.iter().map(|tool| tool.name.as_str()).collect::<Vec<_>>(),
812 vec!["github__issue"]
813 );
814 }
815
816 #[tokio::test]
817 async fn drop_logs_cleanup_abort_with_tracing() {
818 let (event_sender, _event_receiver) = mpsc::channel(1);
819 let mut manager = McpManager::new(event_sender, None);
820 manager
821 .add_mcp(
822 ServerConfig::InMemory { name: "test".to_string(), server: TestServer::default().into_dyn() }.into(),
823 )
824 .await
825 .unwrap();
826
827 let output = Arc::new(Mutex::new(Vec::new()));
828 let subscriber = tracing_subscriber::fmt()
829 .with_ansi(false)
830 .without_time()
831 .with_writer({
832 let output = Arc::clone(&output);
833 move || SharedWriter(Arc::clone(&output))
834 })
835 .finish();
836
837 tracing::subscriber::with_default(subscriber, || {
838 drop(manager);
839 });
840
841 let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
842 assert!(logs.contains("Server 'test' task aborted during cleanup"));
843 }
844}