Skip to main content

aster/agents/
extension.rs

1use crate::agents::chatrecall_extension;
2use crate::agents::code_execution_extension;
3use crate::agents::extension_manager_extension;
4use crate::agents::skills_extension;
5use crate::agents::todo_extension;
6use std::collections::HashMap;
7
8use crate::agents::mcp_client::McpClientTrait;
9use crate::config;
10use crate::config::extensions::name_to_key;
11use crate::config::permission::PermissionLevel;
12use once_cell::sync::Lazy;
13use rmcp::model::Tool;
14use rmcp::service::ClientInitializeError;
15use rmcp::ServiceError as ClientError;
16use serde::Deserializer;
17use serde::{Deserialize, Serialize};
18use thiserror::Error;
19use tracing::warn;
20use utoipa::ToSchema;
21
22#[derive(Error, Debug)]
23#[error("process quit before initialization: stderr = {stderr}")]
24pub struct ProcessExit {
25    stderr: String,
26    #[source]
27    source: ClientInitializeError,
28}
29
30impl ProcessExit {
31    pub fn new<T>(stderr: T, source: ClientInitializeError) -> Self
32    where
33        T: Into<String>,
34    {
35        ProcessExit {
36            stderr: stderr.into(),
37            source,
38        }
39    }
40}
41
42pub static PLATFORM_EXTENSIONS: Lazy<HashMap<&'static str, PlatformExtensionDef>> = Lazy::new(
43    || {
44        let mut map = HashMap::new();
45
46        map.insert(
47            todo_extension::EXTENSION_NAME,
48            PlatformExtensionDef {
49                name: todo_extension::EXTENSION_NAME,
50                description:
51                    "Enable a todo list for aster so it can keep track of what it is doing",
52                default_enabled: true,
53                client_factory: |ctx| Box::new(todo_extension::TodoClient::new(ctx).unwrap()),
54            },
55        );
56
57        map.insert(
58            chatrecall_extension::EXTENSION_NAME,
59            PlatformExtensionDef {
60                name: chatrecall_extension::EXTENSION_NAME,
61                description:
62                    "Search past conversations and load session summaries for contextual memory",
63                default_enabled: false,
64                client_factory: |ctx| {
65                    Box::new(chatrecall_extension::ChatRecallClient::new(ctx).unwrap())
66                },
67            },
68        );
69
70        map.insert(
71            "extensionmanager",
72            PlatformExtensionDef {
73                name: extension_manager_extension::EXTENSION_NAME,
74                description:
75                    "Enable extension management tools for discovering, enabling, and disabling extensions",
76                default_enabled: true,
77                client_factory: |ctx| Box::new(extension_manager_extension::ExtensionManagerClient::new(ctx).unwrap()),
78            },
79        );
80
81        map.insert(
82            skills_extension::EXTENSION_NAME,
83            PlatformExtensionDef {
84                name: skills_extension::EXTENSION_NAME,
85                description: "Load and use skills from relevant directories",
86                default_enabled: true,
87                client_factory: |ctx| Box::new(skills_extension::SkillsClient::new(ctx).unwrap()),
88            },
89        );
90
91        map.insert(
92            code_execution_extension::EXTENSION_NAME,
93            PlatformExtensionDef {
94                name: code_execution_extension::EXTENSION_NAME,
95                description: "Execute JavaScript code in a sandboxed environment",
96                default_enabled: false,
97                client_factory: |ctx| {
98                    Box::new(code_execution_extension::CodeExecutionClient::new(ctx).unwrap())
99                },
100            },
101        );
102
103        map
104    },
105);
106
107#[derive(Clone)]
108pub struct PlatformExtensionContext {
109    pub session_id: Option<String>,
110    pub extension_manager:
111        Option<std::sync::Weak<crate::agents::extension_manager::ExtensionManager>>,
112}
113
114#[derive(Debug, Clone)]
115pub struct PlatformExtensionDef {
116    pub name: &'static str,
117    pub description: &'static str,
118    pub default_enabled: bool,
119    pub client_factory: fn(PlatformExtensionContext) -> Box<dyn McpClientTrait>,
120}
121
122/// Errors from Extension operation
123#[derive(Error, Debug)]
124pub enum ExtensionError {
125    #[error("failed a client call to an MCP server: {0}")]
126    Client(#[from] ClientError),
127    #[error("invalid config: {0}")]
128    ConfigError(String),
129    #[error("error during extension setup: {0}")]
130    SetupError(String),
131    #[error("join error occurred during task execution: {0}")]
132    TaskJoinError(#[from] tokio::task::JoinError),
133    #[error("IO error: {0}")]
134    IoError(#[from] std::io::Error),
135    #[error("failed to initialize MCP client: {0}")]
136    InitializeError(#[from] ClientInitializeError),
137    #[error("{0}")]
138    ProcessExit(#[from] ProcessExit),
139}
140
141pub type ExtensionResult<T> = Result<T, ExtensionError>;
142
143#[derive(Debug, Clone, Deserialize, Serialize, Default, ToSchema, PartialEq)]
144pub struct Envs {
145    /// A map of environment variables to set, e.g. API_KEY -> some_secret, HOST -> host
146    #[serde(default)]
147    #[serde(flatten)]
148    map: HashMap<String, String>,
149}
150
151impl Envs {
152    /// List of sensitive env vars that should not be overridden
153    const DISALLOWED_KEYS: [&'static str; 31] = [
154        // 🔧 Binary path manipulation
155        "PATH",       // Controls executable lookup paths — critical for command hijacking
156        "PATHEXT",    // Windows: Determines recognized executable extensions (e.g., .exe, .bat)
157        "SystemRoot", // Windows: Can affect system DLL resolution (e.g., `kernel32.dll`)
158        "windir",     // Windows: Alternative to SystemRoot (used in legacy apps)
159        // 🧬 Dynamic linker hijacking (Linux/macOS)
160        "LD_LIBRARY_PATH",  // Alters shared library resolution
161        "LD_PRELOAD",       // Forces preloading of shared libraries — common attack vector
162        "LD_AUDIT",         // Loads a monitoring library that can intercept execution
163        "LD_DEBUG",         // Enables verbose linker logging (information disclosure risk)
164        "LD_BIND_NOW",      // Forces immediate symbol resolution, affecting ASLR
165        "LD_ASSUME_KERNEL", // Tricks linker into thinking it's running on an older kernel
166        // 🍎 macOS dynamic linker variables
167        "DYLD_LIBRARY_PATH",     // Same as LD_LIBRARY_PATH but for macOS
168        "DYLD_INSERT_LIBRARIES", // macOS equivalent of LD_PRELOAD
169        "DYLD_FRAMEWORK_PATH",   // Overrides framework lookup paths
170        // 🐍 Python / Node / Ruby / Java / Golang hijacking
171        "PYTHONPATH",   // Overrides Python module resolution
172        "PYTHONHOME",   // Overrides Python root directory
173        "NODE_OPTIONS", // Injects options/scripts into every Node.js process
174        "RUBYOPT",      // Injects Ruby execution flags
175        "GEM_PATH",     // Alters where RubyGems looks for installed packages
176        "GEM_HOME",     // Changes RubyGems default install location
177        "CLASSPATH",    // Java: Controls where classes are loaded from — critical for RCE attacks
178        "GO111MODULE",  // Go: Forces use of module proxy or disables it
179        "GOROOT", // Go: Changes root installation directory (could lead to execution hijacking)
180        // 🖥️ Windows-specific process & DLL hijacking
181        "APPINIT_DLLS", // Forces Windows to load a DLL into every process
182        "SESSIONNAME",  // Affects Windows session configuration
183        "ComSpec",      // Determines default command interpreter (can replace `cmd.exe`)
184        "TEMP",
185        "TMP",          // Redirects temporary file storage (useful for injection attacks)
186        "LOCALAPPDATA", // Controls application data paths (can be abused for persistence)
187        "USERPROFILE",  // Windows user directory (can affect profile-based execution paths)
188        "HOMEDRIVE",
189        "HOMEPATH", // Changes where the user's home directory is located
190    ];
191
192    /// Constructs a new Envs, skipping disallowed env vars with a warning
193    pub fn new(map: HashMap<String, String>) -> Self {
194        let mut validated = HashMap::new();
195
196        for (key, value) in map {
197            if Self::is_disallowed(&key) {
198                warn!("Skipping disallowed env var: {}", key);
199                continue;
200            }
201            validated.insert(key, value);
202        }
203
204        Self { map: validated }
205    }
206
207    /// Returns a copy of the validated env vars
208    pub fn get_env(&self) -> HashMap<String, String> {
209        self.map.clone()
210    }
211
212    /// Returns an error if any disallowed env var is present
213    pub fn validate(&self) -> Result<(), Box<ExtensionError>> {
214        for key in self.map.keys() {
215            if Self::is_disallowed(key) {
216                return Err(Box::new(ExtensionError::ConfigError(format!(
217                    "environment variable {} not allowed to be overwritten",
218                    key
219                ))));
220            }
221        }
222        Ok(())
223    }
224
225    fn is_disallowed(key: &str) -> bool {
226        Self::DISALLOWED_KEYS
227            .iter()
228            .any(|disallowed| disallowed.eq_ignore_ascii_case(key))
229    }
230}
231
232/// Represents the different types of MCP extensions that can be added to the manager
233#[derive(Debug, Clone, Deserialize, Serialize, ToSchema, PartialEq)]
234#[serde(tag = "type")]
235pub enum ExtensionConfig {
236    /// SSE transport is no longer supported - kept only for config file compatibility
237    #[serde(rename = "sse")]
238    Sse {
239        #[serde(default)]
240        #[schema(required)]
241        name: String,
242        #[serde(default)]
243        #[serde(deserialize_with = "deserialize_null_with_default")]
244        #[schema(required)]
245        description: String,
246        #[serde(default)]
247        uri: Option<String>,
248    },
249    /// Standard I/O client with command and arguments
250    #[serde(rename = "stdio")]
251    Stdio {
252        /// The name used to identify this extension
253        name: String,
254        #[serde(default)]
255        #[serde(deserialize_with = "deserialize_null_with_default")]
256        #[schema(required)]
257        description: String,
258        cmd: String,
259        args: Vec<String>,
260        #[serde(default)]
261        envs: Envs,
262        #[serde(default)]
263        env_keys: Vec<String>,
264        timeout: Option<u64>,
265        #[serde(default)]
266        bundled: Option<bool>,
267        #[serde(default)]
268        available_tools: Vec<String>,
269    },
270    /// Built-in extension that is part of the bundled aster MCP server
271    #[serde(rename = "builtin")]
272    Builtin {
273        /// The name used to identify this extension
274        name: String,
275        #[serde(default)]
276        #[serde(deserialize_with = "deserialize_null_with_default")]
277        #[schema(required)]
278        description: String,
279        display_name: Option<String>, // needed for the UI
280        timeout: Option<u64>,
281        #[serde(default)]
282        bundled: Option<bool>,
283        #[serde(default)]
284        available_tools: Vec<String>,
285    },
286    /// Platform extensions that have direct access to the agent etc and run in the agent process
287    #[serde(rename = "platform")]
288    Platform {
289        /// The name used to identify this extension
290        name: String,
291        #[serde(deserialize_with = "deserialize_null_with_default")]
292        #[schema(required)]
293        description: String,
294        #[serde(default)]
295        bundled: Option<bool>,
296        #[serde(default)]
297        available_tools: Vec<String>,
298    },
299    /// Streamable HTTP client with a URI endpoint using MCP Streamable HTTP specification
300    #[serde(rename = "streamable_http")]
301    StreamableHttp {
302        /// The name used to identify this extension
303        name: String,
304        #[serde(deserialize_with = "deserialize_null_with_default")]
305        #[schema(required)]
306        description: String,
307        uri: String,
308        #[serde(default)]
309        envs: Envs,
310        #[serde(default)]
311        env_keys: Vec<String>,
312        #[serde(default)]
313        headers: HashMap<String, String>,
314        // NOTE: set timeout to be optional for compatibility.
315        // However, new configurations should include this field.
316        timeout: Option<u64>,
317        #[serde(default)]
318        bundled: Option<bool>,
319        #[serde(default)]
320        available_tools: Vec<String>,
321    },
322    /// Frontend-provided tools that will be called through the frontend
323    #[serde(rename = "frontend")]
324    Frontend {
325        /// The name used to identify this extension
326        name: String,
327        #[serde(deserialize_with = "deserialize_null_with_default")]
328        #[schema(required)]
329        description: String,
330        /// The tools provided by the frontend
331        tools: Vec<Tool>,
332        /// Instructions for how to use these tools
333        instructions: Option<String>,
334        #[serde(default)]
335        bundled: Option<bool>,
336        #[serde(default)]
337        available_tools: Vec<String>,
338    },
339    /// Inline Python code that will be executed using uvx
340    #[serde(rename = "inline_python")]
341    InlinePython {
342        /// The name used to identify this extension
343        name: String,
344        #[serde(deserialize_with = "deserialize_null_with_default")]
345        #[schema(required)]
346        description: String,
347        /// The Python code to execute
348        code: String,
349        /// Timeout in seconds
350        timeout: Option<u64>,
351        /// Python package dependencies required by this extension
352        #[serde(default)]
353        dependencies: Option<Vec<String>>,
354        #[serde(default)]
355        available_tools: Vec<String>,
356    },
357}
358
359impl Default for ExtensionConfig {
360    fn default() -> Self {
361        Self::Builtin {
362            name: config::DEFAULT_EXTENSION.to_string(),
363            display_name: Some(config::DEFAULT_DISPLAY_NAME.to_string()),
364            description: "default".to_string(),
365            timeout: Some(config::DEFAULT_EXTENSION_TIMEOUT),
366            bundled: Some(true),
367            available_tools: Vec::new(),
368        }
369    }
370}
371
372impl ExtensionConfig {
373    pub fn streamable_http<S: Into<String>, T: Into<u64>>(
374        name: S,
375        uri: S,
376        description: S,
377        timeout: T,
378    ) -> Self {
379        Self::StreamableHttp {
380            name: name.into(),
381            uri: uri.into(),
382            envs: Envs::default(),
383            env_keys: Vec::new(),
384            headers: HashMap::new(),
385            description: description.into(),
386            timeout: Some(timeout.into()),
387            bundled: None,
388            available_tools: Vec::new(),
389        }
390    }
391
392    pub fn stdio<S: Into<String>, T: Into<u64>>(
393        name: S,
394        cmd: S,
395        description: S,
396        timeout: T,
397    ) -> Self {
398        Self::Stdio {
399            name: name.into(),
400            cmd: cmd.into(),
401            args: vec![],
402            envs: Envs::default(),
403            env_keys: Vec::new(),
404            description: description.into(),
405            timeout: Some(timeout.into()),
406            bundled: None,
407            available_tools: Vec::new(),
408        }
409    }
410
411    pub fn inline_python<S: Into<String>, T: Into<u64>>(
412        name: S,
413        code: S,
414        description: S,
415        timeout: T,
416    ) -> Self {
417        Self::InlinePython {
418            name: name.into(),
419            code: code.into(),
420            description: description.into(),
421            timeout: Some(timeout.into()),
422            dependencies: None,
423            available_tools: Vec::new(),
424        }
425    }
426
427    pub fn with_args<I, S>(self, args: I) -> Self
428    where
429        I: IntoIterator<Item = S>,
430        S: Into<String>,
431    {
432        match self {
433            Self::Stdio {
434                name,
435                cmd,
436                envs,
437                env_keys,
438                timeout,
439                description,
440                bundled,
441                available_tools,
442                ..
443            } => Self::Stdio {
444                name,
445                cmd,
446                envs,
447                env_keys,
448                args: args.into_iter().map(Into::into).collect(),
449                description,
450                timeout,
451                bundled,
452                available_tools,
453            },
454            other => other,
455        }
456    }
457
458    pub fn key(&self) -> String {
459        let name = self.name();
460        name_to_key(&name)
461    }
462
463    /// Get the extension name regardless of variant
464    pub fn name(&self) -> String {
465        match self {
466            Self::Sse { name, .. } => name,
467            Self::StreamableHttp { name, .. } => name,
468            Self::Stdio { name, .. } => name,
469            Self::Builtin { name, .. } => name,
470            Self::Platform { name, .. } => name,
471            Self::Frontend { name, .. } => name,
472            Self::InlinePython { name, .. } => name,
473        }
474        .to_string()
475    }
476
477    /// Check if a tool should be available to the LLM
478    pub fn is_tool_available(&self, tool_name: &str) -> bool {
479        let available_tools = match self {
480            Self::Sse { .. } => return false, // SSE is unsupported
481            Self::StreamableHttp {
482                available_tools, ..
483            }
484            | Self::Stdio {
485                available_tools, ..
486            }
487            | Self::Builtin {
488                available_tools, ..
489            }
490            | Self::Platform {
491                available_tools, ..
492            }
493            | Self::InlinePython {
494                available_tools, ..
495            }
496            | Self::Frontend {
497                available_tools, ..
498            } => available_tools,
499        };
500
501        // If no tools are specified, all tools are available
502        // If tools are specified, only those tools are available
503        available_tools.is_empty() || available_tools.contains(&tool_name.to_string())
504    }
505}
506
507impl std::fmt::Display for ExtensionConfig {
508    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509        match self {
510            ExtensionConfig::Sse { name, .. } => {
511                write!(f, "SSE({}: unsupported)", name)
512            }
513            ExtensionConfig::StreamableHttp { name, uri, .. } => {
514                write!(f, "StreamableHttp({}: {})", name, uri)
515            }
516            ExtensionConfig::Stdio {
517                name, cmd, args, ..
518            } => {
519                write!(f, "Stdio({}: {} {})", name, cmd, args.join(" "))
520            }
521            ExtensionConfig::Builtin { name, .. } => write!(f, "Builtin({})", name),
522            ExtensionConfig::Platform { name, .. } => write!(f, "Platform({})", name),
523            ExtensionConfig::Frontend { name, tools, .. } => {
524                write!(f, "Frontend({}: {} tools)", name, tools.len())
525            }
526            ExtensionConfig::InlinePython { name, code, .. } => {
527                write!(f, "InlinePython({}: {} chars)", name, code.len())
528            }
529        }
530    }
531}
532
533/// Information about the extension used for building prompts
534#[derive(Clone, Debug, Serialize)]
535pub struct ExtensionInfo {
536    pub name: String,
537    pub instructions: String,
538    pub has_resources: bool,
539}
540
541impl ExtensionInfo {
542    pub fn new(name: &str, instructions: &str, has_resources: bool) -> Self {
543        Self {
544            name: name.to_string(),
545            instructions: instructions.to_string(),
546            has_resources,
547        }
548    }
549}
550
551fn deserialize_null_with_default<'de, D, T>(deserializer: D) -> Result<T, D::Error>
552where
553    T: Default + Deserialize<'de>,
554    D: Deserializer<'de>,
555{
556    let opt = Option::deserialize(deserializer)?;
557    Ok(opt.unwrap_or_default())
558}
559
560/// Information about the tool used for building prompts
561#[derive(Clone, Debug, Serialize, ToSchema)]
562pub struct ToolInfo {
563    pub name: String,
564    pub description: String,
565    pub parameters: Vec<String>,
566    pub permission: Option<PermissionLevel>,
567}
568
569impl ToolInfo {
570    pub fn new(
571        name: &str,
572        description: &str,
573        parameters: Vec<String>,
574        permission: Option<PermissionLevel>,
575    ) -> Self {
576        Self {
577            name: name.to_string(),
578            description: description.to_string(),
579            parameters,
580            permission,
581        }
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use crate::agents::*;
588
589    #[test]
590    fn test_deserialize_missing_description() {
591        let config: ExtensionConfig = serde_yaml::from_str(
592            "enabled: true
593type: builtin
594name: developer
595display_name: Developer
596timeout: 300
597bundled: true
598available_tools: []",
599        )
600        .unwrap();
601        if let ExtensionConfig::Builtin { description, .. } = config {
602            assert_eq!(description, "")
603        } else {
604            panic!("unexpected result of deserialization: {}", config)
605        }
606    }
607
608    #[test]
609    fn test_deserialize_null_description() {
610        let config: ExtensionConfig = serde_yaml::from_str(
611            "enabled: true
612type: builtin
613name: developer
614display_name: Developer
615description: null
616timeout: 300
617bundled: true
618available_tools: []
619",
620        )
621        .unwrap();
622        if let ExtensionConfig::Builtin { description, .. } = config {
623            assert_eq!(description, "")
624        } else {
625            panic!("unexpected result of deserialization: {}", config)
626        }
627    }
628
629    #[test]
630    fn test_deserialize_normal_description() {
631        let config: ExtensionConfig = serde_yaml::from_str(
632            "enabled: true
633type: builtin
634name: developer
635display_name: Developer
636description: description goes here
637timeout: 300
638bundled: true
639available_tools: []
640    ",
641        )
642        .unwrap();
643        if let ExtensionConfig::Builtin { description, .. } = config {
644            assert_eq!(description, "description goes here")
645        } else {
646            panic!("unexpected result of deserialization: {}", config)
647        }
648    }
649}