Skip to main content

arbiter_storage/
encryption.rs

1//! Field-level encryption for sensitive session data stored at rest.
2//!
3//! Uses AES-256-GCM (authenticated encryption with associated data) to
4//! encrypt individual fields before they are written to SQLite. Each
5//! encrypted value is prefixed with a random 12-byte nonce, then
6//! base64-encoded for safe storage in TEXT columns.
7//!
8//! Encryption is **optional**: when no key is configured, the storage
9//! layer stores data in plaintext (backward compatible). The key is
10//! loaded from the `ARBITER_STORAGE_ENCRYPTION_KEY` environment variable
11//! as a 64-character hex string (32 bytes).
12
13use aes_gcm::{
14    Aes256Gcm, Nonce,
15    aead::{Aead, KeyInit},
16};
17use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
18use rand::RngCore;
19
20/// Errors from encryption / decryption operations.
21#[derive(Debug, thiserror::Error)]
22pub enum EncryptionError {
23    #[error("invalid key length: expected 32 bytes (64 hex chars), got {0}")]
24    InvalidKeyLength(usize),
25
26    #[error("invalid hex in encryption key: {0}")]
27    InvalidHex(String),
28
29    #[error("encryption failed: {0}")]
30    EncryptionFailed(String),
31
32    #[error("decryption failed: {0}")]
33    DecryptionFailed(String),
34
35    #[error("invalid ciphertext: {0}")]
36    InvalidCiphertext(String),
37}
38
39/// Current key version byte. Prepended to every encrypted blob.
40/// When key rotation occurs, a new version can be assigned and the
41/// decryptor tries all known versions.
42const CURRENT_KEY_VERSION: u8 = 1;
43
44/// Field-level encryption using AES-256-GCM.
45///
46/// Each encrypted field has the wire format:
47///   `base64(key_version_1 || nonce_12_bytes || ciphertext_with_tag)`
48///
49/// The 1-byte key version prefix enables future key rotation: the decryptor
50/// can identify which key was used and select the correct one.
51///
52/// A fresh random nonce is generated for every `encrypt_*` call, so
53/// encrypting the same plaintext twice yields different ciphertext.
54#[derive(Clone)]
55pub struct FieldEncryptor {
56    /// Current key for encryption and decryption.
57    cipher: Aes256Gcm,
58    /// Previous key for decrypting old blobs during rotation.
59    /// When set, decrypt_field tries the current key first, then falls back.
60    previous_cipher: Option<Aes256Gcm>,
61}
62
63impl FieldEncryptor {
64    /// Create from a raw 32-byte key.
65    pub fn new(key: &[u8; 32]) -> Self {
66        Self {
67            cipher: Aes256Gcm::new(key.into()),
68            previous_cipher: None,
69        }
70    }
71
72    /// Set a previous key for rotation. During decryption, if the current
73    /// key fails, the previous key is tried. This allows a rolling upgrade
74    /// window where old data is still readable.
75    pub fn with_previous_key(mut self, key: &[u8; 32]) -> Self {
76        self.previous_cipher = Some(Aes256Gcm::new(key.into()));
77        self
78    }
79
80    /// Create from a hex-encoded key string (64 hex chars = 32 bytes).
81    pub fn from_hex_key(hex_key: &str) -> Result<Self, EncryptionError> {
82        let hex_key = hex_key.trim();
83        if hex_key.len() != 64 {
84            return Err(EncryptionError::InvalidKeyLength(hex_key.len()));
85        }
86        let bytes = hex_decode(hex_key)?;
87        let key: [u8; 32] = bytes
88            .try_into()
89            .map_err(|_| EncryptionError::InvalidKeyLength(0))?;
90        Ok(Self::new(&key))
91    }
92
93    /// Create from a hex-encoded previous key for rotation support.
94    pub fn with_previous_hex_key(self, hex_key: &str) -> Result<Self, EncryptionError> {
95        let hex_key = hex_key.trim();
96        if hex_key.len() != 64 {
97            return Err(EncryptionError::InvalidKeyLength(hex_key.len()));
98        }
99        let bytes = hex_decode(hex_key)?;
100        let key: [u8; 32] = bytes
101            .try_into()
102            .map_err(|_| EncryptionError::InvalidKeyLength(0))?;
103        Ok(self.with_previous_key(&key))
104    }
105
106    /// Create from the `ARBITER_STORAGE_ENCRYPTION_KEY` environment variable.
107    ///
108    /// Returns `Ok(None)` when the variable is absent or empty (encryption
109    /// disabled). Returns `Err` when the variable is present but malformed.
110    pub fn from_env() -> Result<Option<Self>, EncryptionError> {
111        match std::env::var("ARBITER_STORAGE_ENCRYPTION_KEY") {
112            Ok(val) if !val.trim().is_empty() => Ok(Some(Self::from_hex_key(&val)?)),
113            _ => Ok(None),
114        }
115    }
116
117    /// Encrypt a UTF-8 string field.
118    ///
119    /// Returns a base64-encoded blob containing `nonce || ciphertext`.
120    pub fn encrypt_field(&self, plaintext: &str) -> Result<String, EncryptionError> {
121        let mut nonce_bytes = [0u8; 12];
122        rand::thread_rng().fill_bytes(&mut nonce_bytes);
123        let nonce = Nonce::from_slice(&nonce_bytes);
124
125        let ciphertext = self
126            .cipher
127            .encrypt(nonce, plaintext.as_bytes())
128            .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
129
130        // version || nonce || ciphertext
131        let mut combined = Vec::with_capacity(1 + 12 + ciphertext.len());
132        combined.push(CURRENT_KEY_VERSION);
133        combined.extend_from_slice(&nonce_bytes);
134        combined.extend_from_slice(&ciphertext);
135
136        Ok(BASE64.encode(&combined))
137    }
138
139    /// Decrypt a base64-encoded `nonce || ciphertext` blob back to the
140    /// original UTF-8 string.
141    pub fn decrypt_field(&self, encoded: &str) -> Result<String, EncryptionError> {
142        let combined = BASE64
143            .decode(encoded)
144            .map_err(|e| EncryptionError::InvalidCiphertext(e.to_string()))?;
145
146        // Detect versioned vs legacy format.
147        // Versioned: version_1 || nonce_12 || ciphertext (min 14 bytes)
148        // Legacy:    nonce_12 || ciphertext (min 13 bytes, first byte is random nonce)
149        let (nonce_bytes, ciphertext) =
150            if !combined.is_empty() && combined[0] == CURRENT_KEY_VERSION && combined.len() >= 14 {
151                // Versioned format: skip the version byte.
152                (&combined[1..13], &combined[13..])
153            } else if combined.len() >= 13 {
154                // Legacy format (no version prefix): nonce starts at offset 0.
155                (&combined[..12], &combined[12..])
156            } else {
157                return Err(EncryptionError::InvalidCiphertext(
158                    "ciphertext too short".into(),
159                ));
160            };
161
162        let nonce = Nonce::from_slice(nonce_bytes);
163
164        // Try current key first.
165        match self.cipher.decrypt(nonce, ciphertext) {
166            Ok(plaintext) => String::from_utf8(plaintext)
167                .map_err(|e| EncryptionError::DecryptionFailed(e.to_string())),
168            Err(current_err) => {
169                // If a previous key is configured, try it (key rotation support).
170                if let Some(ref prev) = self.previous_cipher
171                    && let Ok(plaintext) = prev.decrypt(nonce, ciphertext)
172                {
173                    return String::from_utf8(plaintext)
174                        .map_err(|e| EncryptionError::DecryptionFailed(e.to_string()));
175                }
176                Err(EncryptionError::DecryptionFailed(current_err.to_string()))
177            }
178        }
179    }
180
181    /// Encrypt a `Vec<String>` by JSON-serializing then encrypting.
182    pub fn encrypt_string_vec(&self, values: &[String]) -> Result<String, EncryptionError> {
183        let json = serde_json::to_string(values)
184            .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
185        self.encrypt_field(&json)
186    }
187
188    /// Decrypt back to `Vec<String>`.
189    pub fn decrypt_string_vec(&self, ciphertext: &str) -> Result<Vec<String>, EncryptionError> {
190        let json = self.decrypt_field(ciphertext)?;
191        serde_json::from_str(&json).map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))
192    }
193}
194
195/// Decode a hex string to bytes (no external hex crate needed).
196fn hex_decode(hex: &str) -> Result<Vec<u8>, EncryptionError> {
197    if !hex.len().is_multiple_of(2) {
198        return Err(EncryptionError::InvalidHex(
199            "odd number of hex characters".into(),
200        ));
201    }
202    (0..hex.len())
203        .step_by(2)
204        .map(|i| {
205            u8::from_str_radix(&hex[i..i + 2], 16)
206                .map_err(|e| EncryptionError::InvalidHex(e.to_string()))
207        })
208        .collect()
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    fn test_key() -> [u8; 32] {
216        let mut key = [0u8; 32];
217        rand::thread_rng().fill_bytes(&mut key);
218        key
219    }
220
221    fn key_to_hex(key: &[u8; 32]) -> String {
222        key.iter().map(|b| format!("{b:02x}")).collect()
223    }
224
225    #[test]
226    fn encrypt_decrypt_roundtrip() {
227        let key = test_key();
228        let enc = FieldEncryptor::new(&key);
229
230        let original = "sensitive session intent: read all financials";
231        let encrypted = enc.encrypt_field(original).unwrap();
232        let decrypted = enc.decrypt_field(&encrypted).unwrap();
233
234        assert_eq!(decrypted, original);
235    }
236
237    #[test]
238    fn encrypted_bytes_differ_from_plaintext() {
239        let key = test_key();
240        let enc = FieldEncryptor::new(&key);
241
242        let plaintext = "my-secret-intent";
243        let encrypted = enc.encrypt_field(plaintext).unwrap();
244
245        // The encrypted output (base64) must not contain the plaintext
246        assert!(
247            !encrypted.contains(plaintext),
248            "encrypted output must not contain plaintext substring"
249        );
250    }
251
252    #[test]
253    fn different_encryptions_produce_different_ciphertext() {
254        let key = test_key();
255        let enc = FieldEncryptor::new(&key);
256
257        let plaintext = "deterministic input";
258        let ct1 = enc.encrypt_field(plaintext).unwrap();
259        let ct2 = enc.encrypt_field(plaintext).unwrap();
260
261        assert_ne!(
262            ct1, ct2,
263            "two encryptions of the same plaintext must differ (random nonce)"
264        );
265    }
266
267    #[test]
268    fn wrong_key_fails_decryption() {
269        let key1 = test_key();
270        let key2 = test_key();
271
272        let enc1 = FieldEncryptor::new(&key1);
273        let enc2 = FieldEncryptor::new(&key2);
274
275        let encrypted = enc1.encrypt_field("secret").unwrap();
276        let result = enc2.decrypt_field(&encrypted);
277
278        assert!(result.is_err(), "decryption with wrong key must fail");
279    }
280
281    #[test]
282    fn missing_env_key_returns_none() {
283        // Ensure the variable is not set
284        // SAFETY: test-only, single-threaded access to env var
285        unsafe { std::env::remove_var("ARBITER_STORAGE_ENCRYPTION_KEY") };
286        let result = FieldEncryptor::from_env().unwrap();
287        assert!(result.is_none(), "from_env with no var must return None");
288    }
289
290    #[test]
291    fn encrypt_decrypt_string_vec_roundtrip() {
292        let key = test_key();
293        let enc = FieldEncryptor::new(&key);
294
295        let tools = vec![
296            "read_file".to_string(),
297            "write_file".to_string(),
298            "execute_command".to_string(),
299        ];
300        let encrypted = enc.encrypt_string_vec(&tools).unwrap();
301        let decrypted = enc.decrypt_string_vec(&encrypted).unwrap();
302
303        assert_eq!(decrypted, tools);
304    }
305
306    #[test]
307    fn corrupt_ciphertext_fails() {
308        let key = test_key();
309        let enc = FieldEncryptor::new(&key);
310
311        let encrypted = enc.encrypt_field("valid data").unwrap();
312
313        // Decode, corrupt a byte in the ciphertext portion, re-encode
314        let mut raw = BASE64.decode(&encrypted).unwrap();
315        if raw.len() > 12 {
316            // Flip a bit in the ciphertext (past the nonce)
317            raw[13] ^= 0xFF;
318        }
319        let corrupted = BASE64.encode(&raw);
320
321        let result = enc.decrypt_field(&corrupted);
322        assert!(
323            result.is_err(),
324            "corrupted ciphertext must fail AEAD verification"
325        );
326    }
327
328    #[test]
329    fn from_hex_key_roundtrip() {
330        let key = test_key();
331        let hex = key_to_hex(&key);
332
333        let enc = FieldEncryptor::from_hex_key(&hex).unwrap();
334        let encrypted = enc.encrypt_field("hex key test").unwrap();
335        let decrypted = enc.decrypt_field(&encrypted).unwrap();
336
337        assert_eq!(decrypted, "hex key test");
338    }
339
340    #[test]
341    fn from_hex_key_invalid_length() {
342        let result = FieldEncryptor::from_hex_key("0011aabb");
343        assert!(result.is_err());
344    }
345
346    #[test]
347    fn from_hex_key_invalid_chars() {
348        let bad = "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz";
349        let result = FieldEncryptor::from_hex_key(bad);
350        assert!(result.is_err());
351    }
352
353    #[test]
354    fn env_key_present_and_valid() {
355        let key = test_key();
356        let hex = key_to_hex(&key);
357
358        // SAFETY: test-only, single-threaded access to env var
359        unsafe { std::env::set_var("ARBITER_STORAGE_ENCRYPTION_KEY", &hex) };
360        let result = FieldEncryptor::from_env();
361        unsafe { std::env::remove_var("ARBITER_STORAGE_ENCRYPTION_KEY") };
362
363        let enc = result.unwrap().expect("should return Some when key is set");
364        let ct = enc.encrypt_field("env test").unwrap();
365        let pt = enc.decrypt_field(&ct).unwrap();
366        assert_eq!(pt, "env test");
367    }
368
369    #[test]
370    fn empty_string_roundtrip() {
371        let key = test_key();
372        let enc = FieldEncryptor::new(&key);
373
374        let encrypted = enc.encrypt_field("").unwrap();
375        let decrypted = enc.decrypt_field(&encrypted).unwrap();
376        assert_eq!(decrypted, "");
377    }
378
379    #[test]
380    fn empty_vec_roundtrip() {
381        let key = test_key();
382        let enc = FieldEncryptor::new(&key);
383
384        let encrypted = enc.encrypt_string_vec(&[]).unwrap();
385        let decrypted = enc.decrypt_string_vec(&encrypted).unwrap();
386        assert!(decrypted.is_empty());
387    }
388}