Skip to main content

aster/config/
declarative_providers.rs

1use crate::config::paths::Paths;
2use crate::config::Config;
3use crate::providers::anthropic::AnthropicProvider;
4use crate::providers::base::{ModelInfo, ProviderType};
5use crate::providers::ollama::OllamaProvider;
6use crate::providers::openai::OpenAiProvider;
7use anyhow::Result;
8use include_dir::{include_dir, Dir};
9use once_cell::sync::Lazy;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::Path;
13use std::sync::Mutex;
14use utoipa::ToSchema;
15
16static FIXED_PROVIDERS: Dir = include_dir!("$CARGO_MANIFEST_DIR/src/providers/declarative");
17
18pub fn custom_providers_dir() -> std::path::PathBuf {
19    Paths::config_dir().join("custom_providers")
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
23#[serde(rename_all = "lowercase")]
24pub enum ProviderEngine {
25    OpenAI,
26    Ollama,
27    Anthropic,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
31pub struct DeclarativeProviderConfig {
32    pub name: String,
33    pub engine: ProviderEngine,
34    pub display_name: String,
35    pub description: Option<String>,
36    pub api_key_env: String,
37    pub base_url: String,
38    pub models: Vec<ModelInfo>,
39    pub headers: Option<HashMap<String, String>>,
40    pub timeout_seconds: Option<u64>,
41    pub supports_streaming: Option<bool>,
42}
43
44impl DeclarativeProviderConfig {
45    pub fn id(&self) -> &str {
46        &self.name
47    }
48
49    pub fn display_name(&self) -> &str {
50        &self.display_name
51    }
52
53    pub fn models(&self) -> &[ModelInfo] {
54        &self.models
55    }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
59pub struct LoadedProvider {
60    pub config: DeclarativeProviderConfig,
61    pub is_editable: bool,
62}
63
64static ID_GENERATION_LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
65
66pub fn generate_id(display_name: &str) -> String {
67    let _guard = ID_GENERATION_LOCK.lock().unwrap();
68
69    let normalized = display_name.to_lowercase().replace(' ', "_");
70    let base_id = format!("custom_{}", normalized);
71
72    let custom_dir = custom_providers_dir();
73    let mut candidate_id = base_id.clone();
74    let mut counter = 1;
75
76    while custom_dir.join(format!("{}.json", candidate_id)).exists() {
77        candidate_id = format!("{}_{}", base_id, counter);
78        counter += 1;
79    }
80
81    candidate_id
82}
83
84pub fn generate_api_key_name(id: &str) -> String {
85    format!("{}_API_KEY", id.to_uppercase())
86}
87
88pub fn create_custom_provider(
89    engine: &str,
90    display_name: String,
91    api_url: String,
92    api_key: String,
93    models: Vec<String>,
94    supports_streaming: Option<bool>,
95    headers: Option<HashMap<String, String>>,
96) -> Result<DeclarativeProviderConfig> {
97    let id = generate_id(&display_name);
98    let api_key_name = generate_api_key_name(&id);
99
100    let config = Config::global();
101    config.set_secret(&api_key_name, &api_key)?;
102
103    let model_infos: Vec<ModelInfo> = models
104        .into_iter()
105        .map(|name| ModelInfo::new(name, 128000))
106        .collect();
107
108    let provider_config = DeclarativeProviderConfig {
109        name: id.clone(),
110        engine: match engine {
111            "openai_compatible" => ProviderEngine::OpenAI,
112            "anthropic_compatible" => ProviderEngine::Anthropic,
113            "ollama_compatible" => ProviderEngine::Ollama,
114            _ => return Err(anyhow::anyhow!("Invalid provider type: {}", engine)),
115        },
116        display_name: display_name.clone(),
117        description: Some(format!("Custom {} provider", display_name)),
118        api_key_env: api_key_name,
119        base_url: api_url,
120        models: model_infos,
121        headers,
122        timeout_seconds: None,
123        supports_streaming,
124    };
125
126    let custom_providers_dir = custom_providers_dir();
127    std::fs::create_dir_all(&custom_providers_dir)?;
128
129    let json_content = serde_json::to_string_pretty(&provider_config)?;
130    let file_path = custom_providers_dir.join(format!("{}.json", id));
131    std::fs::write(file_path, json_content)?;
132
133    Ok(provider_config)
134}
135
136pub fn update_custom_provider(
137    id: &str,
138    provider_type: &str,
139    display_name: String,
140    api_url: String,
141    api_key: String,
142    models: Vec<String>,
143    supports_streaming: Option<bool>,
144) -> Result<()> {
145    let loaded_provider = load_provider(id)?;
146    let existing_config = loaded_provider.config;
147    let editable = loaded_provider.is_editable;
148
149    let config = Config::global();
150    if !api_key.is_empty() {
151        config.set_secret(&existing_config.api_key_env, &api_key)?;
152    }
153
154    if editable {
155        let model_infos: Vec<ModelInfo> = models
156            .into_iter()
157            .map(|name| ModelInfo::new(name, 128000))
158            .collect();
159
160        let updated_config = DeclarativeProviderConfig {
161            name: id.to_string(),
162            engine: match provider_type {
163                "openai_compatible" => ProviderEngine::OpenAI,
164                "anthropic_compatible" => ProviderEngine::Anthropic,
165                "ollama_compatible" => ProviderEngine::Ollama,
166                _ => return Err(anyhow::anyhow!("Invalid provider type: {}", provider_type)),
167            },
168            display_name,
169            description: existing_config.description,
170            api_key_env: existing_config.api_key_env,
171            base_url: api_url,
172            models: model_infos,
173            headers: existing_config.headers,
174            timeout_seconds: existing_config.timeout_seconds,
175            supports_streaming,
176        };
177
178        let file_path = custom_providers_dir().join(format!("{}.json", id));
179        let json_content = serde_json::to_string_pretty(&updated_config)?;
180        std::fs::write(file_path, json_content)?;
181    }
182    Ok(())
183}
184
185pub fn remove_custom_provider(id: &str) -> Result<()> {
186    let config = Config::global();
187    let api_key_name = generate_api_key_name(id);
188    let _ = config.delete_secret(&api_key_name);
189
190    let custom_providers_dir = custom_providers_dir();
191    let file_path = custom_providers_dir.join(format!("{}.json", id));
192
193    if file_path.exists() {
194        std::fs::remove_file(file_path)?;
195    }
196
197    Ok(())
198}
199
200pub fn load_provider(id: &str) -> Result<LoadedProvider> {
201    let custom_file_path = custom_providers_dir().join(format!("{}.json", id));
202
203    if custom_file_path.exists() {
204        let content = std::fs::read_to_string(&custom_file_path)?;
205        let config: DeclarativeProviderConfig = serde_json::from_str(&content)?;
206        return Ok(LoadedProvider {
207            config,
208            is_editable: true,
209        });
210    }
211
212    for file in FIXED_PROVIDERS.files() {
213        if file.path().extension().and_then(|s| s.to_str()) != Some("json") {
214            continue;
215        }
216
217        let content = file
218            .contents_utf8()
219            .ok_or_else(|| anyhow::anyhow!("Failed to read file as UTF-8: {:?}", file.path()))?;
220
221        let config: DeclarativeProviderConfig = serde_json::from_str(content)?;
222        if config.name == id {
223            return Ok(LoadedProvider {
224                config,
225                is_editable: false,
226            });
227        }
228    }
229
230    Err(anyhow::anyhow!("Provider not found: {}", id))
231}
232pub fn load_custom_providers(dir: &Path) -> Result<Vec<DeclarativeProviderConfig>> {
233    if !dir.exists() {
234        return Ok(Vec::new());
235    }
236
237    std::fs::read_dir(dir)?
238        .filter_map(|entry| {
239            let path = entry.ok()?.path();
240            (path.extension()? == "json").then_some(path)
241        })
242        .map(|path| {
243            let content = std::fs::read_to_string(&path)?;
244            serde_json::from_str(&content)
245                .map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", path.display(), e))
246        })
247        .collect()
248}
249
250fn load_fixed_providers() -> Result<Vec<DeclarativeProviderConfig>> {
251    let mut res = Vec::new();
252    for file in FIXED_PROVIDERS.files() {
253        if file.path().extension().and_then(|s| s.to_str()) != Some("json") {
254            continue;
255        }
256
257        let content = file
258            .contents_utf8()
259            .ok_or_else(|| anyhow::anyhow!("Failed to read file as UTF-8: {:?}", file.path()))?;
260
261        let config: DeclarativeProviderConfig = serde_json::from_str(content)?;
262        res.push(config)
263    }
264
265    Ok(res)
266}
267
268pub fn register_declarative_providers(
269    registry: &mut crate::providers::provider_registry::ProviderRegistry,
270) -> Result<()> {
271    let dir = custom_providers_dir();
272    let custom_providers = load_custom_providers(&dir)?;
273    let fixed_providers = load_fixed_providers()?;
274    for config in fixed_providers {
275        register_declarative_provider(registry, config, ProviderType::Declarative);
276    }
277
278    for config in custom_providers {
279        register_declarative_provider(registry, config, ProviderType::Custom);
280    }
281
282    Ok(())
283}
284
285pub fn register_declarative_provider(
286    registry: &mut crate::providers::provider_registry::ProviderRegistry,
287    config: DeclarativeProviderConfig,
288    provider_type: ProviderType,
289) {
290    let config_clone = config.clone();
291
292    match config.engine {
293        ProviderEngine::OpenAI => {
294            registry.register_with_name::<OpenAiProvider, _>(
295                &config,
296                provider_type,
297                move |model| OpenAiProvider::from_custom_config(model, config_clone.clone()),
298            );
299        }
300        ProviderEngine::Ollama => {
301            registry.register_with_name::<OllamaProvider, _>(
302                &config,
303                provider_type,
304                move |model| OllamaProvider::from_custom_config(model, config_clone.clone()),
305            );
306        }
307        ProviderEngine::Anthropic => {
308            registry.register_with_name::<AnthropicProvider, _>(
309                &config,
310                provider_type,
311                move |model| AnthropicProvider::from_custom_config(model, config_clone.clone()),
312            );
313        }
314    }
315}