lc/data/
keys.rs

1//! Centralized API key management for providers
2//!
3//! This module handles storing and retrieving API keys separately from provider configurations,
4//! allowing provider configs to be shared and version-controlled without exposing secrets.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10use std::path::PathBuf;
11
12/// Structure for storing API keys and secrets
13#[derive(Debug, Serialize, Deserialize, Clone, Default)]
14pub struct KeysConfig {
15    /// Provider API keys - provider_name -> api_key
16    #[serde(default)]
17    pub api_keys: HashMap<String, String>,
18
19    /// Additional authentication tokens (e.g., for search providers)
20    #[serde(default)]
21    pub tokens: HashMap<String, String>,
22
23    /// Service account JSON strings for providers like Google Vertex AI
24    #[serde(default)]
25    pub service_accounts: HashMap<String, String>,
26
27    /// OAuth tokens for providers that use OAuth
28    #[serde(default)]
29    pub oauth_tokens: HashMap<String, String>,
30
31    /// Custom headers that contain sensitive values (renamed from auth_headers)
32    #[serde(default, alias = "auth_headers")]
33    pub custom_headers: HashMap<String, HashMap<String, String>>,
34}
35
36impl KeysConfig {
37    /// Create a new empty KeysConfig
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Load keys configuration from file
43    pub fn load() -> Result<Self> {
44        let keys_path = Self::keys_file_path()?;
45
46        if keys_path.exists() {
47            let content = fs::read_to_string(&keys_path)?;
48            let config: KeysConfig = toml::from_str(&content)?;
49            Ok(config)
50        } else {
51            // Create default empty keys config
52            let config = KeysConfig::default();
53            // Ensure directory exists
54            if let Some(parent) = keys_path.parent() {
55                fs::create_dir_all(parent)?;
56            }
57            config.save()?;
58            Ok(config)
59        }
60    }
61
62    /// Save keys configuration to file
63    pub fn save(&self) -> Result<()> {
64        let keys_path = Self::keys_file_path()?;
65
66        // Ensure directory exists
67        if let Some(parent) = keys_path.parent() {
68            fs::create_dir_all(parent)?;
69        }
70
71        let content = toml::to_string_pretty(self)?;
72        fs::write(&keys_path, content)?;
73
74        // Set restrictive permissions on keys file (Unix-like systems)
75        #[cfg(unix)]
76        {
77            use std::os::unix::fs::PermissionsExt;
78            let metadata = fs::metadata(&keys_path)?;
79            let mut permissions = metadata.permissions();
80            permissions.set_mode(0o600); // Read/write for owner only
81            fs::set_permissions(&keys_path, permissions)?;
82        }
83
84        Ok(())
85    }
86
87    /// Get the path to the keys.toml file
88    fn keys_file_path() -> Result<PathBuf> {
89        let config_dir = crate::config::Config::config_dir()?;
90        Ok(config_dir.join("keys.toml"))
91    }
92
93    /// Set an API key for a provider
94    pub fn set_api_key(&mut self, provider: String, api_key: String) -> Result<()> {
95        self.api_keys.insert(provider, api_key);
96        self.save()
97    }
98
99    /// Get an API key for a provider
100    #[allow(dead_code)]
101    pub fn get_api_key(&self, provider: &str) -> Option<&String> {
102        self.api_keys.get(provider)
103    }
104
105    /// Remove an API key for a provider
106    #[allow(dead_code)]
107    pub fn remove_api_key(&mut self, provider: &str) -> Result<bool> {
108        let removed = self.api_keys.remove(provider).is_some();
109        if removed {
110            self.save()?;
111        }
112        Ok(removed)
113    }
114
115    /// Set a service account JSON for a provider
116    #[allow(dead_code)]
117    pub fn set_service_account(&mut self, provider: String, sa_json: String) -> Result<()> {
118        self.service_accounts.insert(provider, sa_json);
119        self.save()
120    }
121
122    /// Get a service account JSON for a provider
123    #[allow(dead_code)]
124    pub fn get_service_account(&self, provider: &str) -> Option<&String> {
125        self.service_accounts.get(provider)
126    }
127
128    /// Set an authentication token
129    #[allow(dead_code)]
130    pub fn set_token(&mut self, name: String, token: String) -> Result<()> {
131        self.tokens.insert(name, token);
132        self.save()
133    }
134
135    /// Get an authentication token
136    #[allow(dead_code)]
137    pub fn get_token(&self, name: &str) -> Option<&String> {
138        self.tokens.get(name)
139    }
140
141    /// Set authentication headers for a provider
142    #[allow(dead_code)]
143    pub fn set_auth_headers(
144        &mut self,
145        provider: String,
146        headers: HashMap<String, String>,
147    ) -> Result<()> {
148        self.custom_headers.insert(provider, headers);
149        self.save()
150    }
151
152    /// Get authentication headers for a provider (returns custom headers)
153    #[allow(dead_code)]
154    pub fn get_auth_headers(&self, provider: &str) -> HashMap<String, String> {
155        let mut headers = HashMap::new();
156
157        // Check for API key
158        if let Some(api_key) = self.api_keys.get(provider) {
159            headers.insert("Authorization".to_string(), format!("Bearer {}", api_key));
160        }
161
162        // Check for custom headers
163        if let Some(custom) = self.custom_headers.get(provider) {
164            headers.extend(custom.clone());
165        }
166
167        headers
168    }
169
170    /// Get authentication for a provider (returns the appropriate auth type)
171    pub fn get_auth(&self, provider: &str) -> Option<ProviderAuth> {
172        // Check different auth types in order
173        if let Some(api_key) = self.api_keys.get(provider) {
174            return Some(ProviderAuth::ApiKey(api_key.clone()));
175        }
176
177        if let Some(sa) = self.service_accounts.get(provider) {
178            return Some(ProviderAuth::ServiceAccount(sa.clone()));
179        }
180
181        if let Some(oauth) = self.oauth_tokens.get(provider) {
182            return Some(ProviderAuth::OAuthToken(oauth.clone()));
183        }
184
185        if let Some(token) = self.tokens.get(provider) {
186            return Some(ProviderAuth::Token(token.clone()));
187        }
188
189        if let Some(headers) = self.custom_headers.get(provider) {
190            return Some(ProviderAuth::Headers(headers.clone()));
191        }
192
193        None
194    }
195
196    /// List all providers with configured keys
197    #[allow(dead_code)]
198    pub fn list_providers_with_keys(&self) -> Vec<String> {
199        let mut providers = Vec::new();
200
201        for key in self.api_keys.keys() {
202            if !providers.contains(key) {
203                providers.push(key.clone());
204            }
205        }
206
207        for key in self.service_accounts.keys() {
208            if !providers.contains(key) {
209                providers.push(key.clone());
210            }
211        }
212
213        for key in self.custom_headers.keys() {
214            if !providers.contains(key) {
215                providers.push(key.clone());
216            }
217        }
218
219        providers.sort();
220        providers
221    }
222
223    /// Check if a provider has any authentication configured
224    pub fn has_auth(&self, provider: &str) -> bool {
225        self.api_keys.contains_key(provider)
226            || self.service_accounts.contains_key(provider)
227            || self.custom_headers.contains_key(provider)
228            || self.oauth_tokens.contains_key(provider)
229            || self.tokens.contains_key(provider)
230    }
231
232    /// Migrate keys from old provider configs to centralized keys.toml
233    pub fn migrate_from_provider_configs(config: &crate::config::Config) -> Result<Self> {
234        let mut keys_config = Self::load()?;
235        let mut migrated_count = 0;
236
237        for (provider_name, provider_config) in &config.providers {
238            // Migrate API keys
239            if let Some(api_key) = &provider_config.api_key {
240                if !api_key.is_empty() && !keys_config.api_keys.contains_key(provider_name) {
241                    keys_config
242                        .api_keys
243                        .insert(provider_name.clone(), api_key.clone());
244                    migrated_count += 1;
245                    crate::debug_log!("Migrated API key for provider '{}'", provider_name);
246                }
247            }
248
249            // Migrate auth headers
250            let mut auth_headers = HashMap::new();
251            for (header_name, header_value) in &provider_config.headers {
252                let header_lower = header_name.to_lowercase();
253                if header_lower.contains("key")
254                    || header_lower.contains("token")
255                    || header_lower.contains("auth")
256                    || header_lower.contains("secret")
257                {
258                    auth_headers.insert(header_name.clone(), header_value.clone());
259                }
260            }
261
262            if !auth_headers.is_empty() && !keys_config.custom_headers.contains_key(provider_name) {
263                keys_config
264                    .custom_headers
265                    .insert(provider_name.clone(), auth_headers);
266                migrated_count += 1;
267                crate::debug_log!("Migrated auth headers for provider '{}'", provider_name);
268            }
269        }
270
271        if migrated_count > 0 {
272            keys_config.save()?;
273            println!(
274                "✓ Migrated {} authentication configurations to keys.toml",
275                migrated_count
276            );
277        }
278
279        Ok(keys_config)
280    }
281}
282
283/// Helper function to get authentication for a provider from centralized keys
284pub fn get_provider_auth(provider: &str) -> Result<Option<ProviderAuth>> {
285    let keys = KeysConfig::load()?;
286    Ok(keys.get_auth(provider))
287}
288
289/// Types of authentication that can be stored
290#[derive(Debug, Clone)]
291pub enum ProviderAuth {
292    ApiKey(String),
293    ServiceAccount(String),
294    OAuthToken(String),
295    Token(String),
296    Headers(HashMap<String, String>),
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use serial_test::serial;
303    use std::env;
304    use tempfile::TempDir;
305
306    #[test]
307    #[serial]
308    fn test_keys_config_save_load() {
309        // Create a temporary directory for testing
310        let temp_dir = TempDir::new().unwrap();
311        env::set_var("LC_TEST_CONFIG_DIR", temp_dir.path());
312
313        let mut keys = KeysConfig::default();
314        keys.set_api_key("test-openai-save-load".to_string(), "test-key".to_string())
315            .unwrap();
316
317        // Load a fresh instance to test persistence
318        let loaded_keys = KeysConfig::load().unwrap();
319
320        // Check that the loaded keys has the API key using get_auth instead
321        let auth = loaded_keys.get_auth("test-openai-save-load");
322        assert!(auth.is_some());
323
324        if let Some(ProviderAuth::ApiKey(key)) = auth {
325            assert_eq!(key, "test-key");
326        } else {
327            panic!("Expected API key auth type");
328        }
329    }
330
331    #[test]
332    #[serial]
333    fn test_provider_auth_types() {
334        let temp_dir = TempDir::new().unwrap();
335        env::set_var("LC_TEST_CONFIG_DIR", temp_dir.path());
336
337        let mut keys = KeysConfig::default();
338
339        // Test API key
340        keys.set_api_key("test-openai-auth-types".to_string(), "sk-test".to_string())
341            .unwrap();
342        assert!(keys.has_auth("test-openai-auth-types"));
343
344        // Test service account
345        keys.set_service_account(
346            "test-vertex-auth-types".to_string(),
347            "{\"type\":\"service_account\"}".to_string(),
348        )
349        .unwrap();
350        assert!(keys.has_auth("test-vertex-auth-types"));
351
352        // Test auth headers
353        let mut headers = HashMap::new();
354        headers.insert("X-API-Key".to_string(), "test-key".to_string());
355        keys.set_auth_headers("test-custom-auth-types".to_string(), headers)
356            .unwrap();
357        assert!(keys.has_auth("test-custom-auth-types"));
358
359        // Test non-existent provider
360        assert!(!keys.has_auth("nonexistent"));
361    }
362}