use alloc::boxed::Box;
use alloc::vec::Vec;
use core::fmt;
use aws_lc_rs::rand::SystemRandom;
use aws_lc_rs::{agreement, kem};
use pki_types::FipsStatus;
use rustls::crypto::GetRandomFailed;
use rustls::crypto::kx::{
ActiveKeyExchange, CompletedKeyExchange, Hybrid, HybridLayout, NamedGroup, SharedSecret,
StartedKeyExchange, SupportedKxGroup,
};
use rustls::error::{Error, PeerMisbehaved};
pub static X25519MLKEM768: &dyn SupportedKxGroup = &Hybrid {
classical: X25519,
post_quantum: MLKEM768,
name: NamedGroup::X25519MLKEM768,
layout: HybridLayout {
classical_share_len: X25519_LEN,
post_quantum_client_share_len: MLKEM768_ENCAP_LEN,
post_quantum_server_share_len: MLKEM768_CIPHERTEXT_LEN,
post_quantum_first: true,
},
};
pub static SECP256R1MLKEM768: &dyn SupportedKxGroup = &Hybrid {
classical: SECP256R1,
post_quantum: MLKEM768,
name: NamedGroup::secp256r1MLKEM768,
layout: HybridLayout {
classical_share_len: SECP256R1_LEN,
post_quantum_client_share_len: MLKEM768_ENCAP_LEN,
post_quantum_server_share_len: MLKEM768_CIPHERTEXT_LEN,
post_quantum_first: false,
},
};
pub static MLKEM768: &dyn SupportedKxGroup = &MlKem768;
#[derive(Debug)]
pub(crate) struct MlKem768;
impl SupportedKxGroup for MlKem768 {
fn start(&self) -> Result<StartedKeyExchange, Error> {
let decaps_key = kem::DecapsulationKey::generate(&kem::ML_KEM_768)
.map_err(|_| Error::General("key generation failed".into()))?;
let pub_key_bytes = decaps_key
.encapsulation_key()
.and_then(|encaps_key| encaps_key.key_bytes())
.map_err(|_| Error::General("encaps failed".into()))?;
Ok(StartedKeyExchange::Single(Box::new(Active {
decaps_key: Box::new(decaps_key),
encaps_key_bytes: Vec::from(pub_key_bytes.as_ref()),
})))
}
fn start_and_complete(&self, client_share: &[u8]) -> Result<CompletedKeyExchange, Error> {
let encaps_key = kem::EncapsulationKey::new(&kem::ML_KEM_768, client_share)
.map_err(|_| PeerMisbehaved::InvalidKeyShare)?;
let (ciphertext, shared_secret) = encaps_key
.encapsulate()
.map_err(|_| PeerMisbehaved::InvalidKeyShare)?;
Ok(CompletedKeyExchange {
group: self.name(),
pub_key: Vec::from(ciphertext.as_ref()),
secret: SharedSecret::from(shared_secret.as_ref()),
})
}
fn name(&self) -> NamedGroup {
NamedGroup::MLKEM768
}
fn fips(&self) -> FipsStatus {
super::fips()
}
}
struct Active {
decaps_key: Box<kem::DecapsulationKey<kem::AlgorithmId>>,
encaps_key_bytes: Vec<u8>,
}
impl ActiveKeyExchange for Active {
fn complete(self: Box<Self>, peer_pub_key: &[u8]) -> Result<SharedSecret, Error> {
let shared_secret = self
.decaps_key
.decapsulate(peer_pub_key.into())
.map_err(|_| PeerMisbehaved::InvalidKeyShare)?;
Ok(SharedSecret::from(shared_secret.as_ref()))
}
fn pub_key(&self) -> &[u8] {
&self.encaps_key_bytes
}
fn group(&self) -> NamedGroup {
NamedGroup::MLKEM768
}
}
const X25519_LEN: usize = 32;
const SECP256R1_LEN: usize = 65;
const MLKEM768_CIPHERTEXT_LEN: usize = 1088;
const MLKEM768_ENCAP_LEN: usize = 1184;
struct KxGroup {
name: NamedGroup,
agreement_algorithm: &'static agreement::Algorithm,
fips_allowed: bool,
pub_key_validator: fn(&[u8]) -> bool,
}
impl SupportedKxGroup for KxGroup {
fn start(&self) -> Result<StartedKeyExchange, Error> {
let rng = SystemRandom::new();
let priv_key = agreement::EphemeralPrivateKey::generate(self.agreement_algorithm, &rng)
.map_err(|_| GetRandomFailed)?;
let pub_key = priv_key
.compute_public_key()
.map_err(|_| GetRandomFailed)?;
Ok(StartedKeyExchange::Single(Box::new(KeyExchange {
name: self.name,
agreement_algorithm: self.agreement_algorithm,
priv_key,
pub_key,
pub_key_validator: self.pub_key_validator,
})))
}
fn name(&self) -> NamedGroup {
self.name
}
fn fips(&self) -> FipsStatus {
match self.fips_allowed {
true => super::fips(),
false => FipsStatus::Unvalidated,
}
}
}
impl fmt::Debug for KxGroup {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.name.fmt(f)
}
}
pub static X25519: &dyn SupportedKxGroup = &KxGroup {
name: NamedGroup::X25519,
agreement_algorithm: &agreement::X25519,
fips_allowed: false,
pub_key_validator: |point: &[u8]| point.len() == 32,
};
pub static SECP256R1: &dyn SupportedKxGroup = &KxGroup {
name: NamedGroup::secp256r1,
agreement_algorithm: &agreement::ECDH_P256,
fips_allowed: true,
pub_key_validator: uncompressed_point,
};
pub static SECP384R1: &dyn SupportedKxGroup = &KxGroup {
name: NamedGroup::secp384r1,
agreement_algorithm: &agreement::ECDH_P384,
fips_allowed: true,
pub_key_validator: uncompressed_point,
};
fn uncompressed_point(point: &[u8]) -> bool {
matches!(point.first(), Some(0x04))
}
struct KeyExchange {
name: NamedGroup,
agreement_algorithm: &'static agreement::Algorithm,
priv_key: agreement::EphemeralPrivateKey,
pub_key: agreement::PublicKey,
pub_key_validator: fn(&[u8]) -> bool,
}
impl ActiveKeyExchange for KeyExchange {
fn complete(self: Box<Self>, peer: &[u8]) -> Result<SharedSecret, Error> {
if !(self.pub_key_validator)(peer) {
return Err(PeerMisbehaved::InvalidKeyShare.into());
}
let peer_key = agreement::UnparsedPublicKey::new(self.agreement_algorithm, peer);
super::ring_shim::agree_ephemeral(self.priv_key, &peer_key)
.map_err(|_| PeerMisbehaved::InvalidKeyShare.into())
}
fn group(&self) -> NamedGroup {
self.name
}
fn pub_key(&self) -> &[u8] {
self.pub_key.as_ref()
}
}
#[cfg(test)]
mod tests {
use std::format;
#[test]
fn kxgroup_fmt_yields_name() {
assert_eq!("X25519", format!("{:?}", super::X25519));
}
}
#[cfg(bench)]
mod benchmarks {
#[bench]
fn bench_x25519(b: &mut test::Bencher) {
bench_any(b, super::X25519);
}
#[bench]
fn bench_ecdh_p256(b: &mut test::Bencher) {
bench_any(b, super::SECP256R1);
}
#[bench]
fn bench_ecdh_p384(b: &mut test::Bencher) {
bench_any(b, super::SECP384R1);
}
fn bench_any(b: &mut test::Bencher, kxg: &dyn super::SupportedKxGroup) {
b.iter(|| {
let akx = kxg.start().unwrap().into_single();
let pub_key = akx.pub_key().to_vec();
test::black_box(akx.complete(&pub_key).unwrap());
});
}
}