Skip to main content

nodedb_wal/
crypto.rs

1//! WAL payload encryption using AES-256-GCM.
2//!
3//! Design:
4//! - Header stays plaintext (needed for recovery scanning — magic, lsn, tenant_id)
5//! - Payload is encrypted before CRC computation
6//! - CRC covers the ciphertext (detects corruption of encrypted data)
7//! - Nonce = `[4-byte random epoch][8-byte LSN]` — epoch is generated per WAL
8//!   lifetime to prevent nonce reuse after snapshot restore or WAL truncation
9//! - Additional Authenticated Data (AAD) = header bytes (binds ciphertext to its header)
10//!
11//! On-disk format for encrypted payload:
12//! ```text
13//! [header(30B plaintext)] [ciphertext(payload_len bytes)] [auth_tag(16B)]
14//! ```
15//! `payload_len` includes the 16-byte auth tag.
16
17use aes_gcm::Aes256Gcm;
18use aes_gcm::aead::{Aead, KeyInit};
19
20use crate::error::{Result, WalError};
21use crate::record::HEADER_SIZE;
22
23/// AES-256-GCM key with a random per-lifetime epoch for nonce disambiguation.
24///
25/// The epoch is generated randomly at construction time. Each WAL lifetime
26/// (process start, snapshot restore, segment creation) gets a fresh epoch,
27/// ensuring that nonces are never reused even if LSNs restart from 1.
28#[derive(Clone)]
29pub struct WalEncryptionKey {
30    cipher: Aes256Gcm,
31    /// Random 4-byte epoch: occupies the high 4 bytes of the 12-byte nonce.
32    /// Disambiguates nonces across WAL lifetimes with the same key.
33    epoch: [u8; 4],
34}
35
36impl WalEncryptionKey {
37    /// Create from a 32-byte key with a fresh random epoch.
38    pub fn from_bytes(key: &[u8; 32]) -> Self {
39        let mut epoch = [0u8; 4];
40        getrandom::fill(&mut epoch).expect("getrandom failed");
41        Self {
42            cipher: Aes256Gcm::new(key.into()),
43            epoch,
44        }
45    }
46
47    /// Load key from a file (must contain exactly 32 bytes).
48    pub fn from_file(path: &std::path::Path) -> Result<Self> {
49        let key_bytes = std::fs::read(path).map_err(WalError::Io)?;
50        if key_bytes.len() != 32 {
51            return Err(WalError::EncryptionError {
52                detail: format!(
53                    "encryption key must be exactly 32 bytes, got {}",
54                    key_bytes.len()
55                ),
56            });
57        }
58        let mut key = [0u8; 32];
59        key.copy_from_slice(&key_bytes);
60        Ok(Self::from_bytes(&key))
61    }
62
63    /// Encrypt a payload. Returns ciphertext + auth_tag (16 bytes appended).
64    ///
65    /// - `lsn`: used to derive a deterministic nonce
66    /// - `header_bytes`: used as AAD (additional authenticated data)
67    /// - `plaintext`: the payload to encrypt
68    pub fn encrypt(
69        &self,
70        lsn: u64,
71        header_bytes: &[u8; HEADER_SIZE],
72        plaintext: &[u8],
73    ) -> Result<Vec<u8>> {
74        let nonce = lsn_to_nonce(&self.epoch, lsn);
75        self.cipher
76            .encrypt(
77                &nonce,
78                aes_gcm::aead::Payload {
79                    msg: plaintext,
80                    aad: header_bytes,
81                },
82            )
83            .map_err(|_| WalError::EncryptionError {
84                detail: "AES-256-GCM encryption failed".into(),
85            })
86    }
87
88    /// The random epoch for this key instance.
89    pub fn epoch(&self) -> &[u8; 4] {
90        &self.epoch
91    }
92
93    /// Decrypt a payload. Input is ciphertext + auth_tag (16 bytes at end).
94    ///
95    /// - `epoch`: the epoch that was used during encryption (from the segment header)
96    /// - `lsn`: must match the LSN used during encryption
97    /// - `header_bytes`: must match the header used during encryption (AAD)
98    /// - `ciphertext`: the encrypted payload (includes 16-byte auth tag)
99    pub fn decrypt(
100        &self,
101        epoch: &[u8; 4],
102        lsn: u64,
103        header_bytes: &[u8; HEADER_SIZE],
104        ciphertext: &[u8],
105    ) -> Result<Vec<u8>> {
106        let nonce = lsn_to_nonce(epoch, lsn);
107        self.cipher
108            .decrypt(
109                &nonce,
110                aes_gcm::aead::Payload {
111                    msg: ciphertext,
112                    aad: header_bytes,
113                },
114            )
115            .map_err(|_| WalError::EncryptionError {
116                detail: "AES-256-GCM decryption failed (corrupted or wrong key)".into(),
117            })
118    }
119}
120
121/// Key ring supporting dual-key reads for seamless key rotation.
122///
123/// During rotation: new writes use the current key, reads try current
124/// then fall back to previous. Once all old data is re-encrypted,
125/// the previous key is removed.
126#[derive(Clone)]
127pub struct KeyRing {
128    current: WalEncryptionKey,
129    previous: Option<WalEncryptionKey>,
130}
131
132impl KeyRing {
133    /// Create a key ring with only the current key.
134    pub fn new(current: WalEncryptionKey) -> Self {
135        Self {
136            current,
137            previous: None,
138        }
139    }
140
141    /// Create a key ring with current + previous key (for rotation).
142    pub fn with_previous(current: WalEncryptionKey, previous: WalEncryptionKey) -> Self {
143        Self {
144            current,
145            previous: Some(previous),
146        }
147    }
148
149    /// Encrypt using the current key.
150    pub fn encrypt(
151        &self,
152        lsn: u64,
153        header_bytes: &[u8; HEADER_SIZE],
154        plaintext: &[u8],
155    ) -> Result<Vec<u8>> {
156        self.current.encrypt(lsn, header_bytes, plaintext)
157    }
158
159    /// Decrypt: try current key first, then previous (if set).
160    ///
161    /// `epoch` is the encryption epoch stored in the WAL segment header.
162    /// This enables seamless key rotation — old data encrypted with the
163    /// previous key can still be read while new data uses the current key.
164    pub fn decrypt(
165        &self,
166        epoch: &[u8; 4],
167        lsn: u64,
168        header_bytes: &[u8; HEADER_SIZE],
169        ciphertext: &[u8],
170    ) -> Result<Vec<u8>> {
171        match (
172            self.current.decrypt(epoch, lsn, header_bytes, ciphertext),
173            self.previous.as_ref(),
174        ) {
175            (Ok(plaintext), _) => Ok(plaintext),
176            (Err(_), Some(prev)) => prev.decrypt(epoch, lsn, header_bytes, ciphertext),
177            (Err(e), None) => Err(e),
178        }
179    }
180
181    /// Get the current key (for encryption operations).
182    pub fn current(&self) -> &WalEncryptionKey {
183        &self.current
184    }
185
186    /// Whether a previous key is present (rotation in progress).
187    pub fn has_previous(&self) -> bool {
188        self.previous.is_some()
189    }
190
191    /// Remove the previous key (rotation complete).
192    pub fn clear_previous(&mut self) {
193        self.previous = None;
194    }
195}
196
197/// AES-256-GCM auth tag size in bytes.
198pub const AUTH_TAG_SIZE: usize = 16;
199
200/// Derive a 12-byte nonce from an epoch and LSN.
201///
202/// AES-256-GCM requires a 96-bit (12 byte) nonce that must never repeat
203/// for the same key. Layout: `[4-byte random epoch][8-byte LSN]`.
204/// The epoch is generated randomly per WAL lifetime, so even if LSNs
205/// restart from 1 after a snapshot restore, the nonces remain unique.
206fn lsn_to_nonce(epoch: &[u8; 4], lsn: u64) -> aes_gcm::Nonce<aes_gcm::aead::consts::U12> {
207    let mut nonce_bytes = [0u8; 12];
208    nonce_bytes[..4].copy_from_slice(epoch);
209    nonce_bytes[4..12].copy_from_slice(&lsn.to_le_bytes());
210    nonce_bytes.into()
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    fn test_key() -> WalEncryptionKey {
218        WalEncryptionKey::from_bytes(&[0x42u8; 32])
219    }
220
221    fn test_header(lsn: u64) -> [u8; HEADER_SIZE] {
222        let mut h = [0u8; HEADER_SIZE];
223        h[8..16].copy_from_slice(&lsn.to_le_bytes());
224        h
225    }
226
227    #[test]
228    fn encrypt_decrypt_roundtrip() {
229        let key = test_key();
230        let epoch = *key.epoch();
231        let header = test_header(1);
232        let plaintext = b"hello nodedb encryption";
233
234        let ciphertext = key.encrypt(1, &header, plaintext).unwrap();
235        assert_ne!(&ciphertext[..plaintext.len()], plaintext);
236        assert_eq!(ciphertext.len(), plaintext.len() + AUTH_TAG_SIZE);
237
238        let decrypted = key.decrypt(&epoch, 1, &header, &ciphertext).unwrap();
239        assert_eq!(decrypted, plaintext);
240    }
241
242    #[test]
243    fn wrong_key_fails() {
244        let key1 = WalEncryptionKey::from_bytes(&[0x01; 32]);
245        let epoch1 = *key1.epoch();
246        let key2 = WalEncryptionKey::from_bytes(&[0x02; 32]);
247        let header = test_header(1);
248
249        let ciphertext = key1.encrypt(1, &header, b"secret").unwrap();
250        assert!(key2.decrypt(&epoch1, 1, &header, &ciphertext).is_err());
251    }
252
253    #[test]
254    fn wrong_lsn_fails() {
255        let key = test_key();
256        let epoch = *key.epoch();
257        let header = test_header(1);
258
259        let ciphertext = key.encrypt(1, &header, b"secret").unwrap();
260        // Different LSN = different nonce = decryption fails.
261        assert!(key.decrypt(&epoch, 2, &header, &ciphertext).is_err());
262    }
263
264    #[test]
265    fn tampered_ciphertext_fails() {
266        let key = test_key();
267        let epoch = *key.epoch();
268        let header = test_header(1);
269
270        let mut ciphertext = key.encrypt(1, &header, b"secret").unwrap();
271        ciphertext[0] ^= 0xFF;
272        assert!(key.decrypt(&epoch, 1, &header, &ciphertext).is_err());
273    }
274
275    #[test]
276    fn tampered_header_fails() {
277        let key = test_key();
278        let epoch = *key.epoch();
279        let header1 = test_header(1);
280
281        let ciphertext = key.encrypt(1, &header1, b"secret").unwrap();
282
283        // Tamper the AAD (header).
284        let mut header2 = header1;
285        header2[0] = 0xFF;
286        assert!(key.decrypt(&epoch, 1, &header2, &ciphertext).is_err());
287    }
288
289    #[test]
290    fn empty_payload() {
291        let key = test_key();
292        let epoch = *key.epoch();
293        let header = test_header(1);
294
295        let ciphertext = key.encrypt(1, &header, b"").unwrap();
296        assert_eq!(ciphertext.len(), AUTH_TAG_SIZE); // Just the tag.
297
298        let decrypted = key.decrypt(&epoch, 1, &header, &ciphertext).unwrap();
299        assert!(decrypted.is_empty());
300    }
301
302    #[test]
303    fn different_lsns_produce_different_ciphertext() {
304        let key = test_key();
305        let plaintext = b"same payload";
306
307        let ct1 = key.encrypt(1, &test_header(1), plaintext).unwrap();
308        let ct2 = key.encrypt(2, &test_header(2), plaintext).unwrap();
309        assert_ne!(ct1, ct2);
310    }
311
312    #[test]
313    fn same_lsn_different_wal_lifetimes_produce_different_ciphertext() {
314        // Simulate two WAL lifetimes: same key bytes, same LSN=1, but
315        // separate WalEncryptionKey instances (each gets a fresh random epoch).
316        // This models: write at LSN=1, wipe WAL, restart with same key,
317        // write at LSN=1 again. The two ciphertexts must differ.
318        let key_bytes = [0x42u8; 32];
319        let key1 = WalEncryptionKey::from_bytes(&key_bytes);
320        let key2 = WalEncryptionKey::from_bytes(&key_bytes);
321        let header = test_header(1);
322        let pt = b"same plaintext in two wal lifetimes";
323
324        let ct1 = key1.encrypt(1, &header, pt).unwrap();
325        let ct2 = key2.encrypt(1, &header, pt).unwrap();
326
327        // SPEC: different WAL lifetimes (different epochs) must produce
328        // different ciphertext even with the same key bytes and LSN.
329        assert_ne!(
330            ct1, ct2,
331            "nonce reuse: same (key_bytes, lsn) must not produce identical ciphertext across WAL lifetimes"
332        );
333    }
334}