1use llm::ToolDefinition;
2
3use super::{
4 McpError, Result,
5 config::{McpServerConfig, ServerConfig},
6 connection::{ConnectParams, ConnectResult, McpServerConnection, ServerInstructions, Tool},
7 mcp_client::McpClient,
8 naming::{create_namespaced_tool_name, split_on_server_name},
9 oauth::{OAuthHandler, perform_oauth_flow},
10 tool_proxy::ToolProxy,
11};
12use rmcp::{
13 RoleClient,
14 model::{
15 CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
16 ElicitationAction, FormElicitationCapability, Implementation, Root, UrlElicitationCapability,
17 },
18 service::RunningService,
19 transport::streamable_http_client::StreamableHttpClientTransportConfig,
20};
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use std::collections::{HashMap, HashSet};
24use std::sync::Arc;
25use tokio::sync::{RwLock, mpsc, oneshot};
26
27pub use crate::status::{McpServerStatus, McpServerStatusEntry};
28
29#[derive(Debug)]
30pub struct ElicitationRequest {
31 pub server_name: String,
32 pub request: CreateElicitationRequestParams,
33 pub response_sender: oneshot::Sender<CreateElicitationResult>,
34}
35
36#[derive(Debug, Clone)]
37pub struct ElicitationResponse {
38 pub action: ElicitationAction,
39 pub content: Option<Value>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
43pub struct UrlElicitationCompleteParams {
44 pub server_name: String,
45 pub elicitation_id: String,
46}
47
48#[derive(Debug)]
52pub enum McpClientEvent {
53 Elicitation(ElicitationRequest),
54 UrlElicitationComplete(UrlElicitationCompleteParams),
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60enum Registration {
61 Direct,
63 Proxied,
65}
66
67pub struct McpManager {
69 servers: HashMap<String, McpServerConnection>,
70 tools: HashMap<String, Tool>,
71 tool_definitions: Vec<ToolDefinition>,
72 client_info: ClientInfo,
73 event_sender: mpsc::Sender<McpClientEvent>,
74 roots: Arc<RwLock<Vec<Root>>>,
76 oauth_handler: Option<Arc<dyn OAuthHandler>>,
77 server_statuses: Vec<McpServerStatusEntry>,
78 pending_configs: HashMap<String, StreamableHttpClientTransportConfig>,
80 proxy: Option<ToolProxy>,
82}
83
84impl McpManager {
85 pub fn new(event_sender: mpsc::Sender<McpClientEvent>, oauth_handler: Option<Arc<dyn OAuthHandler>>) -> Self {
86 let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
87 if let Some(elicitation) = capabilities.elicitation.as_mut() {
88 elicitation.form = Some(FormElicitationCapability::default());
89 elicitation.url = Some(UrlElicitationCapability::default());
90 }
91
92 Self {
93 servers: HashMap::new(),
94 tools: HashMap::new(),
95 tool_definitions: Vec::new(),
96 client_info: ClientInfo::new(capabilities, Implementation::new("aether", "0.1.0")),
97 event_sender,
98 roots: Arc::new(RwLock::new(Vec::new())),
99 oauth_handler,
100 server_statuses: Vec::new(),
101 pending_configs: HashMap::new(),
102 proxy: None,
103 }
104 }
105
106 fn create_mcp_client(&self, server_name: &str) -> McpClient {
107 McpClient::new(
108 self.client_info.clone(),
109 server_name.to_string(),
110 self.event_sender.clone(),
111 Arc::clone(&self.roots),
112 )
113 }
114
115 fn connect_params(&self, server_name: &str) -> ConnectParams {
116 ConnectParams { mcp_client: self.create_mcp_client(server_name), oauth_handler: self.oauth_handler.clone() }
117 }
118
119 fn set_status(&mut self, name: &str, status: McpServerStatus) {
121 if let Some(entry) = self.server_statuses.iter_mut().find(|s| s.name == name) {
122 entry.status = status;
123 } else {
124 self.server_statuses.push(McpServerStatusEntry { name: name.to_string(), status });
125 }
126 }
127
128 pub async fn add_mcps(&mut self, configs: Vec<McpServerConfig>) -> Result<()> {
129 for config in configs {
130 let name = config.name().to_string();
131 if let Err(e) = self.add_mcp(config).await {
132 tracing::warn!("Failed to connect to MCP server '{}': {}", name, e);
134 if !self.server_statuses.iter().any(|s| s.name == name) {
136 self.set_status(&name, McpServerStatus::Failed { error: e.to_string() });
137 }
138 }
139 }
140 Ok(())
141 }
142
143 pub async fn add_mcp_with_auth(&mut self, name: String, base_url: &str, auth_header: String) -> Result<()> {
144 let config = ServerConfig::Http {
145 name: name.clone(),
146 config: StreamableHttpClientTransportConfig::with_uri(base_url).auth_header(auth_header),
147 };
148 let params = self.connect_params(&name);
149 match McpServerConnection::connect(config, params).await {
150 ConnectResult::Connected(conn) => self.register_server(&name, conn, Registration::Direct).await,
151 ConnectResult::NeedsOAuth { error, .. } => Err(error),
152 ConnectResult::Failed(e) => Err(e),
153 }
154 }
155
156 pub async fn add_mcp(&mut self, config: McpServerConfig) -> Result<()> {
157 match config {
158 McpServerConfig::ToolProxy { name, servers } => self.connect_tool_proxy(name, servers).await,
159
160 McpServerConfig::Server(config) => {
161 let name = config.name().to_string();
162 let params = self.connect_params(&name);
163 match McpServerConnection::connect(config, params).await {
164 ConnectResult::Connected(conn) => self.register_server(&name, conn, Registration::Direct).await,
165 ConnectResult::NeedsOAuth { name, config, error } => {
166 self.pending_configs.insert(name.clone(), config);
167 self.set_status(&name, McpServerStatus::NeedsOAuth);
168 Err(error)
169 }
170 ConnectResult::Failed(e) => Err(e),
171 }
172 }
173 }
174 }
175
176 async fn connect_tool_proxy(&mut self, proxy_name: String, servers: Vec<ServerConfig>) -> Result<()> {
180 let tool_dir = ToolProxy::dir(&proxy_name)?;
181 ToolProxy::clean_dir(&tool_dir).await?;
182
183 let mut nested_names = HashSet::new();
184 let mut server_descriptions = Vec::new();
185
186 for config in servers {
187 let server_name = config.name().to_string();
188 let params = self.connect_params(&server_name);
189
190 let result = match McpServerConnection::connect(config, params).await {
191 ConnectResult::Connected(conn) => self.register_server(&server_name, conn, Registration::Proxied).await,
192 ConnectResult::NeedsOAuth { name, config, error } => {
193 self.pending_configs.insert(name.clone(), config);
194 self.set_status(&name, McpServerStatus::NeedsOAuth);
195 Err(error)
196 }
197 ConnectResult::Failed(e) => Err(e),
198 };
199
200 match result {
201 Ok(()) => {
202 if let Some(conn) = self.servers.get(&server_name) {
204 let client = conn.client.clone();
205 if let Err(e) = ToolProxy::write_tools_to_dir(&server_name, &client, &tool_dir).await {
206 tracing::warn!("Failed to write tool files for nested server '{server_name}': {e}");
207 }
208
209 let description = ToolProxy::extract_server_description(&client, &server_name);
210 server_descriptions.push((server_name.clone(), description));
211 }
212 nested_names.insert(server_name);
213 }
214 Err(e) => {
215 tracing::warn!("Failed to connect nested server '{server_name}': {e}");
216 if self.pending_configs.contains_key(&server_name) {
219 nested_names.insert(server_name);
220 }
221 }
222 }
223 }
224
225 let call_tool_def = ToolProxy::call_tool_definition(&proxy_name);
226 self.tools.insert(
227 call_tool_def.name.clone(),
228 Tool {
229 description: call_tool_def.description.clone(),
230 parameters: serde_json::from_str(&call_tool_def.parameters)
231 .unwrap_or(Value::Object(serde_json::Map::default())),
232 },
233 );
234 self.tool_definitions.push(call_tool_def);
235
236 self.proxy = Some(ToolProxy::new(proxy_name.clone(), nested_names, tool_dir, &server_descriptions));
237
238 self.set_status(&proxy_name, McpServerStatus::Connected { tool_count: 1 });
240
241 Ok(())
242 }
243
244 async fn oauth_and_reconnect(&mut self, name: String, config: StreamableHttpClientTransportConfig) -> Result<()> {
245 let handler = self
246 .oauth_handler
247 .as_ref()
248 .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler available for '{name}'")))?;
249 let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref())
250 .await
251 .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
252
253 let mcp_client = self.create_mcp_client(&name);
254 let conn = McpServerConnection::reconnect_with_auth(&name, config, auth_client, mcp_client).await?;
255
256 if let Some(proxy) = self.proxy.as_ref().filter(|p| p.contains_server(&name)) {
258 let tool_dir = proxy.tool_dir().to_path_buf();
259 self.register_server(&name, conn, Registration::Proxied).await?;
260 if let Some(conn) = self.servers.get(&name) {
262 let client = conn.client.clone();
263 if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &tool_dir).await {
264 tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
265 }
266 }
267 Ok(())
268 } else {
269 self.register_server(&name, conn, Registration::Direct).await
270 }
271 }
272
273 async fn register_server(
279 &mut self,
280 name: &str,
281 conn: McpServerConnection,
282 registration: Registration,
283 ) -> Result<()> {
284 let tools = conn
285 .list_tools()
286 .await
287 .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
288
289 for rmcp_tool in &tools {
290 let tool_name = rmcp_tool.name.to_string();
291 let namespaced_tool_name = create_namespaced_tool_name(name, &tool_name);
292 let tool = Tool::from(rmcp_tool);
293
294 if registration == Registration::Direct {
295 self.tool_definitions.push(ToolDefinition {
296 name: namespaced_tool_name.clone(),
297 description: tool.description.clone(),
298 parameters: tool.parameters.to_string(),
299 server: Some(name.to_string()),
300 });
301 }
302
303 self.tools.insert(namespaced_tool_name, tool);
304 }
305
306 let tool_count = tools.len();
307
308 self.set_status(name, McpServerStatus::Connected { tool_count });
309
310 self.pending_configs.remove(name);
312
313 self.servers.insert(name.to_string(), conn);
314 Ok(())
315 }
316
317 pub fn get_client_for_tool(
323 &self,
324 namespaced_tool_name: &str,
325 arguments_json: &str,
326 ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
327 if !self.tools.contains_key(namespaced_tool_name) {
328 return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
329 }
330
331 let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
332 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
333
334 if let Some(proxy) = self.proxy.as_ref().filter(|p| p.name() == server_name) {
335 let call = proxy.resolve_call(arguments_json)?;
336 let conn = self
337 .servers
338 .get(&call.server)
339 .ok_or_else(|| McpError::ServerNotFound(format!("Nested server '{}' is not connected", call.server)))?;
340 let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
341 return Ok((conn.client.clone(), params));
342 }
343
344 let client = self
345 .servers
346 .get(server_name)
347 .map(|server| server.client.clone())
348 .ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
349
350 let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
351 let mut params = CallToolRequestParams::new(tool_name.to_string());
352 if let Some(args) = arguments {
353 params = params.with_arguments(args);
354 }
355
356 Ok((client, params))
357 }
358
359 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
360 self.tool_definitions.clone()
361 }
362
363 pub fn server_instructions(&self) -> Vec<ServerInstructions> {
366 let mut instructions: Vec<ServerInstructions> = self
367 .servers
368 .iter()
369 .filter(|(name, _)| self.proxy.as_ref().is_none_or(|p| !p.contains_server(name)))
370 .filter_map(|(name, conn)| {
371 conn.instructions
372 .as_ref()
373 .map(|instr| ServerInstructions { server_name: name.clone(), instructions: instr.clone() })
374 })
375 .collect();
376
377 if let Some(proxy) = &self.proxy {
378 instructions.push(ServerInstructions {
379 server_name: proxy.name().to_string(),
380 instructions: proxy.instructions().to_string(),
381 });
382 }
383
384 instructions
385 }
386
387 pub fn server_statuses(&self) -> &[McpServerStatusEntry] {
388 &self.server_statuses
389 }
390
391 pub async fn authenticate_server(&mut self, name: &str) -> Result<()> {
396 let config = self
397 .pending_configs
398 .get(name)
399 .ok_or_else(|| McpError::ConnectionFailed(format!("no pending config for server '{name}'")))?
400 .clone();
401
402 self.oauth_and_reconnect(name.to_string(), config).await
403 }
404
405 pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
407 use futures::future::join_all;
408
409 let futures: Vec<_> = self
410 .servers
411 .iter()
412 .filter(|(_, server_conn)| {
413 server_conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref()).is_some()
414 })
415 .map(|(server_name, server_conn)| {
416 let server_name = server_name.clone();
417 let client = server_conn.client.clone();
418 async move {
419 let prompts_response = client.list_prompts(None).await.map_err(|e| {
420 McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
421 })?;
422
423 let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
424 .prompts
425 .into_iter()
426 .map(|prompt| {
427 let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
428 rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
429 })
430 .collect();
431
432 Ok::<_, McpError>(namespaced_prompts)
433 }
434 })
435 .collect();
436
437 let results = join_all(futures).await;
438 let mut all_prompts = Vec::new();
439 for result in results {
440 all_prompts.extend(result?);
441 }
442
443 Ok(all_prompts)
444 }
445
446 pub async fn get_prompt(
448 &self,
449 namespaced_prompt_name: &str,
450 arguments: Option<serde_json::Map<String, serde_json::Value>>,
451 ) -> Result<rmcp::model::GetPromptResult> {
452 let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
453 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
454
455 let server_conn =
456 self.servers.get(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
457
458 let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
459 if let Some(args) = arguments {
460 request = request.with_arguments(args);
461 }
462
463 server_conn.client.get_prompt(request).await.map_err(|e| {
464 McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
465 })
466 }
467
468 pub async fn shutdown(&mut self) {
470 let servers: Vec<(String, McpServerConnection)> = self.servers.drain().collect();
471
472 for (server_name, server) in servers {
473 if let Some(handle) = server.server_task {
474 drop(server.client);
476
477 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
479 Ok(Ok(())) => {
480 tracing::info!("Server '{server_name}' shut down gracefully");
481 }
482 Ok(Err(e)) => {
483 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
484 }
485 Err(_) => {
486 tracing::warn!("Server '{server_name}' shutdown timed out");
487 }
489 }
490 }
491 }
492
493 self.tools.clear();
494 self.tool_definitions.clear();
495 self.proxy = None;
496 }
497
498 pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
500 let server = self.servers.remove(server_name);
501
502 if let Some(server) = server {
503 if let Some(handle) = server.server_task {
504 drop(server.client);
506
507 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
509 Ok(Ok(())) => {
510 tracing::info!("Server '{server_name}' shut down gracefully");
511 }
512 Ok(Err(e)) => {
513 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
514 }
515 Err(_) => {
516 tracing::warn!("Server '{server_name}' shutdown timed out");
517 }
519 }
520 }
521
522 self.tools.retain(|tool_name, _| !tool_name.starts_with(server_name));
524
525 self.tool_definitions.retain(|tool_def| !tool_def.name.starts_with(server_name));
526 }
527
528 Ok(())
529 }
530
531 pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
536 {
538 let mut roots = self.roots.write().await;
539 *roots = new_roots;
540 }
541
542 self.notify_roots_changed().await;
544
545 Ok(())
546 }
547
548 async fn notify_roots_changed(&self) {
553 for (server_name, server_conn) in &self.servers {
554 if let Err(e) = server_conn.client.notify_roots_list_changed().await {
556 tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
558 }
559 }
560 }
561}
562
563impl Drop for McpManager {
564 fn drop(&mut self) {
565 let servers: Vec<(String, McpServerConnection)> = self.servers.drain().collect();
566 for (server_name, server) in servers {
567 if let Some(handle) = server.server_task {
568 handle.abort();
569 tracing::warn!("Server '{server_name}' task aborted during cleanup");
570 }
571 }
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::McpManager;
578 use crate::client::config::ServerConfig;
579 use rmcp::{
580 Json, RoleServer, ServerHandler,
581 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
582 model::{Implementation, ServerCapabilities, ServerInfo},
583 service::DynService,
584 tool, tool_handler, tool_router,
585 };
586 use schemars::JsonSchema;
587 use serde::{Deserialize, Serialize};
588 use std::{
589 io,
590 sync::{Arc, Mutex},
591 };
592 use tokio::sync::mpsc;
593
594 #[derive(Clone)]
595 struct TestServer {
596 tool_router: ToolRouter<Self>,
597 }
598
599 #[tool_handler(router = self.tool_router)]
600 impl ServerHandler for TestServer {
601 fn get_info(&self) -> ServerInfo {
602 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
603 .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
604 }
605 }
606
607 impl Default for TestServer {
608 fn default() -> Self {
609 Self { tool_router: Self::tool_router() }
610 }
611 }
612
613 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
614 struct EchoRequest {
615 value: String,
616 }
617
618 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
619 struct EchoResult {
620 value: String,
621 }
622
623 #[tool_router]
624 impl TestServer {
625 fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
626 Box::new(self)
627 }
628
629 #[tool(description = "Returns the provided value")]
630 async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
631 let Parameters(EchoRequest { value }) = request;
632 Json(EchoResult { value })
633 }
634 }
635
636 #[derive(Clone)]
637 struct SharedWriter(Arc<Mutex<Vec<u8>>>);
638
639 impl io::Write for SharedWriter {
640 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
641 self.0.lock().unwrap().extend_from_slice(buf);
642 Ok(buf.len())
643 }
644
645 fn flush(&mut self) -> io::Result<()> {
646 Ok(())
647 }
648 }
649
650 #[tokio::test]
651 async fn drop_logs_cleanup_abort_with_tracing() {
652 let (event_sender, _event_receiver) = mpsc::channel(1);
653 let mut manager = McpManager::new(event_sender, None);
654 manager
655 .add_mcp(
656 ServerConfig::InMemory { name: "test".to_string(), server: TestServer::default().into_dyn() }.into(),
657 )
658 .await
659 .unwrap();
660
661 let output = Arc::new(Mutex::new(Vec::new()));
662 let subscriber = tracing_subscriber::fmt()
663 .with_ansi(false)
664 .without_time()
665 .with_writer({
666 let output = Arc::clone(&output);
667 move || SharedWriter(Arc::clone(&output))
668 })
669 .finish();
670
671 tracing::subscriber::with_default(subscriber, || {
672 drop(manager);
673 });
674
675 let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
676 assert!(logs.contains("Server 'test' task aborted during cleanup"));
677 }
678}