Skip to main content

agent_tools_interface/core/
keyring.rs

1use aes_gcm::{
2    aead::{Aead, KeyInit, OsRng},
3    AeadCore, Aes256Gcm, Nonce,
4};
5use std::collections::HashMap;
6use std::path::Path;
7use thiserror::Error;
8use zeroize::Zeroize;
9
10use crate::security::{memory, sealed_file};
11
12#[derive(Error, Debug)]
13pub enum KeyringError {
14    #[error("Failed to read session key: {0}")]
15    SealedFile(#[from] sealed_file::SealedFileError),
16    #[error("Failed to read keyring file: {0}")]
17    Io(#[from] std::io::Error),
18    #[error("Decryption failed — wrong key or corrupted keyring")]
19    DecryptionFailed,
20    #[error("Invalid keyring format — expected JSON object")]
21    InvalidFormat(#[from] serde_json::Error),
22    #[error("Keyring file too small (need at least nonce + tag)")]
23    TooSmall,
24}
25
26/// Holds decrypted API keys in memory. Keys are mlock'd and zeroized on drop.
27pub struct Keyring {
28    keys: HashMap<String, String>,
29    /// Raw bytes pointer and length for mlock/munlock
30    _raw_json: Vec<u8>,
31    /// Whether keys were loaded from a sealed source (one-shot key).
32    /// When true, credential files should be wiped after each use.
33    pub ephemeral: bool,
34}
35
36impl Keyring {
37    /// Load the keyring: read sealed key file, decrypt keyring.enc, mlock memory.
38    ///
39    /// The session key file is deleted immediately after reading.
40    pub fn load(keyring_path: &Path) -> Result<Self, KeyringError> {
41        // Read and delete the session key
42        let session_key = sealed_file::read_and_delete_key()?;
43
44        // Read encrypted keyring
45        let encrypted = std::fs::read(keyring_path)?;
46
47        // Decrypt
48        let decrypted = decrypt_keyring(&session_key, &encrypted)?;
49
50        // mlock the decrypted bytes (best-effort)
51        if let Err(warning) = memory::mlock(decrypted.as_ptr(), decrypted.len()) {
52            eprintln!("Warning: {warning}");
53        }
54        let _ = memory::madvise_dontdump(decrypted.as_ptr(), decrypted.len());
55
56        // Parse JSON into HashMap
57        let keys: HashMap<String, String> = serde_json::from_slice(&decrypted)?;
58
59        Ok(Keyring {
60            keys,
61            _raw_json: decrypted,
62            ephemeral: true,
63        })
64    }
65
66    /// Load from an already-known session key (for testing or orchestrator use).
67    pub fn load_with_key(
68        keyring_path: &Path,
69        session_key: &[u8; 32],
70    ) -> Result<Self, KeyringError> {
71        let encrypted = std::fs::read(keyring_path)?;
72        let decrypted = decrypt_keyring(session_key, &encrypted)?;
73
74        if let Err(warning) = memory::mlock(decrypted.as_ptr(), decrypted.len()) {
75            eprintln!("Warning: {warning}");
76        }
77        let _ = memory::madvise_dontdump(decrypted.as_ptr(), decrypted.len());
78
79        let keys: HashMap<String, String> = serde_json::from_slice(&decrypted)?;
80
81        Ok(Keyring {
82            keys,
83            _raw_json: decrypted,
84            ephemeral: true,
85        })
86    }
87
88    /// Get a key by name (e.g. "parallel_api_key").
89    pub fn get(&self, key_name: &str) -> Option<&str> {
90        self.keys.get(key_name).map(|s| s.as_str())
91    }
92
93    /// Check if the keyring contains a specific key.
94    pub fn contains(&self, key_name: &str) -> bool {
95        self.keys.contains_key(key_name)
96    }
97
98    /// List all key names (not values).
99    pub fn key_names(&self) -> Vec<&str> {
100        self.keys.keys().map(|s| s.as_str()).collect()
101    }
102
103    /// Load from a plaintext credentials file (JSON object: {"key_name": "value", ...}).
104    ///
105    /// Used in local mode where `~/.ati/credentials` stores keys as plaintext JSON
106    /// with 0600 permissions (same approach as AWS CLI, gh, Docker, Stripe).
107    pub fn load_credentials(path: &Path) -> Result<Self, KeyringError> {
108        let data = std::fs::read(path)?;
109        let keys: HashMap<String, String> = serde_json::from_slice(&data)?;
110        Ok(Keyring {
111            keys,
112            _raw_json: Vec::new(),
113            ephemeral: false,
114        })
115    }
116
117    /// Load keyring.enc using a persistent key stored alongside the ATI directory.
118    ///
119    /// Looks for `<ati_dir>/.keyring-key` (base64-encoded 32-byte key).
120    /// Unlike the sealed key in `/run/ati/.key`, this key is NOT deleted after reading —
121    /// it's for proxy servers with persistent storage.
122    pub fn load_local(keyring_path: &Path, ati_dir: &Path) -> Result<Self, KeyringError> {
123        let persistent_key_path = ati_dir.join(".keyring-key");
124
125        let contents =
126            std::fs::read_to_string(&persistent_key_path).map_err(|e| KeyringError::Io(e))?;
127
128        let decoded =
129            base64::Engine::decode(&base64::engine::general_purpose::STANDARD, contents.trim())
130                .map_err(|_| KeyringError::DecryptionFailed)?;
131
132        if decoded.len() != 32 {
133            return Err(KeyringError::DecryptionFailed);
134        }
135
136        let mut key = [0u8; 32];
137        key.copy_from_slice(&decoded);
138
139        let mut kr = Self::load_with_key(keyring_path, &key)?;
140        kr.ephemeral = false;
141        Ok(kr)
142    }
143
144    /// Create a keyring from environment variables with `ATI_KEY_` prefix.
145    ///
146    /// Scans all env vars matching `ATI_KEY_*`, strips the prefix, lowercases the name.
147    /// Example: `ATI_KEY_FINNHUB_API_KEY=abc123` → key name `finnhub_api_key`.
148    pub fn from_env() -> Self {
149        let mut keys = HashMap::new();
150        for (name, value) in std::env::vars() {
151            if let Some(key_name) = name.strip_prefix("ATI_KEY_") {
152                if !value.is_empty() {
153                    keys.insert(key_name.to_lowercase(), value);
154                }
155            }
156        }
157        Keyring {
158            keys,
159            _raw_json: Vec::new(),
160            ephemeral: false,
161        }
162    }
163
164    /// Create an empty keyring (for tools with auth_type = none).
165    pub fn empty() -> Self {
166        Keyring {
167            keys: HashMap::new(),
168            _raw_json: Vec::new(),
169            ephemeral: false,
170        }
171    }
172
173    /// Merge another keyring's keys into this one (other's keys take precedence).
174    pub fn merge(&mut self, other: &Keyring) {
175        for (k, v) in &other.keys {
176            self.keys.insert(k.clone(), v.clone());
177        }
178    }
179
180    /// Number of keys in the keyring.
181    pub fn len(&self) -> usize {
182        self.keys.len()
183    }
184
185    /// Whether the keyring has no keys.
186    pub fn is_empty(&self) -> bool {
187        self.keys.is_empty()
188    }
189}
190
191impl Drop for Keyring {
192    fn drop(&mut self) {
193        // Zeroize all key values
194        for value in self.keys.values_mut() {
195            value.zeroize();
196        }
197        // Save ptr/len before zeroizing — Vec::zeroize() sets len to 0,
198        // which would cause the is_empty() check to skip munlock.
199        let ptr = self._raw_json.as_ptr();
200        let len = self._raw_json.len();
201        // Zeroize raw JSON bytes
202        self._raw_json.zeroize();
203        // Unlock memory (using saved len, not post-zeroize len)
204        if len > 0 {
205            memory::munlock(ptr, len);
206        }
207    }
208}
209
210// --- Encryption / Decryption ---
211
212/// AES-256-GCM nonce size (96 bits = 12 bytes)
213const NONCE_SIZE: usize = 12;
214
215/// Decrypt a keyring blob. Format: [12-byte nonce][ciphertext+tag]
216fn decrypt_keyring(session_key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>, KeyringError> {
217    if encrypted.len() < NONCE_SIZE + 16 {
218        // Minimum: nonce (12) + GCM tag (16)
219        return Err(KeyringError::TooSmall);
220    }
221
222    let (nonce_bytes, ciphertext) = encrypted.split_at(NONCE_SIZE);
223    let nonce = Nonce::from_slice(nonce_bytes);
224
225    let cipher =
226        Aes256Gcm::new_from_slice(session_key).map_err(|_| KeyringError::DecryptionFailed)?;
227
228    cipher
229        .decrypt(nonce, ciphertext)
230        .map_err(|_| KeyringError::DecryptionFailed)
231}
232
233/// Encrypt a keyring (for keygen tooling / orchestrator).
234/// Returns the encrypted blob: [12-byte nonce][ciphertext+tag]
235pub fn encrypt_keyring(session_key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>, KeyringError> {
236    let cipher =
237        Aes256Gcm::new_from_slice(session_key).map_err(|_| KeyringError::DecryptionFailed)?;
238
239    let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
240
241    let ciphertext = cipher
242        .encrypt(&nonce, plaintext)
243        .map_err(|_| KeyringError::DecryptionFailed)?;
244
245    let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
246    result.extend_from_slice(&nonce);
247    result.extend_from_slice(&ciphertext);
248    Ok(result)
249}
250
251/// Generate a random 256-bit session key.
252pub fn generate_session_key() -> [u8; 32] {
253    let mut key = [0u8; 32];
254    use rand::RngCore;
255    OsRng.fill_bytes(&mut key);
256    key
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_encrypt_decrypt_roundtrip() {
265        let session_key = generate_session_key();
266        let plaintext = br#"{"parallel_api_key":"test123","epo_api_key":"test456"}"#;
267
268        let encrypted = encrypt_keyring(&session_key, plaintext).unwrap();
269        let decrypted = decrypt_keyring(&session_key, &encrypted).unwrap();
270
271        assert_eq!(decrypted, plaintext);
272    }
273
274    #[test]
275    fn test_wrong_key_fails() {
276        let key1 = generate_session_key();
277        let key2 = generate_session_key();
278        let plaintext = br#"{"key":"value"}"#;
279
280        let encrypted = encrypt_keyring(&key1, plaintext).unwrap();
281        let result = decrypt_keyring(&key2, &encrypted);
282
283        assert!(result.is_err());
284    }
285
286    #[test]
287    fn test_too_small_fails() {
288        let key = generate_session_key();
289        let result = decrypt_keyring(&key, &[0u8; 10]);
290        assert!(result.is_err());
291    }
292
293    #[test]
294    fn test_load_credentials() {
295        let dir = tempfile::TempDir::new().unwrap();
296        let creds_path = dir.path().join("credentials");
297        std::fs::write(&creds_path, r#"{"my_api_key":"secret123","other":"val"}"#).unwrap();
298
299        let kr = Keyring::load_credentials(&creds_path).unwrap();
300        assert_eq!(kr.get("my_api_key"), Some("secret123"));
301        assert_eq!(kr.get("other"), Some("val"));
302        assert_eq!(kr.len(), 2);
303        assert!(!kr.is_empty());
304    }
305
306    #[test]
307    fn test_load_credentials_empty() {
308        let dir = tempfile::TempDir::new().unwrap();
309        let creds_path = dir.path().join("credentials");
310        std::fs::write(&creds_path, "{}").unwrap();
311
312        let kr = Keyring::load_credentials(&creds_path).unwrap();
313        assert_eq!(kr.len(), 0);
314        assert!(kr.is_empty());
315    }
316
317    #[test]
318    fn test_from_env_ati_key_prefix() {
319        // Set some ATI_KEY_ env vars for the test
320        std::env::set_var("ATI_KEY_TEST_API_KEY", "test_value_123");
321        std::env::set_var("ATI_KEY_ANOTHER_KEY", "another_val");
322
323        let kr = Keyring::from_env();
324        assert_eq!(kr.get("test_api_key"), Some("test_value_123"));
325        assert_eq!(kr.get("another_key"), Some("another_val"));
326
327        // Clean up
328        std::env::remove_var("ATI_KEY_TEST_API_KEY");
329        std::env::remove_var("ATI_KEY_ANOTHER_KEY");
330    }
331
332    #[test]
333    fn test_merge() {
334        let dir = tempfile::TempDir::new().unwrap();
335        let creds1 = dir.path().join("c1");
336        let creds2 = dir.path().join("c2");
337        std::fs::write(&creds1, r#"{"a":"1","b":"2"}"#).unwrap();
338        std::fs::write(&creds2, r#"{"b":"overridden","c":"3"}"#).unwrap();
339
340        let mut kr1 = Keyring::load_credentials(&creds1).unwrap();
341        let kr2 = Keyring::load_credentials(&creds2).unwrap();
342        kr1.merge(&kr2);
343
344        assert_eq!(kr1.get("a"), Some("1"));
345        assert_eq!(kr1.get("b"), Some("overridden"));
346        assert_eq!(kr1.get("c"), Some("3"));
347        assert_eq!(kr1.len(), 3);
348    }
349}