1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10use std::path::PathBuf;
11
12#[derive(Debug, Serialize, Deserialize, Clone, Default)]
14pub struct KeysConfig {
15 #[serde(default)]
17 pub api_keys: HashMap<String, String>,
18
19 #[serde(default)]
21 pub tokens: HashMap<String, String>,
22
23 #[serde(default)]
25 pub service_accounts: HashMap<String, String>,
26
27 #[serde(default)]
29 pub oauth_tokens: HashMap<String, String>,
30
31 #[serde(default, alias = "auth_headers")]
33 pub custom_headers: HashMap<String, HashMap<String, String>>,
34}
35
36impl KeysConfig {
37 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn load() -> Result<Self> {
44 let keys_path = Self::keys_file_path()?;
45
46 if keys_path.exists() {
47 let content = fs::read_to_string(&keys_path)?;
48 let config: KeysConfig = toml::from_str(&content)?;
49 Ok(config)
50 } else {
51 let config = KeysConfig::default();
53 if let Some(parent) = keys_path.parent() {
55 fs::create_dir_all(parent)?;
56 }
57 config.save()?;
58 Ok(config)
59 }
60 }
61
62 pub fn save(&self) -> Result<()> {
64 let keys_path = Self::keys_file_path()?;
65
66 if let Some(parent) = keys_path.parent() {
68 fs::create_dir_all(parent)?;
69 }
70
71 let content = toml::to_string_pretty(self)?;
72 fs::write(&keys_path, content)?;
73
74 #[cfg(unix)]
76 {
77 use std::os::unix::fs::PermissionsExt;
78 let metadata = fs::metadata(&keys_path)?;
79 let mut permissions = metadata.permissions();
80 permissions.set_mode(0o600); fs::set_permissions(&keys_path, permissions)?;
82 }
83
84 Ok(())
85 }
86
87 fn keys_file_path() -> Result<PathBuf> {
89 let config_dir = crate::config::Config::config_dir()?;
90 Ok(config_dir.join("keys.toml"))
91 }
92
93 pub fn set_api_key(&mut self, provider: String, api_key: String) -> Result<()> {
95 self.api_keys.insert(provider, api_key);
96 self.save()
97 }
98
99 #[allow(dead_code)]
101 pub fn get_api_key(&self, provider: &str) -> Option<&String> {
102 self.api_keys.get(provider)
103 }
104
105 #[allow(dead_code)]
107 pub fn remove_api_key(&mut self, provider: &str) -> Result<bool> {
108 let removed = self.api_keys.remove(provider).is_some();
109 if removed {
110 self.save()?;
111 }
112 Ok(removed)
113 }
114
115 #[allow(dead_code)]
117 pub fn set_service_account(&mut self, provider: String, sa_json: String) -> Result<()> {
118 self.service_accounts.insert(provider, sa_json);
119 self.save()
120 }
121
122 #[allow(dead_code)]
124 pub fn get_service_account(&self, provider: &str) -> Option<&String> {
125 self.service_accounts.get(provider)
126 }
127
128 #[allow(dead_code)]
130 pub fn set_token(&mut self, name: String, token: String) -> Result<()> {
131 self.tokens.insert(name, token);
132 self.save()
133 }
134
135 #[allow(dead_code)]
137 pub fn get_token(&self, name: &str) -> Option<&String> {
138 self.tokens.get(name)
139 }
140
141 #[allow(dead_code)]
143 pub fn set_auth_headers(
144 &mut self,
145 provider: String,
146 headers: HashMap<String, String>,
147 ) -> Result<()> {
148 self.custom_headers.insert(provider, headers);
149 self.save()
150 }
151
152 #[allow(dead_code)]
154 pub fn get_auth_headers(&self, provider: &str) -> HashMap<String, String> {
155 let mut headers = HashMap::new();
156
157 if let Some(api_key) = self.api_keys.get(provider) {
159 headers.insert("Authorization".to_string(), format!("Bearer {}", api_key));
160 }
161
162 if let Some(custom) = self.custom_headers.get(provider) {
164 headers.extend(custom.clone());
165 }
166
167 headers
168 }
169
170 pub fn get_auth(&self, provider: &str) -> Option<ProviderAuth> {
172 if let Some(api_key) = self.api_keys.get(provider) {
174 return Some(ProviderAuth::ApiKey(api_key.clone()));
175 }
176
177 if let Some(sa) = self.service_accounts.get(provider) {
178 return Some(ProviderAuth::ServiceAccount(sa.clone()));
179 }
180
181 if let Some(oauth) = self.oauth_tokens.get(provider) {
182 return Some(ProviderAuth::OAuthToken(oauth.clone()));
183 }
184
185 if let Some(token) = self.tokens.get(provider) {
186 return Some(ProviderAuth::Token(token.clone()));
187 }
188
189 if let Some(headers) = self.custom_headers.get(provider) {
190 return Some(ProviderAuth::Headers(headers.clone()));
191 }
192
193 None
194 }
195
196 #[allow(dead_code)]
198 pub fn list_providers_with_keys(&self) -> Vec<String> {
199 let mut providers = Vec::new();
200
201 for key in self.api_keys.keys() {
202 if !providers.contains(key) {
203 providers.push(key.clone());
204 }
205 }
206
207 for key in self.service_accounts.keys() {
208 if !providers.contains(key) {
209 providers.push(key.clone());
210 }
211 }
212
213 for key in self.custom_headers.keys() {
214 if !providers.contains(key) {
215 providers.push(key.clone());
216 }
217 }
218
219 providers.sort();
220 providers
221 }
222
223 pub fn has_auth(&self, provider: &str) -> bool {
225 self.api_keys.contains_key(provider)
226 || self.service_accounts.contains_key(provider)
227 || self.custom_headers.contains_key(provider)
228 || self.oauth_tokens.contains_key(provider)
229 || self.tokens.contains_key(provider)
230 }
231
232 pub fn migrate_from_provider_configs(config: &crate::config::Config) -> Result<Self> {
234 let mut keys_config = Self::load()?;
235 let mut migrated_count = 0;
236
237 for (provider_name, provider_config) in &config.providers {
238 if let Some(api_key) = &provider_config.api_key {
240 if !api_key.is_empty() && !keys_config.api_keys.contains_key(provider_name) {
241 keys_config
242 .api_keys
243 .insert(provider_name.clone(), api_key.clone());
244 migrated_count += 1;
245 crate::debug_log!("Migrated API key for provider '{}'", provider_name);
246 }
247 }
248
249 let mut auth_headers = HashMap::new();
251 for (header_name, header_value) in &provider_config.headers {
252 let header_lower = header_name.to_lowercase();
253 if header_lower.contains("key")
254 || header_lower.contains("token")
255 || header_lower.contains("auth")
256 || header_lower.contains("secret")
257 {
258 auth_headers.insert(header_name.clone(), header_value.clone());
259 }
260 }
261
262 if !auth_headers.is_empty() && !keys_config.custom_headers.contains_key(provider_name) {
263 keys_config
264 .custom_headers
265 .insert(provider_name.clone(), auth_headers);
266 migrated_count += 1;
267 crate::debug_log!("Migrated auth headers for provider '{}'", provider_name);
268 }
269 }
270
271 if migrated_count > 0 {
272 keys_config.save()?;
273 println!(
274 "✓ Migrated {} authentication configurations to keys.toml",
275 migrated_count
276 );
277 }
278
279 Ok(keys_config)
280 }
281}
282
283pub fn get_provider_auth(provider: &str) -> Result<Option<ProviderAuth>> {
285 let keys = KeysConfig::load()?;
286 Ok(keys.get_auth(provider))
287}
288
289#[derive(Debug, Clone)]
291pub enum ProviderAuth {
292 ApiKey(String),
293 ServiceAccount(String),
294 OAuthToken(String),
295 Token(String),
296 Headers(HashMap<String, String>),
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use serial_test::serial;
303 use std::env;
304 use tempfile::TempDir;
305
306 #[test]
307 #[serial]
308 fn test_keys_config_save_load() {
309 let temp_dir = TempDir::new().unwrap();
311 env::set_var("LC_TEST_CONFIG_DIR", temp_dir.path());
312
313 let mut keys = KeysConfig::default();
314 keys.set_api_key("test-openai-save-load".to_string(), "test-key".to_string())
315 .unwrap();
316
317 let loaded_keys = KeysConfig::load().unwrap();
319
320 let auth = loaded_keys.get_auth("test-openai-save-load");
322 assert!(auth.is_some());
323
324 if let Some(ProviderAuth::ApiKey(key)) = auth {
325 assert_eq!(key, "test-key");
326 } else {
327 panic!("Expected API key auth type");
328 }
329 }
330
331 #[test]
332 #[serial]
333 fn test_provider_auth_types() {
334 let temp_dir = TempDir::new().unwrap();
335 env::set_var("LC_TEST_CONFIG_DIR", temp_dir.path());
336
337 let mut keys = KeysConfig::default();
338
339 keys.set_api_key("test-openai-auth-types".to_string(), "sk-test".to_string())
341 .unwrap();
342 assert!(keys.has_auth("test-openai-auth-types"));
343
344 keys.set_service_account(
346 "test-vertex-auth-types".to_string(),
347 "{\"type\":\"service_account\"}".to_string(),
348 )
349 .unwrap();
350 assert!(keys.has_auth("test-vertex-auth-types"));
351
352 let mut headers = HashMap::new();
354 headers.insert("X-API-Key".to_string(), "test-key".to_string());
355 keys.set_auth_headers("test-custom-auth-types".to_string(), headers)
356 .unwrap();
357 assert!(keys.has_auth("test-custom-auth-types"));
358
359 assert!(!keys.has_auth("nonexistent"));
361 }
362}