use hmac::{Hmac, KeyInit, Mac};
use sha2::Sha256;
use subtle::ConstantTimeEq;
use zeroize::Zeroize;
use crate::kdf::hkdf_sha256;
use super::aead::xchacha20_poly1305_decrypt;
use super::errors::{EciesSealedPoeError, EciesSealedPoeErrorCode};
use super::kem::{
mlkem768x25519_decapsulate, x25519_ecdh, x25519_public_key, KemError, MLKEM768X25519_ENC_LENGTH,
};
use super::slots::{
join_kem_ct, slots_to_mac_cbor, Mlkem768X25519Slot, SealedEnvelope, SealedSlots, X25519Slot,
AEAD_XCHACHA20_POLY1305, KEM_MLKEM768X25519, KEM_X25519,
};
use super::wrap::{
CARDANO_POE_HKDF_INFO_KEK, CARDANO_POE_HKDF_INFO_KEK_MLKEM768X25519,
CARDANO_POE_HKDF_INFO_SLOTS_MAC,
};
use super::aead::chacha20_poly1305_decrypt;
const ZERO_NONCE_12: [u8; 12] = [0u8; 12];
const X25519_SECRET_KEY_LENGTH: usize = 32;
const NONCE_LENGTH: usize = 24;
const WRAP_LENGTH: usize = 48;
const SLOTS_MAC_LENGTH: usize = 32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnwrapFailureReason {
WrongRecipientKey,
TamperedHeader,
TamperedCiphertext,
}
impl UnwrapFailureReason {
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
UnwrapFailureReason::WrongRecipientKey => "WRONG_RECIPIENT_KEY",
UnwrapFailureReason::TamperedHeader => "TAMPERED_HEADER",
UnwrapFailureReason::TamperedCiphertext => "TAMPERED_CIPHERTEXT",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UnwrapResult {
Matched {
plaintext: Vec<u8>,
},
NotMatched {
reason: UnwrapFailureReason,
},
}
impl UnwrapResult {
#[must_use]
pub fn matched(&self) -> bool {
matches!(self, UnwrapResult::Matched { .. })
}
}
#[derive(Debug, Clone, Default)]
pub struct RecipientKeyBundle {
pub x25519_private_keys: Vec<Vec<u8>>,
pub mlkem768x25519_secret_seeds: Vec<Vec<u8>>,
}
pub enum UnwrapKeys<'a> {
Single(&'a [u8]),
Multi(&'a [Vec<u8>]),
Bundle(&'a RecipientKeyBundle),
}
#[derive(Debug, Default, Clone)]
pub struct UnwrapProbe {
pub inner: SlotsAttempted,
pub outer: PrivsAttempted,
}
#[derive(Debug, Default, Clone)]
pub struct SlotsAttempted {
pub count: usize,
pub per_priv_counts: Vec<usize>,
}
#[derive(Debug, Default, Clone)]
pub struct PrivsAttempted {
pub count: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TrialDecryptResult {
Match {
slot_idx: usize,
cek: Vec<u8>,
},
NoAeadPass,
AeadPassNoMacMatch,
}
fn select_bundle_secrets<'a>(
envelope: &SealedEnvelope,
bundle: &'a RecipientKeyBundle,
) -> &'a [Vec<u8>] {
if envelope.kem == KEM_X25519 {
&bundle.x25519_private_keys
} else {
&bundle.mlkem768x25519_secret_seeds
}
}
fn assert_envelope_structure(
envelope: &SealedEnvelope,
multi_priv_keys: Option<&[Vec<u8>]>,
single_priv_key: Option<&[u8]>,
) -> Result<(), EciesSealedPoeError> {
if envelope.scheme != 1 {
return Err(EciesSealedPoeError::new(
EciesSealedPoeErrorCode::UnsupportedEncVersion,
format!(
"envelope.scheme={} unsupported (expected 1)",
envelope.scheme
),
));
}
if envelope.aead != AEAD_XCHACHA20_POLY1305 {
return Err(EciesSealedPoeError::new(
EciesSealedPoeErrorCode::UnsupportedAeadAlg,
format!(
"envelope.aead={} unsupported (expected '{AEAD_XCHACHA20_POLY1305}')",
envelope.aead
),
));
}
if envelope.kem != KEM_X25519 && envelope.kem != KEM_MLKEM768X25519 {
return Err(EciesSealedPoeError::new(
EciesSealedPoeErrorCode::UnsupportedKemAlg,
format!(
"envelope.kem={} unsupported (expected '{KEM_X25519}' or '{KEM_MLKEM768X25519}')",
envelope.kem
),
));
}
let n = envelope.slots.len();
if n < 1 {
return Err(EciesSealedPoeError::new(
EciesSealedPoeErrorCode::EncSlotsEmpty,
format!("envelope.slots.len()={n} must be >= 1"),
));
}
if envelope.nonce.len() != NONCE_LENGTH {
return Err(EciesSealedPoeError::new(
EciesSealedPoeErrorCode::NonceLengthMismatch,
format!(
"envelope.nonce MUST be exactly {NONCE_LENGTH} bytes, got {}",
envelope.nonce.len()
),
));
}
if envelope.slots_mac.len() != SLOTS_MAC_LENGTH {
return Err(EciesSealedPoeError::new(
EciesSealedPoeErrorCode::EncSlotsMacInvalidLength,
format!(
"envelope.slots_mac MUST be exactly {SLOTS_MAC_LENGTH} bytes, got {}",
envelope.slots_mac.len()
),
));
}
match &envelope.slots {
SealedSlots::X25519(slots) => {
for (i, slot) in slots.iter().enumerate() {
if slot.epk.len() != X25519_SECRET_KEY_LENGTH {
return Err(EciesSealedPoeError::new(
EciesSealedPoeErrorCode::KemEpkLengthMismatch,
format!(
"envelope.slots[{i}].epk MUST be exactly {X25519_SECRET_KEY_LENGTH} bytes, got {}",
slot.epk.len()
),
));
}
if slot.wrap.len() != WRAP_LENGTH {
return Err(wrap_length_error(i, slot.wrap.len()));
}
}
}
SealedSlots::Mlkem768X25519(slots) => {
for (i, slot) in slots.iter().enumerate() {
let enc = join_kem_ct(&slot.kem_ct);
if enc.len() != MLKEM768X25519_ENC_LENGTH {
return Err(EciesSealedPoeError::new(
EciesSealedPoeErrorCode::KemCtLengthMismatch,
format!(
"envelope.slots[{i}].kem_ct MUST reassemble to exactly {MLKEM768X25519_ENC_LENGTH} bytes, got {}",
enc.len()
),
));
}
if slot.wrap.len() != WRAP_LENGTH {
return Err(wrap_length_error(i, slot.wrap.len()));
}
}
}
}
if let Some(keys) = multi_priv_keys {
for (i, key) in keys.iter().enumerate() {
if key.len() != X25519_SECRET_KEY_LENGTH {
return Err(EciesSealedPoeError::new(
EciesSealedPoeErrorCode::InvalidRecipientKey,
format!(
"recipient_secret_keys[{i}] MUST be exactly {X25519_SECRET_KEY_LENGTH} bytes, got {}",
key.len()
),
));
}
}
} else if let Some(key) = single_priv_key {
if key.len() != X25519_SECRET_KEY_LENGTH {
return Err(EciesSealedPoeError::new(
EciesSealedPoeErrorCode::InvalidRecipientKey,
format!(
"recipient_secret_key MUST be exactly {X25519_SECRET_KEY_LENGTH} bytes, got {}",
key.len()
),
));
}
}
Ok(())
}
fn wrap_length_error(slot_idx: usize, got: usize) -> EciesSealedPoeError {
EciesSealedPoeError::new(
EciesSealedPoeErrorCode::WrapLengthMismatch,
format!("envelope.slots[{slot_idx}].wrap MUST be exactly {WRAP_LENGTH} bytes, got {got}"),
)
}
fn try_x25519_slot(
slot: &X25519Slot,
recipient_secret_key: &[u8],
pub_r_local: &[u8],
live_slot: bool,
) -> Option<Vec<u8>> {
let shared = match x25519_ecdh(recipient_secret_key, &slot.epk) {
Ok(s) => s,
Err(KemError::X25519LowOrderPoint) => return None,
Err(_) => return None,
};
let mut salt = Vec::with_capacity(slot.epk.len() + pub_r_local.len());
salt.extend_from_slice(&slot.epk);
salt.extend_from_slice(pub_r_local);
let mut kek = hkdf_sha256(&shared, &salt, CARDANO_POE_HKDF_INFO_KEK, 32)
.expect("32-byte HKDF output is within the RFC 5869 maximum");
if !live_slot {
kek.zeroize();
return None;
}
let result =
chacha20_poly1305_decrypt(&kek, &ZERO_NONCE_12, CARDANO_POE_HKDF_INFO_KEK, &slot.wrap).ok();
kek.zeroize();
result
}
fn try_mlkem768x25519_slot(
slot: &Mlkem768X25519Slot,
recipient_secret_seed: &[u8],
live_slot: bool,
) -> Option<Vec<u8>> {
let enc = join_kem_ct(&slot.kem_ct);
let mut ss = mlkem768x25519_decapsulate(recipient_secret_seed, &enc)
.expect("kem_ct reassembles to the validated enc length and the seed length is checked");
let mut kek = hkdf_sha256(&ss, &[], CARDANO_POE_HKDF_INFO_KEK_MLKEM768X25519, 32)
.expect("32-byte HKDF output is within the RFC 5869 maximum");
ss.zeroize();
if !live_slot {
kek.zeroize();
return None;
}
let result = chacha20_poly1305_decrypt(
&kek,
&ZERO_NONCE_12,
CARDANO_POE_HKDF_INFO_KEK_MLKEM768X25519,
&slot.wrap,
)
.ok();
kek.zeroize();
result
}
fn try_recipient_unwrap_with_idx(
envelope: &SealedEnvelope,
recipient_secret_key: &[u8],
constant_time_n: bool,
probe: Option<&mut SlotsAttempted>,
) -> Option<(Vec<u8>, usize)> {
let mut cek: Option<Vec<u8>> = None;
let mut matched_slot_idx = 0usize;
let mut slots_count = 0usize;
match &envelope.slots {
SealedSlots::X25519(slots) => {
let pub_r_local =
x25519_public_key(recipient_secret_key).expect("recipient key length checked");
for (i, slot) in slots.iter().enumerate() {
slots_count = i + 1;
let candidate =
try_x25519_slot(slot, recipient_secret_key, &pub_r_local, cek.is_none());
if cek.is_none() {
if let Some(c) = candidate {
cek = Some(c);
matched_slot_idx = i;
}
}
if cek.is_some() && !constant_time_n {
break;
}
}
}
SealedSlots::Mlkem768X25519(slots) => {
for (i, slot) in slots.iter().enumerate() {
slots_count = i + 1;
let candidate = try_mlkem768x25519_slot(slot, recipient_secret_key, cek.is_none());
if cek.is_none() {
if let Some(c) = candidate {
cek = Some(c);
matched_slot_idx = i;
}
}
if cek.is_some() && !constant_time_n {
break;
}
}
}
}
if let Some(p) = probe {
p.count = slots_count;
}
cek.map(|c| (c, matched_slot_idx))
}
fn slots_mac_matches(cek: &[u8], slots_cbor: &[u8], expected: &[u8]) -> bool {
let mut hmac_key = hkdf_sha256(cek, &[], CARDANO_POE_HKDF_INFO_SLOTS_MAC, 32)
.expect("32-byte HKDF output is within the RFC 5869 maximum");
let mut mac =
<Hmac<Sha256>>::new_from_slice(&hmac_key).expect("HMAC accepts a key of any length");
mac.update(slots_cbor);
let calc = mac.finalize().into_bytes();
hmac_key.zeroize();
calc.ct_eq(expected).into()
}
pub fn ecies_sealed_poe_unwrap(
envelope: &SealedEnvelope,
ciphertext: &[u8],
keys: UnwrapKeys<'_>,
constant_time_n: bool,
mut probe: Option<&mut UnwrapProbe>,
) -> Result<UnwrapResult, EciesSealedPoeError> {
let mut single: Option<&[u8]> = None;
let mut multi: Option<&[Vec<u8>]> = None;
let mut is_bundle = false;
match keys {
UnwrapKeys::Single(k) => single = Some(k),
UnwrapKeys::Multi(list) => multi = Some(list),
UnwrapKeys::Bundle(bundle) => {
multi = Some(select_bundle_secrets(envelope, bundle));
is_bundle = true;
}
}
if let Some(list) = multi {
if list.is_empty() {
if is_bundle {
return Ok(UnwrapResult::NotMatched {
reason: UnwrapFailureReason::WrongRecipientKey,
});
}
return Err(EciesSealedPoeError::new(
EciesSealedPoeErrorCode::InvalidRecipientKey,
"recipient_secret_keys MUST be a non-empty list, got length 0",
));
}
}
assert_envelope_structure(envelope, multi, single)?;
let slots_cbor = slots_to_mac_cbor(&envelope.slots);
let mut matched_cek: Option<Vec<u8>> = None;
if let Some(recipient_secret_key) = single {
let mut slots_attempted = SlotsAttempted::default();
let candidate = try_recipient_unwrap_with_idx(
envelope,
recipient_secret_key,
constant_time_n,
Some(&mut slots_attempted),
);
if let Some(p) = probe.as_deref_mut() {
p.inner.count = slots_attempted.count;
}
match candidate {
None => {
return Ok(UnwrapResult::NotMatched {
reason: UnwrapFailureReason::WrongRecipientKey,
});
}
Some((cek, _)) => {
if !slots_mac_matches(&cek, &slots_cbor, &envelope.slots_mac) {
return Ok(UnwrapResult::NotMatched {
reason: UnwrapFailureReason::TamperedHeader,
});
}
matched_cek = Some(cek);
}
}
} else {
let keys = multi.expect("exactly one of single/multi is set");
let mut any_candidate_recovered = false;
for (k, key) in keys.iter().enumerate() {
if let Some(p) = probe.as_deref_mut() {
p.outer.count = k + 1;
}
let mut slots_attempted = SlotsAttempted::default();
let candidate = try_recipient_unwrap_with_idx(
envelope,
key,
constant_time_n,
Some(&mut slots_attempted),
);
if let Some(p) = probe.as_deref_mut() {
p.inner.count = slots_attempted.count;
p.inner.per_priv_counts.push(slots_attempted.count);
}
let Some((cek, _)) = candidate else {
continue;
};
if slots_mac_matches(&cek, &slots_cbor, &envelope.slots_mac) {
matched_cek = Some(cek);
break;
}
any_candidate_recovered = true;
}
if matched_cek.is_none() {
return Ok(UnwrapResult::NotMatched {
reason: if any_candidate_recovered {
UnwrapFailureReason::TamperedHeader
} else {
UnwrapFailureReason::WrongRecipientKey
},
});
}
}
let mut matched_cek = matched_cek.expect("matched_cek set on every non-early-return path");
let mut ad_content = Vec::with_capacity(envelope.nonce.len() + envelope.slots_mac.len());
ad_content.extend_from_slice(&envelope.nonce);
ad_content.extend_from_slice(&envelope.slots_mac);
let result =
match xchacha20_poly1305_decrypt(&matched_cek, &envelope.nonce, &ad_content, ciphertext) {
Ok(plaintext) => UnwrapResult::Matched { plaintext },
Err(_) => UnwrapResult::NotMatched {
reason: UnwrapFailureReason::TamperedCiphertext,
},
};
matched_cek.zeroize();
Ok(result)
}
pub enum TrialDecryptKeys<'a> {
Multi(&'a [Vec<u8>]),
Bundle(&'a RecipientKeyBundle),
}
pub fn ecies_sealed_poe_trial_decrypt(
envelope: &SealedEnvelope,
keys: TrialDecryptKeys<'_>,
constant_time_n: bool,
mut probe: Option<&mut UnwrapProbe>,
) -> Result<TrialDecryptResult, EciesSealedPoeError> {
let (recipient_secret_keys, is_bundle): (&[Vec<u8>], bool) = match keys {
TrialDecryptKeys::Multi(list) => (list, false),
TrialDecryptKeys::Bundle(bundle) => (select_bundle_secrets(envelope, bundle), true),
};
if recipient_secret_keys.is_empty() {
if is_bundle {
return Ok(TrialDecryptResult::NoAeadPass);
}
return Err(EciesSealedPoeError::new(
EciesSealedPoeErrorCode::InvalidRecipientKey,
"recipient_secret_keys MUST be a non-empty list, got length 0",
));
}
assert_envelope_structure(envelope, Some(recipient_secret_keys), None)?;
let slots_cbor = slots_to_mac_cbor(&envelope.slots);
let mut any_candidate_recovered = false;
for (k, key) in recipient_secret_keys.iter().enumerate() {
if let Some(p) = probe.as_deref_mut() {
p.outer.count = k + 1;
}
let mut slots_attempted = SlotsAttempted::default();
let candidate = try_recipient_unwrap_with_idx(
envelope,
key,
constant_time_n,
Some(&mut slots_attempted),
);
if let Some(p) = probe.as_deref_mut() {
p.inner.count = slots_attempted.count;
p.inner.per_priv_counts.push(slots_attempted.count);
}
let Some((cek, slot_idx)) = candidate else {
continue;
};
if slots_mac_matches(&cek, &slots_cbor, &envelope.slots_mac) {
return Ok(TrialDecryptResult::Match { slot_idx, cek });
}
any_candidate_recovered = true;
}
Ok(if any_candidate_recovered {
TrialDecryptResult::AeadPassNoMacMatch
} else {
TrialDecryptResult::NoAeadPass
})
}