Skip to main content

ati/core/
keyring.rs

1use aes_gcm::{
2    aead::{Aead, KeyInit, OsRng},
3    AeadCore, Aes256Gcm, Nonce,
4};
5use std::collections::HashMap;
6use std::path::Path;
7use thiserror::Error;
8use zeroize::Zeroize;
9
10use crate::security::{memory, sealed_file};
11
12#[derive(Error, Debug)]
13pub enum KeyringError {
14    #[error("Failed to read session key: {0}")]
15    SealedFile(#[from] sealed_file::SealedFileError),
16    #[error("Failed to read keyring file: {0}")]
17    Io(#[from] std::io::Error),
18    #[error("Decryption failed — wrong key or corrupted keyring")]
19    DecryptionFailed,
20    #[error("Invalid keyring format — expected JSON object")]
21    InvalidFormat(#[from] serde_json::Error),
22    #[error("Keyring file too small (need at least nonce + tag)")]
23    TooSmall,
24}
25
26/// Holds decrypted API keys in memory. Keys are mlock'd and zeroized on drop.
27pub struct Keyring {
28    keys: HashMap<String, String>,
29    /// Raw bytes pointer and length for mlock/munlock
30    _raw_json: Vec<u8>,
31    /// Whether keys were loaded from a sealed source (one-shot key).
32    /// When true, credential files should be wiped after each use.
33    pub ephemeral: bool,
34}
35
36impl Keyring {
37    /// Load the keyring: read sealed key file, decrypt keyring.enc, mlock memory.
38    ///
39    /// The session key file is deleted immediately after reading.
40    pub fn load(keyring_path: &Path) -> Result<Self, KeyringError> {
41        // Read and delete the session key
42        let session_key = sealed_file::read_and_delete_key()?;
43
44        // Read encrypted keyring
45        let encrypted = std::fs::read(keyring_path)?;
46
47        // Decrypt
48        let decrypted = decrypt_keyring(&session_key, &encrypted)?;
49
50        // mlock the decrypted bytes (best-effort)
51        if let Err(warning) = memory::mlock(decrypted.as_ptr(), decrypted.len()) {
52            tracing::warn!("{warning}");
53        }
54        let _ = memory::madvise_dontdump(decrypted.as_ptr(), decrypted.len());
55
56        // Parse JSON into HashMap
57        let keys: HashMap<String, String> = serde_json::from_slice(&decrypted)?;
58
59        Ok(Keyring {
60            keys,
61            _raw_json: decrypted,
62            ephemeral: true,
63        })
64    }
65
66    /// Load from an already-known session key (for testing or orchestrator use).
67    pub fn load_with_key(
68        keyring_path: &Path,
69        session_key: &[u8; 32],
70    ) -> Result<Self, KeyringError> {
71        let encrypted = std::fs::read(keyring_path)?;
72        let decrypted = decrypt_keyring(session_key, &encrypted)?;
73
74        if let Err(warning) = memory::mlock(decrypted.as_ptr(), decrypted.len()) {
75            tracing::warn!("{warning}");
76        }
77        let _ = memory::madvise_dontdump(decrypted.as_ptr(), decrypted.len());
78
79        let keys: HashMap<String, String> = serde_json::from_slice(&decrypted)?;
80
81        Ok(Keyring {
82            keys,
83            _raw_json: decrypted,
84            ephemeral: true,
85        })
86    }
87
88    /// Get a key by name (e.g. "parallel_api_key").
89    pub fn get(&self, key_name: &str) -> Option<&str> {
90        self.keys.get(key_name).map(|s| s.as_str())
91    }
92
93    /// Check if the keyring contains a specific key.
94    pub fn contains(&self, key_name: &str) -> bool {
95        self.keys.contains_key(key_name)
96    }
97
98    /// List all key names (not values).
99    pub fn key_names(&self) -> Vec<&str> {
100        self.keys.keys().map(|s| s.as_str()).collect()
101    }
102
103    /// Load from a plaintext credentials file (JSON object: {"key_name": "value", ...}).
104    ///
105    /// Used in local mode where `~/.ati/credentials` stores keys as plaintext JSON
106    /// with 0600 permissions (same approach as AWS CLI, gh, Docker, Stripe).
107    pub fn load_credentials(path: &Path) -> Result<Self, KeyringError> {
108        let data = std::fs::read(path)?;
109        let keys: HashMap<String, String> = serde_json::from_slice(&data)?;
110        Ok(Keyring {
111            keys,
112            _raw_json: Vec::new(),
113            ephemeral: false,
114        })
115    }
116
117    /// Load keyring.enc using a persistent key stored alongside the ATI directory.
118    ///
119    /// Looks for `<ati_dir>/.keyring-key` (base64-encoded 32-byte key).
120    /// Unlike the sealed key in `/run/ati/.key`, this key is NOT deleted after reading —
121    /// it's for proxy servers with persistent storage.
122    pub fn load_local(keyring_path: &Path, ati_dir: &Path) -> Result<Self, KeyringError> {
123        let persistent_key_path = ati_dir.join(".keyring-key");
124
125        let contents = std::fs::read_to_string(&persistent_key_path).map_err(KeyringError::Io)?;
126
127        let decoded =
128            base64::Engine::decode(&base64::engine::general_purpose::STANDARD, contents.trim())
129                .map_err(|_| KeyringError::DecryptionFailed)?;
130
131        if decoded.len() != 32 {
132            return Err(KeyringError::DecryptionFailed);
133        }
134
135        let mut key = [0u8; 32];
136        key.copy_from_slice(&decoded);
137
138        let mut kr = Self::load_with_key(keyring_path, &key)?;
139        kr.ephemeral = false;
140        Ok(kr)
141    }
142
143    /// Create a keyring from environment variables with `ATI_KEY_` prefix.
144    ///
145    /// Scans all env vars matching `ATI_KEY_*`, strips the prefix, lowercases the name.
146    /// Example: `ATI_KEY_FINNHUB_API_KEY=abc123` → key name `finnhub_api_key`.
147    pub fn from_env() -> Self {
148        let mut keys = HashMap::new();
149        for (name, value) in std::env::vars() {
150            if let Some(key_name) = name.strip_prefix("ATI_KEY_") {
151                if !value.is_empty() {
152                    keys.insert(key_name.to_lowercase(), value);
153                }
154            }
155        }
156        Keyring {
157            keys,
158            _raw_json: Vec::new(),
159            ephemeral: false,
160        }
161    }
162
163    /// Create an empty keyring (for tools with auth_type = none).
164    pub fn empty() -> Self {
165        Keyring {
166            keys: HashMap::new(),
167            _raw_json: Vec::new(),
168            ephemeral: false,
169        }
170    }
171
172    /// Merge another keyring's keys into this one (other's keys take precedence).
173    pub fn merge(&mut self, other: &Keyring) {
174        for (k, v) in &other.keys {
175            self.keys.insert(k.clone(), v.clone());
176        }
177    }
178
179    /// Number of keys in the keyring.
180    pub fn len(&self) -> usize {
181        self.keys.len()
182    }
183
184    /// Whether the keyring has no keys.
185    pub fn is_empty(&self) -> bool {
186        self.keys.is_empty()
187    }
188}
189
190impl Drop for Keyring {
191    fn drop(&mut self) {
192        // Zeroize all key values
193        for value in self.keys.values_mut() {
194            value.zeroize();
195        }
196        // Save ptr/len before zeroizing — Vec::zeroize() sets len to 0,
197        // which would cause the is_empty() check to skip munlock.
198        let ptr = self._raw_json.as_ptr();
199        let len = self._raw_json.len();
200        // Zeroize raw JSON bytes
201        self._raw_json.zeroize();
202        // Unlock memory (using saved len, not post-zeroize len)
203        if len > 0 {
204            memory::munlock(ptr, len);
205        }
206    }
207}
208
209// --- Encryption / Decryption ---
210
211/// AES-256-GCM nonce size (96 bits = 12 bytes)
212const NONCE_SIZE: usize = 12;
213
214/// Decrypt a keyring blob. Format: [12-byte nonce][ciphertext+tag]
215fn decrypt_keyring(session_key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>, KeyringError> {
216    if encrypted.len() < NONCE_SIZE + 16 {
217        // Minimum: nonce (12) + GCM tag (16)
218        return Err(KeyringError::TooSmall);
219    }
220
221    let (nonce_bytes, ciphertext) = encrypted.split_at(NONCE_SIZE);
222    let nonce = Nonce::from_slice(nonce_bytes);
223
224    let cipher =
225        Aes256Gcm::new_from_slice(session_key).map_err(|_| KeyringError::DecryptionFailed)?;
226
227    cipher
228        .decrypt(nonce, ciphertext)
229        .map_err(|_| KeyringError::DecryptionFailed)
230}
231
232/// Encrypt a keyring (for keygen tooling / orchestrator).
233/// Returns the encrypted blob: [12-byte nonce][ciphertext+tag]
234pub fn encrypt_keyring(session_key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>, KeyringError> {
235    let cipher =
236        Aes256Gcm::new_from_slice(session_key).map_err(|_| KeyringError::DecryptionFailed)?;
237
238    let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
239
240    let ciphertext = cipher
241        .encrypt(&nonce, plaintext)
242        .map_err(|_| KeyringError::DecryptionFailed)?;
243
244    let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
245    result.extend_from_slice(&nonce);
246    result.extend_from_slice(&ciphertext);
247    Ok(result)
248}
249
250/// Generate a random 256-bit session key.
251pub fn generate_session_key() -> [u8; 32] {
252    let mut key = [0u8; 32];
253    use rand::RngCore;
254    OsRng.fill_bytes(&mut key);
255    key
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_encrypt_decrypt_roundtrip() {
264        let session_key = generate_session_key();
265        let plaintext = br#"{"parallel_api_key":"test123","epo_api_key":"test456"}"#;
266
267        let encrypted = encrypt_keyring(&session_key, plaintext).unwrap();
268        let decrypted = decrypt_keyring(&session_key, &encrypted).unwrap();
269
270        assert_eq!(decrypted, plaintext);
271    }
272
273    #[test]
274    fn test_wrong_key_fails() {
275        let key1 = generate_session_key();
276        let key2 = generate_session_key();
277        let plaintext = br#"{"key":"value"}"#;
278
279        let encrypted = encrypt_keyring(&key1, plaintext).unwrap();
280        let result = decrypt_keyring(&key2, &encrypted);
281
282        assert!(result.is_err());
283    }
284
285    #[test]
286    fn test_too_small_fails() {
287        let key = generate_session_key();
288        let result = decrypt_keyring(&key, &[0u8; 10]);
289        assert!(result.is_err());
290    }
291
292    #[test]
293    fn test_load_credentials() {
294        let dir = tempfile::TempDir::new().unwrap();
295        let creds_path = dir.path().join("credentials");
296        std::fs::write(&creds_path, r#"{"my_api_key":"secret123","other":"val"}"#).unwrap();
297
298        let kr = Keyring::load_credentials(&creds_path).unwrap();
299        assert_eq!(kr.get("my_api_key"), Some("secret123"));
300        assert_eq!(kr.get("other"), Some("val"));
301        assert_eq!(kr.len(), 2);
302        assert!(!kr.is_empty());
303    }
304
305    #[test]
306    fn test_load_credentials_empty() {
307        let dir = tempfile::TempDir::new().unwrap();
308        let creds_path = dir.path().join("credentials");
309        std::fs::write(&creds_path, "{}").unwrap();
310
311        let kr = Keyring::load_credentials(&creds_path).unwrap();
312        assert_eq!(kr.len(), 0);
313        assert!(kr.is_empty());
314    }
315
316    #[test]
317    fn test_from_env_ati_key_prefix() {
318        // Set some ATI_KEY_ env vars for the test
319        std::env::set_var("ATI_KEY_TEST_API_KEY", "test_value_123");
320        std::env::set_var("ATI_KEY_ANOTHER_KEY", "another_val");
321
322        let kr = Keyring::from_env();
323        assert_eq!(kr.get("test_api_key"), Some("test_value_123"));
324        assert_eq!(kr.get("another_key"), Some("another_val"));
325
326        // Clean up
327        std::env::remove_var("ATI_KEY_TEST_API_KEY");
328        std::env::remove_var("ATI_KEY_ANOTHER_KEY");
329    }
330
331    #[test]
332    fn test_merge() {
333        let dir = tempfile::TempDir::new().unwrap();
334        let creds1 = dir.path().join("c1");
335        let creds2 = dir.path().join("c2");
336        std::fs::write(&creds1, r#"{"a":"1","b":"2"}"#).unwrap();
337        std::fs::write(&creds2, r#"{"b":"overridden","c":"3"}"#).unwrap();
338
339        let mut kr1 = Keyring::load_credentials(&creds1).unwrap();
340        let kr2 = Keyring::load_credentials(&creds2).unwrap();
341        kr1.merge(&kr2);
342
343        assert_eq!(kr1.get("a"), Some("1"));
344        assert_eq!(kr1.get("b"), Some("overridden"));
345        assert_eq!(kr1.get("c"), Some("3"));
346        assert_eq!(kr1.len(), 3);
347    }
348}