#![deny(unsafe_code)]
#![deny(missing_docs)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::panic)]
use crate::tls::{TlsError, TlsMode};
use rustls::crypto::CryptoProvider;
use std::mem;
use subtle::ConstantTimeEq;
use zeroize::{Zeroize, Zeroizing};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PqKexMode {
RustlsPq,
CustomHybrid,
Classical,
}
#[derive(Debug, Clone)]
pub struct KexInfo {
pub method: String,
pub security_level: String,
pub is_pq_secure: bool,
pub pk_size: usize,
pub sk_size: usize,
pub ct_size: usize,
pub ss_size: usize,
}
pub fn get_kex_provider(mode: TlsMode, kex_mode: PqKexMode) -> Result<CryptoProvider, TlsError> {
match (mode, kex_mode) {
(TlsMode::Hybrid | TlsMode::Pq, PqKexMode::RustlsPq) => {
let mut provider = rustls::crypto::aws_lc_rs::default_provider();
provider.kx_groups.sort_by_key(|group| {
let name = format!("{:?}", group.name());
if name.contains("MLKEM") { 0 } else { 1 }
});
Ok(provider)
}
(TlsMode::Hybrid | TlsMode::Pq, PqKexMode::CustomHybrid) => {
Ok(rustls::crypto::aws_lc_rs::default_provider())
}
(TlsMode::Classic, _) | (_, PqKexMode::Classical) => {
Ok(rustls::crypto::aws_lc_rs::default_provider())
}
}
}
#[must_use]
pub fn get_kex_info(mode: TlsMode, kex_mode: PqKexMode) -> KexInfo {
match (mode, kex_mode) {
(TlsMode::Hybrid | TlsMode::Pq, PqKexMode::RustlsPq) => KexInfo {
method: "X25519MLKEM768".to_string(),
security_level: "Hybrid (Post-Quantum + Classical)".to_string(),
is_pq_secure: true,
pk_size: 32 + 1184, sk_size: 32 + 2400, ct_size: 32 + 1088, ss_size: 64, },
(TlsMode::Hybrid | TlsMode::Pq, PqKexMode::CustomHybrid) => KexInfo {
method: "Custom Hybrid (X25519 + ML-KEM-768)".to_string(),
security_level: "Hybrid (Post-Quantum + Classical)".to_string(),
is_pq_secure: true,
pk_size: 32 + 1184,
sk_size: 32 + 2400,
ct_size: 32 + 1088,
ss_size: 64,
},
(TlsMode::Classic, _) | (_, PqKexMode::Classical) => KexInfo {
method: "X25519 (ECDHE)".to_string(),
security_level: "Classical (128-bit security)".to_string(),
is_pq_secure: false,
pk_size: 32, sk_size: 32, ct_size: 32, ss_size: 32, },
}
}
#[must_use]
pub fn is_pq_available() -> bool {
true
}
#[must_use]
pub fn is_custom_hybrid_available() -> bool {
true
}
pub struct SecureSharedSecret {
secret: Vec<u8>,
}
impl SecureSharedSecret {
#[must_use]
pub fn new(secret: Vec<u8>) -> Self {
Self { secret }
}
#[must_use]
pub fn secret_ref(&self) -> &[u8] {
&self.secret
}
}
impl AsRef<[u8]> for SecureSharedSecret {
fn as_ref(&self) -> &[u8] {
&self.secret
}
}
impl SecureSharedSecret {
#[must_use]
pub fn into_inner(mut self) -> Zeroizing<Vec<u8>> {
Zeroizing::new(mem::take(&mut self.secret))
}
#[must_use]
pub fn into_inner_raw(mut self) -> Vec<u8> {
mem::take(&mut self.secret)
}
}
impl std::fmt::Debug for SecureSharedSecret {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SecureSharedSecret").field("data", &"[REDACTED]").finish()
}
}
impl Drop for SecureSharedSecret {
fn drop(&mut self) {
self.secret.zeroize();
}
}
impl Zeroize for SecureSharedSecret {
fn zeroize(&mut self) {
self.secret.zeroize();
}
}
impl ConstantTimeEq for SecureSharedSecret {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.secret.ct_eq(&other.secret)
}
}
pub fn perform_hybrid_keygen()
-> Result<(crate::hybrid::HybridKemPublicKey, crate::hybrid::HybridKemSecretKey), TlsError> {
crate::hybrid::kem_generate_keypair().map_err(|e| TlsError::KeyExchange {
message: format!("Hybrid keygen failed: {}", e),
method: "X25519MLKEM768".to_string(),
operation: Some("keygen".to_string()),
code: crate::tls::error::ErrorCode::KeyExchangeFailed,
context: Box::default(),
recovery: Box::new(crate::tls::error::RecoveryHint::NoRecovery),
})
}
pub fn perform_hybrid_encapsulate(
pk: &crate::hybrid::HybridKemPublicKey,
) -> Result<crate::hybrid::EncapsulatedKey, TlsError> {
crate::hybrid::encapsulate(pk).map_err(|e| TlsError::KeyExchange {
message: format!("Hybrid encapsulation failed: {}", e),
method: "X25519MLKEM768".to_string(),
operation: Some("encapsulate".to_string()),
code: crate::tls::error::ErrorCode::EncapsulationFailed,
context: Box::default(),
recovery: Box::new(crate::tls::error::RecoveryHint::NoRecovery),
})
}
pub fn perform_hybrid_decapsulate_secure(
sk: &crate::hybrid::HybridKemSecretKey,
ct: &crate::hybrid::EncapsulatedKey,
) -> Result<SecureSharedSecret, TlsError> {
let secret = crate::hybrid::decapsulate(sk, ct).map_err(|e| TlsError::KeyExchange {
message: format!("Hybrid decapsulation failed: {}", e),
method: "X25519MLKEM768".to_string(),
operation: Some("decapsulate".to_string()),
code: crate::tls::error::ErrorCode::DecapsulationFailed,
context: Box::default(),
recovery: Box::new(crate::tls::error::RecoveryHint::NoRecovery),
})?;
Ok(SecureSharedSecret::new((*secret).clone()))
}
pub fn perform_hybrid_decapsulate(
sk: &crate::hybrid::HybridKemSecretKey,
ct: &crate::hybrid::EncapsulatedKey,
) -> Result<Zeroizing<Vec<u8>>, TlsError> {
let secure_secret = perform_hybrid_decapsulate_secure(sk, ct)?;
Ok(Zeroizing::new(secure_secret.into_inner_raw()))
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_kex_info_hybrid_is_correct() {
let info = get_kex_info(TlsMode::Hybrid, PqKexMode::RustlsPq);
assert_eq!(info.method, "X25519MLKEM768");
assert!(info.is_pq_secure);
assert_eq!(info.ss_size, 64);
}
#[test]
fn test_kex_info_classical_is_correct() {
let info = get_kex_info(TlsMode::Classic, PqKexMode::Classical);
assert_eq!(info.method, "X25519 (ECDHE)");
assert!(!info.is_pq_secure);
assert_eq!(info.ss_size, 32);
}
#[test]
fn test_pq_availability_is_correct() {
assert!(is_pq_available());
}
#[test]
fn test_custom_hybrid_availability_is_correct() {
assert!(is_custom_hybrid_available());
}
#[test]
fn test_hybrid_key_exchange_roundtrip() {
let (pk, sk) = perform_hybrid_keygen().expect("Failed to generate keypair");
let enc = perform_hybrid_encapsulate(&pk).expect("Failed to encapsulate");
let secure_ss =
perform_hybrid_decapsulate_secure(&sk, &enc).expect("Failed to decapsulate");
assert_eq!(secure_ss.secret.as_slice(), enc.shared_secret());
assert_eq!(secure_ss.secret.len(), 64);
let ss = perform_hybrid_decapsulate(&sk, &enc).expect("Failed to decapsulate");
assert_eq!(ss.as_slice(), enc.shared_secret());
assert_eq!(ss.len(), 64);
}
#[test]
fn test_get_kex_provider_succeeds() {
let provider = get_kex_provider(TlsMode::Hybrid, PqKexMode::RustlsPq);
assert!(provider.is_ok());
}
#[test]
fn test_get_kex_provider_classical_succeeds() {
let provider = get_kex_provider(TlsMode::Classic, PqKexMode::Classical);
assert!(provider.is_ok());
}
#[test]
fn test_get_kex_provider_custom_hybrid_succeeds() {
let provider = get_kex_provider(TlsMode::Hybrid, PqKexMode::CustomHybrid);
assert!(provider.is_ok());
}
#[test]
fn test_get_kex_provider_pq_rustls_succeeds() {
let provider = get_kex_provider(TlsMode::Pq, PqKexMode::RustlsPq);
assert!(provider.is_ok());
}
#[test]
fn test_get_kex_provider_pq_custom_hybrid_succeeds() {
let provider = get_kex_provider(TlsMode::Pq, PqKexMode::CustomHybrid);
assert!(provider.is_ok());
}
#[test]
fn test_get_kex_provider_classic_with_rustls_pq_succeeds() {
let provider = get_kex_provider(TlsMode::Classic, PqKexMode::RustlsPq);
assert!(provider.is_ok());
}
#[test]
fn test_pq_groups_preferred_in_hybrid_mode_is_correct() {
let provider = get_kex_provider(TlsMode::Hybrid, PqKexMode::RustlsPq)
.expect("Provider should be available");
let group_names: Vec<String> =
provider.kx_groups.iter().map(|g| format!("{:?}", g.name())).collect();
let first_mlkem = group_names.iter().position(|n| n.contains("MLKEM"));
let last_classical = group_names.iter().rposition(|n| !n.contains("MLKEM"));
if let (Some(last_ml), Some(first_cl)) = (
group_names.iter().rposition(|n| n.contains("MLKEM")),
group_names.iter().position(|n| !n.contains("MLKEM")),
) {
assert!(
last_ml < first_cl,
"PQ groups must be sorted before classical groups, got: {group_names:?}"
);
}
assert!(first_mlkem.is_some(), "Provider must contain at least one MLKEM group");
assert!(last_classical.is_some(), "Provider must contain classical groups too");
}
#[test]
fn test_pq_groups_preferred_in_pq_mode_is_correct() {
let provider = get_kex_provider(TlsMode::Pq, PqKexMode::RustlsPq)
.expect("Provider should be available");
let group_names: Vec<String> =
provider.kx_groups.iter().map(|g| format!("{:?}", g.name())).collect();
let first_mlkem = group_names.iter().position(|n| n.contains("MLKEM"));
assert!(
first_mlkem == Some(0),
"First group in PQ mode must be an MLKEM group, got: {group_names:?}"
);
}
#[test]
fn test_native_pq_groups_are_available_succeeds() {
let provider = get_kex_provider(TlsMode::Hybrid, PqKexMode::RustlsPq)
.expect("Provider should be available");
let group_names: Vec<String> =
provider.kx_groups.iter().map(|g| format!("{:?}", g.name())).collect();
let joined = group_names.join(",");
assert!(joined.contains("X25519MLKEM768"), "Missing X25519MLKEM768 in {joined}");
assert!(joined.contains("X25519"), "Missing X25519 classical fallback in {joined}");
}
#[test]
fn test_kex_info_custom_hybrid_is_correct() {
let info = get_kex_info(TlsMode::Hybrid, PqKexMode::CustomHybrid);
assert!(info.method.contains("Custom Hybrid"));
assert!(info.is_pq_secure);
assert_eq!(info.ss_size, 64);
assert_eq!(info.pk_size, 32 + 1184);
}
#[test]
fn test_kex_info_pq_mode_is_correct() {
let info = get_kex_info(TlsMode::Pq, PqKexMode::RustlsPq);
assert!(info.is_pq_secure);
assert_eq!(info.method, "X25519MLKEM768");
}
#[test]
fn test_kex_info_classic_overrides_kex_mode_is_correct() {
let info = get_kex_info(TlsMode::Classic, PqKexMode::RustlsPq);
assert!(!info.is_pq_secure);
assert!(info.method.contains("X25519"));
}
#[test]
fn test_secure_shared_secret_new_and_ref_is_correct() {
let secret = SecureSharedSecret::new(vec![1, 2, 3, 4]);
assert_eq!(secret.secret_ref(), &[1, 2, 3, 4]);
}
#[test]
fn test_secure_shared_secret_as_ref_returns_correct_slice_succeeds() {
let secret = SecureSharedSecret::new(vec![5, 6, 7]);
let slice: &[u8] = secret.as_ref();
assert_eq!(slice, &[5, 6, 7]);
}
#[test]
fn test_secure_shared_secret_into_inner_returns_correct_value_succeeds() {
let secret = SecureSharedSecret::new(vec![10, 20, 30]);
let zeroizing = secret.into_inner();
assert_eq!(zeroizing.as_slice(), &[10, 20, 30]);
}
#[test]
fn test_secure_shared_secret_into_inner_raw_returns_correct_value_succeeds() {
let secret = SecureSharedSecret::new(vec![40, 50, 60]);
let raw = secret.into_inner_raw();
assert_eq!(raw, vec![40, 50, 60]);
}
#[test]
fn test_secure_shared_secret_zeroize_succeeds() {
let mut secret = SecureSharedSecret::new(vec![1, 2, 3, 4, 5]);
secret.zeroize();
assert!(secret.secret_ref().iter().all(|&b| b == 0));
}
#[test]
fn test_secure_shared_secret_drop_zeroizes_succeeds() {
let secret = SecureSharedSecret::new(vec![99; 64]);
drop(secret);
}
#[test]
fn test_pq_kex_mode_eq_is_correct() {
assert_eq!(PqKexMode::RustlsPq, PqKexMode::RustlsPq);
assert_eq!(PqKexMode::Classical, PqKexMode::Classical);
assert_ne!(PqKexMode::RustlsPq, PqKexMode::Classical);
assert_ne!(PqKexMode::CustomHybrid, PqKexMode::RustlsPq);
}
#[test]
fn test_pq_kex_mode_debug_produces_expected_output_succeeds() {
let debug_str = format!("{:?}", PqKexMode::CustomHybrid);
assert!(debug_str.contains("CustomHybrid"));
}
}