1use llm::ToolDefinition;
2
3use super::{
4 McpError, Result,
5 config::{McpServerConfig, ServerConfig},
6 connection::{ConnectParams, ConnectResult, McpServerConnection, ServerInstructions, Tool},
7 mcp_client::McpClient,
8 naming::{create_namespaced_tool_name, split_on_server_name},
9 oauth::{OAuthHandler, perform_oauth_flow},
10 tool_proxy::ToolProxy,
11};
12use rmcp::{
13 RoleClient,
14 model::{
15 CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams,
16 CreateElicitationResult, ElicitationAction, Implementation, Root,
17 },
18 service::RunningService,
19 transport::streamable_http_client::StreamableHttpClientTransportConfig,
20};
21use serde_json::Value;
22use std::collections::{HashMap, HashSet};
23use std::sync::Arc;
24use tokio::sync::{RwLock, mpsc, oneshot};
25
26pub use crate::status::{McpServerStatus, McpServerStatusEntry};
27
28#[derive(Debug)]
29pub struct ElicitationRequest {
30 pub request: CreateElicitationRequestParams,
31 pub response_sender: oneshot::Sender<CreateElicitationResult>,
32}
33
34#[derive(Debug, Clone)]
35pub struct ElicitationResponse {
36 pub action: ElicitationAction,
37 pub content: Option<Value>,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43enum Registration {
44 Direct,
46 Proxied,
48}
49
50pub struct McpManager {
52 servers: HashMap<String, McpServerConnection>,
53 tools: HashMap<String, Tool>,
54 tool_definitions: Vec<ToolDefinition>,
55 client_info: ClientInfo,
56 elicitation_sender: mpsc::Sender<ElicitationRequest>,
57 roots: Arc<RwLock<Vec<Root>>>,
59 oauth_handler: Option<Arc<dyn OAuthHandler>>,
60 server_statuses: Vec<McpServerStatusEntry>,
61 pending_configs: HashMap<String, StreamableHttpClientTransportConfig>,
63 proxy: Option<ToolProxy>,
65}
66
67impl McpManager {
68 pub fn new(
69 elicitation_sender: mpsc::Sender<ElicitationRequest>,
70 oauth_handler: Option<Arc<dyn OAuthHandler>>,
71 ) -> Self {
72 Self {
73 servers: HashMap::new(),
74 tools: HashMap::new(),
75 tool_definitions: Vec::new(),
76 client_info: ClientInfo::new(
77 ClientCapabilities::builder()
78 .enable_elicitation()
79 .enable_roots()
80 .build(),
81 Implementation::new("aether", "0.1.0"),
82 ),
83 elicitation_sender,
84 roots: Arc::new(RwLock::new(Vec::new())),
85 oauth_handler,
86 server_statuses: Vec::new(),
87 pending_configs: HashMap::new(),
88 proxy: None,
89 }
90 }
91
92 fn create_mcp_client(&self) -> McpClient {
93 McpClient::new(
94 self.client_info.clone(),
95 self.elicitation_sender.clone(),
96 Arc::clone(&self.roots),
97 )
98 }
99
100 fn connect_params(&self) -> ConnectParams {
101 ConnectParams {
102 mcp_client: self.create_mcp_client(),
103 oauth_handler: self.oauth_handler.clone(),
104 }
105 }
106
107 fn set_status(&mut self, name: &str, status: McpServerStatus) {
109 if let Some(entry) = self.server_statuses.iter_mut().find(|s| s.name == name) {
110 entry.status = status;
111 } else {
112 self.server_statuses.push(McpServerStatusEntry {
113 name: name.to_string(),
114 status,
115 });
116 }
117 }
118
119 pub async fn add_mcps(&mut self, configs: Vec<McpServerConfig>) -> Result<()> {
120 for config in configs {
121 let name = config.name().to_string();
122 if let Err(e) = self.add_mcp(config).await {
123 tracing::warn!("Failed to connect to MCP server '{}': {}", name, e);
125 if !self.server_statuses.iter().any(|s| s.name == name) {
127 self.set_status(
128 &name,
129 McpServerStatus::Failed {
130 error: e.to_string(),
131 },
132 );
133 }
134 }
135 }
136 Ok(())
137 }
138
139 pub async fn add_mcp_with_auth(
140 &mut self,
141 name: String,
142 base_url: &str,
143 auth_header: String,
144 ) -> Result<()> {
145 let config = ServerConfig::Http {
146 name: name.clone(),
147 config: StreamableHttpClientTransportConfig::with_uri(base_url)
148 .auth_header(auth_header),
149 };
150 let params = self.connect_params();
151 match McpServerConnection::connect(config, params).await {
152 ConnectResult::Connected(conn) => {
153 self.register_server(&name, conn, Registration::Direct)
154 .await
155 }
156 ConnectResult::NeedsOAuth { error, .. } => Err(error),
157 ConnectResult::Failed(e) => Err(e),
158 }
159 }
160
161 pub async fn add_mcp(&mut self, config: McpServerConfig) -> Result<()> {
162 match config {
163 McpServerConfig::ToolProxy { name, servers } => {
164 self.connect_tool_proxy(name, servers).await
165 }
166
167 McpServerConfig::Server(config) => {
168 let name = config.name().to_string();
169 let params = self.connect_params();
170 match McpServerConnection::connect(config, params).await {
171 ConnectResult::Connected(conn) => {
172 self.register_server(&name, conn, Registration::Direct)
173 .await
174 }
175 ConnectResult::NeedsOAuth {
176 name,
177 config,
178 error,
179 } => {
180 self.pending_configs.insert(name.clone(), config);
181 self.set_status(&name, McpServerStatus::NeedsOAuth);
182 Err(error)
183 }
184 ConnectResult::Failed(e) => Err(e),
185 }
186 }
187 }
188 }
189
190 async fn connect_tool_proxy(
194 &mut self,
195 proxy_name: String,
196 servers: Vec<ServerConfig>,
197 ) -> Result<()> {
198 let tool_dir = ToolProxy::dir(&proxy_name)?;
199 ToolProxy::clean_dir(&tool_dir).await?;
200
201 let mut nested_names = HashSet::new();
202 let mut server_descriptions = Vec::new();
203
204 for config in servers {
205 let server_name = config.name().to_string();
206 let params = self.connect_params();
207
208 let result = match McpServerConnection::connect(config, params).await {
209 ConnectResult::Connected(conn) => {
210 self.register_server(&server_name, conn, Registration::Proxied)
211 .await
212 }
213 ConnectResult::NeedsOAuth {
214 name,
215 config,
216 error,
217 } => {
218 self.pending_configs.insert(name.clone(), config);
219 self.set_status(&name, McpServerStatus::NeedsOAuth);
220 Err(error)
221 }
222 ConnectResult::Failed(e) => Err(e),
223 };
224
225 match result {
226 Ok(()) => {
227 if let Some(conn) = self.servers.get(&server_name) {
229 let client = conn.client.clone();
230 if let Err(e) =
231 ToolProxy::write_tools_to_dir(&server_name, &client, &tool_dir).await
232 {
233 tracing::warn!(
234 "Failed to write tool files for nested server '{server_name}': {e}"
235 );
236 }
237
238 let description =
239 ToolProxy::extract_server_description(&client, &server_name);
240 server_descriptions.push((server_name.clone(), description));
241 }
242 nested_names.insert(server_name);
243 }
244 Err(e) => {
245 tracing::warn!("Failed to connect nested server '{server_name}': {e}");
246 if self.pending_configs.contains_key(&server_name) {
249 nested_names.insert(server_name);
250 }
251 }
252 }
253 }
254
255 let call_tool_def = ToolProxy::call_tool_definition(&proxy_name);
256 self.tools.insert(
257 call_tool_def.name.clone(),
258 Tool {
259 description: call_tool_def.description.clone(),
260 parameters: serde_json::from_str(&call_tool_def.parameters)
261 .unwrap_or(Value::Object(serde_json::Map::default())),
262 },
263 );
264 self.tool_definitions.push(call_tool_def);
265
266 self.proxy = Some(ToolProxy::new(
267 proxy_name.clone(),
268 nested_names,
269 tool_dir,
270 &server_descriptions,
271 ));
272
273 self.set_status(&proxy_name, McpServerStatus::Connected { tool_count: 1 });
275
276 Ok(())
277 }
278
279 async fn oauth_and_reconnect(
280 &mut self,
281 name: String,
282 config: StreamableHttpClientTransportConfig,
283 ) -> Result<()> {
284 let handler = self.oauth_handler.as_ref().ok_or_else(|| {
285 McpError::ConnectionFailed(format!("No OAuth handler available for '{name}'"))
286 })?;
287 let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref())
288 .await
289 .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
290
291 let mcp_client = self.create_mcp_client();
292 let conn = McpServerConnection::reconnect_with_auth(&name, config, auth_client, mcp_client)
293 .await?;
294
295 if let Some(proxy) = self.proxy.as_ref().filter(|p| p.contains_server(&name)) {
297 let tool_dir = proxy.tool_dir().to_path_buf();
298 self.register_server(&name, conn, Registration::Proxied)
299 .await?;
300 if let Some(conn) = self.servers.get(&name) {
302 let client = conn.client.clone();
303 if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &tool_dir).await {
304 tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
305 }
306 }
307 Ok(())
308 } else {
309 self.register_server(&name, conn, Registration::Direct)
310 .await
311 }
312 }
313
314 async fn register_server(
320 &mut self,
321 name: &str,
322 conn: McpServerConnection,
323 registration: Registration,
324 ) -> Result<()> {
325 let tools = conn.list_tools().await.map_err(|e| {
326 McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}"))
327 })?;
328
329 for rmcp_tool in &tools {
330 let tool_name = rmcp_tool.name.to_string();
331 let namespaced_tool_name = create_namespaced_tool_name(name, &tool_name);
332 let tool = Tool::from(rmcp_tool);
333
334 if registration == Registration::Direct {
335 self.tool_definitions.push(ToolDefinition {
336 name: namespaced_tool_name.clone(),
337 description: tool.description.clone(),
338 parameters: tool.parameters.to_string(),
339 server: Some(name.to_string()),
340 });
341 }
342
343 self.tools.insert(namespaced_tool_name, tool);
344 }
345
346 let tool_count = tools.len();
347
348 self.set_status(name, McpServerStatus::Connected { tool_count });
349
350 self.pending_configs.remove(name);
352
353 self.servers.insert(name.to_string(), conn);
354 Ok(())
355 }
356
357 pub fn get_client_for_tool(
363 &self,
364 namespaced_tool_name: &str,
365 arguments_json: &str,
366 ) -> Result<(
367 Arc<RunningService<RoleClient, McpClient>>,
368 CallToolRequestParams,
369 )> {
370 if !self.tools.contains_key(namespaced_tool_name) {
371 return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
372 }
373
374 let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
375 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
376
377 if let Some(proxy) = self.proxy.as_ref().filter(|p| p.name() == server_name) {
378 let call = proxy.resolve_call(arguments_json)?;
379 let conn = self.servers.get(&call.server).ok_or_else(|| {
380 McpError::ServerNotFound(format!(
381 "Nested server '{}' is not connected",
382 call.server
383 ))
384 })?;
385 let params = CallToolRequestParams::new(call.tool)
386 .with_arguments(call.arguments.unwrap_or_default());
387 return Ok((conn.client.clone(), params));
388 }
389
390 let client = self
391 .servers
392 .get(server_name)
393 .map(|server| server.client.clone())
394 .ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
395
396 let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?
397 .as_object()
398 .cloned();
399 let mut params = CallToolRequestParams::new(tool_name.to_string());
400 if let Some(args) = arguments {
401 params = params.with_arguments(args);
402 }
403
404 Ok((client, params))
405 }
406
407 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
408 self.tool_definitions.clone()
409 }
410
411 pub fn server_instructions(&self) -> Vec<ServerInstructions> {
414 let mut instructions: Vec<ServerInstructions> = self
415 .servers
416 .iter()
417 .filter(|(name, _)| self.proxy.as_ref().is_none_or(|p| !p.contains_server(name)))
418 .filter_map(|(name, conn)| {
419 conn.instructions.as_ref().map(|instr| ServerInstructions {
420 server_name: name.clone(),
421 instructions: instr.clone(),
422 })
423 })
424 .collect();
425
426 if let Some(proxy) = &self.proxy {
427 instructions.push(ServerInstructions {
428 server_name: proxy.name().to_string(),
429 instructions: proxy.instructions().to_string(),
430 });
431 }
432
433 instructions
434 }
435
436 pub fn server_statuses(&self) -> &[McpServerStatusEntry] {
437 &self.server_statuses
438 }
439
440 pub async fn authenticate_server(&mut self, name: &str) -> Result<()> {
445 let config = self
446 .pending_configs
447 .get(name)
448 .ok_or_else(|| {
449 McpError::ConnectionFailed(format!("no pending config for server '{name}'"))
450 })?
451 .clone();
452
453 self.oauth_and_reconnect(name.to_string(), config).await
454 }
455
456 pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
458 use futures::future::join_all;
459
460 let futures: Vec<_> = self
461 .servers
462 .iter()
463 .filter(|(_, server_conn)| {
464 server_conn
465 .client
466 .peer_info()
467 .and_then(|info| info.capabilities.prompts.as_ref())
468 .is_some()
469 })
470 .map(|(server_name, server_conn)| {
471 let server_name = server_name.clone();
472 let client = server_conn.client.clone();
473 async move {
474 let prompts_response = client.list_prompts(None).await.map_err(|e| {
475 McpError::PromptListFailed(format!(
476 "Failed to list prompts for {server_name}: {e}"
477 ))
478 })?;
479
480 let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
481 .prompts
482 .into_iter()
483 .map(|prompt| {
484 let namespaced_name =
485 create_namespaced_tool_name(&server_name, &prompt.name);
486 rmcp::model::Prompt::new(
487 namespaced_name,
488 prompt.description,
489 prompt.arguments,
490 )
491 })
492 .collect();
493
494 Ok::<_, McpError>(namespaced_prompts)
495 }
496 })
497 .collect();
498
499 let results = join_all(futures).await;
500 let mut all_prompts = Vec::new();
501 for result in results {
502 all_prompts.extend(result?);
503 }
504
505 Ok(all_prompts)
506 }
507
508 pub async fn get_prompt(
510 &self,
511 namespaced_prompt_name: &str,
512 arguments: Option<serde_json::Map<String, serde_json::Value>>,
513 ) -> Result<rmcp::model::GetPromptResult> {
514 let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
515 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
516
517 let server_conn = self
518 .servers
519 .get(server_name)
520 .ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
521
522 let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
523 if let Some(args) = arguments {
524 request = request.with_arguments(args);
525 }
526
527 server_conn.client.get_prompt(request).await.map_err(|e| {
528 McpError::PromptGetFailed(format!(
529 "Failed to get prompt '{prompt_name}' from {server_name}: {e}"
530 ))
531 })
532 }
533
534 pub async fn shutdown(&mut self) {
536 let servers: Vec<(String, McpServerConnection)> = self.servers.drain().collect();
537
538 for (server_name, server) in servers {
539 if let Some(handle) = server.server_task {
540 drop(server.client);
542
543 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
545 Ok(Ok(())) => {
546 tracing::info!("Server '{server_name}' shut down gracefully");
547 }
548 Ok(Err(e)) => {
549 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
550 }
551 Err(_) => {
552 tracing::warn!("Server '{server_name}' shutdown timed out");
553 }
555 }
556 }
557 }
558
559 self.tools.clear();
560 self.tool_definitions.clear();
561 self.proxy = None;
562 }
563
564 pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
566 let server = self.servers.remove(server_name);
567
568 if let Some(server) = server {
569 if let Some(handle) = server.server_task {
570 drop(server.client);
572
573 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
575 Ok(Ok(())) => {
576 tracing::info!("Server '{server_name}' shut down gracefully");
577 }
578 Ok(Err(e)) => {
579 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
580 }
581 Err(_) => {
582 tracing::warn!("Server '{server_name}' shutdown timed out");
583 }
585 }
586 }
587
588 self.tools
590 .retain(|tool_name, _| !tool_name.starts_with(server_name));
591
592 self.tool_definitions
593 .retain(|tool_def| !tool_def.name.starts_with(server_name));
594 }
595
596 Ok(())
597 }
598
599 pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
604 {
606 let mut roots = self.roots.write().await;
607 *roots = new_roots;
608 }
609
610 self.notify_roots_changed().await;
612
613 Ok(())
614 }
615
616 async fn notify_roots_changed(&self) {
621 for (server_name, server_conn) in &self.servers {
622 if let Err(e) = server_conn.client.notify_roots_list_changed().await {
624 tracing::debug!(
626 "Note: server '{server_name}' did not accept roots notification: {e}"
627 );
628 }
629 }
630 }
631}
632
633impl Drop for McpManager {
634 fn drop(&mut self) {
635 let servers: Vec<(String, McpServerConnection)> = self.servers.drain().collect();
636 for (server_name, server) in servers {
637 if let Some(handle) = server.server_task {
638 handle.abort();
639 tracing::warn!("Server '{server_name}' task aborted during cleanup");
640 }
641 }
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::McpManager;
648 use crate::client::config::ServerConfig;
649 use rmcp::{
650 Json, RoleServer, ServerHandler,
651 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
652 model::{Implementation, ServerCapabilities, ServerInfo},
653 service::DynService,
654 tool, tool_handler, tool_router,
655 };
656 use schemars::JsonSchema;
657 use serde::{Deserialize, Serialize};
658 use std::{
659 io,
660 sync::{Arc, Mutex},
661 };
662 use tokio::sync::mpsc;
663
664 #[derive(Clone)]
665 struct TestServer {
666 tool_router: ToolRouter<Self>,
667 }
668
669 #[tool_handler(router = self.tool_router)]
670 impl ServerHandler for TestServer {
671 fn get_info(&self) -> ServerInfo {
672 ServerInfo::new(ServerCapabilities::builder().enable_tools().build()).with_server_info(
673 Implementation::new("test-server", "0.1.0").with_description("Test MCP server"),
674 )
675 }
676 }
677
678 impl Default for TestServer {
679 fn default() -> Self {
680 Self {
681 tool_router: Self::tool_router(),
682 }
683 }
684 }
685
686 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
687 struct EchoRequest {
688 value: String,
689 }
690
691 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
692 struct EchoResult {
693 value: String,
694 }
695
696 #[tool_router]
697 impl TestServer {
698 fn as_dyn(self) -> Box<dyn DynService<RoleServer>> {
699 Box::new(self)
700 }
701
702 #[tool(description = "Returns the provided value")]
703 async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
704 let Parameters(EchoRequest { value }) = request;
705 Json(EchoResult { value })
706 }
707 }
708
709 #[derive(Clone)]
710 struct SharedWriter(Arc<Mutex<Vec<u8>>>);
711
712 impl io::Write for SharedWriter {
713 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
714 self.0.lock().unwrap().extend_from_slice(buf);
715 Ok(buf.len())
716 }
717
718 fn flush(&mut self) -> io::Result<()> {
719 Ok(())
720 }
721 }
722
723 #[tokio::test]
724 async fn drop_logs_cleanup_abort_with_tracing() {
725 let (elicitation_sender, _elicitation_receiver) = mpsc::channel(1);
726 let mut manager = McpManager::new(elicitation_sender, None);
727 manager
728 .add_mcp(
729 ServerConfig::InMemory {
730 name: "test".to_string(),
731 server: TestServer::default().as_dyn(),
732 }
733 .into(),
734 )
735 .await
736 .unwrap();
737
738 let output = Arc::new(Mutex::new(Vec::new()));
739 let subscriber = tracing_subscriber::fmt()
740 .with_ansi(false)
741 .without_time()
742 .with_writer({
743 let output = Arc::clone(&output);
744 move || SharedWriter(Arc::clone(&output))
745 })
746 .finish();
747
748 tracing::subscriber::with_default(subscriber, || {
749 drop(manager);
750 });
751
752 let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
753 assert!(logs.contains("Server 'test' task aborted during cleanup"));
754 }
755}