1use llm::ToolDefinition;
2
3use super::{
4 McpError, Result,
5 config::{McpServerConfig, ServerConfig},
6 connect_mcp::{ConnectOutcome, ConnectionSpec, ProxySpec, Registration, build_plan, connect_mcp},
7 connection::{McpServerConnection, ServerInstructions, Tool},
8 mcp_client::McpClient,
9 naming::{create_namespaced_tool_name, split_on_server_name},
10 oauth::{OAuthHandler, perform_oauth_flow},
11 tool_proxy::ToolProxy,
12};
13use futures::future::join_all;
14use rmcp::{
15 RoleClient,
16 model::{
17 CallToolRequestParams, ClientCapabilities, ClientInfo, CreateElicitationRequestParams, CreateElicitationResult,
18 ElicitationAction, FormElicitationCapability, Implementation, Root, Tool as RmcpTool, UrlElicitationCapability,
19 },
20 service::RunningService,
21 transport::streamable_http_client::StreamableHttpClientTransportConfig,
22};
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use std::collections::{HashMap, HashSet};
26use std::sync::Arc;
27use tokio::sync::{RwLock, mpsc, oneshot};
28
29pub use crate::status::{McpServerStatus, McpServerStatusEntry};
30
31#[derive(Debug)]
32pub struct ElicitationRequest {
33 pub server_name: String,
34 pub request: CreateElicitationRequestParams,
35 pub response_sender: oneshot::Sender<CreateElicitationResult>,
36}
37
38#[derive(Debug, Clone)]
39pub struct ElicitationResponse {
40 pub action: ElicitationAction,
41 pub content: Option<Value>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
45pub struct UrlElicitationCompleteParams {
46 pub server_name: String,
47 pub elicitation_id: String,
48}
49
50#[derive(Debug)]
54pub enum McpClientEvent {
55 Elicitation(ElicitationRequest),
56 UrlElicitationComplete(UrlElicitationCompleteParams),
57}
58
59pub struct McpManager {
61 servers: HashMap<String, McpServerConnection>,
62 tools: HashMap<String, Tool>,
63 tool_definitions: Vec<ToolDefinition>,
64 client_info: ClientInfo,
65 event_sender: mpsc::Sender<McpClientEvent>,
66 roots: Arc<RwLock<Vec<Root>>>,
68 oauth_handler: Option<Arc<dyn OAuthHandler>>,
69 server_statuses: Vec<McpServerStatusEntry>,
70 pending_configs: HashMap<String, StreamableHttpClientTransportConfig>,
72 proxy: Option<ToolProxy>,
74}
75
76impl McpManager {
77 pub fn new(event_sender: mpsc::Sender<McpClientEvent>, oauth_handler: Option<Arc<dyn OAuthHandler>>) -> Self {
78 let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
79 if let Some(elicitation) = capabilities.elicitation.as_mut() {
80 elicitation.form = Some(FormElicitationCapability::default());
81 elicitation.url = Some(UrlElicitationCapability::default());
82 }
83
84 Self {
85 servers: HashMap::new(),
86 tools: HashMap::new(),
87 tool_definitions: Vec::new(),
88 client_info: ClientInfo::new(capabilities, Implementation::new("aether", "0.1.0")),
89 event_sender,
90 roots: Arc::new(RwLock::new(Vec::new())),
91 oauth_handler,
92 server_statuses: Vec::new(),
93 pending_configs: HashMap::new(),
94 proxy: None,
95 }
96 }
97
98 fn set_status(&mut self, name: &str, status: McpServerStatus) {
100 if let Some(entry) = self.server_statuses.iter_mut().find(|s| s.name == name) {
101 entry.status = status;
102 } else {
103 self.server_statuses.push(McpServerStatusEntry { name: name.to_string(), status });
104 }
105 }
106
107 pub async fn add_mcps(&mut self, configs: Vec<McpServerConfig>) -> Result<()> {
108 let (direct, proxies) = build_plan(configs).await?;
109 let outcomes: Vec<ConnectOutcome> = join_all(direct.into_iter().map(|leaf| {
110 connect_mcp(leaf, &self.client_info, &self.event_sender, &self.roots, self.oauth_handler.as_ref())
111 }))
112 .await;
113
114 let mut mcp_proxies: HashMap<String, HashSet<String>> =
115 proxies.iter().map(|p| (p.name.clone(), HashSet::new())).collect();
116
117 let mut connected_mcps_to_proxy: HashMap<String, Vec<String>> = HashMap::new();
118 for outcome in outcomes {
119 match outcome {
120 ConnectOutcome::Ready { name, conn, tools, proxy, registration } => {
121 self.apply_connected(&name, conn, &tools, registration);
122 if let Some(p) = proxy {
123 if let Some(members) = mcp_proxies.get_mut(&p) {
124 members.insert(name.clone());
125 }
126 connected_mcps_to_proxy.entry(p).or_default().push(name);
127 }
128 }
129 ConnectOutcome::NeedsOAuth { name, config, error, proxy } => {
130 tracing::warn!("Server '{name}' needs OAuth: {error}");
131 self.pending_configs.insert(name.clone(), config);
132 self.set_status(&name, McpServerStatus::NeedsOAuth);
133 if let Some(p) = proxy
134 && let Some(members) = mcp_proxies.get_mut(&p)
135 {
136 members.insert(name);
137 }
138 }
139 ConnectOutcome::Failed { name, error } => {
140 tracing::warn!("Failed to connect to MCP server '{name}': {error}");
141 if !self.server_statuses.iter().any(|s| s.name == name) {
142 self.set_status(&name, McpServerStatus::Failed { error: error.to_string() });
143 }
144 }
145 }
146 }
147
148 let writes = proxies.iter().flat_map(|proxy| {
149 connected_mcps_to_proxy
150 .get(&proxy.name)
151 .into_iter()
152 .flatten()
153 .filter_map(|member| {
154 let client = self.servers.get(member)?.client.clone();
155 let dir = proxy.tool_dir.clone();
156 let name = member.clone();
157 Some(async move {
158 if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &dir).await {
159 tracing::warn!("Failed to write tool files for nested server '{name}': {e}");
160 }
161 })
162 })
163 .collect::<Vec<_>>()
164 });
165 join_all(writes).await;
166
167 for proxy in proxies {
168 self.register_proxy(proxy, &mut mcp_proxies, &mut connected_mcps_to_proxy);
169 }
170
171 Ok(())
172 }
173
174 pub async fn add_mcp_with_auth(&mut self, name: String, base_url: &str, auth_header: String) -> Result<()> {
175 let config = ServerConfig::Http {
176 name: name.clone(),
177 config: StreamableHttpClientTransportConfig::with_uri(base_url).auth_header(auth_header),
178 };
179 let leaf = ConnectionSpec { name, config, proxy: None, registration: Registration::Direct };
180 match connect_mcp(leaf, &self.client_info, &self.event_sender, &self.roots, self.oauth_handler.as_ref()).await {
181 ConnectOutcome::Ready { name, conn, tools, registration, .. } => {
182 self.apply_connected(&name, conn, &tools, registration);
183 Ok(())
184 }
185 ConnectOutcome::NeedsOAuth { error, .. } | ConnectOutcome::Failed { error, .. } => Err(error),
186 }
187 }
188
189 pub async fn add_mcp(&mut self, config: McpServerConfig) -> Result<()> {
190 match config {
191 McpServerConfig::ToolProxy { .. } => self.add_mcps(vec![config]).await,
192
193 McpServerConfig::Server(config) => {
194 let leaf = ConnectionSpec {
195 name: config.name().to_string(),
196 config,
197 proxy: None,
198 registration: Registration::Direct,
199 };
200 match connect_mcp(leaf, &self.client_info, &self.event_sender, &self.roots, self.oauth_handler.as_ref())
201 .await
202 {
203 ConnectOutcome::Ready { name, conn, tools, registration, .. } => {
204 self.apply_connected(&name, conn, &tools, registration);
205 Ok(())
206 }
207 ConnectOutcome::NeedsOAuth { name, config, error, .. } => {
208 self.pending_configs.insert(name.clone(), config);
209 self.set_status(&name, McpServerStatus::NeedsOAuth);
210 Err(error)
211 }
212 ConnectOutcome::Failed { error, .. } => Err(error),
213 }
214 }
215 }
216 }
217
218 fn register_proxy(
222 &mut self,
223 proxy: ProxySpec,
224 proxy_members: &mut HashMap<String, HashSet<String>>,
225 ready_for_proxy: &mut HashMap<String, Vec<String>>,
226 ) {
227 let members = ready_for_proxy.remove(&proxy.name).unwrap_or_default();
228
229 let server_descriptions: Vec<(String, String)> = members
230 .iter()
231 .filter_map(|member| {
232 self.servers
233 .get(member)
234 .map(|conn| (member.clone(), ToolProxy::extract_server_description(&conn.client, member)))
235 })
236 .collect();
237
238 let call_tool_def = ToolProxy::call_tool_definition(&proxy.name);
239 self.tools.insert(
240 call_tool_def.name.clone(),
241 Tool {
242 description: call_tool_def.description.clone(),
243 parameters: serde_json::from_str(&call_tool_def.parameters)
244 .unwrap_or(Value::Object(serde_json::Map::default())),
245 },
246 );
247 self.tool_definitions.push(call_tool_def);
248
249 let nested = proxy_members.remove(&proxy.name).unwrap_or_default();
250 self.proxy = Some(ToolProxy::new(proxy.name.clone(), nested, proxy.tool_dir, &server_descriptions));
251 self.set_status(&proxy.name, McpServerStatus::Connected { tool_count: 1 });
252 }
253
254 async fn oauth_and_reconnect(&mut self, name: String, config: StreamableHttpClientTransportConfig) -> Result<()> {
255 let handler = self
256 .oauth_handler
257 .as_ref()
258 .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler available for '{name}'")))?;
259 let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref())
260 .await
261 .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
262
263 let mcp_client =
264 McpClient::new(self.client_info.clone(), name.clone(), self.event_sender.clone(), Arc::clone(&self.roots));
265 let conn = McpServerConnection::reconnect_with_auth(&name, config, auth_client, mcp_client).await?;
266
267 if let Some(proxy) = self.proxy.as_ref().filter(|p| p.contains_server(&name)) {
269 let tool_dir = proxy.tool_dir().to_path_buf();
270 self.register_server(&name, conn, Registration::Proxied).await?;
271 if let Some(conn) = self.servers.get(&name) {
273 let client = conn.client.clone();
274 if let Err(e) = ToolProxy::write_tools_to_dir(&name, &client, &tool_dir).await {
275 tracing::warn!("Failed to write tool files for '{name}' after OAuth: {e}");
276 }
277 }
278 Ok(())
279 } else {
280 self.register_server(&name, conn, Registration::Direct).await
281 }
282 }
283
284 async fn register_server(
289 &mut self,
290 name: &str,
291 conn: McpServerConnection,
292 registration: Registration,
293 ) -> Result<()> {
294 let tools = conn
295 .list_tools()
296 .await
297 .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools for {name}: {e}")))?;
298 self.apply_connected(name, conn, &tools, registration);
299 Ok(())
300 }
301
302 fn apply_connected(
308 &mut self,
309 name: &str,
310 conn: McpServerConnection,
311 tools: &[RmcpTool],
312 registration: Registration,
313 ) {
314 for rmcp_tool in tools {
315 let tool_name = rmcp_tool.name.to_string();
316 let namespaced_tool_name = create_namespaced_tool_name(name, &tool_name);
317 let tool = Tool::from(rmcp_tool);
318
319 if registration == Registration::Direct {
320 self.tool_definitions.push(ToolDefinition {
321 name: namespaced_tool_name.clone(),
322 description: tool.description.clone(),
323 parameters: tool.parameters.to_string(),
324 server: Some(name.to_string()),
325 });
326 }
327
328 self.tools.insert(namespaced_tool_name, tool);
329 }
330
331 self.set_status(name, McpServerStatus::Connected { tool_count: tools.len() });
332 self.pending_configs.remove(name);
333 self.servers.insert(name.to_string(), conn);
334 }
335
336 pub fn get_client_for_tool(
342 &self,
343 namespaced_tool_name: &str,
344 arguments_json: &str,
345 ) -> Result<(Arc<RunningService<RoleClient, McpClient>>, CallToolRequestParams)> {
346 if !self.tools.contains_key(namespaced_tool_name) {
347 return Err(McpError::ToolNotFound(namespaced_tool_name.to_string()));
348 }
349
350 let (server_name, tool_name) = split_on_server_name(namespaced_tool_name)
351 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_tool_name.to_string()))?;
352
353 if let Some(proxy) = self.proxy.as_ref().filter(|p| p.name() == server_name) {
354 let call = proxy.resolve_call(arguments_json)?;
355 let conn = self
356 .servers
357 .get(&call.server)
358 .ok_or_else(|| McpError::ServerNotFound(format!("Nested server '{}' is not connected", call.server)))?;
359 let params = CallToolRequestParams::new(call.tool).with_arguments(call.arguments.unwrap_or_default());
360 return Ok((conn.client.clone(), params));
361 }
362
363 let client = self
364 .servers
365 .get(server_name)
366 .map(|server| server.client.clone())
367 .ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
368
369 let arguments = serde_json::from_str::<serde_json::Value>(arguments_json)?.as_object().cloned();
370 let mut params = CallToolRequestParams::new(tool_name.to_string());
371 if let Some(args) = arguments {
372 params = params.with_arguments(args);
373 }
374
375 Ok((client, params))
376 }
377
378 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
379 self.tool_definitions.clone()
380 }
381
382 pub fn server_instructions(&self) -> Vec<ServerInstructions> {
385 let mut instructions: Vec<ServerInstructions> = self
386 .servers
387 .iter()
388 .filter(|(name, _)| self.proxy.as_ref().is_none_or(|p| !p.contains_server(name)))
389 .filter_map(|(name, conn)| {
390 conn.instructions
391 .as_ref()
392 .map(|instr| ServerInstructions { server_name: name.clone(), instructions: instr.clone() })
393 })
394 .collect();
395
396 if let Some(proxy) = &self.proxy {
397 instructions.push(ServerInstructions {
398 server_name: proxy.name().to_string(),
399 instructions: proxy.instructions().to_string(),
400 });
401 }
402
403 instructions
404 }
405
406 pub fn server_statuses(&self) -> &[McpServerStatusEntry] {
407 &self.server_statuses
408 }
409
410 pub async fn authenticate_server(&mut self, name: &str) -> Result<()> {
415 let config = self
416 .pending_configs
417 .get(name)
418 .ok_or_else(|| McpError::ConnectionFailed(format!("no pending config for server '{name}'")))?
419 .clone();
420
421 self.oauth_and_reconnect(name.to_string(), config).await
422 }
423
424 pub async fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>> {
426 use futures::future::join_all;
427
428 let futures: Vec<_> = self
429 .servers
430 .iter()
431 .filter(|(_, server_conn)| {
432 server_conn.client.peer_info().and_then(|info| info.capabilities.prompts.as_ref()).is_some()
433 })
434 .map(|(server_name, server_conn)| {
435 let server_name = server_name.clone();
436 let client = server_conn.client.clone();
437 async move {
438 let prompts_response = client.list_prompts(None).await.map_err(|e| {
439 McpError::PromptListFailed(format!("Failed to list prompts for {server_name}: {e}"))
440 })?;
441
442 let namespaced_prompts: Vec<rmcp::model::Prompt> = prompts_response
443 .prompts
444 .into_iter()
445 .map(|prompt| {
446 let namespaced_name = create_namespaced_tool_name(&server_name, &prompt.name);
447 rmcp::model::Prompt::new(namespaced_name, prompt.description, prompt.arguments)
448 })
449 .collect();
450
451 Ok::<_, McpError>(namespaced_prompts)
452 }
453 })
454 .collect();
455
456 let results = join_all(futures).await;
457 let mut all_prompts = Vec::new();
458 for result in results {
459 all_prompts.extend(result?);
460 }
461
462 Ok(all_prompts)
463 }
464
465 pub async fn get_prompt(
467 &self,
468 namespaced_prompt_name: &str,
469 arguments: Option<serde_json::Map<String, serde_json::Value>>,
470 ) -> Result<rmcp::model::GetPromptResult> {
471 let (server_name, prompt_name) = split_on_server_name(namespaced_prompt_name)
472 .ok_or_else(|| McpError::InvalidToolNameFormat(namespaced_prompt_name.to_string()))?;
473
474 let server_conn =
475 self.servers.get(server_name).ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
476
477 let mut request = rmcp::model::GetPromptRequestParams::new(prompt_name);
478 if let Some(args) = arguments {
479 request = request.with_arguments(args);
480 }
481
482 server_conn.client.get_prompt(request).await.map_err(|e| {
483 McpError::PromptGetFailed(format!("Failed to get prompt '{prompt_name}' from {server_name}: {e}"))
484 })
485 }
486
487 pub async fn shutdown(&mut self) {
489 let servers: Vec<(String, McpServerConnection)> = self.servers.drain().collect();
490
491 for (server_name, server) in servers {
492 if let Some(handle) = server.server_task {
493 drop(server.client);
495
496 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
498 Ok(Ok(())) => {
499 tracing::info!("Server '{server_name}' shut down gracefully");
500 }
501 Ok(Err(e)) => {
502 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
503 }
504 Err(_) => {
505 tracing::warn!("Server '{server_name}' shutdown timed out");
506 }
508 }
509 }
510 }
511
512 self.tools.clear();
513 self.tool_definitions.clear();
514 self.proxy = None;
515 }
516
517 pub async fn shutdown_server(&mut self, server_name: &str) -> Result<()> {
519 let server = self.servers.remove(server_name);
520
521 if let Some(server) = server {
522 if let Some(handle) = server.server_task {
523 drop(server.client);
525
526 match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
528 Ok(Ok(())) => {
529 tracing::info!("Server '{server_name}' shut down gracefully");
530 }
531 Ok(Err(e)) => {
532 tracing::warn!("Server '{server_name}' task panicked: {e:?}");
533 }
534 Err(_) => {
535 tracing::warn!("Server '{server_name}' shutdown timed out");
536 }
538 }
539 }
540
541 self.tools.retain(|tool_name, _| !tool_name.starts_with(server_name));
543
544 self.tool_definitions.retain(|tool_def| !tool_def.name.starts_with(server_name));
545 }
546
547 Ok(())
548 }
549
550 pub async fn set_roots(&mut self, new_roots: Vec<Root>) -> Result<()> {
555 {
557 let mut roots = self.roots.write().await;
558 *roots = new_roots;
559 }
560
561 self.notify_roots_changed().await;
563
564 Ok(())
565 }
566
567 async fn notify_roots_changed(&self) {
572 for (server_name, server_conn) in &self.servers {
573 if let Err(e) = server_conn.client.notify_roots_list_changed().await {
575 tracing::debug!("Note: server '{server_name}' did not accept roots notification: {e}");
577 }
578 }
579 }
580}
581
582impl Drop for McpManager {
583 fn drop(&mut self) {
584 let servers: Vec<(String, McpServerConnection)> = self.servers.drain().collect();
585 for (server_name, server) in servers {
586 if let Some(handle) = server.server_task {
587 handle.abort();
588 tracing::warn!("Server '{server_name}' task aborted during cleanup");
589 }
590 }
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::McpManager;
597 use crate::client::config::ServerConfig;
598 use rmcp::{
599 Json, RoleServer, ServerHandler,
600 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
601 model::{Implementation, ServerCapabilities, ServerInfo},
602 service::DynService,
603 tool, tool_handler, tool_router,
604 };
605 use schemars::JsonSchema;
606 use serde::{Deserialize, Serialize};
607 use std::{
608 io,
609 sync::{Arc, Mutex},
610 };
611 use tokio::sync::mpsc;
612
613 #[derive(Clone)]
614 struct TestServer {
615 tool_router: ToolRouter<Self>,
616 }
617
618 #[tool_handler(router = self.tool_router)]
619 impl ServerHandler for TestServer {
620 fn get_info(&self) -> ServerInfo {
621 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
622 .with_server_info(Implementation::new("test-server", "0.1.0").with_description("Test MCP server"))
623 }
624 }
625
626 impl Default for TestServer {
627 fn default() -> Self {
628 Self { tool_router: Self::tool_router() }
629 }
630 }
631
632 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
633 struct EchoRequest {
634 value: String,
635 }
636
637 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
638 struct EchoResult {
639 value: String,
640 }
641
642 #[tool_router]
643 impl TestServer {
644 fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
645 Box::new(self)
646 }
647
648 #[tool(description = "Returns the provided value")]
649 async fn echo(&self, request: Parameters<EchoRequest>) -> Json<EchoResult> {
650 let Parameters(EchoRequest { value }) = request;
651 Json(EchoResult { value })
652 }
653 }
654
655 #[derive(Clone)]
656 struct SharedWriter(Arc<Mutex<Vec<u8>>>);
657
658 impl io::Write for SharedWriter {
659 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
660 self.0.lock().unwrap().extend_from_slice(buf);
661 Ok(buf.len())
662 }
663
664 fn flush(&mut self) -> io::Result<()> {
665 Ok(())
666 }
667 }
668
669 #[tokio::test]
670 async fn drop_logs_cleanup_abort_with_tracing() {
671 let (event_sender, _event_receiver) = mpsc::channel(1);
672 let mut manager = McpManager::new(event_sender, None);
673 manager
674 .add_mcp(
675 ServerConfig::InMemory { name: "test".to_string(), server: TestServer::default().into_dyn() }.into(),
676 )
677 .await
678 .unwrap();
679
680 let output = Arc::new(Mutex::new(Vec::new()));
681 let subscriber = tracing_subscriber::fmt()
682 .with_ansi(false)
683 .without_time()
684 .with_writer({
685 let output = Arc::clone(&output);
686 move || SharedWriter(Arc::clone(&output))
687 })
688 .finish();
689
690 tracing::subscriber::with_default(subscriber, || {
691 drop(manager);
692 });
693
694 let logs = String::from_utf8(output.lock().unwrap().clone()).unwrap();
695 assert!(logs.contains("Server 'test' task aborted during cleanup"));
696 }
697}