Skip to main content

aster/config/
extensions.rs

1use super::base::Config;
2use crate::agents::extension::PLATFORM_EXTENSIONS;
3use crate::agents::ExtensionConfig;
4use indexmap::IndexMap;
5use serde::{Deserialize, Serialize};
6use serde_yaml::Mapping;
7use tracing::warn;
8use utoipa::ToSchema;
9
10pub const DEFAULT_EXTENSION: &str = "developer";
11pub const DEFAULT_EXTENSION_TIMEOUT: u64 = 300;
12pub const DEFAULT_EXTENSION_DESCRIPTION: &str = "";
13pub const DEFAULT_DISPLAY_NAME: &str = "Developer";
14const EXTENSIONS_CONFIG_KEY: &str = "extensions";
15
16#[derive(Debug, Deserialize, Serialize, Clone, ToSchema)]
17pub struct ExtensionEntry {
18    pub enabled: bool,
19    #[serde(flatten)]
20    pub config: ExtensionConfig,
21}
22
23pub fn name_to_key(name: &str) -> String {
24    name.chars()
25        .filter(|c| !c.is_whitespace())
26        .collect::<String>()
27        .to_lowercase()
28}
29
30fn get_extensions_map() -> IndexMap<String, ExtensionEntry> {
31    let raw: Mapping = Config::global()
32        .get_param(EXTENSIONS_CONFIG_KEY)
33        .unwrap_or_else(|err| {
34            warn!(
35                "Failed to load {}: {err}. Falling back to empty object.",
36                EXTENSIONS_CONFIG_KEY
37            );
38            Default::default()
39        });
40
41    let mut extensions_map = IndexMap::with_capacity(raw.len());
42    for (k, v) in raw {
43        match (k, serde_yaml::from_value::<ExtensionEntry>(v)) {
44            (serde_yaml::Value::String(key), Ok(entry)) => {
45                extensions_map.insert(key, entry);
46            }
47            (k, v) => {
48                warn!(
49                    key = ?k,
50                    value = ?v,
51                    "Skipping malformed extension config entry"
52                );
53            }
54        }
55    }
56
57    if !extensions_map.is_empty() {
58        for (name, def) in PLATFORM_EXTENSIONS.iter() {
59            if !extensions_map.contains_key(*name) {
60                extensions_map.insert(
61                    name.to_string(),
62                    ExtensionEntry {
63                        config: ExtensionConfig::Platform {
64                            name: def.name.to_string(),
65                            description: def.description.to_string(),
66                            bundled: Some(true),
67                            available_tools: Vec::new(),
68                        },
69                        enabled: def.default_enabled,
70                    },
71                );
72            }
73        }
74    }
75    extensions_map
76}
77
78fn save_extensions_map(extensions: IndexMap<String, ExtensionEntry>) {
79    let config = Config::global();
80    if let Err(e) = config.set_param(EXTENSIONS_CONFIG_KEY, &extensions) {
81        // TODO(jack) why is this just a debug statement?
82        tracing::debug!("Failed to save extensions config: {}", e);
83    }
84}
85
86pub fn get_extension_by_name(name: &str) -> Option<ExtensionConfig> {
87    let extensions = get_extensions_map();
88    extensions
89        .values()
90        .find(|entry| entry.config.name() == name)
91        .map(|entry| entry.config.clone())
92}
93
94pub fn set_extension(entry: ExtensionEntry) {
95    let mut extensions = get_extensions_map();
96    let key = entry.config.key();
97    extensions.insert(key, entry);
98    save_extensions_map(extensions);
99}
100
101pub fn remove_extension(key: &str) {
102    let mut extensions = get_extensions_map();
103    extensions.shift_remove(key);
104    save_extensions_map(extensions);
105}
106
107pub fn set_extension_enabled(key: &str, enabled: bool) {
108    let mut extensions = get_extensions_map();
109    if let Some(entry) = extensions.get_mut(key) {
110        entry.enabled = enabled;
111        save_extensions_map(extensions);
112    }
113}
114
115pub fn get_all_extensions() -> Vec<ExtensionEntry> {
116    let extensions = get_extensions_map();
117    extensions.into_values().collect()
118}
119
120pub fn get_all_extension_names() -> Vec<String> {
121    let extensions = get_extensions_map();
122    extensions.keys().cloned().collect()
123}
124
125pub fn is_extension_enabled(key: &str) -> bool {
126    let extensions = get_extensions_map();
127    extensions.get(key).map(|e| e.enabled).unwrap_or(false)
128}
129
130pub fn get_enabled_extensions() -> Vec<ExtensionConfig> {
131    get_all_extensions()
132        .into_iter()
133        .filter(|ext| ext.enabled)
134        .map(|ext| ext.config)
135        .collect()
136}
137
138pub fn get_warnings() -> Vec<String> {
139    let raw: Mapping = Config::global()
140        .get_param(EXTENSIONS_CONFIG_KEY)
141        .unwrap_or_default();
142
143    let mut warnings = Vec::new();
144    for (k, v) in raw {
145        if let (serde_yaml::Value::String(key), Ok(entry)) =
146            (k, serde_yaml::from_value::<ExtensionEntry>(v))
147        {
148            if matches!(entry.config, ExtensionConfig::Sse { .. }) {
149                warnings.push(format!(
150                    "'{}': SSE is unsupported, migrate to streamable_http",
151                    key
152                ));
153            }
154        }
155    }
156    warnings
157}