use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
pub struct KeysConfig {
#[serde(default)]
pub api_keys: HashMap<String, String>,
#[serde(default)]
pub tokens: HashMap<String, String>,
#[serde(default)]
pub service_accounts: HashMap<String, String>,
#[serde(default)]
pub oauth_tokens: HashMap<String, String>,
#[serde(default, alias = "auth_headers")]
pub custom_headers: HashMap<String, HashMap<String, String>>,
}
impl KeysConfig {
pub fn new() -> Self {
Self::default()
}
pub fn load() -> Result<Self> {
let keys_path = Self::keys_file_path()?;
if keys_path.exists() {
let content = fs::read_to_string(&keys_path)?;
let config: KeysConfig = toml::from_str(&content)?;
Ok(config)
} else {
let config = KeysConfig::default();
if let Some(parent) = keys_path.parent() {
fs::create_dir_all(parent)?;
}
config.save()?;
Ok(config)
}
}
pub fn save(&self) -> Result<()> {
let keys_path = Self::keys_file_path()?;
if let Some(parent) = keys_path.parent() {
fs::create_dir_all(parent)?;
}
let content = toml::to_string_pretty(self)?;
fs::write(&keys_path, content)?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let metadata = fs::metadata(&keys_path)?;
let mut permissions = metadata.permissions();
permissions.set_mode(0o600); fs::set_permissions(&keys_path, permissions)?;
}
Ok(())
}
fn keys_file_path() -> Result<PathBuf> {
let config_dir = crate::config::Config::config_dir()?;
Ok(config_dir.join("keys.toml"))
}
pub fn set_api_key(&mut self, provider: String, api_key: String) -> Result<()> {
self.api_keys.insert(provider, api_key);
self.save()
}
#[allow(dead_code)]
pub fn get_api_key(&self, provider: &str) -> Option<&String> {
self.api_keys.get(provider)
}
#[allow(dead_code)]
pub fn remove_api_key(&mut self, provider: &str) -> Result<bool> {
let removed = self.api_keys.remove(provider).is_some();
if removed {
self.save()?;
}
Ok(removed)
}
#[allow(dead_code)]
pub fn set_service_account(&mut self, provider: String, sa_json: String) -> Result<()> {
self.service_accounts.insert(provider, sa_json);
self.save()
}
#[allow(dead_code)]
pub fn get_service_account(&self, provider: &str) -> Option<&String> {
self.service_accounts.get(provider)
}
#[allow(dead_code)]
pub fn set_token(&mut self, name: String, token: String) -> Result<()> {
self.tokens.insert(name, token);
self.save()
}
#[allow(dead_code)]
pub fn get_token(&self, name: &str) -> Option<&String> {
self.tokens.get(name)
}
#[allow(dead_code)]
pub fn set_auth_headers(
&mut self,
provider: String,
headers: HashMap<String, String>,
) -> Result<()> {
self.custom_headers.insert(provider, headers);
self.save()
}
#[allow(dead_code)]
pub fn get_auth_headers(&self, provider: &str) -> HashMap<String, String> {
let mut headers = HashMap::new();
if let Some(api_key) = self.api_keys.get(provider) {
headers.insert("Authorization".to_string(), format!("Bearer {}", api_key));
}
if let Some(custom) = self.custom_headers.get(provider) {
headers.extend(custom.clone());
}
headers
}
pub fn get_auth(&self, provider: &str) -> Option<ProviderAuth> {
if let Some(api_key) = self.api_keys.get(provider) {
return Some(ProviderAuth::ApiKey(api_key.clone()));
}
if let Some(sa) = self.service_accounts.get(provider) {
return Some(ProviderAuth::ServiceAccount(sa.clone()));
}
if let Some(oauth) = self.oauth_tokens.get(provider) {
return Some(ProviderAuth::OAuthToken(oauth.clone()));
}
if let Some(token) = self.tokens.get(provider) {
return Some(ProviderAuth::Token(token.clone()));
}
if let Some(headers) = self.custom_headers.get(provider) {
return Some(ProviderAuth::Headers(headers.clone()));
}
None
}
#[allow(dead_code)]
pub fn list_providers_with_keys(&self) -> Vec<String> {
let mut providers = Vec::new();
for key in self.api_keys.keys() {
if !providers.contains(key) {
providers.push(key.clone());
}
}
for key in self.service_accounts.keys() {
if !providers.contains(key) {
providers.push(key.clone());
}
}
for key in self.custom_headers.keys() {
if !providers.contains(key) {
providers.push(key.clone());
}
}
providers.sort();
providers
}
pub fn has_auth(&self, provider: &str) -> bool {
self.api_keys.contains_key(provider)
|| self.service_accounts.contains_key(provider)
|| self.custom_headers.contains_key(provider)
|| self.oauth_tokens.contains_key(provider)
|| self.tokens.contains_key(provider)
}
pub fn migrate_from_provider_configs(config: &crate::config::Config) -> Result<Self> {
let mut keys_config = Self::load()?;
let mut migrated_count = 0;
for (provider_name, provider_config) in &config.providers {
if let Some(api_key) = &provider_config.api_key {
if !api_key.is_empty() && !keys_config.api_keys.contains_key(provider_name) {
keys_config
.api_keys
.insert(provider_name.clone(), api_key.clone());
migrated_count += 1;
crate::debug_log!("Migrated API key for provider '{}'", provider_name);
}
}
let mut auth_headers = HashMap::new();
for (header_name, header_value) in &provider_config.headers {
let header_lower = header_name.to_lowercase();
if header_lower.contains("key")
|| header_lower.contains("token")
|| header_lower.contains("auth")
|| header_lower.contains("secret")
{
auth_headers.insert(header_name.clone(), header_value.clone());
}
}
if !auth_headers.is_empty() && !keys_config.custom_headers.contains_key(provider_name) {
keys_config
.custom_headers
.insert(provider_name.clone(), auth_headers);
migrated_count += 1;
crate::debug_log!("Migrated auth headers for provider '{}'", provider_name);
}
}
if migrated_count > 0 {
keys_config.save()?;
println!(
"✓ Migrated {} authentication configurations to keys.toml",
migrated_count
);
}
Ok(keys_config)
}
}
pub fn get_provider_auth(provider: &str) -> Result<Option<ProviderAuth>> {
let keys = KeysConfig::load()?;
Ok(keys.get_auth(provider))
}
#[derive(Debug, Clone)]
pub enum ProviderAuth {
ApiKey(String),
ServiceAccount(String),
OAuthToken(String),
Token(String),
Headers(HashMap<String, String>),
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use std::env;
use tempfile::TempDir;
#[test]
#[serial]
fn test_keys_config_save_load() {
let temp_dir = TempDir::new().unwrap();
env::set_var("LC_TEST_CONFIG_DIR", temp_dir.path());
let mut keys = KeysConfig::default();
keys.set_api_key("test-openai-save-load".to_string(), "test-key".to_string())
.unwrap();
let loaded_keys = KeysConfig::load().unwrap();
let auth = loaded_keys.get_auth("test-openai-save-load");
assert!(auth.is_some());
if let Some(ProviderAuth::ApiKey(key)) = auth {
assert_eq!(key, "test-key");
} else {
panic!("Expected API key auth type");
}
}
#[test]
#[serial]
fn test_provider_auth_types() {
let temp_dir = TempDir::new().unwrap();
env::set_var("LC_TEST_CONFIG_DIR", temp_dir.path());
let mut keys = KeysConfig::default();
keys.set_api_key("test-openai-auth-types".to_string(), "sk-test".to_string())
.unwrap();
assert!(keys.has_auth("test-openai-auth-types"));
keys.set_service_account(
"test-vertex-auth-types".to_string(),
"{\"type\":\"service_account\"}".to_string(),
)
.unwrap();
assert!(keys.has_auth("test-vertex-auth-types"));
let mut headers = HashMap::new();
headers.insert("X-API-Key".to_string(), "test-key".to_string());
keys.set_auth_headers("test-custom-auth-types".to_string(), headers)
.unwrap();
assert!(keys.has_auth("test-custom-auth-types"));
assert!(!keys.has_auth("nonexistent"));
}
}