1use anyhow::Result;
2use axum::http::{HeaderMap, HeaderName};
3use chrono::{DateTime, Utc};
4use futures::stream::{FuturesUnordered, StreamExt};
5use futures::{future, FutureExt};
6use rand::{distributions::Alphanumeric, Rng};
7use rmcp::service::{ClientInitializeError, ServiceError};
8use rmcp::transport::streamable_http_client::{
9 AuthRequiredError, StreamableHttpClientTransportConfig, StreamableHttpError,
10};
11use rmcp::transport::{
12 ConfigureCommandExt, DynamicTransportError, StreamableHttpClientTransport, TokioChildProcess,
13};
14use std::collections::HashMap;
15use std::option::Option;
16use std::path::PathBuf;
17use std::process::Stdio;
18use std::sync::Arc;
19use std::time::Duration;
20use tempfile::{tempdir, TempDir};
21use tokio::io::AsyncReadExt;
22use tokio::process::Command;
23use tokio::sync::Mutex;
24use tokio::task;
25use tokio_stream::wrappers::ReceiverStream;
26use tokio_util::sync::CancellationToken;
27use tracing::{error, warn};
28
29use super::extension::{
30 ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, PlatformExtensionContext,
31 ToolInfo, PLATFORM_EXTENSIONS,
32};
33use super::tool_execution::ToolCallResult;
34use super::types::SharedProvider;
35use crate::agents::extension::{Envs, ProcessExit};
36use crate::agents::extension_malware_check;
37use crate::agents::mcp_client::{McpClient, McpClientTrait};
38use crate::config::search_path::SearchPaths;
39use crate::config::{get_all_extensions, Config};
40use crate::oauth::oauth_flow;
41use crate::prompt_template;
42use crate::subprocess::configure_command_no_window;
43use rmcp::model::{
44 CallToolRequestParam, Content, ErrorCode, ErrorData, GetPromptResult, Prompt, Resource,
45 ResourceContents, ServerInfo, Tool,
46};
47use rmcp::transport::auth::AuthClient;
48use schemars::_private::NoSerialize;
49use serde_json::Value;
50
51type McpClientBox = Arc<Mutex<Box<dyn McpClientTrait>>>;
52
53struct Extension {
54 pub config: ExtensionConfig,
55
56 client: McpClientBox,
57 server_info: Option<ServerInfo>,
58 _temp_dir: Option<tempfile::TempDir>,
59}
60
61impl Extension {
62 fn new(
63 config: ExtensionConfig,
64 client: McpClientBox,
65 server_info: Option<ServerInfo>,
66 temp_dir: Option<tempfile::TempDir>,
67 ) -> Self {
68 Self {
69 client,
70 config,
71 server_info,
72 _temp_dir: temp_dir,
73 }
74 }
75
76 fn supports_resources(&self) -> bool {
77 self.server_info
78 .as_ref()
79 .and_then(|info| info.capabilities.resources.as_ref())
80 .is_some()
81 }
82
83 fn get_instructions(&self) -> Option<String> {
84 self.server_info
85 .as_ref()
86 .and_then(|info| info.instructions.clone())
87 }
88
89 fn get_client(&self) -> McpClientBox {
90 self.client.clone()
91 }
92}
93
94pub struct ExtensionManager {
96 extensions: Mutex<HashMap<String, Extension>>,
97 context: Mutex<PlatformExtensionContext>,
98 provider: SharedProvider,
99}
100
101#[derive(Debug, Clone)]
103pub struct ResourceItem {
104 pub client_name: String, pub uri: String, pub name: String, pub content: String, pub timestamp: DateTime<Utc>, pub priority: f32, pub token_count: Option<u32>, }
112
113impl ResourceItem {
114 pub fn new(
115 client_name: String,
116 uri: String,
117 name: String,
118 content: String,
119 timestamp: DateTime<Utc>,
120 priority: f32,
121 ) -> Self {
122 Self {
123 client_name,
124 uri,
125 name,
126 content,
127 timestamp,
128 priority,
129 token_count: None,
130 }
131 }
132}
133
134fn normalize(input: String) -> String {
137 let mut result = String::with_capacity(input.len());
138 for c in input.chars() {
139 result.push(match c {
140 c if c.is_ascii_alphanumeric() || c == '_' || c == '-' => c,
141 c if c.is_whitespace() => continue, _ => '_', });
144 }
145 result.to_lowercase()
146}
147
148fn generate_extension_name(
150 server_info: Option<&ServerInfo>,
151 name_exists: impl Fn(&str) -> bool,
152) -> String {
153 let base = server_info
154 .and_then(|info| {
155 let name = info.server_info.name.as_str();
156 (!name.is_empty()).then(|| normalize(name.to_string()))
157 })
158 .unwrap_or_else(|| "unnamed".to_string());
159
160 if !name_exists(&base) {
161 return base;
162 }
163
164 let suffix: String = rand::thread_rng()
165 .sample_iter(Alphanumeric)
166 .take(6)
167 .map(char::from)
168 .collect();
169
170 format!("{base}_{suffix}")
171}
172
173fn resolve_command(cmd: &str) -> PathBuf {
174 SearchPaths::builder()
175 .with_npm()
176 .resolve(cmd)
177 .unwrap_or_else(|_| {
178 PathBuf::from(cmd)
180 })
181}
182
183fn require_str_parameter<'a>(v: &'a serde_json::Value, name: &str) -> Result<&'a str, ErrorData> {
184 let v = v.get(name).ok_or_else(|| {
185 ErrorData::new(
186 ErrorCode::INVALID_PARAMS,
187 format!("The parameter {name} is required"),
188 None,
189 )
190 })?;
191 match v.as_str() {
192 Some(r) => Ok(r),
193 None => Err(ErrorData::new(
194 ErrorCode::INVALID_PARAMS,
195 format!("The parameter {name} must be a string"),
196 None,
197 )),
198 }
199}
200
201pub fn get_parameter_names(tool: &Tool) -> Vec<String> {
202 let mut names: Vec<String> = tool
203 .input_schema
204 .get("properties")
205 .and_then(|props| props.as_object())
206 .map(|props| props.keys().cloned().collect())
207 .unwrap_or_default();
208 names.sort();
209 names
210}
211
212impl Default for ExtensionManager {
213 fn default() -> Self {
214 Self::new(Arc::new(Mutex::new(None)))
215 }
216}
217
218async fn child_process_client(
219 mut command: Command,
220 timeout: &Option<u64>,
221 provider: SharedProvider,
222) -> ExtensionResult<McpClient> {
223 #[cfg(unix)]
224 command.process_group(0);
225 configure_command_no_window(&mut command);
226
227 if let Ok(path) = SearchPaths::builder().path() {
228 command.env("PATH", path);
229 }
230
231 let (transport, mut stderr) = TokioChildProcess::builder(command)
232 .stderr(Stdio::piped())
233 .spawn()?;
234 let mut stderr = stderr.take().ok_or_else(|| {
235 ExtensionError::SetupError("failed to attach child process stderr".to_owned())
236 })?;
237
238 let stderr_task = tokio::spawn(async move {
239 let mut all_stderr = Vec::new();
240 stderr.read_to_end(&mut all_stderr).await?;
241 Ok::<String, std::io::Error>(String::from_utf8_lossy(&all_stderr).into())
242 });
243
244 let client_result = McpClient::connect(
245 transport,
246 Duration::from_secs(timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT)),
247 provider,
248 )
249 .await;
250
251 match client_result {
252 Ok(client) => Ok(client),
253 Err(error) => {
254 let error_task_out = stderr_task.await?;
255 Err::<McpClient, ExtensionError>(match error_task_out {
256 Ok(stderr_content) => ProcessExit::new(stderr_content, error).into(),
257 Err(e) => e.into(),
258 })
259 }
260 }
261}
262
263fn extract_auth_error(
264 res: &Result<McpClient, ClientInitializeError>,
265) -> Option<&AuthRequiredError> {
266 match res {
267 Ok(_) => None,
268 Err(err) => match err {
269 ClientInitializeError::TransportError {
270 error: DynamicTransportError { error, .. },
271 ..
272 } => error
273 .downcast_ref::<StreamableHttpError<reqwest::Error>>()
274 .and_then(|auth_error| match auth_error {
275 StreamableHttpError::AuthRequired(auth_required_error) => {
276 Some(auth_required_error)
277 }
278 _ => None,
279 }),
280 _ => None,
281 },
282 }
283}
284
285async fn merge_environments(
287 envs: &Envs,
288 env_keys: &[String],
289 ext_name: &str,
290) -> Result<HashMap<String, String>, ExtensionError> {
291 let mut all_envs = envs.get_env();
292 let config_instance = Config::global();
293
294 for key in env_keys {
295 if all_envs.contains_key(key) {
296 continue;
297 }
298
299 match config_instance.get(key, true) {
300 Ok(value) => {
301 if value.is_null() {
302 warn!(
303 key = %key,
304 ext_name = %ext_name,
305 "Secret key not found in config (returned null)."
306 );
307 continue;
308 }
309
310 if let Some(str_val) = value.as_str() {
311 all_envs.insert(key.clone(), str_val.to_string());
312 } else {
313 warn!(
314 key = %key,
315 ext_name = %ext_name,
316 value_type = %value.get("type").and_then(|t| t.as_str()).unwrap_or("unknown"),
317 "Secret value is not a string; skipping."
318 );
319 }
320 }
321 Err(e) => {
322 error!(
323 key = %key,
324 ext_name = %ext_name,
325 error = %e,
326 "Failed to fetch secret from config."
327 );
328 return Err(ExtensionError::ConfigError(format!(
329 "Failed to fetch secret '{}' from config: {}",
330 key, e
331 )));
332 }
333 }
334 }
335
336 Ok(all_envs)
337}
338
339fn substitute_env_vars(value: &str, env_map: &HashMap<String, String>) -> String {
341 let mut result = value.to_string();
342
343 let re_braces =
344 regex::Regex::new(r"\$\{\s*([A-Za-z_][A-Za-z0-9_]*)\s*\}").expect("valid regex");
345 for cap in re_braces.captures_iter(value) {
346 if let Some(var_name) = cap.get(1) {
347 if let Some(env_value) = env_map.get(var_name.as_str()) {
348 result = result.replace(&cap[0], env_value);
349 }
350 }
351 }
352
353 let re_simple = regex::Regex::new(r"\$([A-Za-z_][A-Za-z0-9_]*)").expect("valid regex");
354 for cap in re_simple.captures_iter(&result.clone()) {
355 if let Some(var_name) = cap.get(1) {
356 if !value.contains(&format!("${{{}}}", var_name.as_str())) {
357 if let Some(env_value) = env_map.get(var_name.as_str()) {
358 result = result.replace(&cap[0], env_value);
359 }
360 }
361 }
362 }
363
364 result
365}
366
367async fn create_streamable_http_client(
368 uri: &str,
369 timeout: Option<u64>,
370 headers: &HashMap<String, String>,
371 name: &str,
372 all_envs: &HashMap<String, String>,
373 provider: SharedProvider,
374) -> ExtensionResult<Box<dyn McpClientTrait>> {
375 let mut default_headers = HeaderMap::new();
376 for (key, value) in headers {
377 let substituted_value = substitute_env_vars(value, all_envs);
378 default_headers.insert(
379 HeaderName::try_from(key)
380 .map_err(|_| ExtensionError::ConfigError(format!("invalid header: {}", key)))?,
381 substituted_value.parse().map_err(|_| {
382 ExtensionError::ConfigError(format!("invalid header value: {}", key))
383 })?,
384 );
385 }
386
387 let http_client = reqwest::Client::builder()
388 .default_headers(default_headers)
389 .build()
390 .map_err(|_| ExtensionError::ConfigError("could not construct http client".to_string()))?;
391
392 let transport = StreamableHttpClientTransport::with_client(
393 http_client,
394 StreamableHttpClientTransportConfig {
395 uri: uri.into(),
396 ..Default::default()
397 },
398 );
399
400 let timeout_duration =
401 Duration::from_secs(timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT));
402
403 let client_res = McpClient::connect(transport, timeout_duration, provider.clone()).await;
404
405 if extract_auth_error(&client_res).is_some() {
406 let am = oauth_flow(&uri.to_string(), &name.to_string())
407 .await
408 .map_err(|_| ExtensionError::SetupError("auth error".to_string()))?;
409 let auth_client = AuthClient::new(reqwest::Client::default(), am);
410 let transport = StreamableHttpClientTransport::with_client(
411 auth_client,
412 StreamableHttpClientTransportConfig {
413 uri: uri.into(),
414 ..Default::default()
415 },
416 );
417 Ok(Box::new(
418 McpClient::connect(transport, timeout_duration, provider).await?,
419 ))
420 } else {
421 Ok(Box::new(client_res?))
422 }
423}
424
425async fn create_stdio_client(
426 cmd: &str,
427 args: &[String],
428 all_envs: HashMap<String, String>,
429 timeout: &Option<u64>,
430 provider: SharedProvider,
431) -> ExtensionResult<Box<dyn McpClientTrait>> {
432 extension_malware_check::deny_if_malicious_cmd_args(cmd, args).await?;
433
434 let resolved_cmd = resolve_command(cmd);
435 let command = Command::new(resolved_cmd).configure(|command| {
436 command.args(args).envs(all_envs);
437 });
438
439 Ok(Box::new(
440 child_process_client(command, timeout, provider).await?,
441 ))
442}
443
444impl ExtensionManager {
445 pub fn new(provider: SharedProvider) -> Self {
446 Self {
447 extensions: Mutex::new(HashMap::new()),
448 context: Mutex::new(PlatformExtensionContext {
449 session_id: None,
450 extension_manager: None,
451 }),
452 provider,
453 }
454 }
455
456 pub fn new_without_provider() -> Self {
458 Self::new(Arc::new(Mutex::new(None)))
459 }
460
461 pub async fn set_context(&self, context: PlatformExtensionContext) {
462 *self.context.lock().await = context;
463 }
464
465 pub async fn get_context(&self) -> PlatformExtensionContext {
466 self.context.lock().await.clone()
467 }
468
469 pub async fn supports_resources(&self) -> bool {
470 self.extensions
471 .lock()
472 .await
473 .values()
474 .any(|ext| ext.supports_resources())
475 }
476
477 pub async fn add_extension(&self, config: ExtensionConfig) -> ExtensionResult<()> {
478 let config_name = config.key().to_string();
479 let sanitized_name = normalize(config_name.clone());
480
481 if self.extensions.lock().await.contains_key(&sanitized_name) {
482 return Ok(());
483 }
484
485 let mut temp_dir = None;
486
487 let client: Box<dyn McpClientTrait> = match &config {
488 ExtensionConfig::Sse { .. } => {
489 return Err(ExtensionError::ConfigError(
490 "SSE is unsupported, migrate to streamable_http".to_string(),
491 ));
492 }
493 ExtensionConfig::StreamableHttp {
494 uri,
495 timeout,
496 headers,
497 name,
498 envs,
499 env_keys,
500 ..
501 } => {
502 let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
503 create_streamable_http_client(
504 uri,
505 *timeout,
506 headers,
507 name,
508 &all_envs,
509 self.provider.clone(),
510 )
511 .await?
512 }
513 ExtensionConfig::Stdio {
514 cmd,
515 args,
516 envs,
517 env_keys,
518 timeout,
519 ..
520 } => {
521 let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
522 create_stdio_client(cmd, args, all_envs, timeout, self.provider.clone()).await?
523 }
524 ExtensionConfig::Builtin { name, timeout, .. } => {
525 let cmd = std::env::current_exe()
526 .and_then(|path| {
527 path.to_str().map(|s| s.to_string()).ok_or_else(|| {
528 std::io::Error::new(
529 std::io::ErrorKind::InvalidData,
530 "Invalid UTF-8 in executable path",
531 )
532 })
533 })
534 .map_err(|e| {
535 ExtensionError::ConfigError(format!(
536 "Failed to resolve executable path: {}",
537 e
538 ))
539 })?;
540 let command = Command::new(cmd).configure(|command| {
541 command.arg("mcp").arg(name);
542 });
543 Box::new(child_process_client(command, timeout, self.provider.clone()).await?)
544 }
545 ExtensionConfig::Platform { name, .. } => {
546 let normalized_key = normalize(name.clone());
547 let def = PLATFORM_EXTENSIONS
548 .get(normalized_key.as_str())
549 .ok_or_else(|| {
550 ExtensionError::ConfigError(format!("Unknown platform extension: {}", name))
551 })?;
552 let context = self.get_context().await;
553 (def.client_factory)(context)
554 }
555 ExtensionConfig::InlinePython {
556 name,
557 code,
558 timeout,
559 dependencies,
560 ..
561 } => {
562 let dir = tempdir()?;
563 let file_path = dir.path().join(format!("{}.py", name));
564 temp_dir = Some(dir);
565 std::fs::write(&file_path, code)?;
566
567 let command = Command::new("uvx").configure(|command| {
568 command.arg("--with").arg("mcp");
569 dependencies.iter().flatten().for_each(|dep| {
570 command.arg("--with").arg(dep);
571 });
572 command.arg("python").arg(file_path.to_str().unwrap());
573 });
574
575 Box::new(child_process_client(command, timeout, self.provider.clone()).await?)
576 }
577 ExtensionConfig::Frontend { .. } => {
578 return Err(ExtensionError::ConfigError(
579 "Invalid extension type: Frontend extensions cannot be added as server extensions".to_string()
580 ));
581 }
582 };
583
584 let server_info = client.get_info().cloned();
585
586 let mut extensions = self.extensions.lock().await;
588 let final_name = if sanitized_name.is_empty() {
589 generate_extension_name(server_info.as_ref(), |n| extensions.contains_key(n))
590 } else {
591 sanitized_name
592 };
593 extensions.insert(
594 final_name,
595 Extension::new(config, Arc::new(Mutex::new(client)), server_info, temp_dir),
596 );
597
598 Ok(())
599 }
600
601 pub async fn add_client(
602 &self,
603 name: String,
604 config: ExtensionConfig,
605 client: McpClientBox,
606 info: Option<ServerInfo>,
607 temp_dir: Option<TempDir>,
608 ) {
609 self.extensions
610 .lock()
611 .await
612 .insert(name, Extension::new(config, client, info, temp_dir));
613 }
614
615 pub async fn get_extensions_info(&self) -> Vec<ExtensionInfo> {
617 self.extensions
618 .lock()
619 .await
620 .iter()
621 .map(|(name, ext)| {
622 ExtensionInfo::new(
623 name,
624 ext.get_instructions().unwrap_or_default().as_str(),
625 ext.supports_resources(),
626 )
627 })
628 .collect()
629 }
630
631 pub async fn remove_extension(&self, name: &str) -> ExtensionResult<()> {
633 let sanitized_name = normalize(name.to_string());
634 self.extensions.lock().await.remove(&sanitized_name);
635 Ok(())
636 }
637
638 pub async fn get_extension_and_tool_counts(&self) -> (usize, usize) {
639 let enabled_extensions_count = self.extensions.lock().await.len();
640
641 let total_tools = self
642 .get_prefixed_tools(None)
643 .await
644 .map(|tools| tools.len())
645 .unwrap_or(0);
646
647 (enabled_extensions_count, total_tools)
648 }
649
650 pub async fn list_extensions(&self) -> ExtensionResult<Vec<String>> {
651 Ok(self.extensions.lock().await.keys().cloned().collect())
652 }
653
654 pub async fn is_extension_enabled(&self, name: &str) -> bool {
655 self.extensions.lock().await.contains_key(name)
656 }
657
658 pub async fn get_extension_configs(&self) -> Vec<ExtensionConfig> {
659 self.extensions
660 .lock()
661 .await
662 .values()
663 .map(|ext| ext.config.clone())
664 .collect()
665 }
666
667 pub async fn get_prefixed_tools(
669 &self,
670 extension_name: Option<String>,
671 ) -> ExtensionResult<Vec<Tool>> {
672 self.get_prefixed_tools_impl(extension_name, None).await
673 }
674
675 async fn get_prefixed_tools_impl(
676 &self,
677 extension_name: Option<String>,
678 exclude: Option<&str>,
679 ) -> ExtensionResult<Vec<Tool>> {
680 let filtered_clients: Vec<_> = self
682 .extensions
683 .lock()
684 .await
685 .iter()
686 .filter(|(name, _ext)| {
687 if let Some(excluded) = exclude {
688 if name.as_str() == excluded {
689 return false;
690 }
691 }
692
693 if let Some(ref name_filter) = extension_name {
694 *name == name_filter
695 } else {
696 true
697 }
698 })
699 .map(|(name, ext)| (name.clone(), ext.config.clone(), ext.get_client()))
700 .collect();
701
702 let cancel_token = CancellationToken::default();
703 let client_futures = filtered_clients.into_iter().map(|(name, config, client)| {
704 let cancel_token = cancel_token.clone();
705 task::spawn(async move {
706 let mut tools = Vec::new();
707 let client_guard = client.lock().await;
708 let mut client_tools = client_guard.list_tools(None, cancel_token).await?;
709
710 loop {
711 for tool in client_tools.tools {
712 let is_available = config.is_tool_available(&tool.name);
713
714 if is_available {
715 tools.push(Tool {
716 name: format!("{}__{}", name, tool.name).into(),
717 description: tool.description,
718 input_schema: tool.input_schema,
719 annotations: tool.annotations,
720 output_schema: tool.output_schema,
721 icons: tool.icons,
722 title: tool.title,
723 meta: tool.meta,
724 });
725 }
726 }
727
728 if client_tools.next_cursor.is_none() {
729 break;
730 }
731
732 client_tools = client_guard
733 .list_tools(client_tools.next_cursor, CancellationToken::default())
734 .await?;
735 }
736
737 Ok::<Vec<Tool>, ExtensionError>(tools)
738 })
739 });
740
741 let results = future::join_all(client_futures).await;
743
744 let mut tools = Vec::new();
746 for result in results {
747 match result {
748 Ok(Ok(client_tools)) => tools.extend(client_tools),
749 Ok(Err(err)) => return Err(err),
750 Err(join_err) => return Err(ExtensionError::from(join_err)),
751 }
752 }
753
754 Ok(tools)
755 }
756
757 pub async fn get_prefixed_tools_excluding(&self, exclude: &str) -> ExtensionResult<Vec<Tool>> {
758 self.get_prefixed_tools_impl(None, Some(exclude)).await
759 }
760
761 pub async fn get_planning_prompt(&self, tools_info: Vec<ToolInfo>) -> String {
763 let mut context: HashMap<&str, Value> = HashMap::new();
764 context.insert("tools", serde_json::to_value(tools_info).unwrap());
765
766 prompt_template::render_global_file("plan.md", &context).expect("Prompt should render")
767 }
768
769 async fn get_client_for_tool(&self, prefixed_name: &str) -> Option<(String, McpClientBox)> {
771 self.extensions
772 .lock()
773 .await
774 .iter()
775 .find(|(key, _)| prefixed_name.starts_with(*key))
776 .map(|(name, extension)| (name.clone(), extension.get_client()))
777 }
778
779 pub async fn read_resource_tool(
781 &self,
782 params: Value,
783 cancellation_token: CancellationToken,
784 ) -> Result<Vec<Content>, ErrorData> {
785 let uri = require_str_parameter(¶ms, "uri")?;
786
787 let extension_name = params.get("extension_name").and_then(|v| v.as_str());
788
789 if let Some(ext_name) = extension_name {
791 let read_result = self
792 .read_resource(uri, ext_name, cancellation_token.clone())
793 .await?;
794
795 let mut result = Vec::new();
796 for content in read_result.contents {
797 if let ResourceContents::TextResourceContents { text, .. } = content {
798 let content_str = format!("{}\n\n{}", uri, text);
799 result.push(Content::text(content_str));
800 }
801 }
802 return Ok(result);
803 }
804
805 let extension_names: Vec<String> = self.extensions.lock().await.keys().cloned().collect();
812
813 for extension_name in extension_names {
814 let read_result = self
815 .read_resource(uri, &extension_name, cancellation_token.clone())
816 .await;
817 match read_result {
818 Ok(read_result) => {
819 let mut result = Vec::new();
820 for content in read_result.contents {
821 if let ResourceContents::TextResourceContents { text, .. } = content {
822 let content_str = format!("{}\n\n{}", uri, text);
823 result.push(Content::text(content_str));
824 }
825 }
826 return Ok(result);
827 }
828 Err(_) => continue,
829 }
830 }
831
832 let available_extensions = self
834 .extensions
835 .lock()
836 .await
837 .keys()
838 .map(|s| s.as_str())
839 .collect::<Vec<&str>>()
840 .join(", ");
841 let error_msg = format!(
842 "Resource with uri '{}' not found. Here are the available extensions: {}",
843 uri, available_extensions
844 );
845
846 Err(ErrorData::new(
847 ErrorCode::RESOURCE_NOT_FOUND,
848 error_msg,
849 None,
850 ))
851 }
852
853 pub async fn read_resource(
854 &self,
855 uri: &str,
856 extension_name: &str,
857 cancellation_token: CancellationToken,
858 ) -> Result<rmcp::model::ReadResourceResult, ErrorData> {
859 let available_extensions = self
860 .extensions
861 .lock()
862 .await
863 .keys()
864 .map(|s| s.as_str())
865 .collect::<Vec<&str>>()
866 .join(", ");
867 let error_msg = format!(
868 "Extension '{}' not found. Here are the available extensions: {}",
869 extension_name, available_extensions
870 );
871
872 let client = self
873 .get_server_client(extension_name)
874 .await
875 .ok_or(ErrorData::new(ErrorCode::INVALID_PARAMS, error_msg, None))?;
876
877 let client_guard = client.lock().await;
878 client_guard
879 .read_resource(uri, cancellation_token)
880 .await
881 .map_err(|_| {
882 ErrorData::new(
883 ErrorCode::INTERNAL_ERROR,
884 format!("Could not read resource with uri: {}", uri),
885 None,
886 )
887 })
888 }
889
890 pub async fn get_ui_resources(&self) -> Result<Vec<(String, Resource)>, ErrorData> {
891 let mut ui_resources = Vec::new();
892
893 let extensions_to_check: Vec<(String, McpClientBox)> = {
894 let extensions = self.extensions.lock().await;
895 extensions
896 .iter()
897 .map(|(name, ext)| (name.clone(), ext.get_client()))
898 .collect()
899 };
900
901 for (extension_name, client) in extensions_to_check {
902 let client_guard = client.lock().await;
903
904 match client_guard
905 .list_resources(None, CancellationToken::default())
906 .await
907 {
908 Ok(list_response) => {
909 for resource in list_response.resources {
910 if resource.uri.starts_with("ui://") {
911 ui_resources.push((extension_name.clone(), resource));
912 }
913 }
914 }
915 Err(e) => {
916 warn!("Failed to list resources for {}: {:?}", extension_name, e);
917 }
918 }
919 }
920
921 Ok(ui_resources)
922 }
923
924 async fn list_resources_from_extension(
925 &self,
926 extension_name: &str,
927 cancellation_token: CancellationToken,
928 ) -> Result<Vec<Content>, ErrorData> {
929 let client = self
930 .get_server_client(extension_name)
931 .await
932 .ok_or_else(|| {
933 ErrorData::new(
934 ErrorCode::INVALID_PARAMS,
935 format!("Extension {} is not valid", extension_name),
936 None,
937 )
938 })?;
939
940 let client_guard = client.lock().await;
941 client_guard
942 .list_resources(None, cancellation_token)
943 .await
944 .map_err(|e| {
945 ErrorData::new(
946 ErrorCode::INTERNAL_ERROR,
947 format!("Unable to list resources for {}, {:?}", extension_name, e),
948 None,
949 )
950 })
951 .map(|lr| {
952 let resource_list = lr
953 .resources
954 .into_iter()
955 .map(|r| format!("{} - {}, uri: ({})", extension_name, r.name, r.uri))
956 .collect::<Vec<String>>()
957 .join("\n");
958
959 vec![Content::text(resource_list)]
960 })
961 }
962
963 pub async fn list_resources(
964 &self,
965 params: Value,
966 cancellation_token: CancellationToken,
967 ) -> Result<Vec<Content>, ErrorData> {
968 let extension = params.get("extension").and_then(|v| v.as_str());
969
970 match extension {
971 Some(extension_name) => {
972 self.list_resources_from_extension(extension_name, cancellation_token)
974 .await
975 }
976 None => {
977 let mut futures = FuturesUnordered::new();
979
980 self.extensions
982 .lock()
983 .await
984 .iter()
985 .filter(|(_name, ext)| ext.supports_resources())
986 .map(|(name, _ext)| name.clone())
987 .for_each(|name| {
988 let token = cancellation_token.clone();
989 futures.push(async move {
990 self.list_resources_from_extension(&name.clone(), token)
991 .await
992 });
993 });
994
995 let mut all_resources = Vec::new();
996 let mut errors = Vec::new();
997
998 while let Some(result) = futures.next().await {
1000 match result {
1001 Ok(content) => {
1002 all_resources.extend(content);
1003 }
1004 Err(tool_error) => {
1005 errors.push(tool_error);
1006 }
1007 }
1008 }
1009
1010 if !errors.is_empty() {
1011 tracing::error!(
1012 errors = ?errors
1013 .into_iter()
1014 .map(|e| format!("{:?}", e))
1015 .collect::<Vec<_>>(),
1016 "errors from listing resources"
1017 );
1018 }
1019
1020 Ok(all_resources)
1021 }
1022 }
1023 }
1024
1025 pub async fn dispatch_tool_call(
1026 &self,
1027 tool_call: CallToolRequestParam,
1028 cancellation_token: CancellationToken,
1029 ) -> Result<ToolCallResult> {
1030 let (client_name, client) =
1032 self.get_client_for_tool(&tool_call.name)
1033 .await
1034 .ok_or_else(|| {
1035 ErrorData::new(ErrorCode::RESOURCE_NOT_FOUND, tool_call.name.clone(), None)
1036 })?;
1037
1038 let tool_name = tool_call
1040 .name
1041 .strip_prefix(client_name.as_str())
1042 .and_then(|s| s.strip_prefix("__"))
1043 .ok_or_else(|| {
1044 ErrorData::new(ErrorCode::RESOURCE_NOT_FOUND, tool_call.name.clone(), None)
1045 })?
1046 .to_string();
1047
1048 if let Some(extension) = self.extensions.lock().await.get(&client_name) {
1049 if !extension.config.is_tool_available(&tool_name) {
1050 return Err(ErrorData::new(
1051 ErrorCode::RESOURCE_NOT_FOUND,
1052 format!(
1053 "Tool '{}' is not available for extension '{}'",
1054 tool_name, client_name
1055 ),
1056 None,
1057 )
1058 .into());
1059 }
1060 }
1061
1062 let arguments = tool_call.arguments.clone();
1063 let client = client.clone();
1064 let notifications_receiver = client.lock().await.subscribe().await;
1065
1066 let fut = async move {
1067 let client_guard = client.lock().await;
1068 client_guard
1069 .call_tool(&tool_name, arguments, cancellation_token)
1070 .await
1071 .map_err(|e| match e {
1072 ServiceError::McpError(error_data) => error_data,
1073 _ => {
1074 ErrorData::new(ErrorCode::INTERNAL_ERROR, e.to_string(), e.maybe_to_value())
1075 }
1076 })
1077 };
1078
1079 Ok(ToolCallResult {
1080 result: Box::new(fut.boxed()),
1081 notification_stream: Some(Box::new(ReceiverStream::new(notifications_receiver))),
1082 })
1083 }
1084
1085 pub async fn list_prompts_from_extension(
1086 &self,
1087 extension_name: &str,
1088 cancellation_token: CancellationToken,
1089 ) -> Result<Vec<Prompt>, ErrorData> {
1090 let client = self
1091 .get_server_client(extension_name)
1092 .await
1093 .ok_or_else(|| {
1094 ErrorData::new(
1095 ErrorCode::INVALID_PARAMS,
1096 format!("Extension {} is not valid", extension_name),
1097 None,
1098 )
1099 })?;
1100
1101 let client_guard = client.lock().await;
1102 client_guard
1103 .list_prompts(None, cancellation_token)
1104 .await
1105 .map_err(|e| {
1106 ErrorData::new(
1107 ErrorCode::INTERNAL_ERROR,
1108 format!("Unable to list prompts for {}, {:?}", extension_name, e),
1109 None,
1110 )
1111 })
1112 .map(|lp| lp.prompts)
1113 }
1114
1115 pub async fn list_prompts(
1116 &self,
1117 cancellation_token: CancellationToken,
1118 ) -> Result<HashMap<String, Vec<Prompt>>, ErrorData> {
1119 let mut futures = FuturesUnordered::new();
1120
1121 let names: Vec<_> = self.extensions.lock().await.keys().cloned().collect();
1122 for extension_name in names {
1123 let token = cancellation_token.clone();
1124 futures.push(async move {
1125 (
1126 extension_name.clone(),
1127 self.list_prompts_from_extension(extension_name.as_str(), token)
1128 .await,
1129 )
1130 });
1131 }
1132
1133 let mut all_prompts = HashMap::new();
1134 let mut errors = Vec::new();
1135
1136 while let Some(result) = futures.next().await {
1138 let (name, prompts) = result;
1139 match prompts {
1140 Ok(content) => {
1141 all_prompts.insert(name.to_string(), content);
1142 }
1143 Err(tool_error) => {
1144 errors.push(tool_error);
1145 }
1146 }
1147 }
1148
1149 if !errors.is_empty() {
1150 tracing::debug!(
1151 errors = ?errors
1152 .into_iter()
1153 .map(|e| format!("{:?}", e))
1154 .collect::<Vec<_>>(),
1155 "errors from listing prompts"
1156 );
1157 }
1158
1159 Ok(all_prompts)
1160 }
1161
1162 pub async fn get_prompt(
1163 &self,
1164 extension_name: &str,
1165 name: &str,
1166 arguments: Value,
1167 cancellation_token: CancellationToken,
1168 ) -> Result<GetPromptResult> {
1169 let client = self
1170 .get_server_client(extension_name)
1171 .await
1172 .ok_or_else(|| anyhow::anyhow!("Extension {} not found", extension_name))?;
1173
1174 let client_guard = client.lock().await;
1175 client_guard
1176 .get_prompt(name, arguments, cancellation_token)
1177 .await
1178 .map_err(|e| anyhow::anyhow!("Failed to get prompt: {}", e))
1179 }
1180
1181 pub async fn search_available_extensions(&self) -> Result<Vec<Content>, ErrorData> {
1182 let mut output_parts = vec![];
1183
1184 let mut disabled_extensions: Vec<String> = vec![];
1186 for extension in get_all_extensions() {
1187 if !extension.enabled {
1188 let config = extension.config.clone();
1189 let description = match &config {
1190 ExtensionConfig::Builtin {
1191 description,
1192 display_name,
1193 ..
1194 } => {
1195 if description.is_empty() {
1196 display_name.as_deref().unwrap_or("Built-in extension")
1197 } else {
1198 description
1199 }
1200 }
1201 ExtensionConfig::Sse { .. } => "SSE extension (unsupported)",
1202 ExtensionConfig::Platform { description, .. }
1203 | ExtensionConfig::StreamableHttp { description, .. }
1204 | ExtensionConfig::Stdio { description, .. }
1205 | ExtensionConfig::Frontend { description, .. }
1206 | ExtensionConfig::InlinePython { description, .. } => description,
1207 };
1208 disabled_extensions.push(format!("- {} - {}", config.name(), description));
1209 }
1210 }
1211
1212 let enabled_extensions: Vec<String> =
1214 self.extensions.lock().await.keys().cloned().collect();
1215
1216 if !disabled_extensions.is_empty() {
1218 output_parts.push(format!(
1219 "Extensions available to enable:\n{}\n",
1220 disabled_extensions.join("\n")
1221 ));
1222 } else {
1223 output_parts.push("No extensions available to enable.\n".to_string());
1224 }
1225
1226 if !enabled_extensions.is_empty() {
1227 output_parts.push(format!(
1228 "\n\nExtensions available to disable:\n{}\n",
1229 enabled_extensions
1230 .iter()
1231 .map(|name| format!("- {}", name))
1232 .collect::<Vec<_>>()
1233 .join("\n")
1234 ));
1235 } else {
1236 output_parts.push("No extensions that can be disabled.\n".to_string());
1237 }
1238
1239 Ok(vec![Content::text(output_parts.join("\n"))])
1240 }
1241
1242 async fn get_server_client(&self, name: impl Into<String>) -> Option<McpClientBox> {
1243 self.extensions
1244 .lock()
1245 .await
1246 .get(&name.into())
1247 .map(|ext| ext.get_client())
1248 }
1249
1250 pub async fn collect_moim(&self) -> Option<String> {
1251 let timestamp = chrono::Local::now().format("%Y-%m-%d %H:%M:00").to_string();
1253 let mut content = format!("<info-msg>\nIt is currently {}\n", timestamp);
1254
1255 let platform_clients: Vec<(String, McpClientBox)> = {
1256 let extensions = self.extensions.lock().await;
1257 extensions
1258 .iter()
1259 .filter_map(|(name, extension)| {
1260 if let ExtensionConfig::Platform { .. } = &extension.config {
1261 Some((name.clone(), extension.get_client()))
1262 } else {
1263 None
1264 }
1265 })
1266 .collect()
1267 };
1268
1269 for (name, client) in platform_clients {
1270 let client_guard = client.lock().await;
1271 if let Some(moim_content) = client_guard.get_moim().await {
1272 tracing::debug!("MOIM content from {}: {} chars", name, moim_content.len());
1273 content.push('\n');
1274 content.push_str(&moim_content);
1275 }
1276 }
1277
1278 content.push_str("\n</info-msg>");
1279
1280 Some(content)
1281 }
1282}
1283
1284#[cfg(test)]
1285mod tests {
1286 use super::*;
1287 use rmcp::model::CallToolResult;
1288 use rmcp::model::{InitializeResult, JsonObject};
1289 use rmcp::{object, ServiceError as Error};
1290
1291 use rmcp::model::ListPromptsResult;
1292 use rmcp::model::ListResourcesResult;
1293 use rmcp::model::ListToolsResult;
1294 use rmcp::model::ReadResourceResult;
1295 use rmcp::model::ServerNotification;
1296
1297 use tokio::sync::mpsc;
1298
1299 impl ExtensionManager {
1300 async fn add_mock_extension(&self, name: String, client: McpClientBox) {
1301 self.add_mock_extension_with_tools(name, client, vec![])
1302 .await;
1303 }
1304
1305 async fn add_mock_extension_with_tools(
1306 &self,
1307 name: String,
1308 client: McpClientBox,
1309 available_tools: Vec<String>,
1310 ) {
1311 let sanitized_name = normalize(name.clone());
1312 let config = ExtensionConfig::Builtin {
1313 name: name.clone(),
1314 display_name: Some(name.clone()),
1315 description: "built-in".to_string(),
1316 timeout: None,
1317 bundled: None,
1318 available_tools,
1319 };
1320 let extension = Extension::new(config, client, None, None);
1321 self.extensions
1322 .lock()
1323 .await
1324 .insert(sanitized_name, extension);
1325 }
1326 }
1327
1328 struct MockClient {}
1329
1330 #[async_trait::async_trait]
1331 impl McpClientTrait for MockClient {
1332 fn get_info(&self) -> Option<&InitializeResult> {
1333 None
1334 }
1335
1336 async fn list_resources(
1337 &self,
1338 _next_cursor: Option<String>,
1339 _cancellation_token: CancellationToken,
1340 ) -> Result<ListResourcesResult, Error> {
1341 Err(Error::TransportClosed)
1342 }
1343
1344 async fn read_resource(
1345 &self,
1346 _uri: &str,
1347 _cancellation_token: CancellationToken,
1348 ) -> Result<ReadResourceResult, Error> {
1349 Err(Error::TransportClosed)
1350 }
1351
1352 async fn list_tools(
1353 &self,
1354 _next_cursor: Option<String>,
1355 _cancellation_token: CancellationToken,
1356 ) -> Result<ListToolsResult, Error> {
1357 use serde_json::json;
1358 use std::sync::Arc;
1359 Ok(ListToolsResult {
1360 tools: vec![
1361 Tool::new(
1362 "tool".to_string(),
1363 "A basic tool".to_string(),
1364 Arc::new(json!({}).as_object().unwrap().clone()),
1365 ),
1366 Tool::new(
1367 "available_tool".to_string(),
1368 "An available tool".to_string(),
1369 Arc::new(json!({}).as_object().unwrap().clone()),
1370 ),
1371 Tool::new(
1372 "hidden_tool".to_string(),
1373 "hidden tool".to_string(),
1374 Arc::new(json!({}).as_object().unwrap().clone()),
1375 ),
1376 ],
1377 next_cursor: None,
1378 meta: None,
1379 })
1380 }
1381
1382 async fn call_tool(
1383 &self,
1384 name: &str,
1385 _arguments: Option<JsonObject>,
1386 _cancellation_token: CancellationToken,
1387 ) -> Result<CallToolResult, Error> {
1388 match name {
1389 "tool" | "test__tool" | "available_tool" | "hidden_tool" => Ok(CallToolResult {
1390 content: vec![],
1391 is_error: None,
1392 structured_content: None,
1393 meta: None,
1394 }),
1395 _ => Err(Error::TransportClosed),
1396 }
1397 }
1398
1399 async fn list_prompts(
1400 &self,
1401 _next_cursor: Option<String>,
1402 _cancellation_token: CancellationToken,
1403 ) -> Result<ListPromptsResult, Error> {
1404 Err(Error::TransportClosed)
1405 }
1406
1407 async fn get_prompt(
1408 &self,
1409 _name: &str,
1410 _arguments: Value,
1411 _cancellation_token: CancellationToken,
1412 ) -> Result<GetPromptResult, Error> {
1413 Err(Error::TransportClosed)
1414 }
1415
1416 async fn subscribe(&self) -> mpsc::Receiver<ServerNotification> {
1417 mpsc::channel(1).1
1418 }
1419 }
1420
1421 #[tokio::test]
1422 async fn test_get_client_for_tool() {
1423 let extension_manager = ExtensionManager::new_without_provider();
1424
1425 extension_manager
1427 .add_mock_extension(
1428 "test_client".to_string(),
1429 Arc::new(Mutex::new(Box::new(MockClient {}))),
1430 )
1431 .await;
1432
1433 extension_manager
1434 .add_mock_extension(
1435 "__client".to_string(),
1436 Arc::new(Mutex::new(Box::new(MockClient {}))),
1437 )
1438 .await;
1439
1440 extension_manager
1441 .add_mock_extension(
1442 "__cli__ent__".to_string(),
1443 Arc::new(Mutex::new(Box::new(MockClient {}))),
1444 )
1445 .await;
1446
1447 extension_manager
1448 .add_mock_extension(
1449 "client 🚀".to_string(),
1450 Arc::new(Mutex::new(Box::new(MockClient {}))),
1451 )
1452 .await;
1453
1454 assert!(extension_manager
1456 .get_client_for_tool("test_client__tool")
1457 .await
1458 .is_some());
1459
1460 assert!(extension_manager
1462 .get_client_for_tool("__client__tool")
1463 .await
1464 .is_some());
1465
1466 assert!(extension_manager
1468 .get_client_for_tool("__cli__ent____tool")
1469 .await
1470 .is_some());
1471
1472 assert!(extension_manager
1474 .get_client_for_tool("client___tool")
1475 .await
1476 .is_some());
1477 }
1478
1479 #[tokio::test]
1480 async fn test_dispatch_tool_call() {
1481 let extension_manager = ExtensionManager::new_without_provider();
1484
1485 extension_manager
1487 .add_mock_extension(
1488 "test_client".to_string(),
1489 Arc::new(Mutex::new(Box::new(MockClient {}))),
1490 )
1491 .await;
1492
1493 extension_manager
1494 .add_mock_extension(
1495 "__cli__ent__".to_string(),
1496 Arc::new(Mutex::new(Box::new(MockClient {}))),
1497 )
1498 .await;
1499
1500 extension_manager
1501 .add_mock_extension(
1502 "client 🚀".to_string(),
1503 Arc::new(Mutex::new(Box::new(MockClient {}))),
1504 )
1505 .await;
1506
1507 let tool_call = CallToolRequestParam {
1509 name: "test_client__tool".to_string().into(),
1510 arguments: Some(object!({})),
1511 };
1512
1513 let result = extension_manager
1514 .dispatch_tool_call(tool_call, CancellationToken::default())
1515 .await;
1516 assert!(result.is_ok());
1517
1518 let tool_call = CallToolRequestParam {
1519 name: "test_client__test__tool".to_string().into(),
1520 arguments: Some(object!({})),
1521 };
1522
1523 let result = extension_manager
1524 .dispatch_tool_call(tool_call, CancellationToken::default())
1525 .await;
1526 assert!(result.is_ok());
1527
1528 let tool_call = CallToolRequestParam {
1530 name: "__cli__ent____tool".to_string().into(),
1531 arguments: Some(object!({})),
1532 };
1533
1534 let result = extension_manager
1535 .dispatch_tool_call(tool_call, CancellationToken::default())
1536 .await;
1537 assert!(result.is_ok());
1538
1539 let tool_call = CallToolRequestParam {
1541 name: "client___tool".to_string().into(),
1542 arguments: Some(object!({})),
1543 };
1544
1545 let result = extension_manager
1546 .dispatch_tool_call(tool_call, CancellationToken::default())
1547 .await;
1548 assert!(result.is_ok());
1549
1550 let tool_call = CallToolRequestParam {
1551 name: "client___test__tool".to_string().into(),
1552 arguments: Some(object!({})),
1553 };
1554
1555 let result = extension_manager
1556 .dispatch_tool_call(tool_call, CancellationToken::default())
1557 .await;
1558 assert!(result.is_ok());
1559
1560 let invalid_tool_call = CallToolRequestParam {
1562 name: "client___tools".to_string().into(),
1563 arguments: Some(object!({})),
1564 };
1565
1566 let result = extension_manager
1567 .dispatch_tool_call(invalid_tool_call, CancellationToken::default())
1568 .await
1569 .unwrap()
1570 .result
1571 .await;
1572 assert!(matches!(
1573 result,
1574 Err(ErrorData {
1575 code: ErrorCode::INTERNAL_ERROR,
1576 ..
1577 })
1578 ));
1579
1580 let invalid_tool_call = CallToolRequestParam {
1583 name: "_client__tools".to_string().into(),
1584 arguments: Some(object!({})),
1585 };
1586
1587 let result = extension_manager
1588 .dispatch_tool_call(invalid_tool_call, CancellationToken::default())
1589 .await;
1590 if let Err(err) = result {
1591 let tool_err = err.downcast_ref::<ErrorData>().expect("Expected ErrorData");
1592 assert_eq!(tool_err.code, ErrorCode::RESOURCE_NOT_FOUND);
1593 } else {
1594 panic!("Expected ErrorData with ErrorCode::RESOURCE_NOT_FOUND");
1595 }
1596 }
1597
1598 #[tokio::test]
1599 async fn test_tool_availability_filtering() {
1600 let extension_manager = ExtensionManager::new_without_provider();
1601
1602 let available_tools = vec!["available_tool".to_string()];
1604
1605 extension_manager
1606 .add_mock_extension_with_tools(
1607 "test_extension".to_string(),
1608 Arc::new(Mutex::new(Box::new(MockClient {}))),
1609 available_tools,
1610 )
1611 .await;
1612
1613 let tools = extension_manager.get_prefixed_tools(None).await.unwrap();
1614
1615 let tool_names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect();
1616 assert!(!tool_names.iter().any(|name| name == "test_extension__tool")); assert!(tool_names
1618 .iter()
1619 .any(|name| name == "test_extension__available_tool"));
1620 assert!(!tool_names
1621 .iter()
1622 .any(|name| name == "test_extension__hidden_tool"));
1623 assert!(tool_names.len() == 1);
1624 }
1625
1626 #[tokio::test]
1627 async fn test_tool_availability_defaults_to_available() {
1628 let extension_manager = ExtensionManager::new_without_provider();
1629
1630 extension_manager
1631 .add_mock_extension_with_tools(
1632 "test_extension".to_string(),
1633 Arc::new(Mutex::new(Box::new(MockClient {}))),
1634 vec![], )
1636 .await;
1637
1638 let tools = extension_manager.get_prefixed_tools(None).await.unwrap();
1639
1640 let tool_names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect();
1641 assert!(tool_names.iter().any(|name| name == "test_extension__tool"));
1642 assert!(tool_names
1643 .iter()
1644 .any(|name| name == "test_extension__available_tool"));
1645 assert!(tool_names
1646 .iter()
1647 .any(|name| name == "test_extension__hidden_tool"));
1648 assert!(tool_names.len() == 3);
1649 }
1650
1651 #[tokio::test]
1652 async fn test_dispatch_unavailable_tool_returns_error() {
1653 let extension_manager = ExtensionManager::new_without_provider();
1654
1655 let available_tools = vec!["available_tool".to_string()];
1656
1657 extension_manager
1658 .add_mock_extension_with_tools(
1659 "test_extension".to_string(),
1660 Arc::new(Mutex::new(Box::new(MockClient {}))),
1661 available_tools,
1662 )
1663 .await;
1664
1665 let unavailable_tool_call = CallToolRequestParam {
1667 name: "test_extension__tool".to_string().into(),
1668 arguments: Some(object!({})),
1669 };
1670
1671 let result = extension_manager
1672 .dispatch_tool_call(unavailable_tool_call, CancellationToken::default())
1673 .await;
1674
1675 if let Err(err) = result {
1677 let tool_err = err.downcast_ref::<ErrorData>().expect("Expected ErrorData");
1678 assert_eq!(tool_err.code, ErrorCode::RESOURCE_NOT_FOUND);
1679 assert!(tool_err.message.contains("is not available"));
1680 } else {
1681 panic!("Expected ErrorData with ErrorCode::RESOURCE_NOT_FOUND");
1682 }
1683
1684 let available_tool_call = CallToolRequestParam {
1686 name: "test_extension__available_tool".to_string().into(),
1687 arguments: Some(object!({})),
1688 };
1689
1690 let result = extension_manager
1691 .dispatch_tool_call(available_tool_call, CancellationToken::default())
1692 .await;
1693
1694 assert!(result.is_ok());
1695 }
1696
1697 #[tokio::test]
1698 async fn test_streamable_http_header_env_substitution() {
1699 let mut env_map = HashMap::new();
1700 env_map.insert("AUTH_TOKEN".to_string(), "secret123".to_string());
1701 env_map.insert("API_KEY".to_string(), "key456".to_string());
1702
1703 let result = substitute_env_vars("Bearer ${ AUTH_TOKEN }", &env_map);
1705 assert_eq!(result, "Bearer secret123");
1706
1707 let result = substitute_env_vars("Bearer ${AUTH_TOKEN}", &env_map);
1709 assert_eq!(result, "Bearer secret123");
1710
1711 let result = substitute_env_vars("Bearer $AUTH_TOKEN", &env_map);
1713 assert_eq!(result, "Bearer secret123");
1714
1715 let result = substitute_env_vars("Key: $API_KEY, Token: ${AUTH_TOKEN}", &env_map);
1717 assert_eq!(result, "Key: key456, Token: secret123");
1718
1719 let result = substitute_env_vars("Bearer ${UNKNOWN_VAR}", &env_map);
1721 assert_eq!(result, "Bearer ${UNKNOWN_VAR}");
1722
1723 let result = substitute_env_vars(
1725 "Authorization: Bearer ${AUTH_TOKEN} and API ${API_KEY}",
1726 &env_map,
1727 );
1728 assert_eq!(result, "Authorization: Bearer secret123 and API key456");
1729 }
1730
1731 mod generate_extension_name_tests {
1732 use super::*;
1733 use rmcp::model::Implementation;
1734 use test_case::test_case;
1735
1736 fn make_info(name: &str) -> ServerInfo {
1737 ServerInfo {
1738 server_info: Implementation {
1739 name: name.into(),
1740 ..Default::default()
1741 },
1742 ..Default::default()
1743 }
1744 }
1745
1746 #[test_case(Some("kiwi-mcp-server"), None, "^kiwi-mcp-server$" ; "already normalized server name")]
1747 #[test_case(Some("Context7"), None, "^context7$" ; "mixed case normalized")]
1748 #[test_case(Some("@huggingface/mcp-services"), None, "^_huggingface_mcp-services$" ; "special chars normalized")]
1749 #[test_case(None, None, "^unnamed$" ; "no server info falls back")]
1750 #[test_case(Some(""), None, "^unnamed$" ; "empty server name falls back")]
1751 #[test_case(Some("github-mcp-server"), Some("github-mcp-server"), r"^github-mcp-server_[A-Za-z0-9]{6}$" ; "duplicate adds suffix")]
1752 fn test_generate_name(server_name: Option<&str>, collision: Option<&str>, expected: &str) {
1753 let info = server_name.map(make_info);
1754 let result = generate_extension_name(info.as_ref(), |n| collision == Some(n));
1755 let re = regex::Regex::new(expected).unwrap();
1756 assert!(re.is_match(&result));
1757 }
1758 }
1759
1760 #[tokio::test]
1761 async fn test_collect_moim_uses_minute_granularity() {
1762 let em = ExtensionManager::new_without_provider();
1763
1764 if let Some(moim) = em.collect_moim().await {
1765 assert!(
1767 moim.contains(":00\n"),
1768 "Timestamp should use minute granularity"
1769 );
1770 }
1771 }
1772}