use std::collections::HashSet;
use aes_gcm::aead::Aead;
use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce};
use base64::Engine;
use hkdf::Hkdf;
use rand::RngCore;
use sha2::Sha256;
use subtle::ConstantTimeEq;
use x25519_dalek::{EphemeralSecret, PublicKey, StaticSecret};
use zeroize::Zeroizing;
use crate::CryptoError;
use crate::canonicalize::canonicalize_json;
use crate::hash::{
PayloadAadParams, PayloadCipherParams, compute_payload_aad, compute_payload_cipher_hash,
compute_payload_plain_hash, compute_recipients_hash,
};
const NONCE_SIZE: usize = 12;
const SALT_SIZE: usize = 16;
const KEY_SIZE: usize = 32;
const TAG_SIZE: usize = 16;
const ENC_VERSION: u64 = 1;
const AEAD_ALGORITHM: &str = "AES-256-GCM";
const HPKE_MODE: &str = "base";
const HPKE_KEM: &str = "X25519-HKDF-SHA256";
const HPKE_KDF: &str = "HKDF-SHA256";
const HPKE_AEAD: &str = "AES-256-GCM";
#[derive(Debug, Clone)]
pub struct RecipientKey {
pub kid: u32,
pub public_key: [u8; 32],
}
#[derive(Debug, Clone)]
pub struct EncryptionResult {
pub payload_encrypted: serde_json::Value,
pub salt: [u8; 16],
pub payload_plain_hash: [u8; 32],
pub payload_cipher_hash: [u8; 32],
}
pub fn encrypt_payload(
payload: &serde_json::Value,
aad_params: &PayloadAadParams<'_>,
recipient_keys: &[RecipientKey],
) -> Result<EncryptionResult, CryptoError> {
if recipient_keys.is_empty() {
return Err(CryptoError::NoRecipients);
}
validate_unique_recipient_kids(recipient_keys.iter().map(|recipient| recipient.kid))?;
let mut rng = rand::thread_rng();
let b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD;
let mut salt = [0u8; SALT_SIZE];
let mut dek = Zeroizing::new([0u8; KEY_SIZE]);
let mut nonce_bytes = [0u8; NONCE_SIZE];
rng.fill_bytes(&mut salt);
rng.fill_bytes(dek.as_mut());
rng.fill_bytes(&mut nonce_bytes);
let payload_plain_hash = compute_payload_plain_hash(payload, Some(&salt))?;
let updated_aad_params =
PayloadAadParams { payload_plain_hash: &payload_plain_hash, ..*aad_params };
let payload_aad = compute_payload_aad(&updated_aad_params)?;
let canonical = canonicalize_json(payload)?;
let mut plaintext = Vec::with_capacity(SALT_SIZE + canonical.len());
plaintext.extend_from_slice(&salt);
plaintext.extend_from_slice(canonical.as_bytes());
let key = Key::<Aes256Gcm>::from_slice(&*dek);
let cipher = Aes256Gcm::new(key);
let nonce = Nonce::from_slice(&nonce_bytes);
let aead_payload = aes_gcm::aead::Payload { msg: &plaintext, aad: &payload_aad };
let ciphertext_with_tag = cipher
.encrypt(nonce, aead_payload)
.map_err(|e| CryptoError::EncryptionError(e.to_string()))?;
if ciphertext_with_tag.len() < TAG_SIZE {
return Err(CryptoError::EncryptionError(
"AES-GCM output shorter than authentication tag".to_string(),
));
}
let ct_len = ciphertext_with_tag.len() - TAG_SIZE;
let ciphertext = &ciphertext_with_tag[..ct_len];
let tag = &ciphertext_with_tag[ct_len..];
let mut recipients = Vec::with_capacity(recipient_keys.len());
for rk in recipient_keys {
let (enc, wrapped_key) = wrap_dek(&dek, &rk.public_key, &payload_aad)?;
recipients.push(serde_json::json!({
"recipient_kid": rk.kid,
"enc_b64u": b64.encode(&enc),
"ct_b64u": b64.encode(&wrapped_key),
}));
}
recipients.sort_by(|a, b| {
let a_kid = a.get("recipient_kid").and_then(|v| v.as_u64()).unwrap_or(0);
let b_kid = b.get("recipient_kid").and_then(|v| v.as_u64()).unwrap_or(0);
a_kid.cmp(&b_kid)
});
let recipients_hash = compute_recipients_hash(&recipients)?;
let cipher_params = PayloadCipherParams {
nonce: &nonce_bytes,
payload_aad: &payload_aad,
ciphertext,
tag,
recipients_hash: &recipients_hash,
};
let payload_cipher_hash = compute_payload_cipher_hash(Some(&cipher_params));
let payload_encrypted = serde_json::json!({
"enc_version": ENC_VERSION,
"aead": AEAD_ALGORITHM,
"nonce_b64u": b64.encode(nonce_bytes),
"ciphertext_b64u": b64.encode(ciphertext),
"tag_b64u": b64.encode(tag),
"hpke": {
"mode": HPKE_MODE,
"kem": HPKE_KEM,
"kdf": HPKE_KDF,
"aead": HPKE_AEAD
},
"recipients": recipients,
});
Ok(EncryptionResult { payload_encrypted, salt, payload_plain_hash, payload_cipher_hash })
}
pub fn decrypt_payload(
payload_encrypted: &serde_json::Value,
payload_aad: &[u8; 32],
recipient_kid: u32,
recipient_private_key: &[u8; 32],
expected_plain_hash: &[u8; 32],
) -> Result<serde_json::Value, CryptoError> {
let b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD;
let recipients = validate_encryption_envelope(payload_encrypted)?;
let recipient = recipients
.iter()
.find(|r| r.get("recipient_kid").and_then(|v| v.as_u64()) == Some(u64::from(recipient_kid)))
.ok_or(CryptoError::RecipientNotFound(recipient_kid))?;
let enc = b64
.decode(
recipient
.get("enc_b64u")
.and_then(|v| v.as_str())
.ok_or_else(|| CryptoError::DecryptionError("Missing enc_b64u".to_string()))?,
)
.map_err(|e| CryptoError::DecryptionError(e.to_string()))?;
let wrapped_key = b64
.decode(
recipient
.get("ct_b64u")
.and_then(|v| v.as_str())
.ok_or_else(|| CryptoError::DecryptionError("Missing ct_b64u".to_string()))?,
)
.map_err(|e| CryptoError::DecryptionError(e.to_string()))?;
if enc.len() != KEY_SIZE {
return Err(CryptoError::DecryptionError(format!(
"Invalid enc_b64u length: expected {KEY_SIZE}, got {}",
enc.len()
)));
}
let mut enc_arr = [0u8; 32];
enc_arr.copy_from_slice(&enc);
let dek = unwrap_dek(&enc_arr, &wrapped_key, recipient_private_key, payload_aad)?;
let nonce_bytes = b64
.decode(
payload_encrypted
.get("nonce_b64u")
.and_then(|v| v.as_str())
.ok_or_else(|| CryptoError::DecryptionError("Missing nonce_b64u".to_string()))?,
)
.map_err(|e| CryptoError::DecryptionError(e.to_string()))?;
if nonce_bytes.len() != NONCE_SIZE {
return Err(CryptoError::DecryptionError(format!(
"Invalid nonce length: expected {NONCE_SIZE}, got {}",
nonce_bytes.len()
)));
}
let ciphertext = b64
.decode(
payload_encrypted.get("ciphertext_b64u").and_then(|v| v.as_str()).ok_or_else(|| {
CryptoError::DecryptionError("Missing ciphertext_b64u".to_string())
})?,
)
.map_err(|e| CryptoError::DecryptionError(e.to_string()))?;
let tag = b64
.decode(
payload_encrypted
.get("tag_b64u")
.and_then(|v| v.as_str())
.ok_or_else(|| CryptoError::DecryptionError("Missing tag_b64u".to_string()))?,
)
.map_err(|e| CryptoError::DecryptionError(e.to_string()))?;
if tag.len() != TAG_SIZE {
return Err(CryptoError::DecryptionError(format!(
"Invalid tag length: expected {TAG_SIZE}, got {}",
tag.len()
)));
}
let key = Key::<Aes256Gcm>::from_slice(&dek);
let cipher_obj = Aes256Gcm::new(key);
let nonce = Nonce::from_slice(&nonce_bytes);
let mut ct_with_tag = Vec::with_capacity(ciphertext.len() + tag.len());
ct_with_tag.extend_from_slice(&ciphertext);
ct_with_tag.extend_from_slice(&tag);
let aead_payload = aes_gcm::aead::Payload { msg: &ct_with_tag, aad: payload_aad };
let plaintext = cipher_obj
.decrypt(nonce, aead_payload)
.map_err(|e| CryptoError::DecryptionError(e.to_string()))?;
if plaintext.len() < SALT_SIZE {
return Err(CryptoError::DecryptionError("Plaintext too short".to_string()));
}
let salt: [u8; 16] = plaintext[..SALT_SIZE].try_into().map_err(|_| CryptoError::InvalidSalt)?;
let json_bytes = &plaintext[SALT_SIZE..];
let payload: serde_json::Value = serde_json::from_slice(json_bytes)
.map_err(|e| CryptoError::DecryptionError(e.to_string()))?;
let computed_hash = compute_payload_plain_hash(&payload, Some(&salt))?;
if computed_hash.ct_eq(expected_plain_hash).unwrap_u8() == 0 {
return Err(CryptoError::PayloadHashMismatch);
}
Ok(payload)
}
fn validate_unique_recipient_kids(
recipient_kids: impl IntoIterator<Item = u32>,
) -> Result<(), CryptoError> {
let mut seen = HashSet::new();
for recipient_kid in recipient_kids {
if !seen.insert(recipient_kid) {
return Err(CryptoError::EncryptionError(format!(
"duplicate recipient_kid: {recipient_kid}"
)));
}
}
Ok(())
}
fn validate_encryption_envelope(
payload_encrypted: &serde_json::Value,
) -> Result<&Vec<serde_json::Value>, CryptoError> {
let enc_version = payload_encrypted
.get("enc_version")
.and_then(|value| value.as_u64())
.ok_or_else(|| CryptoError::DecryptionError("Missing enc_version".to_string()))?;
if enc_version != ENC_VERSION {
return Err(CryptoError::DecryptionError(format!(
"Unsupported enc_version: expected {ENC_VERSION}, got {enc_version}"
)));
}
let aead = payload_encrypted
.get("aead")
.and_then(|value| value.as_str())
.ok_or_else(|| CryptoError::DecryptionError("Missing aead".to_string()))?;
if aead != AEAD_ALGORITHM {
return Err(CryptoError::DecryptionError(format!(
"Unsupported aead: expected {AEAD_ALGORITHM}, got {aead}"
)));
}
let hpke = payload_encrypted
.get("hpke")
.and_then(|value| value.as_object())
.ok_or_else(|| CryptoError::DecryptionError("Missing hpke".to_string()))?;
validate_hpke_field(hpke, "mode", HPKE_MODE)?;
validate_hpke_field(hpke, "kem", HPKE_KEM)?;
validate_hpke_field(hpke, "kdf", HPKE_KDF)?;
validate_hpke_field(hpke, "aead", HPKE_AEAD)?;
let recipients = payload_encrypted
.get("recipients")
.and_then(|value| value.as_array())
.ok_or_else(|| CryptoError::DecryptionError("Missing recipients".to_string()))?;
let mut seen = HashSet::new();
for recipient in recipients {
let Some(recipient_kid) = recipient.get("recipient_kid").and_then(|value| value.as_u64())
else {
return Err(CryptoError::DecryptionError("Missing recipient_kid".to_string()));
};
if !seen.insert(recipient_kid) {
return Err(CryptoError::DecryptionError(format!(
"Duplicate recipient_kid entry: {recipient_kid}"
)));
}
}
Ok(recipients)
}
fn validate_hpke_field(
hpke: &serde_json::Map<String, serde_json::Value>,
field: &str,
expected: &str,
) -> Result<(), CryptoError> {
let value = hpke
.get(field)
.and_then(|value| value.as_str())
.ok_or_else(|| CryptoError::DecryptionError(format!("Missing hpke.{field}")))?;
if value != expected {
return Err(CryptoError::DecryptionError(format!(
"Unsupported hpke.{field}: expected {expected}, got {value}"
)));
}
Ok(())
}
fn wrap_dek(
dek: &[u8; 32],
recipient_public_key: &[u8; 32],
info: &[u8],
) -> Result<(Vec<u8>, Vec<u8>), CryptoError> {
let mut rng = rand::thread_rng();
let ephemeral_secret = EphemeralSecret::random_from_rng(&mut rng);
let ephemeral_public = PublicKey::from(&ephemeral_secret);
let enc = ephemeral_public.as_bytes().to_vec();
let recipient_pk = PublicKey::from(*recipient_public_key);
let shared_secret = ephemeral_secret.diffie_hellman(&recipient_pk);
let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
let mut wrapping_key = Zeroizing::new([0u8; 32]);
hk.expand(info, wrapping_key.as_mut()).map_err(|e| CryptoError::KeyWrapError(e.to_string()))?;
let mut wrap_nonce_bytes = [0u8; NONCE_SIZE];
rng.fill_bytes(&mut wrap_nonce_bytes);
let key = Key::<Aes256Gcm>::from_slice(&*wrapping_key);
let wrap_cipher = Aes256Gcm::new(key);
let wrap_nonce = Nonce::from_slice(&wrap_nonce_bytes);
let wrapped = wrap_cipher
.encrypt(wrap_nonce, dek.as_ref())
.map_err(|e| CryptoError::KeyWrapError(e.to_string()))?;
let mut wrapped_key = Vec::with_capacity(NONCE_SIZE + wrapped.len());
wrapped_key.extend_from_slice(&wrap_nonce_bytes);
wrapped_key.extend_from_slice(&wrapped);
Ok((enc, wrapped_key))
}
fn unwrap_dek(
enc: &[u8; 32],
wrapped_key: &[u8],
recipient_private_key: &[u8; 32],
info: &[u8],
) -> Result<[u8; 32], CryptoError> {
let ephemeral_pk = PublicKey::from(*enc);
let recipient_sk = StaticSecret::from(*recipient_private_key);
let shared_secret = recipient_sk.diffie_hellman(&ephemeral_pk);
let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
let mut wrapping_key = Zeroizing::new([0u8; 32]);
hk.expand(info, wrapping_key.as_mut()).map_err(|e| CryptoError::KeyWrapError(e.to_string()))?;
if wrapped_key.len() < NONCE_SIZE + TAG_SIZE {
return Err(CryptoError::KeyWrapError("Wrapped key too short".to_string()));
}
let wrap_nonce = Nonce::from_slice(&wrapped_key[..NONCE_SIZE]);
let ciphertext_tag = &wrapped_key[NONCE_SIZE..];
let key = Key::<Aes256Gcm>::from_slice(&*wrapping_key);
let wrap_cipher = Aes256Gcm::new(key);
let dek_bytes = wrap_cipher
.decrypt(wrap_nonce, ciphertext_tag)
.map_err(|e| CryptoError::KeyWrapError(e.to_string()))?;
if dek_bytes.len() != KEY_SIZE {
return Err(CryptoError::KeyWrapError(format!(
"Invalid unwrapped DEK length: expected {KEY_SIZE}, got {}",
dek_bytes.len()
)));
}
let mut dek = Zeroizing::new([0u8; KEY_SIZE]);
dek.copy_from_slice(&dek_bytes);
Ok(*dek)
}
pub fn generate_x25519_keypair() -> ([u8; 32], [u8; 32]) {
let mut rng = rand::thread_rng();
let secret = StaticSecret::random_from_rng(&mut rng);
let public = PublicKey::from(&secret);
(secret.to_bytes(), *public.as_bytes())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hash::PayloadAadParams;
use serde_json::json;
const TEST_UUID: &str = "550e8400-e29b-41d4-a716-446655440000";
fn test_aad_params(plain_hash: &[u8; 32]) -> PayloadAadParams<'_> {
PayloadAadParams {
ves_version: 1,
tenant_id: TEST_UUID,
store_id: TEST_UUID,
event_id: TEST_UUID,
source_agent_id: TEST_UUID,
agent_key_id: 1,
entity_type: "order",
entity_id: "ord_001",
event_type: "order.created",
created_at: "2026-02-21T00:00:00Z",
payload_plain_hash: plain_hash,
}
}
#[test]
fn dek_wrap_unwrap_roundtrip() {
let (private_key, public_key) = generate_x25519_keypair();
let dek = [42u8; 32];
let info = b"test_info";
let (enc, wrapped) = wrap_dek(&dek, &public_key, info).unwrap();
let mut enc_arr = [0u8; 32];
enc_arr.copy_from_slice(&enc);
let recovered = unwrap_dek(&enc_arr, &wrapped, &private_key, info).unwrap();
assert_eq!(recovered, dek);
}
#[test]
fn dek_unwrap_wrong_key_fails() {
let (_, public_key) = generate_x25519_keypair();
let (wrong_private, _) = generate_x25519_keypair();
let dek = [42u8; 32];
let info = b"test_info";
let (enc, wrapped) = wrap_dek(&dek, &public_key, info).unwrap();
let mut enc_arr = [0u8; 32];
enc_arr.copy_from_slice(&enc);
assert!(unwrap_dek(&enc_arr, &wrapped, &wrong_private, info).is_err());
}
#[test]
fn encrypt_decrypt_roundtrip() {
let payload = json!({"order_id": "ord_001", "amount": 99.99});
let plain_hash = [0u8; 32]; let aad_params = test_aad_params(&plain_hash);
let (private_key, public_key) = generate_x25519_keypair();
let recipients = vec![RecipientKey { kid: 1, public_key }];
let enc_result = encrypt_payload(&payload, &aad_params, &recipients).unwrap();
let dec_aad_params =
PayloadAadParams { payload_plain_hash: &enc_result.payload_plain_hash, ..aad_params };
let dec_payload_aad = crate::hash::compute_payload_aad(&dec_aad_params).unwrap();
let decrypted = decrypt_payload(
&enc_result.payload_encrypted,
&dec_payload_aad,
1,
&private_key,
&enc_result.payload_plain_hash,
)
.unwrap();
assert_eq!(decrypted, payload);
}
#[test]
fn encrypt_no_recipients_fails() {
let payload = json!({"key": "value"});
let plain_hash = [0u8; 32];
let aad_params = test_aad_params(&plain_hash);
assert!(encrypt_payload(&payload, &aad_params, &[]).is_err());
}
#[test]
fn encrypt_duplicate_recipient_kids_fail() {
let payload = json!({"key": "value"});
let plain_hash = [0u8; 32];
let aad_params = test_aad_params(&plain_hash);
let (_, public_key) = generate_x25519_keypair();
let recipients =
vec![RecipientKey { kid: 7, public_key }, RecipientKey { kid: 7, public_key }];
let err = encrypt_payload(&payload, &aad_params, &recipients)
.expect_err("duplicate recipient ids should be rejected");
match err {
CryptoError::EncryptionError(message) => {
assert!(message.contains("duplicate recipient_kid"))
}
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn encrypt_multiple_recipients() {
let payload = json!({"key": "value"});
let plain_hash = [0u8; 32];
let aad_params = test_aad_params(&plain_hash);
let (priv1, pub1) = generate_x25519_keypair();
let (priv2, pub2) = generate_x25519_keypair();
let recipients = vec![
RecipientKey { kid: 1, public_key: pub1 },
RecipientKey { kid: 2, public_key: pub2 },
];
let enc_result = encrypt_payload(&payload, &aad_params, &recipients).unwrap();
let dec_aad_params =
PayloadAadParams { payload_plain_hash: &enc_result.payload_plain_hash, ..aad_params };
let dec_payload_aad = crate::hash::compute_payload_aad(&dec_aad_params).unwrap();
let d1 = decrypt_payload(
&enc_result.payload_encrypted,
&dec_payload_aad,
1,
&priv1,
&enc_result.payload_plain_hash,
);
let d2 = decrypt_payload(
&enc_result.payload_encrypted,
&dec_payload_aad,
2,
&priv2,
&enc_result.payload_plain_hash,
);
assert!(d1.is_ok());
assert!(d2.is_ok());
assert_eq!(d1.unwrap(), payload);
assert_eq!(d2.unwrap(), payload);
}
#[test]
fn decrypt_wrong_recipient_fails() {
let payload = json!({"key": "value"});
let plain_hash = [0u8; 32];
let aad_params = test_aad_params(&plain_hash);
let (_, pub_key) = generate_x25519_keypair();
let recipients = vec![RecipientKey { kid: 1, public_key: pub_key }];
let enc_result = encrypt_payload(&payload, &aad_params, &recipients).unwrap();
let dec_aad_params =
PayloadAadParams { payload_plain_hash: &enc_result.payload_plain_hash, ..aad_params };
let dec_payload_aad = crate::hash::compute_payload_aad(&dec_aad_params).unwrap();
let (wrong_priv, _) = generate_x25519_keypair();
let result = decrypt_payload(
&enc_result.payload_encrypted,
&dec_payload_aad,
99,
&wrong_priv,
&enc_result.payload_plain_hash,
);
assert!(result.is_err());
}
#[test]
fn decrypt_rejects_invalid_ephemeral_key_length() {
let payload = json!({"key": "value"});
let plain_hash = [0u8; 32];
let aad_params = test_aad_params(&plain_hash);
let (private_key, public_key) = generate_x25519_keypair();
let recipients = vec![RecipientKey { kid: 1, public_key }];
let enc_result = encrypt_payload(&payload, &aad_params, &recipients).unwrap();
let dec_aad_params =
PayloadAadParams { payload_plain_hash: &enc_result.payload_plain_hash, ..aad_params };
let dec_payload_aad = crate::hash::compute_payload_aad(&dec_aad_params).unwrap();
let b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD;
let mut tampered = enc_result.payload_encrypted.clone();
tampered["recipients"][0]["enc_b64u"] = serde_json::Value::String(b64.encode([7u8; 31]));
let err = decrypt_payload(
&tampered,
&dec_payload_aad,
1,
&private_key,
&enc_result.payload_plain_hash,
)
.expect_err("invalid enc length should error");
match err {
CryptoError::DecryptionError(msg) => assert!(msg.contains("Invalid enc_b64u length")),
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn decrypt_rejects_invalid_nonce_length() {
let payload = json!({"key": "value"});
let plain_hash = [0u8; 32];
let aad_params = test_aad_params(&plain_hash);
let (private_key, public_key) = generate_x25519_keypair();
let recipients = vec![RecipientKey { kid: 1, public_key }];
let enc_result = encrypt_payload(&payload, &aad_params, &recipients).unwrap();
let dec_aad_params =
PayloadAadParams { payload_plain_hash: &enc_result.payload_plain_hash, ..aad_params };
let dec_payload_aad = crate::hash::compute_payload_aad(&dec_aad_params).unwrap();
let b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD;
let mut tampered = enc_result.payload_encrypted.clone();
tampered["nonce_b64u"] = serde_json::Value::String(b64.encode([1u8; 11]));
let err = decrypt_payload(
&tampered,
&dec_payload_aad,
1,
&private_key,
&enc_result.payload_plain_hash,
)
.expect_err("invalid nonce length should error");
match err {
CryptoError::DecryptionError(msg) => assert!(msg.contains("Invalid nonce length")),
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn decrypt_rejects_invalid_tag_length() {
let payload = json!({"key": "value"});
let plain_hash = [0u8; 32];
let aad_params = test_aad_params(&plain_hash);
let (private_key, public_key) = generate_x25519_keypair();
let recipients = vec![RecipientKey { kid: 1, public_key }];
let enc_result = encrypt_payload(&payload, &aad_params, &recipients).unwrap();
let dec_aad_params =
PayloadAadParams { payload_plain_hash: &enc_result.payload_plain_hash, ..aad_params };
let dec_payload_aad = crate::hash::compute_payload_aad(&dec_aad_params).unwrap();
let b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD;
let mut tampered = enc_result.payload_encrypted.clone();
tampered["tag_b64u"] = serde_json::Value::String(b64.encode([9u8; 8]));
let err = decrypt_payload(
&tampered,
&dec_payload_aad,
1,
&private_key,
&enc_result.payload_plain_hash,
)
.expect_err("invalid tag length should error");
match err {
CryptoError::DecryptionError(msg) => assert!(msg.contains("Invalid tag length")),
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn decrypt_rejects_tampered_metadata() {
let payload = json!({"key": "value"});
let plain_hash = [0u8; 32];
let aad_params = test_aad_params(&plain_hash);
let (private_key, public_key) = generate_x25519_keypair();
let recipients = vec![RecipientKey { kid: 1, public_key }];
let enc_result = encrypt_payload(&payload, &aad_params, &recipients).unwrap();
let dec_aad_params =
PayloadAadParams { payload_plain_hash: &enc_result.payload_plain_hash, ..aad_params };
let dec_payload_aad = crate::hash::compute_payload_aad(&dec_aad_params).unwrap();
let mut tampered = enc_result.payload_encrypted.clone();
tampered["aead"] = serde_json::Value::String("CHACHA20-POLY1305".to_string());
let err = decrypt_payload(
&tampered,
&dec_payload_aad,
1,
&private_key,
&enc_result.payload_plain_hash,
)
.expect_err("tampered metadata should be rejected");
match err {
CryptoError::DecryptionError(message) => assert!(message.contains("Unsupported aead")),
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn decrypt_rejects_duplicate_recipient_entries() {
let payload = json!({"key": "value"});
let plain_hash = [0u8; 32];
let aad_params = test_aad_params(&plain_hash);
let (private_key, public_key) = generate_x25519_keypair();
let recipients = vec![RecipientKey { kid: 1, public_key }];
let enc_result = encrypt_payload(&payload, &aad_params, &recipients).unwrap();
let dec_aad_params =
PayloadAadParams { payload_plain_hash: &enc_result.payload_plain_hash, ..aad_params };
let dec_payload_aad = crate::hash::compute_payload_aad(&dec_aad_params).unwrap();
let mut tampered = enc_result.payload_encrypted.clone();
let duplicate = tampered["recipients"][0].clone();
tampered["recipients"].as_array_mut().unwrap().push(duplicate);
let err = decrypt_payload(
&tampered,
&dec_payload_aad,
1,
&private_key,
&enc_result.payload_plain_hash,
)
.expect_err("duplicate recipients should be rejected");
match err {
CryptoError::DecryptionError(message) => {
assert!(message.contains("Duplicate recipient_kid entry"))
}
other => panic!("unexpected error variant: {other:?}"),
}
}
}