use crate::crypto::aes_gcm::{aes256_gcm_decrypt, aes256_gcm_encrypt};
use crate::crypto::os_random;
pub const FRAME_MAGIC: [u8; 4] = *b"RDEP";
pub const FRAME_VERSION: u8 = 0x01;
pub const FRAME_OVERHEAD: usize = 4 + 1 + 12 + 16;
#[derive(Debug)]
pub enum PageEncryptionError {
InvalidMagic,
UnsupportedVersion(u8),
Truncated,
KeyMismatch(String),
RandomFailure(String),
}
impl std::fmt::Display for PageEncryptionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidMagic => {
f.write_str("encrypted page: bad magic — page not produced by encrypt_page")
}
Self::UnsupportedVersion(v) => write!(f, "encrypted page: unsupported version {v}"),
Self::Truncated => f.write_str("encrypted page: truncated frame"),
Self::KeyMismatch(detail) => {
write!(f, "encrypted page: key mismatch or tampering ({detail})")
}
Self::RandomFailure(detail) => {
write!(f, "encrypted page: nonce generation failed ({detail})")
}
}
}
}
impl std::error::Error for PageEncryptionError {}
pub fn encrypt_page(
key: &[u8; 32],
page_id: u64,
plaintext: &[u8],
) -> Result<Vec<u8>, PageEncryptionError> {
let mut nonce = [0u8; 12];
os_random::fill_bytes(&mut nonce).map_err(PageEncryptionError::RandomFailure)?;
let aad = page_id.to_le_bytes();
let ciphertext = aes256_gcm_encrypt(key, &nonce, &aad, plaintext);
let mut out = Vec::with_capacity(FRAME_OVERHEAD + plaintext.len());
out.extend_from_slice(&FRAME_MAGIC);
out.push(FRAME_VERSION);
out.extend_from_slice(&nonce);
out.extend_from_slice(&ciphertext);
Ok(out)
}
pub fn decrypt_page(
key: &[u8; 32],
page_id: u64,
frame: &[u8],
) -> Result<Vec<u8>, PageEncryptionError> {
if frame.len() < FRAME_OVERHEAD {
return Err(PageEncryptionError::Truncated);
}
if frame[0..4] != FRAME_MAGIC {
return Err(PageEncryptionError::InvalidMagic);
}
let version = frame[4];
if version != FRAME_VERSION {
return Err(PageEncryptionError::UnsupportedVersion(version));
}
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&frame[5..17]);
let aad = page_id.to_le_bytes();
aes256_gcm_decrypt(key, &nonce, &aad, &frame[17..]).map_err(PageEncryptionError::KeyMismatch)
}
pub fn is_encrypted_frame(bytes: &[u8]) -> bool {
bytes.len() >= FRAME_OVERHEAD && bytes[0..4] == FRAME_MAGIC
}
pub fn parse_key(raw: &str) -> Result<[u8; 32], String> {
let trimmed = raw.trim();
if trimmed.len() == 64 && trimmed.chars().all(|c| c.is_ascii_hexdigit()) {
let mut out = [0u8; 32];
for (i, byte) in out.iter_mut().enumerate() {
*byte = u8::from_str_radix(&trimmed[i * 2..i * 2 + 2], 16)
.map_err(|err| format!("invalid hex key byte {i}: {err}"))?;
}
return Ok(out);
}
let decoded = decode_base64(trimmed)
.map_err(|err| format!("key is neither 64-hex nor base64 (decode error: {err})"))?;
if decoded.len() != 32 {
return Err(format!(
"decoded key is {} bytes; AES-256-GCM requires exactly 32",
decoded.len()
));
}
let mut out = [0u8; 32];
out.copy_from_slice(&decoded);
Ok(out)
}
fn decode_base64(s: &str) -> Result<Vec<u8>, String> {
fn val(c: u8) -> Option<u8> {
match c {
b'A'..=b'Z' => Some(c - b'A'),
b'a'..=b'z' => Some(c - b'a' + 26),
b'0'..=b'9' => Some(c - b'0' + 52),
b'+' => Some(62),
b'/' => Some(63),
_ => None,
}
}
let bytes: Vec<u8> = s
.bytes()
.filter(|b| !b.is_ascii_whitespace() && *b != b'=')
.collect();
let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
let mut i = 0;
while i + 3 < bytes.len() {
let a = val(bytes[i]).ok_or_else(|| format!("invalid base64 char at {i}"))?;
let b = val(bytes[i + 1]).ok_or_else(|| format!("invalid base64 char at {}", i + 1))?;
let c = val(bytes[i + 2]).ok_or_else(|| format!("invalid base64 char at {}", i + 2))?;
let d = val(bytes[i + 3]).ok_or_else(|| format!("invalid base64 char at {}", i + 3))?;
out.push((a << 2) | (b >> 4));
out.push(((b & 0x0F) << 4) | (c >> 2));
out.push(((c & 0x03) << 6) | d);
i += 4;
}
let rem = bytes.len() - i;
match rem {
0 => {}
2 => {
let a = val(bytes[i]).ok_or_else(|| format!("invalid base64 char at {i}"))?;
let b = val(bytes[i + 1]).ok_or_else(|| format!("invalid base64 char at {}", i + 1))?;
out.push((a << 2) | (b >> 4));
}
3 => {
let a = val(bytes[i]).ok_or_else(|| format!("invalid base64 char at {i}"))?;
let b = val(bytes[i + 1]).ok_or_else(|| format!("invalid base64 char at {}", i + 1))?;
let c = val(bytes[i + 2]).ok_or_else(|| format!("invalid base64 char at {}", i + 2))?;
out.push((a << 2) | (b >> 4));
out.push(((b & 0x0F) << 4) | (c >> 2));
}
_ => return Err(format!("invalid base64 length remainder {rem}")),
}
Ok(out)
}
pub fn key_from_env() -> Result<Option<[u8; 32]>, String> {
match crate::utils::env_with_file_fallback("RED_ENCRYPTION_KEY") {
Some(raw) => parse_key(&raw).map(Some),
None => Ok(None),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn key() -> [u8; 32] {
let mut k = [0u8; 32];
for (i, b) in k.iter_mut().enumerate() {
*b = i as u8;
}
k
}
#[test]
fn round_trips_plaintext() {
let plaintext = b"page bytes that will be encrypted";
let frame = encrypt_page(&key(), 7, plaintext).unwrap();
assert_eq!(frame.len(), FRAME_OVERHEAD + plaintext.len());
assert!(is_encrypted_frame(&frame));
let recovered = decrypt_page(&key(), 7, &frame).unwrap();
assert_eq!(recovered, plaintext);
}
#[test]
fn nonce_is_random_per_call() {
let plaintext = b"same payload, different nonce";
let f1 = encrypt_page(&key(), 1, plaintext).unwrap();
let f2 = encrypt_page(&key(), 1, plaintext).unwrap();
assert_ne!(f1, f2);
}
#[test]
fn page_id_binding_catches_swapped_pages() {
let plaintext = b"page 1 contents";
let frame = encrypt_page(&key(), 1, plaintext).unwrap();
let err = decrypt_page(&key(), 2, &frame).unwrap_err();
assert!(
matches!(err, PageEncryptionError::KeyMismatch(_)),
"got {err:?}"
);
}
#[test]
fn wrong_key_fails_closed() {
let plaintext = b"sensitive";
let frame = encrypt_page(&key(), 5, plaintext).unwrap();
let mut wrong = key();
wrong[0] ^= 0xff;
let err = decrypt_page(&wrong, 5, &frame).unwrap_err();
assert!(matches!(err, PageEncryptionError::KeyMismatch(_)));
}
#[test]
fn bad_magic_returns_typed_error() {
let mut frame = encrypt_page(&key(), 0, b"x").unwrap();
frame[0] ^= 0xff;
let err = decrypt_page(&key(), 0, &frame).unwrap_err();
assert!(matches!(err, PageEncryptionError::InvalidMagic));
}
#[test]
fn unsupported_version_is_typed() {
let mut frame = encrypt_page(&key(), 0, b"x").unwrap();
frame[4] = 0xFE;
let err = decrypt_page(&key(), 0, &frame).unwrap_err();
assert!(matches!(err, PageEncryptionError::UnsupportedVersion(0xFE)));
}
#[test]
fn truncated_frame_is_typed() {
let frame = vec![0u8; FRAME_OVERHEAD - 1];
let err = decrypt_page(&key(), 0, &frame).unwrap_err();
assert!(matches!(err, PageEncryptionError::Truncated));
}
#[test]
fn parse_key_accepts_hex() {
let hex = "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20";
let key = parse_key(hex).unwrap();
assert_eq!(key[0], 0x01);
assert_eq!(key[31], 0x20);
}
#[test]
fn parse_key_accepts_hex_with_whitespace() {
let hex = " 0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20\n";
assert!(parse_key(hex).is_ok());
}
#[test]
fn parse_key_rejects_wrong_length() {
assert!(parse_key("ab").is_err());
assert!(parse_key("zz".repeat(32).as_str()).is_err()); }
#[test]
fn parse_key_accepts_base64() {
let raw = vec![0xAB_u8; 32];
let alphabet = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut out = String::new();
let mut i = 0;
while i + 3 <= raw.len() {
let n = ((raw[i] as u32) << 16) | ((raw[i + 1] as u32) << 8) | (raw[i + 2] as u32);
out.push(alphabet[((n >> 18) & 0x3F) as usize] as char);
out.push(alphabet[((n >> 12) & 0x3F) as usize] as char);
out.push(alphabet[((n >> 6) & 0x3F) as usize] as char);
out.push(alphabet[(n & 0x3F) as usize] as char);
i += 3;
}
if i < raw.len() {
let rem = raw.len() - i;
let n = if rem == 1 {
(raw[i] as u32) << 16
} else {
((raw[i] as u32) << 16) | ((raw[i + 1] as u32) << 8)
};
out.push(alphabet[((n >> 18) & 0x3F) as usize] as char);
out.push(alphabet[((n >> 12) & 0x3F) as usize] as char);
if rem == 2 {
out.push(alphabet[((n >> 6) & 0x3F) as usize] as char);
}
}
let key = parse_key(&out).unwrap();
assert_eq!(key, [0xABu8; 32]);
}
}