1use aes_gcm::{
2 aead::{Aead, KeyInit, OsRng},
3 AeadCore, Aes256Gcm, Nonce,
4};
5use std::collections::HashMap;
6use std::path::Path;
7use thiserror::Error;
8use zeroize::Zeroize;
9
10use crate::security::{memory, sealed_file};
11
12#[derive(Error, Debug)]
13pub enum KeyringError {
14 #[error("Failed to read session key: {0}")]
15 SealedFile(#[from] sealed_file::SealedFileError),
16 #[error("Failed to read keyring file: {0}")]
17 Io(#[from] std::io::Error),
18 #[error("Decryption failed — wrong key or corrupted keyring")]
19 DecryptionFailed,
20 #[error("Invalid keyring format — expected JSON object")]
21 InvalidFormat(#[from] serde_json::Error),
22 #[error("Keyring file too small (need at least nonce + tag)")]
23 TooSmall,
24}
25
26pub struct Keyring {
28 keys: HashMap<String, String>,
29 _raw_json: Vec<u8>,
31 pub ephemeral: bool,
34}
35
36impl Keyring {
37 pub fn load(keyring_path: &Path) -> Result<Self, KeyringError> {
41 let session_key = sealed_file::read_and_delete_key()?;
43
44 let encrypted = std::fs::read(keyring_path)?;
46
47 let decrypted = decrypt_keyring(&session_key, &encrypted)?;
49
50 if let Err(warning) = memory::mlock(decrypted.as_ptr(), decrypted.len()) {
52 tracing::warn!("{warning}");
53 }
54 let _ = memory::madvise_dontdump(decrypted.as_ptr(), decrypted.len());
55
56 let keys: HashMap<String, String> = serde_json::from_slice(&decrypted)?;
58
59 Ok(Keyring {
60 keys,
61 _raw_json: decrypted,
62 ephemeral: true,
63 })
64 }
65
66 pub fn load_with_key(
68 keyring_path: &Path,
69 session_key: &[u8; 32],
70 ) -> Result<Self, KeyringError> {
71 let encrypted = std::fs::read(keyring_path)?;
72 let decrypted = decrypt_keyring(session_key, &encrypted)?;
73
74 if let Err(warning) = memory::mlock(decrypted.as_ptr(), decrypted.len()) {
75 tracing::warn!("{warning}");
76 }
77 let _ = memory::madvise_dontdump(decrypted.as_ptr(), decrypted.len());
78
79 let keys: HashMap<String, String> = serde_json::from_slice(&decrypted)?;
80
81 Ok(Keyring {
82 keys,
83 _raw_json: decrypted,
84 ephemeral: true,
85 })
86 }
87
88 pub fn get(&self, key_name: &str) -> Option<&str> {
90 self.keys.get(key_name).map(|s| s.as_str())
91 }
92
93 pub fn contains(&self, key_name: &str) -> bool {
95 self.keys.contains_key(key_name)
96 }
97
98 pub fn key_names(&self) -> Vec<&str> {
100 self.keys.keys().map(|s| s.as_str()).collect()
101 }
102
103 pub fn load_credentials(path: &Path) -> Result<Self, KeyringError> {
112 let data = std::fs::read(path)?;
113 let mut keys: HashMap<String, String> = serde_json::from_slice(&data)?;
114
115 for value in keys.values_mut() {
117 if let Some(file_path) = value.strip_prefix("@file:") {
118 *value = std::fs::read_to_string(file_path.trim())
119 .map_err(KeyringError::Io)?
120 .trim()
121 .to_string();
122 }
123 }
124
125 Ok(Keyring {
126 keys,
127 _raw_json: Vec::new(),
128 ephemeral: false,
129 })
130 }
131
132 pub fn load_local(keyring_path: &Path, ati_dir: &Path) -> Result<Self, KeyringError> {
138 let persistent_key_path = ati_dir.join(".keyring-key");
139
140 let contents = std::fs::read_to_string(&persistent_key_path).map_err(KeyringError::Io)?;
141
142 let decoded =
143 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, contents.trim())
144 .map_err(|_| KeyringError::DecryptionFailed)?;
145
146 if decoded.len() != 32 {
147 return Err(KeyringError::DecryptionFailed);
148 }
149
150 let mut key = [0u8; 32];
151 key.copy_from_slice(&decoded);
152
153 let mut kr = Self::load_with_key(keyring_path, &key)?;
154 kr.ephemeral = false;
155 Ok(kr)
156 }
157
158 pub fn from_env() -> Self {
163 let mut keys = HashMap::new();
164 for (name, value) in std::env::vars() {
165 if let Some(key_name) = name.strip_prefix("ATI_KEY_") {
166 if !value.is_empty() {
167 keys.insert(key_name.to_lowercase(), value);
168 }
169 }
170 }
171 Keyring {
172 keys,
173 _raw_json: Vec::new(),
174 ephemeral: false,
175 }
176 }
177
178 pub fn empty() -> Self {
180 Keyring {
181 keys: HashMap::new(),
182 _raw_json: Vec::new(),
183 ephemeral: false,
184 }
185 }
186
187 pub fn merge(&mut self, other: &Keyring) {
189 for (k, v) in &other.keys {
190 self.keys.insert(k.clone(), v.clone());
191 }
192 }
193
194 pub fn len(&self) -> usize {
196 self.keys.len()
197 }
198
199 pub fn is_empty(&self) -> bool {
201 self.keys.is_empty()
202 }
203}
204
205impl Drop for Keyring {
206 fn drop(&mut self) {
207 for value in self.keys.values_mut() {
209 value.zeroize();
210 }
211 let ptr = self._raw_json.as_ptr();
214 let len = self._raw_json.len();
215 self._raw_json.zeroize();
217 if len > 0 {
219 memory::munlock(ptr, len);
220 }
221 }
222}
223
224const NONCE_SIZE: usize = 12;
228
229fn decrypt_keyring(session_key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>, KeyringError> {
231 if encrypted.len() < NONCE_SIZE + 16 {
232 return Err(KeyringError::TooSmall);
234 }
235
236 let (nonce_bytes, ciphertext) = encrypted.split_at(NONCE_SIZE);
237 let nonce = Nonce::from_slice(nonce_bytes);
238
239 let cipher =
240 Aes256Gcm::new_from_slice(session_key).map_err(|_| KeyringError::DecryptionFailed)?;
241
242 cipher
243 .decrypt(nonce, ciphertext)
244 .map_err(|_| KeyringError::DecryptionFailed)
245}
246
247pub fn encrypt_keyring(session_key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>, KeyringError> {
250 let cipher =
251 Aes256Gcm::new_from_slice(session_key).map_err(|_| KeyringError::DecryptionFailed)?;
252
253 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
254
255 let ciphertext = cipher
256 .encrypt(&nonce, plaintext)
257 .map_err(|_| KeyringError::DecryptionFailed)?;
258
259 let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
260 result.extend_from_slice(&nonce);
261 result.extend_from_slice(&ciphertext);
262 Ok(result)
263}
264
265pub fn generate_session_key() -> [u8; 32] {
267 let mut key = [0u8; 32];
268 use rand::RngCore;
269 OsRng.fill_bytes(&mut key);
270 key
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_encrypt_decrypt_roundtrip() {
279 let session_key = generate_session_key();
280 let plaintext = br#"{"parallel_api_key":"test123","epo_api_key":"test456"}"#;
281
282 let encrypted = encrypt_keyring(&session_key, plaintext).unwrap();
283 let decrypted = decrypt_keyring(&session_key, &encrypted).unwrap();
284
285 assert_eq!(decrypted, plaintext);
286 }
287
288 #[test]
289 fn test_wrong_key_fails() {
290 let key1 = generate_session_key();
291 let key2 = generate_session_key();
292 let plaintext = br#"{"key":"value"}"#;
293
294 let encrypted = encrypt_keyring(&key1, plaintext).unwrap();
295 let result = decrypt_keyring(&key2, &encrypted);
296
297 assert!(result.is_err());
298 }
299
300 #[test]
301 fn test_too_small_fails() {
302 let key = generate_session_key();
303 let result = decrypt_keyring(&key, &[0u8; 10]);
304 assert!(result.is_err());
305 }
306
307 #[test]
308 fn test_load_credentials() {
309 let dir = tempfile::TempDir::new().unwrap();
310 let creds_path = dir.path().join("credentials");
311 std::fs::write(&creds_path, r#"{"my_api_key":"secret123","other":"val"}"#).unwrap();
312
313 let kr = Keyring::load_credentials(&creds_path).unwrap();
314 assert_eq!(kr.get("my_api_key"), Some("secret123"));
315 assert_eq!(kr.get("other"), Some("val"));
316 assert_eq!(kr.len(), 2);
317 assert!(!kr.is_empty());
318 }
319
320 #[test]
321 fn test_load_credentials_empty() {
322 let dir = tempfile::TempDir::new().unwrap();
323 let creds_path = dir.path().join("credentials");
324 std::fs::write(&creds_path, "{}").unwrap();
325
326 let kr = Keyring::load_credentials(&creds_path).unwrap();
327 assert_eq!(kr.len(), 0);
328 assert!(kr.is_empty());
329 }
330
331 #[test]
332 fn test_from_env_ati_key_prefix() {
333 std::env::set_var("ATI_KEY_TEST_API_KEY", "test_value_123");
335 std::env::set_var("ATI_KEY_ANOTHER_KEY", "another_val");
336
337 let kr = Keyring::from_env();
338 assert_eq!(kr.get("test_api_key"), Some("test_value_123"));
339 assert_eq!(kr.get("another_key"), Some("another_val"));
340
341 std::env::remove_var("ATI_KEY_TEST_API_KEY");
343 std::env::remove_var("ATI_KEY_ANOTHER_KEY");
344 }
345
346 #[test]
347 fn test_merge() {
348 let dir = tempfile::TempDir::new().unwrap();
349 let creds1 = dir.path().join("c1");
350 let creds2 = dir.path().join("c2");
351 std::fs::write(&creds1, r#"{"a":"1","b":"2"}"#).unwrap();
352 std::fs::write(&creds2, r#"{"b":"overridden","c":"3"}"#).unwrap();
353
354 let mut kr1 = Keyring::load_credentials(&creds1).unwrap();
355 let kr2 = Keyring::load_credentials(&creds2).unwrap();
356 kr1.merge(&kr2);
357
358 assert_eq!(kr1.get("a"), Some("1"));
359 assert_eq!(kr1.get("b"), Some("overridden"));
360 assert_eq!(kr1.get("c"), Some("3"));
361 assert_eq!(kr1.len(), 3);
362 }
363}