arbiter_storage/
encryption.rs1use aes_gcm::{
14 Aes256Gcm, Nonce,
15 aead::{Aead, KeyInit},
16};
17use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
18use rand::RngCore;
19
20#[derive(Debug, thiserror::Error)]
22pub enum EncryptionError {
23 #[error("invalid key length: expected 32 bytes (64 hex chars), got {0}")]
24 InvalidKeyLength(usize),
25
26 #[error("invalid hex in encryption key: {0}")]
27 InvalidHex(String),
28
29 #[error("encryption failed: {0}")]
30 EncryptionFailed(String),
31
32 #[error("decryption failed: {0}")]
33 DecryptionFailed(String),
34
35 #[error("invalid ciphertext: {0}")]
36 InvalidCiphertext(String),
37}
38
39const CURRENT_KEY_VERSION: u8 = 1;
43
44#[derive(Clone)]
55pub struct FieldEncryptor {
56 cipher: Aes256Gcm,
58 previous_cipher: Option<Aes256Gcm>,
61}
62
63impl FieldEncryptor {
64 pub fn new(key: &[u8; 32]) -> Self {
66 Self {
67 cipher: Aes256Gcm::new(key.into()),
68 previous_cipher: None,
69 }
70 }
71
72 pub fn with_previous_key(mut self, key: &[u8; 32]) -> Self {
76 self.previous_cipher = Some(Aes256Gcm::new(key.into()));
77 self
78 }
79
80 pub fn from_hex_key(hex_key: &str) -> Result<Self, EncryptionError> {
82 let hex_key = hex_key.trim();
83 if hex_key.len() != 64 {
84 return Err(EncryptionError::InvalidKeyLength(hex_key.len()));
85 }
86 let bytes = hex_decode(hex_key)?;
87 let key: [u8; 32] = bytes
88 .try_into()
89 .map_err(|_| EncryptionError::InvalidKeyLength(0))?;
90 Ok(Self::new(&key))
91 }
92
93 pub fn with_previous_hex_key(self, hex_key: &str) -> Result<Self, EncryptionError> {
95 let hex_key = hex_key.trim();
96 if hex_key.len() != 64 {
97 return Err(EncryptionError::InvalidKeyLength(hex_key.len()));
98 }
99 let bytes = hex_decode(hex_key)?;
100 let key: [u8; 32] = bytes
101 .try_into()
102 .map_err(|_| EncryptionError::InvalidKeyLength(0))?;
103 Ok(self.with_previous_key(&key))
104 }
105
106 pub fn from_env() -> Result<Option<Self>, EncryptionError> {
111 match std::env::var("ARBITER_STORAGE_ENCRYPTION_KEY") {
112 Ok(val) if !val.trim().is_empty() => Ok(Some(Self::from_hex_key(&val)?)),
113 _ => Ok(None),
114 }
115 }
116
117 pub fn encrypt_field(&self, plaintext: &str) -> Result<String, EncryptionError> {
121 let mut nonce_bytes = [0u8; 12];
122 rand::thread_rng().fill_bytes(&mut nonce_bytes);
123 let nonce = Nonce::from_slice(&nonce_bytes);
124
125 let ciphertext = self
126 .cipher
127 .encrypt(nonce, plaintext.as_bytes())
128 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
129
130 let mut combined = Vec::with_capacity(1 + 12 + ciphertext.len());
132 combined.push(CURRENT_KEY_VERSION);
133 combined.extend_from_slice(&nonce_bytes);
134 combined.extend_from_slice(&ciphertext);
135
136 Ok(BASE64.encode(&combined))
137 }
138
139 pub fn decrypt_field(&self, encoded: &str) -> Result<String, EncryptionError> {
142 let combined = BASE64
143 .decode(encoded)
144 .map_err(|e| EncryptionError::InvalidCiphertext(e.to_string()))?;
145
146 let (nonce_bytes, ciphertext) =
150 if !combined.is_empty() && combined[0] == CURRENT_KEY_VERSION && combined.len() >= 14 {
151 (&combined[1..13], &combined[13..])
153 } else if combined.len() >= 13 {
154 (&combined[..12], &combined[12..])
156 } else {
157 return Err(EncryptionError::InvalidCiphertext(
158 "ciphertext too short".into(),
159 ));
160 };
161
162 let nonce = Nonce::from_slice(nonce_bytes);
163
164 match self.cipher.decrypt(nonce, ciphertext) {
166 Ok(plaintext) => String::from_utf8(plaintext)
167 .map_err(|e| EncryptionError::DecryptionFailed(e.to_string())),
168 Err(current_err) => {
169 if let Some(ref prev) = self.previous_cipher
171 && let Ok(plaintext) = prev.decrypt(nonce, ciphertext)
172 {
173 return String::from_utf8(plaintext)
174 .map_err(|e| EncryptionError::DecryptionFailed(e.to_string()));
175 }
176 Err(EncryptionError::DecryptionFailed(current_err.to_string()))
177 }
178 }
179 }
180
181 pub fn encrypt_string_vec(&self, values: &[String]) -> Result<String, EncryptionError> {
183 let json = serde_json::to_string(values)
184 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
185 self.encrypt_field(&json)
186 }
187
188 pub fn decrypt_string_vec(&self, ciphertext: &str) -> Result<Vec<String>, EncryptionError> {
190 let json = self.decrypt_field(ciphertext)?;
191 serde_json::from_str(&json).map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))
192 }
193}
194
195fn hex_decode(hex: &str) -> Result<Vec<u8>, EncryptionError> {
197 if !hex.len().is_multiple_of(2) {
198 return Err(EncryptionError::InvalidHex(
199 "odd number of hex characters".into(),
200 ));
201 }
202 (0..hex.len())
203 .step_by(2)
204 .map(|i| {
205 u8::from_str_radix(&hex[i..i + 2], 16)
206 .map_err(|e| EncryptionError::InvalidHex(e.to_string()))
207 })
208 .collect()
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 fn test_key() -> [u8; 32] {
216 let mut key = [0u8; 32];
217 rand::thread_rng().fill_bytes(&mut key);
218 key
219 }
220
221 fn key_to_hex(key: &[u8; 32]) -> String {
222 key.iter().map(|b| format!("{b:02x}")).collect()
223 }
224
225 #[test]
226 fn encrypt_decrypt_roundtrip() {
227 let key = test_key();
228 let enc = FieldEncryptor::new(&key);
229
230 let original = "sensitive session intent: read all financials";
231 let encrypted = enc.encrypt_field(original).unwrap();
232 let decrypted = enc.decrypt_field(&encrypted).unwrap();
233
234 assert_eq!(decrypted, original);
235 }
236
237 #[test]
238 fn encrypted_bytes_differ_from_plaintext() {
239 let key = test_key();
240 let enc = FieldEncryptor::new(&key);
241
242 let plaintext = "my-secret-intent";
243 let encrypted = enc.encrypt_field(plaintext).unwrap();
244
245 assert!(
247 !encrypted.contains(plaintext),
248 "encrypted output must not contain plaintext substring"
249 );
250 }
251
252 #[test]
253 fn different_encryptions_produce_different_ciphertext() {
254 let key = test_key();
255 let enc = FieldEncryptor::new(&key);
256
257 let plaintext = "deterministic input";
258 let ct1 = enc.encrypt_field(plaintext).unwrap();
259 let ct2 = enc.encrypt_field(plaintext).unwrap();
260
261 assert_ne!(
262 ct1, ct2,
263 "two encryptions of the same plaintext must differ (random nonce)"
264 );
265 }
266
267 #[test]
268 fn wrong_key_fails_decryption() {
269 let key1 = test_key();
270 let key2 = test_key();
271
272 let enc1 = FieldEncryptor::new(&key1);
273 let enc2 = FieldEncryptor::new(&key2);
274
275 let encrypted = enc1.encrypt_field("secret").unwrap();
276 let result = enc2.decrypt_field(&encrypted);
277
278 assert!(result.is_err(), "decryption with wrong key must fail");
279 }
280
281 #[test]
282 fn missing_env_key_returns_none() {
283 unsafe { std::env::remove_var("ARBITER_STORAGE_ENCRYPTION_KEY") };
286 let result = FieldEncryptor::from_env().unwrap();
287 assert!(result.is_none(), "from_env with no var must return None");
288 }
289
290 #[test]
291 fn encrypt_decrypt_string_vec_roundtrip() {
292 let key = test_key();
293 let enc = FieldEncryptor::new(&key);
294
295 let tools = vec![
296 "read_file".to_string(),
297 "write_file".to_string(),
298 "execute_command".to_string(),
299 ];
300 let encrypted = enc.encrypt_string_vec(&tools).unwrap();
301 let decrypted = enc.decrypt_string_vec(&encrypted).unwrap();
302
303 assert_eq!(decrypted, tools);
304 }
305
306 #[test]
307 fn corrupt_ciphertext_fails() {
308 let key = test_key();
309 let enc = FieldEncryptor::new(&key);
310
311 let encrypted = enc.encrypt_field("valid data").unwrap();
312
313 let mut raw = BASE64.decode(&encrypted).unwrap();
315 if raw.len() > 12 {
316 raw[13] ^= 0xFF;
318 }
319 let corrupted = BASE64.encode(&raw);
320
321 let result = enc.decrypt_field(&corrupted);
322 assert!(
323 result.is_err(),
324 "corrupted ciphertext must fail AEAD verification"
325 );
326 }
327
328 #[test]
329 fn from_hex_key_roundtrip() {
330 let key = test_key();
331 let hex = key_to_hex(&key);
332
333 let enc = FieldEncryptor::from_hex_key(&hex).unwrap();
334 let encrypted = enc.encrypt_field("hex key test").unwrap();
335 let decrypted = enc.decrypt_field(&encrypted).unwrap();
336
337 assert_eq!(decrypted, "hex key test");
338 }
339
340 #[test]
341 fn from_hex_key_invalid_length() {
342 let result = FieldEncryptor::from_hex_key("0011aabb");
343 assert!(result.is_err());
344 }
345
346 #[test]
347 fn from_hex_key_invalid_chars() {
348 let bad = "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz";
349 let result = FieldEncryptor::from_hex_key(bad);
350 assert!(result.is_err());
351 }
352
353 #[test]
354 fn env_key_present_and_valid() {
355 let key = test_key();
356 let hex = key_to_hex(&key);
357
358 unsafe { std::env::set_var("ARBITER_STORAGE_ENCRYPTION_KEY", &hex) };
360 let result = FieldEncryptor::from_env();
361 unsafe { std::env::remove_var("ARBITER_STORAGE_ENCRYPTION_KEY") };
362
363 let enc = result.unwrap().expect("should return Some when key is set");
364 let ct = enc.encrypt_field("env test").unwrap();
365 let pt = enc.decrypt_field(&ct).unwrap();
366 assert_eq!(pt, "env test");
367 }
368
369 #[test]
370 fn empty_string_roundtrip() {
371 let key = test_key();
372 let enc = FieldEncryptor::new(&key);
373
374 let encrypted = enc.encrypt_field("").unwrap();
375 let decrypted = enc.decrypt_field(&encrypted).unwrap();
376 assert_eq!(decrypted, "");
377 }
378
379 #[test]
380 fn empty_vec_roundtrip() {
381 let key = test_key();
382 let enc = FieldEncryptor::new(&key);
383
384 let encrypted = enc.encrypt_string_vec(&[]).unwrap();
385 let decrypted = enc.decrypt_string_vec(&encrypted).unwrap();
386 assert!(decrypted.is_empty());
387 }
388}