lc/core/
provider_installer.rs

1//! Provider configuration installer and manager
2//!
3//! This module handles downloading, installing, and updating provider configurations
4//! from a central repository, keeping API keys separate from the configurations.
5
6use anyhow::Result;
7use colored::Colorize;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fs;
11use std::path::PathBuf;
12
13/// Provider registry that lists available providers and their metadata
14#[derive(Debug, Serialize, Deserialize, Clone)]
15pub struct ProviderRegistry {
16    /// Version of the registry format
17    pub version: String,
18
19    /// List of available providers
20    pub providers: HashMap<String, ProviderMetadata>,
21
22    /// Base URL for downloading provider configs
23    pub base_url: String,
24}
25
26/// Metadata about a provider
27#[derive(Debug, Serialize, Deserialize, Clone)]
28pub struct ProviderMetadata {
29    /// Display name of the provider
30    pub name: String,
31
32    /// Description of the provider
33    pub description: String,
34
35    /// Provider configuration file name
36    pub config_file: String,
37
38    /// Version of the provider config
39    pub version: String,
40
41    /// Required authentication type
42    pub auth_type: AuthType,
43
44    /// Tags for categorization
45    #[serde(default)]
46    pub tags: Vec<String>,
47
48    /// Whether this provider is officially supported
49    #[serde(default)]
50    pub official: bool,
51
52    /// Documentation URL
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub docs_url: Option<String>,
55
56    /// Minimum lc version required
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub min_version: Option<String>,
59}
60
61/// Types of authentication required by providers
62#[derive(Debug, Serialize, Deserialize, Clone)]
63#[serde(rename_all = "snake_case")]
64pub enum AuthType {
65    ApiKey,
66    ServiceAccount,
67    OAuth,
68    Token,
69    Headers,
70    None,
71}
72
73/// Provider installer that manages downloading and installing provider configs
74pub struct ProviderInstaller {
75    /// Registry URL or local path
76    registry_source: String,
77
78    /// Cache directory for downloaded configs
79    cache_dir: PathBuf,
80
81    /// Target directory for installed providers
82    providers_dir: PathBuf,
83}
84
85impl ProviderInstaller {
86    /// Create a new provider installer
87    pub fn new() -> Result<Self> {
88        let config_dir = crate::config::Config::config_dir()?;
89        let cache_dir = config_dir.join(".provider_cache");
90        let providers_dir = config_dir.join("providers");
91
92        // Default to GitHub repository
93        let registry_source = std::env::var("LC_PROVIDER_REGISTRY").unwrap_or_else(|_| {
94            "https://raw.githubusercontent.com/rajashekar/lc-providers/main".to_string()
95        });
96
97        Ok(Self {
98            registry_source,
99            cache_dir,
100            providers_dir,
101        })
102    }
103
104    /// Fetch the provider registry
105    pub async fn fetch_registry(&self) -> Result<ProviderRegistry> {
106        let registry_url = format!("{}/registry.json", self.registry_source);
107
108        crate::debug_log!("Fetching provider registry from: {}", registry_url);
109
110        // Handle local file paths
111        if registry_url.starts_with("file://") {
112            let path = registry_url
113                .strip_prefix("file://")
114                .ok_or_else(|| anyhow::anyhow!("Invalid file:// URL format"))?;
115            let content = fs::read_to_string(path)
116                .map_err(|e| anyhow::anyhow!("Failed to read local registry: {}", e))?;
117            let registry: ProviderRegistry = serde_json::from_str(&content)
118                .map_err(|e| anyhow::anyhow!("Failed to parse registry: {}", e))?;
119
120            // Cache the registry locally
121            self.cache_registry(&registry)?;
122
123            return Ok(registry);
124        }
125
126        // Create HTTP client for remote URLs
127        let client = reqwest::Client::builder()
128            .timeout(std::time::Duration::from_secs(30))
129            .build()?;
130
131        let response = client
132            .get(&registry_url)
133            .send()
134            .await
135            .map_err(|e| anyhow::anyhow!("Failed to fetch registry: {}", e))?;
136
137        if !response.status().is_success() {
138            anyhow::bail!("Failed to fetch registry: HTTP {}", response.status());
139        }
140
141        let registry: ProviderRegistry = response
142            .json()
143            .await
144            .map_err(|e| anyhow::anyhow!("Failed to parse registry: {}", e))?;
145
146        // Cache the registry locally
147        self.cache_registry(&registry)?;
148
149        Ok(registry)
150    }
151
152    /// Get cached registry if available
153    pub fn get_cached_registry(&self) -> Result<Option<ProviderRegistry>> {
154        let cache_file = self.cache_dir.join("registry.json");
155
156        if !cache_file.exists() {
157            return Ok(None);
158        }
159
160        // Check if cache is fresh (less than 24 hours old)
161        let metadata = fs::metadata(&cache_file)?;
162        if let Ok(modified) = metadata.modified() {
163            let age = std::time::SystemTime::now()
164                .duration_since(modified)
165                .unwrap_or(std::time::Duration::MAX);
166
167            if age > std::time::Duration::from_secs(24 * 60 * 60) {
168                crate::debug_log!("Registry cache is stale (>24 hours old)");
169                return Ok(None);
170            }
171        }
172
173        let content = fs::read_to_string(&cache_file)?;
174        let registry: ProviderRegistry = serde_json::from_str(&content)?;
175
176        Ok(Some(registry))
177    }
178
179    /// Cache the registry locally
180    fn cache_registry(&self, registry: &ProviderRegistry) -> Result<()> {
181        fs::create_dir_all(&self.cache_dir)?;
182
183        let cache_file = self.cache_dir.join("registry.json");
184        let content = serde_json::to_string_pretty(registry)?;
185        fs::write(&cache_file, content)?;
186
187        Ok(())
188    }
189
190    /// List available providers
191    pub async fn list_available(&self) -> Result<Vec<(String, ProviderMetadata)>> {
192        // Try cached registry first
193        let registry = if let Some(cached) = self.get_cached_registry()? {
194            cached
195        } else {
196            self.fetch_registry().await?
197        };
198
199        let mut providers: Vec<_> = registry.providers.into_iter().collect();
200        providers.sort_by(|a, b| a.0.cmp(&b.0));
201
202        Ok(providers)
203    }
204
205    /// Install a provider configuration
206    pub async fn install_provider(&self, provider_id: &str, force: bool) -> Result<()> {
207        println!("{} Installing provider '{}'...", "đŸ“Ļ".blue(), provider_id);
208
209        // Fetch registry
210        let registry = if let Some(cached) = self.get_cached_registry()? {
211            cached
212        } else {
213            println!("{} Fetching provider registry...", "🔄".blue());
214            self.fetch_registry().await?
215        };
216
217        // Find provider metadata
218        let metadata = registry
219            .providers
220            .get(provider_id)
221            .ok_or_else(|| anyhow::anyhow!("Provider '{}' not found in registry", provider_id))?;
222
223        // Check if already installed
224        let target_file = self.providers_dir.join(&metadata.config_file);
225        if target_file.exists() && !force {
226            // Check version
227            if let Ok(existing_config) = fs::read_to_string(&target_file) {
228                if let Ok(existing_toml) = toml::from_str::<toml::Value>(&existing_config) {
229                    if let Some(existing_version) =
230                        existing_toml.get("version").and_then(|v| v.as_str())
231                    {
232                        if existing_version == metadata.version {
233                            println!(
234                                "{} Provider '{}' is already up to date (v{})",
235                                "✓".green(),
236                                provider_id,
237                                metadata.version
238                            );
239                            return Ok(());
240                        }
241                    }
242                }
243            }
244
245            println!(
246                "{} Provider '{}' already exists. Updating to v{}...",
247                "🔄".yellow(),
248                provider_id,
249                metadata.version
250            );
251        }
252
253        // Download provider config
254        let config_url = format!("{}/providers/{}", registry.base_url, metadata.config_file);
255
256        crate::debug_log!("Downloading provider config from: {}", config_url);
257
258        let config_content = if config_url.starts_with("file://") {
259            // Handle local file
260            let path = config_url
261                .strip_prefix("file://")
262                .ok_or_else(|| anyhow::anyhow!("Invalid file:// URL format"))?;
263            fs::read_to_string(path)
264                .map_err(|e| anyhow::anyhow!("Failed to read local provider config: {}", e))?
265        } else {
266            // Handle remote URL
267            let client = reqwest::Client::builder()
268                .timeout(std::time::Duration::from_secs(30))
269                .build()?;
270
271            let response = client
272                .get(&config_url)
273                .send()
274                .await
275                .map_err(|e| anyhow::anyhow!("Failed to download provider config: {}", e))?;
276
277            if !response.status().is_success() {
278                anyhow::bail!(
279                    "Failed to download provider config: HTTP {}",
280                    response.status()
281                );
282            }
283
284            response.text().await?
285        };
286
287        // Validate the downloaded config
288        self.validate_provider_config(&config_content)?;
289
290        // Ensure providers directory exists
291        fs::create_dir_all(&self.providers_dir)?;
292
293        // Write the provider config
294        fs::write(&target_file, &config_content)?;
295
296        println!(
297            "{} Provider '{}' installed successfully (v{})",
298            "✅".green(),
299            provider_id,
300            metadata.version
301        );
302
303        // Show authentication instructions
304        self.show_auth_instructions(provider_id, metadata)?;
305
306        Ok(())
307    }
308
309    /// Update a provider configuration
310    pub async fn update_provider(&self, provider_id: &str) -> Result<()> {
311        self.install_provider(provider_id, true).await
312    }
313
314    /// Update all installed providers
315    pub async fn update_all_providers(&self) -> Result<()> {
316        println!("{} Updating all installed providers...", "🔄".blue());
317
318        // Get list of installed providers
319        let installed = self.list_installed_providers()?;
320
321        if installed.is_empty() {
322            println!("{} No providers installed", "â„šī¸".blue());
323            return Ok(());
324        }
325
326        let mut updated_count = 0;
327        let mut failed_count = 0;
328
329        for provider_id in installed {
330            match self.update_provider(&provider_id).await {
331                Ok(_) => updated_count += 1,
332                Err(e) => {
333                    eprintln!("{} Failed to update '{}': {}", "❌".red(), provider_id, e);
334                    failed_count += 1;
335                }
336            }
337        }
338
339        if failed_count == 0 {
340            println!(
341                "{} All {} providers updated successfully",
342                "✅".green(),
343                updated_count
344            );
345        } else {
346            println!(
347                "{} Updated {} providers, {} failed",
348                "âš ī¸".yellow(),
349                updated_count,
350                failed_count
351            );
352        }
353
354        Ok(())
355    }
356
357    /// List installed providers
358    pub fn list_installed_providers(&self) -> Result<Vec<String>> {
359        if !self.providers_dir.exists() {
360            return Ok(Vec::new());
361        }
362
363        let mut providers = Vec::new();
364
365        for entry in fs::read_dir(&self.providers_dir)? {
366            let entry = entry?;
367            let path = entry.path();
368
369            if path.extension().and_then(|s| s.to_str()) == Some("toml") {
370                if let Some(name) = path.file_stem().and_then(|s| s.to_str()) {
371                    providers.push(name.to_string());
372                }
373            }
374        }
375
376        providers.sort();
377        Ok(providers)
378    }
379
380    /// Remove an installed provider
381    pub fn uninstall_provider(&self, provider_id: &str) -> Result<()> {
382        let provider_file = self.providers_dir.join(format!("{}.toml", provider_id));
383
384        if !provider_file.exists() {
385            anyhow::bail!("Provider '{}' is not installed", provider_id);
386        }
387
388        fs::remove_file(&provider_file)?;
389
390        println!(
391            "{} Provider '{}' uninstalled successfully",
392            "✅".green(),
393            provider_id
394        );
395
396        // Check if there are any API keys to clean up
397        let keys = crate::keys::KeysConfig::load()?;
398        if keys.has_auth(provider_id) {
399            println!(
400                "{} Note: API keys for '{}' are still stored in keys.toml",
401                "â„šī¸".blue(),
402                provider_id
403            );
404            println!("  To remove them, use: lc keys remove {}", provider_id);
405        }
406
407        Ok(())
408    }
409
410    /// Validate a provider configuration
411    fn validate_provider_config(&self, config_content: &str) -> Result<()> {
412        // Try to parse as TOML
413        let config: toml::Value = toml::from_str(config_content)
414            .map_err(|e| anyhow::anyhow!("Invalid TOML format: {}", e))?;
415
416        // Check required fields
417        let required_fields = ["endpoint", "models_path", "chat_path"];
418
419        for field in &required_fields {
420            if !config.get(field).is_some() {
421                anyhow::bail!("Provider config missing required field: {}", field);
422            }
423        }
424
425        Ok(())
426    }
427
428    /// Show authentication instructions for a provider
429    fn show_auth_instructions(&self, provider_id: &str, metadata: &ProviderMetadata) -> Result<()> {
430        println!("\n{} Authentication Setup", "🔑".yellow());
431
432        match metadata.auth_type {
433            AuthType::ApiKey => {
434                println!("This provider requires an API key.");
435                println!("To set it up, run:");
436                println!("  {}", format!("lc keys add {}", provider_id).bold());
437            }
438            AuthType::ServiceAccount => {
439                println!("This provider requires a service account JSON.");
440                println!("To set it up, run:");
441                println!("  {}", format!("lc keys add {}", provider_id).bold());
442            }
443            AuthType::OAuth => {
444                println!("This provider uses OAuth authentication.");
445                println!("Follow the provider's documentation to set up OAuth.");
446                if let Some(docs_url) = &metadata.docs_url {
447                    println!("Documentation: {}", docs_url.blue());
448                }
449            }
450            AuthType::Token => {
451                println!("This provider requires an authentication token.");
452                println!("To set it up, run:");
453                println!("  {}", format!("lc keys add {}", provider_id).bold());
454            }
455            AuthType::Headers => {
456                println!("This provider requires custom authentication headers.");
457                println!("To set them up, run:");
458                println!(
459                    "  {}",
460                    format!(
461                        "lc providers headers {} add <header-name> <header-value>",
462                        provider_id
463                    )
464                    .bold()
465                );
466            }
467            AuthType::None => {
468                println!("This provider does not require authentication.");
469            }
470        }
471
472        Ok(())
473    }
474}
475
476/// Create a sample provider registry for testing
477#[allow(dead_code)]
478pub fn create_sample_registry() -> ProviderRegistry {
479    let mut providers = HashMap::new();
480
481    // Add sample providers
482    providers.insert(
483        "openai".to_string(),
484        ProviderMetadata {
485            name: "OpenAI".to_string(),
486            description: "OpenAI GPT models including GPT-4 and GPT-3.5".to_string(),
487            config_file: "openai.toml".to_string(),
488            version: "1.0.0".to_string(),
489            auth_type: AuthType::ApiKey,
490            tags: vec![
491                "official".to_string(),
492                "chat".to_string(),
493                "embeddings".to_string(),
494            ],
495            official: true,
496            docs_url: Some("https://platform.openai.com/docs".to_string()),
497            min_version: None,
498        },
499    );
500
501    providers.insert(
502        "gemini".to_string(),
503        ProviderMetadata {
504            name: "Google Gemini".to_string(),
505            description: "Google's Gemini models".to_string(),
506            config_file: "gemini.toml".to_string(),
507            version: "1.0.0".to_string(),
508            auth_type: AuthType::ApiKey,
509            tags: vec![
510                "official".to_string(),
511                "chat".to_string(),
512                "vision".to_string(),
513            ],
514            official: true,
515            docs_url: Some("https://ai.google.dev/docs".to_string()),
516            min_version: None,
517        },
518    );
519
520    providers.insert(
521        "anthropic".to_string(),
522        ProviderMetadata {
523            name: "Anthropic Claude".to_string(),
524            description: "Anthropic's Claude models".to_string(),
525            config_file: "anthropic.toml".to_string(),
526            version: "1.0.0".to_string(),
527            auth_type: AuthType::ApiKey,
528            tags: vec!["official".to_string(), "chat".to_string()],
529            official: true,
530            docs_url: Some("https://docs.anthropic.com".to_string()),
531            min_version: None,
532        },
533    );
534
535    ProviderRegistry {
536        version: "1.0.0".to_string(),
537        providers,
538        base_url: "https://raw.githubusercontent.com/rajashekar/lc-providers/main".to_string(),
539    }
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    #[test]
547    fn test_provider_metadata_serialization() {
548        let metadata = ProviderMetadata {
549            name: "Test Provider".to_string(),
550            description: "A test provider".to_string(),
551            config_file: "test.toml".to_string(),
552            version: "1.0.0".to_string(),
553            auth_type: AuthType::ApiKey,
554            tags: vec!["test".to_string()],
555            official: false,
556            docs_url: None,
557            min_version: None,
558        };
559
560        let json = serde_json::to_string(&metadata).unwrap();
561        let deserialized: ProviderMetadata = serde_json::from_str(&json).unwrap();
562
563        assert_eq!(metadata.name, deserialized.name);
564        assert_eq!(metadata.version, deserialized.version);
565    }
566
567    #[test]
568    fn test_registry_creation() {
569        let registry = create_sample_registry();
570
571        assert!(registry.providers.contains_key("openai"));
572        assert!(registry.providers.contains_key("gemini"));
573        assert!(registry.providers.contains_key("anthropic"));
574
575        let openai = &registry.providers["openai"];
576        assert_eq!(openai.name, "OpenAI");
577        assert!(openai.official);
578    }
579}