Skip to main content

ati/core/
keyring.rs

1use aes_gcm::{
2    aead::{Aead, KeyInit, OsRng},
3    AeadCore, Aes256Gcm, Nonce,
4};
5use std::collections::HashMap;
6use std::path::Path;
7use thiserror::Error;
8use zeroize::Zeroize;
9
10use crate::security::{memory, sealed_file};
11
12#[derive(Error, Debug)]
13pub enum KeyringError {
14    #[error("Failed to read session key: {0}")]
15    SealedFile(#[from] sealed_file::SealedFileError),
16    #[error("Failed to read keyring file: {0}")]
17    Io(#[from] std::io::Error),
18    #[error("Decryption failed — wrong key or corrupted keyring")]
19    DecryptionFailed,
20    #[error("Invalid keyring format — expected JSON object")]
21    InvalidFormat(#[from] serde_json::Error),
22    #[error("Keyring file too small (need at least nonce + tag)")]
23    TooSmall,
24}
25
26/// Holds decrypted API keys in memory. Keys are mlock'd and zeroized on drop.
27pub struct Keyring {
28    keys: HashMap<String, String>,
29    /// Raw bytes pointer and length for mlock/munlock
30    _raw_json: Vec<u8>,
31    /// Whether keys were loaded from a sealed source (one-shot key).
32    /// When true, credential files should be wiped after each use.
33    pub ephemeral: bool,
34}
35
36impl Keyring {
37    /// Load the keyring: read sealed key file, decrypt keyring.enc, mlock memory.
38    ///
39    /// The session key file is deleted immediately after reading.
40    pub fn load(keyring_path: &Path) -> Result<Self, KeyringError> {
41        // Read and delete the session key
42        let session_key = sealed_file::read_and_delete_key()?;
43
44        // Read encrypted keyring
45        let encrypted = std::fs::read(keyring_path)?;
46
47        // Decrypt
48        let decrypted = decrypt_keyring(&session_key, &encrypted)?;
49
50        // mlock the decrypted bytes (best-effort)
51        if let Err(warning) = memory::mlock(decrypted.as_ptr(), decrypted.len()) {
52            tracing::warn!("{warning}");
53        }
54        let _ = memory::madvise_dontdump(decrypted.as_ptr(), decrypted.len());
55
56        // Parse JSON into HashMap
57        let keys: HashMap<String, String> = serde_json::from_slice(&decrypted)?;
58
59        Ok(Keyring {
60            keys,
61            _raw_json: decrypted,
62            ephemeral: true,
63        })
64    }
65
66    /// Load from an already-known session key (for testing or orchestrator use).
67    pub fn load_with_key(
68        keyring_path: &Path,
69        session_key: &[u8; 32],
70    ) -> Result<Self, KeyringError> {
71        let encrypted = std::fs::read(keyring_path)?;
72        let decrypted = decrypt_keyring(session_key, &encrypted)?;
73
74        if let Err(warning) = memory::mlock(decrypted.as_ptr(), decrypted.len()) {
75            tracing::warn!("{warning}");
76        }
77        let _ = memory::madvise_dontdump(decrypted.as_ptr(), decrypted.len());
78
79        let keys: HashMap<String, String> = serde_json::from_slice(&decrypted)?;
80
81        Ok(Keyring {
82            keys,
83            _raw_json: decrypted,
84            ephemeral: true,
85        })
86    }
87
88    /// Get a key by name (e.g. "parallel_api_key").
89    pub fn get(&self, key_name: &str) -> Option<&str> {
90        self.keys.get(key_name).map(|s| s.as_str())
91    }
92
93    /// Check if the keyring contains a specific key.
94    pub fn contains(&self, key_name: &str) -> bool {
95        self.keys.contains_key(key_name)
96    }
97
98    /// List all key names (not values).
99    pub fn key_names(&self) -> Vec<&str> {
100        self.keys.keys().map(|s| s.as_str()).collect()
101    }
102
103    /// Load from a plaintext credentials file (JSON object: {"key_name": "value", ...}).
104    ///
105    /// Used in local mode where `~/.ati/credentials` stores keys as plaintext JSON
106    /// with 0600 permissions (same approach as AWS CLI, gh, Docker, Stripe).
107    ///
108    /// Supports `@file:/path/to/secret` values — the file contents are read and
109    /// used as the credential value. Useful for container platforms that mount
110    /// secrets as individual files (Northflank, Kubernetes, Docker).
111    pub fn load_credentials(path: &Path) -> Result<Self, KeyringError> {
112        let data = std::fs::read(path)?;
113        let mut keys: HashMap<String, String> = serde_json::from_slice(&data)?;
114
115        // Resolve @file: references — read credential value from external file
116        for value in keys.values_mut() {
117            if let Some(file_path) = value.strip_prefix("@file:") {
118                *value = std::fs::read_to_string(file_path.trim())
119                    .map_err(KeyringError::Io)?
120                    .trim()
121                    .to_string();
122            }
123        }
124
125        Ok(Keyring {
126            keys,
127            _raw_json: Vec::new(),
128            ephemeral: false,
129        })
130    }
131
132    /// Load keyring.enc using a persistent key stored alongside the ATI directory.
133    ///
134    /// Looks for `<ati_dir>/.keyring-key` (base64-encoded 32-byte key).
135    /// Unlike the sealed key in `/run/ati/.key`, this key is NOT deleted after reading —
136    /// it's for proxy servers with persistent storage.
137    pub fn load_local(keyring_path: &Path, ati_dir: &Path) -> Result<Self, KeyringError> {
138        let persistent_key_path = ati_dir.join(".keyring-key");
139
140        let contents = std::fs::read_to_string(&persistent_key_path).map_err(KeyringError::Io)?;
141
142        let decoded =
143            base64::Engine::decode(&base64::engine::general_purpose::STANDARD, contents.trim())
144                .map_err(|_| KeyringError::DecryptionFailed)?;
145
146        if decoded.len() != 32 {
147            return Err(KeyringError::DecryptionFailed);
148        }
149
150        let mut key = [0u8; 32];
151        key.copy_from_slice(&decoded);
152
153        let mut kr = Self::load_with_key(keyring_path, &key)?;
154        kr.ephemeral = false;
155        Ok(kr)
156    }
157
158    /// Create a keyring from environment variables with `ATI_KEY_` prefix.
159    ///
160    /// Scans all env vars matching `ATI_KEY_*`, strips the prefix, lowercases the name.
161    /// Example: `ATI_KEY_FINNHUB_API_KEY=abc123` → key name `finnhub_api_key`.
162    pub fn from_env() -> Self {
163        let mut keys = HashMap::new();
164        for (name, value) in std::env::vars() {
165            if let Some(key_name) = name.strip_prefix("ATI_KEY_") {
166                if !value.is_empty() {
167                    keys.insert(key_name.to_lowercase(), value);
168                }
169            }
170        }
171        Keyring {
172            keys,
173            _raw_json: Vec::new(),
174            ephemeral: false,
175        }
176    }
177
178    /// Create an empty keyring (for tools with auth_type = none).
179    pub fn empty() -> Self {
180        Keyring {
181            keys: HashMap::new(),
182            _raw_json: Vec::new(),
183            ephemeral: false,
184        }
185    }
186
187    /// Merge another keyring's keys into this one (other's keys take precedence).
188    pub fn merge(&mut self, other: &Keyring) {
189        for (k, v) in &other.keys {
190            self.keys.insert(k.clone(), v.clone());
191        }
192    }
193
194    /// Number of keys in the keyring.
195    pub fn len(&self) -> usize {
196        self.keys.len()
197    }
198
199    /// Whether the keyring has no keys.
200    pub fn is_empty(&self) -> bool {
201        self.keys.is_empty()
202    }
203}
204
205impl Drop for Keyring {
206    fn drop(&mut self) {
207        // Zeroize all key values
208        for value in self.keys.values_mut() {
209            value.zeroize();
210        }
211        // Save ptr/len before zeroizing — Vec::zeroize() sets len to 0,
212        // which would cause the is_empty() check to skip munlock.
213        let ptr = self._raw_json.as_ptr();
214        let len = self._raw_json.len();
215        // Zeroize raw JSON bytes
216        self._raw_json.zeroize();
217        // Unlock memory (using saved len, not post-zeroize len)
218        if len > 0 {
219            memory::munlock(ptr, len);
220        }
221    }
222}
223
224// --- Encryption / Decryption ---
225
226/// AES-256-GCM nonce size (96 bits = 12 bytes)
227const NONCE_SIZE: usize = 12;
228
229/// Decrypt a keyring blob. Format: [12-byte nonce][ciphertext+tag]
230fn decrypt_keyring(session_key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>, KeyringError> {
231    if encrypted.len() < NONCE_SIZE + 16 {
232        // Minimum: nonce (12) + GCM tag (16)
233        return Err(KeyringError::TooSmall);
234    }
235
236    let (nonce_bytes, ciphertext) = encrypted.split_at(NONCE_SIZE);
237    let nonce = Nonce::from_slice(nonce_bytes);
238
239    let cipher =
240        Aes256Gcm::new_from_slice(session_key).map_err(|_| KeyringError::DecryptionFailed)?;
241
242    cipher
243        .decrypt(nonce, ciphertext)
244        .map_err(|_| KeyringError::DecryptionFailed)
245}
246
247/// Encrypt a keyring (for keygen tooling / orchestrator).
248/// Returns the encrypted blob: [12-byte nonce][ciphertext+tag]
249pub fn encrypt_keyring(session_key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>, KeyringError> {
250    let cipher =
251        Aes256Gcm::new_from_slice(session_key).map_err(|_| KeyringError::DecryptionFailed)?;
252
253    let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
254
255    let ciphertext = cipher
256        .encrypt(&nonce, plaintext)
257        .map_err(|_| KeyringError::DecryptionFailed)?;
258
259    let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
260    result.extend_from_slice(&nonce);
261    result.extend_from_slice(&ciphertext);
262    Ok(result)
263}
264
265/// Generate a random 256-bit session key.
266pub fn generate_session_key() -> [u8; 32] {
267    let mut key = [0u8; 32];
268    use rand::RngCore;
269    OsRng.fill_bytes(&mut key);
270    key
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_encrypt_decrypt_roundtrip() {
279        let session_key = generate_session_key();
280        let plaintext = br#"{"parallel_api_key":"test123","epo_api_key":"test456"}"#;
281
282        let encrypted = encrypt_keyring(&session_key, plaintext).unwrap();
283        let decrypted = decrypt_keyring(&session_key, &encrypted).unwrap();
284
285        assert_eq!(decrypted, plaintext);
286    }
287
288    #[test]
289    fn test_wrong_key_fails() {
290        let key1 = generate_session_key();
291        let key2 = generate_session_key();
292        let plaintext = br#"{"key":"value"}"#;
293
294        let encrypted = encrypt_keyring(&key1, plaintext).unwrap();
295        let result = decrypt_keyring(&key2, &encrypted);
296
297        assert!(result.is_err());
298    }
299
300    #[test]
301    fn test_too_small_fails() {
302        let key = generate_session_key();
303        let result = decrypt_keyring(&key, &[0u8; 10]);
304        assert!(result.is_err());
305    }
306
307    #[test]
308    fn test_load_credentials() {
309        let dir = tempfile::TempDir::new().unwrap();
310        let creds_path = dir.path().join("credentials");
311        std::fs::write(&creds_path, r#"{"my_api_key":"secret123","other":"val"}"#).unwrap();
312
313        let kr = Keyring::load_credentials(&creds_path).unwrap();
314        assert_eq!(kr.get("my_api_key"), Some("secret123"));
315        assert_eq!(kr.get("other"), Some("val"));
316        assert_eq!(kr.len(), 2);
317        assert!(!kr.is_empty());
318    }
319
320    #[test]
321    fn test_load_credentials_empty() {
322        let dir = tempfile::TempDir::new().unwrap();
323        let creds_path = dir.path().join("credentials");
324        std::fs::write(&creds_path, "{}").unwrap();
325
326        let kr = Keyring::load_credentials(&creds_path).unwrap();
327        assert_eq!(kr.len(), 0);
328        assert!(kr.is_empty());
329    }
330
331    #[test]
332    fn test_from_env_ati_key_prefix() {
333        // Set some ATI_KEY_ env vars for the test
334        std::env::set_var("ATI_KEY_TEST_API_KEY", "test_value_123");
335        std::env::set_var("ATI_KEY_ANOTHER_KEY", "another_val");
336
337        let kr = Keyring::from_env();
338        assert_eq!(kr.get("test_api_key"), Some("test_value_123"));
339        assert_eq!(kr.get("another_key"), Some("another_val"));
340
341        // Clean up
342        std::env::remove_var("ATI_KEY_TEST_API_KEY");
343        std::env::remove_var("ATI_KEY_ANOTHER_KEY");
344    }
345
346    #[test]
347    fn test_merge() {
348        let dir = tempfile::TempDir::new().unwrap();
349        let creds1 = dir.path().join("c1");
350        let creds2 = dir.path().join("c2");
351        std::fs::write(&creds1, r#"{"a":"1","b":"2"}"#).unwrap();
352        std::fs::write(&creds2, r#"{"b":"overridden","c":"3"}"#).unwrap();
353
354        let mut kr1 = Keyring::load_credentials(&creds1).unwrap();
355        let kr2 = Keyring::load_credentials(&creds2).unwrap();
356        kr1.merge(&kr2);
357
358        assert_eq!(kr1.get("a"), Some("1"));
359        assert_eq!(kr1.get("b"), Some("overridden"));
360        assert_eq!(kr1.get("c"), Some("3"));
361        assert_eq!(kr1.len(), 3);
362    }
363}