use anyhow::{Context, Result};
use directories::ProjectDirs;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AuthStore {
#[serde(default)]
pub credentials: Vec<ProviderCredentials>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderCredentials {
pub provider: String,
pub api_key: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub base_url: Option<String>,
}
impl AuthStore {
pub fn load() -> Result<Self> {
let path = Self::auth_path();
if !path.exists() {
return Ok(AuthStore::default());
}
let content = fs::read_to_string(&path)
.with_context(|| format!("failed to read auth store from {:?}", path))?;
let store: AuthStore = serde_yaml::from_str(&content)
.with_context(|| format!("failed to parse auth store from {:?}", path))?;
Ok(store)
}
pub fn save(&self) -> Result<()> {
let path = Self::auth_path();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("failed to create auth directory {:?}", parent))?;
}
let content = serde_yaml::to_string(self).context("failed to serialize auth store")?;
let temp_path = path.with_extension("tmp");
fs::write(&temp_path, &content)
.with_context(|| format!("failed to write auth store to temp {:?}", temp_path))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::Permissions::mode(0o600);
fs::set_permissions(&temp_path, perms)?;
}
#[cfg(windows)]
{
use std::process::Command;
let _ = Command::new("attrib").arg("+H").arg(&temp_path).output();
}
fs::rename(&temp_path, &path)
.with_context(|| format!("failed to rename temp auth store to {:?}", path))?;
Ok(())
}
pub fn auth_path() -> PathBuf {
if let Ok(home) = std::env::var("HERMES_HOME") {
return PathBuf::from(home).join("credentials.yaml");
}
if let Ok(profile) = std::env::var("HERMES_PROFILE") {
if let Some(proj_dirs) =
ProjectDirs::from("ai", "hermes", &format!("hermes-{}", profile))
{
return proj_dirs.config_dir().join("credentials.yaml");
}
}
if let Some(proj_dirs) = ProjectDirs::from("ai", "hermes", "hermes-cli") {
return proj_dirs.config_dir().join("credentials.yaml");
}
if let Ok(home) = std::env::var("USERPROFILE") {
return PathBuf::from(home).join(".hermes").join("credentials.yaml");
}
PathBuf::from(".hermes").join("credentials.yaml")
}
pub fn add(&mut self, provider: &str, api_key: &str, base_url: Option<&str>) {
self.credentials.retain(|c| c.provider != provider);
self.credentials.push(ProviderCredentials {
provider: provider.to_string(),
api_key: api_key.to_string(),
base_url: base_url.map(|s| s.to_string()),
});
}
pub fn list(&self) -> Vec<(String, String, Option<String>)> {
self.credentials
.iter()
.map(|c| (c.provider.clone(), mask_key(&c.api_key), c.base_url.clone()))
.collect()
}
pub fn get(&self, provider: &str) -> Option<&ProviderCredentials> {
self.credentials.iter().find(|c| c.provider == provider)
}
pub fn remove(&mut self, provider: &str) -> bool {
let len = self.credentials.len();
self.credentials.retain(|c| c.provider != provider);
self.credentials.len() < len
}
pub fn reset(&mut self) {
self.credentials.clear();
}
}
fn mask_key(key: &str) -> String {
if key.len() <= 8 {
return "*".repeat(key.len());
}
let start = &key[..4];
let end = &key[key.len() - 4..];
format!("{}...{}", start, end)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mask_key() {
assert_eq!(mask_key("short"), "*****");
assert_eq!(mask_key("12345678"), "********");
assert_eq!(mask_key("sk-1234567890abcdef"), "sk-1...cdef");
}
#[test]
fn test_auth_store_add_get() {
let mut store = AuthStore::default();
store.add("openai", "sk-test123", None);
assert!(store.get("openai").is_some());
assert_eq!(store.get("openai").unwrap().api_key, "sk-test123");
}
#[test]
fn test_auth_store_remove() {
let mut store = AuthStore::default();
store.add("openai", "sk-test123", None);
assert!(store.remove("openai"));
assert!(store.get("openai").is_none());
}
#[test]
fn test_auth_store_list_masked() {
let mut store = AuthStore::default();
store.add("openai", "sk-1234567890abcdef", None);
let list = store.list();
assert_eq!(list.len(), 1);
assert_eq!(list[0].0, "openai");
assert_eq!(list[0].1, "sk-1...cdef"); }
}