1use crate::agents::chatrecall_extension;
2use crate::agents::code_execution_extension;
3use crate::agents::extension_manager_extension;
4use crate::agents::skills_extension;
5use crate::agents::todo_extension;
6use std::collections::HashMap;
7
8use crate::agents::mcp_client::McpClientTrait;
9use crate::config;
10use crate::config::extensions::name_to_key;
11use crate::config::permission::PermissionLevel;
12use once_cell::sync::Lazy;
13use rmcp::model::Tool;
14use rmcp::service::ClientInitializeError;
15use rmcp::ServiceError as ClientError;
16use serde::Deserializer;
17use serde::{Deserialize, Serialize};
18use thiserror::Error;
19use tracing::warn;
20use utoipa::ToSchema;
21
22#[derive(Error, Debug)]
23#[error("process quit before initialization: stderr = {stderr}")]
24pub struct ProcessExit {
25 stderr: String,
26 #[source]
27 source: ClientInitializeError,
28}
29
30impl ProcessExit {
31 pub fn new<T>(stderr: T, source: ClientInitializeError) -> Self
32 where
33 T: Into<String>,
34 {
35 ProcessExit {
36 stderr: stderr.into(),
37 source,
38 }
39 }
40}
41
42pub static PLATFORM_EXTENSIONS: Lazy<HashMap<&'static str, PlatformExtensionDef>> = Lazy::new(
43 || {
44 let mut map = HashMap::new();
45
46 map.insert(
47 todo_extension::EXTENSION_NAME,
48 PlatformExtensionDef {
49 name: todo_extension::EXTENSION_NAME,
50 description:
51 "Enable a todo list for aster so it can keep track of what it is doing",
52 default_enabled: true,
53 client_factory: |ctx| Box::new(todo_extension::TodoClient::new(ctx).unwrap()),
54 },
55 );
56
57 map.insert(
58 chatrecall_extension::EXTENSION_NAME,
59 PlatformExtensionDef {
60 name: chatrecall_extension::EXTENSION_NAME,
61 description:
62 "Search past conversations and load session summaries for contextual memory",
63 default_enabled: false,
64 client_factory: |ctx| {
65 Box::new(chatrecall_extension::ChatRecallClient::new(ctx).unwrap())
66 },
67 },
68 );
69
70 map.insert(
71 "extensionmanager",
72 PlatformExtensionDef {
73 name: extension_manager_extension::EXTENSION_NAME,
74 description:
75 "Enable extension management tools for discovering, enabling, and disabling extensions",
76 default_enabled: true,
77 client_factory: |ctx| Box::new(extension_manager_extension::ExtensionManagerClient::new(ctx).unwrap()),
78 },
79 );
80
81 map.insert(
82 skills_extension::EXTENSION_NAME,
83 PlatformExtensionDef {
84 name: skills_extension::EXTENSION_NAME,
85 description: "Load and use skills from relevant directories",
86 default_enabled: true,
87 client_factory: |ctx| Box::new(skills_extension::SkillsClient::new(ctx).unwrap()),
88 },
89 );
90
91 map.insert(
92 code_execution_extension::EXTENSION_NAME,
93 PlatformExtensionDef {
94 name: code_execution_extension::EXTENSION_NAME,
95 description: "Execute JavaScript code in a sandboxed environment",
96 default_enabled: false,
97 client_factory: |ctx| {
98 Box::new(code_execution_extension::CodeExecutionClient::new(ctx).unwrap())
99 },
100 },
101 );
102
103 map
104 },
105);
106
107#[derive(Clone)]
108pub struct PlatformExtensionContext {
109 pub session_id: Option<String>,
110 pub extension_manager:
111 Option<std::sync::Weak<crate::agents::extension_manager::ExtensionManager>>,
112}
113
114#[derive(Debug, Clone)]
115pub struct PlatformExtensionDef {
116 pub name: &'static str,
117 pub description: &'static str,
118 pub default_enabled: bool,
119 pub client_factory: fn(PlatformExtensionContext) -> Box<dyn McpClientTrait>,
120}
121
122#[derive(Error, Debug)]
124pub enum ExtensionError {
125 #[error("failed a client call to an MCP server: {0}")]
126 Client(#[from] ClientError),
127 #[error("invalid config: {0}")]
128 ConfigError(String),
129 #[error("error during extension setup: {0}")]
130 SetupError(String),
131 #[error("join error occurred during task execution: {0}")]
132 TaskJoinError(#[from] tokio::task::JoinError),
133 #[error("IO error: {0}")]
134 IoError(#[from] std::io::Error),
135 #[error("failed to initialize MCP client: {0}")]
136 InitializeError(#[from] ClientInitializeError),
137 #[error("{0}")]
138 ProcessExit(#[from] ProcessExit),
139}
140
141pub type ExtensionResult<T> = Result<T, ExtensionError>;
142
143#[derive(Debug, Clone, Deserialize, Serialize, Default, ToSchema, PartialEq)]
144pub struct Envs {
145 #[serde(default)]
147 #[serde(flatten)]
148 map: HashMap<String, String>,
149}
150
151impl Envs {
152 const DISALLOWED_KEYS: [&'static str; 31] = [
154 "PATH", "PATHEXT", "SystemRoot", "windir", "LD_LIBRARY_PATH", "LD_PRELOAD", "LD_AUDIT", "LD_DEBUG", "LD_BIND_NOW", "LD_ASSUME_KERNEL", "DYLD_LIBRARY_PATH", "DYLD_INSERT_LIBRARIES", "DYLD_FRAMEWORK_PATH", "PYTHONPATH", "PYTHONHOME", "NODE_OPTIONS", "RUBYOPT", "GEM_PATH", "GEM_HOME", "CLASSPATH", "GO111MODULE", "GOROOT", "APPINIT_DLLS", "SESSIONNAME", "ComSpec", "TEMP",
185 "TMP", "LOCALAPPDATA", "USERPROFILE", "HOMEDRIVE",
189 "HOMEPATH", ];
191
192 pub fn new(map: HashMap<String, String>) -> Self {
194 let mut validated = HashMap::new();
195
196 for (key, value) in map {
197 if Self::is_disallowed(&key) {
198 warn!("Skipping disallowed env var: {}", key);
199 continue;
200 }
201 validated.insert(key, value);
202 }
203
204 Self { map: validated }
205 }
206
207 pub fn get_env(&self) -> HashMap<String, String> {
209 self.map.clone()
210 }
211
212 pub fn validate(&self) -> Result<(), Box<ExtensionError>> {
214 for key in self.map.keys() {
215 if Self::is_disallowed(key) {
216 return Err(Box::new(ExtensionError::ConfigError(format!(
217 "environment variable {} not allowed to be overwritten",
218 key
219 ))));
220 }
221 }
222 Ok(())
223 }
224
225 fn is_disallowed(key: &str) -> bool {
226 Self::DISALLOWED_KEYS
227 .iter()
228 .any(|disallowed| disallowed.eq_ignore_ascii_case(key))
229 }
230}
231
232#[derive(Debug, Clone, Deserialize, Serialize, ToSchema, PartialEq)]
234#[serde(tag = "type")]
235pub enum ExtensionConfig {
236 #[serde(rename = "sse")]
238 Sse {
239 #[serde(default)]
240 #[schema(required)]
241 name: String,
242 #[serde(default)]
243 #[serde(deserialize_with = "deserialize_null_with_default")]
244 #[schema(required)]
245 description: String,
246 #[serde(default)]
247 uri: Option<String>,
248 },
249 #[serde(rename = "stdio")]
251 Stdio {
252 name: String,
254 #[serde(default)]
255 #[serde(deserialize_with = "deserialize_null_with_default")]
256 #[schema(required)]
257 description: String,
258 cmd: String,
259 args: Vec<String>,
260 #[serde(default)]
261 envs: Envs,
262 #[serde(default)]
263 env_keys: Vec<String>,
264 timeout: Option<u64>,
265 #[serde(default)]
266 bundled: Option<bool>,
267 #[serde(default)]
268 available_tools: Vec<String>,
269 },
270 #[serde(rename = "builtin")]
272 Builtin {
273 name: String,
275 #[serde(default)]
276 #[serde(deserialize_with = "deserialize_null_with_default")]
277 #[schema(required)]
278 description: String,
279 display_name: Option<String>, timeout: Option<u64>,
281 #[serde(default)]
282 bundled: Option<bool>,
283 #[serde(default)]
284 available_tools: Vec<String>,
285 },
286 #[serde(rename = "platform")]
288 Platform {
289 name: String,
291 #[serde(deserialize_with = "deserialize_null_with_default")]
292 #[schema(required)]
293 description: String,
294 #[serde(default)]
295 bundled: Option<bool>,
296 #[serde(default)]
297 available_tools: Vec<String>,
298 },
299 #[serde(rename = "streamable_http")]
301 StreamableHttp {
302 name: String,
304 #[serde(deserialize_with = "deserialize_null_with_default")]
305 #[schema(required)]
306 description: String,
307 uri: String,
308 #[serde(default)]
309 envs: Envs,
310 #[serde(default)]
311 env_keys: Vec<String>,
312 #[serde(default)]
313 headers: HashMap<String, String>,
314 timeout: Option<u64>,
317 #[serde(default)]
318 bundled: Option<bool>,
319 #[serde(default)]
320 available_tools: Vec<String>,
321 },
322 #[serde(rename = "frontend")]
324 Frontend {
325 name: String,
327 #[serde(deserialize_with = "deserialize_null_with_default")]
328 #[schema(required)]
329 description: String,
330 tools: Vec<Tool>,
332 instructions: Option<String>,
334 #[serde(default)]
335 bundled: Option<bool>,
336 #[serde(default)]
337 available_tools: Vec<String>,
338 },
339 #[serde(rename = "inline_python")]
341 InlinePython {
342 name: String,
344 #[serde(deserialize_with = "deserialize_null_with_default")]
345 #[schema(required)]
346 description: String,
347 code: String,
349 timeout: Option<u64>,
351 #[serde(default)]
353 dependencies: Option<Vec<String>>,
354 #[serde(default)]
355 available_tools: Vec<String>,
356 },
357}
358
359impl Default for ExtensionConfig {
360 fn default() -> Self {
361 Self::Builtin {
362 name: config::DEFAULT_EXTENSION.to_string(),
363 display_name: Some(config::DEFAULT_DISPLAY_NAME.to_string()),
364 description: "default".to_string(),
365 timeout: Some(config::DEFAULT_EXTENSION_TIMEOUT),
366 bundled: Some(true),
367 available_tools: Vec::new(),
368 }
369 }
370}
371
372impl ExtensionConfig {
373 pub fn streamable_http<S: Into<String>, T: Into<u64>>(
374 name: S,
375 uri: S,
376 description: S,
377 timeout: T,
378 ) -> Self {
379 Self::StreamableHttp {
380 name: name.into(),
381 uri: uri.into(),
382 envs: Envs::default(),
383 env_keys: Vec::new(),
384 headers: HashMap::new(),
385 description: description.into(),
386 timeout: Some(timeout.into()),
387 bundled: None,
388 available_tools: Vec::new(),
389 }
390 }
391
392 pub fn stdio<S: Into<String>, T: Into<u64>>(
393 name: S,
394 cmd: S,
395 description: S,
396 timeout: T,
397 ) -> Self {
398 Self::Stdio {
399 name: name.into(),
400 cmd: cmd.into(),
401 args: vec![],
402 envs: Envs::default(),
403 env_keys: Vec::new(),
404 description: description.into(),
405 timeout: Some(timeout.into()),
406 bundled: None,
407 available_tools: Vec::new(),
408 }
409 }
410
411 pub fn inline_python<S: Into<String>, T: Into<u64>>(
412 name: S,
413 code: S,
414 description: S,
415 timeout: T,
416 ) -> Self {
417 Self::InlinePython {
418 name: name.into(),
419 code: code.into(),
420 description: description.into(),
421 timeout: Some(timeout.into()),
422 dependencies: None,
423 available_tools: Vec::new(),
424 }
425 }
426
427 pub fn with_args<I, S>(self, args: I) -> Self
428 where
429 I: IntoIterator<Item = S>,
430 S: Into<String>,
431 {
432 match self {
433 Self::Stdio {
434 name,
435 cmd,
436 envs,
437 env_keys,
438 timeout,
439 description,
440 bundled,
441 available_tools,
442 ..
443 } => Self::Stdio {
444 name,
445 cmd,
446 envs,
447 env_keys,
448 args: args.into_iter().map(Into::into).collect(),
449 description,
450 timeout,
451 bundled,
452 available_tools,
453 },
454 other => other,
455 }
456 }
457
458 pub fn key(&self) -> String {
459 let name = self.name();
460 name_to_key(&name)
461 }
462
463 pub fn name(&self) -> String {
465 match self {
466 Self::Sse { name, .. } => name,
467 Self::StreamableHttp { name, .. } => name,
468 Self::Stdio { name, .. } => name,
469 Self::Builtin { name, .. } => name,
470 Self::Platform { name, .. } => name,
471 Self::Frontend { name, .. } => name,
472 Self::InlinePython { name, .. } => name,
473 }
474 .to_string()
475 }
476
477 pub fn is_tool_available(&self, tool_name: &str) -> bool {
479 let available_tools = match self {
480 Self::Sse { .. } => return false, Self::StreamableHttp {
482 available_tools, ..
483 }
484 | Self::Stdio {
485 available_tools, ..
486 }
487 | Self::Builtin {
488 available_tools, ..
489 }
490 | Self::Platform {
491 available_tools, ..
492 }
493 | Self::InlinePython {
494 available_tools, ..
495 }
496 | Self::Frontend {
497 available_tools, ..
498 } => available_tools,
499 };
500
501 available_tools.is_empty() || available_tools.contains(&tool_name.to_string())
504 }
505}
506
507impl std::fmt::Display for ExtensionConfig {
508 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509 match self {
510 ExtensionConfig::Sse { name, .. } => {
511 write!(f, "SSE({}: unsupported)", name)
512 }
513 ExtensionConfig::StreamableHttp { name, uri, .. } => {
514 write!(f, "StreamableHttp({}: {})", name, uri)
515 }
516 ExtensionConfig::Stdio {
517 name, cmd, args, ..
518 } => {
519 write!(f, "Stdio({}: {} {})", name, cmd, args.join(" "))
520 }
521 ExtensionConfig::Builtin { name, .. } => write!(f, "Builtin({})", name),
522 ExtensionConfig::Platform { name, .. } => write!(f, "Platform({})", name),
523 ExtensionConfig::Frontend { name, tools, .. } => {
524 write!(f, "Frontend({}: {} tools)", name, tools.len())
525 }
526 ExtensionConfig::InlinePython { name, code, .. } => {
527 write!(f, "InlinePython({}: {} chars)", name, code.len())
528 }
529 }
530 }
531}
532
533#[derive(Clone, Debug, Serialize)]
535pub struct ExtensionInfo {
536 pub name: String,
537 pub instructions: String,
538 pub has_resources: bool,
539}
540
541impl ExtensionInfo {
542 pub fn new(name: &str, instructions: &str, has_resources: bool) -> Self {
543 Self {
544 name: name.to_string(),
545 instructions: instructions.to_string(),
546 has_resources,
547 }
548 }
549}
550
551fn deserialize_null_with_default<'de, D, T>(deserializer: D) -> Result<T, D::Error>
552where
553 T: Default + Deserialize<'de>,
554 D: Deserializer<'de>,
555{
556 let opt = Option::deserialize(deserializer)?;
557 Ok(opt.unwrap_or_default())
558}
559
560#[derive(Clone, Debug, Serialize, ToSchema)]
562pub struct ToolInfo {
563 pub name: String,
564 pub description: String,
565 pub parameters: Vec<String>,
566 pub permission: Option<PermissionLevel>,
567}
568
569impl ToolInfo {
570 pub fn new(
571 name: &str,
572 description: &str,
573 parameters: Vec<String>,
574 permission: Option<PermissionLevel>,
575 ) -> Self {
576 Self {
577 name: name.to_string(),
578 description: description.to_string(),
579 parameters,
580 permission,
581 }
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use crate::agents::*;
588
589 #[test]
590 fn test_deserialize_missing_description() {
591 let config: ExtensionConfig = serde_yaml::from_str(
592 "enabled: true
593type: builtin
594name: developer
595display_name: Developer
596timeout: 300
597bundled: true
598available_tools: []",
599 )
600 .unwrap();
601 if let ExtensionConfig::Builtin { description, .. } = config {
602 assert_eq!(description, "")
603 } else {
604 panic!("unexpected result of deserialization: {}", config)
605 }
606 }
607
608 #[test]
609 fn test_deserialize_null_description() {
610 let config: ExtensionConfig = serde_yaml::from_str(
611 "enabled: true
612type: builtin
613name: developer
614display_name: Developer
615description: null
616timeout: 300
617bundled: true
618available_tools: []
619",
620 )
621 .unwrap();
622 if let ExtensionConfig::Builtin { description, .. } = config {
623 assert_eq!(description, "")
624 } else {
625 panic!("unexpected result of deserialization: {}", config)
626 }
627 }
628
629 #[test]
630 fn test_deserialize_normal_description() {
631 let config: ExtensionConfig = serde_yaml::from_str(
632 "enabled: true
633type: builtin
634name: developer
635display_name: Developer
636description: description goes here
637timeout: 300
638bundled: true
639available_tools: []
640 ",
641 )
642 .unwrap();
643 if let ExtensionConfig::Builtin { description, .. } = config {
644 assert_eq!(description, "description goes here")
645 } else {
646 panic!("unexpected result of deserialization: {}", config)
647 }
648 }
649}