1use aes_gcm::Aes256Gcm;
18use aes_gcm::aead::{Aead, KeyInit};
19
20use crate::error::{Result, WalError};
21use crate::record::HEADER_SIZE;
22
23#[derive(Clone)]
29pub struct WalEncryptionKey {
30 cipher: Aes256Gcm,
31 epoch: [u8; 4],
34}
35
36impl WalEncryptionKey {
37 pub fn from_bytes(key: &[u8; 32]) -> Self {
39 let mut epoch = [0u8; 4];
40 getrandom::fill(&mut epoch).expect("getrandom failed");
41 Self {
42 cipher: Aes256Gcm::new(key.into()),
43 epoch,
44 }
45 }
46
47 pub fn from_file(path: &std::path::Path) -> Result<Self> {
49 let key_bytes = std::fs::read(path).map_err(WalError::Io)?;
50 if key_bytes.len() != 32 {
51 return Err(WalError::EncryptionError {
52 detail: format!(
53 "encryption key must be exactly 32 bytes, got {}",
54 key_bytes.len()
55 ),
56 });
57 }
58 let mut key = [0u8; 32];
59 key.copy_from_slice(&key_bytes);
60 Ok(Self::from_bytes(&key))
61 }
62
63 pub fn encrypt(
69 &self,
70 lsn: u64,
71 header_bytes: &[u8; HEADER_SIZE],
72 plaintext: &[u8],
73 ) -> Result<Vec<u8>> {
74 let nonce = lsn_to_nonce(&self.epoch, lsn);
75 self.cipher
76 .encrypt(
77 &nonce,
78 aes_gcm::aead::Payload {
79 msg: plaintext,
80 aad: header_bytes,
81 },
82 )
83 .map_err(|_| WalError::EncryptionError {
84 detail: "AES-256-GCM encryption failed".into(),
85 })
86 }
87
88 pub fn epoch(&self) -> &[u8; 4] {
90 &self.epoch
91 }
92
93 pub fn decrypt(
100 &self,
101 epoch: &[u8; 4],
102 lsn: u64,
103 header_bytes: &[u8; HEADER_SIZE],
104 ciphertext: &[u8],
105 ) -> Result<Vec<u8>> {
106 let nonce = lsn_to_nonce(epoch, lsn);
107 self.cipher
108 .decrypt(
109 &nonce,
110 aes_gcm::aead::Payload {
111 msg: ciphertext,
112 aad: header_bytes,
113 },
114 )
115 .map_err(|_| WalError::EncryptionError {
116 detail: "AES-256-GCM decryption failed (corrupted or wrong key)".into(),
117 })
118 }
119}
120
121#[derive(Clone)]
127pub struct KeyRing {
128 current: WalEncryptionKey,
129 previous: Option<WalEncryptionKey>,
130}
131
132impl KeyRing {
133 pub fn new(current: WalEncryptionKey) -> Self {
135 Self {
136 current,
137 previous: None,
138 }
139 }
140
141 pub fn with_previous(current: WalEncryptionKey, previous: WalEncryptionKey) -> Self {
143 Self {
144 current,
145 previous: Some(previous),
146 }
147 }
148
149 pub fn encrypt(
151 &self,
152 lsn: u64,
153 header_bytes: &[u8; HEADER_SIZE],
154 plaintext: &[u8],
155 ) -> Result<Vec<u8>> {
156 self.current.encrypt(lsn, header_bytes, plaintext)
157 }
158
159 pub fn decrypt(
165 &self,
166 epoch: &[u8; 4],
167 lsn: u64,
168 header_bytes: &[u8; HEADER_SIZE],
169 ciphertext: &[u8],
170 ) -> Result<Vec<u8>> {
171 match (
172 self.current.decrypt(epoch, lsn, header_bytes, ciphertext),
173 self.previous.as_ref(),
174 ) {
175 (Ok(plaintext), _) => Ok(plaintext),
176 (Err(_), Some(prev)) => prev.decrypt(epoch, lsn, header_bytes, ciphertext),
177 (Err(e), None) => Err(e),
178 }
179 }
180
181 pub fn current(&self) -> &WalEncryptionKey {
183 &self.current
184 }
185
186 pub fn has_previous(&self) -> bool {
188 self.previous.is_some()
189 }
190
191 pub fn clear_previous(&mut self) {
193 self.previous = None;
194 }
195}
196
197pub const AUTH_TAG_SIZE: usize = 16;
199
200fn lsn_to_nonce(epoch: &[u8; 4], lsn: u64) -> aes_gcm::Nonce<aes_gcm::aead::consts::U12> {
207 let mut nonce_bytes = [0u8; 12];
208 nonce_bytes[..4].copy_from_slice(epoch);
209 nonce_bytes[4..12].copy_from_slice(&lsn.to_le_bytes());
210 nonce_bytes.into()
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 fn test_key() -> WalEncryptionKey {
218 WalEncryptionKey::from_bytes(&[0x42u8; 32])
219 }
220
221 fn test_header(lsn: u64) -> [u8; HEADER_SIZE] {
222 let mut h = [0u8; HEADER_SIZE];
223 h[8..16].copy_from_slice(&lsn.to_le_bytes());
224 h
225 }
226
227 #[test]
228 fn encrypt_decrypt_roundtrip() {
229 let key = test_key();
230 let epoch = *key.epoch();
231 let header = test_header(1);
232 let plaintext = b"hello nodedb encryption";
233
234 let ciphertext = key.encrypt(1, &header, plaintext).unwrap();
235 assert_ne!(&ciphertext[..plaintext.len()], plaintext);
236 assert_eq!(ciphertext.len(), plaintext.len() + AUTH_TAG_SIZE);
237
238 let decrypted = key.decrypt(&epoch, 1, &header, &ciphertext).unwrap();
239 assert_eq!(decrypted, plaintext);
240 }
241
242 #[test]
243 fn wrong_key_fails() {
244 let key1 = WalEncryptionKey::from_bytes(&[0x01; 32]);
245 let epoch1 = *key1.epoch();
246 let key2 = WalEncryptionKey::from_bytes(&[0x02; 32]);
247 let header = test_header(1);
248
249 let ciphertext = key1.encrypt(1, &header, b"secret").unwrap();
250 assert!(key2.decrypt(&epoch1, 1, &header, &ciphertext).is_err());
251 }
252
253 #[test]
254 fn wrong_lsn_fails() {
255 let key = test_key();
256 let epoch = *key.epoch();
257 let header = test_header(1);
258
259 let ciphertext = key.encrypt(1, &header, b"secret").unwrap();
260 assert!(key.decrypt(&epoch, 2, &header, &ciphertext).is_err());
262 }
263
264 #[test]
265 fn tampered_ciphertext_fails() {
266 let key = test_key();
267 let epoch = *key.epoch();
268 let header = test_header(1);
269
270 let mut ciphertext = key.encrypt(1, &header, b"secret").unwrap();
271 ciphertext[0] ^= 0xFF;
272 assert!(key.decrypt(&epoch, 1, &header, &ciphertext).is_err());
273 }
274
275 #[test]
276 fn tampered_header_fails() {
277 let key = test_key();
278 let epoch = *key.epoch();
279 let header1 = test_header(1);
280
281 let ciphertext = key.encrypt(1, &header1, b"secret").unwrap();
282
283 let mut header2 = header1;
285 header2[0] = 0xFF;
286 assert!(key.decrypt(&epoch, 1, &header2, &ciphertext).is_err());
287 }
288
289 #[test]
290 fn empty_payload() {
291 let key = test_key();
292 let epoch = *key.epoch();
293 let header = test_header(1);
294
295 let ciphertext = key.encrypt(1, &header, b"").unwrap();
296 assert_eq!(ciphertext.len(), AUTH_TAG_SIZE); let decrypted = key.decrypt(&epoch, 1, &header, &ciphertext).unwrap();
299 assert!(decrypted.is_empty());
300 }
301
302 #[test]
303 fn different_lsns_produce_different_ciphertext() {
304 let key = test_key();
305 let plaintext = b"same payload";
306
307 let ct1 = key.encrypt(1, &test_header(1), plaintext).unwrap();
308 let ct2 = key.encrypt(2, &test_header(2), plaintext).unwrap();
309 assert_ne!(ct1, ct2);
310 }
311
312 #[test]
313 fn same_lsn_different_wal_lifetimes_produce_different_ciphertext() {
314 let key_bytes = [0x42u8; 32];
319 let key1 = WalEncryptionKey::from_bytes(&key_bytes);
320 let key2 = WalEncryptionKey::from_bytes(&key_bytes);
321 let header = test_header(1);
322 let pt = b"same plaintext in two wal lifetimes";
323
324 let ct1 = key1.encrypt(1, &header, pt).unwrap();
325 let ct2 = key2.encrypt(1, &header, pt).unwrap();
326
327 assert_ne!(
330 ct1, ct2,
331 "nonce reuse: same (key_bytes, lsn) must not produce identical ciphertext across WAL lifetimes"
332 );
333 }
334}