agent_tools_interface/core/
keyring.rs1use aes_gcm::{
2 aead::{Aead, KeyInit, OsRng},
3 AeadCore, Aes256Gcm, Nonce,
4};
5use std::collections::HashMap;
6use std::path::Path;
7use thiserror::Error;
8use zeroize::Zeroize;
9
10use crate::security::{memory, sealed_file};
11
12#[derive(Error, Debug)]
13pub enum KeyringError {
14 #[error("Failed to read session key: {0}")]
15 SealedFile(#[from] sealed_file::SealedFileError),
16 #[error("Failed to read keyring file: {0}")]
17 Io(#[from] std::io::Error),
18 #[error("Decryption failed — wrong key or corrupted keyring")]
19 DecryptionFailed,
20 #[error("Invalid keyring format — expected JSON object")]
21 InvalidFormat(#[from] serde_json::Error),
22 #[error("Keyring file too small (need at least nonce + tag)")]
23 TooSmall,
24}
25
26pub struct Keyring {
28 keys: HashMap<String, String>,
29 _raw_json: Vec<u8>,
31 pub ephemeral: bool,
34}
35
36impl Keyring {
37 pub fn load(keyring_path: &Path) -> Result<Self, KeyringError> {
41 let session_key = sealed_file::read_and_delete_key()?;
43
44 let encrypted = std::fs::read(keyring_path)?;
46
47 let decrypted = decrypt_keyring(&session_key, &encrypted)?;
49
50 if let Err(warning) = memory::mlock(decrypted.as_ptr(), decrypted.len()) {
52 eprintln!("Warning: {warning}");
53 }
54 let _ = memory::madvise_dontdump(decrypted.as_ptr(), decrypted.len());
55
56 let keys: HashMap<String, String> = serde_json::from_slice(&decrypted)?;
58
59 Ok(Keyring {
60 keys,
61 _raw_json: decrypted,
62 ephemeral: true,
63 })
64 }
65
66 pub fn load_with_key(
68 keyring_path: &Path,
69 session_key: &[u8; 32],
70 ) -> Result<Self, KeyringError> {
71 let encrypted = std::fs::read(keyring_path)?;
72 let decrypted = decrypt_keyring(session_key, &encrypted)?;
73
74 if let Err(warning) = memory::mlock(decrypted.as_ptr(), decrypted.len()) {
75 eprintln!("Warning: {warning}");
76 }
77 let _ = memory::madvise_dontdump(decrypted.as_ptr(), decrypted.len());
78
79 let keys: HashMap<String, String> = serde_json::from_slice(&decrypted)?;
80
81 Ok(Keyring {
82 keys,
83 _raw_json: decrypted,
84 ephemeral: true,
85 })
86 }
87
88 pub fn get(&self, key_name: &str) -> Option<&str> {
90 self.keys.get(key_name).map(|s| s.as_str())
91 }
92
93 pub fn contains(&self, key_name: &str) -> bool {
95 self.keys.contains_key(key_name)
96 }
97
98 pub fn key_names(&self) -> Vec<&str> {
100 self.keys.keys().map(|s| s.as_str()).collect()
101 }
102
103 pub fn load_credentials(path: &Path) -> Result<Self, KeyringError> {
108 let data = std::fs::read(path)?;
109 let keys: HashMap<String, String> = serde_json::from_slice(&data)?;
110 Ok(Keyring {
111 keys,
112 _raw_json: Vec::new(),
113 ephemeral: false,
114 })
115 }
116
117 pub fn load_local(keyring_path: &Path, ati_dir: &Path) -> Result<Self, KeyringError> {
123 let persistent_key_path = ati_dir.join(".keyring-key");
124
125 let contents =
126 std::fs::read_to_string(&persistent_key_path).map_err(|e| KeyringError::Io(e))?;
127
128 let decoded =
129 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, contents.trim())
130 .map_err(|_| KeyringError::DecryptionFailed)?;
131
132 if decoded.len() != 32 {
133 return Err(KeyringError::DecryptionFailed);
134 }
135
136 let mut key = [0u8; 32];
137 key.copy_from_slice(&decoded);
138
139 let mut kr = Self::load_with_key(keyring_path, &key)?;
140 kr.ephemeral = false;
141 Ok(kr)
142 }
143
144 pub fn from_env() -> Self {
149 let mut keys = HashMap::new();
150 for (name, value) in std::env::vars() {
151 if let Some(key_name) = name.strip_prefix("ATI_KEY_") {
152 if !value.is_empty() {
153 keys.insert(key_name.to_lowercase(), value);
154 }
155 }
156 }
157 Keyring {
158 keys,
159 _raw_json: Vec::new(),
160 ephemeral: false,
161 }
162 }
163
164 pub fn empty() -> Self {
166 Keyring {
167 keys: HashMap::new(),
168 _raw_json: Vec::new(),
169 ephemeral: false,
170 }
171 }
172
173 pub fn merge(&mut self, other: &Keyring) {
175 for (k, v) in &other.keys {
176 self.keys.insert(k.clone(), v.clone());
177 }
178 }
179
180 pub fn len(&self) -> usize {
182 self.keys.len()
183 }
184
185 pub fn is_empty(&self) -> bool {
187 self.keys.is_empty()
188 }
189}
190
191impl Drop for Keyring {
192 fn drop(&mut self) {
193 for value in self.keys.values_mut() {
195 value.zeroize();
196 }
197 let ptr = self._raw_json.as_ptr();
200 let len = self._raw_json.len();
201 self._raw_json.zeroize();
203 if len > 0 {
205 memory::munlock(ptr, len);
206 }
207 }
208}
209
210const NONCE_SIZE: usize = 12;
214
215fn decrypt_keyring(session_key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>, KeyringError> {
217 if encrypted.len() < NONCE_SIZE + 16 {
218 return Err(KeyringError::TooSmall);
220 }
221
222 let (nonce_bytes, ciphertext) = encrypted.split_at(NONCE_SIZE);
223 let nonce = Nonce::from_slice(nonce_bytes);
224
225 let cipher =
226 Aes256Gcm::new_from_slice(session_key).map_err(|_| KeyringError::DecryptionFailed)?;
227
228 cipher
229 .decrypt(nonce, ciphertext)
230 .map_err(|_| KeyringError::DecryptionFailed)
231}
232
233pub fn encrypt_keyring(session_key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>, KeyringError> {
236 let cipher =
237 Aes256Gcm::new_from_slice(session_key).map_err(|_| KeyringError::DecryptionFailed)?;
238
239 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
240
241 let ciphertext = cipher
242 .encrypt(&nonce, plaintext)
243 .map_err(|_| KeyringError::DecryptionFailed)?;
244
245 let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
246 result.extend_from_slice(&nonce);
247 result.extend_from_slice(&ciphertext);
248 Ok(result)
249}
250
251pub fn generate_session_key() -> [u8; 32] {
253 let mut key = [0u8; 32];
254 use rand::RngCore;
255 OsRng.fill_bytes(&mut key);
256 key
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_encrypt_decrypt_roundtrip() {
265 let session_key = generate_session_key();
266 let plaintext = br#"{"parallel_api_key":"test123","epo_api_key":"test456"}"#;
267
268 let encrypted = encrypt_keyring(&session_key, plaintext).unwrap();
269 let decrypted = decrypt_keyring(&session_key, &encrypted).unwrap();
270
271 assert_eq!(decrypted, plaintext);
272 }
273
274 #[test]
275 fn test_wrong_key_fails() {
276 let key1 = generate_session_key();
277 let key2 = generate_session_key();
278 let plaintext = br#"{"key":"value"}"#;
279
280 let encrypted = encrypt_keyring(&key1, plaintext).unwrap();
281 let result = decrypt_keyring(&key2, &encrypted);
282
283 assert!(result.is_err());
284 }
285
286 #[test]
287 fn test_too_small_fails() {
288 let key = generate_session_key();
289 let result = decrypt_keyring(&key, &[0u8; 10]);
290 assert!(result.is_err());
291 }
292
293 #[test]
294 fn test_load_credentials() {
295 let dir = tempfile::TempDir::new().unwrap();
296 let creds_path = dir.path().join("credentials");
297 std::fs::write(&creds_path, r#"{"my_api_key":"secret123","other":"val"}"#).unwrap();
298
299 let kr = Keyring::load_credentials(&creds_path).unwrap();
300 assert_eq!(kr.get("my_api_key"), Some("secret123"));
301 assert_eq!(kr.get("other"), Some("val"));
302 assert_eq!(kr.len(), 2);
303 assert!(!kr.is_empty());
304 }
305
306 #[test]
307 fn test_load_credentials_empty() {
308 let dir = tempfile::TempDir::new().unwrap();
309 let creds_path = dir.path().join("credentials");
310 std::fs::write(&creds_path, "{}").unwrap();
311
312 let kr = Keyring::load_credentials(&creds_path).unwrap();
313 assert_eq!(kr.len(), 0);
314 assert!(kr.is_empty());
315 }
316
317 #[test]
318 fn test_from_env_ati_key_prefix() {
319 std::env::set_var("ATI_KEY_TEST_API_KEY", "test_value_123");
321 std::env::set_var("ATI_KEY_ANOTHER_KEY", "another_val");
322
323 let kr = Keyring::from_env();
324 assert_eq!(kr.get("test_api_key"), Some("test_value_123"));
325 assert_eq!(kr.get("another_key"), Some("another_val"));
326
327 std::env::remove_var("ATI_KEY_TEST_API_KEY");
329 std::env::remove_var("ATI_KEY_ANOTHER_KEY");
330 }
331
332 #[test]
333 fn test_merge() {
334 let dir = tempfile::TempDir::new().unwrap();
335 let creds1 = dir.path().join("c1");
336 let creds2 = dir.path().join("c2");
337 std::fs::write(&creds1, r#"{"a":"1","b":"2"}"#).unwrap();
338 std::fs::write(&creds2, r#"{"b":"overridden","c":"3"}"#).unwrap();
339
340 let mut kr1 = Keyring::load_credentials(&creds1).unwrap();
341 let kr2 = Keyring::load_credentials(&creds2).unwrap();
342 kr1.merge(&kr2);
343
344 assert_eq!(kr1.get("a"), Some("1"));
345 assert_eq!(kr1.get("b"), Some("overridden"));
346 assert_eq!(kr1.get("c"), Some("3"));
347 assert_eq!(kr1.len(), 3);
348 }
349}