1use crate::crypto::aes_gcm::{aes256_gcm_decrypt, aes256_gcm_encrypt};
41use crate::crypto::os_random;
42
43pub const FRAME_MAGIC: [u8; 4] = *b"RDEP";
45
46pub const FRAME_VERSION: u8 = 0x01;
48
49pub const FRAME_OVERHEAD: usize = 4 + 1 + 12 + 16;
52
53#[derive(Debug)]
56pub enum PageEncryptionError {
57 InvalidMagic,
58 UnsupportedVersion(u8),
59 Truncated,
60 KeyMismatch(String),
61 RandomFailure(String),
62}
63
64impl std::fmt::Display for PageEncryptionError {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 match self {
67 Self::InvalidMagic => {
68 f.write_str("encrypted page: bad magic — page not produced by encrypt_page")
69 }
70 Self::UnsupportedVersion(v) => write!(f, "encrypted page: unsupported version {v}"),
71 Self::Truncated => f.write_str("encrypted page: truncated frame"),
72 Self::KeyMismatch(detail) => {
73 write!(f, "encrypted page: key mismatch or tampering ({detail})")
74 }
75 Self::RandomFailure(detail) => {
76 write!(f, "encrypted page: nonce generation failed ({detail})")
77 }
78 }
79 }
80}
81
82impl std::error::Error for PageEncryptionError {}
83
84pub fn encrypt_page(
87 key: &[u8; 32],
88 page_id: u64,
89 plaintext: &[u8],
90) -> Result<Vec<u8>, PageEncryptionError> {
91 let mut nonce = [0u8; 12];
92 os_random::fill_bytes(&mut nonce).map_err(PageEncryptionError::RandomFailure)?;
93 let aad = page_id.to_le_bytes();
94 let ciphertext = aes256_gcm_encrypt(key, &nonce, &aad, plaintext);
95
96 let mut out = Vec::with_capacity(FRAME_OVERHEAD + plaintext.len());
97 out.extend_from_slice(&FRAME_MAGIC);
98 out.push(FRAME_VERSION);
99 out.extend_from_slice(&nonce);
100 out.extend_from_slice(&ciphertext);
101 Ok(out)
102}
103
104pub fn decrypt_page(
110 key: &[u8; 32],
111 page_id: u64,
112 frame: &[u8],
113) -> Result<Vec<u8>, PageEncryptionError> {
114 if frame.len() < FRAME_OVERHEAD {
115 return Err(PageEncryptionError::Truncated);
116 }
117 if frame[0..4] != FRAME_MAGIC {
118 return Err(PageEncryptionError::InvalidMagic);
119 }
120 let version = frame[4];
121 if version != FRAME_VERSION {
122 return Err(PageEncryptionError::UnsupportedVersion(version));
123 }
124 let mut nonce = [0u8; 12];
125 nonce.copy_from_slice(&frame[5..17]);
126 let aad = page_id.to_le_bytes();
127 aes256_gcm_decrypt(key, &nonce, &aad, &frame[17..]).map_err(PageEncryptionError::KeyMismatch)
128}
129
130pub fn is_encrypted_frame(bytes: &[u8]) -> bool {
135 bytes.len() >= FRAME_OVERHEAD && bytes[0..4] == FRAME_MAGIC
136}
137
138pub fn parse_key(raw: &str) -> Result<[u8; 32], String> {
142 let trimmed = raw.trim();
143 if trimmed.len() == 64 && trimmed.chars().all(|c| c.is_ascii_hexdigit()) {
145 let mut out = [0u8; 32];
146 for (i, byte) in out.iter_mut().enumerate() {
147 *byte = u8::from_str_radix(&trimmed[i * 2..i * 2 + 2], 16)
148 .map_err(|err| format!("invalid hex key byte {i}: {err}"))?;
149 }
150 return Ok(out);
151 }
152 let decoded = decode_base64(trimmed)
156 .map_err(|err| format!("key is neither 64-hex nor base64 (decode error: {err})"))?;
157 if decoded.len() != 32 {
158 return Err(format!(
159 "decoded key is {} bytes; AES-256-GCM requires exactly 32",
160 decoded.len()
161 ));
162 }
163 let mut out = [0u8; 32];
164 out.copy_from_slice(&decoded);
165 Ok(out)
166}
167
168fn decode_base64(s: &str) -> Result<Vec<u8>, String> {
169 fn val(c: u8) -> Option<u8> {
170 match c {
171 b'A'..=b'Z' => Some(c - b'A'),
172 b'a'..=b'z' => Some(c - b'a' + 26),
173 b'0'..=b'9' => Some(c - b'0' + 52),
174 b'+' => Some(62),
175 b'/' => Some(63),
176 _ => None,
177 }
178 }
179 let bytes: Vec<u8> = s
180 .bytes()
181 .filter(|b| !b.is_ascii_whitespace() && *b != b'=')
182 .collect();
183 let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
184 let mut i = 0;
185 while i + 3 < bytes.len() {
186 let a = val(bytes[i]).ok_or_else(|| format!("invalid base64 char at {i}"))?;
187 let b = val(bytes[i + 1]).ok_or_else(|| format!("invalid base64 char at {}", i + 1))?;
188 let c = val(bytes[i + 2]).ok_or_else(|| format!("invalid base64 char at {}", i + 2))?;
189 let d = val(bytes[i + 3]).ok_or_else(|| format!("invalid base64 char at {}", i + 3))?;
190 out.push((a << 2) | (b >> 4));
191 out.push(((b & 0x0F) << 4) | (c >> 2));
192 out.push(((c & 0x03) << 6) | d);
193 i += 4;
194 }
195 let rem = bytes.len() - i;
196 match rem {
197 0 => {}
198 2 => {
199 let a = val(bytes[i]).ok_or_else(|| format!("invalid base64 char at {i}"))?;
200 let b = val(bytes[i + 1]).ok_or_else(|| format!("invalid base64 char at {}", i + 1))?;
201 out.push((a << 2) | (b >> 4));
202 }
203 3 => {
204 let a = val(bytes[i]).ok_or_else(|| format!("invalid base64 char at {i}"))?;
205 let b = val(bytes[i + 1]).ok_or_else(|| format!("invalid base64 char at {}", i + 1))?;
206 let c = val(bytes[i + 2]).ok_or_else(|| format!("invalid base64 char at {}", i + 2))?;
207 out.push((a << 2) | (b >> 4));
208 out.push(((b & 0x0F) << 4) | (c >> 2));
209 }
210 _ => return Err(format!("invalid base64 length remainder {rem}")),
211 }
212 Ok(out)
213}
214
215pub fn key_from_env() -> Result<Option<[u8; 32]>, String> {
221 match crate::utils::env_with_file_fallback("RED_ENCRYPTION_KEY") {
222 Some(raw) => parse_key(&raw).map(Some),
223 None => Ok(None),
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 fn key() -> [u8; 32] {
232 let mut k = [0u8; 32];
233 for (i, b) in k.iter_mut().enumerate() {
234 *b = i as u8;
235 }
236 k
237 }
238
239 #[test]
240 fn round_trips_plaintext() {
241 let plaintext = b"page bytes that will be encrypted";
242 let frame = encrypt_page(&key(), 7, plaintext).unwrap();
243 assert_eq!(frame.len(), FRAME_OVERHEAD + plaintext.len());
244 assert!(is_encrypted_frame(&frame));
245 let recovered = decrypt_page(&key(), 7, &frame).unwrap();
246 assert_eq!(recovered, plaintext);
247 }
248
249 #[test]
250 fn nonce_is_random_per_call() {
251 let plaintext = b"same payload, different nonce";
252 let f1 = encrypt_page(&key(), 1, plaintext).unwrap();
253 let f2 = encrypt_page(&key(), 1, plaintext).unwrap();
254 assert_ne!(f1, f2);
258 }
259
260 #[test]
261 fn page_id_binding_catches_swapped_pages() {
262 let plaintext = b"page 1 contents";
263 let frame = encrypt_page(&key(), 1, plaintext).unwrap();
264 let err = decrypt_page(&key(), 2, &frame).unwrap_err();
267 assert!(
268 matches!(err, PageEncryptionError::KeyMismatch(_)),
269 "got {err:?}"
270 );
271 }
272
273 #[test]
274 fn wrong_key_fails_closed() {
275 let plaintext = b"sensitive";
276 let frame = encrypt_page(&key(), 5, plaintext).unwrap();
277 let mut wrong = key();
278 wrong[0] ^= 0xff;
279 let err = decrypt_page(&wrong, 5, &frame).unwrap_err();
280 assert!(matches!(err, PageEncryptionError::KeyMismatch(_)));
281 }
282
283 #[test]
284 fn bad_magic_returns_typed_error() {
285 let mut frame = encrypt_page(&key(), 0, b"x").unwrap();
286 frame[0] ^= 0xff;
287 let err = decrypt_page(&key(), 0, &frame).unwrap_err();
288 assert!(matches!(err, PageEncryptionError::InvalidMagic));
289 }
290
291 #[test]
292 fn unsupported_version_is_typed() {
293 let mut frame = encrypt_page(&key(), 0, b"x").unwrap();
294 frame[4] = 0xFE;
295 let err = decrypt_page(&key(), 0, &frame).unwrap_err();
296 assert!(matches!(err, PageEncryptionError::UnsupportedVersion(0xFE)));
297 }
298
299 #[test]
300 fn truncated_frame_is_typed() {
301 let frame = vec![0u8; FRAME_OVERHEAD - 1];
302 let err = decrypt_page(&key(), 0, &frame).unwrap_err();
303 assert!(matches!(err, PageEncryptionError::Truncated));
304 }
305
306 #[test]
307 fn parse_key_accepts_hex() {
308 let hex = "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20";
309 let key = parse_key(hex).unwrap();
310 assert_eq!(key[0], 0x01);
311 assert_eq!(key[31], 0x20);
312 }
313
314 #[test]
315 fn parse_key_accepts_hex_with_whitespace() {
316 let hex = " 0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20\n";
317 assert!(parse_key(hex).is_ok());
318 }
319
320 #[test]
321 fn parse_key_rejects_wrong_length() {
322 assert!(parse_key("ab").is_err());
323 assert!(parse_key("zz".repeat(32).as_str()).is_err()); }
325
326 #[test]
327 fn parse_key_accepts_base64() {
328 let raw = vec![0xAB_u8; 32];
330 let alphabet = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
332 let mut out = String::new();
333 let mut i = 0;
334 while i + 3 <= raw.len() {
335 let n = ((raw[i] as u32) << 16) | ((raw[i + 1] as u32) << 8) | (raw[i + 2] as u32);
336 out.push(alphabet[((n >> 18) & 0x3F) as usize] as char);
337 out.push(alphabet[((n >> 12) & 0x3F) as usize] as char);
338 out.push(alphabet[((n >> 6) & 0x3F) as usize] as char);
339 out.push(alphabet[(n & 0x3F) as usize] as char);
340 i += 3;
341 }
342 if i < raw.len() {
343 let rem = raw.len() - i;
344 let n = if rem == 1 {
345 (raw[i] as u32) << 16
346 } else {
347 ((raw[i] as u32) << 16) | ((raw[i + 1] as u32) << 8)
348 };
349 out.push(alphabet[((n >> 18) & 0x3F) as usize] as char);
350 out.push(alphabet[((n >> 12) & 0x3F) as usize] as char);
351 if rem == 2 {
352 out.push(alphabet[((n >> 6) & 0x3F) as usize] as char);
353 }
354 }
355 let key = parse_key(&out).unwrap();
356 assert_eq!(key, [0xABu8; 32]);
357 }
358}