1use llm::ToolDefinition;
2
3use super::{
4 McpError, Result,
5 config::{McpServerConfig, ServerConfig},
6 connection::{ConnectParams, ConnectResult, McpServerConnection, ServerInstructions, Tool},
7 mcp_client::McpClient,
8 naming::{create_namespaced_tool_name, split_on_server_name},
9 oauth::{OAuthHandler, perform_oauth_flow},
10 tool_proxy::ToolProxy,
11};
12use rmcp::{
13 RoleClient,
14 model::{
15 CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
16 ElicitationAction, Implementation, Root,
17 },
18 service::RunningService,
19 transport::streamable_http_client::StreamableHttpClientTransportConfig,
20};
21use serde_json::Value;
22use std::collections::{HashMap, HashSet};
23use std::sync::Arc;
24use tokio::sync::{RwLock, mpsc, oneshot};
25
26pub use crate::status::{McpServerStatus, McpServerStatusEntry};
27
28#[derive(Debug)]
29pub struct ElicitationRequest {
30 pub request: CreateElicitationRequestParams,
31 pub response_sender: oneshot::Sender<CreateElicitationResult>,
32}
33
34#[derive(Debug, Clone)]
35pub struct ElicitationResponse {
36 pub action: ElicitationAction,
37 pub content: Option<Value>,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43enum Registration {
44 Direct,
46 Proxied,
48}
49
50pub struct McpManager {
52 servers: HashMap<String, McpServerConnection>,
53 tools: HashMap<String, Tool>,
54 tool_definitions: Vec<ToolDefinition>,
55 client_info: ClientInfo,
56 elicitation_sender: mpsc::Sender<ElicitationRequest>,
57 roots: Arc<RwLock<Vec<Root>>>,
59 oauth_handler: Option<Arc<dyn OAuthHandler>>,
60 server_statuses: Vec<McpServerStatusEntry>,
61 pending_configs: HashMap<String, StreamableHttpClientTransportConfig>,
63 proxy: Option<ToolProxy>,
65}
66
67impl McpManager {
68 pub fn new(
69 elicitation_sender: mpsc::Sender<ElicitationRequest>,
70 oauth_handler: Option<Arc<dyn OAuthHandler>>,
71 ) -> Self {
72 Self {
73 servers: HashMap::new(),
74 tools: HashMap::new(),
75 tool_definitions: Vec::new(),
76 client_info: ClientInfo::new(
77 ClientCapabilities::builder().enable_elicitation().enable_roots().build(),
78 Implementation::new("aether", "0.1.0"),
79 ),
80 elicitation_sender,
81 roots: Arc::new(RwLock::new(Vec::new())),
82 oauth_handler,
83 server_statuses: Vec::new(),
84 pending_configs: HashMap::new(),
85 proxy: None,
86 }
87 }
88
89 fn create_mcp_client(&self) -> McpClient {
90 McpClient::new(self.client_info.clone(), self.elicitation_sender.clone(), Arc::clone(&self.roots))
91 }
92
93 fn connect_params(&self) -> ConnectParams {
94 ConnectParams { mcp_client: self.create_mcp_client(), oauth_handler: self.oauth_handler.clone() }
95 }
96
97 fn set_status(&mut self, name: &str, status: McpServerStatus) {
99 if let Some(entry) = self.server_statuses.iter_mut().find(|s| s.name == name) {
100 entry.status = status;
101 } else {
102 self.server_statuses.push(McpServerStatusEntry { name: name.to_string(), status });
103 }
104 }
105
106 pub async fn add_mcps(&mut self, configs: Vec<McpServerConfig>) -> Result<()> {
107 for config in configs {
108 let name = config.name().to_string();
109 if let Err(e) = self.add_mcp(config).await {
110 tracing::warn!("Failed to connect to MCP server '{}': {}", name, e);
112 if !self.server_statuses.iter().any(|s| s.name == name) {
114 self.set_status(&name, McpServerStatus::Failed { error: e.to_string() });
115 }
116 }
117 }
118 Ok(())
119 }
120
121 pub async fn add_mcp_with_auth(&mut self, name: String, base_url: &str, auth_header: String) -> Result<()> {
122 let config = ServerConfig::Http {
123 name: name.clone(),
124 config: StreamableHttpClientTransportConfig::with_uri(base_url).auth_header(auth_header),
125 };
126 let params = self.connect_params();
127 match McpServerConnection::connect(config, params).await {
128 ConnectResult::Connected(conn) => self.register_server(&name, conn, Registration::Direct).await,
129 ConnectResult::NeedsOAuth { error, .. } => Err(error),
130 ConnectResult::Failed(e) => Err(e),
131 }
132 }
133
134 pub async fn add_mcp(&mut self, config: McpServerConfig) -> Result<()> {
135 match config {
136 McpServerConfig::ToolProxy { name, servers } => self.connect_tool_proxy(name, servers).await,
137
138 McpServerConfig::Server(config) => {
139 let name = config.name().to_string();
140 let params = self.connect_params();
141 match McpServerConnection::connect(config, params).await {
142 ConnectResult::Connected(conn) => self.register_server(&name, conn, Registration::Direct).await,
143 ConnectResult::NeedsOAuth { name, config, error } => {
144 self.pending_configs.insert(name.clone(), config);
145 self.set_status(&name, McpServerStatus::NeedsOAuth);
146 Err(error)
147 }
148 ConnectResult::Failed(e) => Err(e),
149 }
150 }
151 }
152 }
153
154 async fn connect_tool_proxy(&mut self, proxy_name: String, servers: Vec<ServerConfig>) -> Result<()> {
158 let tool_dir = ToolProxy::dir(&proxy_name)?;
159 ToolProxy::clean_dir(&tool_dir).await?;
160
161 let mut nested_names = HashSet::new();
162 let mut server_descriptions = Vec::new();
163
164 for config in servers {
165 let server_name = config.name().to_string();
166 let params = self.connect_params();
167
168 let result = match McpServerConnection::connect(config, params).await {
169 ConnectResult::Connected(conn) => self.register_server(&server_name, conn, Registration::Proxied).await,
170 ConnectResult::NeedsOAuth { name, config, error } => {
171 self.pending_configs.insert(name.clone(), config);
172 self.set_status(&name, McpServerStatus::NeedsOAuth);
173 Err(error)
174 }
175 ConnectResult::Failed(e) => Err(e),
176 };
177
178 match result {
179 Ok(()) => {
180 if let Some(conn) = self.servers.get(&server_name) {
182 let client = conn.client.clone();
183 if let Err(e) = ToolProxy::write_tools_to_dir(&server_name, &client, &tool_dir).await {
184 tracing::warn!("Failed to write tool files for nested server '{server_name}': {e}");
185 }
186
187 let description = ToolProxy::extract_server_description(&client, &server_name);
188 server_descriptions.push((server_name.clone(), description));
189 }
190 nested_names.insert(server_name);
191 }
192 Err(e) => {
193 tracing::warn!("Failed to connect nested server '{server_name}': {e}");
194 if self.pending_configs.contains_key(&server_name) {
197 nested_names.insert(server_name);
198 }
199 }
200 }
201 }
202
203 let call_tool_def = ToolProxy::call_tool_definition(&proxy_name);
204 self.tools.insert(
205 call_tool_def.name.clone(),
206 Tool {
207 description: call_tool_def.description.clone(),
208 parameters: serde_json::from_str(&call_tool_def.parameters)
209 .unwrap_or(Value::Object(serde_json::Map::default())),
210 },
211 );
212 self.tool_definitions.push(call_tool_def);
213
214 self.proxy = Some(ToolProxy::new(proxy_name.clone(), nested_names, tool_dir, &server_descriptions));
215
216 self.set_status(&proxy_name, McpServerStatus::Connected { tool_count: 1 });
218
219 Ok(())
220 }
221
222 async fn oauth_and_reconnect(&mut self, name: String, config: StreamableHttpClientTransportConfig) -> Result<()> {
223 let handler = self
224 .oauth_handler
225 .as_ref()
226 .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler available for '{name}'")))?;
227 let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref())
228 .await
229 .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
230
231 let mcp_client = self.create_mcp_client();
232 let conn = McpServerConnection::reconnect_with_auth(&name, config, auth_client, mcp_client).await?;
233
234 if let Some(proxy) = self.proxy.as_ref().filter(|p| p.contains_server(&name)) {
236 let tool_dir = proxy.tool_dir().to_path_buf();
237 self.register_server(&name, conn, Registration::Proxied).await?;
238 if let Some(conn) = self.servers.get(&name) {
240 let client = conn.client.clone();
241 if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &tool_dir).await {
242 tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
243 }
244 }
245 Ok(())
246 } else {
247 self.register_server(&name, conn, Registration::Direct).await
248 }
249 }
250
251 async fn register_server(
257 &mut self,
258 name: &str,
259 conn: McpServerConnection,
260 registration: Registration,
261 ) -> Result<()> {
262 let tools = conn
263 .list_tools()
264 .await
265 .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
266
267 for rmcp_tool in &tools {
268 let tool_name = rmcp_tool.name.to_string();
269 let namespaced_tool_name = create_namespaced_tool_name(name, &tool_name);
270 let tool = Tool::from(rmcp_tool);
271
272 if registration == Registration::Direct {
273 self.tool_definitions.push(ToolDefinition {
274 name: namespaced_tool_name.clone(),
275 description: tool.description.clone(),
276 parameters: tool.parameters.to_string(),
277 server: Some(name.to_string()),
278 });
279 }
280
281 self.tools.insert(namespaced_tool_name, tool);
282 }
283
284 let tool_count = tools.len();
285
286 self.set_status(name, McpServerStatus::Connected { tool_count });
287
288 self.pending_configs.remove(name);
290
291 self.servers.insert(name.to_string(), conn);
292 Ok(())
293 }
294
295 pub fn get_client_for_tool(
301 &self,
302 namespaced_tool_name: &str,
303 arguments_json: &str,
304 ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
305 if !self.tools.contains_key(namespaced_tool_name) {
306 return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
307 }
308
309 let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
310 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
311
312 if let Some(proxy) = self.proxy.as_ref().filter(|p| p.name() == server_name) {
313 let call = proxy.resolve_call(arguments_json)?;
314 let conn = self
315 .servers
316 .get(&call.server)
317 .ok_or_else(|| McpError::ServerNotFound(format!("Nested server '{}' is not connected", call.server)))?;
318 let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
319 return Ok((conn.client.clone(), params));
320 }
321
322 let client = self
323 .servers
324 .get(server_name)
325 .map(|server| server.client.clone())
326 .ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
327
328 let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
329 let mut params = CallToolRequestParams::new(tool_name.to_string());
330 if let Some(args) = arguments {
331 params = params.with_arguments(args);
332 }
333
334 Ok((client, params))
335 }
336
337 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
338 self.tool_definitions.clone()
339 }
340
341 pub fn server_instructions(&self) -> Vec<ServerInstructions> {
344 let mut instructions: Vec<ServerInstructions> = self
345 .servers
346 .iter()
347 .filter(|(name, _)| self.proxy.as_ref().is_none_or(|p| !p.contains_server(name)))
348 .filter_map(|(name, conn)| {
349 conn.instructions
350 .as_ref()
351 .map(|instr| ServerInstructions { server_name: name.clone(), instructions: instr.clone() })
352 })
353 .collect();
354
355 if let Some(proxy) = &self.proxy {
356 instructions.push(ServerInstructions {
357 server_name: proxy.name().to_string(),
358 instructions: proxy.instructions().to_string(),
359 });
360 }
361
362 instructions
363 }
364
365 pub fn server_statuses(&self) -> &[McpServerStatusEntry] {
366 &self.server_statuses
367 }
368
369 pub async fn authenticate_server(&mut self, name: &str) -> Result<()> {
374 let config = self
375 .pending_configs
376 .get(name)
377 .ok_or_else(|| McpError::ConnectionFailed(format!("no pending config for server '{name}'")))?
378 .clone();
379
380 self.oauth_and_reconnect(name.to_string(), config).await
381 }
382
383 pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
385 use futures::future::join_all;
386
387 let futures: Vec<_> = self
388 .servers
389 .iter()
390 .filter(|(_, server_conn)| {
391 server_conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref()).is_some()
392 })
393 .map(|(server_name, server_conn)| {
394 let server_name = server_name.clone();
395 let client = server_conn.client.clone();
396 async move {
397 let prompts_response = client.list_prompts(None).await.map_err(|e| {
398 McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
399 })?;
400
401 let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
402 .prompts
403 .into_iter()
404 .map(|prompt| {
405 let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
406 rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
407 })
408 .collect();
409
410 Ok::<_, McpError>(namespaced_prompts)
411 }
412 })
413 .collect();
414
415 let results = join_all(futures).await;
416 let mut all_prompts = Vec::new();
417 for result in results {
418 all_prompts.extend(result?);
419 }
420
421 Ok(all_prompts)
422 }
423
424 pub async fn get_prompt(
426 &self,
427 namespaced_prompt_name: &str,
428 arguments: Option<serde_json::Map<String, serde_json::Value>>,
429 ) -> Result<rmcp::model::GetPromptResult> {
430 let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
431 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
432
433 let server_conn =
434 self.servers.get(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
435
436 let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
437 if let Some(args) = arguments {
438 request = request.with_arguments(args);
439 }
440
441 server_conn.client.get_prompt(request).await.map_err(|e| {
442 McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
443 })
444 }
445
446 pub async fn shutdown(&mut self) {
448 let servers: Vec<(String, McpServerConnection)> = self.servers.drain().collect();
449
450 for (server_name, server) in servers {
451 if let Some(handle) = server.server_task {
452 drop(server.client);
454
455 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
457 Ok(Ok(())) => {
458 tracing::info!("Server '{server_name}' shut down gracefully");
459 }
460 Ok(Err(e)) => {
461 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
462 }
463 Err(_) => {
464 tracing::warn!("Server '{server_name}' shutdown timed out");
465 }
467 }
468 }
469 }
470
471 self.tools.clear();
472 self.tool_definitions.clear();
473 self.proxy = None;
474 }
475
476 pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
478 let server = self.servers.remove(server_name);
479
480 if let Some(server) = server {
481 if let Some(handle) = server.server_task {
482 drop(server.client);
484
485 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
487 Ok(Ok(())) => {
488 tracing::info!("Server '{server_name}' shut down gracefully");
489 }
490 Ok(Err(e)) => {
491 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
492 }
493 Err(_) => {
494 tracing::warn!("Server '{server_name}' shutdown timed out");
495 }
497 }
498 }
499
500 self.tools.retain(|tool_name, _| !tool_name.starts_with(server_name));
502
503 self.tool_definitions.retain(|tool_def| !tool_def.name.starts_with(server_name));
504 }
505
506 Ok(())
507 }
508
509 pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
514 {
516 let mut roots = self.roots.write().await;
517 *roots = new_roots;
518 }
519
520 self.notify_roots_changed().await;
522
523 Ok(())
524 }
525
526 async fn notify_roots_changed(&self) {
531 for (server_name, server_conn) in &self.servers {
532 if let Err(e) = server_conn.client.notify_roots_list_changed().await {
534 tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
536 }
537 }
538 }
539}
540
541impl Drop for McpManager {
542 fn drop(&mut self) {
543 let servers: Vec<(String, McpServerConnection)> = self.servers.drain().collect();
544 for (server_name, server) in servers {
545 if let Some(handle) = server.server_task {
546 handle.abort();
547 tracing::warn!("Server '{server_name}' task aborted during cleanup");
548 }
549 }
550 }
551}
552
553#[cfg(test)]
554mod tests {
555 use super::McpManager;
556 use crate::client::config::ServerConfig;
557 use rmcp::{
558 Json, RoleServer, ServerHandler,
559 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
560 model::{Implementation, ServerCapabilities, ServerInfo},
561 service::DynService,
562 tool, tool_handler, tool_router,
563 };
564 use schemars::JsonSchema;
565 use serde::{Deserialize, Serialize};
566 use std::{
567 io,
568 sync::{Arc, Mutex},
569 };
570 use tokio::sync::mpsc;
571
572 #[derive(Clone)]
573 struct TestServer {
574 tool_router: ToolRouter<Self>,
575 }
576
577 #[tool_handler(router = self.tool_router)]
578 impl ServerHandler for TestServer {
579 fn get_info(&self) -> ServerInfo {
580 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
581 .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
582 }
583 }
584
585 impl Default for TestServer {
586 fn default() -> Self {
587 Self { tool_router: Self::tool_router() }
588 }
589 }
590
591 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
592 struct EchoRequest {
593 value: String,
594 }
595
596 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
597 struct EchoResult {
598 value: String,
599 }
600
601 #[tool_router]
602 impl TestServer {
603 fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
604 Box::new(self)
605 }
606
607 #[tool(description = "Returns the provided value")]
608 async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
609 let Parameters(EchoRequest { value }) = request;
610 Json(EchoResult { value })
611 }
612 }
613
614 #[derive(Clone)]
615 struct SharedWriter(Arc<Mutex<Vec<u8>>>);
616
617 impl io::Write for SharedWriter {
618 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
619 self.0.lock().unwrap().extend_from_slice(buf);
620 Ok(buf.len())
621 }
622
623 fn flush(&mut self) -> io::Result<()> {
624 Ok(())
625 }
626 }
627
628 #[tokio::test]
629 async fn drop_logs_cleanup_abort_with_tracing() {
630 let (elicitation_sender, _elicitation_receiver) = mpsc::channel(1);
631 let mut manager = McpManager::new(elicitation_sender, None);
632 manager
633 .add_mcp(
634 ServerConfig::InMemory { name: "test".to_string(), server: TestServer::default().into_dyn() }.into(),
635 )
636 .await
637 .unwrap();
638
639 let output = Arc::new(Mutex::new(Vec::new()));
640 let subscriber = tracing_subscriber::fmt()
641 .with_ansi(false)
642 .without_time()
643 .with_writer({
644 let output = Arc::clone(&output);
645 move || SharedWriter(Arc::clone(&output))
646 })
647 .finish();
648
649 tracing::subscriber::with_default(subscriber, || {
650 drop(manager);
651 });
652
653 let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
654 assert!(logs.contains("Server 'test' task aborted during cleanup"));
655 }
656}