use crate::aes_gcm::{aes256_gcm_decrypt, aes256_gcm_encrypt};
use crate::os_random;
use crate::params::{NONCE_SIZE, PAGE_ENVELOPE_OVERHEAD};
#[derive(Debug)]
pub enum PageEnvelopeError {
Truncated,
KeyMismatch(String),
RandomFailure(String),
}
impl std::fmt::Display for PageEnvelopeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Truncated => f.write_str("encrypted page: truncated frame"),
Self::KeyMismatch(detail) => {
write!(f, "encrypted page: key mismatch or tampering ({detail})")
}
Self::RandomFailure(detail) => {
write!(f, "encrypted page: nonce generation failed ({detail})")
}
}
}
}
impl std::error::Error for PageEnvelopeError {}
pub fn encrypt_page(
key: &[u8; 32],
page_id: u32,
plaintext: &[u8],
) -> Result<Vec<u8>, PageEnvelopeError> {
let mut nonce = [0u8; NONCE_SIZE];
os_random::fill_bytes(&mut nonce).map_err(PageEnvelopeError::RandomFailure)?;
let aad = page_id.to_le_bytes();
let ciphertext = aes256_gcm_encrypt(key, &nonce, &aad, plaintext);
let mut out = Vec::with_capacity(PAGE_ENVELOPE_OVERHEAD + plaintext.len());
out.extend_from_slice(&nonce);
out.extend_from_slice(&ciphertext);
Ok(out)
}
pub fn decrypt_page(
key: &[u8; 32],
page_id: u32,
frame: &[u8],
) -> Result<Vec<u8>, PageEnvelopeError> {
if frame.len() < PAGE_ENVELOPE_OVERHEAD {
return Err(PageEnvelopeError::Truncated);
}
let mut nonce = [0u8; NONCE_SIZE];
nonce.copy_from_slice(&frame[..NONCE_SIZE]);
let aad = page_id.to_le_bytes();
aes256_gcm_decrypt(key, &nonce, &aad, &frame[NONCE_SIZE..])
.map_err(PageEnvelopeError::KeyMismatch)
}
#[cfg(test)]
mod tests {
use super::*;
fn key() -> [u8; 32] {
let mut k = [0u8; 32];
for (i, b) in k.iter_mut().enumerate() {
*b = i as u8;
}
k
}
#[test]
fn round_trips_plaintext() {
let plaintext = b"page bytes that will be encrypted";
let frame = encrypt_page(&key(), 7, plaintext).unwrap();
assert_eq!(frame.len(), PAGE_ENVELOPE_OVERHEAD + plaintext.len());
let recovered = decrypt_page(&key(), 7, &frame).unwrap();
assert_eq!(recovered, plaintext);
}
#[test]
fn nonce_is_random_per_call() {
let plaintext = b"same payload, different nonce";
let f1 = encrypt_page(&key(), 1, plaintext).unwrap();
let f2 = encrypt_page(&key(), 1, plaintext).unwrap();
assert_ne!(f1, f2);
}
#[test]
fn page_id_binding_catches_swapped_pages() {
let plaintext = b"page 1 contents";
let frame = encrypt_page(&key(), 1, plaintext).unwrap();
let err = decrypt_page(&key(), 2, &frame).unwrap_err();
assert!(
matches!(err, PageEnvelopeError::KeyMismatch(_)),
"got {err:?}"
);
}
#[test]
fn wrong_key_fails_closed() {
let plaintext = b"sensitive";
let frame = encrypt_page(&key(), 5, plaintext).unwrap();
let mut wrong = key();
wrong[0] ^= 0xff;
let err = decrypt_page(&wrong, 5, &frame).unwrap_err();
assert!(matches!(err, PageEnvelopeError::KeyMismatch(_)));
}
#[test]
fn truncated_frame_is_typed() {
let frame = vec![0u8; PAGE_ENVELOPE_OVERHEAD - 1];
let err = decrypt_page(&key(), 0, &frame).unwrap_err();
assert!(matches!(err, PageEnvelopeError::Truncated));
}
#[test]
fn tampered_tag_fails() {
let frame = encrypt_page(&key(), 9, b"abc").unwrap();
let mut bad = frame.clone();
let last = bad.len() - 1;
bad[last] ^= 1;
assert!(decrypt_page(&key(), 9, &bad).is_err());
}
#[test]
fn error_display_is_specific_to_failure_class() {
assert_eq!(
PageEnvelopeError::Truncated.to_string(),
"encrypted page: truncated frame"
);
assert_eq!(
PageEnvelopeError::KeyMismatch("bad tag".to_string()).to_string(),
"encrypted page: key mismatch or tampering (bad tag)"
);
assert_eq!(
PageEnvelopeError::RandomFailure("no entropy".to_string()).to_string(),
"encrypted page: nonce generation failed (no entropy)"
);
}
}