#![deny(unsafe_code)]
#![deny(missing_docs)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::panic)]
use crate::primitives::aead::aes_gcm::AesGcm256;
use crate::primitives::aead::{AeadCipher, TAG_LEN};
use crate::primitives::kdf::hkdf::hkdf;
use crate::primitives::kem::ml_kem::{
MlKem, MlKemCiphertext, MlKemPublicKey, MlKemSecretKey, MlKemSecurityLevel,
};
use subtle::ConstantTimeEq;
use thiserror::Error;
use zeroize::Zeroizing;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum PqOnlyError {
#[error("KEM error: {0}")]
KemError(String),
#[error("Encryption error: {0}")]
EncryptionError(String),
#[error("Decryption error: {0}")]
DecryptionError(String),
#[error("Key derivation error: {0}")]
KdfError(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Key generation error: {0}")]
KeyGenError(String),
}
#[derive(Clone)]
pub struct PqOnlyPublicKey {
ml_kem_pk: MlKemPublicKey,
security_level: MlKemSecurityLevel,
}
impl std::fmt::Debug for PqOnlyPublicKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PqOnlyPublicKey")
.field("security_level", &self.security_level)
.field("pk_len", &self.ml_kem_pk.as_bytes().len())
.finish()
}
}
impl PqOnlyPublicKey {
pub fn from_bytes(level: MlKemSecurityLevel, pk_bytes: &[u8]) -> Result<Self, PqOnlyError> {
let ml_kem_pk = MlKemPublicKey::new(level, pk_bytes.to_vec())
.map_err(|e| PqOnlyError::InvalidInput(format!("Invalid ML-KEM public key: {e}")))?;
Ok(Self { ml_kem_pk, security_level: level })
}
#[must_use]
pub fn security_level(&self) -> MlKemSecurityLevel {
self.security_level
}
#[must_use]
pub fn ml_kem_pk_bytes(&self) -> &[u8] {
self.ml_kem_pk.as_bytes()
}
#[must_use]
pub fn ml_kem_pk(&self) -> &MlKemPublicKey {
&self.ml_kem_pk
}
}
pub struct PqOnlySecretKey {
ml_kem_sk: MlKemSecretKey,
security_level: MlKemSecurityLevel,
}
impl std::fmt::Debug for PqOnlySecretKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PqOnlySecretKey")
.field("security_level", &self.security_level)
.field("sk", &"[REDACTED]")
.finish()
}
}
impl ConstantTimeEq for PqOnlySecretKey {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.ml_kem_sk.ct_eq(&other.ml_kem_sk)
}
}
impl PqOnlySecretKey {
pub fn from_bytes(level: MlKemSecurityLevel, sk_bytes: &[u8]) -> Result<Self, PqOnlyError> {
let ml_kem_sk = MlKemSecretKey::new(level, sk_bytes.to_vec())
.map_err(|e| PqOnlyError::InvalidInput(format!("Invalid ML-KEM secret key: {e}")))?;
Ok(Self { ml_kem_sk, security_level: level })
}
#[must_use]
pub fn security_level(&self) -> MlKemSecurityLevel {
self.security_level
}
#[must_use]
pub fn ml_kem_sk_bytes(&self) -> &[u8] {
self.ml_kem_sk.as_bytes()
}
#[must_use]
pub fn ml_kem_sk(&self) -> &MlKemSecretKey {
&self.ml_kem_sk
}
}
pub fn generate_pq_keypair() -> Result<(PqOnlyPublicKey, PqOnlySecretKey), PqOnlyError> {
generate_pq_keypair_with_level(MlKemSecurityLevel::MlKem768)
}
pub fn generate_pq_keypair_with_level(
level: MlKemSecurityLevel,
) -> Result<(PqOnlyPublicKey, PqOnlySecretKey), PqOnlyError> {
let (pk, sk) = MlKem::generate_keypair(level)
.map_err(|e| PqOnlyError::KeyGenError(format!("ML-KEM keygen failed: {e}")))?;
Ok((
PqOnlyPublicKey { ml_kem_pk: pk, security_level: level },
PqOnlySecretKey { ml_kem_sk: sk, security_level: level },
))
}
pub struct PqOnlyCiphertext {
ml_kem_ciphertext: Vec<u8>,
symmetric_ciphertext: Vec<u8>,
nonce: [u8; 12],
tag: [u8; TAG_LEN],
}
impl PqOnlyCiphertext {
#[must_use]
pub fn ml_kem_ciphertext(&self) -> &[u8] {
&self.ml_kem_ciphertext
}
#[must_use]
pub fn symmetric_ciphertext(&self) -> &[u8] {
&self.symmetric_ciphertext
}
#[must_use]
pub fn nonce(&self) -> &[u8; 12] {
&self.nonce
}
#[must_use]
pub fn tag(&self) -> &[u8; TAG_LEN] {
&self.tag
}
#[must_use]
pub fn into_parts(self) -> (Vec<u8>, Vec<u8>, [u8; 12], [u8; TAG_LEN]) {
(self.ml_kem_ciphertext, self.symmetric_ciphertext, self.nonce, self.tag)
}
}
pub fn encrypt_pq_only(
pk: &PqOnlyPublicKey,
plaintext: &[u8],
) -> Result<PqOnlyCiphertext, PqOnlyError> {
let (shared_secret, kem_ct) = MlKem::encapsulate(pk.ml_kem_pk())
.map_err(|e| PqOnlyError::KemError(format!("ML-KEM encapsulation failed: {e}")))?;
let hkdf_result = hkdf(
shared_secret.as_bytes(),
None,
Some(crate::types::domains::PQ_ONLY_ENCRYPTION_INFO),
32,
)
.map_err(|e| PqOnlyError::KdfError(format!("HKDF failed: {e}")))?;
let cipher = AesGcm256::new(hkdf_result.key())
.map_err(|e| PqOnlyError::EncryptionError(format!("AES-GCM init failed: {e}")))?;
let nonce = AesGcm256::generate_nonce();
let (ciphertext, tag) = cipher
.encrypt(&nonce, plaintext, None)
.map_err(|e| PqOnlyError::EncryptionError(format!("AES-GCM encrypt failed: {e}")))?;
Ok(PqOnlyCiphertext {
ml_kem_ciphertext: kem_ct.into_bytes(),
symmetric_ciphertext: ciphertext,
nonce,
tag,
})
}
pub fn decrypt_pq_only(
sk: &PqOnlySecretKey,
kem_ciphertext: &[u8],
symmetric_ciphertext: &[u8],
nonce: &[u8; 12],
tag: &[u8; TAG_LEN],
) -> Result<Zeroizing<Vec<u8>>, PqOnlyError> {
let ct = MlKemCiphertext::new(sk.security_level(), kem_ciphertext.to_vec())
.map_err(|e| PqOnlyError::KemError(format!("Invalid ML-KEM ciphertext: {e}")))?;
let shared_secret = MlKem::decapsulate(sk.ml_kem_sk(), &ct)
.map_err(|e| PqOnlyError::KemError(format!("ML-KEM decapsulation failed: {e}")))?;
let hkdf_result = hkdf(
shared_secret.as_bytes(),
None,
Some(crate::types::domains::PQ_ONLY_ENCRYPTION_INFO),
32,
)
.map_err(|e| PqOnlyError::KdfError(format!("HKDF failed: {e}")))?;
let cipher = AesGcm256::new(hkdf_result.key())
.map_err(|e| PqOnlyError::DecryptionError(format!("AES-GCM init failed: {e}")))?;
cipher.decrypt(nonce, symmetric_ciphertext, tag, None).map_err(|_aead_err| {
PqOnlyError::DecryptionError("decryption failed".to_string())
})
}
#[cfg(test)]
#[allow(
clippy::panic,
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
clippy::arithmetic_side_effects
)]
mod tests {
use super::*;
#[test]
fn test_generate_pq_keypair_default_succeeds() {
let (pk, sk) = generate_pq_keypair().expect("keygen should succeed");
assert_eq!(pk.security_level(), MlKemSecurityLevel::MlKem768);
assert_eq!(sk.security_level(), MlKemSecurityLevel::MlKem768);
}
#[test]
fn test_generate_pq_keypair_all_levels_succeeds() {
for level in [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
] {
let (pk, sk) = generate_pq_keypair_with_level(level).expect("keygen should succeed");
assert_eq!(pk.security_level(), level);
assert_eq!(sk.security_level(), level);
}
}
#[test]
fn test_encrypt_decrypt_pq_only_roundtrip_768() {
let (pk, sk) = generate_pq_keypair().unwrap();
let plaintext = b"PQ-only roundtrip test data";
let ct = encrypt_pq_only(&pk, plaintext).expect("encrypt should succeed");
let decrypted = decrypt_pq_only(
&sk,
ct.ml_kem_ciphertext(),
ct.symmetric_ciphertext(),
ct.nonce(),
ct.tag(),
)
.expect("decrypt should succeed");
assert_eq!(decrypted.as_slice(), plaintext.as_slice());
}
#[test]
fn test_encrypt_decrypt_pq_only_all_levels_roundtrip() {
for level in [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
] {
let (pk, sk) = generate_pq_keypair_with_level(level).unwrap();
let plaintext = b"Test all security levels";
let ct = encrypt_pq_only(&pk, plaintext).expect("encrypt should succeed");
let decrypted = decrypt_pq_only(
&sk,
ct.ml_kem_ciphertext(),
ct.symmetric_ciphertext(),
ct.nonce(),
ct.tag(),
)
.expect("decrypt should succeed");
assert_eq!(decrypted.as_slice(), plaintext.as_slice());
}
}
#[test]
fn test_encrypt_pq_only_empty_data_succeeds() {
let (pk, sk) = generate_pq_keypair().unwrap();
let ct = encrypt_pq_only(&pk, b"").expect("empty data should encrypt");
let decrypted = decrypt_pq_only(
&sk,
ct.ml_kem_ciphertext(),
ct.symmetric_ciphertext(),
ct.nonce(),
ct.tag(),
)
.expect("empty data should decrypt");
assert!(decrypted.is_empty());
}
#[test]
fn test_decrypt_pq_only_wrong_key_fails() {
let (pk, _sk) = generate_pq_keypair().unwrap();
let (_pk2, sk2) = generate_pq_keypair().unwrap();
let ct = encrypt_pq_only(&pk, b"secret").unwrap();
let result = decrypt_pq_only(
&sk2,
ct.ml_kem_ciphertext(),
ct.symmetric_ciphertext(),
ct.nonce(),
ct.tag(),
);
assert!(result.is_err(), "Wrong key should fail");
}
#[test]
fn test_encrypt_pq_only_different_ciphertexts() {
let (pk, _sk) = generate_pq_keypair().unwrap();
let ct1 = encrypt_pq_only(&pk, b"same data").unwrap();
let ct2 = encrypt_pq_only(&pk, b"same data").unwrap();
assert_ne!(
ct1.ml_kem_ciphertext(),
ct2.ml_kem_ciphertext(),
"Random KEM should produce different ciphertexts"
);
}
#[test]
fn test_pq_only_public_key_debug_no_leak() {
let (pk, _sk) = generate_pq_keypair().unwrap();
let debug = format!("{:?}", pk);
assert!(debug.contains("PqOnlyPublicKey"));
assert!(debug.contains("MlKem768"));
}
#[test]
fn test_pq_only_secret_key_debug_redacted() {
let (_pk, sk) = generate_pq_keypair().unwrap();
let debug = format!("{:?}", sk);
assert!(debug.contains("REDACTED"));
}
#[test]
fn test_pq_only_public_key_from_bytes_wrong_length_fails() {
let result = PqOnlyPublicKey::from_bytes(MlKemSecurityLevel::MlKem768, &[0u8; 10]);
assert!(result.is_err());
}
#[test]
fn test_pq_only_secret_key_from_bytes_wrong_length_fails() {
let result = PqOnlySecretKey::from_bytes(MlKemSecurityLevel::MlKem768, &[0u8; 10]);
assert!(result.is_err());
}
#[test]
fn test_pq_only_public_key_from_bytes_roundtrip() {
let (pk, _sk) = generate_pq_keypair().unwrap();
let bytes = pk.ml_kem_pk_bytes().to_vec();
let pk2 = PqOnlyPublicKey::from_bytes(pk.security_level(), &bytes).unwrap();
assert_eq!(pk2.security_level(), pk.security_level());
assert_eq!(pk2.ml_kem_pk_bytes(), pk.ml_kem_pk_bytes());
}
}