1use aes_gcm::{
2 aead::{Aead, KeyInit, OsRng},
3 AeadCore, Aes256Gcm, Nonce,
4};
5use std::collections::HashMap;
6use std::path::Path;
7use thiserror::Error;
8use zeroize::Zeroize;
9
10use crate::security::{memory, sealed_file};
11
12#[derive(Error, Debug)]
13pub enum KeyringError {
14 #[error("Failed to read session key: {0}")]
15 SealedFile(#[from] sealed_file::SealedFileError),
16 #[error("Failed to read keyring file: {0}")]
17 Io(#[from] std::io::Error),
18 #[error("Decryption failed — wrong key or corrupted keyring")]
19 DecryptionFailed,
20 #[error("Invalid keyring format — expected JSON object")]
21 InvalidFormat(#[from] serde_json::Error),
22 #[error("Keyring file too small (need at least nonce + tag)")]
23 TooSmall,
24}
25
26pub struct Keyring {
28 keys: HashMap<String, String>,
29 _raw_json: Vec<u8>,
31 pub ephemeral: bool,
34}
35
36impl Keyring {
37 pub fn load(keyring_path: &Path) -> Result<Self, KeyringError> {
41 let session_key = sealed_file::read_and_delete_key()?;
43
44 let encrypted = std::fs::read(keyring_path)?;
46
47 let decrypted = decrypt_keyring(&session_key, &encrypted)?;
49
50 if let Err(warning) = memory::mlock(decrypted.as_ptr(), decrypted.len()) {
52 tracing::warn!("{warning}");
53 }
54 let _ = memory::madvise_dontdump(decrypted.as_ptr(), decrypted.len());
55
56 let keys: HashMap<String, String> = serde_json::from_slice(&decrypted)?;
58
59 Ok(Keyring {
60 keys,
61 _raw_json: decrypted,
62 ephemeral: true,
63 })
64 }
65
66 pub fn load_with_key(
68 keyring_path: &Path,
69 session_key: &[u8; 32],
70 ) -> Result<Self, KeyringError> {
71 let encrypted = std::fs::read(keyring_path)?;
72 let decrypted = decrypt_keyring(session_key, &encrypted)?;
73
74 if let Err(warning) = memory::mlock(decrypted.as_ptr(), decrypted.len()) {
75 tracing::warn!("{warning}");
76 }
77 let _ = memory::madvise_dontdump(decrypted.as_ptr(), decrypted.len());
78
79 let keys: HashMap<String, String> = serde_json::from_slice(&decrypted)?;
80
81 Ok(Keyring {
82 keys,
83 _raw_json: decrypted,
84 ephemeral: true,
85 })
86 }
87
88 pub fn get(&self, key_name: &str) -> Option<&str> {
90 self.keys.get(key_name).map(|s| s.as_str())
91 }
92
93 pub fn contains(&self, key_name: &str) -> bool {
95 self.keys.contains_key(key_name)
96 }
97
98 pub fn key_names(&self) -> Vec<&str> {
100 self.keys.keys().map(|s| s.as_str()).collect()
101 }
102
103 pub fn load_credentials(path: &Path) -> Result<Self, KeyringError> {
108 let data = std::fs::read(path)?;
109 let keys: HashMap<String, String> = serde_json::from_slice(&data)?;
110 Ok(Keyring {
111 keys,
112 _raw_json: Vec::new(),
113 ephemeral: false,
114 })
115 }
116
117 pub fn load_local(keyring_path: &Path, ati_dir: &Path) -> Result<Self, KeyringError> {
123 let persistent_key_path = ati_dir.join(".keyring-key");
124
125 let contents = std::fs::read_to_string(&persistent_key_path).map_err(KeyringError::Io)?;
126
127 let decoded =
128 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, contents.trim())
129 .map_err(|_| KeyringError::DecryptionFailed)?;
130
131 if decoded.len() != 32 {
132 return Err(KeyringError::DecryptionFailed);
133 }
134
135 let mut key = [0u8; 32];
136 key.copy_from_slice(&decoded);
137
138 let mut kr = Self::load_with_key(keyring_path, &key)?;
139 kr.ephemeral = false;
140 Ok(kr)
141 }
142
143 pub fn from_env() -> Self {
148 let mut keys = HashMap::new();
149 for (name, value) in std::env::vars() {
150 if let Some(key_name) = name.strip_prefix("ATI_KEY_") {
151 if !value.is_empty() {
152 keys.insert(key_name.to_lowercase(), value);
153 }
154 }
155 }
156 Keyring {
157 keys,
158 _raw_json: Vec::new(),
159 ephemeral: false,
160 }
161 }
162
163 pub fn empty() -> Self {
165 Keyring {
166 keys: HashMap::new(),
167 _raw_json: Vec::new(),
168 ephemeral: false,
169 }
170 }
171
172 pub fn merge(&mut self, other: &Keyring) {
174 for (k, v) in &other.keys {
175 self.keys.insert(k.clone(), v.clone());
176 }
177 }
178
179 pub fn len(&self) -> usize {
181 self.keys.len()
182 }
183
184 pub fn is_empty(&self) -> bool {
186 self.keys.is_empty()
187 }
188}
189
190impl Drop for Keyring {
191 fn drop(&mut self) {
192 for value in self.keys.values_mut() {
194 value.zeroize();
195 }
196 let ptr = self._raw_json.as_ptr();
199 let len = self._raw_json.len();
200 self._raw_json.zeroize();
202 if len > 0 {
204 memory::munlock(ptr, len);
205 }
206 }
207}
208
209const NONCE_SIZE: usize = 12;
213
214fn decrypt_keyring(session_key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>, KeyringError> {
216 if encrypted.len() < NONCE_SIZE + 16 {
217 return Err(KeyringError::TooSmall);
219 }
220
221 let (nonce_bytes, ciphertext) = encrypted.split_at(NONCE_SIZE);
222 let nonce = Nonce::from_slice(nonce_bytes);
223
224 let cipher =
225 Aes256Gcm::new_from_slice(session_key).map_err(|_| KeyringError::DecryptionFailed)?;
226
227 cipher
228 .decrypt(nonce, ciphertext)
229 .map_err(|_| KeyringError::DecryptionFailed)
230}
231
232pub fn encrypt_keyring(session_key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>, KeyringError> {
235 let cipher =
236 Aes256Gcm::new_from_slice(session_key).map_err(|_| KeyringError::DecryptionFailed)?;
237
238 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
239
240 let ciphertext = cipher
241 .encrypt(&nonce, plaintext)
242 .map_err(|_| KeyringError::DecryptionFailed)?;
243
244 let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
245 result.extend_from_slice(&nonce);
246 result.extend_from_slice(&ciphertext);
247 Ok(result)
248}
249
250pub fn generate_session_key() -> [u8; 32] {
252 let mut key = [0u8; 32];
253 use rand::RngCore;
254 OsRng.fill_bytes(&mut key);
255 key
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_encrypt_decrypt_roundtrip() {
264 let session_key = generate_session_key();
265 let plaintext = br#"{"parallel_api_key":"test123","epo_api_key":"test456"}"#;
266
267 let encrypted = encrypt_keyring(&session_key, plaintext).unwrap();
268 let decrypted = decrypt_keyring(&session_key, &encrypted).unwrap();
269
270 assert_eq!(decrypted, plaintext);
271 }
272
273 #[test]
274 fn test_wrong_key_fails() {
275 let key1 = generate_session_key();
276 let key2 = generate_session_key();
277 let plaintext = br#"{"key":"value"}"#;
278
279 let encrypted = encrypt_keyring(&key1, plaintext).unwrap();
280 let result = decrypt_keyring(&key2, &encrypted);
281
282 assert!(result.is_err());
283 }
284
285 #[test]
286 fn test_too_small_fails() {
287 let key = generate_session_key();
288 let result = decrypt_keyring(&key, &[0u8; 10]);
289 assert!(result.is_err());
290 }
291
292 #[test]
293 fn test_load_credentials() {
294 let dir = tempfile::TempDir::new().unwrap();
295 let creds_path = dir.path().join("credentials");
296 std::fs::write(&creds_path, r#"{"my_api_key":"secret123","other":"val"}"#).unwrap();
297
298 let kr = Keyring::load_credentials(&creds_path).unwrap();
299 assert_eq!(kr.get("my_api_key"), Some("secret123"));
300 assert_eq!(kr.get("other"), Some("val"));
301 assert_eq!(kr.len(), 2);
302 assert!(!kr.is_empty());
303 }
304
305 #[test]
306 fn test_load_credentials_empty() {
307 let dir = tempfile::TempDir::new().unwrap();
308 let creds_path = dir.path().join("credentials");
309 std::fs::write(&creds_path, "{}").unwrap();
310
311 let kr = Keyring::load_credentials(&creds_path).unwrap();
312 assert_eq!(kr.len(), 0);
313 assert!(kr.is_empty());
314 }
315
316 #[test]
317 fn test_from_env_ati_key_prefix() {
318 std::env::set_var("ATI_KEY_TEST_API_KEY", "test_value_123");
320 std::env::set_var("ATI_KEY_ANOTHER_KEY", "another_val");
321
322 let kr = Keyring::from_env();
323 assert_eq!(kr.get("test_api_key"), Some("test_value_123"));
324 assert_eq!(kr.get("another_key"), Some("another_val"));
325
326 std::env::remove_var("ATI_KEY_TEST_API_KEY");
328 std::env::remove_var("ATI_KEY_ANOTHER_KEY");
329 }
330
331 #[test]
332 fn test_merge() {
333 let dir = tempfile::TempDir::new().unwrap();
334 let creds1 = dir.path().join("c1");
335 let creds2 = dir.path().join("c2");
336 std::fs::write(&creds1, r#"{"a":"1","b":"2"}"#).unwrap();
337 std::fs::write(&creds2, r#"{"b":"overridden","c":"3"}"#).unwrap();
338
339 let mut kr1 = Keyring::load_credentials(&creds1).unwrap();
340 let kr2 = Keyring::load_credentials(&creds2).unwrap();
341 kr1.merge(&kr2);
342
343 assert_eq!(kr1.get("a"), Some("1"));
344 assert_eq!(kr1.get("b"), Some("overridden"));
345 assert_eq!(kr1.get("c"), Some("3"));
346 assert_eq!(kr1.len(), 3);
347 }
348}