#![deny(unsafe_code)]
#![deny(missing_docs)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::panic)]
use crate::log_crypto_operation_error;
use crate::primitives::aead::aes_gcm::AesGcm256;
use crate::primitives::aead::{AeadCipher, TAG_LEN};
use crate::primitives::kdf::hkdf::hkdf;
use crate::primitives::kem::ml_kem::{
MlKem, MlKemCiphertext, MlKemPublicKey, MlKemSecretKey, MlKemSecurityLevel,
};
use crate::unified_api::logging::op;
use subtle::ConstantTimeEq;
use thiserror::Error;
use zeroize::Zeroizing;
#[non_exhaustive]
#[derive(Debug, Clone, Error)]
pub enum PqOnlyError {
#[error("KEM error: {0}")]
KemError(String),
#[error("Encryption error: {0}")]
EncryptionError(String),
#[error("Decryption error: {0}")]
DecryptionError(String),
#[error("Key derivation error: {0}")]
KdfError(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Key generation error: {0}")]
KeyGenError(String),
}
#[derive(Clone)]
pub struct PqOnlyPublicKey {
ml_kem_pk: MlKemPublicKey,
security_level: MlKemSecurityLevel,
}
impl std::fmt::Debug for PqOnlyPublicKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PqOnlyPublicKey")
.field("security_level", &self.security_level)
.field("pk_len", &self.ml_kem_pk.as_bytes().len())
.finish()
}
}
impl PqOnlyPublicKey {
pub fn from_bytes(level: MlKemSecurityLevel, pk_bytes: &[u8]) -> Result<Self, PqOnlyError> {
let ml_kem_pk = MlKemPublicKey::new(level, pk_bytes.to_vec())
.map_err(|e| PqOnlyError::InvalidInput(format!("Invalid ML-KEM public key: {e}")))?;
Ok(Self { ml_kem_pk, security_level: level })
}
#[must_use]
pub fn security_level(&self) -> MlKemSecurityLevel {
self.security_level
}
#[must_use]
pub fn ml_kem_pk_bytes(&self) -> &[u8] {
self.ml_kem_pk.as_bytes()
}
#[must_use]
pub fn ml_kem_pk(&self) -> &MlKemPublicKey {
&self.ml_kem_pk
}
}
pub struct PqOnlySecretKey {
ml_kem_sk: MlKemSecretKey,
ml_kem_pk_bytes: Vec<u8>,
security_level: MlKemSecurityLevel,
}
impl std::fmt::Debug for PqOnlySecretKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PqOnlySecretKey")
.field("security_level", &self.security_level)
.field("sk", &"[REDACTED]")
.finish()
}
}
impl ConstantTimeEq for PqOnlySecretKey {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.ml_kem_sk.ct_eq(&other.ml_kem_sk)
}
}
impl PqOnlySecretKey {
pub fn from_bytes(
level: MlKemSecurityLevel,
sk_bytes: &[u8],
pk_bytes: &[u8],
) -> Result<Self, PqOnlyError> {
let ml_kem_sk = MlKemSecretKey::new(level, sk_bytes.to_vec())
.map_err(|e| PqOnlyError::InvalidInput(format!("Invalid ML-KEM secret key: {e}")))?;
let expected_pk_len = level.public_key_size();
if pk_bytes.len() != expected_pk_len {
return Err(PqOnlyError::InvalidInput(format!(
"ML-KEM public key length {} does not match expected {} for {}",
pk_bytes.len(),
expected_pk_len,
level.name()
)));
}
let embedded = ml_kem_sk
.embedded_public_key_bytes()
.map_err(|e| PqOnlyError::InvalidInput(format!("ML-KEM SK does not embed PK: {e}")))?;
use subtle::ConstantTimeEq;
if embedded.ct_eq(pk_bytes).unwrap_u8() != 1 {
return Err(PqOnlyError::InvalidInput(
"supplied ML-KEM PK does not match SK-embedded PK \
(SK is authoritative; mismatched metadata is rejected)"
.to_string(),
));
}
Ok(Self { ml_kem_sk, ml_kem_pk_bytes: pk_bytes.to_vec(), security_level: level })
}
pub fn from_sk_bytes(level: MlKemSecurityLevel, sk_bytes: &[u8]) -> Result<Self, PqOnlyError> {
let ml_kem_sk = MlKemSecretKey::new(level, sk_bytes.to_vec())
.map_err(|e| PqOnlyError::InvalidInput(format!("Invalid ML-KEM secret key: {e}")))?;
let pk_bytes = ml_kem_sk
.embedded_public_key_bytes()
.map_err(|e| PqOnlyError::InvalidInput(format!("ML-KEM SK does not embed PK: {e}")))?
.to_vec();
Ok(Self { ml_kem_sk, ml_kem_pk_bytes: pk_bytes, security_level: level })
}
#[must_use]
pub fn recipient_pk_bytes(&self) -> &[u8] {
&self.ml_kem_pk_bytes
}
#[must_use]
pub fn security_level(&self) -> MlKemSecurityLevel {
self.security_level
}
#[must_use]
pub fn expose_secret(&self) -> &[u8] {
self.ml_kem_sk.expose_secret()
}
#[must_use]
pub fn ml_kem_sk(&self) -> &MlKemSecretKey {
&self.ml_kem_sk
}
}
fn pq_only_encryption_info(
recipient_pk: &[u8],
kem_ciphertext: &[u8],
) -> Result<Vec<u8>, PqOnlyError> {
crate::types::domains::hkdf_kem_info_with_pk(
crate::types::domains::HkdfKemLabel::PqOnlyEncryption,
recipient_pk,
kem_ciphertext,
)
.map_err(|e| PqOnlyError::KdfError(e.to_string()))
}
#[must_use = "generated keypair must be stored or used"]
pub fn generate_pq_keypair() -> Result<(PqOnlyPublicKey, PqOnlySecretKey), PqOnlyError> {
generate_pq_keypair_with_level(MlKemSecurityLevel::MlKem768)
}
#[must_use = "generated keypair must be stored or used"]
pub fn generate_pq_keypair_with_level(
level: MlKemSecurityLevel,
) -> Result<(PqOnlyPublicKey, PqOnlySecretKey), PqOnlyError> {
let (pk, sk) = MlKem::generate_keypair(level)
.map_err(|e| PqOnlyError::KeyGenError(format!("ML-KEM keygen failed: {e}")))?;
let pk_bytes = pk.as_bytes().to_vec();
Ok((
PqOnlyPublicKey { ml_kem_pk: pk, security_level: level },
PqOnlySecretKey { ml_kem_sk: sk, ml_kem_pk_bytes: pk_bytes, security_level: level },
))
}
pub struct PqOnlyCiphertext {
ml_kem_ciphertext: Vec<u8>,
symmetric_ciphertext: Vec<u8>,
nonce: [u8; 12],
tag: [u8; TAG_LEN],
}
impl PqOnlyCiphertext {
#[must_use]
pub fn ml_kem_ciphertext(&self) -> &[u8] {
&self.ml_kem_ciphertext
}
#[must_use]
pub fn symmetric_ciphertext(&self) -> &[u8] {
&self.symmetric_ciphertext
}
#[must_use]
pub fn nonce(&self) -> &[u8; 12] {
&self.nonce
}
#[must_use]
pub fn tag(&self) -> &[u8; TAG_LEN] {
&self.tag
}
#[must_use]
pub fn into_parts(self) -> (Vec<u8>, Vec<u8>, [u8; 12], [u8; TAG_LEN]) {
(self.ml_kem_ciphertext, self.symmetric_ciphertext, self.nonce, self.tag)
}
}
pub fn encrypt_pq_only(
pk: &PqOnlyPublicKey,
plaintext: &[u8],
) -> Result<PqOnlyCiphertext, PqOnlyError> {
encrypt_pq_only_with_aad(pk, plaintext, &[])
}
pub fn encrypt_pq_only_with_aad(
pk: &PqOnlyPublicKey,
plaintext: &[u8],
aad: &[u8],
) -> Result<PqOnlyCiphertext, PqOnlyError> {
let (shared_secret, kem_ct) = MlKem::encapsulate(pk.ml_kem_pk()).map_err(|_e| {
log_crypto_operation_error!(op::PQ_ONLY_ENCRYPT, "ML-KEM encapsulation failed");
PqOnlyError::KemError("encapsulation failed".to_string())
})?;
let info = pq_only_encryption_info(pk.ml_kem_pk_bytes(), kem_ct.as_bytes())
.map_err(|_e| PqOnlyError::KdfError("KDF info construction failed".to_string()))?;
let hkdf_result = hkdf(shared_secret.expose_secret(), None, Some(&info), 32).map_err(|_e| {
log_crypto_operation_error!(op::PQ_ONLY_ENCRYPT, "HKDF failed");
PqOnlyError::KdfError("KDF failed".to_string())
})?;
let cipher = AesGcm256::new(hkdf_result.expose_secret()).map_err(|_e| {
log_crypto_operation_error!(op::PQ_ONLY_ENCRYPT, "AES-256 init failed");
PqOnlyError::EncryptionError("encryption failed".to_string())
})?;
let nonce = AesGcm256::generate_nonce();
let (ciphertext, tag) = cipher.encrypt(&nonce, plaintext, Some(aad)).map_err(|_e| {
log_crypto_operation_error!(op::PQ_ONLY_ENCRYPT, "AES-GCM seal failed");
PqOnlyError::EncryptionError("encryption failed".to_string())
})?;
Ok(PqOnlyCiphertext {
ml_kem_ciphertext: kem_ct.into_bytes(),
symmetric_ciphertext: ciphertext,
nonce,
tag,
})
}
pub fn decrypt_pq_only(
sk: &PqOnlySecretKey,
kem_ciphertext: &[u8],
symmetric_ciphertext: &[u8],
nonce: &[u8; 12],
tag: &[u8; TAG_LEN],
) -> Result<Zeroizing<Vec<u8>>, PqOnlyError> {
decrypt_pq_only_with_aad(sk, kem_ciphertext, symmetric_ciphertext, nonce, tag, &[])
}
pub fn decrypt_pq_only_with_aad(
sk: &PqOnlySecretKey,
kem_ciphertext: &[u8],
symmetric_ciphertext: &[u8],
nonce: &[u8; 12],
tag: &[u8; TAG_LEN],
aad: &[u8],
) -> Result<Zeroizing<Vec<u8>>, PqOnlyError> {
let opaque = || PqOnlyError::DecryptionError("decryption failed".to_string());
let ct = MlKemCiphertext::new(sk.security_level(), kem_ciphertext.to_vec()).map_err(|_e| {
log_crypto_operation_error!(op::PQ_ONLY_DECRYPT, "invalid ML-KEM ciphertext");
opaque()
})?;
let shared_secret = MlKem::decapsulate(sk.ml_kem_sk(), &ct).map_err(|_e| {
log_crypto_operation_error!(op::PQ_ONLY_DECRYPT, "ML-KEM decapsulation failed");
opaque()
})?;
let info =
pq_only_encryption_info(sk.recipient_pk_bytes(), kem_ciphertext).map_err(|_e| opaque())?;
let hkdf_result = hkdf(shared_secret.expose_secret(), None, Some(&info), 32).map_err(|_e| {
log_crypto_operation_error!(op::PQ_ONLY_DECRYPT, "HKDF failed");
opaque()
})?;
let cipher = AesGcm256::new(hkdf_result.expose_secret()).map_err(|_e| {
log_crypto_operation_error!(op::PQ_ONLY_DECRYPT, "AES-256 init failed");
opaque()
})?;
cipher.decrypt(nonce, symmetric_ciphertext, tag, Some(aad)).map_err(|_aead_err| {
log_crypto_operation_error!(op::PQ_ONLY_DECRYPT, "AEAD authentication failed");
opaque()
})
}
#[cfg(test)]
#[expect(
clippy::unwrap_used,
clippy::expect_used,
reason = "test/bench scaffolding: lints suppressed for this module"
)]
mod tests {
use super::*;
#[test]
fn test_generate_pq_keypair_default_succeeds() {
let (pk, sk) = generate_pq_keypair().expect("keygen should succeed");
assert_eq!(pk.security_level(), MlKemSecurityLevel::MlKem768);
assert_eq!(sk.security_level(), MlKemSecurityLevel::MlKem768);
}
#[test]
fn test_generate_pq_keypair_all_levels_succeeds() {
for level in [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
] {
let (pk, sk) = generate_pq_keypair_with_level(level).expect("keygen should succeed");
assert_eq!(pk.security_level(), level);
assert_eq!(sk.security_level(), level);
}
}
#[test]
fn test_encrypt_decrypt_pq_only_roundtrip_768() {
let (pk, sk) = generate_pq_keypair().unwrap();
let plaintext = b"PQ-only roundtrip test data";
let ct = encrypt_pq_only(&pk, plaintext).expect("encrypt should succeed");
let decrypted = decrypt_pq_only(
&sk,
ct.ml_kem_ciphertext(),
ct.symmetric_ciphertext(),
ct.nonce(),
ct.tag(),
)
.expect("decrypt should succeed");
assert_eq!(decrypted.as_slice(), plaintext.as_slice());
}
#[test]
fn test_encrypt_decrypt_pq_only_all_levels_roundtrip() {
for level in [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
] {
let (pk, sk) = generate_pq_keypair_with_level(level).unwrap();
let plaintext = b"Test all security levels";
let ct = encrypt_pq_only(&pk, plaintext).expect("encrypt should succeed");
let decrypted = decrypt_pq_only(
&sk,
ct.ml_kem_ciphertext(),
ct.symmetric_ciphertext(),
ct.nonce(),
ct.tag(),
)
.expect("decrypt should succeed");
assert_eq!(decrypted.as_slice(), plaintext.as_slice());
}
}
#[test]
fn test_encrypt_pq_only_empty_data_succeeds() {
let (pk, sk) = generate_pq_keypair().unwrap();
let ct = encrypt_pq_only(&pk, b"").expect("empty data should encrypt");
let decrypted = decrypt_pq_only(
&sk,
ct.ml_kem_ciphertext(),
ct.symmetric_ciphertext(),
ct.nonce(),
ct.tag(),
)
.expect("empty data should decrypt");
assert!(decrypted.is_empty());
}
#[test]
fn test_decrypt_pq_only_wrong_key_fails() {
let (pk, _sk) = generate_pq_keypair().unwrap();
let (_pk2, sk2) = generate_pq_keypair().unwrap();
let ct = encrypt_pq_only(&pk, b"secret").unwrap();
let result = decrypt_pq_only(
&sk2,
ct.ml_kem_ciphertext(),
ct.symmetric_ciphertext(),
ct.nonce(),
ct.tag(),
);
assert!(result.is_err(), "Wrong key should fail");
}
#[test]
fn test_encrypt_pq_only_different_ciphertexts() {
let (pk, _sk) = generate_pq_keypair().unwrap();
let ct1 = encrypt_pq_only(&pk, b"same data").unwrap();
let ct2 = encrypt_pq_only(&pk, b"same data").unwrap();
assert_ne!(
ct1.ml_kem_ciphertext(),
ct2.ml_kem_ciphertext(),
"Random KEM should produce different ciphertexts"
);
}
#[test]
fn test_pq_only_public_key_debug_no_leak() {
let (pk, _sk) = generate_pq_keypair().unwrap();
let debug = format!("{:?}", pk);
assert!(debug.contains("PqOnlyPublicKey"));
assert!(debug.contains("MlKem768"));
}
#[test]
fn test_pq_only_secret_key_debug_redacted() {
let (_pk, sk) = generate_pq_keypair().unwrap();
let debug = format!("{:?}", sk);
assert!(debug.contains("REDACTED"));
}
#[test]
fn test_pq_only_public_key_from_bytes_wrong_length_fails() {
let result = PqOnlyPublicKey::from_bytes(MlKemSecurityLevel::MlKem768, &[0u8; 10]);
assert!(result.is_err());
}
#[test]
fn test_pq_only_secret_key_from_bytes_wrong_length_fails() {
let pk_len = MlKemSecurityLevel::MlKem768.public_key_size();
let pk_bytes = vec![0u8; pk_len];
let result =
PqOnlySecretKey::from_bytes(MlKemSecurityLevel::MlKem768, &[0u8; 10], &pk_bytes);
assert!(result.is_err());
}
#[test]
fn test_pq_only_secret_key_from_bytes_wrong_pk_length_fails() {
let (_pk, sk) = generate_pq_keypair().unwrap();
let sk_bytes = sk.expose_secret().to_vec();
let result = PqOnlySecretKey::from_bytes(sk.security_level(), &sk_bytes, &[0u8; 10]);
assert!(matches!(result, Err(PqOnlyError::InvalidInput(_))));
}
#[test]
fn test_pq_only_public_key_from_bytes_roundtrip() {
let (pk, _sk) = generate_pq_keypair().unwrap();
let bytes = pk.ml_kem_pk_bytes().to_vec();
let pk2 = PqOnlyPublicKey::from_bytes(pk.security_level(), &bytes).unwrap();
assert_eq!(pk2.security_level(), pk.security_level());
assert_eq!(pk2.ml_kem_pk_bytes(), pk.ml_kem_pk_bytes());
}
}