Skip to main content

arbiter_storage/
encryption.rs

1//! Field-level encryption for sensitive session data stored at rest.
2//!
3//! Uses AES-256-GCM (authenticated encryption with associated data) to
4//! encrypt individual fields before they are written to SQLite. Each
5//! encrypted value is prefixed with a random 12-byte nonce, then
6//! base64-encoded for safe storage in TEXT columns.
7//!
8//! Encryption is **optional**: when no key is configured, the storage
9//! layer stores data in plaintext (backward compatible). The key is
10//! loaded from the `ARBITER_STORAGE_ENCRYPTION_KEY` environment variable
11//! as a 64-character hex string (32 bytes).
12
13use aes_gcm::{
14    Aes256Gcm, Nonce,
15    aead::{Aead, KeyInit},
16};
17use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
18use rand::RngCore;
19
20/// Errors from encryption / decryption operations.
21#[derive(Debug, thiserror::Error)]
22pub enum EncryptionError {
23    #[error("invalid key length: expected 32 bytes (64 hex chars), got {0}")]
24    InvalidKeyLength(usize),
25
26    #[error("invalid hex in encryption key: {0}")]
27    InvalidHex(String),
28
29    #[error("encryption failed: {0}")]
30    EncryptionFailed(String),
31
32    #[error("decryption failed: {0}")]
33    DecryptionFailed(String),
34
35    #[error("invalid ciphertext: {0}")]
36    InvalidCiphertext(String),
37}
38
39/// Field-level encryption using AES-256-GCM.
40///
41/// Each encrypted field has the wire format:
42///   `base64(nonce_12_bytes || ciphertext_with_tag)`
43///
44/// A fresh random nonce is generated for every `encrypt_*` call, so
45/// encrypting the same plaintext twice yields different ciphertext.
46#[derive(Clone)]
47pub struct FieldEncryptor {
48    cipher: Aes256Gcm,
49}
50
51impl FieldEncryptor {
52    /// Create from a raw 32-byte key.
53    pub fn new(key: &[u8; 32]) -> Self {
54        Self {
55            cipher: Aes256Gcm::new(key.into()),
56        }
57    }
58
59    /// Create from a hex-encoded key string (64 hex chars = 32 bytes).
60    pub fn from_hex_key(hex_key: &str) -> Result<Self, EncryptionError> {
61        let hex_key = hex_key.trim();
62        if hex_key.len() != 64 {
63            return Err(EncryptionError::InvalidKeyLength(hex_key.len()));
64        }
65        let bytes = hex_decode(hex_key)?;
66        let key: [u8; 32] = bytes
67            .try_into()
68            .map_err(|_| EncryptionError::InvalidKeyLength(0))?;
69        Ok(Self::new(&key))
70    }
71
72    /// Create from the `ARBITER_STORAGE_ENCRYPTION_KEY` environment variable.
73    ///
74    /// Returns `Ok(None)` when the variable is absent or empty (encryption
75    /// disabled). Returns `Err` when the variable is present but malformed.
76    pub fn from_env() -> Result<Option<Self>, EncryptionError> {
77        match std::env::var("ARBITER_STORAGE_ENCRYPTION_KEY") {
78            Ok(val) if !val.trim().is_empty() => Ok(Some(Self::from_hex_key(&val)?)),
79            _ => Ok(None),
80        }
81    }
82
83    /// Encrypt a UTF-8 string field.
84    ///
85    /// Returns a base64-encoded blob containing `nonce || ciphertext`.
86    pub fn encrypt_field(&self, plaintext: &str) -> Result<String, EncryptionError> {
87        let mut nonce_bytes = [0u8; 12];
88        rand::thread_rng().fill_bytes(&mut nonce_bytes);
89        let nonce = Nonce::from_slice(&nonce_bytes);
90
91        let ciphertext = self
92            .cipher
93            .encrypt(nonce, plaintext.as_bytes())
94            .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
95
96        // nonce || ciphertext
97        let mut combined = Vec::with_capacity(12 + ciphertext.len());
98        combined.extend_from_slice(&nonce_bytes);
99        combined.extend_from_slice(&ciphertext);
100
101        Ok(BASE64.encode(&combined))
102    }
103
104    /// Decrypt a base64-encoded `nonce || ciphertext` blob back to the
105    /// original UTF-8 string.
106    pub fn decrypt_field(&self, encoded: &str) -> Result<String, EncryptionError> {
107        let combined = BASE64
108            .decode(encoded)
109            .map_err(|e| EncryptionError::InvalidCiphertext(e.to_string()))?;
110
111        if combined.len() < 13 {
112            // 12-byte nonce + at least 1 byte ciphertext
113            return Err(EncryptionError::InvalidCiphertext(
114                "ciphertext too short".into(),
115            ));
116        }
117
118        let (nonce_bytes, ciphertext) = combined.split_at(12);
119        let nonce = Nonce::from_slice(nonce_bytes);
120
121        let plaintext = self
122            .cipher
123            .decrypt(nonce, ciphertext)
124            .map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))?;
125
126        String::from_utf8(plaintext).map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))
127    }
128
129    /// Encrypt a `Vec<String>` by JSON-serializing then encrypting.
130    pub fn encrypt_string_vec(&self, values: &[String]) -> Result<String, EncryptionError> {
131        let json = serde_json::to_string(values)
132            .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
133        self.encrypt_field(&json)
134    }
135
136    /// Decrypt back to `Vec<String>`.
137    pub fn decrypt_string_vec(&self, ciphertext: &str) -> Result<Vec<String>, EncryptionError> {
138        let json = self.decrypt_field(ciphertext)?;
139        serde_json::from_str(&json).map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))
140    }
141}
142
143/// Decode a hex string to bytes (no external hex crate needed).
144fn hex_decode(hex: &str) -> Result<Vec<u8>, EncryptionError> {
145    if !hex.len().is_multiple_of(2) {
146        return Err(EncryptionError::InvalidHex(
147            "odd number of hex characters".into(),
148        ));
149    }
150    (0..hex.len())
151        .step_by(2)
152        .map(|i| {
153            u8::from_str_radix(&hex[i..i + 2], 16)
154                .map_err(|e| EncryptionError::InvalidHex(e.to_string()))
155        })
156        .collect()
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    fn test_key() -> [u8; 32] {
164        let mut key = [0u8; 32];
165        rand::thread_rng().fill_bytes(&mut key);
166        key
167    }
168
169    fn key_to_hex(key: &[u8; 32]) -> String {
170        key.iter().map(|b| format!("{b:02x}")).collect()
171    }
172
173    #[test]
174    fn encrypt_decrypt_roundtrip() {
175        let key = test_key();
176        let enc = FieldEncryptor::new(&key);
177
178        let original = "sensitive session intent: read all financials";
179        let encrypted = enc.encrypt_field(original).unwrap();
180        let decrypted = enc.decrypt_field(&encrypted).unwrap();
181
182        assert_eq!(decrypted, original);
183    }
184
185    #[test]
186    fn encrypted_bytes_differ_from_plaintext() {
187        let key = test_key();
188        let enc = FieldEncryptor::new(&key);
189
190        let plaintext = "my-secret-intent";
191        let encrypted = enc.encrypt_field(plaintext).unwrap();
192
193        // The encrypted output (base64) must not contain the plaintext
194        assert!(
195            !encrypted.contains(plaintext),
196            "encrypted output must not contain plaintext substring"
197        );
198    }
199
200    #[test]
201    fn different_encryptions_produce_different_ciphertext() {
202        let key = test_key();
203        let enc = FieldEncryptor::new(&key);
204
205        let plaintext = "deterministic input";
206        let ct1 = enc.encrypt_field(plaintext).unwrap();
207        let ct2 = enc.encrypt_field(plaintext).unwrap();
208
209        assert_ne!(
210            ct1, ct2,
211            "two encryptions of the same plaintext must differ (random nonce)"
212        );
213    }
214
215    #[test]
216    fn wrong_key_fails_decryption() {
217        let key1 = test_key();
218        let key2 = test_key();
219
220        let enc1 = FieldEncryptor::new(&key1);
221        let enc2 = FieldEncryptor::new(&key2);
222
223        let encrypted = enc1.encrypt_field("secret").unwrap();
224        let result = enc2.decrypt_field(&encrypted);
225
226        assert!(result.is_err(), "decryption with wrong key must fail");
227    }
228
229    #[test]
230    fn missing_env_key_returns_none() {
231        // Ensure the variable is not set
232        // SAFETY: test-only, single-threaded access to env var
233        unsafe { std::env::remove_var("ARBITER_STORAGE_ENCRYPTION_KEY") };
234        let result = FieldEncryptor::from_env().unwrap();
235        assert!(result.is_none(), "from_env with no var must return None");
236    }
237
238    #[test]
239    fn encrypt_decrypt_string_vec_roundtrip() {
240        let key = test_key();
241        let enc = FieldEncryptor::new(&key);
242
243        let tools = vec![
244            "read_file".to_string(),
245            "write_file".to_string(),
246            "execute_command".to_string(),
247        ];
248        let encrypted = enc.encrypt_string_vec(&tools).unwrap();
249        let decrypted = enc.decrypt_string_vec(&encrypted).unwrap();
250
251        assert_eq!(decrypted, tools);
252    }
253
254    #[test]
255    fn corrupt_ciphertext_fails() {
256        let key = test_key();
257        let enc = FieldEncryptor::new(&key);
258
259        let encrypted = enc.encrypt_field("valid data").unwrap();
260
261        // Decode, corrupt a byte in the ciphertext portion, re-encode
262        let mut raw = BASE64.decode(&encrypted).unwrap();
263        if raw.len() > 12 {
264            // Flip a bit in the ciphertext (past the nonce)
265            raw[13] ^= 0xFF;
266        }
267        let corrupted = BASE64.encode(&raw);
268
269        let result = enc.decrypt_field(&corrupted);
270        assert!(
271            result.is_err(),
272            "corrupted ciphertext must fail AEAD verification"
273        );
274    }
275
276    #[test]
277    fn from_hex_key_roundtrip() {
278        let key = test_key();
279        let hex = key_to_hex(&key);
280
281        let enc = FieldEncryptor::from_hex_key(&hex).unwrap();
282        let encrypted = enc.encrypt_field("hex key test").unwrap();
283        let decrypted = enc.decrypt_field(&encrypted).unwrap();
284
285        assert_eq!(decrypted, "hex key test");
286    }
287
288    #[test]
289    fn from_hex_key_invalid_length() {
290        let result = FieldEncryptor::from_hex_key("0011aabb");
291        assert!(result.is_err());
292    }
293
294    #[test]
295    fn from_hex_key_invalid_chars() {
296        let bad = "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz";
297        let result = FieldEncryptor::from_hex_key(bad);
298        assert!(result.is_err());
299    }
300
301    #[test]
302    fn env_key_present_and_valid() {
303        let key = test_key();
304        let hex = key_to_hex(&key);
305
306        // SAFETY: test-only, single-threaded access to env var
307        unsafe { std::env::set_var("ARBITER_STORAGE_ENCRYPTION_KEY", &hex) };
308        let result = FieldEncryptor::from_env();
309        unsafe { std::env::remove_var("ARBITER_STORAGE_ENCRYPTION_KEY") };
310
311        let enc = result.unwrap().expect("should return Some when key is set");
312        let ct = enc.encrypt_field("env test").unwrap();
313        let pt = enc.decrypt_field(&ct).unwrap();
314        assert_eq!(pt, "env test");
315    }
316
317    #[test]
318    fn empty_string_roundtrip() {
319        let key = test_key();
320        let enc = FieldEncryptor::new(&key);
321
322        let encrypted = enc.encrypt_field("").unwrap();
323        let decrypted = enc.decrypt_field(&encrypted).unwrap();
324        assert_eq!(decrypted, "");
325    }
326
327    #[test]
328    fn empty_vec_roundtrip() {
329        let key = test_key();
330        let enc = FieldEncryptor::new(&key);
331
332        let encrypted = enc.encrypt_string_vec(&[]).unwrap();
333        let decrypted = enc.decrypt_string_vec(&encrypted).unwrap();
334        assert!(decrypted.is_empty());
335    }
336}