1use aes_gcm::Aes256Gcm;
17use aes_gcm::aead::{Aead, KeyInit};
18
19use crate::error::{Result, WalError};
20use crate::record::HEADER_SIZE;
21
22#[derive(Clone)]
24pub struct WalEncryptionKey {
25 cipher: Aes256Gcm,
26}
27
28impl WalEncryptionKey {
29 pub fn from_bytes(key: &[u8; 32]) -> Self {
31 Self {
32 cipher: Aes256Gcm::new(key.into()),
33 }
34 }
35
36 pub fn from_file(path: &std::path::Path) -> Result<Self> {
38 let key_bytes = std::fs::read(path).map_err(WalError::Io)?;
39 if key_bytes.len() != 32 {
40 return Err(WalError::EncryptionError {
41 detail: format!(
42 "encryption key must be exactly 32 bytes, got {}",
43 key_bytes.len()
44 ),
45 });
46 }
47 let mut key = [0u8; 32];
48 key.copy_from_slice(&key_bytes);
49 Ok(Self::from_bytes(&key))
50 }
51
52 pub fn encrypt(
58 &self,
59 lsn: u64,
60 header_bytes: &[u8; HEADER_SIZE],
61 plaintext: &[u8],
62 ) -> Result<Vec<u8>> {
63 let nonce = lsn_to_nonce(lsn);
64 self.cipher
65 .encrypt(
66 &nonce,
67 aes_gcm::aead::Payload {
68 msg: plaintext,
69 aad: header_bytes,
70 },
71 )
72 .map_err(|_| WalError::EncryptionError {
73 detail: "AES-256-GCM encryption failed".into(),
74 })
75 }
76
77 pub fn decrypt(
83 &self,
84 lsn: u64,
85 header_bytes: &[u8; HEADER_SIZE],
86 ciphertext: &[u8],
87 ) -> Result<Vec<u8>> {
88 let nonce = lsn_to_nonce(lsn);
89 self.cipher
90 .decrypt(
91 &nonce,
92 aes_gcm::aead::Payload {
93 msg: ciphertext,
94 aad: header_bytes,
95 },
96 )
97 .map_err(|_| WalError::EncryptionError {
98 detail: "AES-256-GCM decryption failed (corrupted or wrong key)".into(),
99 })
100 }
101}
102
103#[derive(Clone)]
109pub struct KeyRing {
110 current: WalEncryptionKey,
111 previous: Option<WalEncryptionKey>,
112}
113
114impl KeyRing {
115 pub fn new(current: WalEncryptionKey) -> Self {
117 Self {
118 current,
119 previous: None,
120 }
121 }
122
123 pub fn with_previous(current: WalEncryptionKey, previous: WalEncryptionKey) -> Self {
125 Self {
126 current,
127 previous: Some(previous),
128 }
129 }
130
131 pub fn encrypt(
133 &self,
134 lsn: u64,
135 header_bytes: &[u8; HEADER_SIZE],
136 plaintext: &[u8],
137 ) -> Result<Vec<u8>> {
138 self.current.encrypt(lsn, header_bytes, plaintext)
139 }
140
141 pub fn decrypt(
146 &self,
147 lsn: u64,
148 header_bytes: &[u8; HEADER_SIZE],
149 ciphertext: &[u8],
150 ) -> Result<Vec<u8>> {
151 match self.current.decrypt(lsn, header_bytes, ciphertext) {
152 Ok(plaintext) => Ok(plaintext),
153 Err(_) if self.previous.is_some() => {
154 if let Some(prev) = self.previous.as_ref() {
156 prev.decrypt(lsn, header_bytes, ciphertext)
157 } else {
158 Err(crate::error::WalError::EncryptionError {
159 detail: "key rotation state inconsistent".into(),
160 })
161 }
162 }
163 Err(e) => Err(e),
164 }
165 }
166
167 pub fn current(&self) -> &WalEncryptionKey {
169 &self.current
170 }
171
172 pub fn has_previous(&self) -> bool {
174 self.previous.is_some()
175 }
176
177 pub fn clear_previous(&mut self) {
179 self.previous = None;
180 }
181}
182
183pub const AUTH_TAG_SIZE: usize = 16;
185
186fn lsn_to_nonce(lsn: u64) -> aes_gcm::Nonce<aes_gcm::aead::consts::U12> {
192 let mut nonce_bytes = [0u8; 12];
193 nonce_bytes[..8].copy_from_slice(&lsn.to_le_bytes());
194 nonce_bytes.into()
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 fn test_key() -> WalEncryptionKey {
202 WalEncryptionKey::from_bytes(&[0x42u8; 32])
203 }
204
205 fn test_header(lsn: u64) -> [u8; HEADER_SIZE] {
206 let mut h = [0u8; HEADER_SIZE];
207 h[8..16].copy_from_slice(&lsn.to_le_bytes());
208 h
209 }
210
211 #[test]
212 fn encrypt_decrypt_roundtrip() {
213 let key = test_key();
214 let header = test_header(1);
215 let plaintext = b"hello nodedb encryption";
216
217 let ciphertext = key.encrypt(1, &header, plaintext).unwrap();
218 assert_ne!(&ciphertext[..plaintext.len()], plaintext);
219 assert_eq!(ciphertext.len(), plaintext.len() + AUTH_TAG_SIZE);
220
221 let decrypted = key.decrypt(1, &header, &ciphertext).unwrap();
222 assert_eq!(decrypted, plaintext);
223 }
224
225 #[test]
226 fn wrong_key_fails() {
227 let key1 = WalEncryptionKey::from_bytes(&[0x01; 32]);
228 let key2 = WalEncryptionKey::from_bytes(&[0x02; 32]);
229 let header = test_header(1);
230
231 let ciphertext = key1.encrypt(1, &header, b"secret").unwrap();
232 assert!(key2.decrypt(1, &header, &ciphertext).is_err());
233 }
234
235 #[test]
236 fn wrong_lsn_fails() {
237 let key = test_key();
238 let header = test_header(1);
239
240 let ciphertext = key.encrypt(1, &header, b"secret").unwrap();
241 assert!(key.decrypt(2, &header, &ciphertext).is_err());
243 }
244
245 #[test]
246 fn tampered_ciphertext_fails() {
247 let key = test_key();
248 let header = test_header(1);
249
250 let mut ciphertext = key.encrypt(1, &header, b"secret").unwrap();
251 ciphertext[0] ^= 0xFF;
252 assert!(key.decrypt(1, &header, &ciphertext).is_err());
253 }
254
255 #[test]
256 fn tampered_header_fails() {
257 let key = test_key();
258 let header1 = test_header(1);
259
260 let ciphertext = key.encrypt(1, &header1, b"secret").unwrap();
261
262 let mut header2 = header1;
264 header2[0] = 0xFF;
265 assert!(key.decrypt(1, &header2, &ciphertext).is_err());
266 }
267
268 #[test]
269 fn empty_payload() {
270 let key = test_key();
271 let header = test_header(1);
272
273 let ciphertext = key.encrypt(1, &header, b"").unwrap();
274 assert_eq!(ciphertext.len(), AUTH_TAG_SIZE); let decrypted = key.decrypt(1, &header, &ciphertext).unwrap();
277 assert!(decrypted.is_empty());
278 }
279
280 #[test]
281 fn different_lsns_produce_different_ciphertext() {
282 let key = test_key();
283 let plaintext = b"same payload";
284
285 let ct1 = key.encrypt(1, &test_header(1), plaintext).unwrap();
286 let ct2 = key.encrypt(2, &test_header(2), plaintext).unwrap();
287 assert_ne!(ct1, ct2);
288 }
289}