Skip to main content

axess_core/session/
crypto.rs

1//! AES-256-GCM encryption for session data at rest.
2//!
3//! Shared between the Valkey and SQLite session stores. Each encrypted payload
4//! is prefixed with a random 12-byte nonce.
5//!
6//! # Key rotation
7//!
8//! [`SessionCrypto`] accepts an optional previous key. On decrypt, the current
9//! key is tried first; if it fails, the previous key is attempted. Writes always
10//! use the current key, so rotated data is transparently re-encrypted on access.
11
12use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
13use axess_rng::{SecureRng, SystemRng};
14use std::sync::Arc;
15
16const NONCE_LEN: usize = 12;
17
18/// Zeroize-on-drop wrapper for a 32-byte AES key.
19#[derive(Clone)]
20pub struct EncryptionKey(pub(crate) [u8; 32]);
21
22impl Drop for EncryptionKey {
23    fn drop(&mut self) {
24        zeroize::Zeroize::zeroize(&mut self.0);
25    }
26}
27
28/// AES-256-GCM encryption configuration with optional key rotation.
29#[derive(Clone)]
30pub struct SessionCrypto {
31    current: Arc<EncryptionKey>,
32    previous: Option<Arc<EncryptionKey>>,
33    rng: Arc<dyn SecureRng>,
34}
35
36/// Error from encrypt/decrypt operations.
37#[derive(Debug, thiserror::Error)]
38#[error("session encryption/decryption error")]
39pub struct CryptoError;
40
41impl SessionCrypto {
42    /// Create a new crypto config with the given 32-byte key. Uses
43    /// [`SystemRng`] for nonce generation; tests under DST should use
44    /// [`SessionCrypto::with_rng`] to inject a deterministic RNG.
45    pub fn new(key: [u8; 32]) -> Self {
46        Self {
47            current: Arc::new(EncryptionKey(key)),
48            previous: None,
49            rng: Arc::new(SystemRng),
50        }
51    }
52
53    /// Swap the RNG used for nonce generation. Production code keeps the
54    /// default [`SystemRng`]; DST suites pass a seeded `MockRng` to drive
55    /// nonce sequences deterministically.
56    pub fn with_rng(mut self, rng: Arc<dyn SecureRng>) -> Self {
57        self.rng = rng;
58        self
59    }
60
61    /// Enable key rotation: on decrypt, try the current key first, then this
62    /// previous key as a fallback.
63    pub fn with_previous_key(mut self, key: [u8; 32]) -> Self {
64        self.previous = Some(Arc::new(EncryptionKey(key)));
65        self
66    }
67
68    /// Encrypt plaintext with AES-256-GCM. Returns `nonce || ciphertext`.
69    ///
70    /// Each call generates a random 96-bit (12-byte) nonce. The collision
71    /// probability is negligible (~2^-96 per pair), which is safe for up to
72    /// ~2^32 encryptions per key, far beyond typical session store volumes.
73    /// The `aes-gcm` crate performs constant-time authentication tag comparison
74    /// internally during decryption.
75    pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
76        let cipher = Aes256Gcm::new_from_slice(&self.current.0).map_err(|_| CryptoError)?;
77
78        let mut nonce_bytes = [0u8; NONCE_LEN];
79        self.rng.fill_bytes(&mut nonce_bytes);
80        let nonce = Nonce::from_slice(&nonce_bytes);
81
82        let ciphertext = cipher.encrypt(nonce, plaintext).map_err(|_| CryptoError)?;
83
84        let mut out = Vec::with_capacity(NONCE_LEN + ciphertext.len());
85        out.extend_from_slice(&nonce_bytes);
86        out.extend_from_slice(&ciphertext);
87        Ok(out)
88    }
89
90    /// Decrypt `nonce || ciphertext`. Tries the current key first, then the
91    /// previous key (if configured) for key rotation support.
92    pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, CryptoError> {
93        if data.len() < NONCE_LEN {
94            return Err(CryptoError);
95        }
96
97        let (nonce_bytes, ciphertext) = data.split_at(NONCE_LEN);
98        let nonce = Nonce::from_slice(nonce_bytes);
99
100        let cipher = Aes256Gcm::new_from_slice(&self.current.0).map_err(|_| CryptoError)?;
101
102        if let Ok(plaintext) = cipher.decrypt(nonce, ciphertext) {
103            return Ok(plaintext);
104        }
105
106        if let Some(prev) = &self.previous {
107            tracing::warn!(
108                "session decryption failed with current key; trying previous key (rotation fallback)"
109            );
110            let old_cipher = Aes256Gcm::new_from_slice(&prev.0).map_err(|_| CryptoError)?;
111
112            if let Ok(plaintext) = old_cipher.decrypt(nonce, ciphertext) {
113                tracing::debug!("session decrypted with previous (rotated) key");
114                return Ok(plaintext);
115            }
116            tracing::warn!(
117                "session decryption also failed with previous key; possible data corruption or key mismatch"
118            );
119        } else {
120            tracing::warn!(
121                "session decryption failed with current key and no previous key configured"
122            );
123        }
124
125        Err(CryptoError)
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn encrypt_decrypt_roundtrip() {
135        let crypto = SessionCrypto::new([42u8; 32]);
136        let plaintext = b"hello session data";
137        let encrypted = crypto.encrypt(plaintext).unwrap();
138        let decrypted = crypto.decrypt(&encrypted).unwrap();
139        assert_eq!(decrypted, plaintext);
140    }
141
142    #[test]
143    fn decrypt_wrong_key_fails() {
144        let crypto1 = SessionCrypto::new([1u8; 32]);
145        let crypto2 = SessionCrypto::new([2u8; 32]);
146        let encrypted = crypto1.encrypt(b"secret").unwrap();
147        assert!(crypto2.decrypt(&encrypted).is_err());
148    }
149
150    #[test]
151    fn key_rotation_decrypt_with_previous() {
152        let old_key = [1u8; 32];
153        let new_key = [2u8; 32];
154
155        let old_crypto = SessionCrypto::new(old_key);
156        let encrypted = old_crypto.encrypt(b"rotated data").unwrap();
157
158        // New config with rotation: current = new_key, previous = old_key.
159        let new_crypto = SessionCrypto::new(new_key).with_previous_key(old_key);
160        let decrypted = new_crypto.decrypt(&encrypted).unwrap();
161        assert_eq!(decrypted, b"rotated data");
162    }
163
164    #[test]
165    fn short_data_fails() {
166        let crypto = SessionCrypto::new([42u8; 32]);
167        assert!(crypto.decrypt(&[0u8; 5]).is_err());
168    }
169
170    /// `decrypt` with a configured previous key still uses the current
171    /// key first when current succeeds. Pins against an "always fall
172    /// through to previous" mutation that would silently re-encrypt
173    /// everything under the old key on the next write.
174    #[test]
175    fn key_rotation_current_key_wins_when_valid() {
176        let current = [9u8; 32];
177        let prev = [1u8; 32];
178        let crypto = SessionCrypto::new(current).with_previous_key(prev);
179        let encrypted = crypto.encrypt(b"current-key data").unwrap();
180        let decrypted = crypto.decrypt(&encrypted).unwrap();
181        assert_eq!(decrypted, b"current-key data");
182    }
183
184    /// Boundary: payload of exactly NONCE_LEN (12) bytes has no
185    /// ciphertext to authenticate and must fail. Pins `< NONCE_LEN`
186    /// against `<=` (which would also reject 12-byte payloads but allow
187    /// 12-byte-only nonce splits to attempt decryption) and `==` /
188    /// `delete` mutants.
189    #[test]
190    fn decrypt_payload_at_nonce_length_boundary_fails() {
191        let crypto = SessionCrypto::new([7u8; 32]);
192        assert!(crypto.decrypt(&[0u8; NONCE_LEN]).is_err());
193    }
194
195    /// `with_previous_key` must store the supplied key in `previous`,
196    /// not overwrite `current`. Pins the assignment direction.
197    #[test]
198    fn with_previous_key_does_not_replace_current_key() {
199        let current = [3u8; 32];
200        let prev = [4u8; 32];
201        // Configure with prev, then encrypt; encryption uses CURRENT.
202        let crypto = SessionCrypto::new(current).with_previous_key(prev);
203        let encrypted = crypto.encrypt(b"under-current").unwrap();
204
205        // A standalone instance keyed by `current` (no rotation) must
206        // decrypt the ciphertext, proving encryption used `current`.
207        let just_current = SessionCrypto::new(current);
208        assert_eq!(just_current.decrypt(&encrypted).unwrap(), b"under-current");
209
210        // And an instance keyed only by `prev` must FAIL to decrypt;
211        // confirming `with_previous_key` did not move `prev` into the
212        // current slot.
213        let just_prev = SessionCrypto::new(prev);
214        assert!(just_prev.decrypt(&encrypted).is_err());
215    }
216}