use std::io::Cursor;
use crate::io::{BigEndian, ReadBytesExt};
use aes_gcm::aead::Generate;
use structured_zstd::skippable::SkippableFrame;
use super::aad::{AAD_LEN, BlockIdentity, EncryptionContext, FORMAT_VERSION_V1, SuiteId, build};
use super::aead::{TAG_LEN, decrypt_in_place, encrypt_in_place};
use super::error::DecryptError;
use super::key_chain::KeyChain;
const METADATA_VARIANT: u8 = 0;
const BODY_VARIANT: u8 = 1;
const SKIPPABLE_MAGIC_START: u32 = 0x184D_2A50;
fn read_framed_payload_len<R: std::io::Read>(
reader: &mut R,
expected_variant: Option<u8>,
min_payload: u32,
max_payload: u32,
err_ctor: fn(&'static str) -> DecryptError,
) -> Result<u32, DecryptError> {
let mut header = [0u8; 8];
reader
.read_exact(&mut header)
.map_err(|_| err_ctor("truncated skippable-frame header"))?;
let magic = u32::from_le_bytes([header[0], header[1], header[2], header[3]]);
let variant = magic.wrapping_sub(SKIPPABLE_MAGIC_START);
if variant > 15 {
return Err(err_ctor("magic outside skippable-frame range"));
}
#[expect(
clippy::cast_possible_truncation,
reason = "guarded by `variant > 15` immediately above"
)]
let variant_byte = variant as u8;
if let Some(v) = expected_variant
&& variant_byte != v
{
return Err(err_ctor("wrong frame magic / variant"));
}
let payload_len = u32::from_le_bytes([header[4], header[5], header[6], header[7]]);
if payload_len < min_payload {
return Err(err_ctor("PayloadLen below spec minimum"));
}
if payload_len > max_payload {
return Err(err_ctor("PayloadLen exceeds cap"));
}
Ok(payload_len)
}
fn read_framed_payload<R: std::io::Read>(
reader: &mut R,
expected_variant: Option<u8>,
min_payload: u32,
max_payload: u32,
err_ctor: fn(&'static str) -> DecryptError,
) -> Result<Vec<u8>, DecryptError> {
let payload_len =
read_framed_payload_len(reader, expected_variant, min_payload, max_payload, err_ctor)?;
let mut payload = vec![0u8; payload_len as usize];
reader
.read_exact(&mut payload)
.map_err(|_| err_ctor("truncated frame payload"))?;
Ok(payload)
}
const METADATA_PAYLOAD_LEN_V1: u32 = 39;
const MAX_BODY_LEN: u32 = 256 * 1024 * 1024;
fn encode_metadata_payload(
ctx: EncryptionContext,
identity: &BlockIdentity,
nonce: &[u8; 12],
tag: &[u8; TAG_LEN],
) -> Vec<u8> {
let mut out = Vec::with_capacity(METADATA_PAYLOAD_LEN_V1 as usize);
out.push(ctx.header_byte);
out.push(ctx.key_epoch);
out.push(u8::from(identity.block_type));
out.push(ctx.suite_id.as_byte());
out.push(ctx.compression_type);
out.extend_from_slice(&identity.dict_id.to_be_bytes());
out.push(identity.window_log);
out.push(ctx.block_flags);
out.extend_from_slice(nonce);
out.extend_from_slice(tag);
debug_assert_eq!(
out.len(),
METADATA_PAYLOAD_LEN_V1 as usize,
"v1 MetadataPayload must be exactly 39 bytes"
);
out
}
struct ParsedMetadata {
suite_id: SuiteId,
nonce: [u8; 12],
tag: [u8; TAG_LEN],
ctx: EncryptionContext,
block_type_byte: u8,
dict_id: u32,
window_log: u8,
}
fn decode_metadata_payload(payload: &[u8]) -> Result<ParsedMetadata, DecryptError> {
if payload.len() != METADATA_PAYLOAD_LEN_V1 as usize {
return Err(DecryptError::MalformedMetadataFrame(
"MetadataPayload length != 39 for v1",
));
}
let mut cursor = Cursor::new(payload);
let read_u8 = |c: &mut Cursor<&[u8]>| {
c.read_u8()
.map_err(|_| DecryptError::MalformedMetadataFrame("truncated MetadataPayload"))
};
let header_byte = read_u8(&mut cursor)?;
if (header_byte >> 4) != FORMAT_VERSION_V1 {
return Err(DecryptError::UnsupportedFormatVersion { header_byte });
}
let key_epoch = read_u8(&mut cursor)?;
let block_type_byte = read_u8(&mut cursor)?;
let suite_byte = read_u8(&mut cursor)?;
let suite_id = SuiteId::try_from(suite_byte)
.map_err(|s| DecryptError::UnsupportedSuite { suite_id: s })?;
let compression_type = read_u8(&mut cursor)?;
if !matches!(compression_type, 0 | 1 | 3 | 4) {
return Err(DecryptError::MalformedMetadataFrame(
"CompressionType byte not in spec registry (0, 1, 3, 4)",
));
}
let dict_id = cursor
.read_u32::<BigEndian>()
.map_err(|_| DecryptError::MalformedMetadataFrame("truncated DictID"))?;
let window_log = read_u8(&mut cursor)?;
if window_log != 0 && !(10..=31).contains(&window_log) {
return Err(DecryptError::MalformedMetadataFrame(
"WindowLog outside valid range (must be 0 or 10..=31)",
));
}
if compression_type != 4 && dict_id != 0 {
return Err(DecryptError::MalformedMetadataFrame(
"non-zero DictID with non-ZstdDict CompressionType",
));
}
if !matches!(compression_type, 3 | 4) && window_log != 0 {
return Err(DecryptError::MalformedMetadataFrame(
"non-zero WindowLog with non-zstd CompressionType",
));
}
let block_flags = read_u8(&mut cursor)?;
if block_flags & !crate::table::block::header::block_flags::KNOWN != 0 {
return Err(DecryptError::MalformedMetadataFrame(
"unknown bits set in BlockFlags",
));
}
let mut nonce = [0u8; 12];
std::io::Read::read_exact(&mut cursor, &mut nonce)
.map_err(|_| DecryptError::MalformedMetadataFrame("truncated Nonce"))?;
let mut tag = [0u8; TAG_LEN];
std::io::Read::read_exact(&mut cursor, &mut tag)
.map_err(|_| DecryptError::MalformedMetadataFrame("truncated AEADTag"))?;
Ok(ParsedMetadata {
suite_id,
nonce,
tag,
ctx: EncryptionContext {
header_byte,
key_epoch,
suite_id,
compression_type,
block_flags,
},
block_type_byte,
dict_id,
window_log,
})
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub struct EncryptedBlockMetadata {
pub format_version: u8,
pub key_epoch: u8,
pub block_type: u8,
pub suite_id: SuiteId,
pub compression_type: u8,
pub dict_id: u32,
pub window_log: u8,
pub block_flags: u8,
pub nonce: [u8; 12],
pub aead_tag: [u8; TAG_LEN],
pub ciphertext_len: usize,
}
pub fn parse_encrypted_block_metadata(
bytes: &[u8],
) -> Result<EncryptedBlockMetadata, DecryptError> {
let mut cursor = Cursor::new(bytes);
let metadata_payload = read_framed_payload(
&mut cursor,
Some(METADATA_VARIANT),
METADATA_PAYLOAD_LEN_V1,
METADATA_PAYLOAD_LEN_V1,
DecryptError::MalformedMetadataFrame,
)?;
let parsed = decode_metadata_payload(&metadata_payload)?;
let ciphertext_len = read_framed_payload_len(
&mut cursor,
Some(BODY_VARIANT),
1,
MAX_BODY_LEN,
DecryptError::MalformedBodyFrame,
)?;
let remaining = u64::try_from(bytes.len())
.unwrap_or(u64::MAX)
.saturating_sub(cursor.position());
if u64::from(ciphertext_len) > remaining {
return Err(DecryptError::MalformedBodyFrame("truncated frame payload"));
}
let ciphertext_len = ciphertext_len as usize;
Ok(EncryptedBlockMetadata {
format_version: parsed.ctx.header_byte >> 4,
key_epoch: parsed.ctx.key_epoch,
block_type: parsed.block_type_byte,
suite_id: parsed.suite_id,
compression_type: parsed.ctx.compression_type,
dict_id: parsed.dict_id,
window_log: parsed.window_log,
block_flags: parsed.ctx.block_flags,
nonce: parsed.nonce,
aead_tag: parsed.tag,
ciphertext_len,
})
}
pub fn reconstruct_block_aad(bytes: &[u8], table_id: u64) -> Result<[u8; AAD_LEN], DecryptError> {
let mut cursor = Cursor::new(bytes);
let metadata_payload = read_framed_payload(
&mut cursor,
Some(METADATA_VARIANT),
METADATA_PAYLOAD_LEN_V1,
METADATA_PAYLOAD_LEN_V1,
DecryptError::MalformedMetadataFrame,
)?;
let parsed = decode_metadata_payload(&metadata_payload)?;
let block_type = crate::table::block::BlockType::try_from(parsed.block_type_byte)
.map_err(|_| DecryptError::MalformedMetadataFrame("unknown BlockType byte"))?;
let identity = BlockIdentity {
table_id,
block_type,
dict_id: parsed.dict_id,
window_log: parsed.window_log,
};
Ok(build(&parsed.ctx, &identity))
}
pub fn encrypt_block(
plaintext: &[u8],
identity: &BlockIdentity,
ctx: &EncryptionContext,
key_chain: &dyn KeyChain,
) -> crate::Result<Vec<u8>> {
if (ctx.header_byte >> 4) != FORMAT_VERSION_V1 {
return Err(crate::Error::Encrypt(
"HeaderByte high nibble does not match FORMAT_VERSION_V1 (spec §4.8)",
));
}
if (ctx.header_byte & 0x0F) != 0 {
return Err(crate::Error::Encrypt(
"HeaderByte low nibble is reserved and must be zero on write (spec §4.8)",
));
}
if ctx.block_flags & !crate::table::block::header::block_flags::KNOWN != 0 {
return Err(crate::Error::Encrypt(
"BlockFlags has bits set outside the known transform mask",
));
}
let key = key_chain.key(ctx.key_epoch).ok_or_else(|| {
log::error!(
"encrypt_block: KeyEpoch {} not present in caller's KeyChain",
ctx.key_epoch,
);
crate::Error::Encrypt("KeyEpoch not present in caller's KeyChain")
})?;
if plaintext.is_empty() {
return Err(crate::Error::Encrypt(
"plaintext must be non-empty per AAD-bound spec (BodyFrame PayloadLen >= 1)",
));
}
if plaintext.len() > MAX_BODY_LEN as usize {
return Err(crate::Error::Encrypt("plaintext exceeds 256 MiB body cap"));
}
if !matches!(ctx.compression_type, 0 | 1 | 3 | 4) {
return Err(crate::Error::Encrypt(
"invalid CompressionType (spec §5.1: must be 0=None, 1=Lz4, 3=Zstd, or 4=ZstdDict)",
));
}
if identity.dict_id != 0 && ctx.compression_type != 4 {
return Err(crate::Error::Encrypt(
"non-zero DictID with non-ZstdDict CompressionType (spec §5.1)",
));
}
if identity.window_log != 0 && !matches!(ctx.compression_type, 3 | 4) {
return Err(crate::Error::Encrypt(
"non-zero WindowLog with non-zstd CompressionType (spec §5.1)",
));
}
if identity.window_log != 0 && !(10..=31).contains(&identity.window_log) {
return Err(crate::Error::Encrypt(
"WindowLog outside valid range (spec §5.1: must be 0 or 10..=31)",
));
}
let nonce: [u8; 12] = <[u8; 12]>::generate();
let aad = build(ctx, identity);
let mut body = plaintext.to_vec();
let tag = encrypt_in_place(ctx.suite_id, key, &nonce, &aad, &mut body)?;
let metadata_payload = encode_metadata_payload(*ctx, identity, &nonce, &tag);
let metadata_frame = SkippableFrame::new(METADATA_VARIANT, metadata_payload)
.map_err(|_| crate::Error::Encrypt("MetadataFrame construction failed"))?;
let body_frame = SkippableFrame::new(BODY_VARIANT, body)
.map_err(|_| crate::Error::Encrypt("BodyFrame construction failed"))?;
let total_size = metadata_frame.serialized_size() + body_frame.serialized_size();
let mut out = Vec::with_capacity(total_size);
metadata_frame
.encode_into(&mut out)
.map_err(|_| crate::Error::Encrypt("MetadataFrame serialisation failed"))?;
body_frame
.encode_into(&mut out)
.map_err(|_| crate::Error::Encrypt("BodyFrame serialisation failed"))?;
Ok(out)
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct DecryptedBlock {
pub plaintext: Vec<u8>,
pub compression_type: u8,
pub dict_id: u32,
pub window_log: u8,
pub block_flags: u8,
}
pub fn decrypt_block(
bytes: &[u8],
identity: &BlockIdentity,
key_chain: &dyn KeyChain,
) -> Result<DecryptedBlock, DecryptError> {
let mut cursor = Cursor::new(bytes);
let metadata_payload = read_framed_payload(
&mut cursor,
Some(METADATA_VARIANT),
METADATA_PAYLOAD_LEN_V1,
METADATA_PAYLOAD_LEN_V1,
DecryptError::MalformedMetadataFrame,
)?;
let parsed = decode_metadata_payload(&metadata_payload)?;
let mut ciphertext = read_framed_payload(
&mut cursor,
Some(BODY_VARIANT),
1,
MAX_BODY_LEN,
DecryptError::MalformedBodyFrame,
)?;
let pos = cursor.position();
let total = u64::try_from(bytes.len()).unwrap_or(u64::MAX);
if pos != total {
return Err(DecryptError::MalformedBodyFrame(
"unexpected trailing bytes after BodyFrame",
));
}
let block_type = crate::table::block::BlockType::try_from(parsed.block_type_byte)
.map_err(|_| DecryptError::MalformedMetadataFrame("unknown BlockType byte"))?;
let aad_identity = BlockIdentity {
table_id: identity.table_id,
block_type,
dict_id: parsed.dict_id,
window_log: parsed.window_log,
};
let aad = build(&parsed.ctx, &aad_identity);
debug_assert_eq!(aad.len(), AAD_LEN);
let key = key_chain
.key(parsed.ctx.key_epoch)
.ok_or(DecryptError::UnknownKeyEpoch {
key_epoch: parsed.ctx.key_epoch,
})?;
decrypt_in_place(
parsed.suite_id,
key,
&parsed.nonce,
&aad,
&parsed.tag,
&mut ciphertext,
)?;
Ok(DecryptedBlock {
plaintext: ciphertext,
compression_type: parsed.ctx.compression_type,
dict_id: parsed.dict_id,
window_log: parsed.window_log,
block_flags: parsed.ctx.block_flags,
})
}
#[cfg(test)]
#[expect(clippy::unwrap_used, clippy::indexing_slicing, reason = "test code")]
mod tests {
use super::*;
use crate::encryption::key_chain::StaticKeyChain;
use crate::table::block::BlockType;
const TEST_KEY: [u8; 32] = [0x42; 32];
const TEST_KEY_OTHER: [u8; 32] = [0x55; 32];
fn id() -> BlockIdentity {
BlockIdentity {
table_id: 0x1234_5678_9ABC_DEF0,
block_type: BlockType::Data,
dict_id: 0,
window_log: 0,
}
}
fn ctx() -> EncryptionContext {
EncryptionContext::v1(0, SuiteId::Aes256Gcm, 0, 0)
}
fn chain() -> StaticKeyChain {
StaticKeyChain::new().with_key(0, TEST_KEY)
}
#[test]
fn parse_metadata_is_key_free_and_matches_seal() {
let plaintext = b"forensic payload bytes";
let sealed = encrypt_block(plaintext, &id(), &ctx(), &chain()).unwrap();
let meta = parse_encrypted_block_metadata(&sealed).unwrap();
assert_eq!(meta.format_version, 1);
assert_eq!(meta.key_epoch, 0);
assert_eq!(meta.suite_id, SuiteId::Aes256Gcm);
assert_eq!(meta.block_type, u8::from(BlockType::Data));
assert_eq!(meta.compression_type, 0);
assert_eq!(meta.dict_id, 0);
assert_eq!(meta.window_log, 0);
assert!(meta.ciphertext_len > 0, "body must carry ciphertext");
let chacha_ctx = EncryptionContext::v1(0, SuiteId::ChaCha20Poly1305, 0, 0);
let sealed_cc = encrypt_block(plaintext, &id(), &chacha_ctx, &chain()).unwrap();
let meta_cc = parse_encrypted_block_metadata(&sealed_cc).unwrap();
assert_eq!(meta_cc.suite_id, SuiteId::ChaCha20Poly1305);
assert!(parse_encrypted_block_metadata(b"not a frame").is_err());
assert!(parse_encrypted_block_metadata(&sealed[..10]).is_err());
}
#[test]
fn parse_metadata_rejects_truncated_body() {
let sealed = encrypt_block(b"forensic payload bytes", &id(), &ctx(), &chain()).unwrap();
let cut = 8 + METADATA_PAYLOAD_LEN_V1 as usize + 8;
assert!(
sealed.len() > cut,
"test setup: sealed block must extend past the body header",
);
let err = parse_encrypted_block_metadata(&sealed[..cut])
.expect_err("truncated body must be rejected");
assert!(
matches!(err, DecryptError::MalformedBodyFrame(_)),
"expected MalformedBodyFrame for a truncated body, got {err:?}",
);
}
#[test]
fn reconstruct_aad_matches_seal_with_correct_table_id() {
let sealed = encrypt_block(b"forensic payload bytes", &id(), &ctx(), &chain()).unwrap();
let expected = build(&ctx(), &id());
let got = reconstruct_block_aad(&sealed, id().table_id).unwrap();
assert_eq!(got.len(), AAD_LEN);
assert_eq!(
got, expected,
"reconstructed AAD must match the sealing AAD"
);
let other = reconstruct_block_aad(&sealed, id().table_id ^ 1).unwrap();
assert_ne!(got, other, "table_id must affect the reconstructed AAD");
assert!(reconstruct_block_aad(b"not a frame", 0).is_err());
}
#[test]
fn roundtrip_aes_recovers_plaintext() {
let plaintext = b"the quick brown fox";
let sealed = encrypt_block(plaintext, &id(), &ctx(), &chain()).unwrap();
let recovered = decrypt_block(&sealed, &id(), &chain()).unwrap();
assert_eq!(&recovered.plaintext[..], plaintext);
assert_eq!(recovered.compression_type, 0);
assert_eq!(recovered.dict_id, 0);
assert_eq!(recovered.window_log, 0);
}
#[test]
fn roundtrip_chacha_recovers_plaintext() {
let plaintext = b"the quick brown fox";
let chacha_ctx = EncryptionContext::v1(0, SuiteId::ChaCha20Poly1305, 0, 0);
let sealed = encrypt_block(plaintext, &id(), &chacha_ctx, &chain()).unwrap();
let recovered = decrypt_block(&sealed, &id(), &chain()).unwrap();
assert_eq!(&recovered.plaintext[..], plaintext);
}
#[test]
fn wrong_key_in_chain_surfaces_aead_failure() {
let plaintext = b"the quick brown fox";
let sealed = encrypt_block(plaintext, &id(), &ctx(), &chain()).unwrap();
let wrong = StaticKeyChain::new().with_key(0, TEST_KEY_OTHER);
let err = decrypt_block(&sealed, &id(), &wrong).unwrap_err();
assert!(
matches!(err, DecryptError::AeadVerificationFailed),
"expected AeadVerificationFailed, got {err:?}",
);
}
#[test]
fn missing_key_epoch_surfaces_unknown_key_epoch() {
let plaintext = b"the quick brown fox";
let sealed = encrypt_block(plaintext, &id(), &ctx(), &chain()).unwrap();
let no_epoch_zero = StaticKeyChain::new().with_key(1, TEST_KEY);
let err = decrypt_block(&sealed, &id(), &no_epoch_zero).unwrap_err();
assert!(
matches!(err, DecryptError::UnknownKeyEpoch { key_epoch: 0 }),
"expected UnknownKeyEpoch {{ key_epoch: 0 }}, got {err:?}",
);
}
#[test]
fn cross_identity_substitution_surfaces_aead_failure() {
let plaintext = b"the quick brown fox";
let sealed = encrypt_block(plaintext, &id(), &ctx(), &chain()).unwrap();
let mut wrong_id = id();
wrong_id.table_id ^= 0x1; let err = decrypt_block(&sealed, &wrong_id, &chain()).unwrap_err();
assert!(matches!(err, DecryptError::AeadVerificationFailed));
}
#[test]
fn trailing_bytes_after_body_are_rejected() {
let plaintext = b"the quick brown fox";
let mut sealed = encrypt_block(plaintext, &id(), &ctx(), &chain()).unwrap();
assert!(decrypt_block(&sealed, &id(), &chain()).is_ok());
sealed.extend_from_slice(&[0xAB, 0xCD, 0xEF, 0x01]);
let err = decrypt_block(&sealed, &id(), &chain()).unwrap_err();
assert!(
matches!(err, DecryptError::MalformedBodyFrame(_)),
"expected MalformedBodyFrame for trailing bytes, got {err:?}",
);
}
#[test]
fn truncated_input_surfaces_malformed_metadata() {
let plaintext = b"the quick brown fox";
let sealed = encrypt_block(plaintext, &id(), &ctx(), &chain()).unwrap();
let truncated = &sealed[..6];
let err = decrypt_block(truncated, &id(), &chain()).unwrap_err();
assert!(matches!(err, DecryptError::MalformedMetadataFrame(_)));
}
#[test]
fn encrypt_block_rejects_empty_plaintext() {
let err = encrypt_block(&[], &id(), &ctx(), &chain()).unwrap_err();
assert!(
matches!(err, crate::Error::Encrypt(_)),
"expected Error::Encrypt for empty plaintext, got {err:?}",
);
}
#[test]
fn encrypt_block_rejects_unknown_block_flags_bit() {
let mut c = ctx();
c.block_flags = 0x10; let err = encrypt_block(b"payload", &id(), &c, &chain()).unwrap_err();
assert!(
matches!(err, crate::Error::Encrypt(_)),
"expected Error::Encrypt for unknown BlockFlags bit, got {err:?}",
);
}
#[test]
fn invalid_window_log_surfaces_malformed_metadata() {
let plaintext = b"the quick brown fox";
let mut sealed = encrypt_block(plaintext, &id(), &ctx(), &chain()).unwrap();
sealed[17] = 9; let err = decrypt_block(&sealed, &id(), &chain()).unwrap_err();
assert!(
matches!(err, DecryptError::MalformedMetadataFrame(_)),
"expected MalformedMetadataFrame for WindowLog=9, got {err:?}",
);
}
#[test]
fn oversized_body_payload_len_rejected_before_alloc() {
let plaintext = b"the quick brown fox";
let mut sealed = encrypt_block(plaintext, &id(), &ctx(), &chain()).unwrap();
let metadata_frame_len = 8 + METADATA_PAYLOAD_LEN_V1 as usize;
let body_payload_len_at = metadata_frame_len + 4;
sealed[body_payload_len_at..body_payload_len_at + 4]
.copy_from_slice(&u32::MAX.to_le_bytes());
let err = decrypt_block(&sealed, &id(), &chain()).unwrap_err();
assert!(
matches!(err, DecryptError::MalformedBodyFrame(_)),
"expected MalformedBodyFrame for oversized BodyFrame PayloadLen, got {err:?}",
);
}
#[test]
fn unknown_block_flags_bit_rejected_before_aead() {
let plaintext = b"the quick brown fox";
let mut sealed = encrypt_block(plaintext, &id(), &ctx(), &chain()).unwrap();
const BLOCK_FLAGS_AT: usize = 8 + 10;
sealed[BLOCK_FLAGS_AT] |= 0x10;
let err = decrypt_block(&sealed, &id(), &chain()).unwrap_err();
assert!(
matches!(err, DecryptError::MalformedMetadataFrame(_)),
"expected MalformedMetadataFrame for unknown BlockFlags bit, got {err:?}",
);
}
}