1use anyhow::Result;
7use colored::Colorize;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fs;
11use std::path::PathBuf;
12
13#[derive(Debug, Serialize, Deserialize, Clone)]
15pub struct ProviderRegistry {
16 pub version: String,
18
19 pub providers: HashMap<String, ProviderMetadata>,
21
22 pub base_url: String,
24}
25
26#[derive(Debug, Serialize, Deserialize, Clone)]
28pub struct ProviderMetadata {
29 pub name: String,
31
32 pub description: String,
34
35 pub config_file: String,
37
38 pub version: String,
40
41 pub auth_type: AuthType,
43
44 #[serde(default)]
46 pub tags: Vec<String>,
47
48 #[serde(default)]
50 pub official: bool,
51
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub docs_url: Option<String>,
55
56 #[serde(skip_serializing_if = "Option::is_none")]
58 pub min_version: Option<String>,
59}
60
61#[derive(Debug, Serialize, Deserialize, Clone)]
63#[serde(rename_all = "snake_case")]
64pub enum AuthType {
65 ApiKey,
66 ServiceAccount,
67 OAuth,
68 Token,
69 Headers,
70 None,
71}
72
73pub struct ProviderInstaller {
75 registry_source: String,
77
78 cache_dir: PathBuf,
80
81 providers_dir: PathBuf,
83}
84
85impl ProviderInstaller {
86 pub fn new() -> Result<Self> {
88 let config_dir = crate::config::Config::config_dir()?;
89 let cache_dir = config_dir.join(".provider_cache");
90 let providers_dir = config_dir.join("providers");
91
92 let registry_source = std::env::var("LC_PROVIDER_REGISTRY").unwrap_or_else(|_| {
94 "https://raw.githubusercontent.com/rajashekar/lc-providers/main".to_string()
95 });
96
97 Ok(Self {
98 registry_source,
99 cache_dir,
100 providers_dir,
101 })
102 }
103
104 pub async fn fetch_registry(&self) -> Result<ProviderRegistry> {
106 let registry_url = format!("{}/registry.json", self.registry_source);
107
108 crate::debug_log!("Fetching provider registry from: {}", registry_url);
109
110 if registry_url.starts_with("file://") {
112 let path = registry_url
113 .strip_prefix("file://")
114 .ok_or_else(|| anyhow::anyhow!("Invalid file:// URL format"))?;
115 let content = fs::read_to_string(path)
116 .map_err(|e| anyhow::anyhow!("Failed to read local registry: {}", e))?;
117 let registry: ProviderRegistry = serde_json::from_str(&content)
118 .map_err(|e| anyhow::anyhow!("Failed to parse registry: {}", e))?;
119
120 self.cache_registry(®istry)?;
122
123 return Ok(registry);
124 }
125
126 let client = reqwest::Client::builder()
128 .timeout(std::time::Duration::from_secs(30))
129 .build()?;
130
131 let response = client
132 .get(®istry_url)
133 .send()
134 .await
135 .map_err(|e| anyhow::anyhow!("Failed to fetch registry: {}", e))?;
136
137 if !response.status().is_success() {
138 anyhow::bail!("Failed to fetch registry: HTTP {}", response.status());
139 }
140
141 let registry: ProviderRegistry = response
142 .json()
143 .await
144 .map_err(|e| anyhow::anyhow!("Failed to parse registry: {}", e))?;
145
146 self.cache_registry(®istry)?;
148
149 Ok(registry)
150 }
151
152 pub fn get_cached_registry(&self) -> Result<Option<ProviderRegistry>> {
154 let cache_file = self.cache_dir.join("registry.json");
155
156 if !cache_file.exists() {
157 return Ok(None);
158 }
159
160 let metadata = fs::metadata(&cache_file)?;
162 if let Ok(modified) = metadata.modified() {
163 let age = std::time::SystemTime::now()
164 .duration_since(modified)
165 .unwrap_or(std::time::Duration::MAX);
166
167 if age > std::time::Duration::from_secs(24 * 60 * 60) {
168 crate::debug_log!("Registry cache is stale (>24 hours old)");
169 return Ok(None);
170 }
171 }
172
173 let content = fs::read_to_string(&cache_file)?;
174 let registry: ProviderRegistry = serde_json::from_str(&content)?;
175
176 Ok(Some(registry))
177 }
178
179 fn cache_registry(&self, registry: &ProviderRegistry) -> Result<()> {
181 fs::create_dir_all(&self.cache_dir)?;
182
183 let cache_file = self.cache_dir.join("registry.json");
184 let content = serde_json::to_string_pretty(registry)?;
185 fs::write(&cache_file, content)?;
186
187 Ok(())
188 }
189
190 pub async fn list_available(&self) -> Result<Vec<(String, ProviderMetadata)>> {
192 let registry = if let Some(cached) = self.get_cached_registry()? {
194 cached
195 } else {
196 self.fetch_registry().await?
197 };
198
199 let mut providers: Vec<_> = registry.providers.into_iter().collect();
200 providers.sort_by(|a, b| a.0.cmp(&b.0));
201
202 Ok(providers)
203 }
204
205 pub async fn install_provider(&self, provider_id: &str, force: bool) -> Result<()> {
207 println!("{} Installing provider '{}'...", "đĻ".blue(), provider_id);
208
209 let registry = if let Some(cached) = self.get_cached_registry()? {
211 cached
212 } else {
213 println!("{} Fetching provider registry...", "đ".blue());
214 self.fetch_registry().await?
215 };
216
217 let metadata = registry
219 .providers
220 .get(provider_id)
221 .ok_or_else(|| anyhow::anyhow!("Provider '{}' not found in registry", provider_id))?;
222
223 let target_file = self.providers_dir.join(&metadata.config_file);
225 if target_file.exists() && !force {
226 if let Ok(existing_config) = fs::read_to_string(&target_file) {
228 if let Ok(existing_toml) = toml::from_str::<toml::Value>(&existing_config) {
229 if let Some(existing_version) =
230 existing_toml.get("version").and_then(|v| v.as_str())
231 {
232 if existing_version == metadata.version {
233 println!(
234 "{} Provider '{}' is already up to date (v{})",
235 "â".green(),
236 provider_id,
237 metadata.version
238 );
239 return Ok(());
240 }
241 }
242 }
243 }
244
245 println!(
246 "{} Provider '{}' already exists. Updating to v{}...",
247 "đ".yellow(),
248 provider_id,
249 metadata.version
250 );
251 }
252
253 let config_url = format!("{}/providers/{}", registry.base_url, metadata.config_file);
255
256 crate::debug_log!("Downloading provider config from: {}", config_url);
257
258 let config_content = if config_url.starts_with("file://") {
259 let path = config_url
261 .strip_prefix("file://")
262 .ok_or_else(|| anyhow::anyhow!("Invalid file:// URL format"))?;
263 fs::read_to_string(path)
264 .map_err(|e| anyhow::anyhow!("Failed to read local provider config: {}", e))?
265 } else {
266 let client = reqwest::Client::builder()
268 .timeout(std::time::Duration::from_secs(30))
269 .build()?;
270
271 let response = client
272 .get(&config_url)
273 .send()
274 .await
275 .map_err(|e| anyhow::anyhow!("Failed to download provider config: {}", e))?;
276
277 if !response.status().is_success() {
278 anyhow::bail!(
279 "Failed to download provider config: HTTP {}",
280 response.status()
281 );
282 }
283
284 response.text().await?
285 };
286
287 self.validate_provider_config(&config_content)?;
289
290 fs::create_dir_all(&self.providers_dir)?;
292
293 fs::write(&target_file, &config_content)?;
295
296 println!(
297 "{} Provider '{}' installed successfully (v{})",
298 "â
".green(),
299 provider_id,
300 metadata.version
301 );
302
303 self.show_auth_instructions(provider_id, metadata)?;
305
306 Ok(())
307 }
308
309 pub async fn update_provider(&self, provider_id: &str) -> Result<()> {
311 self.install_provider(provider_id, true).await
312 }
313
314 pub async fn update_all_providers(&self) -> Result<()> {
316 println!("{} Updating all installed providers...", "đ".blue());
317
318 let installed = self.list_installed_providers()?;
320
321 if installed.is_empty() {
322 println!("{} No providers installed", "âšī¸".blue());
323 return Ok(());
324 }
325
326 let mut updated_count = 0;
327 let mut failed_count = 0;
328
329 for provider_id in installed {
330 match self.update_provider(&provider_id).await {
331 Ok(_) => updated_count += 1,
332 Err(e) => {
333 eprintln!("{} Failed to update '{}': {}", "â".red(), provider_id, e);
334 failed_count += 1;
335 }
336 }
337 }
338
339 if failed_count == 0 {
340 println!(
341 "{} All {} providers updated successfully",
342 "â
".green(),
343 updated_count
344 );
345 } else {
346 println!(
347 "{} Updated {} providers, {} failed",
348 "â ī¸".yellow(),
349 updated_count,
350 failed_count
351 );
352 }
353
354 Ok(())
355 }
356
357 pub fn list_installed_providers(&self) -> Result<Vec<String>> {
359 if !self.providers_dir.exists() {
360 return Ok(Vec::new());
361 }
362
363 let mut providers = Vec::new();
364
365 for entry in fs::read_dir(&self.providers_dir)? {
366 let entry = entry?;
367 let path = entry.path();
368
369 if path.extension().and_then(|s| s.to_str()) == Some("toml") {
370 if let Some(name) = path.file_stem().and_then(|s| s.to_str()) {
371 providers.push(name.to_string());
372 }
373 }
374 }
375
376 providers.sort();
377 Ok(providers)
378 }
379
380 pub fn uninstall_provider(&self, provider_id: &str) -> Result<()> {
382 let provider_file = self.providers_dir.join(format!("{}.toml", provider_id));
383
384 if !provider_file.exists() {
385 anyhow::bail!("Provider '{}' is not installed", provider_id);
386 }
387
388 fs::remove_file(&provider_file)?;
389
390 println!(
391 "{} Provider '{}' uninstalled successfully",
392 "â
".green(),
393 provider_id
394 );
395
396 let keys = crate::keys::KeysConfig::load()?;
398 if keys.has_auth(provider_id) {
399 println!(
400 "{} Note: API keys for '{}' are still stored in keys.toml",
401 "âšī¸".blue(),
402 provider_id
403 );
404 println!(" To remove them, use: lc keys remove {}", provider_id);
405 }
406
407 Ok(())
408 }
409
410 fn validate_provider_config(&self, config_content: &str) -> Result<()> {
412 let config: toml::Value = toml::from_str(config_content)
414 .map_err(|e| anyhow::anyhow!("Invalid TOML format: {}", e))?;
415
416 let required_fields = ["endpoint", "models_path", "chat_path"];
418
419 for field in &required_fields {
420 if !config.get(field).is_some() {
421 anyhow::bail!("Provider config missing required field: {}", field);
422 }
423 }
424
425 Ok(())
426 }
427
428 fn show_auth_instructions(&self, provider_id: &str, metadata: &ProviderMetadata) -> Result<()> {
430 println!("\n{} Authentication Setup", "đ".yellow());
431
432 match metadata.auth_type {
433 AuthType::ApiKey => {
434 println!("This provider requires an API key.");
435 println!("To set it up, run:");
436 println!(" {}", format!("lc keys add {}", provider_id).bold());
437 }
438 AuthType::ServiceAccount => {
439 println!("This provider requires a service account JSON.");
440 println!("To set it up, run:");
441 println!(" {}", format!("lc keys add {}", provider_id).bold());
442 }
443 AuthType::OAuth => {
444 println!("This provider uses OAuth authentication.");
445 println!("Follow the provider's documentation to set up OAuth.");
446 if let Some(docs_url) = &metadata.docs_url {
447 println!("Documentation: {}", docs_url.blue());
448 }
449 }
450 AuthType::Token => {
451 println!("This provider requires an authentication token.");
452 println!("To set it up, run:");
453 println!(" {}", format!("lc keys add {}", provider_id).bold());
454 }
455 AuthType::Headers => {
456 println!("This provider requires custom authentication headers.");
457 println!("To set them up, run:");
458 println!(
459 " {}",
460 format!(
461 "lc providers headers {} add <header-name> <header-value>",
462 provider_id
463 )
464 .bold()
465 );
466 }
467 AuthType::None => {
468 println!("This provider does not require authentication.");
469 }
470 }
471
472 Ok(())
473 }
474}
475
476#[allow(dead_code)]
478pub fn create_sample_registry() -> ProviderRegistry {
479 let mut providers = HashMap::new();
480
481 providers.insert(
483 "openai".to_string(),
484 ProviderMetadata {
485 name: "OpenAI".to_string(),
486 description: "OpenAI GPT models including GPT-4 and GPT-3.5".to_string(),
487 config_file: "openai.toml".to_string(),
488 version: "1.0.0".to_string(),
489 auth_type: AuthType::ApiKey,
490 tags: vec![
491 "official".to_string(),
492 "chat".to_string(),
493 "embeddings".to_string(),
494 ],
495 official: true,
496 docs_url: Some("https://platform.openai.com/docs".to_string()),
497 min_version: None,
498 },
499 );
500
501 providers.insert(
502 "gemini".to_string(),
503 ProviderMetadata {
504 name: "Google Gemini".to_string(),
505 description: "Google's Gemini models".to_string(),
506 config_file: "gemini.toml".to_string(),
507 version: "1.0.0".to_string(),
508 auth_type: AuthType::ApiKey,
509 tags: vec![
510 "official".to_string(),
511 "chat".to_string(),
512 "vision".to_string(),
513 ],
514 official: true,
515 docs_url: Some("https://ai.google.dev/docs".to_string()),
516 min_version: None,
517 },
518 );
519
520 providers.insert(
521 "anthropic".to_string(),
522 ProviderMetadata {
523 name: "Anthropic Claude".to_string(),
524 description: "Anthropic's Claude models".to_string(),
525 config_file: "anthropic.toml".to_string(),
526 version: "1.0.0".to_string(),
527 auth_type: AuthType::ApiKey,
528 tags: vec!["official".to_string(), "chat".to_string()],
529 official: true,
530 docs_url: Some("https://docs.anthropic.com".to_string()),
531 min_version: None,
532 },
533 );
534
535 ProviderRegistry {
536 version: "1.0.0".to_string(),
537 providers,
538 base_url: "https://raw.githubusercontent.com/rajashekar/lc-providers/main".to_string(),
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545
546 #[test]
547 fn test_provider_metadata_serialization() {
548 let metadata = ProviderMetadata {
549 name: "Test Provider".to_string(),
550 description: "A test provider".to_string(),
551 config_file: "test.toml".to_string(),
552 version: "1.0.0".to_string(),
553 auth_type: AuthType::ApiKey,
554 tags: vec!["test".to_string()],
555 official: false,
556 docs_url: None,
557 min_version: None,
558 };
559
560 let json = serde_json::to_string(&metadata).unwrap();
561 let deserialized: ProviderMetadata = serde_json::from_str(&json).unwrap();
562
563 assert_eq!(metadata.name, deserialized.name);
564 assert_eq!(metadata.version, deserialized.version);
565 }
566
567 #[test]
568 fn test_registry_creation() {
569 let registry = create_sample_registry();
570
571 assert!(registry.providers.contains_key("openai"));
572 assert!(registry.providers.contains_key("gemini"));
573 assert!(registry.providers.contains_key("anthropic"));
574
575 let openai = ®istry.providers["openai"];
576 assert_eq!(openai.name, "OpenAI");
577 assert!(openai.official);
578 }
579}