use crate::crypto::pqc::{
config::PqcConfig,
negotiation::{NegotiationResult, PqcNegotiator, filter_algorithms, order_by_preference},
tls_extensions::{NamedGroup, SignatureScheme},
types::*,
};
use std::sync::Arc;
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
pub struct PqcHandshakeExtension {
negotiator: PqcNegotiator,
}
impl PqcHandshakeExtension {
pub fn new(config: Arc<PqcConfig>) -> Self {
let negotiator = PqcNegotiator::new((*config).clone());
Self { negotiator }
}
pub fn process_client_hello(
&mut self,
supported_groups: &[u16],
signature_schemes: &[u16],
) -> PqcResult<()> {
debug!("Processing ClientHello for PQC negotiation");
let client_groups: Vec<NamedGroup> = supported_groups
.iter()
.filter_map(|&code| NamedGroup::from_u16(code))
.collect();
let client_signatures: Vec<SignatureScheme> = signature_schemes
.iter()
.filter_map(|&code| SignatureScheme::from_u16(code))
.collect();
debug!(
"Client supports {} groups and {} signatures",
client_groups.len(),
client_signatures.len()
);
self.negotiator
.set_client_algorithms(client_groups, client_signatures);
Ok(())
}
pub fn process_server_hello(
&mut self,
server_groups: &[u16],
server_signatures: &[u16],
) -> PqcResult<NegotiationResult> {
debug!("Processing ServerHello for PQC negotiation");
let groups: Vec<NamedGroup> = server_groups
.iter()
.filter_map(|&code| NamedGroup::from_u16(code))
.collect();
let signatures: Vec<SignatureScheme> = server_signatures
.iter()
.filter_map(|&code| SignatureScheme::from_u16(code))
.collect();
self.negotiator.set_server_algorithms(groups, signatures);
let result = self.negotiator.negotiate();
if self.negotiator.should_fail(&result) {
warn!("Negotiation failed - no PQC algorithms: {}", result.reason);
return Err(PqcError::NegotiationFailed(result.reason));
}
info!("PQC negotiation successful: {}", result.reason);
Ok(result)
}
pub fn get_client_algorithms(&self) -> (Vec<u16>, Vec<u16>) {
let all_groups = Self::all_supported_groups();
let all_signatures = Self::all_supported_signatures();
let (mut groups, mut signatures) = filter_algorithms(&all_groups, &all_signatures);
order_by_preference(&mut groups, &mut signatures);
let group_codes: Vec<u16> = groups.iter().map(|g| g.to_u16()).collect();
let sig_codes: Vec<u16> = signatures.iter().map(|s| s.to_u16()).collect();
(group_codes, sig_codes)
}
fn all_supported_groups() -> Vec<NamedGroup> {
vec![
NamedGroup::MlKem768, NamedGroup::MlKem1024, NamedGroup::MlKem512, ]
}
fn all_supported_signatures() -> Vec<SignatureScheme> {
vec![
SignatureScheme::MlDsa65, SignatureScheme::MlDsa87, SignatureScheme::MlDsa44, ]
}
}
pub trait PqcServerConfig {
fn with_pqc_config(self, config: Arc<PqcConfig>) -> Self;
}
pub trait PqcClientConfig {
fn with_pqc_config(self, config: Arc<PqcConfig>) -> Self;
}
#[derive(Debug, Clone, Default)]
pub struct PqcHandshakeState {
pub started: bool,
pub key_exchange: Option<NamedGroup>,
pub signature_scheme: Option<SignatureScheme>,
pub used_pqc: bool,
pub result_message: Option<String>,
}
impl PqcHandshakeState {
pub fn new() -> Self {
Self::default()
}
pub fn update_from_result(&mut self, result: &NegotiationResult) {
self.started = true;
self.key_exchange = result.key_exchange;
self.signature_scheme = result.signature_scheme;
self.used_pqc = result.used_pqc;
self.result_message = Some(result.reason.clone());
}
pub fn is_pqc(&self) -> bool {
self.used_pqc
}
pub fn selected_algorithms(&self) -> String {
match (self.key_exchange, self.signature_scheme) {
(Some(ke), Some(sig)) => format!("{} + {}", ke, sig),
(Some(ke), None) => format!("{} (no signature)", ke),
(None, Some(sig)) => format!("(no key exchange) + {}", sig),
(None, None) => "No algorithms selected".to_string(),
}
}
}
pub fn requires_larger_packets(state: &PqcHandshakeState) -> bool {
state.used_pqc
}
pub fn estimate_handshake_size(state: &PqcHandshakeState) -> usize {
let mut size = 4096;
if let Some(group) = state.key_exchange {
size += match group {
NamedGroup::MlKem512 => 1568, NamedGroup::MlKem768 => 2272, NamedGroup::MlKem1024 => 3168, };
}
if let Some(sig) = state.signature_scheme {
size += match sig {
SignatureScheme::MlDsa44 => 2420, SignatureScheme::MlDsa65 => 3309, SignatureScheme::MlDsa87 => 4627, };
}
size
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handshake_extension_creation() {
let config = Arc::new(PqcConfig::default());
let extension = PqcHandshakeExtension::new(config);
assert!(extension.negotiator.client_groups.is_empty());
}
#[test]
fn test_process_client_hello_pure_pqc() {
let config = Arc::new(PqcConfig::default());
let mut extension = PqcHandshakeExtension::new(config);
let groups = vec![
NamedGroup::MlKem768.to_u16(),
NamedGroup::MlKem1024.to_u16(),
];
let signatures = vec![
SignatureScheme::MlDsa65.to_u16(),
SignatureScheme::MlDsa87.to_u16(),
];
extension
.process_client_hello(&groups, &signatures)
.unwrap();
assert_eq!(extension.negotiator.client_groups.len(), 2);
assert_eq!(extension.negotiator.client_signatures.len(), 2);
}
#[test]
fn test_get_client_algorithms_pure_pqc_only() {
let config = Arc::new(PqcConfig::builder().build().unwrap());
let extension = PqcHandshakeExtension::new(config);
let (groups, signatures) = extension.get_client_algorithms();
for &group_code in &groups {
if let Some(group) = NamedGroup::from_u16(group_code) {
assert!(group.is_pqc(), "Expected pure PQC group, got {:?}", group);
}
}
for &sig_code in &signatures {
if let Some(sig) = SignatureScheme::from_u16(sig_code) {
assert!(sig.is_pqc(), "Expected pure PQC signature, got {:?}", sig);
}
}
assert_eq!(groups[0], 0x0201); assert_eq!(signatures[0], 0x0905); }
#[test]
fn test_handshake_state_pure_pqc() {
let mut state = PqcHandshakeState::new();
assert!(!state.started);
assert!(!state.is_pqc());
let result = NegotiationResult {
key_exchange: Some(NamedGroup::MlKem768),
signature_scheme: Some(SignatureScheme::MlDsa65),
used_pqc: true,
reason: "Test negotiation".to_string(),
};
state.update_from_result(&result);
assert!(state.started);
assert!(state.is_pqc());
assert_eq!(state.key_exchange, Some(NamedGroup::MlKem768));
assert_eq!(state.signature_scheme, Some(SignatureScheme::MlDsa65));
}
#[test]
fn test_requires_larger_packets() {
let mut state = PqcHandshakeState::new();
assert!(!requires_larger_packets(&state));
state.used_pqc = true;
assert!(requires_larger_packets(&state));
}
#[test]
fn test_estimate_handshake_size_pure_pqc() {
let mut state = PqcHandshakeState::new();
assert_eq!(estimate_handshake_size(&state), 4096);
state.key_exchange = Some(NamedGroup::MlKem768);
assert_eq!(estimate_handshake_size(&state), 4096 + 2272);
state.signature_scheme = Some(SignatureScheme::MlDsa65);
assert_eq!(estimate_handshake_size(&state), 4096 + 2272 + 3309);
}
#[test]
fn test_selected_algorithms_display_pure_pqc() {
let mut state = PqcHandshakeState::new();
assert_eq!(state.selected_algorithms(), "No algorithms selected");
state.key_exchange = Some(NamedGroup::MlKem768);
assert_eq!(state.selected_algorithms(), "ML-KEM-768 (no signature)");
state.signature_scheme = Some(SignatureScheme::MlDsa65);
assert_eq!(state.selected_algorithms(), "ML-KEM-768 + ML-DSA-65");
}
}