Skip to main content

aster/agents/
extension_manager.rs

1use anyhow::Result;
2use axum::http::{HeaderMap, HeaderName};
3use chrono::{DateTime, Utc};
4use futures::stream::{FuturesUnordered, StreamExt};
5use futures::{future, FutureExt};
6use rand::{distributions::Alphanumeric, Rng};
7use rmcp::service::{ClientInitializeError, ServiceError};
8use rmcp::transport::streamable_http_client::{
9    AuthRequiredError, StreamableHttpClientTransportConfig, StreamableHttpError,
10};
11use rmcp::transport::{
12    ConfigureCommandExt, DynamicTransportError, StreamableHttpClientTransport, TokioChildProcess,
13};
14use std::collections::HashMap;
15use std::option::Option;
16use std::path::PathBuf;
17use std::process::Stdio;
18use std::sync::Arc;
19use std::time::Duration;
20use tempfile::{tempdir, TempDir};
21use tokio::io::AsyncReadExt;
22use tokio::process::Command;
23use tokio::sync::Mutex;
24use tokio::task;
25use tokio_stream::wrappers::ReceiverStream;
26use tokio_util::sync::CancellationToken;
27use tracing::{error, warn};
28
29use super::extension::{
30    ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, PlatformExtensionContext,
31    ToolInfo, PLATFORM_EXTENSIONS,
32};
33use super::tool_execution::ToolCallResult;
34use super::types::SharedProvider;
35use crate::agents::extension::{Envs, ProcessExit};
36use crate::agents::extension_malware_check;
37use crate::agents::mcp_client::{McpClient, McpClientTrait};
38use crate::config::search_path::SearchPaths;
39use crate::config::{get_all_extensions, Config};
40use crate::oauth::oauth_flow;
41use crate::prompt_template;
42use crate::subprocess::configure_command_no_window;
43use rmcp::model::{
44    CallToolRequestParam, Content, ErrorCode, ErrorData, GetPromptResult, Prompt, Resource,
45    ResourceContents, ServerInfo, Tool,
46};
47use rmcp::transport::auth::AuthClient;
48use schemars::_private::NoSerialize;
49use serde_json::Value;
50
51type McpClientBox = Arc<Mutex<Box<dyn McpClientTrait>>>;
52
53struct Extension {
54    pub config: ExtensionConfig,
55
56    client: McpClientBox,
57    server_info: Option<ServerInfo>,
58    _temp_dir: Option<tempfile::TempDir>,
59}
60
61impl Extension {
62    fn new(
63        config: ExtensionConfig,
64        client: McpClientBox,
65        server_info: Option<ServerInfo>,
66        temp_dir: Option<tempfile::TempDir>,
67    ) -> Self {
68        Self {
69            client,
70            config,
71            server_info,
72            _temp_dir: temp_dir,
73        }
74    }
75
76    fn supports_resources(&self) -> bool {
77        self.server_info
78            .as_ref()
79            .and_then(|info| info.capabilities.resources.as_ref())
80            .is_some()
81    }
82
83    fn get_instructions(&self) -> Option<String> {
84        self.server_info
85            .as_ref()
86            .and_then(|info| info.instructions.clone())
87    }
88
89    fn get_client(&self) -> McpClientBox {
90        self.client.clone()
91    }
92}
93
94/// Manages aster extensions / MCP clients and their interactions
95pub struct ExtensionManager {
96    extensions: Mutex<HashMap<String, Extension>>,
97    context: Mutex<PlatformExtensionContext>,
98    provider: SharedProvider,
99}
100
101/// A flattened representation of a resource used by the agent to prepare inference
102#[derive(Debug, Clone)]
103pub struct ResourceItem {
104    pub client_name: String,      // The name of the client that owns the resource
105    pub uri: String,              // The URI of the resource
106    pub name: String,             // The name of the resource
107    pub content: String,          // The content of the resource
108    pub timestamp: DateTime<Utc>, // The timestamp of the resource
109    pub priority: f32,            // The priority of the resource
110    pub token_count: Option<u32>, // The token count of the resource (filled in by the agent)
111}
112
113impl ResourceItem {
114    pub fn new(
115        client_name: String,
116        uri: String,
117        name: String,
118        content: String,
119        timestamp: DateTime<Utc>,
120        priority: f32,
121    ) -> Self {
122        Self {
123            client_name,
124            uri,
125            name,
126            content,
127            timestamp,
128            priority,
129            token_count: None,
130        }
131    }
132}
133
134/// Sanitizes a string by replacing invalid characters with underscores.
135/// Valid characters match [a-zA-Z0-9_-]
136fn normalize(input: String) -> String {
137    let mut result = String::with_capacity(input.len());
138    for c in input.chars() {
139        result.push(match c {
140            c if c.is_ascii_alphanumeric() || c == '_' || c == '-' => c,
141            c if c.is_whitespace() => continue, // effectively "strip" whitespace
142            _ => '_',                           // Replace any other non-ASCII character with '_'
143        });
144    }
145    result.to_lowercase()
146}
147
148/// Generates extension name from server info; adds random suffix on collision.
149fn generate_extension_name(
150    server_info: Option<&ServerInfo>,
151    name_exists: impl Fn(&str) -> bool,
152) -> String {
153    let base = server_info
154        .and_then(|info| {
155            let name = info.server_info.name.as_str();
156            (!name.is_empty()).then(|| normalize(name.to_string()))
157        })
158        .unwrap_or_else(|| "unnamed".to_string());
159
160    if !name_exists(&base) {
161        return base;
162    }
163
164    let suffix: String = rand::thread_rng()
165        .sample_iter(Alphanumeric)
166        .take(6)
167        .map(char::from)
168        .collect();
169
170    format!("{base}_{suffix}")
171}
172
173fn resolve_command(cmd: &str) -> PathBuf {
174    SearchPaths::builder()
175        .with_npm()
176        .resolve(cmd)
177        .unwrap_or_else(|_| {
178            // let the OS raise the error
179            PathBuf::from(cmd)
180        })
181}
182
183fn require_str_parameter<'a>(v: &'a serde_json::Value, name: &str) -> Result<&'a str, ErrorData> {
184    let v = v.get(name).ok_or_else(|| {
185        ErrorData::new(
186            ErrorCode::INVALID_PARAMS,
187            format!("The parameter {name} is required"),
188            None,
189        )
190    })?;
191    match v.as_str() {
192        Some(r) => Ok(r),
193        None => Err(ErrorData::new(
194            ErrorCode::INVALID_PARAMS,
195            format!("The parameter {name} must be a string"),
196            None,
197        )),
198    }
199}
200
201pub fn get_parameter_names(tool: &Tool) -> Vec<String> {
202    let mut names: Vec<String> = tool
203        .input_schema
204        .get("properties")
205        .and_then(|props| props.as_object())
206        .map(|props| props.keys().cloned().collect())
207        .unwrap_or_default();
208    names.sort();
209    names
210}
211
212impl Default for ExtensionManager {
213    fn default() -> Self {
214        Self::new(Arc::new(Mutex::new(None)))
215    }
216}
217
218async fn child_process_client(
219    mut command: Command,
220    timeout: &Option<u64>,
221    provider: SharedProvider,
222) -> ExtensionResult<McpClient> {
223    #[cfg(unix)]
224    command.process_group(0);
225    configure_command_no_window(&mut command);
226
227    if let Ok(path) = SearchPaths::builder().path() {
228        command.env("PATH", path);
229    }
230
231    let (transport, mut stderr) = TokioChildProcess::builder(command)
232        .stderr(Stdio::piped())
233        .spawn()?;
234    let mut stderr = stderr.take().ok_or_else(|| {
235        ExtensionError::SetupError("failed to attach child process stderr".to_owned())
236    })?;
237
238    let stderr_task = tokio::spawn(async move {
239        let mut all_stderr = Vec::new();
240        stderr.read_to_end(&mut all_stderr).await?;
241        Ok::<String, std::io::Error>(String::from_utf8_lossy(&all_stderr).into())
242    });
243
244    let client_result = McpClient::connect(
245        transport,
246        Duration::from_secs(timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT)),
247        provider,
248    )
249    .await;
250
251    match client_result {
252        Ok(client) => Ok(client),
253        Err(error) => {
254            let error_task_out = stderr_task.await?;
255            Err::<McpClient, ExtensionError>(match error_task_out {
256                Ok(stderr_content) => ProcessExit::new(stderr_content, error).into(),
257                Err(e) => e.into(),
258            })
259        }
260    }
261}
262
263fn extract_auth_error(
264    res: &Result<McpClient, ClientInitializeError>,
265) -> Option<&AuthRequiredError> {
266    match res {
267        Ok(_) => None,
268        Err(err) => match err {
269            ClientInitializeError::TransportError {
270                error: DynamicTransportError { error, .. },
271                ..
272            } => error
273                .downcast_ref::<StreamableHttpError<reqwest::Error>>()
274                .and_then(|auth_error| match auth_error {
275                    StreamableHttpError::AuthRequired(auth_required_error) => {
276                        Some(auth_required_error)
277                    }
278                    _ => None,
279                }),
280            _ => None,
281        },
282    }
283}
284
285/// Merge environment variables from direct envs and keychain-stored env_keys
286async fn merge_environments(
287    envs: &Envs,
288    env_keys: &[String],
289    ext_name: &str,
290) -> Result<HashMap<String, String>, ExtensionError> {
291    let mut all_envs = envs.get_env();
292    let config_instance = Config::global();
293
294    for key in env_keys {
295        if all_envs.contains_key(key) {
296            continue;
297        }
298
299        match config_instance.get(key, true) {
300            Ok(value) => {
301                if value.is_null() {
302                    warn!(
303                        key = %key,
304                        ext_name = %ext_name,
305                        "Secret key not found in config (returned null)."
306                    );
307                    continue;
308                }
309
310                if let Some(str_val) = value.as_str() {
311                    all_envs.insert(key.clone(), str_val.to_string());
312                } else {
313                    warn!(
314                        key = %key,
315                        ext_name = %ext_name,
316                        value_type = %value.get("type").and_then(|t| t.as_str()).unwrap_or("unknown"),
317                        "Secret value is not a string; skipping."
318                    );
319                }
320            }
321            Err(e) => {
322                error!(
323                    key = %key,
324                    ext_name = %ext_name,
325                    error = %e,
326                    "Failed to fetch secret from config."
327                );
328                return Err(ExtensionError::ConfigError(format!(
329                    "Failed to fetch secret '{}' from config: {}",
330                    key, e
331                )));
332            }
333        }
334    }
335
336    Ok(all_envs)
337}
338
339/// Substitute environment variables in a string. Supports both ${VAR} and $VAR syntax.
340fn substitute_env_vars(value: &str, env_map: &HashMap<String, String>) -> String {
341    let mut result = value.to_string();
342
343    let re_braces =
344        regex::Regex::new(r"\$\{\s*([A-Za-z_][A-Za-z0-9_]*)\s*\}").expect("valid regex");
345    for cap in re_braces.captures_iter(value) {
346        if let Some(var_name) = cap.get(1) {
347            if let Some(env_value) = env_map.get(var_name.as_str()) {
348                result = result.replace(&cap[0], env_value);
349            }
350        }
351    }
352
353    let re_simple = regex::Regex::new(r"\$([A-Za-z_][A-Za-z0-9_]*)").expect("valid regex");
354    for cap in re_simple.captures_iter(&result.clone()) {
355        if let Some(var_name) = cap.get(1) {
356            if !value.contains(&format!("${{{}}}", var_name.as_str())) {
357                if let Some(env_value) = env_map.get(var_name.as_str()) {
358                    result = result.replace(&cap[0], env_value);
359                }
360            }
361        }
362    }
363
364    result
365}
366
367async fn create_streamable_http_client(
368    uri: &str,
369    timeout: Option<u64>,
370    headers: &HashMap<String, String>,
371    name: &str,
372    all_envs: &HashMap<String, String>,
373    provider: SharedProvider,
374) -> ExtensionResult<Box<dyn McpClientTrait>> {
375    let mut default_headers = HeaderMap::new();
376    for (key, value) in headers {
377        let substituted_value = substitute_env_vars(value, all_envs);
378        default_headers.insert(
379            HeaderName::try_from(key)
380                .map_err(|_| ExtensionError::ConfigError(format!("invalid header: {}", key)))?,
381            substituted_value.parse().map_err(|_| {
382                ExtensionError::ConfigError(format!("invalid header value: {}", key))
383            })?,
384        );
385    }
386
387    let http_client = reqwest::Client::builder()
388        .default_headers(default_headers)
389        .build()
390        .map_err(|_| ExtensionError::ConfigError("could not construct http client".to_string()))?;
391
392    let transport = StreamableHttpClientTransport::with_client(
393        http_client,
394        StreamableHttpClientTransportConfig {
395            uri: uri.into(),
396            ..Default::default()
397        },
398    );
399
400    let timeout_duration =
401        Duration::from_secs(timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT));
402
403    let client_res = McpClient::connect(transport, timeout_duration, provider.clone()).await;
404
405    if extract_auth_error(&client_res).is_some() {
406        let am = oauth_flow(&uri.to_string(), &name.to_string())
407            .await
408            .map_err(|_| ExtensionError::SetupError("auth error".to_string()))?;
409        let auth_client = AuthClient::new(reqwest::Client::default(), am);
410        let transport = StreamableHttpClientTransport::with_client(
411            auth_client,
412            StreamableHttpClientTransportConfig {
413                uri: uri.into(),
414                ..Default::default()
415            },
416        );
417        Ok(Box::new(
418            McpClient::connect(transport, timeout_duration, provider).await?,
419        ))
420    } else {
421        Ok(Box::new(client_res?))
422    }
423}
424
425async fn create_stdio_client(
426    cmd: &str,
427    args: &[String],
428    all_envs: HashMap<String, String>,
429    timeout: &Option<u64>,
430    provider: SharedProvider,
431) -> ExtensionResult<Box<dyn McpClientTrait>> {
432    extension_malware_check::deny_if_malicious_cmd_args(cmd, args).await?;
433
434    let resolved_cmd = resolve_command(cmd);
435    let command = Command::new(resolved_cmd).configure(|command| {
436        command.args(args).envs(all_envs);
437    });
438
439    Ok(Box::new(
440        child_process_client(command, timeout, provider).await?,
441    ))
442}
443
444impl ExtensionManager {
445    pub fn new(provider: SharedProvider) -> Self {
446        Self {
447            extensions: Mutex::new(HashMap::new()),
448            context: Mutex::new(PlatformExtensionContext {
449                session_id: None,
450                extension_manager: None,
451            }),
452            provider,
453        }
454    }
455
456    /// Create a new ExtensionManager with no provider (useful for tests)
457    pub fn new_without_provider() -> Self {
458        Self::new(Arc::new(Mutex::new(None)))
459    }
460
461    pub async fn set_context(&self, context: PlatformExtensionContext) {
462        *self.context.lock().await = context;
463    }
464
465    pub async fn get_context(&self) -> PlatformExtensionContext {
466        self.context.lock().await.clone()
467    }
468
469    pub async fn supports_resources(&self) -> bool {
470        self.extensions
471            .lock()
472            .await
473            .values()
474            .any(|ext| ext.supports_resources())
475    }
476
477    pub async fn add_extension(&self, config: ExtensionConfig) -> ExtensionResult<()> {
478        let config_name = config.key().to_string();
479        let sanitized_name = normalize(config_name.clone());
480
481        if self.extensions.lock().await.contains_key(&sanitized_name) {
482            return Ok(());
483        }
484
485        let mut temp_dir = None;
486
487        let client: Box<dyn McpClientTrait> = match &config {
488            ExtensionConfig::Sse { .. } => {
489                return Err(ExtensionError::ConfigError(
490                    "SSE is unsupported, migrate to streamable_http".to_string(),
491                ));
492            }
493            ExtensionConfig::StreamableHttp {
494                uri,
495                timeout,
496                headers,
497                name,
498                envs,
499                env_keys,
500                ..
501            } => {
502                let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
503                create_streamable_http_client(
504                    uri,
505                    *timeout,
506                    headers,
507                    name,
508                    &all_envs,
509                    self.provider.clone(),
510                )
511                .await?
512            }
513            ExtensionConfig::Stdio {
514                cmd,
515                args,
516                envs,
517                env_keys,
518                timeout,
519                ..
520            } => {
521                let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
522                create_stdio_client(cmd, args, all_envs, timeout, self.provider.clone()).await?
523            }
524            ExtensionConfig::Builtin { name, timeout, .. } => {
525                let cmd = std::env::current_exe()
526                    .and_then(|path| {
527                        path.to_str().map(|s| s.to_string()).ok_or_else(|| {
528                            std::io::Error::new(
529                                std::io::ErrorKind::InvalidData,
530                                "Invalid UTF-8 in executable path",
531                            )
532                        })
533                    })
534                    .map_err(|e| {
535                        ExtensionError::ConfigError(format!(
536                            "Failed to resolve executable path: {}",
537                            e
538                        ))
539                    })?;
540                let command = Command::new(cmd).configure(|command| {
541                    command.arg("mcp").arg(name);
542                });
543                Box::new(child_process_client(command, timeout, self.provider.clone()).await?)
544            }
545            ExtensionConfig::Platform { name, .. } => {
546                let normalized_key = normalize(name.clone());
547                let def = PLATFORM_EXTENSIONS
548                    .get(normalized_key.as_str())
549                    .ok_or_else(|| {
550                        ExtensionError::ConfigError(format!("Unknown platform extension: {}", name))
551                    })?;
552                let context = self.get_context().await;
553                (def.client_factory)(context)
554            }
555            ExtensionConfig::InlinePython {
556                name,
557                code,
558                timeout,
559                dependencies,
560                ..
561            } => {
562                let dir = tempdir()?;
563                let file_path = dir.path().join(format!("{}.py", name));
564                temp_dir = Some(dir);
565                std::fs::write(&file_path, code)?;
566
567                let command = Command::new("uvx").configure(|command| {
568                    command.arg("--with").arg("mcp");
569                    dependencies.iter().flatten().for_each(|dep| {
570                        command.arg("--with").arg(dep);
571                    });
572                    command.arg("python").arg(file_path.to_str().unwrap());
573                });
574
575                Box::new(child_process_client(command, timeout, self.provider.clone()).await?)
576            }
577            ExtensionConfig::Frontend { .. } => {
578                return Err(ExtensionError::ConfigError(
579                    "Invalid extension type: Frontend extensions cannot be added as server extensions".to_string()
580                ));
581            }
582        };
583
584        let server_info = client.get_info().cloned();
585
586        // Only generate name from server info when config has no name (e.g., CLI --with-*-extension args)
587        let mut extensions = self.extensions.lock().await;
588        let final_name = if sanitized_name.is_empty() {
589            generate_extension_name(server_info.as_ref(), |n| extensions.contains_key(n))
590        } else {
591            sanitized_name
592        };
593        extensions.insert(
594            final_name,
595            Extension::new(config, Arc::new(Mutex::new(client)), server_info, temp_dir),
596        );
597
598        Ok(())
599    }
600
601    pub async fn add_client(
602        &self,
603        name: String,
604        config: ExtensionConfig,
605        client: McpClientBox,
606        info: Option<ServerInfo>,
607        temp_dir: Option<TempDir>,
608    ) {
609        self.extensions
610            .lock()
611            .await
612            .insert(name, Extension::new(config, client, info, temp_dir));
613    }
614
615    /// Get extensions info for building the system prompt
616    pub async fn get_extensions_info(&self) -> Vec<ExtensionInfo> {
617        self.extensions
618            .lock()
619            .await
620            .iter()
621            .map(|(name, ext)| {
622                ExtensionInfo::new(
623                    name,
624                    ext.get_instructions().unwrap_or_default().as_str(),
625                    ext.supports_resources(),
626                )
627            })
628            .collect()
629    }
630
631    /// Get aggregated usage statistics
632    pub async fn remove_extension(&self, name: &str) -> ExtensionResult<()> {
633        let sanitized_name = normalize(name.to_string());
634        self.extensions.lock().await.remove(&sanitized_name);
635        Ok(())
636    }
637
638    pub async fn get_extension_and_tool_counts(&self) -> (usize, usize) {
639        let enabled_extensions_count = self.extensions.lock().await.len();
640
641        let total_tools = self
642            .get_prefixed_tools(None)
643            .await
644            .map(|tools| tools.len())
645            .unwrap_or(0);
646
647        (enabled_extensions_count, total_tools)
648    }
649
650    pub async fn list_extensions(&self) -> ExtensionResult<Vec<String>> {
651        Ok(self.extensions.lock().await.keys().cloned().collect())
652    }
653
654    pub async fn is_extension_enabled(&self, name: &str) -> bool {
655        self.extensions.lock().await.contains_key(name)
656    }
657
658    pub async fn get_extension_configs(&self) -> Vec<ExtensionConfig> {
659        self.extensions
660            .lock()
661            .await
662            .values()
663            .map(|ext| ext.config.clone())
664            .collect()
665    }
666
667    /// Get all tools from all clients with proper prefixing
668    pub async fn get_prefixed_tools(
669        &self,
670        extension_name: Option<String>,
671    ) -> ExtensionResult<Vec<Tool>> {
672        self.get_prefixed_tools_impl(extension_name, None).await
673    }
674
675    async fn get_prefixed_tools_impl(
676        &self,
677        extension_name: Option<String>,
678        exclude: Option<&str>,
679    ) -> ExtensionResult<Vec<Tool>> {
680        // Filter clients based on the provided extension_name or include all if None
681        let filtered_clients: Vec<_> = self
682            .extensions
683            .lock()
684            .await
685            .iter()
686            .filter(|(name, _ext)| {
687                if let Some(excluded) = exclude {
688                    if name.as_str() == excluded {
689                        return false;
690                    }
691                }
692
693                if let Some(ref name_filter) = extension_name {
694                    *name == name_filter
695                } else {
696                    true
697                }
698            })
699            .map(|(name, ext)| (name.clone(), ext.config.clone(), ext.get_client()))
700            .collect();
701
702        let cancel_token = CancellationToken::default();
703        let client_futures = filtered_clients.into_iter().map(|(name, config, client)| {
704            let cancel_token = cancel_token.clone();
705            task::spawn(async move {
706                let mut tools = Vec::new();
707                let client_guard = client.lock().await;
708                let mut client_tools = client_guard.list_tools(None, cancel_token).await?;
709
710                loop {
711                    for tool in client_tools.tools {
712                        let is_available = config.is_tool_available(&tool.name);
713
714                        if is_available {
715                            tools.push(Tool {
716                                name: format!("{}__{}", name, tool.name).into(),
717                                description: tool.description,
718                                input_schema: tool.input_schema,
719                                annotations: tool.annotations,
720                                output_schema: tool.output_schema,
721                                icons: tool.icons,
722                                title: tool.title,
723                                meta: tool.meta,
724                            });
725                        }
726                    }
727
728                    if client_tools.next_cursor.is_none() {
729                        break;
730                    }
731
732                    client_tools = client_guard
733                        .list_tools(client_tools.next_cursor, CancellationToken::default())
734                        .await?;
735                }
736
737                Ok::<Vec<Tool>, ExtensionError>(tools)
738            })
739        });
740
741        // Collect all results concurrently
742        let results = future::join_all(client_futures).await;
743
744        // Aggregate tools and handle errors
745        let mut tools = Vec::new();
746        for result in results {
747            match result {
748                Ok(Ok(client_tools)) => tools.extend(client_tools),
749                Ok(Err(err)) => return Err(err),
750                Err(join_err) => return Err(ExtensionError::from(join_err)),
751            }
752        }
753
754        Ok(tools)
755    }
756
757    pub async fn get_prefixed_tools_excluding(&self, exclude: &str) -> ExtensionResult<Vec<Tool>> {
758        self.get_prefixed_tools_impl(None, Some(exclude)).await
759    }
760
761    /// Get the extension prompt including client instructions
762    pub async fn get_planning_prompt(&self, tools_info: Vec<ToolInfo>) -> String {
763        let mut context: HashMap<&str, Value> = HashMap::new();
764        context.insert("tools", serde_json::to_value(tools_info).unwrap());
765
766        prompt_template::render_global_file("plan.md", &context).expect("Prompt should render")
767    }
768
769    /// Find and return a reference to the appropriate client for a tool call
770    async fn get_client_for_tool(&self, prefixed_name: &str) -> Option<(String, McpClientBox)> {
771        self.extensions
772            .lock()
773            .await
774            .iter()
775            .find(|(key, _)| prefixed_name.starts_with(*key))
776            .map(|(name, extension)| (name.clone(), extension.get_client()))
777    }
778
779    // Function that gets executed for read_resource tool
780    pub async fn read_resource_tool(
781        &self,
782        params: Value,
783        cancellation_token: CancellationToken,
784    ) -> Result<Vec<Content>, ErrorData> {
785        let uri = require_str_parameter(&params, "uri")?;
786
787        let extension_name = params.get("extension_name").and_then(|v| v.as_str());
788
789        // If extension name is provided, we can just look it up
790        if let Some(ext_name) = extension_name {
791            let read_result = self
792                .read_resource(uri, ext_name, cancellation_token.clone())
793                .await?;
794
795            let mut result = Vec::new();
796            for content in read_result.contents {
797                if let ResourceContents::TextResourceContents { text, .. } = content {
798                    let content_str = format!("{}\n\n{}", uri, text);
799                    result.push(Content::text(content_str));
800                }
801            }
802            return Ok(result);
803        }
804
805        // If extension name is not provided, we need to search for the resource across all extensions
806        // Loop through each extension and try to read the resource, don't raise an error if the resource is not found
807        // TODO: do we want to find if a provided uri is in multiple extensions?
808        // currently it will return the first match and skip any others
809
810        // Collect extension names first to avoid holding the lock during iteration
811        let extension_names: Vec<String> = self.extensions.lock().await.keys().cloned().collect();
812
813        for extension_name in extension_names {
814            let read_result = self
815                .read_resource(uri, &extension_name, cancellation_token.clone())
816                .await;
817            match read_result {
818                Ok(read_result) => {
819                    let mut result = Vec::new();
820                    for content in read_result.contents {
821                        if let ResourceContents::TextResourceContents { text, .. } = content {
822                            let content_str = format!("{}\n\n{}", uri, text);
823                            result.push(Content::text(content_str));
824                        }
825                    }
826                    return Ok(result);
827                }
828                Err(_) => continue,
829            }
830        }
831
832        // None of the extensions had the resource so we raise an error
833        let available_extensions = self
834            .extensions
835            .lock()
836            .await
837            .keys()
838            .map(|s| s.as_str())
839            .collect::<Vec<&str>>()
840            .join(", ");
841        let error_msg = format!(
842            "Resource with uri '{}' not found. Here are the available extensions: {}",
843            uri, available_extensions
844        );
845
846        Err(ErrorData::new(
847            ErrorCode::RESOURCE_NOT_FOUND,
848            error_msg,
849            None,
850        ))
851    }
852
853    pub async fn read_resource(
854        &self,
855        uri: &str,
856        extension_name: &str,
857        cancellation_token: CancellationToken,
858    ) -> Result<rmcp::model::ReadResourceResult, ErrorData> {
859        let available_extensions = self
860            .extensions
861            .lock()
862            .await
863            .keys()
864            .map(|s| s.as_str())
865            .collect::<Vec<&str>>()
866            .join(", ");
867        let error_msg = format!(
868            "Extension '{}' not found. Here are the available extensions: {}",
869            extension_name, available_extensions
870        );
871
872        let client = self
873            .get_server_client(extension_name)
874            .await
875            .ok_or(ErrorData::new(ErrorCode::INVALID_PARAMS, error_msg, None))?;
876
877        let client_guard = client.lock().await;
878        client_guard
879            .read_resource(uri, cancellation_token)
880            .await
881            .map_err(|_| {
882                ErrorData::new(
883                    ErrorCode::INTERNAL_ERROR,
884                    format!("Could not read resource with uri: {}", uri),
885                    None,
886                )
887            })
888    }
889
890    pub async fn get_ui_resources(&self) -> Result<Vec<(String, Resource)>, ErrorData> {
891        let mut ui_resources = Vec::new();
892
893        let extensions_to_check: Vec<(String, McpClientBox)> = {
894            let extensions = self.extensions.lock().await;
895            extensions
896                .iter()
897                .map(|(name, ext)| (name.clone(), ext.get_client()))
898                .collect()
899        };
900
901        for (extension_name, client) in extensions_to_check {
902            let client_guard = client.lock().await;
903
904            match client_guard
905                .list_resources(None, CancellationToken::default())
906                .await
907            {
908                Ok(list_response) => {
909                    for resource in list_response.resources {
910                        if resource.uri.starts_with("ui://") {
911                            ui_resources.push((extension_name.clone(), resource));
912                        }
913                    }
914                }
915                Err(e) => {
916                    warn!("Failed to list resources for {}: {:?}", extension_name, e);
917                }
918            }
919        }
920
921        Ok(ui_resources)
922    }
923
924    async fn list_resources_from_extension(
925        &self,
926        extension_name: &str,
927        cancellation_token: CancellationToken,
928    ) -> Result<Vec<Content>, ErrorData> {
929        let client = self
930            .get_server_client(extension_name)
931            .await
932            .ok_or_else(|| {
933                ErrorData::new(
934                    ErrorCode::INVALID_PARAMS,
935                    format!("Extension {} is not valid", extension_name),
936                    None,
937                )
938            })?;
939
940        let client_guard = client.lock().await;
941        client_guard
942            .list_resources(None, cancellation_token)
943            .await
944            .map_err(|e| {
945                ErrorData::new(
946                    ErrorCode::INTERNAL_ERROR,
947                    format!("Unable to list resources for {}, {:?}", extension_name, e),
948                    None,
949                )
950            })
951            .map(|lr| {
952                let resource_list = lr
953                    .resources
954                    .into_iter()
955                    .map(|r| format!("{} - {}, uri: ({})", extension_name, r.name, r.uri))
956                    .collect::<Vec<String>>()
957                    .join("\n");
958
959                vec![Content::text(resource_list)]
960            })
961    }
962
963    pub async fn list_resources(
964        &self,
965        params: Value,
966        cancellation_token: CancellationToken,
967    ) -> Result<Vec<Content>, ErrorData> {
968        let extension = params.get("extension").and_then(|v| v.as_str());
969
970        match extension {
971            Some(extension_name) => {
972                // Handle single extension case
973                self.list_resources_from_extension(extension_name, cancellation_token)
974                    .await
975            }
976            None => {
977                // Handle all extensions case using FuturesUnordered
978                let mut futures = FuturesUnordered::new();
979
980                // Create futures for each resource_capable_extension
981                self.extensions
982                    .lock()
983                    .await
984                    .iter()
985                    .filter(|(_name, ext)| ext.supports_resources())
986                    .map(|(name, _ext)| name.clone())
987                    .for_each(|name| {
988                        let token = cancellation_token.clone();
989                        futures.push(async move {
990                            self.list_resources_from_extension(&name.clone(), token)
991                                .await
992                        });
993                    });
994
995                let mut all_resources = Vec::new();
996                let mut errors = Vec::new();
997
998                // Process results as they complete
999                while let Some(result) = futures.next().await {
1000                    match result {
1001                        Ok(content) => {
1002                            all_resources.extend(content);
1003                        }
1004                        Err(tool_error) => {
1005                            errors.push(tool_error);
1006                        }
1007                    }
1008                }
1009
1010                if !errors.is_empty() {
1011                    tracing::error!(
1012                        errors = ?errors
1013                            .into_iter()
1014                            .map(|e| format!("{:?}", e))
1015                            .collect::<Vec<_>>(),
1016                        "errors from listing resources"
1017                    );
1018                }
1019
1020                Ok(all_resources)
1021            }
1022        }
1023    }
1024
1025    pub async fn dispatch_tool_call(
1026        &self,
1027        tool_call: CallToolRequestParam,
1028        cancellation_token: CancellationToken,
1029    ) -> Result<ToolCallResult> {
1030        // Dispatch tool call based on the prefix naming convention
1031        let (client_name, client) =
1032            self.get_client_for_tool(&tool_call.name)
1033                .await
1034                .ok_or_else(|| {
1035                    ErrorData::new(ErrorCode::RESOURCE_NOT_FOUND, tool_call.name.clone(), None)
1036                })?;
1037
1038        // rsplit returns the iterator in reverse, tool_name is then at 0
1039        let tool_name = tool_call
1040            .name
1041            .strip_prefix(client_name.as_str())
1042            .and_then(|s| s.strip_prefix("__"))
1043            .ok_or_else(|| {
1044                ErrorData::new(ErrorCode::RESOURCE_NOT_FOUND, tool_call.name.clone(), None)
1045            })?
1046            .to_string();
1047
1048        if let Some(extension) = self.extensions.lock().await.get(&client_name) {
1049            if !extension.config.is_tool_available(&tool_name) {
1050                return Err(ErrorData::new(
1051                    ErrorCode::RESOURCE_NOT_FOUND,
1052                    format!(
1053                        "Tool '{}' is not available for extension '{}'",
1054                        tool_name, client_name
1055                    ),
1056                    None,
1057                )
1058                .into());
1059            }
1060        }
1061
1062        let arguments = tool_call.arguments.clone();
1063        let client = client.clone();
1064        let notifications_receiver = client.lock().await.subscribe().await;
1065
1066        let fut = async move {
1067            let client_guard = client.lock().await;
1068            client_guard
1069                .call_tool(&tool_name, arguments, cancellation_token)
1070                .await
1071                .map_err(|e| match e {
1072                    ServiceError::McpError(error_data) => error_data,
1073                    _ => {
1074                        ErrorData::new(ErrorCode::INTERNAL_ERROR, e.to_string(), e.maybe_to_value())
1075                    }
1076                })
1077        };
1078
1079        Ok(ToolCallResult {
1080            result: Box::new(fut.boxed()),
1081            notification_stream: Some(Box::new(ReceiverStream::new(notifications_receiver))),
1082        })
1083    }
1084
1085    pub async fn list_prompts_from_extension(
1086        &self,
1087        extension_name: &str,
1088        cancellation_token: CancellationToken,
1089    ) -> Result<Vec<Prompt>, ErrorData> {
1090        let client = self
1091            .get_server_client(extension_name)
1092            .await
1093            .ok_or_else(|| {
1094                ErrorData::new(
1095                    ErrorCode::INVALID_PARAMS,
1096                    format!("Extension {} is not valid", extension_name),
1097                    None,
1098                )
1099            })?;
1100
1101        let client_guard = client.lock().await;
1102        client_guard
1103            .list_prompts(None, cancellation_token)
1104            .await
1105            .map_err(|e| {
1106                ErrorData::new(
1107                    ErrorCode::INTERNAL_ERROR,
1108                    format!("Unable to list prompts for {}, {:?}", extension_name, e),
1109                    None,
1110                )
1111            })
1112            .map(|lp| lp.prompts)
1113    }
1114
1115    pub async fn list_prompts(
1116        &self,
1117        cancellation_token: CancellationToken,
1118    ) -> Result<HashMap<String, Vec<Prompt>>, ErrorData> {
1119        let mut futures = FuturesUnordered::new();
1120
1121        let names: Vec<_> = self.extensions.lock().await.keys().cloned().collect();
1122        for extension_name in names {
1123            let token = cancellation_token.clone();
1124            futures.push(async move {
1125                (
1126                    extension_name.clone(),
1127                    self.list_prompts_from_extension(extension_name.as_str(), token)
1128                        .await,
1129                )
1130            });
1131        }
1132
1133        let mut all_prompts = HashMap::new();
1134        let mut errors = Vec::new();
1135
1136        // Process results as they complete
1137        while let Some(result) = futures.next().await {
1138            let (name, prompts) = result;
1139            match prompts {
1140                Ok(content) => {
1141                    all_prompts.insert(name.to_string(), content);
1142                }
1143                Err(tool_error) => {
1144                    errors.push(tool_error);
1145                }
1146            }
1147        }
1148
1149        if !errors.is_empty() {
1150            tracing::debug!(
1151                errors = ?errors
1152                    .into_iter()
1153                    .map(|e| format!("{:?}", e))
1154                    .collect::<Vec<_>>(),
1155                "errors from listing prompts"
1156            );
1157        }
1158
1159        Ok(all_prompts)
1160    }
1161
1162    pub async fn get_prompt(
1163        &self,
1164        extension_name: &str,
1165        name: &str,
1166        arguments: Value,
1167        cancellation_token: CancellationToken,
1168    ) -> Result<GetPromptResult> {
1169        let client = self
1170            .get_server_client(extension_name)
1171            .await
1172            .ok_or_else(|| anyhow::anyhow!("Extension {} not found", extension_name))?;
1173
1174        let client_guard = client.lock().await;
1175        client_guard
1176            .get_prompt(name, arguments, cancellation_token)
1177            .await
1178            .map_err(|e| anyhow::anyhow!("Failed to get prompt: {}", e))
1179    }
1180
1181    pub async fn search_available_extensions(&self) -> Result<Vec<Content>, ErrorData> {
1182        let mut output_parts = vec![];
1183
1184        // First get disabled extensions from current config
1185        let mut disabled_extensions: Vec<String> = vec![];
1186        for extension in get_all_extensions() {
1187            if !extension.enabled {
1188                let config = extension.config.clone();
1189                let description = match &config {
1190                    ExtensionConfig::Builtin {
1191                        description,
1192                        display_name,
1193                        ..
1194                    } => {
1195                        if description.is_empty() {
1196                            display_name.as_deref().unwrap_or("Built-in extension")
1197                        } else {
1198                            description
1199                        }
1200                    }
1201                    ExtensionConfig::Sse { .. } => "SSE extension (unsupported)",
1202                    ExtensionConfig::Platform { description, .. }
1203                    | ExtensionConfig::StreamableHttp { description, .. }
1204                    | ExtensionConfig::Stdio { description, .. }
1205                    | ExtensionConfig::Frontend { description, .. }
1206                    | ExtensionConfig::InlinePython { description, .. } => description,
1207                };
1208                disabled_extensions.push(format!("- {} - {}", config.name(), description));
1209            }
1210        }
1211
1212        // Get currently enabled extensions that can be disabled
1213        let enabled_extensions: Vec<String> =
1214            self.extensions.lock().await.keys().cloned().collect();
1215
1216        // Build output string
1217        if !disabled_extensions.is_empty() {
1218            output_parts.push(format!(
1219                "Extensions available to enable:\n{}\n",
1220                disabled_extensions.join("\n")
1221            ));
1222        } else {
1223            output_parts.push("No extensions available to enable.\n".to_string());
1224        }
1225
1226        if !enabled_extensions.is_empty() {
1227            output_parts.push(format!(
1228                "\n\nExtensions available to disable:\n{}\n",
1229                enabled_extensions
1230                    .iter()
1231                    .map(|name| format!("- {}", name))
1232                    .collect::<Vec<_>>()
1233                    .join("\n")
1234            ));
1235        } else {
1236            output_parts.push("No extensions that can be disabled.\n".to_string());
1237        }
1238
1239        Ok(vec![Content::text(output_parts.join("\n"))])
1240    }
1241
1242    async fn get_server_client(&self, name: impl Into<String>) -> Option<McpClientBox> {
1243        self.extensions
1244            .lock()
1245            .await
1246            .get(&name.into())
1247            .map(|ext| ext.get_client())
1248    }
1249
1250    pub async fn collect_moim(&self) -> Option<String> {
1251        // Use minute-level granularity to prevent conversation changes every second
1252        let timestamp = chrono::Local::now().format("%Y-%m-%d %H:%M:00").to_string();
1253        let mut content = format!("<info-msg>\nIt is currently {}\n", timestamp);
1254
1255        let platform_clients: Vec<(String, McpClientBox)> = {
1256            let extensions = self.extensions.lock().await;
1257            extensions
1258                .iter()
1259                .filter_map(|(name, extension)| {
1260                    if let ExtensionConfig::Platform { .. } = &extension.config {
1261                        Some((name.clone(), extension.get_client()))
1262                    } else {
1263                        None
1264                    }
1265                })
1266                .collect()
1267        };
1268
1269        for (name, client) in platform_clients {
1270            let client_guard = client.lock().await;
1271            if let Some(moim_content) = client_guard.get_moim().await {
1272                tracing::debug!("MOIM content from {}: {} chars", name, moim_content.len());
1273                content.push('\n');
1274                content.push_str(&moim_content);
1275            }
1276        }
1277
1278        content.push_str("\n</info-msg>");
1279
1280        Some(content)
1281    }
1282}
1283
1284#[cfg(test)]
1285mod tests {
1286    use super::*;
1287    use rmcp::model::CallToolResult;
1288    use rmcp::model::{InitializeResult, JsonObject};
1289    use rmcp::{object, ServiceError as Error};
1290
1291    use rmcp::model::ListPromptsResult;
1292    use rmcp::model::ListResourcesResult;
1293    use rmcp::model::ListToolsResult;
1294    use rmcp::model::ReadResourceResult;
1295    use rmcp::model::ServerNotification;
1296
1297    use tokio::sync::mpsc;
1298
1299    impl ExtensionManager {
1300        async fn add_mock_extension(&self, name: String, client: McpClientBox) {
1301            self.add_mock_extension_with_tools(name, client, vec![])
1302                .await;
1303        }
1304
1305        async fn add_mock_extension_with_tools(
1306            &self,
1307            name: String,
1308            client: McpClientBox,
1309            available_tools: Vec<String>,
1310        ) {
1311            let sanitized_name = normalize(name.clone());
1312            let config = ExtensionConfig::Builtin {
1313                name: name.clone(),
1314                display_name: Some(name.clone()),
1315                description: "built-in".to_string(),
1316                timeout: None,
1317                bundled: None,
1318                available_tools,
1319            };
1320            let extension = Extension::new(config, client, None, None);
1321            self.extensions
1322                .lock()
1323                .await
1324                .insert(sanitized_name, extension);
1325        }
1326    }
1327
1328    struct MockClient {}
1329
1330    #[async_trait::async_trait]
1331    impl McpClientTrait for MockClient {
1332        fn get_info(&self) -> Option<&InitializeResult> {
1333            None
1334        }
1335
1336        async fn list_resources(
1337            &self,
1338            _next_cursor: Option<String>,
1339            _cancellation_token: CancellationToken,
1340        ) -> Result<ListResourcesResult, Error> {
1341            Err(Error::TransportClosed)
1342        }
1343
1344        async fn read_resource(
1345            &self,
1346            _uri: &str,
1347            _cancellation_token: CancellationToken,
1348        ) -> Result<ReadResourceResult, Error> {
1349            Err(Error::TransportClosed)
1350        }
1351
1352        async fn list_tools(
1353            &self,
1354            _next_cursor: Option<String>,
1355            _cancellation_token: CancellationToken,
1356        ) -> Result<ListToolsResult, Error> {
1357            use serde_json::json;
1358            use std::sync::Arc;
1359            Ok(ListToolsResult {
1360                tools: vec![
1361                    Tool::new(
1362                        "tool".to_string(),
1363                        "A basic tool".to_string(),
1364                        Arc::new(json!({}).as_object().unwrap().clone()),
1365                    ),
1366                    Tool::new(
1367                        "available_tool".to_string(),
1368                        "An available tool".to_string(),
1369                        Arc::new(json!({}).as_object().unwrap().clone()),
1370                    ),
1371                    Tool::new(
1372                        "hidden_tool".to_string(),
1373                        "hidden tool".to_string(),
1374                        Arc::new(json!({}).as_object().unwrap().clone()),
1375                    ),
1376                ],
1377                next_cursor: None,
1378                meta: None,
1379            })
1380        }
1381
1382        async fn call_tool(
1383            &self,
1384            name: &str,
1385            _arguments: Option<JsonObject>,
1386            _cancellation_token: CancellationToken,
1387        ) -> Result<CallToolResult, Error> {
1388            match name {
1389                "tool" | "test__tool" | "available_tool" | "hidden_tool" => Ok(CallToolResult {
1390                    content: vec![],
1391                    is_error: None,
1392                    structured_content: None,
1393                    meta: None,
1394                }),
1395                _ => Err(Error::TransportClosed),
1396            }
1397        }
1398
1399        async fn list_prompts(
1400            &self,
1401            _next_cursor: Option<String>,
1402            _cancellation_token: CancellationToken,
1403        ) -> Result<ListPromptsResult, Error> {
1404            Err(Error::TransportClosed)
1405        }
1406
1407        async fn get_prompt(
1408            &self,
1409            _name: &str,
1410            _arguments: Value,
1411            _cancellation_token: CancellationToken,
1412        ) -> Result<GetPromptResult, Error> {
1413            Err(Error::TransportClosed)
1414        }
1415
1416        async fn subscribe(&self) -> mpsc::Receiver<ServerNotification> {
1417            mpsc::channel(1).1
1418        }
1419    }
1420
1421    #[tokio::test]
1422    async fn test_get_client_for_tool() {
1423        let extension_manager = ExtensionManager::new_without_provider();
1424
1425        // Add some mock clients using the helper method
1426        extension_manager
1427            .add_mock_extension(
1428                "test_client".to_string(),
1429                Arc::new(Mutex::new(Box::new(MockClient {}))),
1430            )
1431            .await;
1432
1433        extension_manager
1434            .add_mock_extension(
1435                "__client".to_string(),
1436                Arc::new(Mutex::new(Box::new(MockClient {}))),
1437            )
1438            .await;
1439
1440        extension_manager
1441            .add_mock_extension(
1442                "__cli__ent__".to_string(),
1443                Arc::new(Mutex::new(Box::new(MockClient {}))),
1444            )
1445            .await;
1446
1447        extension_manager
1448            .add_mock_extension(
1449                "client 🚀".to_string(),
1450                Arc::new(Mutex::new(Box::new(MockClient {}))),
1451            )
1452            .await;
1453
1454        // Test basic case
1455        assert!(extension_manager
1456            .get_client_for_tool("test_client__tool")
1457            .await
1458            .is_some());
1459
1460        // Test leading underscores
1461        assert!(extension_manager
1462            .get_client_for_tool("__client__tool")
1463            .await
1464            .is_some());
1465
1466        // Test multiple underscores in client name, and ending with __
1467        assert!(extension_manager
1468            .get_client_for_tool("__cli__ent____tool")
1469            .await
1470            .is_some());
1471
1472        // Test unicode in tool name, "client 🚀" should become "client_"
1473        assert!(extension_manager
1474            .get_client_for_tool("client___tool")
1475            .await
1476            .is_some());
1477    }
1478
1479    #[tokio::test]
1480    async fn test_dispatch_tool_call() {
1481        // test that dispatch_tool_call parses out the sanitized name correctly, and extracts
1482        // tool_names
1483        let extension_manager = ExtensionManager::new_without_provider();
1484
1485        // Add some mock clients using the helper method
1486        extension_manager
1487            .add_mock_extension(
1488                "test_client".to_string(),
1489                Arc::new(Mutex::new(Box::new(MockClient {}))),
1490            )
1491            .await;
1492
1493        extension_manager
1494            .add_mock_extension(
1495                "__cli__ent__".to_string(),
1496                Arc::new(Mutex::new(Box::new(MockClient {}))),
1497            )
1498            .await;
1499
1500        extension_manager
1501            .add_mock_extension(
1502                "client 🚀".to_string(),
1503                Arc::new(Mutex::new(Box::new(MockClient {}))),
1504            )
1505            .await;
1506
1507        // verify a normal tool call
1508        let tool_call = CallToolRequestParam {
1509            name: "test_client__tool".to_string().into(),
1510            arguments: Some(object!({})),
1511        };
1512
1513        let result = extension_manager
1514            .dispatch_tool_call(tool_call, CancellationToken::default())
1515            .await;
1516        assert!(result.is_ok());
1517
1518        let tool_call = CallToolRequestParam {
1519            name: "test_client__test__tool".to_string().into(),
1520            arguments: Some(object!({})),
1521        };
1522
1523        let result = extension_manager
1524            .dispatch_tool_call(tool_call, CancellationToken::default())
1525            .await;
1526        assert!(result.is_ok());
1527
1528        // verify a multiple underscores dispatch
1529        let tool_call = CallToolRequestParam {
1530            name: "__cli__ent____tool".to_string().into(),
1531            arguments: Some(object!({})),
1532        };
1533
1534        let result = extension_manager
1535            .dispatch_tool_call(tool_call, CancellationToken::default())
1536            .await;
1537        assert!(result.is_ok());
1538
1539        // Test unicode in tool name, "client 🚀" should become "client_"
1540        let tool_call = CallToolRequestParam {
1541            name: "client___tool".to_string().into(),
1542            arguments: Some(object!({})),
1543        };
1544
1545        let result = extension_manager
1546            .dispatch_tool_call(tool_call, CancellationToken::default())
1547            .await;
1548        assert!(result.is_ok());
1549
1550        let tool_call = CallToolRequestParam {
1551            name: "client___test__tool".to_string().into(),
1552            arguments: Some(object!({})),
1553        };
1554
1555        let result = extension_manager
1556            .dispatch_tool_call(tool_call, CancellationToken::default())
1557            .await;
1558        assert!(result.is_ok());
1559
1560        // this should error out, specifically for an ToolError::ExecutionError
1561        let invalid_tool_call = CallToolRequestParam {
1562            name: "client___tools".to_string().into(),
1563            arguments: Some(object!({})),
1564        };
1565
1566        let result = extension_manager
1567            .dispatch_tool_call(invalid_tool_call, CancellationToken::default())
1568            .await
1569            .unwrap()
1570            .result
1571            .await;
1572        assert!(matches!(
1573            result,
1574            Err(ErrorData {
1575                code: ErrorCode::INTERNAL_ERROR,
1576                ..
1577            })
1578        ));
1579
1580        // this should error out, specifically with an ToolError::NotFound
1581        // this client doesn't exist
1582        let invalid_tool_call = CallToolRequestParam {
1583            name: "_client__tools".to_string().into(),
1584            arguments: Some(object!({})),
1585        };
1586
1587        let result = extension_manager
1588            .dispatch_tool_call(invalid_tool_call, CancellationToken::default())
1589            .await;
1590        if let Err(err) = result {
1591            let tool_err = err.downcast_ref::<ErrorData>().expect("Expected ErrorData");
1592            assert_eq!(tool_err.code, ErrorCode::RESOURCE_NOT_FOUND);
1593        } else {
1594            panic!("Expected ErrorData with ErrorCode::RESOURCE_NOT_FOUND");
1595        }
1596    }
1597
1598    #[tokio::test]
1599    async fn test_tool_availability_filtering() {
1600        let extension_manager = ExtensionManager::new_without_provider();
1601
1602        // Only "available_tool" should be available to the LLM
1603        let available_tools = vec!["available_tool".to_string()];
1604
1605        extension_manager
1606            .add_mock_extension_with_tools(
1607                "test_extension".to_string(),
1608                Arc::new(Mutex::new(Box::new(MockClient {}))),
1609                available_tools,
1610            )
1611            .await;
1612
1613        let tools = extension_manager.get_prefixed_tools(None).await.unwrap();
1614
1615        let tool_names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect();
1616        assert!(!tool_names.iter().any(|name| name == "test_extension__tool")); // Default unavailable
1617        assert!(tool_names
1618            .iter()
1619            .any(|name| name == "test_extension__available_tool"));
1620        assert!(!tool_names
1621            .iter()
1622            .any(|name| name == "test_extension__hidden_tool"));
1623        assert!(tool_names.len() == 1);
1624    }
1625
1626    #[tokio::test]
1627    async fn test_tool_availability_defaults_to_available() {
1628        let extension_manager = ExtensionManager::new_without_provider();
1629
1630        extension_manager
1631            .add_mock_extension_with_tools(
1632                "test_extension".to_string(),
1633                Arc::new(Mutex::new(Box::new(MockClient {}))),
1634                vec![], // Empty available_tools means all tools are available by default
1635            )
1636            .await;
1637
1638        let tools = extension_manager.get_prefixed_tools(None).await.unwrap();
1639
1640        let tool_names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect();
1641        assert!(tool_names.iter().any(|name| name == "test_extension__tool"));
1642        assert!(tool_names
1643            .iter()
1644            .any(|name| name == "test_extension__available_tool"));
1645        assert!(tool_names
1646            .iter()
1647            .any(|name| name == "test_extension__hidden_tool"));
1648        assert!(tool_names.len() == 3);
1649    }
1650
1651    #[tokio::test]
1652    async fn test_dispatch_unavailable_tool_returns_error() {
1653        let extension_manager = ExtensionManager::new_without_provider();
1654
1655        let available_tools = vec!["available_tool".to_string()];
1656
1657        extension_manager
1658            .add_mock_extension_with_tools(
1659                "test_extension".to_string(),
1660                Arc::new(Mutex::new(Box::new(MockClient {}))),
1661                available_tools,
1662            )
1663            .await;
1664
1665        // Try to call an unavailable tool
1666        let unavailable_tool_call = CallToolRequestParam {
1667            name: "test_extension__tool".to_string().into(),
1668            arguments: Some(object!({})),
1669        };
1670
1671        let result = extension_manager
1672            .dispatch_tool_call(unavailable_tool_call, CancellationToken::default())
1673            .await;
1674
1675        // Should return RESOURCE_NOT_FOUND error
1676        if let Err(err) = result {
1677            let tool_err = err.downcast_ref::<ErrorData>().expect("Expected ErrorData");
1678            assert_eq!(tool_err.code, ErrorCode::RESOURCE_NOT_FOUND);
1679            assert!(tool_err.message.contains("is not available"));
1680        } else {
1681            panic!("Expected ErrorData with ErrorCode::RESOURCE_NOT_FOUND");
1682        }
1683
1684        // Try to call an available tool - should succeed
1685        let available_tool_call = CallToolRequestParam {
1686            name: "test_extension__available_tool".to_string().into(),
1687            arguments: Some(object!({})),
1688        };
1689
1690        let result = extension_manager
1691            .dispatch_tool_call(available_tool_call, CancellationToken::default())
1692            .await;
1693
1694        assert!(result.is_ok());
1695    }
1696
1697    #[tokio::test]
1698    async fn test_streamable_http_header_env_substitution() {
1699        let mut env_map = HashMap::new();
1700        env_map.insert("AUTH_TOKEN".to_string(), "secret123".to_string());
1701        env_map.insert("API_KEY".to_string(), "key456".to_string());
1702
1703        // Test ${VAR} syntax
1704        let result = substitute_env_vars("Bearer ${ AUTH_TOKEN }", &env_map);
1705        assert_eq!(result, "Bearer secret123");
1706
1707        // Test ${VAR} syntax without spaces
1708        let result = substitute_env_vars("Bearer ${AUTH_TOKEN}", &env_map);
1709        assert_eq!(result, "Bearer secret123");
1710
1711        // Test $VAR syntax
1712        let result = substitute_env_vars("Bearer $AUTH_TOKEN", &env_map);
1713        assert_eq!(result, "Bearer secret123");
1714
1715        // Test multiple substitutions
1716        let result = substitute_env_vars("Key: $API_KEY, Token: ${AUTH_TOKEN}", &env_map);
1717        assert_eq!(result, "Key: key456, Token: secret123");
1718
1719        // Test no substitution when variable doesn't exist
1720        let result = substitute_env_vars("Bearer ${UNKNOWN_VAR}", &env_map);
1721        assert_eq!(result, "Bearer ${UNKNOWN_VAR}");
1722
1723        // Test mixed content
1724        let result = substitute_env_vars(
1725            "Authorization: Bearer ${AUTH_TOKEN} and API ${API_KEY}",
1726            &env_map,
1727        );
1728        assert_eq!(result, "Authorization: Bearer secret123 and API key456");
1729    }
1730
1731    mod generate_extension_name_tests {
1732        use super::*;
1733        use rmcp::model::Implementation;
1734        use test_case::test_case;
1735
1736        fn make_info(name: &str) -> ServerInfo {
1737            ServerInfo {
1738                server_info: Implementation {
1739                    name: name.into(),
1740                    ..Default::default()
1741                },
1742                ..Default::default()
1743            }
1744        }
1745
1746        #[test_case(Some("kiwi-mcp-server"), None, "^kiwi-mcp-server$" ; "already normalized server name")]
1747        #[test_case(Some("Context7"), None, "^context7$" ; "mixed case normalized")]
1748        #[test_case(Some("@huggingface/mcp-services"), None, "^_huggingface_mcp-services$" ; "special chars normalized")]
1749        #[test_case(None, None, "^unnamed$" ; "no server info falls back")]
1750        #[test_case(Some(""), None, "^unnamed$" ; "empty server name falls back")]
1751        #[test_case(Some("github-mcp-server"), Some("github-mcp-server"), r"^github-mcp-server_[A-Za-z0-9]{6}$" ; "duplicate adds suffix")]
1752        fn test_generate_name(server_name: Option<&str>, collision: Option<&str>, expected: &str) {
1753            let info = server_name.map(make_info);
1754            let result = generate_extension_name(info.as_ref(), |n| collision == Some(n));
1755            let re = regex::Regex::new(expected).unwrap();
1756            assert!(re.is_match(&result));
1757        }
1758    }
1759
1760    #[tokio::test]
1761    async fn test_collect_moim_uses_minute_granularity() {
1762        let em = ExtensionManager::new_without_provider();
1763
1764        if let Some(moim) = em.collect_moim().await {
1765            // Timestamp should end with :00 (seconds fixed to 00)
1766            assert!(
1767                moim.contains(":00\n"),
1768                "Timestamp should use minute granularity"
1769            );
1770        }
1771    }
1772}