1use crate::config::paths::Paths;
2use crate::config::Config;
3use crate::providers::anthropic::AnthropicProvider;
4use crate::providers::base::{ModelInfo, ProviderType};
5use crate::providers::ollama::OllamaProvider;
6use crate::providers::openai::OpenAiProvider;
7use anyhow::Result;
8use include_dir::{include_dir, Dir};
9use once_cell::sync::Lazy;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::Path;
13use std::sync::Mutex;
14use utoipa::ToSchema;
15
16static FIXED_PROVIDERS: Dir = include_dir!("$CARGO_MANIFEST_DIR/src/providers/declarative");
17
18pub fn custom_providers_dir() -> std::path::PathBuf {
19 Paths::config_dir().join("custom_providers")
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
23#[serde(rename_all = "lowercase")]
24pub enum ProviderEngine {
25 OpenAI,
26 Ollama,
27 Anthropic,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
31pub struct DeclarativeProviderConfig {
32 pub name: String,
33 pub engine: ProviderEngine,
34 pub display_name: String,
35 pub description: Option<String>,
36 pub api_key_env: String,
37 pub base_url: String,
38 pub models: Vec<ModelInfo>,
39 pub headers: Option<HashMap<String, String>>,
40 pub timeout_seconds: Option<u64>,
41 pub supports_streaming: Option<bool>,
42}
43
44impl DeclarativeProviderConfig {
45 pub fn id(&self) -> &str {
46 &self.name
47 }
48
49 pub fn display_name(&self) -> &str {
50 &self.display_name
51 }
52
53 pub fn models(&self) -> &[ModelInfo] {
54 &self.models
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
59pub struct LoadedProvider {
60 pub config: DeclarativeProviderConfig,
61 pub is_editable: bool,
62}
63
64static ID_GENERATION_LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
65
66pub fn generate_id(display_name: &str) -> String {
67 let _guard = ID_GENERATION_LOCK.lock().unwrap();
68
69 let normalized = display_name.to_lowercase().replace(' ', "_");
70 let base_id = format!("custom_{}", normalized);
71
72 let custom_dir = custom_providers_dir();
73 let mut candidate_id = base_id.clone();
74 let mut counter = 1;
75
76 while custom_dir.join(format!("{}.json", candidate_id)).exists() {
77 candidate_id = format!("{}_{}", base_id, counter);
78 counter += 1;
79 }
80
81 candidate_id
82}
83
84pub fn generate_api_key_name(id: &str) -> String {
85 format!("{}_API_KEY", id.to_uppercase())
86}
87
88pub fn create_custom_provider(
89 engine: &str,
90 display_name: String,
91 api_url: String,
92 api_key: String,
93 models: Vec<String>,
94 supports_streaming: Option<bool>,
95 headers: Option<HashMap<String, String>>,
96) -> Result<DeclarativeProviderConfig> {
97 let id = generate_id(&display_name);
98 let api_key_name = generate_api_key_name(&id);
99
100 let config = Config::global();
101 config.set_secret(&api_key_name, &api_key)?;
102
103 let model_infos: Vec<ModelInfo> = models
104 .into_iter()
105 .map(|name| ModelInfo::new(name, 128000))
106 .collect();
107
108 let provider_config = DeclarativeProviderConfig {
109 name: id.clone(),
110 engine: match engine {
111 "openai_compatible" => ProviderEngine::OpenAI,
112 "anthropic_compatible" => ProviderEngine::Anthropic,
113 "ollama_compatible" => ProviderEngine::Ollama,
114 _ => return Err(anyhow::anyhow!("Invalid provider type: {}", engine)),
115 },
116 display_name: display_name.clone(),
117 description: Some(format!("Custom {} provider", display_name)),
118 api_key_env: api_key_name,
119 base_url: api_url,
120 models: model_infos,
121 headers,
122 timeout_seconds: None,
123 supports_streaming,
124 };
125
126 let custom_providers_dir = custom_providers_dir();
127 std::fs::create_dir_all(&custom_providers_dir)?;
128
129 let json_content = serde_json::to_string_pretty(&provider_config)?;
130 let file_path = custom_providers_dir.join(format!("{}.json", id));
131 std::fs::write(file_path, json_content)?;
132
133 Ok(provider_config)
134}
135
136pub fn update_custom_provider(
137 id: &str,
138 provider_type: &str,
139 display_name: String,
140 api_url: String,
141 api_key: String,
142 models: Vec<String>,
143 supports_streaming: Option<bool>,
144) -> Result<()> {
145 let loaded_provider = load_provider(id)?;
146 let existing_config = loaded_provider.config;
147 let editable = loaded_provider.is_editable;
148
149 let config = Config::global();
150 if !api_key.is_empty() {
151 config.set_secret(&existing_config.api_key_env, &api_key)?;
152 }
153
154 if editable {
155 let model_infos: Vec<ModelInfo> = models
156 .into_iter()
157 .map(|name| ModelInfo::new(name, 128000))
158 .collect();
159
160 let updated_config = DeclarativeProviderConfig {
161 name: id.to_string(),
162 engine: match provider_type {
163 "openai_compatible" => ProviderEngine::OpenAI,
164 "anthropic_compatible" => ProviderEngine::Anthropic,
165 "ollama_compatible" => ProviderEngine::Ollama,
166 _ => return Err(anyhow::anyhow!("Invalid provider type: {}", provider_type)),
167 },
168 display_name,
169 description: existing_config.description,
170 api_key_env: existing_config.api_key_env,
171 base_url: api_url,
172 models: model_infos,
173 headers: existing_config.headers,
174 timeout_seconds: existing_config.timeout_seconds,
175 supports_streaming,
176 };
177
178 let file_path = custom_providers_dir().join(format!("{}.json", id));
179 let json_content = serde_json::to_string_pretty(&updated_config)?;
180 std::fs::write(file_path, json_content)?;
181 }
182 Ok(())
183}
184
185pub fn remove_custom_provider(id: &str) -> Result<()> {
186 let config = Config::global();
187 let api_key_name = generate_api_key_name(id);
188 let _ = config.delete_secret(&api_key_name);
189
190 let custom_providers_dir = custom_providers_dir();
191 let file_path = custom_providers_dir.join(format!("{}.json", id));
192
193 if file_path.exists() {
194 std::fs::remove_file(file_path)?;
195 }
196
197 Ok(())
198}
199
200pub fn load_provider(id: &str) -> Result<LoadedProvider> {
201 let custom_file_path = custom_providers_dir().join(format!("{}.json", id));
202
203 if custom_file_path.exists() {
204 let content = std::fs::read_to_string(&custom_file_path)?;
205 let config: DeclarativeProviderConfig = serde_json::from_str(&content)?;
206 return Ok(LoadedProvider {
207 config,
208 is_editable: true,
209 });
210 }
211
212 for file in FIXED_PROVIDERS.files() {
213 if file.path().extension().and_then(|s| s.to_str()) != Some("json") {
214 continue;
215 }
216
217 let content = file
218 .contents_utf8()
219 .ok_or_else(|| anyhow::anyhow!("Failed to read file as UTF-8: {:?}", file.path()))?;
220
221 let config: DeclarativeProviderConfig = serde_json::from_str(content)?;
222 if config.name == id {
223 return Ok(LoadedProvider {
224 config,
225 is_editable: false,
226 });
227 }
228 }
229
230 Err(anyhow::anyhow!("Provider not found: {}", id))
231}
232pub fn load_custom_providers(dir: &Path) -> Result<Vec<DeclarativeProviderConfig>> {
233 if !dir.exists() {
234 return Ok(Vec::new());
235 }
236
237 std::fs::read_dir(dir)?
238 .filter_map(|entry| {
239 let path = entry.ok()?.path();
240 (path.extension()? == "json").then_some(path)
241 })
242 .map(|path| {
243 let content = std::fs::read_to_string(&path)?;
244 serde_json::from_str(&content)
245 .map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", path.display(), e))
246 })
247 .collect()
248}
249
250fn load_fixed_providers() -> Result<Vec<DeclarativeProviderConfig>> {
251 let mut res = Vec::new();
252 for file in FIXED_PROVIDERS.files() {
253 if file.path().extension().and_then(|s| s.to_str()) != Some("json") {
254 continue;
255 }
256
257 let content = file
258 .contents_utf8()
259 .ok_or_else(|| anyhow::anyhow!("Failed to read file as UTF-8: {:?}", file.path()))?;
260
261 let config: DeclarativeProviderConfig = serde_json::from_str(content)?;
262 res.push(config)
263 }
264
265 Ok(res)
266}
267
268pub fn register_declarative_providers(
269 registry: &mut crate::providers::provider_registry::ProviderRegistry,
270) -> Result<()> {
271 let dir = custom_providers_dir();
272 let custom_providers = load_custom_providers(&dir)?;
273 let fixed_providers = load_fixed_providers()?;
274 for config in fixed_providers {
275 register_declarative_provider(registry, config, ProviderType::Declarative);
276 }
277
278 for config in custom_providers {
279 register_declarative_provider(registry, config, ProviderType::Custom);
280 }
281
282 Ok(())
283}
284
285pub fn register_declarative_provider(
286 registry: &mut crate::providers::provider_registry::ProviderRegistry,
287 config: DeclarativeProviderConfig,
288 provider_type: ProviderType,
289) {
290 let config_clone = config.clone();
291
292 match config.engine {
293 ProviderEngine::OpenAI => {
294 registry.register_with_name::<OpenAiProvider, _>(
295 &config,
296 provider_type,
297 move |model| OpenAiProvider::from_custom_config(model, config_clone.clone()),
298 );
299 }
300 ProviderEngine::Ollama => {
301 registry.register_with_name::<OllamaProvider, _>(
302 &config,
303 provider_type,
304 move |model| OllamaProvider::from_custom_config(model, config_clone.clone()),
305 );
306 }
307 ProviderEngine::Anthropic => {
308 registry.register_with_name::<AnthropicProvider, _>(
309 &config,
310 provider_type,
311 move |model| AnthropicProvider::from_custom_config(model, config_clone.clone()),
312 );
313 }
314 }
315}