use std::sync::Arc;
use rustls::crypto::CryptoProvider;
use rustls::pki_types::{AlgorithmIdentifier, InvalidSignature, SignatureVerificationAlgorithm};
use super::MlDsaOperations;
use super::config::PqcConfig;
use super::ml_dsa::MlDsa65;
use super::types::PqcError;
const ML_DSA_65_OID: &[u8] = &[
0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x11,
];
#[derive(Debug)]
pub struct MlDsa65Verifier;
impl SignatureVerificationAlgorithm for MlDsa65Verifier {
fn verify_signature(
&self,
public_key: &[u8],
message: &[u8],
signature: &[u8],
) -> Result<(), InvalidSignature> {
use super::types::{MlDsaPublicKey, MlDsaSignature};
let pk = MlDsaPublicKey::from_bytes(public_key).map_err(|_| InvalidSignature)?;
let sig = MlDsaSignature::from_bytes(signature).map_err(|_| InvalidSignature)?;
let verifier = MlDsa65::new();
match verifier.verify(&pk, message, &sig) {
Ok(true) => Ok(()),
_ => Err(InvalidSignature),
}
}
fn public_key_alg_id(&self) -> AlgorithmIdentifier {
AlgorithmIdentifier::from_slice(ML_DSA_65_OID)
}
fn signature_alg_id(&self) -> AlgorithmIdentifier {
AlgorithmIdentifier::from_slice(ML_DSA_65_OID)
}
fn fips(&self) -> bool {
true
}
}
static ML_DSA_65_VERIFIER: MlDsa65Verifier = MlDsa65Verifier;
const ML_DSA_65_SCHEME: rustls::SignatureScheme = rustls::SignatureScheme::ML_DSA_65;
static ML_DSA_65_ALGORITHMS: &[&'static dyn SignatureVerificationAlgorithm] =
&[&ML_DSA_65_VERIFIER];
static ML_DSA_65_MAPPINGS: &[(
rustls::SignatureScheme,
&'static [&'static dyn SignatureVerificationAlgorithm],
)] = &[(ML_DSA_65_SCHEME, &[&ML_DSA_65_VERIFIER])];
pub fn create_crypto_provider(config: &PqcConfig) -> Result<Arc<CryptoProvider>, PqcError> {
create_pqc_provider(config)
}
fn create_pqc_provider(config: &PqcConfig) -> Result<Arc<CryptoProvider>, PqcError> {
if !config.ml_kem_enabled && !config.ml_dsa_enabled {
return Err(PqcError::CryptoError(
"At least one PQC algorithm must be enabled".to_string(),
));
}
let mut provider = rustls::crypto::aws_lc_rs::default_provider();
if config.ml_kem_enabled {
let mlkem_groups: Vec<&'static dyn rustls::crypto::SupportedKxGroup> = provider
.kx_groups
.iter()
.filter(|g| is_mlkem_kx_group(g.name()))
.copied()
.collect();
if mlkem_groups.is_empty() {
let pq_provider = rustls_post_quantum::provider();
let pq_groups: Vec<&'static dyn rustls::crypto::SupportedKxGroup> = pq_provider
.kx_groups
.iter()
.filter(|g| is_mlkem_kx_group(g.name()))
.copied()
.collect();
if pq_groups.is_empty() {
return Err(PqcError::CryptoError(
"No ML-KEM key exchange groups available".to_string(),
));
}
provider.kx_groups = pq_groups;
} else {
provider.kx_groups = mlkem_groups;
}
}
if config.ml_dsa_enabled {
provider.signature_verification_algorithms = rustls::crypto::WebPkiSupportedAlgorithms {
all: ML_DSA_65_ALGORITHMS,
mapping: ML_DSA_65_MAPPINGS,
};
}
Ok(Arc::new(provider))
}
fn is_pure_pqc_kx_group(group: rustls::NamedGroup) -> bool {
const MLKEM512: u16 = 0x0200; const MLKEM768: u16 = 0x0201; const MLKEM1024: u16 = 0x0202;
let group_code = u16::from(group);
matches!(group_code, MLKEM512 | MLKEM768 | MLKEM1024)
}
fn is_mlkem_kx_group(group: rustls::NamedGroup) -> bool {
if is_pure_pqc_kx_group(group) {
return true;
}
const SECP256R1MLKEM768: u16 = 0x11EB;
const X25519MLKEM768: u16 = 0x11EC;
const SECP384R1MLKEM1024: u16 = 0x11ED;
let group_code = u16::from(group);
matches!(
group_code,
SECP256R1MLKEM768 | X25519MLKEM768 | SECP384R1MLKEM1024
)
}
fn is_pqc_kx_group(group: rustls::NamedGroup) -> bool {
is_mlkem_kx_group(group)
}
pub fn is_pqc_group(group: rustls::NamedGroup) -> bool {
is_pqc_kx_group(group)
}
pub fn validate_negotiated_group(negotiated_group: rustls::NamedGroup) -> Result<(), PqcError> {
if !is_pqc_kx_group(negotiated_group) {
return Err(PqcError::NegotiationFailed(format!(
"ML-KEM key exchange required, but negotiated {:?}",
negotiated_group
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_pqc_provider() {
let config = PqcConfig::builder()
.ml_kem(true)
.ml_dsa(true)
.build()
.expect("Failed to build config");
let result = create_pqc_provider(&config);
assert!(result.is_ok(), "Provider creation should succeed");
let provider = result.unwrap();
for group in provider.kx_groups.iter() {
assert!(
is_pqc_kx_group(group.name()),
"Provider should only have ML-KEM groups, found {:?}",
group.name()
);
}
}
#[test]
fn test_requires_algorithms() {
let config = PqcConfig::builder().ml_kem(false).ml_dsa(false).build();
assert!(config.is_ok(), "Config should succeed with PQC forced on");
let config = config.unwrap();
assert!(config.ml_kem_enabled, "ML-KEM must be enabled");
assert!(config.ml_dsa_enabled, "ML-DSA must be enabled");
}
#[test]
fn test_validate_negotiated_group() {
let result = validate_negotiated_group(rustls::NamedGroup::X25519);
assert!(result.is_err(), "X25519 should be rejected");
let result = validate_negotiated_group(rustls::NamedGroup::Unknown(0x0200));
assert!(result.is_ok(), "ML-KEM-512 should be accepted");
let result = validate_negotiated_group(rustls::NamedGroup::Unknown(0x0201));
assert!(result.is_ok(), "ML-KEM-768 should be accepted");
let result = validate_negotiated_group(rustls::NamedGroup::Unknown(0x0202));
assert!(result.is_ok(), "ML-KEM-1024 should be accepted");
let result = validate_negotiated_group(rustls::NamedGroup::Unknown(0x11EC));
assert!(
result.is_ok(),
"X25519MLKEM768 should be accepted (contains ML-KEM)"
);
let result = validate_negotiated_group(rustls::NamedGroup::Unknown(0x11EB));
assert!(
result.is_ok(),
"SecP256r1MLKEM768 should be accepted (contains ML-KEM)"
);
}
#[test]
fn test_is_pure_pqc_kx_group() {
assert!(!is_pure_pqc_kx_group(rustls::NamedGroup::X25519));
assert!(!is_pure_pqc_kx_group(rustls::NamedGroup::secp256r1));
assert!(!is_pure_pqc_kx_group(rustls::NamedGroup::secp384r1));
assert!(is_pure_pqc_kx_group(rustls::NamedGroup::Unknown(0x0200))); assert!(is_pure_pqc_kx_group(rustls::NamedGroup::Unknown(0x0201))); assert!(is_pure_pqc_kx_group(rustls::NamedGroup::Unknown(0x0202)));
assert!(!is_pure_pqc_kx_group(rustls::NamedGroup::Unknown(0x11EB))); assert!(!is_pure_pqc_kx_group(rustls::NamedGroup::Unknown(0x11EC))); assert!(!is_pure_pqc_kx_group(rustls::NamedGroup::Unknown(0x11ED))); }
#[test]
fn test_is_mlkem_kx_group() {
assert!(is_pqc_kx_group(rustls::NamedGroup::Unknown(0x0200))); assert!(is_pqc_kx_group(rustls::NamedGroup::Unknown(0x0201))); assert!(is_pqc_kx_group(rustls::NamedGroup::Unknown(0x0202)));
assert!(is_pqc_kx_group(rustls::NamedGroup::Unknown(0x11EC))); assert!(is_pqc_kx_group(rustls::NamedGroup::Unknown(0x11EB))); assert!(is_pqc_kx_group(rustls::NamedGroup::Unknown(0x11ED)));
assert!(!is_pqc_kx_group(rustls::NamedGroup::X25519));
assert!(!is_pqc_kx_group(rustls::NamedGroup::secp256r1));
}
}