use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use aes_gcm::aead::{Aead, KeyInit, Payload};
use aes_gcm::{Aes256Gcm, Key, Nonce};
use bytes::Bytes;
use md5::{Digest as Md5Digest, Md5};
use rand::RngCore;
use thiserror::Error;
use crate::kms::{KmsBackend, KmsError, WrappedDek};
pub const SSE_MAGIC_V1: &[u8; 4] = b"S4E1";
pub const SSE_MAGIC_V2: &[u8; 4] = b"S4E2";
pub const SSE_MAGIC_V3: &[u8; 4] = b"S4E3";
pub const SSE_MAGIC_V4: &[u8; 4] = b"S4E4";
pub const SSE_MAGIC: &[u8; 4] = SSE_MAGIC_V1;
pub const SSE_HEADER_BYTES: usize = 4 + 1 + 3 + 12 + 16; pub const SSE_HEADER_BYTES_V3: usize = 4 + 1 + KEY_MD5_LEN + 12 + 16; pub const ALGO_AES_256_GCM: u8 = 1;
const NONCE_LEN: usize = 12;
const TAG_LEN: usize = 16;
const KEY_LEN: usize = 32;
const KEY_MD5_LEN: usize = 16;
pub const SSE_C_ALGORITHM: &str = "AES256";
#[derive(Debug, Error)]
pub enum SseError {
#[error("SSE key file {path:?}: {source}")]
KeyFileIo {
path: std::path::PathBuf,
source: std::io::Error,
},
#[error(
"SSE key file must be exactly 32 raw bytes (or 64-char hex / 44-char base64); got {got} bytes after parse"
)]
BadKeyLength { got: usize },
#[error("SSE-encrypted body too short ({got} bytes; need at least {SSE_HEADER_BYTES})")]
TooShort { got: usize },
#[error("SSE bad magic: expected S4E1/S4E2/S4E3/S4E4, got {got:?}")]
BadMagic { got: [u8; 4] },
#[error("SSE unsupported algo tag: {tag} (this build only knows AES-256-GCM = 1)")]
UnsupportedAlgo { tag: u8 },
#[error(
"SSE key_id {id} (S4E2 frame) not present in keyring; rotation history likely incomplete"
)]
KeyNotInKeyring { id: u16 },
#[error("SSE decryption / authentication failed (key mismatch or ciphertext tampered with)")]
DecryptFailed,
#[error("SSE-C key MD5 fingerprint mismatch — client supplied a different key than PUT")]
WrongCustomerKey,
#[error("SSE-C customer-key headers invalid: {reason}")]
InvalidCustomerKey { reason: &'static str },
#[error("SSE-C algorithm {algo:?} unsupported (only {SSE_C_ALGORITHM:?} is allowed)")]
CustomerKeyAlgorithmUnsupported { algo: String },
#[error("S4E3 frame requires SseSource::CustomerKey; got Keyring")]
CustomerKeyRequired,
#[error("S4E1/S4E2 frame stored without SSE-C; SseSource::CustomerKey is unexpected")]
CustomerKeyUnexpected,
#[error(
"S4E4 (SSE-KMS) body requires async decrypt — call decrypt_with_kms() instead of decrypt()"
)]
KmsAsyncRequired,
#[error("S4E4 frame too short ({got} bytes; need at least {min})")]
KmsFrameTooShort { got: usize, min: usize },
#[error("S4E4 frame field length out of bounds: {what}")]
KmsFrameFieldOob { what: &'static str },
#[error("S4E4 key_id is not valid UTF-8")]
KmsKeyIdNotUtf8,
#[error(
"S4E4 SseSource::Kms wrapped DEK key_id {supplied:?} doesn't match frame key_id {stored:?}"
)]
KmsWrappedDekMismatch {
supplied: String,
stored: String,
},
#[error("S4E4 frame requires SseSource::Kms")]
KmsRequired,
#[error("KMS unwrap: {0}")]
KmsBackend(#[from] KmsError),
}
pub struct SseKey {
pub bytes: [u8; 32],
}
impl SseKey {
pub fn from_path(path: &Path) -> Result<Self, SseError> {
let raw = std::fs::read(path).map_err(|source| SseError::KeyFileIo {
path: path.to_path_buf(),
source,
})?;
Self::from_bytes(&raw)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, SseError> {
if bytes.len() == KEY_LEN {
let mut k = [0u8; KEY_LEN];
k.copy_from_slice(bytes);
return Ok(Self { bytes: k });
}
let s = std::str::from_utf8(bytes).unwrap_or("").trim();
if s.len() == KEY_LEN * 2 && s.chars().all(|c| c.is_ascii_hexdigit()) {
let mut k = [0u8; KEY_LEN];
for (i, k_byte) in k.iter_mut().enumerate() {
*k_byte = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16)
.map_err(|_| SseError::BadKeyLength { got: bytes.len() })?;
}
return Ok(Self { bytes: k });
}
if let Ok(decoded) =
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, s.as_bytes())
&& decoded.len() == KEY_LEN
{
let mut k = [0u8; KEY_LEN];
k.copy_from_slice(&decoded);
return Ok(Self { bytes: k });
}
Err(SseError::BadKeyLength { got: bytes.len() })
}
fn as_aes_key(&self) -> &Key<Aes256Gcm> {
Key::<Aes256Gcm>::from_slice(&self.bytes)
}
}
impl std::fmt::Debug for SseKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SseKey")
.field("len", &KEY_LEN)
.field("key", &"<redacted>")
.finish()
}
}
#[derive(Clone)]
pub struct SseKeyring {
active: u16,
keys: HashMap<u16, Arc<SseKey>>,
}
impl SseKeyring {
pub fn new(active: u16, key: Arc<SseKey>) -> Self {
let mut keys = HashMap::new();
keys.insert(active, key);
Self { active, keys }
}
pub fn add(&mut self, id: u16, key: Arc<SseKey>) {
self.keys.insert(id, key);
}
pub fn active(&self) -> (u16, &SseKey) {
let id = self.active;
let key = self
.keys
.get(&id)
.expect("active key id must be present in keyring (constructor invariant)");
(id, key.as_ref())
}
pub fn get(&self, id: u16) -> Option<&SseKey> {
self.keys.get(&id).map(Arc::as_ref)
}
}
impl std::fmt::Debug for SseKeyring {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SseKeyring")
.field("active", &self.active)
.field("key_count", &self.keys.len())
.field("key_ids", &self.keys.keys().collect::<Vec<_>>())
.finish()
}
}
pub type SharedSseKeyring = Arc<SseKeyring>;
pub fn encrypt(key: &SseKey, plaintext: &[u8]) -> Bytes {
let cipher = Aes256Gcm::new(key.as_aes_key());
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let mut aad = [0u8; 8];
aad[..4].copy_from_slice(SSE_MAGIC_V1);
aad[4] = ALGO_AES_256_GCM;
let ct_with_tag = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad: &aad,
},
)
.expect("aes-gcm encrypt cannot fail with a 32-byte key");
debug_assert!(ct_with_tag.len() >= TAG_LEN);
let split = ct_with_tag.len() - TAG_LEN;
let (ct, tag) = ct_with_tag.split_at(split);
let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
out.extend_from_slice(SSE_MAGIC_V1);
out.push(ALGO_AES_256_GCM);
out.extend_from_slice(&[0u8; 3]); out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(tag);
out.extend_from_slice(ct);
Bytes::from(out)
}
pub fn encrypt_v2(plaintext: &[u8], keyring: &SseKeyring) -> Bytes {
let (key_id, key) = keyring.active();
let cipher = Aes256Gcm::new(key.as_aes_key());
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let aad = aad_v2(key_id);
let ct_with_tag = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad: &aad,
},
)
.expect("aes-gcm encrypt cannot fail with a 32-byte key");
debug_assert!(ct_with_tag.len() >= TAG_LEN);
let split = ct_with_tag.len() - TAG_LEN;
let (ct, tag) = ct_with_tag.split_at(split);
let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
out.extend_from_slice(SSE_MAGIC_V2);
out.push(ALGO_AES_256_GCM);
out.extend_from_slice(&key_id.to_be_bytes()); out.push(0u8); out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(tag);
out.extend_from_slice(ct);
Bytes::from(out)
}
fn aad_v1() -> [u8; 8] {
let mut aad = [0u8; 8];
aad[..4].copy_from_slice(SSE_MAGIC_V1);
aad[4] = ALGO_AES_256_GCM;
aad
}
fn aad_v2(key_id: u16) -> [u8; 8] {
let mut aad = [0u8; 8];
aad[..4].copy_from_slice(SSE_MAGIC_V2);
aad[4] = ALGO_AES_256_GCM;
aad[5..7].copy_from_slice(&key_id.to_be_bytes());
aad[7] = 0u8;
aad
}
fn aad_v3(key_md5: &[u8; KEY_MD5_LEN]) -> [u8; 4 + 1 + KEY_MD5_LEN] {
let mut aad = [0u8; 4 + 1 + KEY_MD5_LEN];
aad[..4].copy_from_slice(SSE_MAGIC_V3);
aad[4] = ALGO_AES_256_GCM;
aad[5..5 + KEY_MD5_LEN].copy_from_slice(key_md5);
aad
}
#[derive(Clone)]
pub struct CustomerKeyMaterial {
pub key: [u8; KEY_LEN],
pub key_md5: [u8; KEY_MD5_LEN],
}
impl std::fmt::Debug for CustomerKeyMaterial {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CustomerKeyMaterial")
.field("key", &"<redacted>")
.field("key_md5_hex", &hex_lower(&self.key_md5))
.finish()
}
}
fn hex_lower(bytes: &[u8]) -> String {
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
s.push_str(&format!("{b:02x}"));
}
s
}
#[derive(Debug, Clone, Copy)]
pub enum SseSource<'a> {
Keyring(&'a SseKeyring),
CustomerKey {
key: &'a [u8; KEY_LEN],
key_md5: &'a [u8; KEY_MD5_LEN],
},
Kms {
dek: &'a [u8; KEY_LEN],
wrapped: &'a WrappedDek,
},
}
impl<'a> From<&'a SseKeyring> for SseSource<'a> {
fn from(kr: &'a SseKeyring) -> Self {
SseSource::Keyring(kr)
}
}
impl<'a> From<&'a Arc<SseKeyring>> for SseSource<'a> {
fn from(kr: &'a Arc<SseKeyring>) -> Self {
SseSource::Keyring(kr.as_ref())
}
}
impl<'a> From<&'a CustomerKeyMaterial> for SseSource<'a> {
fn from(m: &'a CustomerKeyMaterial) -> Self {
SseSource::CustomerKey {
key: &m.key,
key_md5: &m.key_md5,
}
}
}
pub fn parse_customer_key_headers(
algorithm: &str,
key_base64: &str,
key_md5_base64: &str,
) -> Result<CustomerKeyMaterial, SseError> {
use base64::Engine as _;
if algorithm != SSE_C_ALGORITHM {
return Err(SseError::CustomerKeyAlgorithmUnsupported {
algo: algorithm.to_string(),
});
}
let key_bytes = base64::engine::general_purpose::STANDARD
.decode(key_base64.trim().as_bytes())
.map_err(|_| SseError::InvalidCustomerKey {
reason: "base64 decode of key",
})?;
if key_bytes.len() != KEY_LEN {
return Err(SseError::InvalidCustomerKey {
reason: "key length (must be 32 bytes after base64 decode)",
});
}
let supplied_md5 = base64::engine::general_purpose::STANDARD
.decode(key_md5_base64.trim().as_bytes())
.map_err(|_| SseError::InvalidCustomerKey {
reason: "base64 decode of key MD5",
})?;
if supplied_md5.len() != KEY_MD5_LEN {
return Err(SseError::InvalidCustomerKey {
reason: "key MD5 length (must be 16 bytes after base64 decode)",
});
}
let actual_md5 = compute_key_md5(&key_bytes);
if !constant_time_eq(&actual_md5, &supplied_md5) {
return Err(SseError::InvalidCustomerKey {
reason: "supplied MD5 does not match MD5 of supplied key",
});
}
let mut key = [0u8; KEY_LEN];
key.copy_from_slice(&key_bytes);
let mut key_md5 = [0u8; KEY_MD5_LEN];
key_md5.copy_from_slice(&actual_md5);
Ok(CustomerKeyMaterial { key, key_md5 })
}
pub fn compute_key_md5(key: &[u8]) -> [u8; KEY_MD5_LEN] {
let mut h = Md5::new();
h.update(key);
let out = h.finalize();
let mut md5 = [0u8; KEY_MD5_LEN];
md5.copy_from_slice(&out);
md5
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut acc: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
acc |= x ^ y;
}
acc == 0
}
pub fn encrypt_with_source(plaintext: &[u8], source: SseSource<'_>) -> Bytes {
match source {
SseSource::Keyring(kr) => encrypt_v2(plaintext, kr),
SseSource::CustomerKey { key, key_md5 } => encrypt_v3(plaintext, key, key_md5),
SseSource::Kms { dek, wrapped } => encrypt_v4(plaintext, dek, wrapped),
}
}
fn encrypt_v3(
plaintext: &[u8],
key: &[u8; KEY_LEN],
key_md5: &[u8; KEY_MD5_LEN],
) -> Bytes {
let aes_key = Key::<Aes256Gcm>::from_slice(key);
let cipher = Aes256Gcm::new(aes_key);
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let aad = aad_v3(key_md5);
let ct_with_tag = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad: &aad,
},
)
.expect("aes-gcm encrypt cannot fail with a 32-byte key");
debug_assert!(ct_with_tag.len() >= TAG_LEN);
let split = ct_with_tag.len() - TAG_LEN;
let (ct, tag) = ct_with_tag.split_at(split);
let mut out = Vec::with_capacity(SSE_HEADER_BYTES_V3 + ct.len());
out.extend_from_slice(SSE_MAGIC_V3);
out.push(ALGO_AES_256_GCM);
out.extend_from_slice(key_md5);
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(tag);
out.extend_from_slice(ct);
Bytes::from(out)
}
pub fn decrypt<'a, S: Into<SseSource<'a>>>(body: &[u8], source: S) -> Result<Bytes, SseError> {
let source = source.into();
if body.len() < SSE_HEADER_BYTES {
return Err(SseError::TooShort { got: body.len() });
}
let mut magic = [0u8; 4];
magic.copy_from_slice(&body[..4]);
match &magic {
m if m == SSE_MAGIC_V1 || m == SSE_MAGIC_V2 => {
let keyring = match source {
SseSource::Keyring(kr) => kr,
SseSource::CustomerKey { .. } => return Err(SseError::CustomerKeyUnexpected),
SseSource::Kms { .. } => return Err(SseError::CustomerKeyUnexpected),
};
if m == SSE_MAGIC_V1 {
decrypt_v1_with_keyring(body, keyring)
} else {
decrypt_v2_with_keyring(body, keyring)
}
}
m if m == SSE_MAGIC_V3 => {
if body.len() < SSE_HEADER_BYTES_V3 {
return Err(SseError::TooShort { got: body.len() });
}
let (key, key_md5) = match source {
SseSource::CustomerKey { key, key_md5 } => (key, key_md5),
SseSource::Keyring(_) => return Err(SseError::CustomerKeyRequired),
SseSource::Kms { .. } => return Err(SseError::CustomerKeyRequired),
};
decrypt_v3(body, key, key_md5)
}
m if m == SSE_MAGIC_V4 => {
Err(SseError::KmsAsyncRequired)
}
_ => Err(SseError::BadMagic { got: magic }),
}
}
fn decrypt_v3(
body: &[u8],
key: &[u8; KEY_LEN],
supplied_md5: &[u8; KEY_MD5_LEN],
) -> Result<Bytes, SseError> {
let algo = body[4];
if algo != ALGO_AES_256_GCM {
return Err(SseError::UnsupportedAlgo { tag: algo });
}
let mut stored_md5 = [0u8; KEY_MD5_LEN];
stored_md5.copy_from_slice(&body[5..5 + KEY_MD5_LEN]);
if !constant_time_eq(supplied_md5, &stored_md5) {
return Err(SseError::WrongCustomerKey);
}
let nonce_off = 5 + KEY_MD5_LEN;
let tag_off = nonce_off + NONCE_LEN;
let mut nonce_bytes = [0u8; NONCE_LEN];
nonce_bytes.copy_from_slice(&body[nonce_off..nonce_off + NONCE_LEN]);
let mut tag_bytes = [0u8; TAG_LEN];
tag_bytes.copy_from_slice(&body[tag_off..tag_off + TAG_LEN]);
let ct = &body[SSE_HEADER_BYTES_V3..];
let aad = aad_v3(&stored_md5);
let nonce = Nonce::from_slice(&nonce_bytes);
let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
ct_with_tag.extend_from_slice(ct);
ct_with_tag.extend_from_slice(&tag_bytes);
let aes_key = Key::<Aes256Gcm>::from_slice(key);
let cipher = Aes256Gcm::new(aes_key);
let plain = cipher
.decrypt(
nonce,
Payload {
msg: &ct_with_tag,
aad: &aad,
},
)
.map_err(|_| SseError::DecryptFailed)?;
Ok(Bytes::from(plain))
}
fn aad_v4(key_id: &[u8], wrapped_dek: &[u8]) -> Vec<u8> {
let mut aad = Vec::with_capacity(4 + 1 + 1 + key_id.len() + 4 + wrapped_dek.len());
aad.extend_from_slice(SSE_MAGIC_V4);
aad.push(ALGO_AES_256_GCM);
aad.push(key_id.len() as u8);
aad.extend_from_slice(key_id);
aad.extend_from_slice(&(wrapped_dek.len() as u32).to_be_bytes());
aad.extend_from_slice(wrapped_dek);
aad
}
fn encrypt_v4(plaintext: &[u8], dek: &[u8; KEY_LEN], wrapped: &WrappedDek) -> Bytes {
assert!(
!wrapped.key_id.is_empty() && wrapped.key_id.len() <= u8::MAX as usize,
"S4E4 key_id must be 1..=255 bytes (got {})",
wrapped.key_id.len()
);
assert!(
wrapped.ciphertext.len() <= u32::MAX as usize,
"S4E4 wrapped_dek longer than u32::MAX",
);
let aes_key = Key::<Aes256Gcm>::from_slice(dek);
let cipher = Aes256Gcm::new(aes_key);
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let aad = aad_v4(wrapped.key_id.as_bytes(), &wrapped.ciphertext);
let ct_with_tag = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad: &aad,
},
)
.expect("aes-gcm encrypt cannot fail with a 32-byte key");
debug_assert!(ct_with_tag.len() >= TAG_LEN);
let split = ct_with_tag.len() - TAG_LEN;
let (ct, tag) = ct_with_tag.split_at(split);
let key_id_bytes = wrapped.key_id.as_bytes();
let mut out = Vec::with_capacity(
4 + 1 + 1 + key_id_bytes.len() + 4 + wrapped.ciphertext.len() + NONCE_LEN + TAG_LEN + ct.len(),
);
out.extend_from_slice(SSE_MAGIC_V4);
out.push(ALGO_AES_256_GCM);
out.push(key_id_bytes.len() as u8);
out.extend_from_slice(key_id_bytes);
out.extend_from_slice(&(wrapped.ciphertext.len() as u32).to_be_bytes());
out.extend_from_slice(&wrapped.ciphertext);
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(tag);
out.extend_from_slice(ct);
Bytes::from(out)
}
#[derive(Debug)]
pub struct S4E4Header<'a> {
pub key_id: &'a str,
pub wrapped_dek: &'a [u8],
pub nonce: &'a [u8],
pub tag: &'a [u8],
pub ciphertext: &'a [u8],
}
pub fn parse_s4e4_header(body: &[u8]) -> Result<S4E4Header<'_>, SseError> {
const S4E4_MIN: usize = 4 + 1 + 1 + 4 + NONCE_LEN + TAG_LEN; if body.len() < S4E4_MIN {
return Err(SseError::KmsFrameTooShort {
got: body.len(),
min: S4E4_MIN,
});
}
let magic = &body[..4];
if magic != SSE_MAGIC_V4 {
let mut got = [0u8; 4];
got.copy_from_slice(magic);
return Err(SseError::BadMagic { got });
}
let algo = body[4];
if algo != ALGO_AES_256_GCM {
return Err(SseError::UnsupportedAlgo { tag: algo });
}
let key_id_len = body[5] as usize;
let key_id_off: usize = 6;
let key_id_end = key_id_off
.checked_add(key_id_len)
.ok_or(SseError::KmsFrameFieldOob { what: "key_id_len" })?;
if key_id_end + 4 > body.len() {
return Err(SseError::KmsFrameFieldOob { what: "key_id" });
}
let key_id = std::str::from_utf8(&body[key_id_off..key_id_end])
.map_err(|_| SseError::KmsKeyIdNotUtf8)?;
let wrapped_len_off = key_id_end;
let wrapped_dek_len = u32::from_be_bytes([
body[wrapped_len_off],
body[wrapped_len_off + 1],
body[wrapped_len_off + 2],
body[wrapped_len_off + 3],
]) as usize;
let wrapped_off = wrapped_len_off + 4;
let wrapped_end = wrapped_off
.checked_add(wrapped_dek_len)
.ok_or(SseError::KmsFrameFieldOob { what: "wrapped_dek_len" })?;
if wrapped_end + NONCE_LEN + TAG_LEN > body.len() {
return Err(SseError::KmsFrameFieldOob { what: "wrapped_dek" });
}
let wrapped_dek = &body[wrapped_off..wrapped_end];
let nonce_off = wrapped_end;
let tag_off = nonce_off + NONCE_LEN;
let ct_off = tag_off + TAG_LEN;
let nonce = &body[nonce_off..nonce_off + NONCE_LEN];
let tag = &body[tag_off..tag_off + TAG_LEN];
let ciphertext = &body[ct_off..];
Ok(S4E4Header {
key_id,
wrapped_dek,
nonce,
tag,
ciphertext,
})
}
pub async fn decrypt_with_kms(
body: &[u8],
kms: &dyn KmsBackend,
) -> Result<Bytes, SseError> {
let hdr = parse_s4e4_header(body)?;
let wrapped = WrappedDek {
key_id: hdr.key_id.to_string(),
ciphertext: hdr.wrapped_dek.to_vec(),
};
let dek_vec = kms.decrypt_dek(&wrapped).await?;
if dek_vec.len() != KEY_LEN {
return Err(SseError::KmsBackend(KmsError::BackendUnavailable {
message: format!(
"KMS returned {} byte DEK; expected {KEY_LEN}",
dek_vec.len()
),
}));
}
let mut dek = [0u8; KEY_LEN];
dek.copy_from_slice(&dek_vec);
let aad = aad_v4(hdr.key_id.as_bytes(), hdr.wrapped_dek);
let aes_key = Key::<Aes256Gcm>::from_slice(&dek);
let cipher = Aes256Gcm::new(aes_key);
let nonce = Nonce::from_slice(hdr.nonce);
let mut ct_with_tag = Vec::with_capacity(hdr.ciphertext.len() + TAG_LEN);
ct_with_tag.extend_from_slice(hdr.ciphertext);
ct_with_tag.extend_from_slice(hdr.tag);
let plain = cipher
.decrypt(
nonce,
Payload {
msg: &ct_with_tag,
aad: &aad,
},
)
.map_err(|_| SseError::DecryptFailed)?;
Ok(Bytes::from(plain))
}
fn decrypt_v1_with_keyring(body: &[u8], keyring: &SseKeyring) -> Result<Bytes, SseError> {
let algo = body[4];
if algo != ALGO_AES_256_GCM {
return Err(SseError::UnsupportedAlgo { tag: algo });
}
let mut nonce_bytes = [0u8; NONCE_LEN];
nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
let mut tag_bytes = [0u8; TAG_LEN];
tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
let ct = &body[SSE_HEADER_BYTES..];
let aad = aad_v1();
let nonce = Nonce::from_slice(&nonce_bytes);
let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
ct_with_tag.extend_from_slice(ct);
ct_with_tag.extend_from_slice(&tag_bytes);
let (active_id, _active_key) = keyring.active();
let mut ids: Vec<u16> = keyring.keys.keys().copied().collect();
ids.sort_by_key(|id| if *id == active_id { 0 } else { 1 });
for id in ids {
let key = keyring.get(id).expect("id came from keyring iteration");
let cipher = Aes256Gcm::new(key.as_aes_key());
if let Ok(plain) = cipher.decrypt(
nonce,
Payload {
msg: &ct_with_tag,
aad: &aad,
},
) {
return Ok(Bytes::from(plain));
}
}
Err(SseError::DecryptFailed)
}
fn decrypt_v2_with_keyring(body: &[u8], keyring: &SseKeyring) -> Result<Bytes, SseError> {
let algo = body[4];
if algo != ALGO_AES_256_GCM {
return Err(SseError::UnsupportedAlgo { tag: algo });
}
let key_id = u16::from_be_bytes([body[5], body[6]]);
let key = keyring
.get(key_id)
.ok_or(SseError::KeyNotInKeyring { id: key_id })?;
let mut nonce_bytes = [0u8; NONCE_LEN];
nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
let mut tag_bytes = [0u8; TAG_LEN];
tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
let ct = &body[SSE_HEADER_BYTES..];
let aad = aad_v2(key_id);
let nonce = Nonce::from_slice(&nonce_bytes);
let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
ct_with_tag.extend_from_slice(ct);
ct_with_tag.extend_from_slice(&tag_bytes);
let cipher = Aes256Gcm::new(key.as_aes_key());
let plain = cipher
.decrypt(
nonce,
Payload {
msg: &ct_with_tag,
aad: &aad,
},
)
.map_err(|_| SseError::DecryptFailed)?;
Ok(Bytes::from(plain))
}
pub fn looks_encrypted(body: &[u8]) -> bool {
if body.len() < SSE_HEADER_BYTES {
return false;
}
let m = &body[..4];
m == SSE_MAGIC_V1 || m == SSE_MAGIC_V2 || m == SSE_MAGIC_V3 || m == SSE_MAGIC_V4
}
pub fn peek_magic(body: &[u8]) -> Option<&'static str> {
if body.len() < SSE_HEADER_BYTES {
return None;
}
match &body[..4] {
m if m == SSE_MAGIC_V1 => Some("S4E1"),
m if m == SSE_MAGIC_V2 => Some("S4E2"),
m if m == SSE_MAGIC_V3 => Some("S4E3"),
m if m == SSE_MAGIC_V4 => Some("S4E4"),
_ => None,
}
}
pub type SharedSseKey = Arc<SseKey>;
#[cfg(test)]
mod tests {
use super::*;
fn key32(seed: u8) -> Arc<SseKey> {
Arc::new(SseKey::from_bytes(&[seed; 32]).unwrap())
}
fn keyring_single(seed: u8) -> SseKeyring {
SseKeyring::new(1, key32(seed))
}
#[test]
fn roundtrip_basic_v1() {
let k = SseKey::from_bytes(&[7u8; 32]).unwrap();
let pt = b"the quick brown fox jumps over the lazy dog";
let ct = encrypt(&k, pt);
assert!(looks_encrypted(&ct));
assert_eq!(&ct[..4], SSE_MAGIC_V1);
assert_eq!(ct[4], ALGO_AES_256_GCM);
assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
let kr = SseKeyring::new(1, Arc::new(k));
let pt2 = decrypt(&ct, &kr).unwrap();
assert_eq!(pt2.as_ref(), pt);
}
#[test]
fn s4e2_roundtrip_active_key() {
let kr = keyring_single(7);
let pt = b"S4E2 active-key roundtrip";
let ct = encrypt_v2(pt, &kr);
assert_eq!(&ct[..4], SSE_MAGIC_V2);
assert_eq!(ct[4], ALGO_AES_256_GCM);
assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1, "key_id BE");
assert_eq!(ct[7], 0, "reserved byte");
assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
assert!(looks_encrypted(&ct));
let pt2 = decrypt(&ct, &kr).unwrap();
assert_eq!(pt2.as_ref(), pt);
}
#[test]
fn decrypt_s4e1_via_active_only_keyring() {
let k_arc = key32(11);
let legacy_ct = encrypt(&k_arc, b"v0.4 vintage object");
assert_eq!(&legacy_ct[..4], SSE_MAGIC_V1);
let kr = SseKeyring::new(1, Arc::clone(&k_arc));
let plain = decrypt(&legacy_ct, &kr).unwrap();
assert_eq!(plain.as_ref(), b"v0.4 vintage object");
}
#[test]
fn decrypt_s4e2_under_old_key_after_rotation() {
let k1 = key32(1);
let k2 = key32(2);
let mut kr_old = SseKeyring::new(1, Arc::clone(&k1));
let ct = encrypt_v2(b"old-rotation object", &kr_old);
assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1);
kr_old.add(2, Arc::clone(&k2));
let mut kr_new = SseKeyring::new(2, Arc::clone(&k2));
kr_new.add(1, Arc::clone(&k1));
let plain = decrypt(&ct, &kr_new).unwrap();
assert_eq!(plain.as_ref(), b"old-rotation object");
let new_ct = encrypt_v2(b"new-rotation object", &kr_new);
assert_eq!(u16::from_be_bytes([new_ct[5], new_ct[6]]), 2);
let plain_new = decrypt(&new_ct, &kr_new).unwrap();
assert_eq!(plain_new.as_ref(), b"new-rotation object");
}
#[test]
fn s4e2_unknown_key_id_errors() {
let kr = keyring_single(3); let kr_other = SseKeyring::new(99, key32(3));
let ct = encrypt_v2(b"x", &kr_other); let err = decrypt(&ct, &kr).unwrap_err();
assert!(
matches!(err, SseError::KeyNotInKeyring { id: 99 }),
"got {err:?}"
);
}
#[test]
fn s4e2_tampered_key_id_fails_auth() {
let kr = SseKeyring::new(1, key32(4));
let mut kr_with_2 = kr.clone();
kr_with_2.add(2, key32(5)); let mut ct = encrypt_v2(b"do not flip my key id", &kr).to_vec();
assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1);
ct[5] = 0;
ct[6] = 2;
let err = decrypt(&ct, &kr_with_2).unwrap_err();
assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
}
#[test]
fn s4e2_tampered_ciphertext_fails() {
let kr = SseKeyring::new(7, key32(9));
let mut ct = encrypt_v2(b"secret message v2", &kr).to_vec();
let last = ct.len() - 1;
ct[last] ^= 0x01;
let err = decrypt(&ct, &kr).unwrap_err();
assert!(matches!(err, SseError::DecryptFailed));
}
#[test]
fn s4e2_tampered_algo_byte_fails() {
let kr = SseKeyring::new(1, key32(2));
let mut ct = encrypt_v2(b"hi", &kr).to_vec();
ct[4] = 99;
let err = decrypt(&ct, &kr).unwrap_err();
assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
}
#[test]
fn wrong_key_fails_v1_via_keyring() {
let k1 = SseKey::from_bytes(&[1u8; 32]).unwrap();
let ct = encrypt(&k1, b"secret");
let kr_wrong = SseKeyring::new(1, Arc::new(SseKey::from_bytes(&[2u8; 32]).unwrap()));
let err = decrypt(&ct, &kr_wrong).unwrap_err();
assert!(matches!(err, SseError::DecryptFailed));
}
#[test]
fn rejects_short_body() {
let kr = SseKeyring::new(1, key32(1));
let err = decrypt(b"short", &kr).unwrap_err();
assert!(matches!(err, SseError::TooShort { got: 5 }));
}
#[test]
fn looks_encrypted_passthrough_returns_false() {
let f2 = b"S4F2\x01\x00\x00\x00........................................";
assert!(!looks_encrypted(f2));
assert!(!looks_encrypted(b""));
}
#[test]
fn looks_encrypted_detects_both_v1_and_v2() {
let kr = SseKeyring::new(1, key32(8));
let v1 = encrypt(&SseKey::from_bytes(&[8u8; 32]).unwrap(), b"x");
let v2 = encrypt_v2(b"x", &kr);
assert!(looks_encrypted(&v1));
assert!(looks_encrypted(&v2));
}
#[test]
fn key_from_hex_string() {
let bad =
SseKey::from_bytes(b"0102030405060708090a0b0c0d0e0f10111213141516171819202122232425")
.unwrap_err();
assert!(matches!(bad, SseError::BadKeyLength { .. }));
let good = b"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
let _ = SseKey::from_bytes(good).expect("64-char hex should parse");
}
#[test]
fn encrypt_v2_uses_random_nonce() {
let kr = SseKeyring::new(1, key32(3));
let pt = b"deterministic input";
let a = encrypt_v2(pt, &kr);
let b = encrypt_v2(pt, &kr);
assert_ne!(a, b, "nonce must be random per-call");
}
#[test]
fn keyring_active_and_get() {
let k1 = key32(1);
let k2 = key32(2);
let mut kr = SseKeyring::new(1, Arc::clone(&k1));
kr.add(2, Arc::clone(&k2));
let (id, active) = kr.active();
assert_eq!(id, 1);
assert_eq!(active.bytes, [1u8; 32]);
assert!(kr.get(2).is_some());
assert!(kr.get(3).is_none());
}
use base64::Engine as _;
fn cust_key(seed: u8) -> CustomerKeyMaterial {
let key = [seed; KEY_LEN];
let key_md5 = compute_key_md5(&key);
CustomerKeyMaterial { key, key_md5 }
}
#[test]
fn s4e3_roundtrip_happy_path() {
let m = cust_key(42);
let pt = b"top-secret SSE-C payload";
let ct = encrypt_with_source(
pt,
SseSource::CustomerKey {
key: &m.key,
key_md5: &m.key_md5,
},
);
assert_eq!(&ct[..4], SSE_MAGIC_V3);
assert_eq!(ct[4], ALGO_AES_256_GCM);
assert_eq!(&ct[5..5 + KEY_MD5_LEN], &m.key_md5);
assert_eq!(ct.len(), SSE_HEADER_BYTES_V3 + pt.len());
assert!(looks_encrypted(&ct));
let plain = decrypt(
&ct,
SseSource::CustomerKey {
key: &m.key,
key_md5: &m.key_md5,
},
)
.unwrap();
assert_eq!(plain.as_ref(), pt);
let plain2 = decrypt(&ct, &m).unwrap();
assert_eq!(plain2.as_ref(), pt);
}
#[test]
fn s4e3_wrong_key_yields_wrong_customer_key_error() {
let m = cust_key(1);
let other = cust_key(2);
let ct = encrypt_with_source(b"payload", (&m).into());
let err = decrypt(
&ct,
SseSource::CustomerKey {
key: &other.key,
key_md5: &other.key_md5,
},
)
.unwrap_err();
assert!(matches!(err, SseError::WrongCustomerKey), "got {err:?}");
}
#[test]
fn s4e3_tampered_stored_md5_is_caught() {
let m = cust_key(7);
let mut ct = encrypt_with_source(b"victim payload", (&m).into()).to_vec();
ct[5] ^= 0x55;
let err = decrypt(
&ct,
SseSource::CustomerKey {
key: &m.key,
key_md5: &m.key_md5,
},
)
.unwrap_err();
assert!(matches!(err, SseError::WrongCustomerKey), "got {err:?}");
}
#[test]
fn s4e3_tampered_md5_with_matching_supplied_md5_fails_aead() {
let m = cust_key(3);
let mut ct = encrypt_with_source(b"x", (&m).into()).to_vec();
ct[5] ^= 0xFF;
let mut bogus_md5 = m.key_md5;
bogus_md5[0] ^= 0xFF;
let err = decrypt(
&ct,
SseSource::CustomerKey {
key: &m.key,
key_md5: &bogus_md5,
},
)
.unwrap_err();
assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
}
#[test]
fn s4e3_tampered_ciphertext_fails_aead() {
let m = cust_key(8);
let mut ct = encrypt_with_source(b"sealed message", (&m).into()).to_vec();
let last = ct.len() - 1;
ct[last] ^= 0x01;
let err = decrypt(&ct, &m).unwrap_err();
assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
}
#[test]
fn s4e3_tampered_algo_byte_rejected() {
let m = cust_key(9);
let mut ct = encrypt_with_source(b"x", (&m).into()).to_vec();
ct[4] = 99;
let err = decrypt(&ct, &m).unwrap_err();
assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
}
#[test]
fn s4e3_uses_random_nonce() {
let m = cust_key(10);
let a = encrypt_with_source(b"deterministic input", (&m).into());
let b = encrypt_with_source(b"deterministic input", (&m).into());
assert_ne!(a, b, "nonce must be random per-call");
}
#[test]
fn parse_customer_key_headers_happy_path() {
let key = [11u8; KEY_LEN];
let md5 = compute_key_md5(&key);
let key_b64 = base64::engine::general_purpose::STANDARD.encode(key);
let md5_b64 = base64::engine::general_purpose::STANDARD.encode(md5);
let m = parse_customer_key_headers("AES256", &key_b64, &md5_b64).unwrap();
assert_eq!(m.key, key);
assert_eq!(m.key_md5, md5);
}
#[test]
fn parse_customer_key_headers_rejects_wrong_algorithm() {
let key = [1u8; KEY_LEN];
let md5 = compute_key_md5(&key);
let kb = base64::engine::general_purpose::STANDARD.encode(key);
let mb = base64::engine::general_purpose::STANDARD.encode(md5);
let err = parse_customer_key_headers("AES128", &kb, &mb).unwrap_err();
assert!(
matches!(err, SseError::CustomerKeyAlgorithmUnsupported { ref algo } if algo == "AES128"),
"got {err:?}"
);
let err2 = parse_customer_key_headers("aes256", &kb, &mb).unwrap_err();
assert!(
matches!(err2, SseError::CustomerKeyAlgorithmUnsupported { .. }),
"got {err2:?}"
);
}
#[test]
fn parse_customer_key_headers_rejects_wrong_key_length() {
let short_key = vec![5u8; 16]; let md5 = compute_key_md5(&short_key);
let kb = base64::engine::general_purpose::STANDARD.encode(&short_key);
let mb = base64::engine::general_purpose::STANDARD.encode(md5);
let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
assert!(
matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("key length")),
"got {err:?}"
);
}
#[test]
fn parse_customer_key_headers_rejects_wrong_md5_length() {
let key = [3u8; KEY_LEN];
let kb = base64::engine::general_purpose::STANDARD.encode(key);
let bad_md5 = vec![0u8; 15];
let mb = base64::engine::general_purpose::STANDARD.encode(bad_md5);
let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
assert!(
matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("MD5 length")),
"got {err:?}"
);
}
#[test]
fn parse_customer_key_headers_rejects_md5_mismatch() {
let key = [4u8; KEY_LEN];
let other = [5u8; KEY_LEN];
let kb = base64::engine::general_purpose::STANDARD.encode(key);
let wrong_md5 = compute_key_md5(&other);
let mb = base64::engine::general_purpose::STANDARD.encode(wrong_md5);
let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
assert!(
matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("MD5 does not match")),
"got {err:?}"
);
}
#[test]
fn parse_customer_key_headers_rejects_bad_base64() {
let valid_key = [0u8; KEY_LEN];
let md5 = compute_key_md5(&valid_key);
let mb = base64::engine::general_purpose::STANDARD.encode(md5);
let err = parse_customer_key_headers("AES256", "!!!not-base64!!!", &mb).unwrap_err();
assert!(
matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("base64")),
"got {err:?}"
);
let kb = base64::engine::general_purpose::STANDARD.encode(valid_key);
let err2 = parse_customer_key_headers("AES256", &kb, "??not-base64??").unwrap_err();
assert!(
matches!(err2, SseError::InvalidCustomerKey { reason } if reason.contains("base64")),
"got {err2:?}"
);
}
#[test]
fn parse_customer_key_headers_trims_whitespace() {
let key = [12u8; KEY_LEN];
let md5 = compute_key_md5(&key);
let kb = format!(
" {}\n",
base64::engine::general_purpose::STANDARD.encode(key)
);
let mb = format!(
"\t{} ",
base64::engine::general_purpose::STANDARD.encode(md5)
);
let m = parse_customer_key_headers("AES256", &kb, &mb).unwrap();
assert_eq!(m.key, key);
}
#[test]
fn back_compat_decrypt_s4e1_with_keyring_source() {
let k = key32(33);
let legacy_ct = encrypt(&k, b"v0.4 vintage object");
let kr = SseKeyring::new(1, Arc::clone(&k));
let plain = decrypt(&legacy_ct, &kr).unwrap();
assert_eq!(plain.as_ref(), b"v0.4 vintage object");
let plain2 = decrypt(&legacy_ct, SseSource::Keyring(&kr)).unwrap();
assert_eq!(plain2.as_ref(), b"v0.4 vintage object");
}
#[test]
fn back_compat_decrypt_s4e2_with_keyring_source() {
let kr = keyring_single(34);
let ct = encrypt_v2(b"v0.5 #29 object", &kr);
let plain = decrypt(&ct, &kr).unwrap();
assert_eq!(plain.as_ref(), b"v0.5 #29 object");
let ct2 = encrypt_with_source(b"v0.5 #29 object", SseSource::Keyring(&kr));
assert_eq!(&ct2[..4], SSE_MAGIC_V2);
let plain2 = decrypt(&ct2, &kr).unwrap();
assert_eq!(plain2.as_ref(), b"v0.5 #29 object");
}
#[test]
fn s4e2_blob_with_customer_key_source_is_rejected() {
let kr = keyring_single(50);
let ct = encrypt_v2(b"server-managed object", &kr);
let m = cust_key(99);
let err = decrypt(
&ct,
SseSource::CustomerKey {
key: &m.key,
key_md5: &m.key_md5,
},
)
.unwrap_err();
assert!(matches!(err, SseError::CustomerKeyUnexpected), "got {err:?}");
}
#[test]
fn s4e3_blob_with_keyring_source_is_rejected() {
let m = cust_key(60);
let ct = encrypt_with_source(b"customer-key object", (&m).into());
let kr = keyring_single(60);
let err = decrypt(&ct, &kr).unwrap_err();
assert!(matches!(err, SseError::CustomerKeyRequired), "got {err:?}");
}
#[test]
fn looks_encrypted_detects_s4e3() {
let m = cust_key(13);
let ct = encrypt_with_source(b"x", (&m).into());
assert!(looks_encrypted(&ct));
}
#[test]
fn s4e3_rejects_short_body() {
let mut short = Vec::new();
short.extend_from_slice(SSE_MAGIC_V3);
short.push(ALGO_AES_256_GCM);
short.extend_from_slice(&[0u8; SSE_HEADER_BYTES - 5]);
assert_eq!(short.len(), SSE_HEADER_BYTES);
let m = cust_key(1);
let err = decrypt(
&short,
SseSource::CustomerKey {
key: &m.key,
key_md5: &m.key_md5,
},
)
.unwrap_err();
assert!(matches!(err, SseError::TooShort { .. }), "got {err:?}");
}
#[test]
fn customer_key_material_debug_redacts_key() {
let m = cust_key(99);
let s = format!("{m:?}");
assert!(s.contains("redacted"));
assert!(!s.contains(&format!("{:?}", m.key.as_slice())));
}
#[test]
fn constant_time_eq_basic() {
assert!(constant_time_eq(b"abc", b"abc"));
assert!(!constant_time_eq(b"abc", b"abd"));
assert!(!constant_time_eq(b"abc", b"abcd"));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn compute_key_md5_known_vector() {
let got = compute_key_md5(b"");
let expected_hex = "d41d8cd98f00b204e9800998ecf8427e";
assert_eq!(hex_lower(&got), expected_hex);
}
use crate::kms::{KmsBackend, LocalKms};
use std::collections::HashMap;
use std::path::PathBuf;
fn local_kms_with(key_ids: &[(&str, [u8; 32])]) -> LocalKms {
let mut keks: HashMap<String, [u8; 32]> = HashMap::new();
for (id, k) in key_ids {
keks.insert((*id).to_string(), *k);
}
LocalKms::from_keks(PathBuf::from("/tmp/none"), keks)
}
#[tokio::test]
async fn s4e4_roundtrip_via_local_kms() {
let kms = local_kms_with(&[("alpha", [42u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let pt = b"SSE-KMS envelope payload across the S4E4 frame";
let ct = encrypt_with_source(
pt,
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
);
assert_eq!(&ct[..4], SSE_MAGIC_V4);
assert_eq!(ct[4], ALGO_AES_256_GCM);
let key_id_len = ct[5] as usize;
assert_eq!(key_id_len, "alpha".len());
assert_eq!(&ct[6..6 + key_id_len], b"alpha");
assert!(looks_encrypted(&ct));
assert_eq!(peek_magic(&ct), Some("S4E4"));
let plain = decrypt_with_kms(&ct, &kms).await.unwrap();
assert_eq!(plain.as_ref(), pt);
}
#[tokio::test]
async fn s4e4_tampered_key_id_fails_aead() {
let kms = local_kms_with(&[("alpha", [1u8; 32]), ("beta", [2u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let mut ct = encrypt_with_source(
b"do not redirect",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
)
.to_vec();
let key_id_off = 6;
ct[key_id_off] = b'b';
let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
assert!(
matches!(
err,
SseError::KmsBackend(crate::kms::KmsError::UnwrapFailed { .. })
| SseError::KmsBackend(crate::kms::KmsError::KeyNotFound { .. })
),
"got {err:?}"
);
}
#[tokio::test]
async fn s4e4_tampered_key_id_to_real_other_id_still_fails() {
let kms = local_kms_with(&[("alpha", [1u8; 32]), ("beta", [2u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let mut ct = encrypt_with_source(
b"redirect attempt",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
)
.to_vec();
let key_id_off = 6;
ct[key_id_off..key_id_off + 5].copy_from_slice(b"beta_");
let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
assert!(
matches!(
err,
SseError::KmsBackend(crate::kms::KmsError::KeyNotFound { .. })
),
"got {err:?}"
);
}
#[tokio::test]
async fn s4e4_tampered_wrapped_dek_fails_unwrap() {
let kms = local_kms_with(&[("k", [3u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let mut ct = encrypt_with_source(
b"target body",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
)
.to_vec();
let key_id_len = ct[5] as usize;
let wrapped_len_off = 6 + key_id_len;
let wrapped_off = wrapped_len_off + 4;
let mid = wrapped_off + (wrapped.ciphertext.len() / 2);
ct[mid] ^= 0xFF;
let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
assert!(
matches!(
err,
SseError::KmsBackend(crate::kms::KmsError::UnwrapFailed { .. })
),
"got {err:?}"
);
}
#[tokio::test]
async fn s4e4_tampered_ciphertext_fails_aead() {
let kms = local_kms_with(&[("k", [4u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let mut ct = encrypt_with_source(
b"sealed body",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
)
.to_vec();
let last = ct.len() - 1;
ct[last] ^= 0x01;
let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
}
#[tokio::test]
async fn s4e4_uses_random_nonce_and_dek_per_put() {
let kms = local_kms_with(&[("k", [5u8; 32])]);
let (dek1_vec, wrapped1) = kms.generate_dek("k").await.unwrap();
let (dek2_vec, wrapped2) = kms.generate_dek("k").await.unwrap();
let mut dek1 = [0u8; 32];
dek1.copy_from_slice(&dek1_vec);
let mut dek2 = [0u8; 32];
dek2.copy_from_slice(&dek2_vec);
let pt = b"deterministic input";
let a = encrypt_with_source(
pt,
SseSource::Kms {
dek: &dek1,
wrapped: &wrapped1,
},
);
let b = encrypt_with_source(
pt,
SseSource::Kms {
dek: &dek2,
wrapped: &wrapped2,
},
);
assert_ne!(a, b);
let plain_a = decrypt_with_kms(&a, &kms).await.unwrap();
let plain_b = decrypt_with_kms(&b, &kms).await.unwrap();
assert_eq!(plain_a.as_ref(), pt);
assert_eq!(plain_b.as_ref(), pt);
}
#[tokio::test]
async fn s4e4_sync_decrypt_returns_kms_async_required() {
let kms = local_kms_with(&[("k", [6u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let ct = encrypt_with_source(
b"async only",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
);
let kr = SseKeyring::new(1, key32(0));
let err = decrypt(&ct, &kr).unwrap_err();
assert!(matches!(err, SseError::KmsAsyncRequired), "got {err:?}");
}
#[test]
fn back_compat_s4e1_e2_e3_still_decrypt_via_sync() {
let k = key32(7);
let v1 = encrypt(&k, b"v0.4 vintage");
let kr = SseKeyring::new(1, Arc::clone(&k));
assert_eq!(decrypt(&v1, &kr).unwrap().as_ref(), b"v0.4 vintage");
let v2 = encrypt_v2(b"v0.5 #29 vintage", &kr);
assert_eq!(
decrypt(&v2, &kr).unwrap().as_ref(),
b"v0.5 #29 vintage"
);
let m = cust_key(7);
let v3 = encrypt_with_source(b"v0.5 #27 vintage", (&m).into());
assert_eq!(
decrypt(&v3, &m).unwrap().as_ref(),
b"v0.5 #27 vintage"
);
}
#[test]
fn peek_magic_distinguishes_all_variants() {
let k = key32(9);
let v1 = encrypt(&k, b"x");
assert_eq!(peek_magic(&v1), Some("S4E1"));
let kr = SseKeyring::new(1, Arc::clone(&k));
let v2 = encrypt_v2(b"x", &kr);
assert_eq!(peek_magic(&v2), Some("S4E2"));
let m = cust_key(9);
let v3 = encrypt_with_source(b"x", (&m).into());
assert_eq!(peek_magic(&v3), Some("S4E3"));
let mut v4 = Vec::new();
v4.extend_from_slice(SSE_MAGIC_V4);
v4.extend_from_slice(&[0u8; 40]);
assert_eq!(peek_magic(&v4), Some("S4E4"));
assert!(peek_magic(b"NOPE").is_none());
assert!(peek_magic(b"short").is_none());
assert!(peek_magic(&[0u8; 100]).is_none());
}
#[tokio::test]
async fn s4e4_truncated_frame_errors_cleanly() {
let truncated = b"S4E4\x01\x05hi";
let kms = local_kms_with(&[("k", [1u8; 32])]);
let err = decrypt_with_kms(truncated, &kms).await.unwrap_err();
assert!(
matches!(err, SseError::KmsFrameTooShort { .. }),
"got {err:?}"
);
}
#[tokio::test]
async fn s4e4_oob_key_id_len_errors() {
let mut body = Vec::new();
body.extend_from_slice(SSE_MAGIC_V4);
body.push(ALGO_AES_256_GCM);
body.push(200u8); body.extend_from_slice(&[0u8; 50]);
let kms = local_kms_with(&[("k", [1u8; 32])]);
let err = decrypt_with_kms(&body, &kms).await.unwrap_err();
assert!(
matches!(err, SseError::KmsFrameFieldOob { .. }),
"got {err:?}"
);
}
#[tokio::test]
async fn s4e4_via_keyring_source_into_sync_decrypt_is_kms_async_required() {
let kms = local_kms_with(&[("k", [9u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let ct = encrypt_with_source(
b"x",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
);
let m = cust_key(1);
let err = decrypt(&ct, &m).unwrap_err();
assert!(matches!(err, SseError::KmsAsyncRequired), "got {err:?}");
}
#[tokio::test]
async fn s4e4_looks_encrypted_passthrough_returns_false_for_synthetic() {
let mut not_s4e4 = Vec::new();
not_s4e4.extend_from_slice(b"S4F4");
not_s4e4.extend_from_slice(&[0u8; 60]);
assert!(!looks_encrypted(¬_s4e4));
assert_eq!(peek_magic(¬_s4e4), None);
}
#[tokio::test]
async fn s4e4_aad_length_prefix_prevents_byte_shifting() {
let kms = local_kms_with(&[("kk", [11u8; 32])]);
let (dek_vec, wrapped) = kms.generate_dek("kk").await.unwrap();
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_vec);
let mut ct = encrypt_with_source(
b"length-shift defense",
SseSource::Kms {
dek: &dek,
wrapped: &wrapped,
},
)
.to_vec();
let key_id_len = ct[5] as usize;
let wrapped_len_off = 6 + key_id_len;
let original_len = u32::from_be_bytes([
ct[wrapped_len_off],
ct[wrapped_len_off + 1],
ct[wrapped_len_off + 2],
ct[wrapped_len_off + 3],
]);
let new_len = (original_len - 1).to_be_bytes();
ct[wrapped_len_off..wrapped_len_off + 4].copy_from_slice(&new_len);
let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
assert!(
matches!(
err,
SseError::KmsBackend(_)
| SseError::DecryptFailed
| SseError::KmsFrameFieldOob { .. }
| SseError::KmsFrameTooShort { .. }
),
"got {err:?}"
);
}
}