use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use aes_gcm::aead::{Aead, KeyInit, Payload};
use aes_gcm::{Aes256Gcm, Key, Nonce};
use bytes::Bytes;
use md5::{Digest as Md5Digest, Md5};
use rand::RngCore;
use thiserror::Error;
use crate::kms::{KmsBackend, KmsError, WrappedDek};
pub const SSE_MAGIC_V1: &[u8; 4] = b"S4E1";
pub const SSE_MAGIC_V2: &[u8; 4] = b"S4E2";
pub const SSE_MAGIC_V3: &[u8; 4] = b"S4E3";
pub const SSE_MAGIC_V4: &[u8; 4] = b"S4E4";
pub const SSE_MAGIC_V5: &[u8; 4] = b"S4E5";
pub const SSE_MAGIC_V6: &[u8; 4] = b"S4E6";
pub const SSE_MAGIC: &[u8; 4] = SSE_MAGIC_V1;
pub const SSE_HEADER_BYTES: usize = 4 + 1 + 3 + 12 + 16; pub const SSE_HEADER_BYTES_V3: usize = 4 + 1 + KEY_MD5_LEN + 12 + 16; pub const ALGO_AES_256_GCM: u8 = 1;
const NONCE_LEN: usize = 12;
const TAG_LEN: usize = 16;
const KEY_LEN: usize = 32;
const KEY_MD5_LEN: usize = 16;
pub const SSE_C_ALGORITHM: &str = "AES256";
#[derive(Debug, Error)]
pub enum SseError {
#[error("SSE key file {path:?}: {source}")]
KeyFileIo {
path: std::path::PathBuf,
source: std::io::Error,
},
#[error(
"SSE key file must be exactly 32 raw bytes (or 64-char hex / 44-char base64); got {got} bytes after parse"
)]
BadKeyLength { got: usize },
#[error("SSE-encrypted body too short ({got} bytes; need at least {SSE_HEADER_BYTES})")]
TooShort { got: usize },
#[error("SSE bad magic: expected S4E1/S4E2/S4E3/S4E4/S4E5/S4E6, got {got:?}")]
BadMagic { got: [u8; 4] },
#[error("SSE unsupported algo tag: {tag} (this build only knows AES-256-GCM = 1)")]
UnsupportedAlgo { tag: u8 },
#[error(
"SSE key_id {id} (S4E2 frame) not present in keyring; rotation history likely incomplete"
)]
KeyNotInKeyring { id: u16 },
#[error("SSE decryption / authentication failed (key mismatch or ciphertext tampered with)")]
DecryptFailed,
#[error("SSE-C key MD5 fingerprint mismatch — client supplied a different key than PUT")]
WrongCustomerKey,
#[error("SSE-C customer-key headers invalid: {reason}")]
InvalidCustomerKey { reason: &'static str },
#[error("SSE-C algorithm {algo:?} unsupported (only {SSE_C_ALGORITHM:?} is allowed)")]
CustomerKeyAlgorithmUnsupported { algo: String },
#[error("S4E3 frame requires SseSource::CustomerKey; got Keyring")]
CustomerKeyRequired,
#[error("S4E1/S4E2 frame stored without SSE-C; SseSource::CustomerKey is unexpected")]
CustomerKeyUnexpected,
#[error(
"S4E4 (SSE-KMS) body requires async decrypt — call decrypt_with_kms() instead of decrypt()"
)]
KmsAsyncRequired,
#[error("S4E4 frame too short ({got} bytes; need at least {min})")]
KmsFrameTooShort { got: usize, min: usize },
#[error("S4E4 frame field length out of bounds: {what}")]
KmsFrameFieldOob { what: &'static str },
#[error("S4E4 key_id is not valid UTF-8")]
KmsKeyIdNotUtf8,
#[error(
"S4E4 SseSource::Kms wrapped DEK key_id {supplied:?} doesn't match frame key_id {stored:?}"
)]
KmsWrappedDekMismatch { supplied: String, stored: String },
#[error("S4E4 frame requires SseSource::Kms")]
KmsRequired,
#[error("KMS unwrap: {0}")]
KmsBackend(#[from] KmsError),
#[error(
"S4E5 chunk {chunk_index} auth tag verify failed (key mismatch or chunk tampered with)"
)]
ChunkAuthFailed { chunk_index: u32 },
#[error("S4E5 chunk_size must be > 0 (got 0)")]
ChunkSizeInvalid,
#[error("S4E5 frame truncated: {what}")]
ChunkFrameTruncated { what: &'static str },
#[error("S4E6 chunk_count {got} exceeds 24-bit max ({max}) — pick a larger --sse-chunk-size")]
ChunkCountTooLarge { got: u32, max: u32 },
#[error("S4E5/S4E6 chunked frame declares an over-large size: {details}")]
ChunkFrameTooLarge { details: &'static str },
}
pub const DEFAULT_MAX_BODY_BYTES: usize = 5 * 1024 * 1024 * 1024;
pub struct SseKey {
pub bytes: [u8; 32],
}
impl SseKey {
pub fn from_path(path: &Path) -> Result<Self, SseError> {
let raw = std::fs::read(path).map_err(|source| SseError::KeyFileIo {
path: path.to_path_buf(),
source,
})?;
Self::from_bytes(&raw)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, SseError> {
if bytes.len() == KEY_LEN {
let mut k = [0u8; KEY_LEN];
k.copy_from_slice(bytes);
return Ok(Self { bytes: k });
}
let s = std::str::from_utf8(bytes).unwrap_or("").trim();
if s.len() == KEY_LEN * 2 && s.chars().all(|c| c.is_ascii_hexdigit()) {
let mut k = [0u8; KEY_LEN];
for (i, k_byte) in k.iter_mut().enumerate() {
*k_byte = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16)
.map_err(|_| SseError::BadKeyLength { got: bytes.len() })?;
}
return Ok(Self { bytes: k });
}
if let Ok(decoded) =
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, s.as_bytes())
&& decoded.len() == KEY_LEN
{
let mut k = [0u8; KEY_LEN];
k.copy_from_slice(&decoded);
return Ok(Self { bytes: k });
}
Err(SseError::BadKeyLength { got: bytes.len() })
}
fn as_aes_key(&self) -> &Key<Aes256Gcm> {
Key::<Aes256Gcm>::from_slice(&self.bytes)
}
}
impl std::fmt::Debug for SseKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SseKey")
.field("len", &KEY_LEN)
.field("key", &"<redacted>")
.finish()
}
}
#[derive(Clone)]
pub struct SseKeyring {
active: u16,
keys: HashMap<u16, Arc<SseKey>>,
}
impl SseKeyring {
pub fn new(active: u16, key: Arc<SseKey>) -> Self {
let mut keys = HashMap::new();
keys.insert(active, key);
Self { active, keys }
}
pub fn add(&mut self, id: u16, key: Arc<SseKey>) {
self.keys.insert(id, key);
}
pub fn active(&self) -> (u16, &SseKey) {
let id = self.active;
let key = self
.keys
.get(&id)
.expect("active key id must be present in keyring (constructor invariant)");
(id, key.as_ref())
}
pub fn get(&self, id: u16) -> Option<&SseKey> {
self.keys.get(&id).map(Arc::as_ref)
}
}
impl std::fmt::Debug for SseKeyring {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SseKeyring")
.field("active", &self.active)
.field("key_count", &self.keys.len())
.field("key_ids", &self.keys.keys().collect::<Vec<_>>())
.finish()
}
}
pub type SharedSseKeyring = Arc<SseKeyring>;
pub fn encrypt(key: &SseKey, plaintext: &[u8]) -> Bytes {
let cipher = Aes256Gcm::new(key.as_aes_key());
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let mut aad = [0u8; 8];
aad[..4].copy_from_slice(SSE_MAGIC_V1);
aad[4] = ALGO_AES_256_GCM;
let ct_with_tag = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad: &aad,
},
)
.expect("aes-gcm encrypt cannot fail with a 32-byte key");
debug_assert!(ct_with_tag.len() >= TAG_LEN);
let split = ct_with_tag.len() - TAG_LEN;
let (ct, tag) = ct_with_tag.split_at(split);
let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
out.extend_from_slice(SSE_MAGIC_V1);
out.push(ALGO_AES_256_GCM);
out.extend_from_slice(&[0u8; 3]); out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(tag);
out.extend_from_slice(ct);
Bytes::from(out)
}
pub fn encrypt_v2(plaintext: &[u8], keyring: &SseKeyring) -> Bytes {
let (key_id, key) = keyring.active();
let cipher = Aes256Gcm::new(key.as_aes_key());
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let aad = aad_v2(key_id);
let ct_with_tag = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad: &aad,
},
)
.expect("aes-gcm encrypt cannot fail with a 32-byte key");
debug_assert!(ct_with_tag.len() >= TAG_LEN);
let split = ct_with_tag.len() - TAG_LEN;
let (ct, tag) = ct_with_tag.split_at(split);
let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
out.extend_from_slice(SSE_MAGIC_V2);
out.push(ALGO_AES_256_GCM);
out.extend_from_slice(&key_id.to_be_bytes()); out.push(0u8); out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(tag);
out.extend_from_slice(ct);
Bytes::from(out)
}
fn aad_v1() -> [u8; 8] {
let mut aad = [0u8; 8];
aad[..4].copy_from_slice(SSE_MAGIC_V1);
aad[4] = ALGO_AES_256_GCM;
aad
}
fn aad_v2(key_id: u16) -> [u8; 8] {
let mut aad = [0u8; 8];
aad[..4].copy_from_slice(SSE_MAGIC_V2);
aad[4] = ALGO_AES_256_GCM;
aad[5..7].copy_from_slice(&key_id.to_be_bytes());
aad[7] = 0u8;
aad
}
fn aad_v3(key_md5: &[u8; KEY_MD5_LEN]) -> [u8; 4 + 1 + KEY_MD5_LEN] {
let mut aad = [0u8; 4 + 1 + KEY_MD5_LEN];
aad[..4].copy_from_slice(SSE_MAGIC_V3);
aad[4] = ALGO_AES_256_GCM;
aad[5..5 + KEY_MD5_LEN].copy_from_slice(key_md5);
aad
}
#[derive(Clone)]
pub struct CustomerKeyMaterial {
pub key: [u8; KEY_LEN],
pub key_md5: [u8; KEY_MD5_LEN],
}
impl std::fmt::Debug for CustomerKeyMaterial {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CustomerKeyMaterial")
.field("key", &"<redacted>")
.field("key_md5_hex", &hex_lower(&self.key_md5))
.finish()
}
}
fn hex_lower(bytes: &[u8]) -> String {
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
s.push_str(&format!("{b:02x}"));
}
s
}
#[derive(Debug, Clone, Copy)]
pub enum SseSource<'a> {
Keyring(&'a SseKeyring),
CustomerKey {
key: &'a [u8; KEY_LEN],
key_md5: &'a [u8; KEY_MD5_LEN],
},
Kms {
dek: &'a [u8; KEY_LEN],
wrapped: &'a WrappedDek,
},
}
impl<'a> From<&'a SseKeyring> for SseSource<'a> {
fn from(kr: &'a SseKeyring) -> Self {
SseSource::Keyring(kr)
}
}
impl<'a> From<&'a Arc<SseKeyring>> for SseSource<'a> {
fn from(kr: &'a Arc<SseKeyring>) -> Self {
SseSource::Keyring(kr.as_ref())
}
}
impl<'a> From<&'a CustomerKeyMaterial> for SseSource<'a> {
fn from(m: &'a CustomerKeyMaterial) -> Self {
SseSource::CustomerKey {
key: &m.key,
key_md5: &m.key_md5,
}
}
}
pub fn parse_customer_key_headers(
algorithm: &str,
key_base64: &str,
key_md5_base64: &str,
) -> Result<CustomerKeyMaterial, SseError> {
use base64::Engine as _;
if algorithm != SSE_C_ALGORITHM {
return Err(SseError::CustomerKeyAlgorithmUnsupported {
algo: algorithm.to_string(),
});
}
let key_bytes = base64::engine::general_purpose::STANDARD
.decode(key_base64.trim().as_bytes())
.map_err(|_| SseError::InvalidCustomerKey {
reason: "base64 decode of key",
})?;
if key_bytes.len() != KEY_LEN {
return Err(SseError::InvalidCustomerKey {
reason: "key length (must be 32 bytes after base64 decode)",
});
}
let supplied_md5 = base64::engine::general_purpose::STANDARD
.decode(key_md5_base64.trim().as_bytes())
.map_err(|_| SseError::InvalidCustomerKey {
reason: "base64 decode of key MD5",
})?;
if supplied_md5.len() != KEY_MD5_LEN {
return Err(SseError::InvalidCustomerKey {
reason: "key MD5 length (must be 16 bytes after base64 decode)",
});
}
let actual_md5 = compute_key_md5(&key_bytes);
if !constant_time_eq(&actual_md5, &supplied_md5) {
return Err(SseError::InvalidCustomerKey {
reason: "supplied MD5 does not match MD5 of supplied key",
});
}
let mut key = [0u8; KEY_LEN];
key.copy_from_slice(&key_bytes);
let mut key_md5 = [0u8; KEY_MD5_LEN];
key_md5.copy_from_slice(&actual_md5);
Ok(CustomerKeyMaterial { key, key_md5 })
}
pub fn compute_key_md5(key: &[u8]) -> [u8; KEY_MD5_LEN] {
let mut h = Md5::new();
h.update(key);
let out = h.finalize();
let mut md5 = [0u8; KEY_MD5_LEN];
md5.copy_from_slice(&out);
md5
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut acc: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
acc |= x ^ y;
}
acc == 0
}
pub fn encrypt_with_source(plaintext: &[u8], source: SseSource<'_>) -> Bytes {
match source {
SseSource::Keyring(kr) => encrypt_v2(plaintext, kr),
SseSource::CustomerKey { key, key_md5 } => encrypt_v3(plaintext, key, key_md5),
SseSource::Kms { dek, wrapped } => encrypt_v4(plaintext, dek, wrapped),
}
}
fn encrypt_v3(plaintext: &[u8], key: &[u8; KEY_LEN], key_md5: &[u8; KEY_MD5_LEN]) -> Bytes {
let aes_key = Key::<Aes256Gcm>::from_slice(key);
let cipher = Aes256Gcm::new(aes_key);
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let aad = aad_v3(key_md5);
let ct_with_tag = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad: &aad,
},
)
.expect("aes-gcm encrypt cannot fail with a 32-byte key");
debug_assert!(ct_with_tag.len() >= TAG_LEN);
let split = ct_with_tag.len() - TAG_LEN;
let (ct, tag) = ct_with_tag.split_at(split);
let mut out = Vec::with_capacity(SSE_HEADER_BYTES_V3 + ct.len());
out.extend_from_slice(SSE_MAGIC_V3);
out.push(ALGO_AES_256_GCM);
out.extend_from_slice(key_md5);
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(tag);
out.extend_from_slice(ct);
Bytes::from(out)
}
pub fn decrypt<'a, S: Into<SseSource<'a>>>(body: &[u8], source: S) -> Result<Bytes, SseError> {
let source = source.into();
if body.len() < SSE_HEADER_BYTES {
return Err(SseError::TooShort { got: body.len() });
}
let mut magic = [0u8; 4];
magic.copy_from_slice(&body[..4]);
match &magic {
m if m == SSE_MAGIC_V1 || m == SSE_MAGIC_V2 => {
let keyring = match source {
SseSource::Keyring(kr) => kr,
SseSource::CustomerKey { .. } => return Err(SseError::CustomerKeyUnexpected),
SseSource::Kms { .. } => return Err(SseError::CustomerKeyUnexpected),
};
if m == SSE_MAGIC_V1 {
decrypt_v1_with_keyring(body, keyring)
} else {
decrypt_v2_with_keyring(body, keyring)
}
}
m if m == SSE_MAGIC_V3 => {
if body.len() < SSE_HEADER_BYTES_V3 {
return Err(SseError::TooShort { got: body.len() });
}
let (key, key_md5) = match source {
SseSource::CustomerKey { key, key_md5 } => (key, key_md5),
SseSource::Keyring(_) => return Err(SseError::CustomerKeyRequired),
SseSource::Kms { .. } => return Err(SseError::CustomerKeyRequired),
};
decrypt_v3(body, key, key_md5)
}
m if m == SSE_MAGIC_V4 => {
Err(SseError::KmsAsyncRequired)
}
m if m == SSE_MAGIC_V5 || m == SSE_MAGIC_V6 => {
let keyring = match source {
SseSource::Keyring(kr) => kr,
SseSource::CustomerKey { .. } => {
return Err(SseError::CustomerKeyUnexpected);
}
SseSource::Kms { .. } => return Err(SseError::CustomerKeyUnexpected),
};
decrypt_chunked_buffered_default(body, keyring)
}
_ => Err(SseError::BadMagic { got: magic }),
}
}
fn decrypt_v3(
body: &[u8],
key: &[u8; KEY_LEN],
supplied_md5: &[u8; KEY_MD5_LEN],
) -> Result<Bytes, SseError> {
let algo = body[4];
if algo != ALGO_AES_256_GCM {
return Err(SseError::UnsupportedAlgo { tag: algo });
}
let mut stored_md5 = [0u8; KEY_MD5_LEN];
stored_md5.copy_from_slice(&body[5..5 + KEY_MD5_LEN]);
if !constant_time_eq(supplied_md5, &stored_md5) {
return Err(SseError::WrongCustomerKey);
}
let nonce_off = 5 + KEY_MD5_LEN;
let tag_off = nonce_off + NONCE_LEN;
let mut nonce_bytes = [0u8; NONCE_LEN];
nonce_bytes.copy_from_slice(&body[nonce_off..nonce_off + NONCE_LEN]);
let mut tag_bytes = [0u8; TAG_LEN];
tag_bytes.copy_from_slice(&body[tag_off..tag_off + TAG_LEN]);
let ct = &body[SSE_HEADER_BYTES_V3..];
let aad = aad_v3(&stored_md5);
let nonce = Nonce::from_slice(&nonce_bytes);
let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
ct_with_tag.extend_from_slice(ct);
ct_with_tag.extend_from_slice(&tag_bytes);
let aes_key = Key::<Aes256Gcm>::from_slice(key);
let cipher = Aes256Gcm::new(aes_key);
let plain = cipher
.decrypt(
nonce,
Payload {
msg: &ct_with_tag,
aad: &aad,
},
)
.map_err(|_| SseError::DecryptFailed)?;
Ok(Bytes::from(plain))
}
fn aad_v4(key_id: &[u8], wrapped_dek: &[u8]) -> Vec<u8> {
let mut aad = Vec::with_capacity(4 + 1 + 1 + key_id.len() + 4 + wrapped_dek.len());
aad.extend_from_slice(SSE_MAGIC_V4);
aad.push(ALGO_AES_256_GCM);
aad.push(key_id.len() as u8);
aad.extend_from_slice(key_id);
aad.extend_from_slice(&(wrapped_dek.len() as u32).to_be_bytes());
aad.extend_from_slice(wrapped_dek);
aad
}
fn encrypt_v4(plaintext: &[u8], dek: &[u8; KEY_LEN], wrapped: &WrappedDek) -> Bytes {
assert!(
!wrapped.key_id.is_empty() && wrapped.key_id.len() <= u8::MAX as usize,
"S4E4 key_id must be 1..=255 bytes (got {})",
wrapped.key_id.len()
);
assert!(
wrapped.ciphertext.len() <= u32::MAX as usize,
"S4E4 wrapped_dek longer than u32::MAX",
);
let aes_key = Key::<Aes256Gcm>::from_slice(dek);
let cipher = Aes256Gcm::new(aes_key);
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let aad = aad_v4(wrapped.key_id.as_bytes(), &wrapped.ciphertext);
let ct_with_tag = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad: &aad,
},
)
.expect("aes-gcm encrypt cannot fail with a 32-byte key");
debug_assert!(ct_with_tag.len() >= TAG_LEN);
let split = ct_with_tag.len() - TAG_LEN;
let (ct, tag) = ct_with_tag.split_at(split);
let key_id_bytes = wrapped.key_id.as_bytes();
let mut out = Vec::with_capacity(
4 + 1
+ 1
+ key_id_bytes.len()
+ 4
+ wrapped.ciphertext.len()
+ NONCE_LEN
+ TAG_LEN
+ ct.len(),
);
out.extend_from_slice(SSE_MAGIC_V4);
out.push(ALGO_AES_256_GCM);
out.push(key_id_bytes.len() as u8);
out.extend_from_slice(key_id_bytes);
out.extend_from_slice(&(wrapped.ciphertext.len() as u32).to_be_bytes());
out.extend_from_slice(&wrapped.ciphertext);
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(tag);
out.extend_from_slice(ct);
Bytes::from(out)
}
#[derive(Debug)]
pub struct S4E4Header<'a> {
pub key_id: &'a str,
pub wrapped_dek: &'a [u8],
pub nonce: &'a [u8],
pub tag: &'a [u8],
pub ciphertext: &'a [u8],
}
pub fn parse_s4e4_header(body: &[u8]) -> Result<S4E4Header<'_>, SseError> {
const S4E4_MIN: usize = 4 + 1 + 1 + 4 + NONCE_LEN + TAG_LEN; if body.len() < S4E4_MIN {
return Err(SseError::KmsFrameTooShort {
got: body.len(),
min: S4E4_MIN,
});
}
let magic = &body[..4];
if magic != SSE_MAGIC_V4 {
let mut got = [0u8; 4];
got.copy_from_slice(magic);
return Err(SseError::BadMagic { got });
}
let algo = body[4];
if algo != ALGO_AES_256_GCM {
return Err(SseError::UnsupportedAlgo { tag: algo });
}
let key_id_len = body[5] as usize;
let key_id_off: usize = 6;
let key_id_end = key_id_off
.checked_add(key_id_len)
.ok_or(SseError::KmsFrameFieldOob { what: "key_id_len" })?;
if key_id_end + 4 > body.len() {
return Err(SseError::KmsFrameFieldOob { what: "key_id" });
}
let key_id = std::str::from_utf8(&body[key_id_off..key_id_end])
.map_err(|_| SseError::KmsKeyIdNotUtf8)?;
let wrapped_len_off = key_id_end;
let wrapped_dek_len = u32::from_be_bytes([
body[wrapped_len_off],
body[wrapped_len_off + 1],
body[wrapped_len_off + 2],
body[wrapped_len_off + 3],
]) as usize;
let wrapped_off = wrapped_len_off + 4;
let wrapped_end =
wrapped_off
.checked_add(wrapped_dek_len)
.ok_or(SseError::KmsFrameFieldOob {
what: "wrapped_dek_len",
})?;
if wrapped_end + NONCE_LEN + TAG_LEN > body.len() {
return Err(SseError::KmsFrameFieldOob {
what: "wrapped_dek",
});
}
let wrapped_dek = &body[wrapped_off..wrapped_end];
let nonce_off = wrapped_end;
let tag_off = nonce_off + NONCE_LEN;
let ct_off = tag_off + TAG_LEN;
let nonce = &body[nonce_off..nonce_off + NONCE_LEN];
let tag = &body[tag_off..tag_off + TAG_LEN];
let ciphertext = &body[ct_off..];
Ok(S4E4Header {
key_id,
wrapped_dek,
nonce,
tag,
ciphertext,
})
}
pub async fn decrypt_with_kms(body: &[u8], kms: &dyn KmsBackend) -> Result<Bytes, SseError> {
let hdr = parse_s4e4_header(body)?;
let wrapped = WrappedDek {
key_id: hdr.key_id.to_string(),
ciphertext: hdr.wrapped_dek.to_vec(),
};
let dek_vec = kms.decrypt_dek(&wrapped).await?;
if dek_vec.len() != KEY_LEN {
return Err(SseError::KmsBackend(KmsError::BackendUnavailable {
message: format!(
"KMS returned {} byte DEK; expected {KEY_LEN}",
dek_vec.len()
),
}));
}
let mut dek = [0u8; KEY_LEN];
dek.copy_from_slice(&dek_vec);
let aad = aad_v4(hdr.key_id.as_bytes(), hdr.wrapped_dek);
let aes_key = Key::<Aes256Gcm>::from_slice(&dek);
let cipher = Aes256Gcm::new(aes_key);
let nonce = Nonce::from_slice(hdr.nonce);
let mut ct_with_tag = Vec::with_capacity(hdr.ciphertext.len() + TAG_LEN);
ct_with_tag.extend_from_slice(hdr.ciphertext);
ct_with_tag.extend_from_slice(hdr.tag);
let plain = cipher
.decrypt(
nonce,
Payload {
msg: &ct_with_tag,
aad: &aad,
},
)
.map_err(|_| SseError::DecryptFailed)?;
Ok(Bytes::from(plain))
}
fn decrypt_v1_with_keyring(body: &[u8], keyring: &SseKeyring) -> Result<Bytes, SseError> {
let algo = body[4];
if algo != ALGO_AES_256_GCM {
return Err(SseError::UnsupportedAlgo { tag: algo });
}
let mut nonce_bytes = [0u8; NONCE_LEN];
nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
let mut tag_bytes = [0u8; TAG_LEN];
tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
let ct = &body[SSE_HEADER_BYTES..];
let aad = aad_v1();
let nonce = Nonce::from_slice(&nonce_bytes);
let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
ct_with_tag.extend_from_slice(ct);
ct_with_tag.extend_from_slice(&tag_bytes);
let (active_id, _active_key) = keyring.active();
let mut ids: Vec<u16> = keyring.keys.keys().copied().collect();
ids.sort_by_key(|id| if *id == active_id { 0 } else { 1 });
for id in ids {
let key = keyring.get(id).expect("id came from keyring iteration");
let cipher = Aes256Gcm::new(key.as_aes_key());
if let Ok(plain) = cipher.decrypt(
nonce,
Payload {
msg: &ct_with_tag,
aad: &aad,
},
) {
return Ok(Bytes::from(plain));
}
}
Err(SseError::DecryptFailed)
}
fn decrypt_v2_with_keyring(body: &[u8], keyring: &SseKeyring) -> Result<Bytes, SseError> {
let algo = body[4];
if algo != ALGO_AES_256_GCM {
return Err(SseError::UnsupportedAlgo { tag: algo });
}
let key_id = u16::from_be_bytes([body[5], body[6]]);
let key = keyring
.get(key_id)
.ok_or(SseError::KeyNotInKeyring { id: key_id })?;
let mut nonce_bytes = [0u8; NONCE_LEN];
nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
let mut tag_bytes = [0u8; TAG_LEN];
tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
let ct = &body[SSE_HEADER_BYTES..];
let aad = aad_v2(key_id);
let nonce = Nonce::from_slice(&nonce_bytes);
let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
ct_with_tag.extend_from_slice(ct);
ct_with_tag.extend_from_slice(&tag_bytes);
let cipher = Aes256Gcm::new(key.as_aes_key());
let plain = cipher
.decrypt(
nonce,
Payload {
msg: &ct_with_tag,
aad: &aad,
},
)
.map_err(|_| SseError::DecryptFailed)?;
Ok(Bytes::from(plain))
}
pub fn looks_encrypted(body: &[u8]) -> bool {
if body.len() < SSE_HEADER_BYTES {
return false;
}
let m = &body[..4];
m == SSE_MAGIC_V1
|| m == SSE_MAGIC_V2
|| m == SSE_MAGIC_V3
|| m == SSE_MAGIC_V4
|| m == SSE_MAGIC_V5
|| m == SSE_MAGIC_V6
}
pub fn peek_magic(body: &[u8]) -> Option<&'static str> {
if body.len() < SSE_HEADER_BYTES {
return None;
}
match &body[..4] {
m if m == SSE_MAGIC_V1 => Some("S4E1"),
m if m == SSE_MAGIC_V2 => Some("S4E2"),
m if m == SSE_MAGIC_V3 => Some("S4E3"),
m if m == SSE_MAGIC_V4 => Some("S4E4"),
m if m == SSE_MAGIC_V5 => Some("S4E5"),
m if m == SSE_MAGIC_V6 => Some("S4E6"),
_ => None,
}
}
pub type SharedSseKey = Arc<SseKey>;
pub const S4E5_HEADER_BYTES: usize = 4 + 1 + 2 + 1 + 4 + 4 + 4;
pub const S4E5_PER_CHUNK_OVERHEAD: usize = TAG_LEN;
pub const S4E6_HEADER_BYTES: usize = 4 + 1 + 2 + 1 + 4 + 4 + 8;
pub const S4E6_PER_CHUNK_OVERHEAD: usize = TAG_LEN;
pub const S4E6_MAX_CHUNK_COUNT: u32 = (1u32 << 24) - 1;
const S4E5_NONCE_TAG: [u8; 4] = [b'E', b'5', 0, 0];
const S4E6_NONCE_PREFIX: u8 = b'E';
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ChunkedVariant {
V5,
V6,
}
impl ChunkedVariant {
fn header_bytes(self) -> usize {
match self {
ChunkedVariant::V5 => S4E5_HEADER_BYTES,
ChunkedVariant::V6 => S4E6_HEADER_BYTES,
}
}
}
fn aad_v5(
chunk_index: u32,
total_chunks: u32,
key_id: u16,
salt: &[u8; 4],
) -> [u8; 4 + 1 + 4 + 4 + 2 + 4] {
let mut aad = [0u8; 4 + 1 + 4 + 4 + 2 + 4]; aad[..4].copy_from_slice(SSE_MAGIC_V5);
aad[4] = ALGO_AES_256_GCM;
aad[5..9].copy_from_slice(&chunk_index.to_be_bytes());
aad[9..13].copy_from_slice(&total_chunks.to_be_bytes());
aad[13..15].copy_from_slice(&key_id.to_be_bytes());
aad[15..19].copy_from_slice(salt);
aad
}
fn aad_v6(
chunk_index: u32,
total_chunks: u32,
key_id: u16,
salt: &[u8; 8],
) -> [u8; 4 + 1 + 4 + 4 + 2 + 8] {
let mut aad = [0u8; 4 + 1 + 4 + 4 + 2 + 8]; aad[..4].copy_from_slice(SSE_MAGIC_V6);
aad[4] = ALGO_AES_256_GCM;
aad[5..9].copy_from_slice(&chunk_index.to_be_bytes());
aad[9..13].copy_from_slice(&total_chunks.to_be_bytes());
aad[13..15].copy_from_slice(&key_id.to_be_bytes());
aad[15..23].copy_from_slice(salt);
aad
}
fn nonce_v5(salt: &[u8; 4], chunk_index: u32) -> [u8; NONCE_LEN] {
let mut n = [0u8; NONCE_LEN];
n[..4].copy_from_slice(&S4E5_NONCE_TAG);
n[4..8].copy_from_slice(salt);
n[8..12].copy_from_slice(&chunk_index.to_be_bytes());
n
}
fn nonce_v6(salt: &[u8; 8], chunk_index: u32) -> [u8; NONCE_LEN] {
debug_assert!(
chunk_index <= S4E6_MAX_CHUNK_COUNT,
"S4E6 chunk_index {chunk_index} exceeds 24-bit cap (caller MUST validate)",
);
let mut n = [0u8; NONCE_LEN];
n[0] = S4E6_NONCE_PREFIX;
n[1..9].copy_from_slice(salt);
let be = chunk_index.to_be_bytes(); n[9..12].copy_from_slice(&be[1..4]);
n
}
pub fn encrypt_v2_chunked(
plaintext: &[u8],
keyring: &SseKeyring,
chunk_size: usize,
) -> Result<Bytes, SseError> {
if chunk_size == 0 {
return Err(SseError::ChunkSizeInvalid);
}
let (key_id, key) = keyring.active();
let cipher = Aes256Gcm::new(key.as_aes_key());
let mut salt = [0u8; 8];
rand::rngs::OsRng.fill_bytes(&mut salt);
let chunk_count_usize = if plaintext.is_empty() {
1
} else {
plaintext.len().div_ceil(chunk_size)
};
let chunk_count: u32 = u32::try_from(chunk_count_usize).unwrap_or(u32::MAX);
if chunk_count > S4E6_MAX_CHUNK_COUNT {
return Err(SseError::ChunkCountTooLarge {
got: chunk_count,
max: S4E6_MAX_CHUNK_COUNT,
});
}
let mut out = Vec::with_capacity(
S4E6_HEADER_BYTES + plaintext.len() + (chunk_count as usize * S4E6_PER_CHUNK_OVERHEAD),
);
out.extend_from_slice(SSE_MAGIC_V6);
out.push(ALGO_AES_256_GCM);
out.extend_from_slice(&key_id.to_be_bytes());
out.push(0u8); out.extend_from_slice(&(chunk_size as u32).to_be_bytes());
out.extend_from_slice(&chunk_count.to_be_bytes());
out.extend_from_slice(&salt);
for i in 0..chunk_count {
let off = (i as usize).saturating_mul(chunk_size);
let end = off.saturating_add(chunk_size).min(plaintext.len());
let chunk_pt: &[u8] = if off >= plaintext.len() {
&[]
} else {
&plaintext[off..end]
};
let nonce_bytes = nonce_v6(&salt, i);
let nonce = Nonce::from_slice(&nonce_bytes);
let aad = aad_v6(i, chunk_count, key_id, &salt);
let ct_with_tag = cipher
.encrypt(
nonce,
Payload {
msg: chunk_pt,
aad: &aad,
},
)
.expect("aes-gcm encrypt cannot fail with a 32-byte key");
debug_assert!(ct_with_tag.len() >= TAG_LEN);
let split = ct_with_tag.len() - TAG_LEN;
let (ct, tag) = ct_with_tag.split_at(split);
out.extend_from_slice(tag);
out.extend_from_slice(ct);
crate::metrics::record_sse_streaming_chunk("encrypt");
}
Ok(Bytes::from(out))
}
#[derive(Debug, Clone, Copy)]
enum ChunkedSalt {
V5([u8; 4]),
V6([u8; 8]),
}
#[derive(Debug, Clone, Copy)]
struct ChunkedHeader {
#[allow(dead_code)]
variant: ChunkedVariant,
key_id: u16,
chunk_size: u32,
chunk_count: u32,
salt: ChunkedSalt,
chunks_offset: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct S4E6Header<'a> {
pub key_id: u16,
pub chunk_size: u32,
pub chunk_count: u32,
pub salt: &'a [u8; 8],
}
pub fn parse_s4e6_header(blob: &[u8]) -> Result<S4E6Header<'_>, SseError> {
if blob.len() < S4E6_HEADER_BYTES {
return Err(SseError::ChunkFrameTruncated { what: "header" });
}
if &blob[..4] != SSE_MAGIC_V6 {
let mut got = [0u8; 4];
got.copy_from_slice(&blob[..4]);
return Err(SseError::BadMagic { got });
}
let algo = blob[4];
if algo != ALGO_AES_256_GCM {
return Err(SseError::UnsupportedAlgo { tag: algo });
}
let key_id = u16::from_be_bytes([blob[5], blob[6]]);
let chunk_size = u32::from_be_bytes([blob[8], blob[9], blob[10], blob[11]]);
let chunk_count = u32::from_be_bytes([blob[12], blob[13], blob[14], blob[15]]);
if chunk_size == 0 {
return Err(SseError::ChunkSizeInvalid);
}
if chunk_count == 0 {
return Err(SseError::ChunkFrameTruncated {
what: "chunk_count == 0",
});
}
if chunk_count > S4E6_MAX_CHUNK_COUNT {
return Err(SseError::ChunkCountTooLarge {
got: chunk_count,
max: S4E6_MAX_CHUNK_COUNT,
});
}
let salt: &[u8; 8] = (&blob[16..24]).try_into().expect("8B salt slice");
Ok(S4E6Header {
key_id,
chunk_size,
chunk_count,
salt,
})
}
fn parse_chunked_header(body: &[u8], max_body_bytes: usize) -> Result<ChunkedHeader, SseError> {
if body.len() < 4 {
return Err(SseError::ChunkFrameTruncated { what: "magic" });
}
let magic = &body[..4];
let variant = if magic == SSE_MAGIC_V5 {
ChunkedVariant::V5
} else if magic == SSE_MAGIC_V6 {
ChunkedVariant::V6
} else {
let mut got = [0u8; 4];
got.copy_from_slice(magic);
return Err(SseError::BadMagic { got });
};
let header_bytes = variant.header_bytes();
if body.len() < header_bytes {
return Err(SseError::ChunkFrameTruncated { what: "header" });
}
let algo = body[4];
if algo != ALGO_AES_256_GCM {
return Err(SseError::UnsupportedAlgo { tag: algo });
}
let key_id = u16::from_be_bytes([body[5], body[6]]);
let chunk_size = u32::from_be_bytes([body[8], body[9], body[10], body[11]]);
let chunk_count = u32::from_be_bytes([body[12], body[13], body[14], body[15]]);
if chunk_size == 0 {
return Err(SseError::ChunkSizeInvalid);
}
if chunk_count == 0 {
return Err(SseError::ChunkFrameTruncated {
what: "chunk_count == 0",
});
}
let salt = match variant {
ChunkedVariant::V5 => {
let mut s = [0u8; 4];
s.copy_from_slice(&body[16..20]);
ChunkedSalt::V5(s)
}
ChunkedVariant::V6 => {
if chunk_count > S4E6_MAX_CHUNK_COUNT {
return Err(SseError::ChunkCountTooLarge {
got: chunk_count,
max: S4E6_MAX_CHUNK_COUNT,
});
}
let mut s = [0u8; 8];
s.copy_from_slice(&body[16..24]);
ChunkedSalt::V6(s)
}
};
let chunk_size_u64 = chunk_size as u64;
let chunk_count_u64 = chunk_count as u64;
let expected_plain_size =
chunk_size_u64
.checked_mul(chunk_count_u64)
.ok_or(SseError::ChunkFrameTooLarge {
details: "chunk_size * chunk_count overflows u64",
})?;
let per_chunk_overhead = S4E5_PER_CHUNK_OVERHEAD as u64; let total_tag_overhead =
per_chunk_overhead
.checked_mul(chunk_count_u64)
.ok_or(SseError::ChunkFrameTooLarge {
details: "tag_len * chunk_count overflows u64",
})?;
let max_total = expected_plain_size
.checked_add(total_tag_overhead)
.and_then(|t| t.checked_add(header_bytes as u64))
.ok_or(SseError::ChunkFrameTooLarge {
details: "header + plaintext + tag overhead overflows u64",
})?;
if (body.len() as u64) > max_total {
return Err(SseError::ChunkFrameTruncated {
what: "trailing bytes past declared chunk geometry",
});
}
if expected_plain_size > max_body_bytes as u64 {
return Err(SseError::ChunkFrameTooLarge {
details: "declared plaintext exceeds gateway max_body_bytes",
});
}
Ok(ChunkedHeader {
variant,
key_id,
chunk_size,
chunk_count,
salt,
chunks_offset: header_bytes,
})
}
fn decrypt_chunked_chunk(
cipher: &Aes256Gcm,
chunk_index: u32,
chunk_count: u32,
key_id: u16,
salt: &ChunkedSalt,
tag: &[u8; TAG_LEN],
ct: &[u8],
) -> Result<Bytes, SseError> {
let nonce_bytes = match salt {
ChunkedSalt::V5(s) => nonce_v5(s, chunk_index),
ChunkedSalt::V6(s) => nonce_v6(s, chunk_index),
};
let nonce = Nonce::from_slice(&nonce_bytes);
let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
ct_with_tag.extend_from_slice(ct);
ct_with_tag.extend_from_slice(tag);
let result = match salt {
ChunkedSalt::V5(s) => {
let aad = aad_v5(chunk_index, chunk_count, key_id, s);
cipher.decrypt(
nonce,
Payload {
msg: &ct_with_tag,
aad: &aad,
},
)
}
ChunkedSalt::V6(s) => {
let aad = aad_v6(chunk_index, chunk_count, key_id, s);
cipher.decrypt(
nonce,
Payload {
msg: &ct_with_tag,
aad: &aad,
},
)
}
};
result
.map(Bytes::from)
.map_err(|_| SseError::ChunkAuthFailed { chunk_index })
}
fn walk_chunked<F: FnMut(Bytes) -> Result<(), SseError>>(
body: &[u8],
keyring: &SseKeyring,
max_body_bytes: usize,
mut emit: F,
) -> Result<(), SseError> {
let hdr = parse_chunked_header(body, max_body_bytes)?;
let key = keyring
.get(hdr.key_id)
.ok_or(SseError::KeyNotInKeyring { id: hdr.key_id })?;
let cipher = Aes256Gcm::new(key.as_aes_key());
let mut cursor = hdr.chunks_offset;
let chunk_size = hdr.chunk_size as usize;
for i in 0..hdr.chunk_count {
if cursor + TAG_LEN > body.len() {
return Err(SseError::ChunkFrameTruncated { what: "chunk tag" });
}
let tag_off = cursor;
let ct_off = tag_off + TAG_LEN;
let is_last = i + 1 == hdr.chunk_count;
let ct_len = if is_last {
if ct_off > body.len() {
return Err(SseError::ChunkFrameTruncated {
what: "final chunk ciphertext",
});
}
let remaining = body.len() - ct_off;
if remaining > chunk_size {
return Err(SseError::ChunkFrameTruncated {
what: "trailing bytes after final chunk",
});
}
remaining
} else {
chunk_size
};
let ct_end = ct_off + ct_len;
if ct_end > body.len() {
return Err(SseError::ChunkFrameTruncated {
what: "chunk ciphertext",
});
}
let mut tag = [0u8; TAG_LEN];
tag.copy_from_slice(&body[tag_off..ct_off]);
let ct = &body[ct_off..ct_end];
let plain =
decrypt_chunked_chunk(&cipher, i, hdr.chunk_count, hdr.key_id, &hdr.salt, &tag, ct)?;
crate::metrics::record_sse_streaming_chunk("decrypt");
emit(plain)?;
cursor = ct_end;
}
if cursor != body.len() {
return Err(SseError::ChunkFrameTruncated {
what: "trailing bytes after declared chunk_count",
});
}
Ok(())
}
pub fn decrypt_chunked_buffered(
body: &[u8],
keyring: &SseKeyring,
max_body_bytes: usize,
) -> Result<Bytes, SseError> {
let hdr = parse_chunked_header(body, max_body_bytes)?;
let mut out = Vec::with_capacity(hdr.chunk_size as usize * hdr.chunk_count as usize);
walk_chunked(body, keyring, max_body_bytes, |chunk| {
out.extend_from_slice(&chunk);
Ok(())
})?;
Ok(Bytes::from(out))
}
pub fn decrypt_chunked_buffered_default(
body: &[u8],
keyring: &SseKeyring,
) -> Result<Bytes, SseError> {
decrypt_chunked_buffered(body, keyring, DEFAULT_MAX_BODY_BYTES)
}
pub fn decrypt_chunked_stream(
body: bytes::Bytes,
keyring: &SseKeyring,
) -> impl futures::Stream<Item = Result<Bytes, SseError>> + 'static {
use futures::stream::{self, StreamExt};
let prelude = (|| {
let hdr = parse_chunked_header(&body, usize::MAX)?;
let key = keyring
.get(hdr.key_id)
.ok_or(SseError::KeyNotInKeyring { id: hdr.key_id })?;
let cipher = Aes256Gcm::new(key.as_aes_key());
Ok::<_, SseError>((hdr, cipher))
})();
match prelude {
Err(e) => stream::iter(std::iter::once(Err(e))).left_stream(),
Ok((hdr, cipher)) => {
let chunks_offset = hdr.chunks_offset;
let state = ChunkedDecryptState {
body,
cipher,
hdr,
cursor: chunks_offset,
next_index: 0,
};
stream::try_unfold(state, decrypt_next_chunk).right_stream()
}
}
}
struct ChunkedDecryptState {
body: bytes::Bytes,
cipher: Aes256Gcm,
hdr: ChunkedHeader,
cursor: usize,
next_index: u32,
}
async fn decrypt_next_chunk(
mut state: ChunkedDecryptState,
) -> Result<Option<(Bytes, ChunkedDecryptState)>, SseError> {
if state.next_index >= state.hdr.chunk_count {
if state.cursor != state.body.len() {
return Err(SseError::ChunkFrameTruncated {
what: "trailing bytes after declared chunk_count",
});
}
return Ok(None);
}
let i = state.next_index;
let chunk_size = state.hdr.chunk_size as usize;
if state.cursor + TAG_LEN > state.body.len() {
return Err(SseError::ChunkFrameTruncated { what: "chunk tag" });
}
let tag_off = state.cursor;
let ct_off = tag_off + TAG_LEN;
let is_last = i + 1 == state.hdr.chunk_count;
let ct_len = if is_last {
if ct_off > state.body.len() {
return Err(SseError::ChunkFrameTruncated {
what: "final chunk ciphertext",
});
}
let remaining = state.body.len() - ct_off;
if remaining > chunk_size {
return Err(SseError::ChunkFrameTruncated {
what: "trailing bytes after final chunk",
});
}
remaining
} else {
chunk_size
};
let ct_end = ct_off + ct_len;
if ct_end > state.body.len() {
return Err(SseError::ChunkFrameTruncated {
what: "chunk ciphertext",
});
}
let mut tag = [0u8; TAG_LEN];
tag.copy_from_slice(&state.body[tag_off..ct_off]);
let ct = &state.body[ct_off..ct_end];
let plain = decrypt_chunked_chunk(
&state.cipher,
i,
state.hdr.chunk_count,
state.hdr.key_id,
&state.hdr.salt,
&tag,
ct,
)?;
crate::metrics::record_sse_streaming_chunk("decrypt");
state.cursor = ct_end;
state.next_index += 1;
Ok(Some((plain, state)))
}
#[cfg(test)]
fn encrypt_v2_chunked_s4e5_for_test(
plaintext: &[u8],
keyring: &SseKeyring,
chunk_size: usize,
) -> Result<Bytes, SseError> {
if chunk_size == 0 {
return Err(SseError::ChunkSizeInvalid);
}
let (key_id, key) = keyring.active();
let cipher = Aes256Gcm::new(key.as_aes_key());
let mut salt = [0u8; 4];
rand::rngs::OsRng.fill_bytes(&mut salt);
let chunk_count: u32 = if plaintext.is_empty() {
1
} else {
plaintext
.len()
.div_ceil(chunk_size)
.try_into()
.expect("chunk_count overflows u32")
};
let mut out = Vec::with_capacity(
S4E5_HEADER_BYTES + plaintext.len() + (chunk_count as usize * S4E5_PER_CHUNK_OVERHEAD),
);
out.extend_from_slice(SSE_MAGIC_V5);
out.push(ALGO_AES_256_GCM);
out.extend_from_slice(&key_id.to_be_bytes());
out.push(0u8);
out.extend_from_slice(&(chunk_size as u32).to_be_bytes());
out.extend_from_slice(&chunk_count.to_be_bytes());
out.extend_from_slice(&salt);
for i in 0..chunk_count {
let off = (i as usize).saturating_mul(chunk_size);
let end = off.saturating_add(chunk_size).min(plaintext.len());
let chunk_pt: &[u8] = if off >= plaintext.len() {
&[]
} else {
&plaintext[off..end]
};
let nonce_bytes = nonce_v5(&salt, i);
let nonce = Nonce::from_slice(&nonce_bytes);
let aad = aad_v5(i, chunk_count, key_id, &salt);
let ct_with_tag = cipher
.encrypt(
nonce,
Payload {
msg: chunk_pt,
aad: &aad,
},
)
.expect("aes-gcm encrypt cannot fail with a 32-byte key");
let split = ct_with_tag.len() - TAG_LEN;
let (ct, tag) = ct_with_tag.split_at(split);
out.extend_from_slice(tag);
out.extend_from_slice(ct);
}
Ok(Bytes::from(out))
}
#[cfg(test)]
mod tests {
use super::*;
fn key32(seed: u8) -> Arc<SseKey> {
Arc::new(SseKey::from_bytes(&[seed; 32]).unwrap())
}
fn keyring_single(seed: u8) -> SseKeyring {
SseKeyring::new(1, key32(seed))
}
#[test]
fn roundtrip_basic_v1() {
let k = SseKey::from_bytes(&[7u8; 32]).unwrap();
let pt = b"the quick brown fox jumps over the lazy dog";
let ct = encrypt(&k, pt);
assert!(looks_encrypted(&ct));
assert_eq!(&ct[..4], SSE_MAGIC_V1);
assert_eq!(ct[4], ALGO_AES_256_GCM);
assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
let kr = SseKeyring::new(1, Arc::new(k));
let pt2 = decrypt(&ct, &kr).unwrap();
assert_eq!(pt2.as_ref(), pt);
}
#[test]
fn s4e2_roundtrip_active_key() {
let kr = keyring_single(7);
let pt = b"S4E2 active-key roundtrip";
let ct = encrypt_v2(pt, &kr);
assert_eq!(&ct[..4], SSE_MAGIC_V2);
assert_eq!(ct[4], ALGO_AES_256_GCM);
assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1, "key_id BE");
assert_eq!(ct[7], 0, "reserved byte");
assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
assert!(looks_encrypted(&ct));
let pt2 = decrypt(&ct, &kr).unwrap();
assert_eq!(pt2.as_ref(), pt);
}
#[test]
fn decrypt_s4e1_via_active_only_keyring() {
let k_arc = key32(11);
let legacy_ct = encrypt(&k_arc, b"v0.4 vintage object");
assert_eq!(&legacy_ct[..4], SSE_MAGIC_V1);
let kr = SseKeyring::new(1, Arc::clone(&k_arc));
let plain = decrypt(&legacy_ct, &kr).unwrap();
assert_eq!(plain.as_ref(), b"v0.4 vintage object");
}
#[test]
fn decrypt_s4e2_under_old_key_after_rotation() {
let k1 = key32(1);
let k2 = key32(2);
let mut kr_old = SseKeyring::new(1, Arc::clone(&k1));
let ct = encrypt_v2(b"old-rotation object", &kr_old);
assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1);
kr_old.add(2, Arc::clone(&k2));
let mut kr_new = SseKeyring::new(2, Arc::clone(&k2));
kr_new.add(1, Arc::clone(&k1));
let plain = decrypt(&ct, &kr_new).unwrap();
assert_eq!(plain.as_ref(), b"old-rotation object");
let new_ct = encrypt_v2(b"new-rotation object", &kr_new);
assert_eq!(u16::from_be_bytes([new_ct[5], new_ct[6]]), 2);
let plain_new = decrypt(&new_ct, &kr_new).unwrap();
assert_eq!(plain_new.as_ref(), b"new-rotation object");
}
#[test]
fn s4e2_unknown_key_id_errors() {
let kr = keyring_single(3); let kr_other = SseKeyring::new(99, key32(3));
let ct = encrypt_v2(b"x", &kr_other); let err = decrypt(&ct, &kr).unwrap_err();
assert!(
matches!(err, SseError::KeyNotInKeyring { id: 99 }),
"got {err:?}"
);
}
#[test]
fn s4e2_tampered_key_id_fails_auth() {
let kr = SseKeyring::new(1, key32(4));
let mut kr_with_2 = kr.clone();
kr_with_2.add(2, key32(5)); let mut ct = encrypt_v2(b"do not flip my key id", &kr).to_vec();
assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1);
ct[5] = 0;
ct[6] = 2;
let err = decrypt(&ct, &kr_with_2).unwrap_err();
assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
}
#[test]
fn s4e2_tampered_ciphertext_fails() {
let kr = SseKeyring::new(7, key32(9));
let mut ct = encrypt_v2(b"secret message v2", &kr).to_vec();
let last = ct.len() - 1;
ct[last] ^= 0x01;
let err = decrypt(&ct, &kr).unwrap_err();
assert!(matches!(err, SseError::DecryptFailed));
}
#[test]
fn s4e2_tampered_algo_byte_fails() {
let kr = SseKeyring::new(1, key32(2));
let mut ct = encrypt_v2(b"hi", &kr).to_vec();
ct[4] = 99;
let err = decrypt(&ct, &kr).unwrap_err();
assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
}
#[test]
fn wrong_key_fails_v1_via_keyring() {
let k1 = SseKey::from_bytes(&[1u8; 32]).unwrap();
let ct = encrypt(&k1, b"secret");
let kr_wrong = SseKeyring::new(1, Arc::new(SseKey::from_bytes(&[2u8; 32]).unwrap()));
let err = decrypt(&ct, &kr_wrong).unwrap_err();
assert!(matches!(err, SseError::DecryptFailed));
}
#[test]
fn rejects_short_body() {
let kr = SseKeyring::new(1, key32(1));
let err = decrypt(b"short", &kr).unwrap_err();
assert!(matches!(err, SseError::TooShort { got: 5 }));
}
#[test]
fn looks_encrypted_passthrough_returns_false() {
let f2 = b"S4F2\x01\x00\x00\x00........................................";
assert!(!looks_encrypted(f2));
assert!(!looks_encrypted(b""));
}
#[test]
fn looks_encrypted_detects_both_v1_and_v2() {
let kr = SseKeyring::new(1, key32(8));
let v1 = encrypt(&SseKey::from_bytes(&[8u8; 32]).unwrap(), b"x");
let v2 = encrypt_v2(b"x", &kr);
assert!(looks_encrypted(&v1));
assert!(looks_encrypted(&v2));
}
#[test]
fn key_from_hex_string() {
let bad =
SseKey::from_bytes(b"0102030405060708090a0b0c0d0e0f10111213141516171819202122232425")
.unwrap_err();
assert!(matches!(bad, SseError::BadKeyLength { .. }));
let good = b"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
let _ = SseKey::from_bytes(good).expect("64-char hex should parse");
}
#[test]
fn encrypt_v2_uses_random_nonce() {
let kr = SseKeyring::new(1, key32(3));
let pt = b"deterministic input";
let a = encrypt_v2(pt, &kr);
let b = encrypt_v2(pt, &kr);
assert_ne!(a, b, "nonce must be random per-call");
}
#[test]
fn keyring_active_and_get() {
let k1 = key32(1);
let k2 = key32(2);
let mut kr = SseKeyring::new(1, Arc::clone(&k1));
kr.add(2, Arc::clone(&k2));
let (id, active) = kr.active();
assert_eq!(id, 1);
assert_eq!(active.bytes, [1u8; 32]);
assert!(kr.get(2).is_some());
assert!(kr.get(3).is_none());
}
use base64::Engine as _;
fn cust_key(seed: u8) -> CustomerKeyMaterial {
let key = [seed; KEY_LEN];
let key_md5 = compute_key_md5(&key);
CustomerKeyMaterial { key, key_md5 }
}
#[test]
fn s4e3_roundtrip_happy_path() {
let m = cust_key(42);
let pt = b"top-secret SSE-C payload";
let ct = encrypt_with_source(
pt,
SseSource::CustomerKey {
key: &m.key,
key_md5: &m.key_md5,
},
);
assert_eq!(&ct[..4], SSE_MAGIC_V3);
assert_eq!(ct[4], ALGO_AES_256_GCM);
assert_eq!(&ct[5..5 + KEY_MD5_LEN], &m.key_md5);
assert_eq!(ct.len(), SSE_HEADER_BYTES_V3 + pt.len());
assert!(looks_encrypted(&ct));
let plain = decrypt(
&ct,
SseSource::CustomerKey {
key: &m.key,
key_md5: &m.key_md5,
},
)
.unwrap();
assert_eq!(plain.as_ref(), pt);
let plain2 = decrypt(&ct, &m).unwrap();
assert_eq!(plain2.as_ref(), pt);
}
#[test]
fn s4e3_wrong_key_yields_wrong_customer_key_error() {
let m = cust_key(1);
let other = cust_key(2);
let ct = encrypt_with_source(b"payload", (&m).into());
let err = decrypt(
&ct,
SseSource::CustomerKey {
key: &other.key,
key_md5: &other.key_md5,
},
)
.unwrap_err();
assert!(matches!(err, SseError::WrongCustomerKey), "got {err:?}");
}
#[test]
fn s4e3_tampered_stored_md5_is_caught() {
let m = cust_key(7);
let mut ct = encrypt_with_source(b"victim payload", (&m).into()).to_vec();
ct[5] ^= 0x55;
let err = decrypt(
&ct,
SseSource::CustomerKey {
key: &m.key,
key_md5: &m.key_md5,
},
)
.unwrap_err();
assert!(matches!(err, SseError::WrongCustomerKey), "got {err:?}");
}
#[test]
fn s4e3_tampered_md5_with_matching_supplied_md5_fails_aead() {
let m = cust_key(3);
let mut ct = encrypt_with_source(b"x", (&m).into()).to_vec();
ct[5] ^= 0xFF;
let mut bogus_md5 = m.key_md5;
bogus_md5[0] ^= 0xFF;
let err = decrypt(
&ct,
SseSource::CustomerKey {
key: &m.key,
key_md5: &bogus_md5,
},
)
.unwrap_err();
assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
}
#[test]
fn s4e3_tampered_ciphertext_fails_aead() {
let m = cust_key(8);
let mut ct = encrypt_with_source(b"sealed message", (&m).into()).to_vec();
let last = ct.len() - 1;
ct[last] ^= 0x01;
let err = decrypt(&ct, &m).unwrap_err();
assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
}
#[test]
fn s4e3_tampered_algo_byte_rejected() {
let m = cust_key(9);
let mut ct = encrypt_with_source(b"x", (&m).into()).to_vec();
ct[4] = 99;
let err = decrypt(&ct, &m).unwrap_err();
assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
}
#[test]
fn s4e3_uses_random_nonce() {
let m = cust_key(10);
let a = encrypt_with_source(b"deterministic input", (&m).into());
let b = encrypt_with_source(b"deterministic input", (&m).into());
assert_ne!(a, b, "nonce must be random per-call");
}
#[test]
fn parse_customer_key_headers_happy_path() {
let key = [11u8; KEY_LEN];
let md5 = compute_key_md5(&key);
let key_b64 = base64::engine::general_purpose::STANDARD.encode(key);
let md5_b64 = base64::engine::general_purpose::STANDARD.encode(md5);
let m = parse_customer_key_headers("AES256", &key_b64, &md5_b64).unwrap();
assert_eq!(m.key, key);
assert_eq!(m.key_md5, md5);
}
#[test]
fn parse_customer_key_headers_rejects_wrong_algorithm() {
let key = [1u8; KEY_LEN];
let md5 = compute_key_md5(&key);
let kb = base64::engine::general_purpose::STANDARD.encode(key);
let mb = base64::engine::general_purpose::STANDARD.encode(md5);
let err = parse_customer_key_headers("AES128", &kb, &mb).unwrap_err();
assert!(
matches!(err, SseError::CustomerKeyAlgorithmUnsupported { ref algo } if algo == "AES128"),
"got {err:?}"
);
let err2 = parse_customer_key_headers("aes256", &kb, &mb).unwrap_err();
assert!(
matches!(err2, SseError::CustomerKeyAlgorithmUnsupported { .. }),
"got {err2:?}"
);
}
#[test]
fn parse_customer_key_headers_rejects_wrong_key_length() {
let short_key = vec![5u8; 16]; let md5 = compute_key_md5(&short_key);
let kb = base64::engine::general_purpose::STANDARD.encode(&short_key);
let mb = base64::engine::general_purpose::STANDARD.encode(md5);
let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
assert!(
matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("key length")),
"got {err:?}"
);
}
#[test]
fn parse_customer_key_headers_rejects_wrong_md5_length() {
let key = [3u8; KEY_LEN];
let kb = base64::engine::general_purpose::STANDARD.encode(key);
let bad_md5 = vec![0u8; 15];
let mb = base64::engine::general_purpose::STANDARD.encode(bad_md5);
let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
assert!(
matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("MD5 length")),
"got {err:?}"
);
}
#[test]
fn parse_customer_key_headers_rejects_md5_mismatch() {
let key = [4u8; KEY_LEN];
let other = [5u8; KEY_LEN];
let kb = base64::engine::general_purpose::STANDARD.encode(key);
let wrong_md5 = compute_key_md5(&other);
let mb = base64::engine::general_purpose::STANDARD.encode(wrong_md5);
let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
assert!(
matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("MD5 does not match")),
"got {err:?}"
);
}
#[test]
fn parse_customer_key_headers_rejects_bad_base64() {
let valid_key = [0u8; KEY_LEN];
let md5 = compute_key_md5(&valid_key);
let mb = base64::engine::general_purpose::STANDARD.encode(md5);
let err = parse_customer_key_headers("AES256", "!!!not-base64!!!", &mb).unwrap_err();
assert!(
matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("base64")),
"got {err:?}"
);
let kb = base64::engine::general_purpose::STANDARD.encode(valid_key);
let err2 = parse_customer_key_headers("AES256", &kb, "??not-base64??").unwrap_err();
assert!(
matches!(err2, SseError::InvalidCustomerKey { reason } if reason.contains("base64")),
"got {err2:?}"
);
}
#[test]
fn parse_customer_key_headers_trims_whitespace() {
let key = [12u8; KEY_LEN];
let md5 = compute_key_md5(&key);
let kb = format!(
" {}\n",
base64::engine::general_purpose::STANDARD.encode(key)
);
let mb = format!(
"\t{} ",
base64::engine::general_purpose::STANDARD.encode(md5)
);
let m = parse_customer_key_headers("AES256", &kb, &mb).unwrap();
assert_eq!(m.key, key);
}
#[test]
fn back_compat_decrypt_s4e1_with_keyring_source() {
let k = key32(33);
let legacy_ct = encrypt(&k, b"v0.4 vintage object");
let kr = SseKeyring::new(1, Arc::clone(&k));
let plain = decrypt(&legacy_ct, &kr).unwrap();
assert_eq!(plain.as_ref(), b"v0.4 vintage object");
let plain2 = decrypt(&legacy_ct, SseSource::Keyring(&kr)).unwrap();
assert_eq!(plain2.as_ref(), b"v0.4 vintage object");
}
#[test]
fn back_compat_decrypt_s4e2_with_keyring_source() {
let kr = keyring_single(34);
let ct = encrypt_v2(b"v0.5 #29 object", &kr);
let plain = decrypt(&ct, &kr).unwrap();
assert_eq!(plain.as_ref(), b"v0.5 #29 object");
let ct2 = encrypt_with_source(b"v0.5 #29 object", SseSource::Keyring(&kr));
assert_eq!(&ct2[..4], SSE_MAGIC_V2);
let plain2 = decrypt(&ct2, &kr).unwrap();
assert_eq!(plain2.as_ref(), b"v0.5 #29 object");
}
#[test]
fn s4e2_blob_with_customer_key_source_is_rejected() {
let kr = keyring_single(50);
let ct = encrypt_v2(b"server-managed object", &kr);
let m = cust_key(99);
let err = decrypt(
&ct,
SseSource::CustomerKey {
key: &m.key,
key_md5: &m.key_md5,
},
)
.unwrap_err();
assert!(
matches!(err, SseError::CustomerKeyUnexpected),
"got {err:?}"
);
}
#[test]
fn s4e3_blob_with_keyring_source_is_rejected() {
let m = cust_key(60);
let ct = encrypt_with_source(b"customer-key object", (&m).into());
let kr = keyring_single(60);
let err = decrypt(&ct, &kr).unwrap_err();
assert!(matches!(err, SseError::CustomerKeyRequired), "got {err:?}");
}
#[test]
fn looks_encrypted_detects_s4e3() {
let m = cust_key(13);
let ct = encrypt_with_source(b"x", (&m).into());
assert!(looks_encrypted(&ct));
}
#[test]
fn s4e3_rejects_short_body() {
let mut short = Vec::new();
short.extend_from_slice(SSE_MAGIC_V3);
short.push(ALGO_AES_256_GCM);
short.extend_from_slice(&[0u8; SSE_HEADER_BYTES - 5]);
assert_eq!(short.len(), SSE_HEADER_BYTES);
let m = cust_key(1);
let err = decrypt(
&short,
SseSource::CustomerKey {
key: &m.key,
key_md5: &m.key_md5,
},
)
.unwrap_err();
assert!(matches!(err, SseError::TooShort { .. }), "got {err:?}");
}
#[test]
fn customer_key_material_debug_redacts_key() {
let m = cust_key(99);
let s = format!("{m:?}");
assert!(s.contains("redacted"));
assert!(!s.contains(&format!("{:?}", m.key.as_slice())));
}
#[test]
fn constant_time_eq_basic() {
assert!(constant_time_eq(b"abc", b"abc"));
assert!(!constant_time_eq(b"abc", b"abd"));
assert!(!constant_time_eq(b"abc", b"abcd"));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn compute_key_md5_known_vector() {
let got = compute_key_md5(b"");
let expected_hex = "d41d8cd98f00b204e9800998ecf8427e";
assert_eq!(hex_lower(&got), expected_hex);
}
use crate::kms::{KmsBackend, LocalKms};
use std::collections::HashMap;
use std::path::PathBuf;
fn local_kms_with(key_ids: &[(&str, [u8; 32])]) -> LocalKms {
let mut keks: HashMap<String, [u8; 32]> = HashMap::new();
for (id, k) in key_ids {
keks.insert((*id).to_string(), *k);
}
LocalKms::from_keks(PathBuf::from("/tmp/none"), keks)
}
#[tokio::test]
async fn s4e4_roundtrip_via_local_kms() {
let kms = local_kms_with(&[("alpha", [42u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let pt = b"SSE-KMS envelope payload across the S4E4 frame";
let ct = encrypt_with_source(
pt,
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
);
assert_eq!(&ct[..4], SSE_MAGIC_V4);
assert_eq!(ct[4], ALGO_AES_256_GCM);
let key_id_len = ct[5] as usize;
assert_eq!(key_id_len, "alpha".len());
assert_eq!(&ct[6..6 + key_id_len], b"alpha");
assert!(looks_encrypted(&ct));
assert_eq!(peek_magic(&ct), Some("S4E4"));
let plain = decrypt_with_kms(&ct, &kms).await.unwrap();
assert_eq!(plain.as_ref(), pt);
}
#[tokio::test]
async fn s4e4_tampered_key_id_fails_aead() {
let kms = local_kms_with(&[("alpha", [1u8; 32]), ("beta", [2u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let mut ct = encrypt_with_source(
b"do not redirect",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
)
.to_vec();
let key_id_off = 6;
ct[key_id_off] = b'b';
let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
assert!(
matches!(
err,
SseError::KmsBackend(crate::kms::KmsError::UnwrapFailed { .. })
| SseError::KmsBackend(crate::kms::KmsError::KeyNotFound { .. })
),
"got {err:?}"
);
}
#[tokio::test]
async fn s4e4_tampered_key_id_to_real_other_id_still_fails() {
let kms = local_kms_with(&[("alpha", [1u8; 32]), ("beta", [2u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let mut ct = encrypt_with_source(
b"redirect attempt",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
)
.to_vec();
let key_id_off = 6;
ct[key_id_off..key_id_off + 5].copy_from_slice(b"beta_");
let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
assert!(
matches!(
err,
SseError::KmsBackend(crate::kms::KmsError::KeyNotFound { .. })
),
"got {err:?}"
);
}
#[tokio::test]
async fn s4e4_tampered_wrapped_dek_fails_unwrap() {
let kms = local_kms_with(&[("k", [3u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let mut ct = encrypt_with_source(
b"target body",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
)
.to_vec();
let key_id_len = ct[5] as usize;
let wrapped_len_off = 6 + key_id_len;
let wrapped_off = wrapped_len_off + 4;
let mid = wrapped_off + (wrapped.ciphertext.len() / 2);
ct[mid] ^= 0xFF;
let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
assert!(
matches!(
err,
SseError::KmsBackend(crate::kms::KmsError::UnwrapFailed { .. })
),
"got {err:?}"
);
}
#[tokio::test]
async fn s4e4_tampered_ciphertext_fails_aead() {
let kms = local_kms_with(&[("k", [4u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let mut ct = encrypt_with_source(
b"sealed body",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
)
.to_vec();
let last = ct.len() - 1;
ct[last] ^= 0x01;
let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
}
#[tokio::test]
async fn s4e4_uses_random_nonce_and_dek_per_put() {
let kms = local_kms_with(&[("k", [5u8; 32])]);
let (dek1_vec, wrapped1) = kms.generate_dek("k").await.unwrap();
let (dek2_vec, wrapped2) = kms.generate_dek("k").await.unwrap();
let mut dek1 = [0u8; 32];
dek1.copy_from_slice(&dek1_vec);
let mut dek2 = [0u8; 32];
dek2.copy_from_slice(&dek2_vec);
let pt = b"deterministic input";
let a = encrypt_with_source(
pt,
SseSource::Kms {
dek: &dek1,
wrapped: &wrapped1,
},
);
let b = encrypt_with_source(
pt,
SseSource::Kms {
dek: &dek2,
wrapped: &wrapped2,
},
);
assert_ne!(a, b);
let plain_a = decrypt_with_kms(&a, &kms).await.unwrap();
let plain_b = decrypt_with_kms(&b, &kms).await.unwrap();
assert_eq!(plain_a.as_ref(), pt);
assert_eq!(plain_b.as_ref(), pt);
}
#[tokio::test]
async fn s4e4_sync_decrypt_returns_kms_async_required() {
let kms = local_kms_with(&[("k", [6u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let ct = encrypt_with_source(
b"async only",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
);
let kr = SseKeyring::new(1, key32(0));
let err = decrypt(&ct, &kr).unwrap_err();
assert!(matches!(err, SseError::KmsAsyncRequired), "got {err:?}");
}
#[test]
fn back_compat_s4e1_e2_e3_still_decrypt_via_sync() {
let k = key32(7);
let v1 = encrypt(&k, b"v0.4 vintage");
let kr = SseKeyring::new(1, Arc::clone(&k));
assert_eq!(decrypt(&v1, &kr).unwrap().as_ref(), b"v0.4 vintage");
let v2 = encrypt_v2(b"v0.5 #29 vintage", &kr);
assert_eq!(decrypt(&v2, &kr).unwrap().as_ref(), b"v0.5 #29 vintage");
let m = cust_key(7);
let v3 = encrypt_with_source(b"v0.5 #27 vintage", (&m).into());
assert_eq!(decrypt(&v3, &m).unwrap().as_ref(), b"v0.5 #27 vintage");
}
#[test]
fn peek_magic_distinguishes_all_variants() {
let k = key32(9);
let v1 = encrypt(&k, b"x");
assert_eq!(peek_magic(&v1), Some("S4E1"));
let kr = SseKeyring::new(1, Arc::clone(&k));
let v2 = encrypt_v2(b"x", &kr);
assert_eq!(peek_magic(&v2), Some("S4E2"));
let m = cust_key(9);
let v3 = encrypt_with_source(b"x", (&m).into());
assert_eq!(peek_magic(&v3), Some("S4E3"));
let mut v4 = Vec::new();
v4.extend_from_slice(SSE_MAGIC_V4);
v4.extend_from_slice(&[0u8; 40]);
assert_eq!(peek_magic(&v4), Some("S4E4"));
assert!(peek_magic(b"NOPE").is_none());
assert!(peek_magic(b"short").is_none());
assert!(peek_magic(&[0u8; 100]).is_none());
}
#[tokio::test]
async fn s4e4_truncated_frame_errors_cleanly() {
let truncated = b"S4E4\x01\x05hi";
let kms = local_kms_with(&[("k", [1u8; 32])]);
let err = decrypt_with_kms(truncated, &kms).await.unwrap_err();
assert!(
matches!(err, SseError::KmsFrameTooShort { .. }),
"got {err:?}"
);
}
#[tokio::test]
async fn s4e4_oob_key_id_len_errors() {
let mut body = Vec::new();
body.extend_from_slice(SSE_MAGIC_V4);
body.push(ALGO_AES_256_GCM);
body.push(200u8); body.extend_from_slice(&[0u8; 50]);
let kms = local_kms_with(&[("k", [1u8; 32])]);
let err = decrypt_with_kms(&body, &kms).await.unwrap_err();
assert!(
matches!(err, SseError::KmsFrameFieldOob { .. }),
"got {err:?}"
);
}
#[tokio::test]
async fn s4e4_via_keyring_source_into_sync_decrypt_is_kms_async_required() {
let kms = local_kms_with(&[("k", [9u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let ct = encrypt_with_source(
b"x",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
);
let m = cust_key(1);
let err = decrypt(&ct, &m).unwrap_err();
assert!(matches!(err, SseError::KmsAsyncRequired), "got {err:?}");
}
#[tokio::test]
async fn s4e4_looks_encrypted_passthrough_returns_false_for_synthetic() {
let mut not_s4e4 = Vec::new();
not_s4e4.extend_from_slice(b"S4F4");
not_s4e4.extend_from_slice(&[0u8; 60]);
assert!(!looks_encrypted(¬_s4e4));
assert_eq!(peek_magic(¬_s4e4), None);
}
#[tokio::test]
async fn s4e4_aad_length_prefix_prevents_byte_shifting() {
let kms = local_kms_with(&[("kk", [11u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("kk").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let mut ct = encrypt_with_source(
b"length-shift defense",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
)
.to_vec();
let key_id_len = ct[5] as usize;
let wrapped_len_off = 6 + key_id_len;
let original_len = u32::from_be_bytes([
ct[wrapped_len_off],
ct[wrapped_len_off + 1],
ct[wrapped_len_off + 2],
ct[wrapped_len_off + 3],
]);
let new_len = (original_len - 1).to_be_bytes();
ct[wrapped_len_off..wrapped_len_off + 4].copy_from_slice(&new_len);
let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
assert!(
matches!(
err,
SseError::KmsBackend(_)
| SseError::DecryptFailed
| SseError::KmsFrameFieldOob { .. }
| SseError::KmsFrameTooShort { .. }
),
"got {err:?}"
);
}
use futures::StreamExt;
async fn collect_chunks(
s: impl futures::Stream<Item = Result<Bytes, SseError>>,
) -> Result<Vec<Bytes>, SseError> {
let mut out = Vec::new();
let mut s = std::pin::pin!(s);
while let Some(item) = s.next().await {
out.push(item?);
}
Ok(out)
}
#[test]
fn s4e6_encrypt_layout_10mb_at_1mib() {
let kr = keyring_single(0x42);
let chunk_size = 1024 * 1024;
let pt_len = 10 * 1024 * 1024;
let pt = vec![0xAB_u8; pt_len];
let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).expect("encrypt ok");
assert_eq!(&ct[..4], SSE_MAGIC_V6, "new PUTs emit S4E6 (v0.8.1 #57)");
assert_eq!(ct[4], ALGO_AES_256_GCM);
assert_eq!(
u16::from_be_bytes([ct[5], ct[6]]),
1,
"key_id BE = active id"
);
assert_eq!(ct[7], 0, "reserved must be 0");
assert_eq!(
u32::from_be_bytes([ct[8], ct[9], ct[10], ct[11]]),
chunk_size as u32,
"chunk_size BE",
);
assert_eq!(
u32::from_be_bytes([ct[12], ct[13], ct[14], ct[15]]),
10,
"chunk_count BE — 10 MiB / 1 MiB = 10 (no remainder)",
);
assert_eq!(&ct[16..24].len(), &8, "S4E6 salt slot is 8 bytes");
assert_ne!(
&ct[16..24],
&[0u8; 8],
"S4E6 salt must be random, not zeros"
);
assert_eq!(
ct.len(),
S4E6_HEADER_BYTES + 10 * S4E6_PER_CHUNK_OVERHEAD + pt_len,
"total = header (24) + 10 tags + plaintext",
);
assert!(looks_encrypted(&ct), "looks_encrypted must accept S4E6");
assert_eq!(peek_magic(&ct), Some("S4E6"));
}
#[tokio::test]
async fn s4e6_decrypt_chunked_stream_byte_equal() {
let kr = keyring_single(0x55);
let pt: Vec<u8> = (0..(10 * 1024 * 1024_u32))
.map(|i| (i & 0xFF) as u8)
.collect();
let ct = encrypt_v2_chunked(&pt, &kr, 1024 * 1024).unwrap();
assert_eq!(&ct[..4], SSE_MAGIC_V6, "new emit is S4E6");
let stream = decrypt_chunked_stream(ct, &kr);
let chunks = collect_chunks(stream).await.expect("stream ok");
assert_eq!(chunks.len(), 10, "10 chunks expected for 10 MiB / 1 MiB");
let mut joined = Vec::with_capacity(pt.len());
for c in chunks {
joined.extend_from_slice(&c);
}
assert_eq!(joined.len(), pt.len(), "byte length matches");
assert_eq!(joined, pt, "byte-equal round-trip");
}
#[tokio::test]
async fn s4e6_single_chunk_for_small_object() {
let kr = keyring_single(0x77);
let pt = b"tiny payload, smaller than chunk_size";
let ct = encrypt_v2_chunked(pt, &kr, 1024 * 1024).unwrap();
assert_eq!(
u32::from_be_bytes([ct[12], ct[13], ct[14], ct[15]]),
1,
"small plaintext = single chunk",
);
let stream = decrypt_chunked_stream(ct, &kr);
let chunks = collect_chunks(stream).await.expect("stream ok");
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].as_ref(), pt);
}
#[tokio::test]
async fn s4e6_tampered_chunk_n_reports_chunk_index() {
let kr = keyring_single(0x91);
let chunk_size = 1024;
let pt = vec![0xCD_u8; chunk_size * 8]; let mut ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap().to_vec();
let target = S4E6_HEADER_BYTES + 3 * (TAG_LEN + chunk_size) + TAG_LEN;
ct[target] ^= 0x42;
let stream = decrypt_chunked_stream(bytes::Bytes::from(ct), &kr);
let mut s = std::pin::pin!(stream);
for expected_i in 0..3_u32 {
let item = s.next().await.expect("yield");
item.unwrap_or_else(|e| panic!("chunk {expected_i}: {e:?}"));
}
let err = s.next().await.expect("yield error").unwrap_err();
assert!(
matches!(err, SseError::ChunkAuthFailed { chunk_index: 3 }),
"got {err:?}",
);
}
#[tokio::test]
async fn s4e5_back_compat_s4e2_blob_rejected_with_clear_error() {
let kr = keyring_single(0x12);
let s4e2 = encrypt_v2(b"a v2 blob, not chunked", &kr);
let stream = decrypt_chunked_stream(s4e2, &kr);
let result = collect_chunks(stream).await;
let err = result.unwrap_err();
assert!(matches!(err, SseError::BadMagic { .. }), "got {err:?}");
}
#[test]
fn s4e6_salt_randomness_smoke() {
let kr = keyring_single(0x33);
let mut salts = std::collections::HashSet::new();
let n = 1024;
for _ in 0..n {
let ct = encrypt_v2_chunked(b"x", &kr, 64).unwrap();
let mut salt = [0u8; 8];
salt.copy_from_slice(&ct[16..24]);
salts.insert(salt);
}
assert!(
salts.len() > n / 2,
"expected most of the {n} salts to be unique (got {} unique)",
salts.len(),
);
}
#[test]
fn s4e6_chunk_size_zero_invalid() {
let kr = keyring_single(0x66);
let err = encrypt_v2_chunked(b"hi", &kr, 0).unwrap_err();
assert!(matches!(err, SseError::ChunkSizeInvalid));
}
#[tokio::test]
async fn s4e6_truncated_body_reports_frame_truncated() {
let kr = keyring_single(0xA1);
let chunk_size = 256;
let pt = vec![0u8; chunk_size * 4];
let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap();
let trunc = S4E6_HEADER_BYTES + 2 * (TAG_LEN + chunk_size) + 8;
let truncated = bytes::Bytes::copy_from_slice(&ct[..trunc]);
let stream = decrypt_chunked_stream(truncated, &kr);
let result = collect_chunks(stream).await;
let err = result.unwrap_err();
assert!(
matches!(err, SseError::ChunkFrameTruncated { .. }),
"got {err:?}",
);
}
#[test]
fn s4e6_decrypt_buffered_round_trip_via_top_level_decrypt() {
let kr = keyring_single(0xDE);
let pt = b"buffered sync decrypt path".repeat(32);
let ct = encrypt_v2_chunked(&pt, &kr, 13).unwrap();
let plain = decrypt(&ct, &kr).expect("buffered S4E6 decrypt ok");
assert_eq!(plain.as_ref(), pt.as_slice());
}
#[tokio::test]
async fn s4e6_unknown_key_id_in_frame_errors() {
let kr_put = SseKeyring::new(7, key32(0xCC));
let kr_get = keyring_single(0xCC); let ct = encrypt_v2_chunked(b"orphan key", &kr_put, 64).unwrap();
let err = decrypt(&ct, &kr_get).unwrap_err();
assert!(
matches!(err, SseError::KeyNotInKeyring { id: 7 }),
"got {err:?}"
);
let stream = decrypt_chunked_stream(ct, &kr_get);
let result = collect_chunks(stream).await;
assert!(
matches!(result, Err(SseError::KeyNotInKeyring { id: 7 })),
"got {result:?}",
);
}
#[tokio::test]
async fn s4e6_final_chunk_smaller_than_chunk_size() {
let kr = keyring_single(0xEF);
let chunk_size = 100;
let pt: Vec<u8> = (0..250_u32).map(|i| i as u8).collect();
let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap();
assert_eq!(
u32::from_be_bytes([ct[12], ct[13], ct[14], ct[15]]),
3,
"ceil(250/100) = 3 chunks",
);
assert_eq!(ct.len(), S4E6_HEADER_BYTES + 48 + 250);
let stream = decrypt_chunked_stream(ct, &kr);
let chunks = collect_chunks(stream).await.expect("stream ok");
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].len(), 100);
assert_eq!(chunks[1].len(), 100);
assert_eq!(chunks[2].len(), 50, "final chunk is the remainder");
let joined: Vec<u8> = chunks.iter().flat_map(|c| c.iter().copied()).collect();
assert_eq!(joined, pt);
}
#[test]
fn s4e6_back_compat_read_s4e5_blob() {
let kr = keyring_single(0x57);
let pt = b"v0.8.0 vintage chunked SSE-S4 object".repeat(64);
let s4e5 = encrypt_v2_chunked_s4e5_for_test(&pt, &kr, 91).unwrap();
assert_eq!(&s4e5[..4], SSE_MAGIC_V5, "fixture must be S4E5");
assert_eq!(peek_magic(&s4e5), Some("S4E5"));
let plain_sync = decrypt(&s4e5, &kr).expect("sync S4E5 decrypt ok");
assert_eq!(plain_sync.as_ref(), pt.as_slice());
let collected = futures::executor::block_on(async {
let stream = decrypt_chunked_stream(s4e5.clone(), &kr);
collect_chunks(stream).await
})
.expect("stream S4E5 decrypt ok");
let mut joined = Vec::with_capacity(pt.len());
for c in collected {
joined.extend_from_slice(&c);
}
assert_eq!(joined, pt, "S4E5 streaming round-trip byte-equal");
}
#[test]
fn s4e6_layout_24_bytes_header() {
assert_eq!(S4E6_HEADER_BYTES, 24);
assert_eq!(S4E6_PER_CHUNK_OVERHEAD, TAG_LEN);
assert_eq!(S4E6_HEADER_BYTES, S4E5_HEADER_BYTES + 4);
}
#[test]
fn s4e6_parse_header_round_trip() {
let kr = keyring_single(0xAB);
let chunk_size = 256;
let pt = vec![1u8; 7 * chunk_size];
let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap();
let hdr = parse_s4e6_header(&ct).expect("parse ok");
assert_eq!(hdr.key_id, 1);
assert_eq!(hdr.chunk_size, chunk_size as u32);
assert_eq!(hdr.chunk_count, 7);
assert_eq!(hdr.salt.len(), 8);
let bogus = b"S4E2\x01\x00\x00\x00........................";
let err = parse_s4e6_header(bogus).unwrap_err();
assert!(matches!(err, SseError::BadMagic { .. }), "got {err:?}");
let err2 = parse_s4e6_header(&ct[..10]).unwrap_err();
assert!(
matches!(err2, SseError::ChunkFrameTruncated { .. }),
"got {err2:?}"
);
}
#[test]
fn s4e6_salt_uniqueness_smoke_16m() {
let kr = keyring_single(0xA6);
let mut salts = std::collections::HashSet::with_capacity(16384);
let n = 16384_usize;
let mut collisions_top4 = 0usize;
let mut top4_seen = std::collections::HashSet::with_capacity(16384);
for _ in 0..n {
let ct = encrypt_v2_chunked(b"x", &kr, 64).unwrap();
let mut salt = [0u8; 8];
salt.copy_from_slice(&ct[16..24]);
salts.insert(salt);
let mut top4 = [0u8; 4];
top4.copy_from_slice(&salt[..4]);
if !top4_seen.insert(top4) {
collisions_top4 += 1;
}
}
assert_eq!(
salts.len(),
n,
"all 8-byte salts must be unique across {n} PUTs (got {} unique)",
salts.len(),
);
eprintln!(
"s4e6_salt_uniqueness_smoke_16m: 16k PUTs, full 8B salts \
all unique ({}/{}), simulated 4B-truncated salt yielded \
{} collisions (this is what S4E5 would have shipped)",
salts.len(),
n,
collisions_top4,
);
}
#[test]
fn s4e6_max_chunks_24bit() {
assert_eq!(S4E6_MAX_CHUNK_COUNT, (1u32 << 24) - 1);
assert_eq!(S4E6_MAX_CHUNK_COUNT, 16_777_215);
let kr = keyring_single(0xC4);
let pt = vec![0u8; (S4E6_MAX_CHUNK_COUNT as usize) + 1]; let err = encrypt_v2_chunked(&pt, &kr, 1).unwrap_err();
assert!(
matches!(
err,
SseError::ChunkCountTooLarge {
got: 16_777_216,
max: 16_777_215
}
),
"got {err:?}",
);
let pt_ok = vec![0u8; 1023];
let ct = encrypt_v2_chunked(&pt_ok, &kr, 1).expect("under-cap PUT must succeed");
let hdr = parse_s4e6_header(&ct).unwrap();
assert_eq!(hdr.chunk_count, 1023);
let mut tampered = ct.to_vec();
let bad = (S4E6_MAX_CHUNK_COUNT + 1).to_be_bytes();
tampered[12..16].copy_from_slice(&bad);
let err2 = parse_s4e6_header(&tampered).unwrap_err();
assert!(
matches!(
err2,
SseError::ChunkCountTooLarge {
got: 16_777_216,
max: 16_777_215
}
),
"got {err2:?}",
);
}
#[test]
fn s4e6_nonce_v6_layout() {
let salt = [0xAA_u8; 8];
let n0 = nonce_v6(&salt, 0);
assert_eq!(n0[0], b'E');
assert_eq!(&n0[1..9], &salt);
assert_eq!(&n0[9..12], &[0, 0, 0]);
let n1 = nonce_v6(&salt, 1);
assert_eq!(&n1[9..12], &[0, 0, 1]);
let n_mid = nonce_v6(&salt, 0x123456);
assert_eq!(&n_mid[9..12], &[0x12, 0x34, 0x56]);
let n_max = nonce_v6(&salt, S4E6_MAX_CHUNK_COUNT);
assert_eq!(&n_max[9..12], &[0xFF, 0xFF, 0xFF]);
}
#[tokio::test]
async fn s4e6_tampered_salt_byte_fails_aead() {
let kr = keyring_single(0xB6);
let pt = b"salt-in-aad coverage".repeat(64);
let mut ct = encrypt_v2_chunked(&pt, &kr, 128).unwrap().to_vec();
ct[20] ^= 0x01;
let err = decrypt(&ct, &kr).unwrap_err();
assert!(
matches!(err, SseError::ChunkAuthFailed { chunk_index: 0 }),
"got {err:?}",
);
}
fn synth_s4e6_header(chunk_size: u32, chunk_count: u32) -> Vec<u8> {
let mut blob = Vec::with_capacity(S4E6_HEADER_BYTES);
blob.extend_from_slice(SSE_MAGIC_V6);
blob.push(ALGO_AES_256_GCM);
blob.extend_from_slice(&1_u16.to_be_bytes()); blob.push(0); blob.extend_from_slice(&chunk_size.to_be_bytes());
blob.extend_from_slice(&chunk_count.to_be_bytes());
blob.extend_from_slice(&[0u8; 8]); debug_assert_eq!(blob.len(), S4E6_HEADER_BYTES);
blob
}
#[test]
fn s4e6_header_claims_huge_size_rejected_pre_alloc() {
let kr = keyring_single(0x01);
let chunk_size: u32 = 1 << 30; let chunk_count: u32 = 100;
let mut blob = synth_s4e6_header(chunk_size, chunk_count);
blob.extend_from_slice(&[0u8; 100]);
let err = decrypt_chunked_buffered_default(&blob, &kr).unwrap_err();
assert!(
matches!(err, SseError::ChunkFrameTooLarge { .. }),
"expected ChunkFrameTooLarge (declared 100 GiB > 5 GiB cap), got {err:?}",
);
let err2 = decrypt_chunked_buffered(&blob, &kr, 1024 * 1024).unwrap_err();
assert!(
matches!(err2, SseError::ChunkFrameTooLarge { .. }),
"expected ChunkFrameTooLarge under tighter cap, got {err2:?}",
);
}
#[test]
fn s4e6_header_chunk_size_x_chunk_count_overflows_u64() {
let kr = keyring_single(0x02);
let mut blob = Vec::with_capacity(S4E5_HEADER_BYTES);
blob.extend_from_slice(SSE_MAGIC_V5);
blob.push(ALGO_AES_256_GCM);
blob.extend_from_slice(&1_u16.to_be_bytes());
blob.push(0);
blob.extend_from_slice(&u32::MAX.to_be_bytes()); blob.extend_from_slice(&u32::MAX.to_be_bytes()); blob.extend_from_slice(&[0u8; 4]); debug_assert_eq!(blob.len(), S4E5_HEADER_BYTES);
let err = decrypt_chunked_buffered_default(&blob, &kr).unwrap_err();
assert!(
matches!(err, SseError::ChunkFrameTooLarge { .. }),
"expected ChunkFrameTooLarge for u64 overflow, got {err:?}",
);
let direct = parse_chunked_header(&blob, usize::MAX).unwrap_err();
assert!(
matches!(direct, SseError::ChunkFrameTooLarge { .. }),
"streaming path: expected ChunkFrameTooLarge, got {direct:?}",
);
}
#[test]
fn s4e6_header_within_max_body_bytes_passes() {
let kr = keyring_single(0x03);
let chunk_size: u32 = 1024 * 1024; let chunk_count: u32 = 100;
let mut blob = synth_s4e6_header(chunk_size, chunk_count);
let chunk_array_size =
(chunk_count as usize) * (S4E6_PER_CHUNK_OVERHEAD + chunk_size as usize);
blob.resize(blob.len() + chunk_array_size, 0);
let err = decrypt_chunked_buffered(&blob, &kr, DEFAULT_MAX_BODY_BYTES).unwrap_err();
assert!(
matches!(err, SseError::ChunkAuthFailed { chunk_index: 0 }),
"expected ChunkAuthFailed (guard let it through), got {err:?}",
);
}
#[test]
fn s4e6_header_exceeds_max_body_bytes_rejected() {
let kr = keyring_single(0x04);
let chunk_size: u32 = 1024 * 1024; let chunk_count: u32 = 6000;
let blob = synth_s4e6_header(chunk_size, chunk_count);
let err = decrypt_chunked_buffered(&blob, &kr, DEFAULT_MAX_BODY_BYTES).unwrap_err();
assert!(
matches!(err, SseError::ChunkFrameTooLarge { .. }),
"expected ChunkFrameTooLarge (6 GiB declared > 5 GiB cap), got {err:?}",
);
let chunk_size_b: u32 = 1024 * 1024; let chunk_count_b: u32 = 100;
let mut blob_b = synth_s4e6_header(chunk_size_b, chunk_count_b);
let pad_b = (chunk_count_b as usize) * (S4E6_PER_CHUNK_OVERHEAD + chunk_size_b as usize);
blob_b.resize(blob_b.len() + pad_b, 0);
let err_b = decrypt_chunked_buffered(&blob_b, &kr, 1024 * 1024).unwrap_err();
assert!(
matches!(err_b, SseError::ChunkFrameTooLarge { .. }),
"expected ChunkFrameTooLarge (cap < declared), got {err_b:?}",
);
}
#[test]
fn s4e6_random_header_never_panics() {
use rand::{Rng, SeedableRng, rngs::StdRng};
let mut rng = StdRng::seed_from_u64(0xC0FF_EE64_6464_64DE);
let mut max_body_bytes_choices = [
0_usize,
1024,
1024 * 1024,
DEFAULT_MAX_BODY_BYTES,
usize::MAX,
]
.iter()
.copied()
.cycle();
for _ in 0..100_000 {
let body_len = rng.gen_range(0..=256_usize);
let mut body = vec![0u8; body_len];
rng.fill(body.as_mut_slice());
if body_len >= 4 && rng.gen_bool(0.25) {
if rng.gen_bool(0.5) {
body[..4].copy_from_slice(SSE_MAGIC_V5);
} else {
body[..4].copy_from_slice(SSE_MAGIC_V6);
}
}
let max_cap = max_body_bytes_choices.next().unwrap();
let _ = parse_chunked_header(&body, max_cap);
}
}
#[test]
fn s4e5_extreme_overflow_chunk_count_u32_max() {
let kr = keyring_single(0x05);
let mut blob = Vec::with_capacity(S4E5_HEADER_BYTES);
blob.extend_from_slice(SSE_MAGIC_V5);
blob.push(ALGO_AES_256_GCM);
blob.extend_from_slice(&1_u16.to_be_bytes());
blob.push(0);
blob.extend_from_slice(&u32::MAX.to_be_bytes());
blob.extend_from_slice(&u32::MAX.to_be_bytes());
blob.extend_from_slice(&[0u8; 4]);
let err = decrypt_chunked_buffered_default(&blob, &kr).unwrap_err();
assert!(
matches!(err, SseError::ChunkFrameTooLarge { .. }),
"expected ChunkFrameTooLarge for extreme overflow, got {err:?}",
);
}
}