Skip to main content

hermes_agent_cli_core/
auth.rs

1use anyhow::{Context, Result};
2use directories::ProjectDirs;
3use serde::{Deserialize, Serialize};
4use std::fs;
5use std::path::PathBuf;
6
7/// Credentials storage for auth providers
8#[derive(Debug, Clone, Serialize, Deserialize, Default)]
9pub struct AuthStore {
10    #[serde(default)]
11    pub credentials: Vec<ProviderCredentials>,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ProviderCredentials {
16    pub provider: String,
17    pub api_key: String,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub base_url: Option<String>,
20}
21
22impl AuthStore {
23    /// Load auth store from disk
24    pub fn load() -> Result<Self> {
25        let path = Self::auth_path();
26        if !path.exists() {
27            return Ok(AuthStore::default());
28        }
29        let content = fs::read_to_string(&path)
30            .with_context(|| format!("failed to read auth store from {:?}", path))?;
31        let store: AuthStore = serde_yaml::from_str(&content)
32            .with_context(|| format!("failed to parse auth store from {:?}", path))?;
33        Ok(store)
34    }
35
36    /// Save auth store to disk securely
37    /// Uses atomic write: temp file → set permissions → rename
38    pub fn save(&self) -> Result<()> {
39        let path = Self::auth_path();
40        if let Some(parent) = path.parent() {
41            fs::create_dir_all(parent)
42                .with_context(|| format!("failed to create auth directory {:?}", parent))?;
43        }
44        let content = serde_yaml::to_string(self).context("failed to serialize auth store")?;
45
46        // Atomic write: write to temp file first
47        let temp_path = path.with_extension("tmp");
48        fs::write(&temp_path, &content)
49            .with_context(|| format!("failed to write auth store to temp {:?}", temp_path))?;
50
51        // Set restrictive permissions on Unix (owner only)
52        #[cfg(unix)]
53        {
54            use std::os::unix::fs::PermissionsExt;
55            let mut perms = fs::Permissions::mode(0o600);
56            fs::set_permissions(&temp_path, perms)?;
57        }
58
59        // Windows: mark file as hidden and system to provide some protection
60        #[cfg(windows)]
61        {
62            use std::process::Command;
63            // Set file as hidden
64            let _ = Command::new("attrib").arg("+H").arg(&temp_path).output();
65        }
66
67        // Atomic rename (overwrites existing file)
68        fs::rename(&temp_path, &path)
69            .with_context(|| format!("failed to rename temp auth store to {:?}", path))?;
70
71        Ok(())
72    }
73
74    /// Get auth path
75    pub fn auth_path() -> PathBuf {
76        if let Ok(home) = std::env::var("HERMES_HOME") {
77            return PathBuf::from(home).join("credentials.yaml");
78        }
79        if let Ok(profile) = std::env::var("HERMES_PROFILE") {
80            if let Some(proj_dirs) =
81                ProjectDirs::from("ai", "hermes", &format!("hermes-{}", profile))
82            {
83                return proj_dirs.config_dir().join("credentials.yaml");
84            }
85        }
86        if let Some(proj_dirs) = ProjectDirs::from("ai", "hermes", "hermes-cli") {
87            return proj_dirs.config_dir().join("credentials.yaml");
88        }
89        if let Ok(home) = std::env::var("USERPROFILE") {
90            return PathBuf::from(home).join(".hermes").join("credentials.yaml");
91        }
92        PathBuf::from(".hermes").join("credentials.yaml")
93    }
94
95    /// Add credentials for a provider
96    pub fn add(&mut self, provider: &str, api_key: &str, base_url: Option<&str>) {
97        // Remove existing credentials for this provider
98        self.credentials.retain(|c| c.provider != provider);
99
100        // Add new credentials
101        self.credentials.push(ProviderCredentials {
102            provider: provider.to_string(),
103            api_key: api_key.to_string(),
104            base_url: base_url.map(|s| s.to_string()),
105        });
106    }
107
108    /// List all credentials (with API keys masked)
109    pub fn list(&self) -> Vec<(String, String, Option<String>)> {
110        self.credentials
111            .iter()
112            .map(|c| (c.provider.clone(), mask_key(&c.api_key), c.base_url.clone()))
113            .collect()
114    }
115
116    /// Get credentials for a provider
117    pub fn get(&self, provider: &str) -> Option<&ProviderCredentials> {
118        self.credentials.iter().find(|c| c.provider == provider)
119    }
120
121    /// Remove credentials for a provider
122    pub fn remove(&mut self, provider: &str) -> bool {
123        let len = self.credentials.len();
124        self.credentials.retain(|c| c.provider != provider);
125        self.credentials.len() < len
126    }
127
128    /// Clear all credentials
129    pub fn reset(&mut self) {
130        self.credentials.clear();
131    }
132}
133
134/// Mask an API key for display (show first 4 and last 4 chars)
135fn mask_key(key: &str) -> String {
136    if key.len() <= 8 {
137        return "*".repeat(key.len());
138    }
139    let start = &key[..4];
140    let end = &key[key.len() - 4..];
141    format!("{}...{}", start, end)
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn test_mask_key() {
150        // Keys <= 8 chars are fully masked
151        assert_eq!(mask_key("short"), "*****");
152        assert_eq!(mask_key("12345678"), "********");
153        // Keys > 8 chars show first 4 and last 4
154        assert_eq!(mask_key("sk-1234567890abcdef"), "sk-1...cdef");
155    }
156
157    #[test]
158    fn test_auth_store_add_get() {
159        let mut store = AuthStore::default();
160        store.add("openai", "sk-test123", None);
161        assert!(store.get("openai").is_some());
162        assert_eq!(store.get("openai").unwrap().api_key, "sk-test123");
163    }
164
165    #[test]
166    fn test_auth_store_remove() {
167        let mut store = AuthStore::default();
168        store.add("openai", "sk-test123", None);
169        assert!(store.remove("openai"));
170        assert!(store.get("openai").is_none());
171    }
172
173    #[test]
174    fn test_auth_store_list_masked() {
175        let mut store = AuthStore::default();
176        store.add("openai", "sk-1234567890abcdef", None);
177        let list = store.list();
178        assert_eq!(list.len(), 1);
179        assert_eq!(list[0].0, "openai");
180        assert_eq!(list[0].1, "sk-1...cdef"); // Masked
181    }
182}