use std::sync::Arc;
use rustls::crypto::CryptoProvider;
use rustls::pki_types::{AlgorithmIdentifier, InvalidSignature, SignatureVerificationAlgorithm};
use super::MlDsaOperations;
use super::config::PqcConfig;
use super::ml_dsa::MlDsa65;
use super::types::PqcError;
const ML_DSA_65_OID: &[u8] = &[
0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x11,
];
#[derive(Debug)]
pub struct MlDsa65Verifier;
impl SignatureVerificationAlgorithm for MlDsa65Verifier {
fn verify_signature(
&self,
public_key: &[u8],
message: &[u8],
signature: &[u8],
) -> Result<(), InvalidSignature> {
use super::types::{MlDsaPublicKey, MlDsaSignature};
let pk = MlDsaPublicKey::from_bytes(public_key).map_err(|_| InvalidSignature)?;
let sig = MlDsaSignature::from_bytes(signature).map_err(|_| InvalidSignature)?;
let verifier = MlDsa65::new();
match verifier.verify(&pk, message, &sig) {
Ok(true) => Ok(()),
_ => Err(InvalidSignature),
}
}
fn public_key_alg_id(&self) -> AlgorithmIdentifier {
AlgorithmIdentifier::from_slice(ML_DSA_65_OID)
}
fn signature_alg_id(&self) -> AlgorithmIdentifier {
AlgorithmIdentifier::from_slice(ML_DSA_65_OID)
}
fn fips(&self) -> bool {
true
}
}
static ML_DSA_65_VERIFIER: MlDsa65Verifier = MlDsa65Verifier;
const ML_DSA_65_SCHEME: rustls::SignatureScheme = rustls::SignatureScheme::ML_DSA_65;
static ML_DSA_65_ALGORITHMS: &[&'static dyn SignatureVerificationAlgorithm] =
&[&ML_DSA_65_VERIFIER];
static ML_DSA_65_MAPPINGS: &[(
rustls::SignatureScheme,
&'static [&'static dyn SignatureVerificationAlgorithm],
)] = &[(ML_DSA_65_SCHEME, &[&ML_DSA_65_VERIFIER])];
pub fn create_crypto_provider(config: &PqcConfig) -> Result<Arc<CryptoProvider>, PqcError> {
create_pqc_provider(config)
}
fn create_pqc_provider(config: &PqcConfig) -> Result<Arc<CryptoProvider>, PqcError> {
if !config.ml_kem_enabled && !config.ml_dsa_enabled {
return Err(PqcError::CryptoError(
"At least one PQC algorithm must be enabled".to_string(),
));
}
let mut provider = rustls::crypto::aws_lc_rs::default_provider();
if config.ml_kem_enabled {
provider.kx_groups = vec![
rustls::crypto::aws_lc_rs::kx_group::MLKEM768,
rustls::crypto::aws_lc_rs::kx_group::MLKEM1024,
];
}
if config.ml_dsa_enabled {
provider.signature_verification_algorithms = rustls::crypto::WebPkiSupportedAlgorithms {
all: ML_DSA_65_ALGORITHMS,
mapping: ML_DSA_65_MAPPINGS,
};
}
Ok(Arc::new(provider))
}
fn is_pure_pqc_kx_group(group: rustls::NamedGroup) -> bool {
const MLKEM512: u16 = 0x0200; const MLKEM768: u16 = 0x0201; const MLKEM1024: u16 = 0x0202;
let group_code = u16::from(group);
matches!(group_code, MLKEM512 | MLKEM768 | MLKEM1024)
}
fn is_pqc_kx_group(group: rustls::NamedGroup) -> bool {
is_pure_pqc_kx_group(group)
}
pub fn is_pqc_group(group: rustls::NamedGroup) -> bool {
is_pqc_kx_group(group)
}
pub fn validate_negotiated_group(negotiated_group: rustls::NamedGroup) -> Result<(), PqcError> {
if !is_pqc_kx_group(negotiated_group) {
return Err(PqcError::NegotiationFailed(format!(
"ML-KEM key exchange required, but negotiated {:?}",
negotiated_group
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_pqc_provider() {
let config = PqcConfig::builder()
.ml_kem(true)
.ml_dsa(true)
.build()
.expect("Failed to build config");
let result = create_pqc_provider(&config);
if let Ok(provider) = result {
for group in provider.kx_groups.iter() {
assert!(
is_pure_pqc_kx_group(group.name()),
"Provider should only have pure ML-KEM groups, found {:?}",
group.name()
);
}
}
}
#[test]
fn test_requires_algorithms() {
let config = PqcConfig::builder().ml_kem(false).ml_dsa(false).build();
assert!(config.is_ok(), "Config should succeed with PQC forced on");
let config = config.unwrap();
assert!(config.ml_kem_enabled, "ML-KEM must be enabled");
assert!(config.ml_dsa_enabled, "ML-DSA must be enabled");
}
#[test]
fn test_validate_negotiated_group() {
let result = validate_negotiated_group(rustls::NamedGroup::X25519);
assert!(result.is_err(), "X25519 should be rejected");
let result = validate_negotiated_group(rustls::NamedGroup::Unknown(0x0200));
assert!(result.is_ok(), "ML-KEM-512 should be accepted");
let result = validate_negotiated_group(rustls::NamedGroup::Unknown(0x0201));
assert!(result.is_ok(), "ML-KEM-768 should be accepted");
let result = validate_negotiated_group(rustls::NamedGroup::Unknown(0x0202));
assert!(result.is_ok(), "ML-KEM-1024 should be accepted");
let result = validate_negotiated_group(rustls::NamedGroup::Unknown(0x11EC));
assert!(
result.is_err(),
"X25519MLKEM768 should be rejected (hybrid)"
);
let result = validate_negotiated_group(rustls::NamedGroup::Unknown(0x11EB));
assert!(
result.is_err(),
"SecP256r1MLKEM768 should be rejected (hybrid)"
);
}
#[test]
fn test_is_pure_pqc_kx_group() {
assert!(!is_pure_pqc_kx_group(rustls::NamedGroup::X25519));
assert!(!is_pure_pqc_kx_group(rustls::NamedGroup::secp256r1));
assert!(!is_pure_pqc_kx_group(rustls::NamedGroup::secp384r1));
assert!(is_pure_pqc_kx_group(rustls::NamedGroup::Unknown(0x0200))); assert!(is_pure_pqc_kx_group(rustls::NamedGroup::Unknown(0x0201))); assert!(is_pure_pqc_kx_group(rustls::NamedGroup::Unknown(0x0202)));
assert!(!is_pure_pqc_kx_group(rustls::NamedGroup::Unknown(0x11EB))); assert!(!is_pure_pqc_kx_group(rustls::NamedGroup::Unknown(0x11EC))); assert!(!is_pure_pqc_kx_group(rustls::NamedGroup::Unknown(0x11ED))); }
#[test]
fn test_is_pqc_kx_group() {
assert!(is_pqc_kx_group(rustls::NamedGroup::Unknown(0x0200))); assert!(is_pqc_kx_group(rustls::NamedGroup::Unknown(0x0201))); assert!(is_pqc_kx_group(rustls::NamedGroup::Unknown(0x0202)));
assert!(!is_pqc_kx_group(rustls::NamedGroup::Unknown(0x11EC))); assert!(!is_pqc_kx_group(rustls::NamedGroup::Unknown(0x11EB))); assert!(!is_pqc_kx_group(rustls::NamedGroup::Unknown(0x11ED)));
assert!(!is_pqc_kx_group(rustls::NamedGroup::X25519));
assert!(!is_pqc_kx_group(rustls::NamedGroup::secp256r1));
}
}