#[cfg(feature = "alloc")]
use alloc::format;
use alloc::sync::Arc;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use zeroize::Zeroizing;
use crate::error::{
HpkeError,
SecurityValidation,
};
use crate::hpke_session::{
HpkeReceiverContext,
HpkeSenderContext,
};
use crate::providers::traits::*;
use crate::security::CryptoRng;
use crate::security::constant_time::constant_time_eq;
use crate::types::*;
pub struct KeyScheduleSecrets {
pub key: Zeroizing<Vec<u8>>,
pub nonce: Zeroizing<Vec<u8>>,
pub exporter_secret: Zeroizing<Vec<u8>>,
}
type ParsedReceiverEncapsulatedKey = (Vec<u8>, Option<Vec<u8>>, HpkeKem);
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EncapsulatedKeyLayoutError {
TooShort {
got: usize,
min_expected: usize,
},
InvalidPskCommitmentSuffix {
got_suffix_len: usize,
expected_suffix_len: usize,
},
}
impl core::fmt::Display for EncapsulatedKeyLayoutError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::TooShort { got, min_expected } => write!(
f,
"encapsulated key too short: got {} bytes, need at least {}",
got, min_expected
),
Self::InvalidPskCommitmentSuffix {
got_suffix_len,
expected_suffix_len,
} => write!(
f,
"invalid PSK commitment suffix length: got {}, expected {}",
got_suffix_len, expected_suffix_len
),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for EncapsulatedKeyLayoutError {}
pub(crate) fn ensure_cipher_suite_supported<P: HpkeCryptoProvider + ?Sized>(
cipher_suite: &HpkeCipherSuite,
provider: &P,
) -> Result<(), HpkeError> {
let supported = provider.supported_algorithms();
if !supported.supports_cipher_suite(cipher_suite) {
return Err(HpkeError::ConfigError {
setting: "cipher_suite".into(),
cause: format!(
"HPKE provider `{}` does not support KEM {:?} + KDF {:?} + AEAD {:?}",
provider.name(),
cipher_suite.kem,
cipher_suite.kdf,
cipher_suite.aead
),
});
}
Ok(())
}
fn verify_sender_keypair_binding<P: HpkeCryptoProvider + ?Sized>(
provider: &P,
kem: HpkeKem,
sender_sk: &lib_q_core::KemSecretKey,
sender_pk: &lib_q_core::KemPublicKey,
) -> Result<(), HpkeError> {
let derived = provider.derive_public_key(kem, sender_sk.as_bytes())?;
let pk = sender_pk.as_bytes();
if derived.len() != pk.len() {
return Err(HpkeError::security_error(
SecurityValidation::KeyLength,
"derived sender public key length does not match sender_pk",
));
}
if !constant_time_eq(derived.as_slice(), pk) {
return Err(HpkeError::security_error(
SecurityValidation::ConstantTimeComparison,
"sender public key does not correspond to sender secret key",
));
}
Ok(())
}
fn validate_psk_parameters(
mode: HpkeMode,
psk: Option<&[u8]>,
psk_id: Option<&[u8]>,
) -> Result<(), HpkeError> {
match mode {
HpkeMode::Base => {
if psk.is_some() || psk_id.is_some() {
return Err(HpkeError::CryptoError(
"Base mode does not support PSK parameters".into(),
));
}
}
HpkeMode::Psk => {
if psk.is_none() || psk_id.is_none() {
return Err(HpkeError::CryptoError(
"PSK mode requires both PSK and PSK ID".into(),
));
}
if let Some(psk_bytes) = psk &&
psk_bytes.is_empty()
{
return Err(HpkeError::CryptoError("PSK cannot be empty".into()));
}
}
HpkeMode::Auth => {
if psk.is_some() || psk_id.is_some() {
return Err(HpkeError::CryptoError(
"Auth mode does not support PSK parameters".into(),
));
}
}
HpkeMode::AuthPsk => {
if psk.is_none() || psk_id.is_none() {
return Err(HpkeError::CryptoError(
"AuthPSK mode requires both PSK and PSK ID".into(),
));
}
if let Some(psk_bytes) = psk &&
psk_bytes.is_empty()
{
return Err(HpkeError::CryptoError("PSK cannot be empty".into()));
}
}
}
Ok(())
}
fn psk_commitment_suffix_enabled(mode: HpkeMode, psk_wire_format: HpkePskWireFormat) -> bool {
matches!(
(mode, psk_wire_format),
(
HpkeMode::Psk | HpkeMode::AuthPsk,
HpkePskWireFormat::LibQCommitmentSuffix
)
)
}
pub fn psk_commitment_len(cipher_suite: &HpkeCipherSuite) -> usize {
cipher_suite.kdf.extract_len()
}
pub fn derive_psk_commitment<P: HpkeCryptoProvider + ?Sized>(
psk: &[u8],
psk_id: &[u8],
enc_kem: &[u8],
cipher_suite: &HpkeCipherSuite,
provider: &P,
) -> Result<Zeroizing<Vec<u8>>, HpkeError> {
let suite_id = create_suite_id(cipher_suite)?;
let mut psk_input =
Zeroizing::new(Vec::with_capacity(psk.len() + psk_id.len() + enc_kem.len()));
psk_input.extend_from_slice(psk);
psk_input.extend_from_slice(psk_id);
psk_input.extend_from_slice(enc_kem);
labeled_extract(
cipher_suite.kdf,
b"",
&suite_id,
"psk_commitment",
psk_input.as_slice(),
provider,
)
}
fn main_kem_ciphertext_for_psk_commitment<'a>(
mode: HpkeMode,
cipher_suite: &HpkeCipherSuite,
kem_and_auth_prefix: &'a [u8],
) -> Result<&'a [u8], HpkeError> {
let kem_enc_len = cipher_suite.kem.enc_len();
match mode {
HpkeMode::Psk => {
if kem_and_auth_prefix.len() != kem_enc_len {
return Err(HpkeError::CryptoError(format!(
"Invalid PSK mode KEM prefix length: expected {} bytes, got {}",
kem_enc_len,
kem_and_auth_prefix.len()
)));
}
Ok(kem_and_auth_prefix)
}
HpkeMode::AuthPsk => {
if kem_and_auth_prefix.len() < kem_enc_len {
return Err(HpkeError::CryptoError(format!(
"Invalid AuthPsk KEM prefix: expected at least {} bytes before PSK commitment, got {}",
kem_enc_len,
kem_and_auth_prefix.len()
)));
}
Ok(&kem_and_auth_prefix[..kem_enc_len])
}
_ => Err(HpkeError::CryptoError(
"PSK commitment derivation requires Psk or AuthPsk mode".into(),
)),
}
}
fn verify_psk_commitment<P: HpkeCryptoProvider + ?Sized>(
psk: &[u8],
psk_id: &[u8],
enc_kem: &[u8],
sender_commitment: &[u8],
cipher_suite: &HpkeCipherSuite,
provider: &P,
) -> Result<(), HpkeError> {
let local_commitment = derive_psk_commitment(psk, psk_id, enc_kem, cipher_suite, provider)?;
if !constant_time_eq(local_commitment.as_slice(), sender_commitment) {
return Err(HpkeError::InconsistentPsk);
}
Ok(())
}
fn append_psk_commitment(mut encapsulated_key: Vec<u8>, commitment: Zeroizing<Vec<u8>>) -> Vec<u8> {
encapsulated_key.extend_from_slice(commitment.as_slice());
encapsulated_key
}
struct EncapsulatedKeyParts {
main: Vec<u8>,
auth: Option<Vec<u8>>,
psk_commitment: Option<Vec<u8>>,
}
fn split_encapsulated_key_for_receiver(
encapsulated_key: &[u8],
mode: HpkeMode,
kem_enc_len: usize,
commitment_len: usize,
) -> Result<EncapsulatedKeyParts, HpkeError> {
match mode {
HpkeMode::Base => Ok(EncapsulatedKeyParts {
main: encapsulated_key.to_vec(),
auth: None,
psk_commitment: None,
}),
HpkeMode::Psk => {
let min_len = kem_enc_len + commitment_len;
if encapsulated_key.len() < min_len {
return Err(HpkeError::CryptoError(format!(
"Invalid PSK mode encapsulated key size: {} bytes (expected at least {} bytes)",
encapsulated_key.len(),
min_len
)));
}
let (kem_part, commitment) =
encapsulated_key.split_at(encapsulated_key.len() - commitment_len);
if kem_part.len() != kem_enc_len {
return Err(HpkeError::CryptoError(format!(
"Invalid PSK mode KEM ciphertext size: {} bytes (expected {} bytes)",
kem_part.len(),
kem_enc_len
)));
}
Ok(EncapsulatedKeyParts {
main: kem_part.to_vec(),
auth: None,
psk_commitment: (commitment_len > 0).then(|| commitment.to_vec()),
})
}
HpkeMode::Auth => {
if encapsulated_key.len() < kem_enc_len * 2 {
return Err(HpkeError::CryptoError(format!(
"Invalid Auth mode encapsulated key size: {} bytes (expected at least {} bytes)",
encapsulated_key.len(),
kem_enc_len * 2
)));
}
let (main_part, auth_part) = encapsulated_key.split_at(kem_enc_len);
Ok(EncapsulatedKeyParts {
main: main_part.to_vec(),
auth: Some(auth_part.to_vec()),
psk_commitment: None,
})
}
HpkeMode::AuthPsk => {
let min_len = kem_enc_len * 2 + commitment_len;
if encapsulated_key.len() < min_len {
return Err(HpkeError::CryptoError(format!(
"Invalid AuthPsk mode encapsulated key size: {} bytes (expected at least {} bytes)",
encapsulated_key.len(),
min_len
)));
}
let (prefix, commitment) =
encapsulated_key.split_at(encapsulated_key.len() - commitment_len);
let (main_part, auth_part) = prefix.split_at(kem_enc_len);
Ok(EncapsulatedKeyParts {
main: main_part.to_vec(),
auth: Some(auth_part.to_vec()),
psk_commitment: (commitment_len > 0).then(|| commitment.to_vec()),
})
}
}
}
#[allow(clippy::too_many_arguments)] fn parse_receiver_encapsulated_key<P: HpkeCryptoProvider + ?Sized>(
encapsulated_key: &[u8],
mode: HpkeMode,
cipher_suite: &HpkeCipherSuite,
sender_pk: Option<&lib_q_core::KemPublicKey>,
psk: Option<&[u8]>,
psk_id: Option<&[u8]>,
psk_wire_format: HpkePskWireFormat,
provider: &P,
) -> Result<ParsedReceiverEncapsulatedKey, HpkeError> {
let commitment_len = if psk_commitment_suffix_enabled(mode, psk_wire_format) {
psk_commitment_len(cipher_suite)
} else {
0
};
let kem_algorithm = cipher_suite.kem;
let kem_enc_len = kem_algorithm.enc_len();
match mode {
HpkeMode::Auth | HpkeMode::AuthPsk => {
let sender_pk = sender_pk.ok_or_else(|| {
HpkeError::CryptoError("Auth and AuthPSK modes require sender public key".into())
})?;
let sender_len = sender_pk.as_bytes().len();
let expected_pk = kem_algorithm.public_key_len();
if sender_len != expected_pk {
return Err(HpkeError::CryptoError(format!(
"Invalid sender public key size: {} bytes (expected {} for this cipher suite's KEM)",
sender_len, expected_pk
)));
}
}
HpkeMode::Psk => {
if encapsulated_key.len() < commitment_len {
return Err(HpkeError::CryptoError(format!(
"Invalid PSK mode encapsulated key size: {} bytes (expected at least {} bytes)",
encapsulated_key.len(),
commitment_len
)));
}
let kem_wire_len = encapsulated_key.len() - commitment_len;
if kem_wire_len != kem_enc_len {
return Err(HpkeError::CryptoError(format!(
"Invalid PSK mode KEM ciphertext size: {} bytes (expected {} bytes for this cipher suite)",
kem_wire_len, kem_enc_len
)));
}
}
HpkeMode::Base => {
if encapsulated_key.len() != kem_enc_len {
return Err(HpkeError::CryptoError(format!(
"Invalid Base mode encapsulated key size: {} bytes (expected {} bytes)",
encapsulated_key.len(),
kem_enc_len
)));
}
}
}
let parts =
split_encapsulated_key_for_receiver(encapsulated_key, mode, kem_enc_len, commitment_len)?;
if let Some(sender_commitment) = parts.psk_commitment {
let psk = psk.ok_or(HpkeError::InconsistentPsk)?;
let psk_id = psk_id.ok_or(HpkeError::InconsistentPsk)?;
verify_psk_commitment(
psk,
psk_id,
parts.main.as_slice(),
&sender_commitment,
cipher_suite,
provider,
)?;
}
if parts.main.len() != kem_enc_len {
return Err(HpkeError::CryptoError(format!(
"Internal error: parsed KEM ciphertext length {} does not match cipher suite (expected {})",
parts.main.len(),
kem_enc_len
)));
}
Ok((parts.main, parts.auth, kem_algorithm))
}
fn attach_psk_commitment_to_encapsulated_key<P: HpkeCryptoProvider + ?Sized>(
encapsulated_key: Vec<u8>,
mode: HpkeMode,
psk: Option<&[u8]>,
psk_id: Option<&[u8]>,
psk_wire_format: HpkePskWireFormat,
cipher_suite: &HpkeCipherSuite,
provider: &P,
) -> Result<Vec<u8>, HpkeError> {
if !psk_commitment_suffix_enabled(mode, psk_wire_format) {
return Ok(encapsulated_key);
}
let enc_kem =
main_kem_ciphertext_for_psk_commitment(mode, cipher_suite, encapsulated_key.as_slice())?;
let commitment = derive_psk_commitment(
psk.ok_or(HpkeError::InconsistentPsk)?,
psk_id.ok_or(HpkeError::InconsistentPsk)?,
enc_kem,
cipher_suite,
provider,
)?;
Ok(append_psk_commitment(encapsulated_key, commitment))
}
pub fn setup_sender<P: HpkeCryptoProvider + ?Sized>(
kem_ctx: &mut lib_q_core::KemContext,
recipient_pk: &lib_q_core::KemPublicKey,
info: &[u8],
cipher_suite: &HpkeCipherSuite,
provider: &P,
rng: &mut dyn CryptoRng,
hpke_crypto: Arc<dyn HpkeCryptoProvider + Send + Sync>,
) -> Result<HpkeSenderContext, HpkeError> {
setup_sender_with_mode(
kem_ctx,
recipient_pk,
info,
cipher_suite,
provider,
rng,
HpkeMode::Base,
None,
None,
None,
None,
HpkePskWireFormat::default(),
hpke_crypto,
)
}
#[allow(clippy::too_many_arguments)]
pub fn setup_sender_with_mode<P: HpkeCryptoProvider + ?Sized>(
kem_ctx: &mut lib_q_core::KemContext,
recipient_pk: &lib_q_core::KemPublicKey,
info: &[u8],
cipher_suite: &HpkeCipherSuite,
provider: &P,
rng: &mut dyn CryptoRng,
mode: HpkeMode,
psk: Option<&[u8]>,
psk_id: Option<&[u8]>,
sender_sk: Option<&lib_q_core::KemSecretKey>,
sender_pk: Option<&lib_q_core::KemPublicKey>,
psk_wire_format: HpkePskWireFormat,
hpke_crypto: Arc<dyn HpkeCryptoProvider + Send + Sync>,
) -> Result<HpkeSenderContext, HpkeError> {
validate_psk_parameters(mode, psk, psk_id)?;
ensure_cipher_suite_supported(cipher_suite, provider)?;
match mode {
HpkeMode::Base | HpkeMode::Psk => {
if sender_sk.is_some() || sender_pk.is_some() {
return Err(HpkeError::CryptoError(
"Base and PSK modes do not support sender authentication".into(),
));
}
}
HpkeMode::Auth | HpkeMode::AuthPsk => {
if sender_sk.is_none() || sender_pk.is_none() {
return Err(HpkeError::CryptoError(
"Auth and AuthPSK modes require sender key pair".into(),
));
}
}
}
let kem_algorithm = cipher_suite.kem;
let pk_size = recipient_pk.as_bytes().len();
let expected_pk_len = kem_algorithm.public_key_len();
if pk_size != expected_pk_len {
return Err(HpkeError::CryptoError(format!(
"Invalid recipient public key size: {} bytes (expected {} for this cipher suite's KEM)",
pk_size, expected_pk_len
)));
}
validate_kem_context_for_algorithm(kem_ctx, kem_algorithm)?;
if let Some(sender_pk) = sender_pk {
let sender_pk_size = sender_pk.as_bytes().len();
if sender_pk_size != expected_pk_len {
return Err(HpkeError::CryptoError(format!(
"Invalid sender public key size: {} bytes (expected {})",
sender_pk_size, expected_pk_len
)));
}
}
if matches!(mode, HpkeMode::Auth | HpkeMode::AuthPsk) {
let sender_sk = sender_sk.expect("validated above");
let sender_pk = sender_pk.expect("validated above");
let expected_sk_len = kem_algorithm.secret_key_len();
let expected_pk_len = kem_algorithm.public_key_len();
if sender_sk.as_bytes().len() != expected_sk_len {
return Err(HpkeError::CryptoError(format!(
"Invalid sender secret key size: {} bytes (expected {})",
sender_sk.as_bytes().len(),
expected_sk_len
)));
}
if sender_pk.as_bytes().len() != expected_pk_len {
return Err(HpkeError::CryptoError(format!(
"Invalid sender public key size: {} bytes (expected {})",
sender_pk.as_bytes().len(),
expected_pk_len
)));
}
verify_sender_keypair_binding(provider, kem_algorithm, sender_sk, sender_pk)?;
}
let (encapsulated_key, mut main_shared_secret) =
provider.encapsulate(kem_algorithm, recipient_pk.as_bytes(), rng)?;
let (auth_shared_secret, auth_encapsulated_key) =
if matches!(mode, HpkeMode::Auth | HpkeMode::AuthPsk) {
let sender_sk = sender_sk.expect("validated above");
let (auth_encapsulated_key, auth_kem_secret) = provider.auth_encapsulate(
kem_algorithm,
sender_sk.as_bytes(),
recipient_pk.as_bytes(),
rng,
)?;
main_shared_secret.extend_from_slice(auth_kem_secret.as_slice());
(main_shared_secret, Some(auth_encapsulated_key))
} else {
(main_shared_secret, None)
};
let schedule = key_schedule(
mode,
auth_shared_secret.as_slice(),
info,
psk,
psk_id,
cipher_suite,
provider,
)?;
let kem_encapsulated_key = if let Some(auth_encap) = auth_encapsulated_key {
let mut combined = encapsulated_key.clone();
combined.extend_from_slice(&auth_encap);
combined
} else {
encapsulated_key
};
let final_encapsulated_key = attach_psk_commitment_to_encapsulated_key(
kem_encapsulated_key,
mode,
psk,
psk_id,
psk_wire_format,
cipher_suite,
provider,
)?;
Ok(HpkeSenderContext {
shared_secret: auth_shared_secret,
exporter_secret: schedule.exporter_secret,
key: schedule.key,
nonce: schedule.nonce,
cipher_suite: *cipher_suite,
aead: cipher_suite.aead,
encapsulated_key: final_encapsulated_key,
sequence_number: 0,
max_sequence_number: u32::MAX - 1,
state: HpkeContextState::Active,
hpke_crypto,
})
}
pub fn setup_receiver<P: HpkeCryptoProvider + ?Sized>(
kem_ctx: &mut lib_q_core::KemContext,
encapsulated_key: &[u8],
recipient_sk: &lib_q_core::KemSecretKey,
info: &[u8],
cipher_suite: &HpkeCipherSuite,
provider: &P,
hpke_crypto: Arc<dyn HpkeCryptoProvider + Send + Sync>,
) -> Result<HpkeReceiverContext, HpkeError> {
setup_receiver_with_mode(
kem_ctx,
encapsulated_key,
recipient_sk,
info,
cipher_suite,
provider,
HpkeMode::Base,
None,
None,
None,
HpkePskWireFormat::default(),
hpke_crypto,
)
}
#[allow(clippy::too_many_arguments)]
pub fn setup_receiver_with_mode<P: HpkeCryptoProvider + ?Sized>(
kem_ctx: &mut lib_q_core::KemContext,
encapsulated_key: &[u8],
recipient_sk: &lib_q_core::KemSecretKey,
info: &[u8],
cipher_suite: &HpkeCipherSuite,
provider: &P,
mode: HpkeMode,
psk: Option<&[u8]>,
psk_id: Option<&[u8]>,
sender_pk: Option<&lib_q_core::KemPublicKey>,
psk_wire_format: HpkePskWireFormat,
hpke_crypto: Arc<dyn HpkeCryptoProvider + Send + Sync>,
) -> Result<HpkeReceiverContext, HpkeError> {
validate_psk_parameters(mode, psk, psk_id)?;
ensure_cipher_suite_supported(cipher_suite, provider)?;
match mode {
HpkeMode::Base | HpkeMode::Psk => {
if sender_pk.is_some() {
return Err(HpkeError::CryptoError(
"Base and PSK modes do not support sender authentication".into(),
));
}
}
HpkeMode::Auth | HpkeMode::AuthPsk => {
if sender_pk.is_none() {
return Err(HpkeError::CryptoError(
"Auth and AuthPSK modes require sender public key".into(),
));
}
}
}
let (main_encapsulated_key, auth_encapsulated_key, kem_algorithm) =
parse_receiver_encapsulated_key(
encapsulated_key,
mode,
cipher_suite,
sender_pk,
psk,
psk_id,
psk_wire_format,
provider,
)?;
validate_kem_context_for_algorithm(kem_ctx, kem_algorithm)?;
let mut main_shared_secret = provider.decapsulate(
kem_algorithm,
recipient_sk.as_bytes(),
&main_encapsulated_key,
)?;
let auth_shared_secret = if matches!(mode, HpkeMode::Auth | HpkeMode::AuthPsk) {
let sender_pk = sender_pk.unwrap(); let auth_encap = auth_encapsulated_key.unwrap();
let auth_kem_secret = provider.auth_decapsulate(
kem_algorithm,
&auth_encap,
recipient_sk.as_bytes(),
sender_pk.as_bytes(),
)?;
main_shared_secret.extend_from_slice(auth_kem_secret.as_slice());
main_shared_secret
} else {
main_shared_secret
};
let schedule = key_schedule(
mode,
auth_shared_secret.as_slice(),
info,
psk,
psk_id,
cipher_suite,
provider,
)?;
Ok(HpkeReceiverContext {
shared_secret: auth_shared_secret,
exporter_secret: schedule.exporter_secret,
key: schedule.key,
nonce: schedule.nonce,
cipher_suite: *cipher_suite,
aead: cipher_suite.aead,
sequence_number: 0,
max_sequence_number: u32::MAX - 1,
state: HpkeContextState::Active,
hpke_crypto,
})
}
#[allow(clippy::too_many_arguments)]
pub fn seal<P: HpkeCryptoProvider + ?Sized>(
kem_ctx: &mut lib_q_core::KemContext,
recipient_pk: &lib_q_core::KemPublicKey,
info: &[u8],
aad: &[u8],
plaintext: &[u8],
cipher_suite: &HpkeCipherSuite,
provider: &P,
rng: &mut dyn CryptoRng,
) -> Result<(Vec<u8>, Vec<u8>), HpkeError> {
seal_with_mode(
kem_ctx,
recipient_pk,
info,
aad,
plaintext,
cipher_suite,
provider,
rng,
HpkeMode::Base,
None,
None,
None,
None,
HpkePskWireFormat::default(),
)
}
#[allow(clippy::too_many_arguments)]
pub fn seal_with_mode<P: HpkeCryptoProvider + ?Sized>(
kem_ctx: &mut lib_q_core::KemContext,
recipient_pk: &lib_q_core::KemPublicKey,
info: &[u8],
aad: &[u8],
plaintext: &[u8],
cipher_suite: &HpkeCipherSuite,
provider: &P,
rng: &mut dyn CryptoRng,
mode: HpkeMode,
psk: Option<&[u8]>,
psk_id: Option<&[u8]>,
sender_sk: Option<&lib_q_core::KemSecretKey>,
sender_pk: Option<&lib_q_core::KemPublicKey>,
psk_wire_format: HpkePskWireFormat,
) -> Result<(Vec<u8>, Vec<u8>), HpkeError> {
match mode {
HpkeMode::Base => {
if psk.is_some() || psk_id.is_some() || sender_sk.is_some() || sender_pk.is_some() {
return Err(HpkeError::CryptoError(
"Base mode does not support PSK or sender authentication".into(),
));
}
}
HpkeMode::Psk => {
if psk.is_none() || psk_id.is_none() {
return Err(HpkeError::CryptoError(
"PSK mode requires both PSK and PSK ID".into(),
));
}
if sender_sk.is_some() || sender_pk.is_some() {
return Err(HpkeError::CryptoError(
"PSK mode does not support sender authentication".into(),
));
}
}
HpkeMode::Auth => {
if sender_sk.is_none() || sender_pk.is_none() {
return Err(HpkeError::CryptoError(
"Auth mode requires sender key pair".into(),
));
}
if psk.is_some() || psk_id.is_some() {
return Err(HpkeError::CryptoError(
"Auth mode does not support PSK".into(),
));
}
}
HpkeMode::AuthPsk => {
if psk.is_none() || psk_id.is_none() || sender_sk.is_none() || sender_pk.is_none() {
return Err(HpkeError::CryptoError(
"AuthPSK mode requires PSK, PSK ID, and sender key pair".into(),
));
}
}
}
ensure_cipher_suite_supported(cipher_suite, provider)?;
let kem_algorithm = cipher_suite.kem;
let expected_pk_len = kem_algorithm.public_key_len();
let pk_size = recipient_pk.as_bytes().len();
if pk_size != expected_pk_len {
return Err(HpkeError::CryptoError(format!(
"Invalid recipient public key size: {} bytes (expected {} for this cipher suite's KEM)",
pk_size, expected_pk_len
)));
}
validate_kem_context_for_algorithm(kem_ctx, kem_algorithm)?;
if matches!(mode, HpkeMode::Auth | HpkeMode::AuthPsk) {
let sender_sk = sender_sk.expect("validated in match above");
let sender_pk = sender_pk.expect("validated in match above");
let expected_sk_len = kem_algorithm.secret_key_len();
let expected_pk_len = kem_algorithm.public_key_len();
if sender_sk.as_bytes().len() != expected_sk_len {
return Err(HpkeError::CryptoError(format!(
"Invalid sender secret key size: {} bytes (expected {})",
sender_sk.as_bytes().len(),
expected_sk_len
)));
}
if sender_pk.as_bytes().len() != expected_pk_len {
return Err(HpkeError::CryptoError(format!(
"Invalid sender public key size: {} bytes (expected {})",
sender_pk.as_bytes().len(),
expected_pk_len
)));
}
verify_sender_keypair_binding(provider, kem_algorithm, sender_sk, sender_pk)?;
}
let (encapsulated_key, mut main_shared_secret) =
provider.encapsulate(kem_algorithm, recipient_pk.as_bytes(), rng)?;
let (auth_shared_secret, auth_encapsulated_key) =
if matches!(mode, HpkeMode::Auth | HpkeMode::AuthPsk) {
let sender_sk = sender_sk.expect("validated above");
let (auth_encapsulated_key, auth_kem_secret) = provider.auth_encapsulate(
kem_algorithm,
sender_sk.as_bytes(),
recipient_pk.as_bytes(),
rng,
)?;
main_shared_secret.extend_from_slice(auth_kem_secret.as_slice());
(main_shared_secret, Some(auth_encapsulated_key))
} else {
(main_shared_secret, None)
};
let schedule = key_schedule(
mode,
auth_shared_secret.as_slice(),
info,
psk,
psk_id,
cipher_suite,
provider,
)?;
let ciphertext = seal_message(
cipher_suite.aead,
schedule.key.as_slice(),
schedule.nonce.as_slice(),
0,
aad,
plaintext,
provider,
)?;
let kem_encapsulated_key = if let Some(auth_encap) = auth_encapsulated_key {
let mut combined = encapsulated_key.clone();
combined.extend_from_slice(&auth_encap);
combined
} else {
encapsulated_key
};
let final_encapsulated_key = attach_psk_commitment_to_encapsulated_key(
kem_encapsulated_key,
mode,
psk,
psk_id,
psk_wire_format,
cipher_suite,
provider,
)?;
Ok((final_encapsulated_key, ciphertext))
}
#[allow(clippy::too_many_arguments)]
pub fn open<P: HpkeCryptoProvider + ?Sized>(
kem_ctx: &mut lib_q_core::KemContext,
encapsulated_key: &[u8],
recipient_sk: &lib_q_core::KemSecretKey,
info: &[u8],
aad: &[u8],
ciphertext: &[u8],
cipher_suite: &HpkeCipherSuite,
provider: &P,
hpke_crypto: Arc<dyn HpkeCryptoProvider + Send + Sync>,
) -> Result<Vec<u8>, HpkeError> {
open_with_mode(
kem_ctx,
encapsulated_key,
recipient_sk,
info,
aad,
ciphertext,
cipher_suite,
provider,
HpkeMode::Base,
None,
None,
None,
HpkePskWireFormat::default(),
hpke_crypto,
)
}
#[allow(clippy::too_many_arguments)]
pub fn open_with_mode<P: HpkeCryptoProvider + ?Sized>(
kem_ctx: &mut lib_q_core::KemContext,
encapsulated_key: &[u8],
recipient_sk: &lib_q_core::KemSecretKey,
info: &[u8],
aad: &[u8],
ciphertext: &[u8],
cipher_suite: &HpkeCipherSuite,
provider: &P,
mode: HpkeMode,
psk: Option<&[u8]>,
psk_id: Option<&[u8]>,
sender_pk: Option<&lib_q_core::KemPublicKey>,
psk_wire_format: HpkePskWireFormat,
hpke_crypto: Arc<dyn HpkeCryptoProvider + Send + Sync>,
) -> Result<Vec<u8>, HpkeError> {
let receiver_ctx = setup_receiver_with_mode(
kem_ctx,
encapsulated_key,
recipient_sk,
info,
cipher_suite,
provider,
mode,
psk,
psk_id,
sender_pk,
psk_wire_format,
hpke_crypto.clone(),
)?;
open_message(
receiver_ctx.aead,
receiver_ctx.key.as_slice(),
receiver_ctx.nonce.as_slice(),
0,
aad,
ciphertext,
receiver_ctx.hpke_crypto.as_ref(),
)
}
pub fn seal_message<P: HpkeCryptoProvider + ?Sized>(
aead: HpkeAead,
key: &[u8],
base_nonce: &[u8],
sequence_number: u32,
aad: &[u8],
plaintext: &[u8],
provider: &P,
) -> Result<Vec<u8>, HpkeError> {
let nonce = compute_nonce(base_nonce, sequence_number);
provider.seal(aead, key, nonce.as_slice(), aad, plaintext)
}
pub fn open_message<P: HpkeCryptoProvider + ?Sized>(
aead: HpkeAead,
key: &[u8],
base_nonce: &[u8],
sequence_number: u32,
aad: &[u8],
ciphertext: &[u8],
provider: &P,
) -> Result<Vec<u8>, HpkeError> {
let nonce = compute_nonce(base_nonce, sequence_number);
provider.open(aead, key, nonce.as_slice(), aad, ciphertext)
}
pub fn export<P: HpkeCryptoProvider + ?Sized>(
exporter_secret: &[u8],
exporter_context: &[u8],
length: usize,
cipher_suite: &HpkeCipherSuite,
provider: &P,
) -> Result<Vec<u8>, HpkeError> {
let nh = cipher_suite.kdf.extract_len();
let max_l = nh
.checked_mul(255)
.ok_or_else(|| HpkeError::CryptoError("export length bound (255*Nh) overflowed".into()))?;
if length > max_l {
return Err(HpkeError::CryptoError(format!(
"export length {length} exceeds RFC 9180 maximum 255*Nh ({max_l})"
)));
}
let suite_id = create_suite_id(cipher_suite)?;
let out = labeled_expand(
cipher_suite.kdf,
exporter_secret,
&suite_id,
"sec",
exporter_context,
length,
provider,
)?;
Ok(out.to_vec())
}
fn compute_nonce(base_nonce: &[u8], sequence_number: u32) -> Zeroizing<Vec<u8>> {
let mut nonce = Zeroizing::new(base_nonce.to_vec());
let n = nonce.len();
if n == 0 {
return nonce;
}
let seq = sequence_number as u64;
for i in 0..n {
let shift = 8usize.saturating_mul(n.saturating_sub(1).saturating_sub(i));
let seq_byte = if shift < 64 {
((seq >> shift) & 0xFF) as u8
} else {
0
};
nonce[i] ^= seq_byte;
}
nonce
}
pub fn create_suite_id(cipher_suite: &HpkeCipherSuite) -> Result<Vec<u8>, HpkeError> {
let mut suite_id = Vec::new();
suite_id.extend_from_slice(b"HPKE");
suite_id.extend_from_slice(&cipher_suite.identifier());
Ok(suite_id)
}
pub fn key_schedule<P: HpkeCryptoProvider + ?Sized>(
mode: HpkeMode,
shared_secret: &[u8],
info: &[u8],
psk: Option<&[u8]>,
psk_id: Option<&[u8]>,
cipher_suite: &HpkeCipherSuite,
provider: &P,
) -> Result<KeyScheduleSecrets, HpkeError> {
validate_psk_parameters(mode, psk, psk_id)?;
let suite_id = create_suite_id(cipher_suite)?;
let kdf = cipher_suite.kdf;
let n_h = kdf.extract_len();
let psk_id_slice = psk_id.unwrap_or(&[]);
let psk_slice = psk.unwrap_or(&[]);
let psk_id_hash = labeled_extract(kdf, b"", &suite_id, "psk_id_hash", psk_id_slice, provider)?;
if psk_id_hash.len() != n_h {
return Err(HpkeError::CryptoError(format!(
"internal HPKE error: psk_id_hash length {} (expected {})",
psk_id_hash.len(),
n_h
)));
}
let info_hash = labeled_extract(kdf, b"", &suite_id, "info_hash", info, provider)?;
if info_hash.len() != n_h {
return Err(HpkeError::CryptoError(format!(
"internal HPKE error: info_hash length {} (expected {})",
info_hash.len(),
n_h
)));
}
let mut key_schedule_context = Vec::with_capacity(1 + n_h + n_h);
key_schedule_context.push(mode.as_u8());
key_schedule_context.extend_from_slice(psk_id_hash.as_slice());
key_schedule_context.extend_from_slice(info_hash.as_slice());
let secret = labeled_extract(kdf, shared_secret, &suite_id, "secret", psk_slice, provider)?;
let secret_slice = secret.as_slice();
let key = labeled_expand(
kdf,
secret_slice,
&suite_id,
"key",
key_schedule_context.as_slice(),
cipher_suite.aead.key_len(),
provider,
)?;
let nonce = labeled_expand(
kdf,
secret_slice,
&suite_id,
"base_nonce",
key_schedule_context.as_slice(),
cipher_suite.aead.nonce_len(),
provider,
)?;
let exporter_secret = labeled_expand(
kdf,
secret_slice,
&suite_id,
"exp",
key_schedule_context.as_slice(),
n_h,
provider,
)?;
Ok(KeyScheduleSecrets {
key,
nonce,
exporter_secret,
})
}
pub fn labeled_extract<P: HpkeCryptoProvider + ?Sized>(
kdf: HpkeKdf,
salt: &[u8],
suite_id: &[u8],
label: &str,
ikm: &[u8],
provider: &P,
) -> Result<Zeroizing<Vec<u8>>, HpkeError> {
let mut labeled_ikm = Zeroizing::new(Vec::with_capacity(
b"HPKE-v1".len() + suite_id.len() + label.len() + ikm.len(),
));
labeled_ikm.extend_from_slice(b"HPKE-v1");
labeled_ikm.extend_from_slice(suite_id);
labeled_ikm.extend_from_slice(label.as_bytes());
labeled_ikm.extend_from_slice(ikm);
let prk = provider.extract(kdf, salt, labeled_ikm.as_slice())?;
Ok(Zeroizing::new(prk))
}
pub fn labeled_expand<P: HpkeCryptoProvider + ?Sized>(
kdf: HpkeKdf,
prk: &[u8],
suite_id: &[u8],
label: &str,
info: &[u8],
length: usize,
provider: &P,
) -> Result<Zeroizing<Vec<u8>>, HpkeError> {
let length_u16 = u16::try_from(length).map_err(|_| {
HpkeError::CryptoError(
"LabeledExpand length L must fit in 16 bits (RFC 9180 I2OSP(L, 2))".into(),
)
})?;
let mut labeled_info = Zeroizing::new(Vec::with_capacity(
2 + b"HPKE-v1".len() + suite_id.len() + label.len() + info.len(),
));
labeled_info.extend_from_slice(&length_u16.to_be_bytes());
labeled_info.extend_from_slice(b"HPKE-v1");
labeled_info.extend_from_slice(suite_id);
labeled_info.extend_from_slice(label.as_bytes());
labeled_info.extend_from_slice(info);
let okm = provider.expand(kdf, prk, labeled_info.as_slice(), length)?;
Ok(Zeroizing::new(okm))
}
fn validate_kem_context_for_algorithm(
kem_ctx: &mut lib_q_core::KemContext,
kem_algorithm: HpkeKem,
) -> Result<(), HpkeError> {
let core_algorithm = match kem_algorithm {
HpkeKem::MlKem512 => lib_q_core::Algorithm::MlKem512,
HpkeKem::MlKem768 => lib_q_core::Algorithm::MlKem768,
HpkeKem::MlKem1024 => lib_q_core::Algorithm::MlKem1024,
};
if core_algorithm.category() != lib_q_core::AlgorithmCategory::Kem {
return Err(HpkeError::CryptoError(format!(
"Invalid algorithm category for HPKE: expected KEM, got {:?}",
core_algorithm.category()
)));
}
let security_level = core_algorithm.security_level();
match security_level {
1 | 3 | 4 => {
}
_ => {
return Err(HpkeError::CryptoError(format!(
"Unsupported security level for HPKE: {} (expected 1, 3, or 4)",
security_level
)));
}
}
let test_result = kem_ctx.generate_keypair(core_algorithm, None);
match test_result {
Ok(_) => {
}
Err(lib_q_core::Error::NotImplemented { feature }) => {
if feature.contains("no provider configured") {
return Err(HpkeError::CryptoError(
"KEM context must have a cryptographic provider configured".into(),
));
} else {
return Err(HpkeError::CryptoError(format!(
"KEM algorithm {:?} is not implemented by the configured provider: {}",
kem_algorithm, feature
)));
}
}
Err(lib_q_core::Error::InvalidState { operation, reason }) => {
return Err(HpkeError::CryptoError(format!(
"KEM context in invalid state for {}: {}",
operation, reason
)));
}
Err(_e) => {
}
}
Ok(())
}
#[cfg(test)]
mod psk_commitment_tests {
use alloc::vec;
use super::*;
use crate::providers::post_quantum::PostQuantumProvider;
fn test_cipher_suite() -> HpkeCipherSuite {
HpkeCipherSuite::new(
HpkeKem::MlKem512,
HpkeKdf::HkdfShake256,
HpkeAead::Saturnin256,
)
}
#[test]
fn derive_psk_commitment_is_deterministic() {
let provider = PostQuantumProvider::new();
let suite = test_cipher_suite();
let psk = b"test-psk";
let psk_id = b"test-id";
let enc_kem = vec![0x42u8; HpkeKem::MlKem512.enc_len()];
let c1 = derive_psk_commitment(psk, psk_id, &enc_kem, &suite, &provider).unwrap();
let c2 = derive_psk_commitment(psk, psk_id, &enc_kem, &suite, &provider).unwrap();
assert_eq!(c1, c2);
assert_eq!(c1.len(), psk_commitment_len(&suite));
}
#[test]
fn derive_psk_commitment_differs_across_kem_ciphertexts() {
let provider = PostQuantumProvider::new();
let suite = test_cipher_suite();
let psk = b"test-psk";
let psk_id = b"test-id";
let enc_a = vec![0u8; HpkeKem::MlKem512.enc_len()];
let mut enc_b = enc_a.clone();
enc_b[0] = 1;
let c_a = derive_psk_commitment(psk, psk_id, &enc_a, &suite, &provider).unwrap();
let c_b = derive_psk_commitment(psk, psk_id, &enc_b, &suite, &provider).unwrap();
assert_ne!(c_a, c_b);
}
#[test]
fn verify_psk_commitment_rejects_mismatch() {
let provider = PostQuantumProvider::new();
let suite = test_cipher_suite();
let enc_kem = vec![0x11u8; HpkeKem::MlKem512.enc_len()];
let sender =
derive_psk_commitment(b"sender-psk", b"id", &enc_kem, &suite, &provider).unwrap();
let result =
verify_psk_commitment(b"receiver-psk", b"id", &enc_kem, &sender, &suite, &provider);
assert_eq!(result, Err(HpkeError::InconsistentPsk));
}
#[test]
fn split_encapsulated_key_extracts_psk_commitment_suffix() {
let kem_enc_len = HpkeKem::MlKem512.enc_len();
let commitment_len = HpkeKdf::HkdfShake256.extract_len();
let mut wire = vec![0xAB; kem_enc_len];
wire.extend_from_slice(&vec![0xCD; commitment_len]);
let parts =
split_encapsulated_key_for_receiver(&wire, HpkeMode::Psk, kem_enc_len, commitment_len)
.unwrap();
assert_eq!(parts.main.len(), kem_enc_len);
assert!(parts.auth.is_none());
assert_eq!(parts.psk_commitment, Some(vec![0xCD; commitment_len]));
}
}