arbiter_storage/
encryption.rs1use aes_gcm::{
14 Aes256Gcm, Nonce,
15 aead::{Aead, KeyInit},
16};
17use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
18use rand::RngCore;
19
20#[derive(Debug, thiserror::Error)]
22pub enum EncryptionError {
23 #[error("invalid key length: expected 32 bytes (64 hex chars), got {0}")]
24 InvalidKeyLength(usize),
25
26 #[error("invalid hex in encryption key: {0}")]
27 InvalidHex(String),
28
29 #[error("encryption failed: {0}")]
30 EncryptionFailed(String),
31
32 #[error("decryption failed: {0}")]
33 DecryptionFailed(String),
34
35 #[error("invalid ciphertext: {0}")]
36 InvalidCiphertext(String),
37}
38
39#[derive(Clone)]
47pub struct FieldEncryptor {
48 cipher: Aes256Gcm,
49}
50
51impl FieldEncryptor {
52 pub fn new(key: &[u8; 32]) -> Self {
54 Self {
55 cipher: Aes256Gcm::new(key.into()),
56 }
57 }
58
59 pub fn from_hex_key(hex_key: &str) -> Result<Self, EncryptionError> {
61 let hex_key = hex_key.trim();
62 if hex_key.len() != 64 {
63 return Err(EncryptionError::InvalidKeyLength(hex_key.len()));
64 }
65 let bytes = hex_decode(hex_key)?;
66 let key: [u8; 32] = bytes
67 .try_into()
68 .map_err(|_| EncryptionError::InvalidKeyLength(0))?;
69 Ok(Self::new(&key))
70 }
71
72 pub fn from_env() -> Result<Option<Self>, EncryptionError> {
77 match std::env::var("ARBITER_STORAGE_ENCRYPTION_KEY") {
78 Ok(val) if !val.trim().is_empty() => Ok(Some(Self::from_hex_key(&val)?)),
79 _ => Ok(None),
80 }
81 }
82
83 pub fn encrypt_field(&self, plaintext: &str) -> Result<String, EncryptionError> {
87 let mut nonce_bytes = [0u8; 12];
88 rand::thread_rng().fill_bytes(&mut nonce_bytes);
89 let nonce = Nonce::from_slice(&nonce_bytes);
90
91 let ciphertext = self
92 .cipher
93 .encrypt(nonce, plaintext.as_bytes())
94 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
95
96 let mut combined = Vec::with_capacity(12 + ciphertext.len());
98 combined.extend_from_slice(&nonce_bytes);
99 combined.extend_from_slice(&ciphertext);
100
101 Ok(BASE64.encode(&combined))
102 }
103
104 pub fn decrypt_field(&self, encoded: &str) -> Result<String, EncryptionError> {
107 let combined = BASE64
108 .decode(encoded)
109 .map_err(|e| EncryptionError::InvalidCiphertext(e.to_string()))?;
110
111 if combined.len() < 13 {
112 return Err(EncryptionError::InvalidCiphertext(
114 "ciphertext too short".into(),
115 ));
116 }
117
118 let (nonce_bytes, ciphertext) = combined.split_at(12);
119 let nonce = Nonce::from_slice(nonce_bytes);
120
121 let plaintext = self
122 .cipher
123 .decrypt(nonce, ciphertext)
124 .map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))?;
125
126 String::from_utf8(plaintext).map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))
127 }
128
129 pub fn encrypt_string_vec(&self, values: &[String]) -> Result<String, EncryptionError> {
131 let json = serde_json::to_string(values)
132 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
133 self.encrypt_field(&json)
134 }
135
136 pub fn decrypt_string_vec(&self, ciphertext: &str) -> Result<Vec<String>, EncryptionError> {
138 let json = self.decrypt_field(ciphertext)?;
139 serde_json::from_str(&json).map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))
140 }
141}
142
143fn hex_decode(hex: &str) -> Result<Vec<u8>, EncryptionError> {
145 if !hex.len().is_multiple_of(2) {
146 return Err(EncryptionError::InvalidHex(
147 "odd number of hex characters".into(),
148 ));
149 }
150 (0..hex.len())
151 .step_by(2)
152 .map(|i| {
153 u8::from_str_radix(&hex[i..i + 2], 16)
154 .map_err(|e| EncryptionError::InvalidHex(e.to_string()))
155 })
156 .collect()
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 fn test_key() -> [u8; 32] {
164 let mut key = [0u8; 32];
165 rand::thread_rng().fill_bytes(&mut key);
166 key
167 }
168
169 fn key_to_hex(key: &[u8; 32]) -> String {
170 key.iter().map(|b| format!("{b:02x}")).collect()
171 }
172
173 #[test]
174 fn encrypt_decrypt_roundtrip() {
175 let key = test_key();
176 let enc = FieldEncryptor::new(&key);
177
178 let original = "sensitive session intent: read all financials";
179 let encrypted = enc.encrypt_field(original).unwrap();
180 let decrypted = enc.decrypt_field(&encrypted).unwrap();
181
182 assert_eq!(decrypted, original);
183 }
184
185 #[test]
186 fn encrypted_bytes_differ_from_plaintext() {
187 let key = test_key();
188 let enc = FieldEncryptor::new(&key);
189
190 let plaintext = "my-secret-intent";
191 let encrypted = enc.encrypt_field(plaintext).unwrap();
192
193 assert!(
195 !encrypted.contains(plaintext),
196 "encrypted output must not contain plaintext substring"
197 );
198 }
199
200 #[test]
201 fn different_encryptions_produce_different_ciphertext() {
202 let key = test_key();
203 let enc = FieldEncryptor::new(&key);
204
205 let plaintext = "deterministic input";
206 let ct1 = enc.encrypt_field(plaintext).unwrap();
207 let ct2 = enc.encrypt_field(plaintext).unwrap();
208
209 assert_ne!(
210 ct1, ct2,
211 "two encryptions of the same plaintext must differ (random nonce)"
212 );
213 }
214
215 #[test]
216 fn wrong_key_fails_decryption() {
217 let key1 = test_key();
218 let key2 = test_key();
219
220 let enc1 = FieldEncryptor::new(&key1);
221 let enc2 = FieldEncryptor::new(&key2);
222
223 let encrypted = enc1.encrypt_field("secret").unwrap();
224 let result = enc2.decrypt_field(&encrypted);
225
226 assert!(result.is_err(), "decryption with wrong key must fail");
227 }
228
229 #[test]
230 fn missing_env_key_returns_none() {
231 unsafe { std::env::remove_var("ARBITER_STORAGE_ENCRYPTION_KEY") };
234 let result = FieldEncryptor::from_env().unwrap();
235 assert!(result.is_none(), "from_env with no var must return None");
236 }
237
238 #[test]
239 fn encrypt_decrypt_string_vec_roundtrip() {
240 let key = test_key();
241 let enc = FieldEncryptor::new(&key);
242
243 let tools = vec![
244 "read_file".to_string(),
245 "write_file".to_string(),
246 "execute_command".to_string(),
247 ];
248 let encrypted = enc.encrypt_string_vec(&tools).unwrap();
249 let decrypted = enc.decrypt_string_vec(&encrypted).unwrap();
250
251 assert_eq!(decrypted, tools);
252 }
253
254 #[test]
255 fn corrupt_ciphertext_fails() {
256 let key = test_key();
257 let enc = FieldEncryptor::new(&key);
258
259 let encrypted = enc.encrypt_field("valid data").unwrap();
260
261 let mut raw = BASE64.decode(&encrypted).unwrap();
263 if raw.len() > 12 {
264 raw[13] ^= 0xFF;
266 }
267 let corrupted = BASE64.encode(&raw);
268
269 let result = enc.decrypt_field(&corrupted);
270 assert!(
271 result.is_err(),
272 "corrupted ciphertext must fail AEAD verification"
273 );
274 }
275
276 #[test]
277 fn from_hex_key_roundtrip() {
278 let key = test_key();
279 let hex = key_to_hex(&key);
280
281 let enc = FieldEncryptor::from_hex_key(&hex).unwrap();
282 let encrypted = enc.encrypt_field("hex key test").unwrap();
283 let decrypted = enc.decrypt_field(&encrypted).unwrap();
284
285 assert_eq!(decrypted, "hex key test");
286 }
287
288 #[test]
289 fn from_hex_key_invalid_length() {
290 let result = FieldEncryptor::from_hex_key("0011aabb");
291 assert!(result.is_err());
292 }
293
294 #[test]
295 fn from_hex_key_invalid_chars() {
296 let bad = "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz";
297 let result = FieldEncryptor::from_hex_key(bad);
298 assert!(result.is_err());
299 }
300
301 #[test]
302 fn env_key_present_and_valid() {
303 let key = test_key();
304 let hex = key_to_hex(&key);
305
306 unsafe { std::env::set_var("ARBITER_STORAGE_ENCRYPTION_KEY", &hex) };
308 let result = FieldEncryptor::from_env();
309 unsafe { std::env::remove_var("ARBITER_STORAGE_ENCRYPTION_KEY") };
310
311 let enc = result.unwrap().expect("should return Some when key is set");
312 let ct = enc.encrypt_field("env test").unwrap();
313 let pt = enc.decrypt_field(&ct).unwrap();
314 assert_eq!(pt, "env test");
315 }
316
317 #[test]
318 fn empty_string_roundtrip() {
319 let key = test_key();
320 let enc = FieldEncryptor::new(&key);
321
322 let encrypted = enc.encrypt_field("").unwrap();
323 let decrypted = enc.decrypt_field(&encrypted).unwrap();
324 assert_eq!(decrypted, "");
325 }
326
327 #[test]
328 fn empty_vec_roundtrip() {
329 let key = test_key();
330 let enc = FieldEncryptor::new(&key);
331
332 let encrypted = enc.encrypt_string_vec(&[]).unwrap();
333 let decrypted = enc.decrypt_string_vec(&encrypted).unwrap();
334 assert!(decrypted.is_empty());
335 }
336}