Skip to main content

nodedb_wal/
crypto.rs

1//! WAL payload encryption using AES-256-GCM.
2//!
3//! Design:
4//! - Header stays plaintext (needed for recovery scanning — magic, lsn, tenant_id)
5//! - Payload is encrypted before CRC computation
6//! - CRC covers the ciphertext (detects corruption of encrypted data)
7//! - Nonce derived from LSN (deterministic — no extra storage, enables replay)
8//! - Additional Authenticated Data (AAD) = header bytes (binds ciphertext to its header)
9//!
10//! On-disk format for encrypted payload:
11//! ```text
12//! [header(30B plaintext)] [ciphertext(payload_len bytes)] [auth_tag(16B)]
13//! ```
14//! `payload_len` includes the 16-byte auth tag.
15
16use aes_gcm::Aes256Gcm;
17use aes_gcm::aead::{Aead, KeyInit};
18
19use crate::error::{Result, WalError};
20use crate::record::HEADER_SIZE;
21
22/// AES-256-GCM key: exactly 32 bytes.
23#[derive(Clone)]
24pub struct WalEncryptionKey {
25    cipher: Aes256Gcm,
26}
27
28impl WalEncryptionKey {
29    /// Create from a 32-byte key.
30    pub fn from_bytes(key: &[u8; 32]) -> Self {
31        Self {
32            cipher: Aes256Gcm::new(key.into()),
33        }
34    }
35
36    /// Load key from a file (must contain exactly 32 bytes).
37    pub fn from_file(path: &std::path::Path) -> Result<Self> {
38        let key_bytes = std::fs::read(path).map_err(WalError::Io)?;
39        if key_bytes.len() != 32 {
40            return Err(WalError::EncryptionError {
41                detail: format!(
42                    "encryption key must be exactly 32 bytes, got {}",
43                    key_bytes.len()
44                ),
45            });
46        }
47        let mut key = [0u8; 32];
48        key.copy_from_slice(&key_bytes);
49        Ok(Self::from_bytes(&key))
50    }
51
52    /// Encrypt a payload. Returns ciphertext + auth_tag (16 bytes appended).
53    ///
54    /// - `lsn`: used to derive a deterministic nonce
55    /// - `header_bytes`: used as AAD (additional authenticated data)
56    /// - `plaintext`: the payload to encrypt
57    pub fn encrypt(
58        &self,
59        lsn: u64,
60        header_bytes: &[u8; HEADER_SIZE],
61        plaintext: &[u8],
62    ) -> Result<Vec<u8>> {
63        let nonce = lsn_to_nonce(lsn);
64        self.cipher
65            .encrypt(
66                &nonce,
67                aes_gcm::aead::Payload {
68                    msg: plaintext,
69                    aad: header_bytes,
70                },
71            )
72            .map_err(|_| WalError::EncryptionError {
73                detail: "AES-256-GCM encryption failed".into(),
74            })
75    }
76
77    /// Decrypt a payload. Input is ciphertext + auth_tag (16 bytes at end).
78    ///
79    /// - `lsn`: must match the LSN used during encryption
80    /// - `header_bytes`: must match the header used during encryption (AAD)
81    /// - `ciphertext`: the encrypted payload (includes 16-byte auth tag)
82    pub fn decrypt(
83        &self,
84        lsn: u64,
85        header_bytes: &[u8; HEADER_SIZE],
86        ciphertext: &[u8],
87    ) -> Result<Vec<u8>> {
88        let nonce = lsn_to_nonce(lsn);
89        self.cipher
90            .decrypt(
91                &nonce,
92                aes_gcm::aead::Payload {
93                    msg: ciphertext,
94                    aad: header_bytes,
95                },
96            )
97            .map_err(|_| WalError::EncryptionError {
98                detail: "AES-256-GCM decryption failed (corrupted or wrong key)".into(),
99            })
100    }
101}
102
103/// Key ring supporting dual-key reads for seamless key rotation.
104///
105/// During rotation: new writes use the current key, reads try current
106/// then fall back to previous. Once all old data is re-encrypted,
107/// the previous key is removed.
108#[derive(Clone)]
109pub struct KeyRing {
110    current: WalEncryptionKey,
111    previous: Option<WalEncryptionKey>,
112}
113
114impl KeyRing {
115    /// Create a key ring with only the current key.
116    pub fn new(current: WalEncryptionKey) -> Self {
117        Self {
118            current,
119            previous: None,
120        }
121    }
122
123    /// Create a key ring with current + previous key (for rotation).
124    pub fn with_previous(current: WalEncryptionKey, previous: WalEncryptionKey) -> Self {
125        Self {
126            current,
127            previous: Some(previous),
128        }
129    }
130
131    /// Encrypt using the current key.
132    pub fn encrypt(
133        &self,
134        lsn: u64,
135        header_bytes: &[u8; HEADER_SIZE],
136        plaintext: &[u8],
137    ) -> Result<Vec<u8>> {
138        self.current.encrypt(lsn, header_bytes, plaintext)
139    }
140
141    /// Decrypt: try current key first, then previous (if set).
142    ///
143    /// This enables seamless key rotation — old data encrypted with the
144    /// previous key can still be read while new data uses the current key.
145    pub fn decrypt(
146        &self,
147        lsn: u64,
148        header_bytes: &[u8; HEADER_SIZE],
149        ciphertext: &[u8],
150    ) -> Result<Vec<u8>> {
151        match self.current.decrypt(lsn, header_bytes, ciphertext) {
152            Ok(plaintext) => Ok(plaintext),
153            Err(_) if self.previous.is_some() => {
154                // Current key failed — try previous key.
155                if let Some(prev) = self.previous.as_ref() {
156                    prev.decrypt(lsn, header_bytes, ciphertext)
157                } else {
158                    Err(crate::error::WalError::EncryptionError {
159                        detail: "key rotation state inconsistent".into(),
160                    })
161                }
162            }
163            Err(e) => Err(e),
164        }
165    }
166
167    /// Get the current key (for encryption operations).
168    pub fn current(&self) -> &WalEncryptionKey {
169        &self.current
170    }
171
172    /// Whether a previous key is present (rotation in progress).
173    pub fn has_previous(&self) -> bool {
174        self.previous.is_some()
175    }
176
177    /// Remove the previous key (rotation complete).
178    pub fn clear_previous(&mut self) {
179        self.previous = None;
180    }
181}
182
183/// AES-256-GCM auth tag size in bytes.
184pub const AUTH_TAG_SIZE: usize = 16;
185
186/// Derive a 12-byte nonce from an LSN.
187///
188/// AES-256-GCM requires a 96-bit (12 byte) nonce. Since LSNs are monotonically
189/// increasing and globally unique, they make ideal deterministic nonces.
190/// We zero-pad the 8-byte LSN to 12 bytes.
191fn lsn_to_nonce(lsn: u64) -> aes_gcm::Nonce<aes_gcm::aead::consts::U12> {
192    let mut nonce_bytes = [0u8; 12];
193    nonce_bytes[..8].copy_from_slice(&lsn.to_le_bytes());
194    nonce_bytes.into()
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    fn test_key() -> WalEncryptionKey {
202        WalEncryptionKey::from_bytes(&[0x42u8; 32])
203    }
204
205    fn test_header(lsn: u64) -> [u8; HEADER_SIZE] {
206        let mut h = [0u8; HEADER_SIZE];
207        h[8..16].copy_from_slice(&lsn.to_le_bytes());
208        h
209    }
210
211    #[test]
212    fn encrypt_decrypt_roundtrip() {
213        let key = test_key();
214        let header = test_header(1);
215        let plaintext = b"hello nodedb encryption";
216
217        let ciphertext = key.encrypt(1, &header, plaintext).unwrap();
218        assert_ne!(&ciphertext[..plaintext.len()], plaintext);
219        assert_eq!(ciphertext.len(), plaintext.len() + AUTH_TAG_SIZE);
220
221        let decrypted = key.decrypt(1, &header, &ciphertext).unwrap();
222        assert_eq!(decrypted, plaintext);
223    }
224
225    #[test]
226    fn wrong_key_fails() {
227        let key1 = WalEncryptionKey::from_bytes(&[0x01; 32]);
228        let key2 = WalEncryptionKey::from_bytes(&[0x02; 32]);
229        let header = test_header(1);
230
231        let ciphertext = key1.encrypt(1, &header, b"secret").unwrap();
232        assert!(key2.decrypt(1, &header, &ciphertext).is_err());
233    }
234
235    #[test]
236    fn wrong_lsn_fails() {
237        let key = test_key();
238        let header = test_header(1);
239
240        let ciphertext = key.encrypt(1, &header, b"secret").unwrap();
241        // Different LSN = different nonce = decryption fails.
242        assert!(key.decrypt(2, &header, &ciphertext).is_err());
243    }
244
245    #[test]
246    fn tampered_ciphertext_fails() {
247        let key = test_key();
248        let header = test_header(1);
249
250        let mut ciphertext = key.encrypt(1, &header, b"secret").unwrap();
251        ciphertext[0] ^= 0xFF;
252        assert!(key.decrypt(1, &header, &ciphertext).is_err());
253    }
254
255    #[test]
256    fn tampered_header_fails() {
257        let key = test_key();
258        let header1 = test_header(1);
259
260        let ciphertext = key.encrypt(1, &header1, b"secret").unwrap();
261
262        // Tamper the AAD (header).
263        let mut header2 = header1;
264        header2[0] = 0xFF;
265        assert!(key.decrypt(1, &header2, &ciphertext).is_err());
266    }
267
268    #[test]
269    fn empty_payload() {
270        let key = test_key();
271        let header = test_header(1);
272
273        let ciphertext = key.encrypt(1, &header, b"").unwrap();
274        assert_eq!(ciphertext.len(), AUTH_TAG_SIZE); // Just the tag.
275
276        let decrypted = key.decrypt(1, &header, &ciphertext).unwrap();
277        assert!(decrypted.is_empty());
278    }
279
280    #[test]
281    fn different_lsns_produce_different_ciphertext() {
282        let key = test_key();
283        let plaintext = b"same payload";
284
285        let ct1 = key.encrypt(1, &test_header(1), plaintext).unwrap();
286        let ct2 = key.encrypt(2, &test_header(2), plaintext).unwrap();
287        assert_ne!(ct1, ct2);
288    }
289}