use std::io::Cursor;
use crate::io::{BigEndian, ReadBytesExt};
use aes_gcm::aead::Generate;
use structured_zstd::skippable::SkippableFrame;
use super::aad::{AAD_LEN, BlockIdentity, EncryptionContext, FORMAT_VERSION_V1, SuiteId, build};
use super::aead::{TAG_LEN, decrypt_in_place, encrypt_in_place};
use super::error::DecryptError;
use super::key_chain::KeyChain;
const METADATA_VARIANT: u8 = 0;
const BODY_VARIANT: u8 = 1;
const SKIPPABLE_MAGIC_START: u32 = 0x184D_2A50;
fn read_framed_payload_len<R: std::io::Read>(
reader: &mut R,
expected_variant: Option<u8>,
min_payload: u32,
max_payload: u32,
err_ctor: fn(&'static str) -> DecryptError,
) -> Result<u32, DecryptError> {
let mut header = [0u8; 8];
reader
.read_exact(&mut header)
.map_err(|_| err_ctor("truncated skippable-frame header"))?;
let magic = u32::from_le_bytes([header[0], header[1], header[2], header[3]]);
let variant = magic.wrapping_sub(SKIPPABLE_MAGIC_START);
if variant > 15 {
return Err(err_ctor("magic outside skippable-frame range"));
}
#[expect(
clippy::cast_possible_truncation,
reason = "guarded by `variant > 15` immediately above"
)]
let variant_byte = variant as u8;
if let Some(v) = expected_variant
&& variant_byte != v
{
return Err(err_ctor("wrong frame magic / variant"));
}
let payload_len = u32::from_le_bytes([header[4], header[5], header[6], header[7]]);
if payload_len < min_payload {
return Err(err_ctor("PayloadLen below spec minimum"));
}
if payload_len > max_payload {
return Err(err_ctor("PayloadLen exceeds cap"));
}
Ok(payload_len)
}
fn read_framed_payload<R: std::io::Read>(
reader: &mut R,
expected_variant: Option<u8>,
min_payload: u32,
max_payload: u32,
err_ctor: fn(&'static str) -> DecryptError,
) -> Result<Vec<u8>, DecryptError> {
let payload_len =
read_framed_payload_len(reader, expected_variant, min_payload, max_payload, err_ctor)?;
let mut payload = vec![0u8; payload_len as usize];
reader
.read_exact(&mut payload)
.map_err(|_| err_ctor("truncated frame payload"))?;
Ok(payload)
}
const METADATA_PAYLOAD_LEN_V1: u32 = 39;
const MAX_BODY_LEN: u32 = 256 * 1024 * 1024;
fn encode_metadata_payload(
ctx: EncryptionContext,
identity: &BlockIdentity,
nonce: &[u8; 12],
tag: &[u8; TAG_LEN],
) -> Vec<u8> {
let mut out = Vec::with_capacity(METADATA_PAYLOAD_LEN_V1 as usize);
out.push(ctx.header_byte);
out.push(ctx.key_epoch);
out.push(u8::from(identity.block_type));
out.push(ctx.suite_id.as_byte());
out.push(ctx.compression_type);
out.extend_from_slice(&identity.dict_id.to_be_bytes());
out.push(identity.window_log);
out.push(ctx.block_flags);
out.extend_from_slice(nonce);
out.extend_from_slice(tag);
debug_assert_eq!(
out.len(),
METADATA_PAYLOAD_LEN_V1 as usize,
"v1 MetadataPayload must be exactly 39 bytes"
);
out
}
struct ParsedMetadata {
suite_id: SuiteId,
nonce: [u8; 12],
tag: [u8; TAG_LEN],
ctx: EncryptionContext,
block_type_byte: u8,
dict_id: u32,
window_log: u8,
}
fn decode_metadata_payload(payload: &[u8]) -> Result<ParsedMetadata, DecryptError> {
if payload.len() != METADATA_PAYLOAD_LEN_V1 as usize {
return Err(DecryptError::MalformedMetadataFrame(
"MetadataPayload length != 39 for v1",
));
}
let mut cursor = Cursor::new(payload);
let read_u8 = |c: &mut Cursor<&[u8]>| {
c.read_u8()
.map_err(|_| DecryptError::MalformedMetadataFrame("truncated MetadataPayload"))
};
let header_byte = read_u8(&mut cursor)?;
if (header_byte >> 4) != FORMAT_VERSION_V1 {
return Err(DecryptError::UnsupportedFormatVersion { header_byte });
}
let key_epoch = read_u8(&mut cursor)?;
let block_type_byte = read_u8(&mut cursor)?;
let suite_byte = read_u8(&mut cursor)?;
let suite_id = SuiteId::try_from(suite_byte)
.map_err(|s| DecryptError::UnsupportedSuite { suite_id: s })?;
let compression_type = read_u8(&mut cursor)?;
if !matches!(compression_type, 0 | 1 | 3 | 4) {
return Err(DecryptError::MalformedMetadataFrame(
"CompressionType byte not in spec registry (0, 1, 3, 4)",
));
}
let dict_id = cursor
.read_u32::<BigEndian>()
.map_err(|_| DecryptError::MalformedMetadataFrame("truncated DictID"))?;
let window_log = read_u8(&mut cursor)?;
if window_log != 0 && !(10..=31).contains(&window_log) {
return Err(DecryptError::MalformedMetadataFrame(
"WindowLog outside valid range (must be 0 or 10..=31)",
));
}
if compression_type != 4 && dict_id != 0 {
return Err(DecryptError::MalformedMetadataFrame(
"non-zero DictID with non-ZstdDict CompressionType",
));
}
if !matches!(compression_type, 3 | 4) && window_log != 0 {
return Err(DecryptError::MalformedMetadataFrame(
"non-zero WindowLog with non-zstd CompressionType",
));
}
let block_flags = read_u8(&mut cursor)?;
if block_flags & !crate::table::block::header::block_flags::KNOWN != 0 {
return Err(DecryptError::MalformedMetadataFrame(
"unknown bits set in BlockFlags",
));
}
let mut nonce = [0u8; 12];
std::io::Read::read_exact(&mut cursor, &mut nonce)
.map_err(|_| DecryptError::MalformedMetadataFrame("truncated Nonce"))?;
let mut tag = [0u8; TAG_LEN];
std::io::Read::read_exact(&mut cursor, &mut tag)
.map_err(|_| DecryptError::MalformedMetadataFrame("truncated AEADTag"))?;
Ok(ParsedMetadata {
suite_id,
nonce,
tag,
ctx: EncryptionContext {
header_byte,
key_epoch,
suite_id,
compression_type,
block_flags,
},
block_type_byte,
dict_id,
window_log,
})
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub struct EncryptedBlockMetadata {
pub format_version: u8,
pub key_epoch: u8,
pub block_type: u8,
pub suite_id: SuiteId,
pub compression_type: u8,
pub dict_id: u32,
pub window_log: u8,
pub block_flags: u8,
pub nonce: [u8; 12],
pub aead_tag: [u8; TAG_LEN],
pub ciphertext_len: usize,
}
pub fn parse_encrypted_block_metadata(
bytes: &[u8],
) -> Result<EncryptedBlockMetadata, DecryptError> {
let mut cursor = Cursor::new(bytes);
let metadata_payload = read_framed_payload(
&mut cursor,
Some(METADATA_VARIANT),
METADATA_PAYLOAD_LEN_V1,
METADATA_PAYLOAD_LEN_V1,
DecryptError::MalformedMetadataFrame,
)?;
let parsed = decode_metadata_payload(&metadata_payload)?;
let ciphertext_len = read_framed_payload_len(
&mut cursor,
Some(BODY_VARIANT),
1,
MAX_BODY_LEN,
DecryptError::MalformedBodyFrame,
)?;
let remaining = u64::try_from(bytes.len())
.unwrap_or(u64::MAX)
.saturating_sub(cursor.position());
if u64::from(ciphertext_len) > remaining {
return Err(DecryptError::MalformedBodyFrame("truncated frame payload"));
}
let ciphertext_len = ciphertext_len as usize;
Ok(EncryptedBlockMetadata {
format_version: parsed.ctx.header_byte >> 4,
key_epoch: parsed.ctx.key_epoch,
block_type: parsed.block_type_byte,
suite_id: parsed.suite_id,
compression_type: parsed.ctx.compression_type,
dict_id: parsed.dict_id,
window_log: parsed.window_log,
block_flags: parsed.ctx.block_flags,
nonce: parsed.nonce,
aead_tag: parsed.tag,
ciphertext_len,
})
}
pub fn reconstruct_block_aad(bytes: &[u8], table_id: u64) -> Result<[u8; AAD_LEN], DecryptError> {
let mut cursor = Cursor::new(bytes);
let metadata_payload = read_framed_payload(
&mut cursor,
Some(METADATA_VARIANT),
METADATA_PAYLOAD_LEN_V1,
METADATA_PAYLOAD_LEN_V1,
DecryptError::MalformedMetadataFrame,
)?;
let parsed = decode_metadata_payload(&metadata_payload)?;
let block_type = crate::table::block::BlockType::try_from(parsed.block_type_byte)
.map_err(|_| DecryptError::MalformedMetadataFrame("unknown BlockType byte"))?;
let identity = BlockIdentity {
table_id,
block_type,
dict_id: parsed.dict_id,
window_log: parsed.window_log,
};
Ok(build(&parsed.ctx, &identity))
}
pub fn encrypt_block(
plaintext: &[u8],
identity: &BlockIdentity,
ctx: &EncryptionContext,
key_chain: &dyn KeyChain,
) -> crate::Result<Vec<u8>> {
if (ctx.header_byte >> 4) != FORMAT_VERSION_V1 {
return Err(crate::Error::Encrypt(
"HeaderByte high nibble does not match FORMAT_VERSION_V1 (spec §4.8)",
));
}
if (ctx.header_byte & 0x0F) != 0 {
return Err(crate::Error::Encrypt(
"HeaderByte low nibble is reserved and must be zero on write (spec §4.8)",
));
}
if ctx.block_flags & !crate::table::block::header::block_flags::KNOWN != 0 {
return Err(crate::Error::Encrypt(
"BlockFlags has bits set outside the known transform mask",
));
}
let key = key_chain.key(ctx.key_epoch).ok_or_else(|| {
log::error!(
"encrypt_block: KeyEpoch {} not present in caller's KeyChain",
ctx.key_epoch,
);
crate::Error::Encrypt("KeyEpoch not present in caller's KeyChain")
})?;
if plaintext.is_empty() {
return Err(crate::Error::Encrypt(
"plaintext must be non-empty per AAD-bound spec (BodyFrame PayloadLen >= 1)",
));
}
if plaintext.len() > MAX_BODY_LEN as usize {
return Err(crate::Error::Encrypt("plaintext exceeds 256 MiB body cap"));
}
if !matches!(ctx.compression_type, 0 | 1 | 3 | 4) {
return Err(crate::Error::Encrypt(
"invalid CompressionType (spec §5.1: must be 0=None, 1=Lz4, 3=Zstd, or 4=ZstdDict)",
));
}
if identity.dict_id != 0 && ctx.compression_type != 4 {
return Err(crate::Error::Encrypt(
"non-zero DictID with non-ZstdDict CompressionType (spec §5.1)",
));
}
if identity.window_log != 0 && !matches!(ctx.compression_type, 3 | 4) {
return Err(crate::Error::Encrypt(
"non-zero WindowLog with non-zstd CompressionType (spec §5.1)",
));
}
if identity.window_log != 0 && !(10..=31).contains(&identity.window_log) {
return Err(crate::Error::Encrypt(
"WindowLog outside valid range (spec §5.1: must be 0 or 10..=31)",
));
}
let nonce: [u8; 12] = <[u8; 12]>::generate();
let aad = build(ctx, identity);
let mut body = plaintext.to_vec();
let tag = encrypt_in_place(ctx.suite_id, key, &nonce, &aad, &mut body)?;
let metadata_payload = encode_metadata_payload(*ctx, identity, &nonce, &tag);
let metadata_frame = SkippableFrame::new(METADATA_VARIANT, metadata_payload)
.map_err(|_| crate::Error::Encrypt("MetadataFrame construction failed"))?;
let body_frame = SkippableFrame::new(BODY_VARIANT, body)
.map_err(|_| crate::Error::Encrypt("BodyFrame construction failed"))?;
let total_size = metadata_frame.serialized_size() + body_frame.serialized_size();
let mut out = Vec::with_capacity(total_size);
metadata_frame
.encode_into(&mut out)
.map_err(|_| crate::Error::Encrypt("MetadataFrame serialisation failed"))?;
body_frame
.encode_into(&mut out)
.map_err(|_| crate::Error::Encrypt("BodyFrame serialisation failed"))?;
Ok(out)
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct DecryptedBlock {
pub plaintext: Vec<u8>,
pub compression_type: u8,
pub dict_id: u32,
pub window_log: u8,
pub block_flags: u8,
}
pub fn decrypt_block(
bytes: &[u8],
identity: &BlockIdentity,
key_chain: &dyn KeyChain,
) -> Result<DecryptedBlock, DecryptError> {
let mut cursor = Cursor::new(bytes);
let metadata_payload = read_framed_payload(
&mut cursor,
Some(METADATA_VARIANT),
METADATA_PAYLOAD_LEN_V1,
METADATA_PAYLOAD_LEN_V1,
DecryptError::MalformedMetadataFrame,
)?;
let parsed = decode_metadata_payload(&metadata_payload)?;
let mut ciphertext = read_framed_payload(
&mut cursor,
Some(BODY_VARIANT),
1,
MAX_BODY_LEN,
DecryptError::MalformedBodyFrame,
)?;
let pos = cursor.position();
let total = u64::try_from(bytes.len()).unwrap_or(u64::MAX);
if pos != total {
return Err(DecryptError::MalformedBodyFrame(
"unexpected trailing bytes after BodyFrame",
));
}
let block_type = crate::table::block::BlockType::try_from(parsed.block_type_byte)
.map_err(|_| DecryptError::MalformedMetadataFrame("unknown BlockType byte"))?;
let aad_identity = BlockIdentity {
table_id: identity.table_id,
block_type,
dict_id: parsed.dict_id,
window_log: parsed.window_log,
};
let aad = build(&parsed.ctx, &aad_identity);
debug_assert_eq!(aad.len(), AAD_LEN);
let key = key_chain
.key(parsed.ctx.key_epoch)
.ok_or(DecryptError::UnknownKeyEpoch {
key_epoch: parsed.ctx.key_epoch,
})?;
decrypt_in_place(
parsed.suite_id,
key,
&parsed.nonce,
&aad,
&parsed.tag,
&mut ciphertext,
)?;
Ok(DecryptedBlock {
plaintext: ciphertext,
compression_type: parsed.ctx.compression_type,
dict_id: parsed.dict_id,
window_log: parsed.window_log,
block_flags: parsed.ctx.block_flags,
})
}
#[cfg(test)]
#[expect(clippy::unwrap_used, clippy::indexing_slicing, reason = "test code")]
mod tests;