#[cfg(feature = "alloc")]
use alloc::boxed::Box;
#[cfg(feature = "alloc")]
use alloc::format;
#[cfg(feature = "alloc")]
use alloc::string::ToString;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use lib_q_aead::create_aead;
use lib_q_core::{
Aead as CoreAead,
AeadKey,
Algorithm,
Hash as CoreHash,
KemOperations,
Nonce,
};
use lib_q_hash::digest::Digest;
use lib_q_hash::{
HashAlgorithm,
create_hash,
};
use lib_q_kem::LibQKemProvider;
use zeroize::Zeroizing;
use crate::error::{
AeadOperation,
HpkeError,
};
use crate::kdf::hkdf::HkdfImpl;
use crate::providers::traits::*;
use crate::security::CryptoRng;
use crate::types::*;
pub struct PostQuantumProvider;
impl Default for PostQuantumProvider {
fn default() -> Self {
Self::new()
}
}
impl PostQuantumProvider {
pub fn new() -> Self {
Self
}
fn hpke_kem_to_algorithm(kem: HpkeKem) -> Result<Algorithm, HpkeError> {
match kem {
HpkeKem::MlKem512 => Ok(Algorithm::MlKem512),
HpkeKem::MlKem768 => Ok(Algorithm::MlKem768),
HpkeKem::MlKem1024 => Ok(Algorithm::MlKem1024),
}
}
fn create_kem_provider() -> Result<LibQKemProvider, HpkeError> {
LibQKemProvider::new()
.map_err(|e| HpkeError::CryptoError(format!("Failed to create KEM provider: {}", e)))
}
fn create_hash_instance(kdf: HpkeKdf) -> Result<Box<dyn CoreHash>, HpkeError> {
let algorithm = match kdf {
HpkeKdf::HkdfShake128 => HashAlgorithm::Shake128,
HpkeKdf::HkdfShake256 => HashAlgorithm::Shake256,
HpkeKdf::HkdfSha3_256 => HashAlgorithm::Sha3_256,
HpkeKdf::HkdfSha3_512 => HashAlgorithm::Sha3_512,
};
create_hash(algorithm)
.map_err(|e| HpkeError::CryptoError(format!("Failed to create hash instance: {}", e)))
}
fn create_aead_instance(aead: HpkeAead) -> Result<Box<dyn CoreAead>, HpkeError> {
let algorithm = match aead {
HpkeAead::Saturnin256 => Algorithm::Saturnin,
HpkeAead::Shake256 => Algorithm::Shake256Aead,
HpkeAead::DuplexSpongeAead => {
#[cfg(feature = "duplex-sponge-aead")]
{
Algorithm::DuplexSpongeAead
}
#[cfg(not(feature = "duplex-sponge-aead"))]
{
return Err(HpkeError::feature_not_enabled(
"duplex-sponge-aead (enable lib-q-hpke feature duplex-sponge-aead)",
));
}
}
HpkeAead::Export => return Err(HpkeError::not_implemented("Export-only AEAD")),
};
let aead_instance: Box<dyn CoreAead> = create_aead(algorithm).map_err(|e| {
HpkeError::CryptoError(format!("Failed to create AEAD instance: {}", e))
})?;
Ok(aead_instance)
}
}
impl KemProvider for PostQuantumProvider {
fn generate_keypair(
&self,
kem: HpkeKem,
_rng: &mut dyn CryptoRng,
) -> Result<(Vec<u8>, Zeroizing<Vec<u8>>), HpkeError> {
let provider = Self::create_kem_provider()?;
let algorithm = Self::hpke_kem_to_algorithm(kem)?;
let keypair = provider
.generate_keypair(algorithm, None)
.map_err(|e| HpkeError::CryptoError(format!("KEM key generation failed: {}", e)))?;
Ok((
keypair.public_key().as_bytes().to_vec(),
Zeroizing::new(keypair.secret_key().as_bytes().to_vec()),
))
}
fn encapsulate(
&self,
kem: HpkeKem,
public_key: &[u8],
_rng: &mut dyn CryptoRng,
) -> Result<(Vec<u8>, Zeroizing<Vec<u8>>), HpkeError> {
let provider = Self::create_kem_provider()?;
let algorithm = Self::hpke_kem_to_algorithm(kem)?;
let pk = lib_q_core::KemPublicKey::new(public_key.to_vec());
let (ct, ss) = provider
.encapsulate(algorithm, &pk, None)
.map_err(|e| HpkeError::CryptoError(format!("KEM encapsulation failed: {}", e)))?;
Ok((ct, Zeroizing::new(ss)))
}
fn decapsulate(
&self,
kem: HpkeKem,
secret_key: &[u8],
ciphertext: &[u8],
) -> Result<Zeroizing<Vec<u8>>, HpkeError> {
let provider = Self::create_kem_provider()?;
let algorithm = Self::hpke_kem_to_algorithm(kem)?;
let sk = lib_q_core::KemSecretKey::new(secret_key.to_vec());
let ss = provider
.decapsulate(algorithm, &sk, ciphertext)
.map_err(|e| HpkeError::CryptoError(format!("KEM decapsulation failed: {}", e)))?;
Ok(Zeroizing::new(ss))
}
fn validate_key(&self, kem: HpkeKem, key: &[u8], is_secret: bool) -> Result<(), HpkeError> {
let expected_len = if is_secret {
kem.secret_key_len()
} else {
kem.public_key_len()
};
if key.len() != expected_len {
return Err(HpkeError::invalid_input(
"key",
format!("{} bytes", key.len()),
format!("{} bytes", expected_len),
));
}
if key.iter().all(|&b| b == 0) {
return Err(HpkeError::CryptoError(
"Key material cannot be all zeros".to_string(),
));
}
Ok(())
}
fn derive_public_key(&self, kem: HpkeKem, secret_key: &[u8]) -> Result<Vec<u8>, HpkeError> {
let provider = Self::create_kem_provider()?;
let algorithm = Self::hpke_kem_to_algorithm(kem)?;
let secret_key_obj = lib_q_core::KemSecretKey::new(secret_key.to_vec());
let public_key_obj = provider
.derive_public_key(algorithm, &secret_key_obj)
.map_err(|e| HpkeError::CryptoError(format!("Failed to derive public key: {}", e)))?;
Ok(public_key_obj.as_bytes().to_vec())
}
fn supports_kem(&self, kem: HpkeKem) -> bool {
match kem {
HpkeKem::MlKem512 | HpkeKem::MlKem768 | HpkeKem::MlKem1024 => {
#[cfg(feature = "ml-kem")]
{
crate::kem::ml_kem::is_ml_kem_available()
}
#[cfg(not(feature = "ml-kem"))]
{
false
}
}
}
}
fn auth_encapsulate(
&self,
kem: HpkeKem,
sender_sk: &[u8],
recipient_pk: &[u8],
_rng: &mut dyn CryptoRng,
) -> Result<(Vec<u8>, Zeroizing<Vec<u8>>), HpkeError> {
let expected_sender_sk_len = kem.secret_key_len();
if sender_sk.len() != expected_sender_sk_len {
return Err(HpkeError::invalid_input(
"sender_sk",
format!("{} bytes", sender_sk.len()),
format!("{} bytes", expected_sender_sk_len),
));
}
let expected_recipient_pk_len = kem.public_key_len();
if recipient_pk.len() != expected_recipient_pk_len {
return Err(HpkeError::invalid_input(
"recipient_pk",
format!("{} bytes", recipient_pk.len()),
format!("{} bytes", expected_recipient_pk_len),
));
}
let sender_pk_bytes = self.derive_public_key(kem, sender_sk)?;
let sender_pk_obj = lib_q_core::KemPublicKey::new(sender_pk_bytes);
let recipient_pk_obj = lib_q_core::KemPublicKey::new(recipient_pk.to_vec());
let provider = Self::create_kem_provider()?;
let algorithm = Self::hpke_kem_to_algorithm(kem)?;
let (encapsulated_key, shared_secret) = provider
.encapsulate(algorithm, &recipient_pk_obj, None)
.map_err(|e| HpkeError::CryptoError(format!("AuthEncap failed: {}", e)))?;
let shared_secret = Zeroizing::new(shared_secret);
let auth_tag = self.create_auth_tag(
shared_secret.as_slice(),
sender_pk_obj.as_bytes(),
&encapsulated_key,
)?;
let _sender_commitment = self.create_sender_commitment_with_pk(
sender_sk,
sender_pk_obj.as_bytes(),
&encapsulated_key,
)?;
let _basic_commitment = self.create_sender_commitment(sender_sk, &encapsulated_key)?;
let mut authenticated_encapsulated_key = encapsulated_key;
authenticated_encapsulated_key.extend_from_slice(&auth_tag);
Ok((authenticated_encapsulated_key, shared_secret))
}
fn auth_decapsulate(
&self,
kem: HpkeKem,
encapsulated_key: &[u8],
recipient_sk: &[u8],
sender_pk: &[u8],
) -> Result<Zeroizing<Vec<u8>>, HpkeError> {
let auth_tag_len = self.get_auth_tag_length()?;
let expected_enc_len = kem.enc_len() + auth_tag_len;
if encapsulated_key.len() != expected_enc_len {
return Err(HpkeError::invalid_input(
"encapsulated_key",
format!("{} bytes", encapsulated_key.len()),
format!("{} bytes", expected_enc_len),
));
}
let expected_recipient_sk_len = kem.secret_key_len();
if recipient_sk.len() != expected_recipient_sk_len {
return Err(HpkeError::invalid_input(
"recipient_sk",
format!("{} bytes", recipient_sk.len()),
format!("{} bytes", expected_recipient_sk_len),
));
}
let expected_sender_pk_len = kem.public_key_len();
if sender_pk.len() != expected_sender_pk_len {
return Err(HpkeError::invalid_input(
"sender_pk",
format!("{} bytes", sender_pk.len()),
format!("{} bytes", expected_sender_pk_len),
));
}
let recipient_sk_obj = lib_q_core::KemSecretKey::new(recipient_sk.to_vec());
let sender_pk_obj = lib_q_core::KemPublicKey::new(sender_pk.to_vec());
if sender_pk_obj.as_bytes().is_empty() {
return Err(HpkeError::CryptoError(
"Invalid sender public key: empty key".into(),
));
}
if sender_pk_obj.as_bytes().iter().all(|&b| b == 0) {
return Err(HpkeError::CryptoError(
"Invalid sender public key: all zeros".into(),
));
}
let auth_tag_len = self.get_auth_tag_length()?;
if encapsulated_key.len() < auth_tag_len {
return Err(HpkeError::CryptoError(
"Invalid authenticated encapsulated key: too short".into(),
));
}
let (main_encapsulated_key, auth_tag) =
encapsulated_key.split_at(encapsulated_key.len() - auth_tag_len);
let provider = Self::create_kem_provider()?;
let algorithm = Self::hpke_kem_to_algorithm(kem)?;
let shared_secret = provider
.decapsulate(algorithm, &recipient_sk_obj, main_encapsulated_key)
.map_err(|e| HpkeError::CryptoError(format!("AuthDecap failed: {}", e)))?;
let shared_secret = Zeroizing::new(shared_secret);
self.verify_auth_tag(
shared_secret.as_slice(),
sender_pk,
main_encapsulated_key,
auth_tag,
)?;
let _commitment_len = self.get_commitment_length()?;
Ok(shared_secret)
}
}
impl PostQuantumProvider {
fn create_sender_commitment(
&self,
sender_sk: &[u8],
encapsulated_key: &[u8],
) -> Result<Vec<u8>, HpkeError> {
let mut commitment_input = Vec::new();
commitment_input.extend_from_slice(sender_sk);
commitment_input.extend_from_slice(encapsulated_key);
let commitment = lib_q_hash::Sha3_256::digest(&commitment_input);
Ok(commitment.to_vec())
}
fn create_auth_tag(
&self,
shared_secret: &[u8],
sender_pk: &[u8],
encapsulated_key: &[u8],
) -> Result<Vec<u8>, HpkeError> {
let mut auth_input = Vec::new();
auth_input.extend_from_slice(shared_secret);
auth_input.extend_from_slice(sender_pk);
auth_input.extend_from_slice(encapsulated_key);
let auth_tag = lib_q_hash::Sha3_256::digest(&auth_input);
Ok(auth_tag.to_vec())
}
fn create_sender_commitment_with_pk(
&self,
sender_sk: &[u8],
sender_pk: &[u8],
encapsulated_key: &[u8],
) -> Result<Vec<u8>, HpkeError> {
let mut commitment_input = Vec::new();
commitment_input.extend_from_slice(sender_sk);
commitment_input.extend_from_slice(sender_pk);
commitment_input.extend_from_slice(encapsulated_key);
let commitment = lib_q_hash::Sha3_256::digest(&commitment_input);
Ok(commitment.to_vec())
}
fn verify_auth_tag(
&self,
shared_secret: &[u8],
sender_pk: &[u8],
encapsulated_key: &[u8],
auth_tag: &[u8],
) -> Result<(), HpkeError> {
if auth_tag.is_empty() {
return Err(HpkeError::CryptoError(
"Invalid authentication tag: empty tag".into(),
));
}
if auth_tag.len() != 32 {
return Err(HpkeError::CryptoError(
"Invalid authentication tag: wrong length".into(),
));
}
let expected_auth_tag = self.create_auth_tag(shared_secret, sender_pk, encapsulated_key)?;
if auth_tag != expected_auth_tag.as_slice() {
return Err(HpkeError::CryptoError(
"Authentication failed: invalid authentication tag".into(),
));
}
Ok(())
}
fn get_commitment_length(&self) -> Result<usize, HpkeError> {
Ok(32)
}
fn get_auth_tag_length(&self) -> Result<usize, HpkeError> {
Ok(32)
}
}
impl KdfProvider for PostQuantumProvider {
fn extract(&self, kdf: HpkeKdf, salt: &[u8], ikm: &[u8]) -> Result<Vec<u8>, HpkeError> {
let hkdf_impl = HkdfImpl::new(kdf);
hkdf_impl.extract(salt, ikm)
}
fn expand(
&self,
kdf: HpkeKdf,
prk: &[u8],
info: &[u8],
output_len: usize,
) -> Result<Vec<u8>, HpkeError> {
let hkdf_impl = HkdfImpl::new(kdf);
hkdf_impl.expand(prk, info, output_len)
}
fn supports_kdf(&self, kdf: HpkeKdf) -> bool {
match kdf {
HpkeKdf::HkdfShake128 |
HpkeKdf::HkdfShake256 |
HpkeKdf::HkdfSha3_256 |
HpkeKdf::HkdfSha3_512 => {
Self::create_hash_instance(kdf).is_ok()
}
}
}
}
impl AeadProvider for PostQuantumProvider {
fn seal(
&self,
aead: HpkeAead,
key: &[u8],
nonce: &[u8],
aad: &[u8],
plaintext: &[u8],
) -> Result<Vec<u8>, HpkeError> {
<Self as AeadProvider>::validate_key(self, aead, key)?;
self.validate_nonce(aead, nonce)?;
match aead {
HpkeAead::Export => Err(HpkeError::aead_error(
HpkeAead::Export,
AeadOperation::Seal,
"Export-only AEAD (RFC 9180): no payload encryption; use HPKE export()",
)),
_ => {
let aead_impl = Self::create_aead_instance(aead)?;
let aead_key = AeadKey::new(key.to_vec());
let aead_nonce = Nonce::new(nonce.to_vec());
aead_impl
.encrypt(&aead_key, &aead_nonce, plaintext, Some(aad))
.map_err(|e| HpkeError::CryptoError(format!("AEAD encryption failed: {}", e)))
}
}
}
fn open(
&self,
aead: HpkeAead,
key: &[u8],
nonce: &[u8],
aad: &[u8],
ciphertext: &[u8],
) -> Result<Vec<u8>, HpkeError> {
<Self as AeadProvider>::validate_key(self, aead, key)?;
self.validate_nonce(aead, nonce)?;
match aead {
HpkeAead::Export => Err(HpkeError::aead_error(
HpkeAead::Export,
AeadOperation::Open,
"Export-only AEAD (RFC 9180): no payload decryption; use HPKE export()",
)),
_ => {
let aead_impl = Self::create_aead_instance(aead)?;
let aead_key = AeadKey::new(key.to_vec());
let aead_nonce = Nonce::new(nonce.to_vec());
aead_impl
.decrypt(&aead_key, &aead_nonce, ciphertext, Some(aad))
.map_err(|e| HpkeError::CryptoError(format!("AEAD decryption failed: {}", e)))
}
}
}
fn validate_key(&self, aead: HpkeAead, key: &[u8]) -> Result<(), HpkeError> {
let expected_len = aead.key_len();
if key.len() != expected_len {
return Err(HpkeError::invalid_input(
"key",
format!("{} bytes", key.len()),
format!("{} bytes", expected_len),
));
}
if !key.is_empty() && key.iter().all(|&b| b == 0) {
return Err(HpkeError::CryptoError(
"Key material cannot be all zeros".to_string(),
));
}
Ok(())
}
fn validate_nonce(&self, aead: HpkeAead, nonce: &[u8]) -> Result<(), HpkeError> {
let expected_len = aead.nonce_len();
if nonce.len() != expected_len {
return Err(HpkeError::invalid_input(
"nonce",
format!("{} bytes", nonce.len()),
format!("{} bytes", expected_len),
));
}
Ok(())
}
fn supports_aead(&self, aead: HpkeAead) -> bool {
match aead {
HpkeAead::Export => true, _ => Self::create_aead_instance(aead).is_ok(),
}
}
}
impl HpkeCryptoProvider for PostQuantumProvider {
fn name(&self) -> &'static str {
"PostQuantumProvider"
}
fn supported_algorithms(&self) -> SupportedAlgorithms {
let mut kems = Vec::new();
let mut kdfs = Vec::new();
let mut aeads = Vec::new();
for kem in [HpkeKem::MlKem512, HpkeKem::MlKem768, HpkeKem::MlKem1024] {
if self.supports_kem(kem) {
kems.push(kem);
}
}
for kdf in [
HpkeKdf::HkdfShake128,
HpkeKdf::HkdfShake256,
HpkeKdf::HkdfSha3_256,
HpkeKdf::HkdfSha3_512,
] {
if self.supports_kdf(kdf) {
kdfs.push(kdf);
}
}
for aead in [
HpkeAead::Saturnin256,
HpkeAead::Shake256,
HpkeAead::DuplexSpongeAead,
HpkeAead::Export,
] {
if self.supports_aead(aead) {
aeads.push(aead);
}
}
SupportedAlgorithms::new(kems, kdfs, aeads)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_creation() {
let provider = PostQuantumProvider::new();
assert_eq!(provider.name(), "PostQuantumProvider");
}
#[test]
fn test_supported_algorithms() {
let provider = PostQuantumProvider::new();
let algorithms = provider.supported_algorithms();
assert!(
!algorithms.kems.is_empty() ||
!algorithms.kdfs.is_empty() ||
!algorithms.aeads.is_empty()
);
}
#[test]
fn test_kem_support() {
let provider = PostQuantumProvider::new();
let ml_kem_512_supported = provider.supports_kem(HpkeKem::MlKem512);
let ml_kem_768_supported = provider.supports_kem(HpkeKem::MlKem768);
let ml_kem_1024_supported = provider.supports_kem(HpkeKem::MlKem1024);
assert_eq!(ml_kem_512_supported, ml_kem_768_supported);
assert_eq!(ml_kem_768_supported, ml_kem_1024_supported);
}
#[test]
fn test_kdf_support() {
let provider = PostQuantumProvider::new();
let shake128_supported = provider.supports_kdf(HpkeKdf::HkdfShake128);
let shake256_supported = provider.supports_kdf(HpkeKdf::HkdfShake256);
let sha3_256_supported = provider.supports_kdf(HpkeKdf::HkdfSha3_256);
let sha3_512_supported = provider.supports_kdf(HpkeKdf::HkdfSha3_512);
assert_eq!(shake128_supported, shake256_supported);
assert_eq!(shake256_supported, sha3_256_supported);
assert_eq!(sha3_256_supported, sha3_512_supported);
}
#[test]
fn test_aead_support() {
let provider = PostQuantumProvider::new();
let saturnin_supported = provider.supports_aead(HpkeAead::Saturnin256);
let shake256_supported = provider.supports_aead(HpkeAead::Shake256);
let duplex_supported = provider.supports_aead(HpkeAead::DuplexSpongeAead);
let export_supported = provider.supports_aead(HpkeAead::Export);
assert!(export_supported);
#[cfg(feature = "saturnin")]
assert!(saturnin_supported);
#[cfg(not(feature = "saturnin"))]
assert!(!saturnin_supported);
assert!(
shake256_supported,
"SHAKE256 AEAD should be supported after migration to lib-q-aead"
);
#[cfg(feature = "duplex-sponge-aead")]
assert!(
duplex_supported,
"Duplex-sponge AEAD should be supported when feature is enabled"
);
#[cfg(not(feature = "duplex-sponge-aead"))]
assert!(!duplex_supported);
}
}