use ring::aead::{self, AES_256_GCM, Aad, LessSafeKey, Nonce, UnboundKey};
use crate::error::{DbError, DbResult};
pub const TAG_LEN: usize = 16;
pub const PAGE_SIZE: usize = 4096;
pub struct PageCipher {
key: LessSafeKey,
}
impl PageCipher {
pub fn new(key_bytes: &[u8; 32]) -> DbResult<Self> {
let unbound = UnboundKey::new(&AES_256_GCM, key_bytes)
.map_err(|e| DbError::EncryptionError(format!("invalid key: {e}")))?;
Ok(Self {
key: LessSafeKey::new(unbound),
})
}
pub fn key_from_env(var: &str) -> DbResult<[u8; 32]> {
let hex = std::env::var(var)
.map_err(|e| DbError::EncryptionError(format!("env var {var}: {e}")))?;
let hex = hex.trim();
if hex.len() != 64 {
return Err(DbError::EncryptionError(format!(
"env var {var}: expected 64 hex chars, got {}",
hex.len()
)));
}
let mut key = [0u8; 32];
for i in 0..32 {
key[i] = u8::from_str_radix(&hex[i * 2..i * 2 + 2], 16).map_err(|e| {
DbError::EncryptionError(format!("env var {var}: invalid hex: {e}"))
})?;
}
Ok(key)
}
pub fn encrypt_page(
&self,
file_id: u32,
page_number: u64,
page: &mut [u8],
) -> DbResult<[u8; TAG_LEN]> {
debug_assert_eq!(page.len(), PAGE_SIZE);
let nonce = make_page_nonce(file_id, page_number);
let tag = self
.key
.seal_in_place_separate_tag(nonce, Aad::empty(), page)
.map_err(|_| DbError::EncryptionError("encrypt failed".into()))?;
let mut tag_bytes = [0u8; TAG_LEN];
tag_bytes.copy_from_slice(tag.as_ref());
Ok(tag_bytes)
}
pub fn decrypt_page(
&self,
file_id: u32,
page_number: u64,
page: &mut [u8],
tag: &[u8; TAG_LEN],
) -> DbResult<()> {
debug_assert_eq!(page.len(), PAGE_SIZE);
let mut buf = Vec::with_capacity(PAGE_SIZE + TAG_LEN);
buf.extend_from_slice(page);
buf.extend_from_slice(tag);
let nonce = make_page_nonce(file_id, page_number);
let plaintext = self
.key
.open_in_place(nonce, Aad::empty(), &mut buf)
.map_err(|_| DbError::EncryptionError("decryption/authentication failed".into()))?;
page.copy_from_slice(&plaintext[..PAGE_SIZE]);
Ok(())
}
}
fn make_page_nonce(file_id: u32, page_number: u64) -> Nonce {
let mut nonce_bytes = [0u8; aead::NONCE_LEN];
nonce_bytes[..4].copy_from_slice(&file_id.to_le_bytes());
nonce_bytes[4..12].copy_from_slice(&page_number.to_le_bytes());
Nonce::assume_unique_for_key(nonce_bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt_round_trip() {
let key = [0xABu8; 32];
let cipher = PageCipher::new(&key).unwrap();
let original = [42u8; PAGE_SIZE];
let mut page = original;
let tag = cipher.encrypt_page(1, 0, &mut page).unwrap();
assert_ne!(page, original);
cipher.decrypt_page(1, 0, &mut page, &tag).unwrap();
assert_eq!(page, original);
}
#[test]
fn test_wrong_key_fails() {
let cipher1 = PageCipher::new(&[0xAA; 32]).unwrap();
let cipher2 = PageCipher::new(&[0xBB; 32]).unwrap();
let mut page = [1u8; PAGE_SIZE];
let tag = cipher1.encrypt_page(1, 0, &mut page).unwrap();
assert!(cipher2.decrypt_page(1, 0, &mut page, &tag).is_err());
}
#[test]
fn test_wrong_nonce_fails() {
let cipher = PageCipher::new(&[0xCC; 32]).unwrap();
let mut page = [7u8; PAGE_SIZE];
let tag = cipher.encrypt_page(1, 0, &mut page).unwrap();
assert!(cipher.decrypt_page(1, 1, &mut page, &tag).is_err());
}
#[test]
fn test_tampered_data_fails() {
let cipher = PageCipher::new(&[0xDD; 32]).unwrap();
let mut page = [9u8; PAGE_SIZE];
let tag = cipher.encrypt_page(1, 0, &mut page).unwrap();
page[0] ^= 0xFF; assert!(cipher.decrypt_page(1, 0, &mut page, &tag).is_err());
}
}