aster/config/
extensions.rs1use super::base::Config;
2use crate::agents::extension::PLATFORM_EXTENSIONS;
3use crate::agents::ExtensionConfig;
4use indexmap::IndexMap;
5use serde::{Deserialize, Serialize};
6use serde_yaml::Mapping;
7use tracing::warn;
8use utoipa::ToSchema;
9
10pub const DEFAULT_EXTENSION: &str = "developer";
11pub const DEFAULT_EXTENSION_TIMEOUT: u64 = 300;
12pub const DEFAULT_EXTENSION_DESCRIPTION: &str = "";
13pub const DEFAULT_DISPLAY_NAME: &str = "Developer";
14const EXTENSIONS_CONFIG_KEY: &str = "extensions";
15
16#[derive(Debug, Deserialize, Serialize, Clone, ToSchema)]
17pub struct ExtensionEntry {
18 pub enabled: bool,
19 #[serde(flatten)]
20 pub config: ExtensionConfig,
21}
22
23pub fn name_to_key(name: &str) -> String {
24 name.chars()
25 .filter(|c| !c.is_whitespace())
26 .collect::<String>()
27 .to_lowercase()
28}
29
30fn get_extensions_map() -> IndexMap<String, ExtensionEntry> {
31 let raw: Mapping = Config::global()
32 .get_param(EXTENSIONS_CONFIG_KEY)
33 .unwrap_or_else(|err| {
34 warn!(
35 "Failed to load {}: {err}. Falling back to empty object.",
36 EXTENSIONS_CONFIG_KEY
37 );
38 Default::default()
39 });
40
41 let mut extensions_map = IndexMap::with_capacity(raw.len());
42 for (k, v) in raw {
43 match (k, serde_yaml::from_value::<ExtensionEntry>(v)) {
44 (serde_yaml::Value::String(key), Ok(entry)) => {
45 extensions_map.insert(key, entry);
46 }
47 (k, v) => {
48 warn!(
49 key = ?k,
50 value = ?v,
51 "Skipping malformed extension config entry"
52 );
53 }
54 }
55 }
56
57 if !extensions_map.is_empty() {
58 for (name, def) in PLATFORM_EXTENSIONS.iter() {
59 if !extensions_map.contains_key(*name) {
60 extensions_map.insert(
61 name.to_string(),
62 ExtensionEntry {
63 config: ExtensionConfig::Platform {
64 name: def.name.to_string(),
65 description: def.description.to_string(),
66 bundled: Some(true),
67 available_tools: Vec::new(),
68 },
69 enabled: def.default_enabled,
70 },
71 );
72 }
73 }
74 }
75 extensions_map
76}
77
78fn save_extensions_map(extensions: IndexMap<String, ExtensionEntry>) {
79 let config = Config::global();
80 if let Err(e) = config.set_param(EXTENSIONS_CONFIG_KEY, &extensions) {
81 tracing::debug!("Failed to save extensions config: {}", e);
83 }
84}
85
86pub fn get_extension_by_name(name: &str) -> Option<ExtensionConfig> {
87 let extensions = get_extensions_map();
88 extensions
89 .values()
90 .find(|entry| entry.config.name() == name)
91 .map(|entry| entry.config.clone())
92}
93
94pub fn set_extension(entry: ExtensionEntry) {
95 let mut extensions = get_extensions_map();
96 let key = entry.config.key();
97 extensions.insert(key, entry);
98 save_extensions_map(extensions);
99}
100
101pub fn remove_extension(key: &str) {
102 let mut extensions = get_extensions_map();
103 extensions.shift_remove(key);
104 save_extensions_map(extensions);
105}
106
107pub fn set_extension_enabled(key: &str, enabled: bool) {
108 let mut extensions = get_extensions_map();
109 if let Some(entry) = extensions.get_mut(key) {
110 entry.enabled = enabled;
111 save_extensions_map(extensions);
112 }
113}
114
115pub fn get_all_extensions() -> Vec<ExtensionEntry> {
116 let extensions = get_extensions_map();
117 extensions.into_values().collect()
118}
119
120pub fn get_all_extension_names() -> Vec<String> {
121 let extensions = get_extensions_map();
122 extensions.keys().cloned().collect()
123}
124
125pub fn is_extension_enabled(key: &str) -> bool {
126 let extensions = get_extensions_map();
127 extensions.get(key).map(|e| e.enabled).unwrap_or(false)
128}
129
130pub fn get_enabled_extensions() -> Vec<ExtensionConfig> {
131 get_all_extensions()
132 .into_iter()
133 .filter(|ext| ext.enabled)
134 .map(|ext| ext.config)
135 .collect()
136}
137
138pub fn get_warnings() -> Vec<String> {
139 let raw: Mapping = Config::global()
140 .get_param(EXTENSIONS_CONFIG_KEY)
141 .unwrap_or_default();
142
143 let mut warnings = Vec::new();
144 for (k, v) in raw {
145 if let (serde_yaml::Value::String(key), Ok(entry)) =
146 (k, serde_yaml::from_value::<ExtensionEntry>(v))
147 {
148 if matches!(entry.config, ExtensionConfig::Sse { .. }) {
149 warnings.push(format!(
150 "'{}': SSE is unsupported, migrate to streamable_http",
151 key
152 ));
153 }
154 }
155 }
156 warnings
157}