use crate::crypto::pqc::tls_extensions::{NamedGroup, SignatureScheme};
use crate::crypto::pqc::types::*;
pub struct PqcTlsExtension {
pub supported_groups: Vec<NamedGroup>,
pub supported_signatures: Vec<SignatureScheme>,
}
impl PqcTlsExtension {
pub fn new() -> Self {
Self {
supported_groups: vec![
NamedGroup::MlKem768, NamedGroup::MlKem1024, NamedGroup::MlKem512, ],
supported_signatures: vec![
SignatureScheme::MlDsa65, SignatureScheme::MlDsa87, SignatureScheme::MlDsa44, ],
}
}
pub fn pqc_only() -> Self {
Self::new()
}
pub fn supported_groups(&self) -> &[NamedGroup] {
&self.supported_groups
}
pub fn supported_signatures(&self) -> &[SignatureScheme] {
&self.supported_signatures
}
pub fn select_group(&self, peer_groups: &[NamedGroup]) -> Option<NamedGroup> {
self.supported_groups
.iter()
.find(|&&our_group| peer_groups.contains(&our_group))
.copied()
}
pub fn select_signature(&self, peer_schemes: &[SignatureScheme]) -> Option<SignatureScheme> {
self.supported_signatures
.iter()
.find(|&&our_scheme| peer_schemes.contains(&our_scheme))
.copied()
}
pub fn supports_group(&self, group: NamedGroup) -> bool {
self.supported_groups.contains(&group)
}
pub fn supports_signature(&self, scheme: SignatureScheme) -> bool {
self.supported_signatures.contains(&scheme)
}
pub fn negotiate_group(&self, peer_groups: &[NamedGroup]) -> NegotiationResult<NamedGroup> {
let pqc_groups: Vec<NamedGroup> =
peer_groups.iter().filter(|g| g.is_pqc()).copied().collect();
if let Some(group) = self.select_group(&pqc_groups) {
return NegotiationResult::Selected(group);
}
NegotiationResult::Failed
}
pub fn negotiate_signature(
&self,
peer_schemes: &[SignatureScheme],
) -> NegotiationResult<SignatureScheme> {
let pqc_schemes: Vec<SignatureScheme> = peer_schemes
.iter()
.filter(|s| s.is_pqc())
.copied()
.collect();
if let Some(scheme) = self.select_signature(&pqc_schemes) {
return NegotiationResult::Selected(scheme);
}
NegotiationResult::Failed
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NegotiationResult<T> {
Selected(T),
Failed,
}
impl<T> NegotiationResult<T> {
pub fn is_success(&self) -> bool {
matches!(self, Self::Selected(_))
}
pub fn value(&self) -> Option<&T> {
match self {
Self::Selected(v) => Some(v),
Self::Failed => None,
}
}
}
impl Default for PqcTlsExtension {
fn default() -> Self {
Self::new()
}
}
pub mod wire_format {
use super::*;
pub fn encode_supported_groups(groups: &[NamedGroup]) -> Vec<u8> {
let mut encoded = Vec::with_capacity(2 + groups.len() * 2);
let len = (groups.len() * 2) as u16;
encoded.extend_from_slice(&len.to_be_bytes());
for group in groups {
encoded.extend_from_slice(&group.to_bytes());
}
encoded
}
pub fn decode_supported_groups(data: &[u8]) -> Result<Vec<NamedGroup>, PqcError> {
if data.len() < 2 {
return Err(PqcError::InvalidKeySize {
expected: 2,
actual: data.len(),
});
}
let len = u16::from_be_bytes([data[0], data[1]]) as usize;
if data.len() != 2 + len {
return Err(PqcError::InvalidKeySize {
expected: 2 + len,
actual: data.len(),
});
}
let mut groups = Vec::new();
let mut offset = 2;
while offset + 2 <= data.len() {
match NamedGroup::from_bytes(&data[offset..offset + 2]) {
Ok(group) => groups.push(group),
Err(_) => {} }
offset += 2;
}
Ok(groups)
}
pub fn encode_signature_schemes(schemes: &[SignatureScheme]) -> Vec<u8> {
let mut encoded = Vec::with_capacity(2 + schemes.len() * 2);
let len = (schemes.len() * 2) as u16;
encoded.extend_from_slice(&len.to_be_bytes());
for scheme in schemes {
encoded.extend_from_slice(&scheme.to_bytes());
}
encoded
}
pub fn decode_signature_schemes(data: &[u8]) -> Result<Vec<SignatureScheme>, PqcError> {
if data.len() < 2 {
return Err(PqcError::InvalidSignatureSize {
expected: 2,
actual: data.len(),
});
}
let len = u16::from_be_bytes([data[0], data[1]]) as usize;
if data.len() != 2 + len {
return Err(PqcError::InvalidSignatureSize {
expected: 2 + len,
actual: data.len(),
});
}
let mut schemes = Vec::new();
let mut offset = 2;
while offset + 2 <= data.len() {
match SignatureScheme::from_bytes(&data[offset..offset + 2]) {
Ok(scheme) => schemes.push(scheme),
Err(_) => {} }
offset += 2;
}
Ok(schemes)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pqc_extension_default_pure_pqc() {
let ext = PqcTlsExtension::new();
assert!(ext.supported_groups()[0].is_pqc());
assert!(ext.supported_signatures()[0].is_pqc());
assert_eq!(ext.supported_groups()[0], NamedGroup::MlKem768);
assert_eq!(ext.supported_signatures()[0], SignatureScheme::MlDsa65);
assert!(ext.supports_group(NamedGroup::MlKem768));
assert!(ext.supports_group(NamedGroup::MlKem1024));
assert!(ext.supports_signature(SignatureScheme::MlDsa65));
assert!(ext.supports_signature(SignatureScheme::MlDsa87));
}
#[test]
fn test_pqc_extension_pqc_only_same_as_new() {
let ext1 = PqcTlsExtension::new();
let ext2 = PqcTlsExtension::pqc_only();
assert_eq!(ext1.supported_groups, ext2.supported_groups);
assert_eq!(ext1.supported_signatures, ext2.supported_signatures);
}
#[test]
fn test_negotiation_both_support_pure_pqc() {
let ext = PqcTlsExtension::new();
let peer_groups = vec![NamedGroup::MlKem768, NamedGroup::MlKem1024];
let result = ext.negotiate_group(&peer_groups);
assert!(result.is_success());
assert_eq!(result.value(), Some(&NamedGroup::MlKem768));
}
#[test]
fn test_negotiation_fails_no_pqc() {
let ext = PqcTlsExtension::new();
let peer_groups: Vec<NamedGroup> = vec![];
let result = ext.negotiate_group(&peer_groups);
assert!(!result.is_success());
assert_eq!(result.value(), None);
}
#[test]
fn test_negotiation_signature_pure_pqc() {
let ext = PqcTlsExtension::new();
let peer_schemes = vec![SignatureScheme::MlDsa65, SignatureScheme::MlDsa87];
let result = ext.negotiate_signature(&peer_schemes);
assert!(result.is_success());
assert_eq!(result.value(), Some(&SignatureScheme::MlDsa65));
}
#[test]
fn test_wire_format_encoding_pure_pqc() {
use wire_format::*;
let groups = vec![NamedGroup::MlKem768, NamedGroup::MlKem1024];
let encoded = encode_supported_groups(&groups);
assert_eq!(encoded.len(), 2 + 4);
let decoded = decode_supported_groups(&encoded).unwrap();
assert_eq!(decoded, groups);
}
#[test]
fn test_wire_format_signature_schemes() {
use wire_format::*;
let schemes = vec![SignatureScheme::MlDsa65, SignatureScheme::MlDsa87];
let encoded = encode_signature_schemes(&schemes);
assert_eq!(encoded.len(), 2 + 4);
let decoded = decode_signature_schemes(&encoded).unwrap();
assert_eq!(decoded, schemes);
}
}