axess_core/session/
crypto.rs1use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
13use axess_rng::{SecureRng, SystemRng};
14use std::sync::Arc;
15
16const NONCE_LEN: usize = 12;
17
18#[derive(Clone)]
20pub struct EncryptionKey(pub(crate) [u8; 32]);
21
22impl Drop for EncryptionKey {
23 fn drop(&mut self) {
24 zeroize::Zeroize::zeroize(&mut self.0);
25 }
26}
27
28#[derive(Clone)]
30pub struct SessionCrypto {
31 current: Arc<EncryptionKey>,
32 previous: Option<Arc<EncryptionKey>>,
33 rng: Arc<dyn SecureRng>,
34}
35
36#[derive(Debug, thiserror::Error)]
38#[error("session encryption/decryption error")]
39pub struct CryptoError;
40
41impl SessionCrypto {
42 pub fn new(key: [u8; 32]) -> Self {
46 Self {
47 current: Arc::new(EncryptionKey(key)),
48 previous: None,
49 rng: Arc::new(SystemRng),
50 }
51 }
52
53 pub fn with_rng(mut self, rng: Arc<dyn SecureRng>) -> Self {
57 self.rng = rng;
58 self
59 }
60
61 pub fn with_previous_key(mut self, key: [u8; 32]) -> Self {
64 self.previous = Some(Arc::new(EncryptionKey(key)));
65 self
66 }
67
68 pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
76 let cipher = Aes256Gcm::new_from_slice(&self.current.0).map_err(|_| CryptoError)?;
77
78 let mut nonce_bytes = [0u8; NONCE_LEN];
79 self.rng.fill_bytes(&mut nonce_bytes);
80 let nonce = Nonce::from_slice(&nonce_bytes);
81
82 let ciphertext = cipher.encrypt(nonce, plaintext).map_err(|_| CryptoError)?;
83
84 let mut out = Vec::with_capacity(NONCE_LEN + ciphertext.len());
85 out.extend_from_slice(&nonce_bytes);
86 out.extend_from_slice(&ciphertext);
87 Ok(out)
88 }
89
90 pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, CryptoError> {
93 if data.len() < NONCE_LEN {
94 return Err(CryptoError);
95 }
96
97 let (nonce_bytes, ciphertext) = data.split_at(NONCE_LEN);
98 let nonce = Nonce::from_slice(nonce_bytes);
99
100 let cipher = Aes256Gcm::new_from_slice(&self.current.0).map_err(|_| CryptoError)?;
101
102 if let Ok(plaintext) = cipher.decrypt(nonce, ciphertext) {
103 return Ok(plaintext);
104 }
105
106 if let Some(prev) = &self.previous {
107 tracing::warn!(
108 "session decryption failed with current key; trying previous key (rotation fallback)"
109 );
110 let old_cipher = Aes256Gcm::new_from_slice(&prev.0).map_err(|_| CryptoError)?;
111
112 if let Ok(plaintext) = old_cipher.decrypt(nonce, ciphertext) {
113 tracing::debug!("session decrypted with previous (rotated) key");
114 return Ok(plaintext);
115 }
116 tracing::warn!(
117 "session decryption also failed with previous key; possible data corruption or key mismatch"
118 );
119 } else {
120 tracing::warn!(
121 "session decryption failed with current key and no previous key configured"
122 );
123 }
124
125 Err(CryptoError)
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn encrypt_decrypt_roundtrip() {
135 let crypto = SessionCrypto::new([42u8; 32]);
136 let plaintext = b"hello session data";
137 let encrypted = crypto.encrypt(plaintext).unwrap();
138 let decrypted = crypto.decrypt(&encrypted).unwrap();
139 assert_eq!(decrypted, plaintext);
140 }
141
142 #[test]
143 fn decrypt_wrong_key_fails() {
144 let crypto1 = SessionCrypto::new([1u8; 32]);
145 let crypto2 = SessionCrypto::new([2u8; 32]);
146 let encrypted = crypto1.encrypt(b"secret").unwrap();
147 assert!(crypto2.decrypt(&encrypted).is_err());
148 }
149
150 #[test]
151 fn key_rotation_decrypt_with_previous() {
152 let old_key = [1u8; 32];
153 let new_key = [2u8; 32];
154
155 let old_crypto = SessionCrypto::new(old_key);
156 let encrypted = old_crypto.encrypt(b"rotated data").unwrap();
157
158 let new_crypto = SessionCrypto::new(new_key).with_previous_key(old_key);
160 let decrypted = new_crypto.decrypt(&encrypted).unwrap();
161 assert_eq!(decrypted, b"rotated data");
162 }
163
164 #[test]
165 fn short_data_fails() {
166 let crypto = SessionCrypto::new([42u8; 32]);
167 assert!(crypto.decrypt(&[0u8; 5]).is_err());
168 }
169
170 #[test]
175 fn key_rotation_current_key_wins_when_valid() {
176 let current = [9u8; 32];
177 let prev = [1u8; 32];
178 let crypto = SessionCrypto::new(current).with_previous_key(prev);
179 let encrypted = crypto.encrypt(b"current-key data").unwrap();
180 let decrypted = crypto.decrypt(&encrypted).unwrap();
181 assert_eq!(decrypted, b"current-key data");
182 }
183
184 #[test]
190 fn decrypt_payload_at_nonce_length_boundary_fails() {
191 let crypto = SessionCrypto::new([7u8; 32]);
192 assert!(crypto.decrypt(&[0u8; NONCE_LEN]).is_err());
193 }
194
195 #[test]
198 fn with_previous_key_does_not_replace_current_key() {
199 let current = [3u8; 32];
200 let prev = [4u8; 32];
201 let crypto = SessionCrypto::new(current).with_previous_key(prev);
203 let encrypted = crypto.encrypt(b"under-current").unwrap();
204
205 let just_current = SessionCrypto::new(current);
208 assert_eq!(just_current.decrypt(&encrypted).unwrap(), b"under-current");
209
210 let just_prev = SessionCrypto::new(prev);
214 assert!(just_prev.decrypt(&encrypted).is_err());
215 }
216}