use std::{any::Any, io, str, sync::Arc};
use aws_lc_rs::aead;
use bytes::BytesMut;
pub use rustls::Error;
use rustls::{
self, CipherSuite,
client::danger::ServerCertVerifier,
pki_types::{CertificateDer, PrivateKeyDer, ServerName},
quic::{Connection, HeaderProtectionKey, KeyChange, PacketKey, Secrets, Suite, Version},
};
#[cfg(feature = "platform-verifier")]
use rustls_platform_verifier::BuilderVerifierExt;
use std::sync::atomic::{AtomicBool, Ordering};
static DEBUG_KEM_ONLY: AtomicBool = AtomicBool::new(false);
use crate::{
ConnectError, ConnectionId, Side, TransportError, TransportErrorCode,
crypto::{
self, CryptoError, ExportKeyingMaterialError, HeaderKey, KeyPair, Keys, ServerStartError,
UnsupportedVersion,
tls_extension_simulation::{
ExtensionAwareTlsSession, SimulatedExtensionContext, TlsExtensionHooks,
},
},
transport_parameters::TransportParameters,
};
use crate::crypto::pqc::PqcConfig;
impl From<Side> for rustls::Side {
fn from(s: Side) -> Self {
match s {
Side::Client => Self::Client,
Side::Server => Self::Server,
}
}
}
pub struct TlsSession {
version: Version,
got_handshake_data: bool,
next_secrets: Option<Secrets>,
inner: Connection,
suite: Suite,
}
impl TlsSession {
fn side(&self) -> Side {
match self.inner {
Connection::Client(_) => Side::Client,
Connection::Server(_) => Side::Server,
}
}
}
impl crypto::Session for TlsSession {
fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys {
initial_keys(self.version, *dst_cid, side, &self.suite)
}
fn handshake_data(&self) -> Option<Box<dyn Any>> {
if !self.got_handshake_data {
return None;
}
Some(Box::new(HandshakeData {
protocol: self.inner.alpn_protocol().map(|x| x.into()),
server_name: match self.inner {
Connection::Client(_) => None,
Connection::Server(ref session) => session.server_name().map(|x| x.into()),
},
}))
}
fn peer_identity(&self) -> Option<Box<dyn Any>> {
self.inner.peer_certificates().map(|v| -> Box<dyn Any> {
Box::new(
v.iter()
.map(|v| v.clone().into_owned())
.collect::<Vec<CertificateDer<'static>>>(),
)
})
}
fn early_crypto(&self) -> Option<(Box<dyn HeaderKey>, Box<dyn crypto::PacketKey>)> {
let keys = self.inner.zero_rtt_keys()?;
Some((Box::new(keys.header), Box::new(keys.packet)))
}
fn early_data_accepted(&self) -> Option<bool> {
match self.inner {
Connection::Client(ref session) => Some(session.is_early_data_accepted()),
_ => None,
}
}
fn is_handshaking(&self) -> bool {
self.inner.is_handshaking()
}
fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, TransportError> {
self.inner.read_hs(buf).map_err(|e| {
if let Some(alert) = self.inner.alert() {
TransportError {
code: TransportErrorCode::crypto(alert.into()),
frame: None,
reason: e.to_string(),
}
} else {
TransportError::PROTOCOL_VIOLATION(format!("TLS error: {e}"))
}
})?;
if !self.got_handshake_data {
let have_server_name = match self.inner {
Connection::Client(_) => false,
Connection::Server(ref session) => session.server_name().is_some(),
};
if self.inner.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() {
self.got_handshake_data = true;
return Ok(true);
}
}
Ok(false)
}
fn transport_parameters(&self) -> Result<Option<TransportParameters>, TransportError> {
match self.inner.quic_transport_parameters() {
None => Ok(None),
Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) {
Ok(params) => Ok(Some(params)),
Err(e) => Err(e.into()),
},
}
}
fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<Keys> {
let keys = match self.inner.write_hs(buf)? {
KeyChange::Handshake { keys } => keys,
KeyChange::OneRtt { keys, next } => {
self.next_secrets = Some(next);
keys
}
};
Some(Keys {
header: KeyPair {
local: Box::new(keys.local.header),
remote: Box::new(keys.remote.header),
},
packet: KeyPair {
local: Box::new(keys.local.packet),
remote: Box::new(keys.remote.packet),
},
})
}
fn next_1rtt_keys(&mut self) -> Option<KeyPair<Box<dyn crypto::PacketKey>>> {
let secrets = self.next_secrets.as_mut()?;
let keys = secrets.next_packet_keys();
Some(KeyPair {
local: Box::new(keys.local),
remote: Box::new(keys.remote),
})
}
fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool {
let tag_start = match payload.len().checked_sub(16) {
Some(x) => x,
None => return false,
};
let mut pseudo_packet =
Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1);
pseudo_packet.push(orig_dst_cid.len() as u8);
pseudo_packet.extend_from_slice(orig_dst_cid);
pseudo_packet.extend_from_slice(header);
let tag_start = tag_start + pseudo_packet.len();
pseudo_packet.extend_from_slice(payload);
let (nonce, key) = match self.version {
Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
_ => unreachable!(),
};
let nonce = aead::Nonce::assume_unique_for_key(nonce);
let key = match aead::UnboundKey::new(&aead::AES_128_GCM, &key) {
Ok(unbound_key) => aead::LessSafeKey::new(unbound_key),
Err(_) => {
debug_assert!(false, "Failed to create AEAD key for retry integrity");
return false;
}
};
let (aad, tag) = pseudo_packet.split_at_mut(tag_start);
key.open_in_place(nonce, aead::Aad::from(aad), tag).is_ok()
}
fn export_keying_material(
&self,
output: &mut [u8],
label: &[u8],
context: &[u8],
) -> Result<(), ExportKeyingMaterialError> {
self.inner
.export_keying_material(output, label, Some(context))
.map_err(|_| ExportKeyingMaterialError)?;
Ok(())
}
}
const RETRY_INTEGRITY_KEY_DRAFT: [u8; 16] = [
0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1,
];
const RETRY_INTEGRITY_NONCE_DRAFT: [u8; 12] = [
0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c,
];
const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [
0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e,
];
const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [
0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb,
];
impl crypto::HeaderKey for Box<dyn HeaderProtectionKey> {
fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) {
let (header, sample) = packet.split_at_mut(pn_offset + 4);
let (first, rest) = header.split_at_mut(1);
let pn_end = Ord::min(pn_offset + 3, rest.len());
if let Err(e) = self.decrypt_in_place(
&sample[..self.sample_size()],
&mut first[0],
&mut rest[pn_offset - 1..pn_end],
) {
debug_assert!(false, "Header protection decrypt failed: {:?}", e);
}
}
fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) {
let (header, sample) = packet.split_at_mut(pn_offset + 4);
let (first, rest) = header.split_at_mut(1);
let pn_end = Ord::min(pn_offset + 3, rest.len());
if let Err(e) = self.encrypt_in_place(
&sample[..self.sample_size()],
&mut first[0],
&mut rest[pn_offset - 1..pn_end],
) {
debug_assert!(false, "Header protection encrypt failed: {:?}", e);
}
}
fn sample_size(&self) -> usize {
self.sample_len()
}
}
pub struct HandshakeData {
pub protocol: Option<Vec<u8>>,
pub server_name: Option<String>,
}
pub struct QuicClientConfig {
pub(crate) inner: Arc<rustls::ClientConfig>,
initial: Suite,
pub(crate) extension_context: Option<Arc<SimulatedExtensionContext>>,
}
impl QuicClientConfig {
#[cfg(feature = "platform-verifier")]
pub(crate) fn with_platform_verifier() -> Result<Self, Error> {
let mut inner = rustls::ClientConfig::builder_with_provider(configured_provider())
.with_protocol_versions(&[&rustls::version::TLS13])
.map_err(|_| Error::General("default providers should support TLS 1.3".into()))?
.with_platform_verifier()?
.with_no_client_auth();
inner.enable_early_data = true;
Ok(Self {
initial: initial_suite_from_provider(inner.crypto_provider())
.ok_or_else(|| Error::General("no initial cipher suite found".into()))?,
inner: Arc::new(inner),
extension_context: None,
})
}
pub(crate) fn new(verifier: Arc<dyn ServerCertVerifier>) -> Result<Self, Error> {
let inner = Self::inner(verifier)?;
Ok(Self {
initial: initial_suite_from_provider(inner.crypto_provider())
.ok_or_else(|| Error::General("no initial cipher suite found".into()))?,
inner: Arc::new(inner),
extension_context: None,
})
}
pub fn with_initial(
inner: Arc<rustls::ClientConfig>,
initial: Suite,
) -> Result<Self, NoInitialCipherSuite> {
match initial.suite.common.suite {
CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self {
inner,
initial,
extension_context: None,
}),
_ => Err(NoInitialCipherSuite { specific: true }),
}
}
pub fn with_extension_context(mut self, context: Arc<SimulatedExtensionContext>) -> Self {
self.extension_context = Some(context);
self
}
pub(crate) fn inner(
verifier: Arc<dyn ServerCertVerifier>,
) -> Result<rustls::ClientConfig, Error> {
let mut config = rustls::ClientConfig::builder_with_provider(configured_provider())
.with_protocol_versions(&[&rustls::version::TLS13])
.map_err(|_| Error::General("default providers should support TLS 1.3".into()))?
.dangerous()
.with_custom_certificate_verifier(verifier)
.with_no_client_auth();
config.enable_early_data = true;
Ok(config)
}
}
impl crypto::ClientConfig for QuicClientConfig {
fn start_session(
self: Arc<Self>,
version: u32,
server_name: &str,
params: &TransportParameters,
) -> Result<Box<dyn crypto::Session>, ConnectError> {
let version = interpret_version(version)?;
let inner_session = Box::new(TlsSession {
version,
got_handshake_data: false,
next_secrets: None,
inner: rustls::quic::Connection::Client(rustls::quic::ClientConnection::new(
self.inner.clone(),
version,
ServerName::try_from(server_name)
.map_err(|_| ConnectError::InvalidServerName(server_name.into()))?
.to_owned(),
to_vec(params).map_err(ConnectError::TransportParameters)?,
)?),
suite: self.initial,
});
if let Some(extension_context) = &self.extension_context {
let conn_id = format!(
"client-{}-{}",
server_name,
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_nanos()
);
Ok(Box::new(ExtensionAwareTlsSession::new(
inner_session,
extension_context.clone() as Arc<dyn TlsExtensionHooks>,
conn_id,
true, )))
} else {
Ok(inner_session)
}
}
}
impl TryFrom<rustls::ClientConfig> for QuicClientConfig {
type Error = NoInitialCipherSuite;
fn try_from(inner: rustls::ClientConfig) -> Result<Self, Self::Error> {
Arc::new(inner).try_into()
}
}
impl TryFrom<Arc<rustls::ClientConfig>> for QuicClientConfig {
type Error = NoInitialCipherSuite;
fn try_from(inner: Arc<rustls::ClientConfig>) -> Result<Self, Self::Error> {
Ok(Self {
initial: initial_suite_from_provider(inner.crypto_provider())
.ok_or(NoInitialCipherSuite { specific: false })?,
inner,
extension_context: None,
})
}
}
#[derive(Clone, Debug)]
pub struct NoInitialCipherSuite {
specific: bool,
}
impl std::fmt::Display for NoInitialCipherSuite {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str(match self.specific {
true => "invalid cipher suite specified",
false => "no initial cipher suite found",
})
}
}
impl std::error::Error for NoInitialCipherSuite {}
pub struct QuicServerConfig {
inner: Arc<rustls::ServerConfig>,
initial: Suite,
pub(crate) extension_context: Option<Arc<SimulatedExtensionContext>>,
}
impl QuicServerConfig {
pub(crate) fn new(
cert_chain: Vec<CertificateDer<'static>>,
key: PrivateKeyDer<'static>,
) -> Result<Self, rustls::Error> {
let inner = Self::inner(cert_chain, key)?;
Ok(Self {
initial: initial_suite_from_provider(inner.crypto_provider())
.ok_or_else(|| rustls::Error::General("no initial cipher suite found".into()))?,
inner: Arc::new(inner),
extension_context: None,
})
}
pub fn with_extension_context(mut self, context: Arc<SimulatedExtensionContext>) -> Self {
self.extension_context = Some(context);
self
}
pub fn with_initial(
inner: Arc<rustls::ServerConfig>,
initial: Suite,
) -> Result<Self, NoInitialCipherSuite> {
match initial.suite.common.suite {
CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self {
inner,
initial,
extension_context: None,
}),
_ => Err(NoInitialCipherSuite { specific: true }),
}
}
pub(crate) fn inner(
cert_chain: Vec<CertificateDer<'static>>,
key: PrivateKeyDer<'static>,
) -> Result<rustls::ServerConfig, rustls::Error> {
let mut inner = rustls::ServerConfig::builder_with_provider(configured_provider())
.with_protocol_versions(&[&rustls::version::TLS13])
.map_err(|_| rustls::Error::General("TLS 1.3 not supported".into()))? .with_no_client_auth()
.with_single_cert(cert_chain, key)?;
inner.max_early_data_size = u32::MAX;
Ok(inner)
}
}
impl TryFrom<rustls::ServerConfig> for QuicServerConfig {
type Error = NoInitialCipherSuite;
fn try_from(inner: rustls::ServerConfig) -> Result<Self, Self::Error> {
Arc::new(inner).try_into()
}
}
impl TryFrom<Arc<rustls::ServerConfig>> for QuicServerConfig {
type Error = NoInitialCipherSuite;
fn try_from(inner: Arc<rustls::ServerConfig>) -> Result<Self, Self::Error> {
Ok(Self {
initial: initial_suite_from_provider(inner.crypto_provider())
.ok_or(NoInitialCipherSuite { specific: false })?,
inner,
extension_context: None,
})
}
}
impl crypto::ServerConfig for QuicServerConfig {
#[allow(clippy::expect_used)]
fn start_session(
self: Arc<Self>,
version: u32,
params: &TransportParameters,
) -> Result<Box<dyn crypto::Session>, ServerStartError> {
let version = interpret_version(version).map_err(|_| {
ServerStartError::TlsError("Invalid QUIC version for server connection".into())
})?;
let inner_session = Box::new(TlsSession {
version,
got_handshake_data: false,
next_secrets: None,
inner: rustls::quic::Connection::Server(
rustls::quic::ServerConnection::new(self.inner.clone(), version, to_vec(params)?)
.map_err(|e| ServerStartError::TlsError(e.to_string()))?,
),
suite: self.initial,
});
if let Some(extension_context) = &self.extension_context {
let conn_id = format!(
"server-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_nanos()
);
Ok(Box::new(ExtensionAwareTlsSession::new(
inner_session,
extension_context.clone() as Arc<dyn TlsExtensionHooks>,
conn_id,
false, )))
} else {
Ok(inner_session)
}
}
fn initial_keys(
&self,
version: u32,
dst_cid: &ConnectionId,
) -> Result<Keys, UnsupportedVersion> {
let version = interpret_version(version)?;
Ok(initial_keys(version, *dst_cid, Side::Server, &self.initial))
}
#[allow(clippy::expect_used)]
fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] {
let version = interpret_version(version).map_err(|_| {
rustls::Error::General("Invalid QUIC version for retry tag".into())
}).expect("Version should be valid at this point - retry_tag() is never called if initial_keys() rejected version");
let (nonce, key) = match version {
Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
_ => unreachable!(),
};
let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1);
pseudo_packet.push(orig_dst_cid.len() as u8);
pseudo_packet.extend_from_slice(orig_dst_cid);
pseudo_packet.extend_from_slice(packet);
let nonce = aead::Nonce::assume_unique_for_key(nonce);
let key = match aead::UnboundKey::new(&aead::AES_128_GCM, &key) {
Ok(unbound_key) => aead::LessSafeKey::new(unbound_key),
Err(_) => {
debug_assert!(false, "Failed to create AEAD key for retry integrity");
return [0; 16];
}
};
let tag =
match key.seal_in_place_separate_tag(nonce, aead::Aad::from(pseudo_packet), &mut []) {
Ok(tag) => tag,
Err(_) => {
debug_assert!(false, "Failed to seal retry integrity tag");
return [0; 16];
}
};
let mut result = [0; 16];
result.copy_from_slice(tag.as_ref());
result
}
}
pub(crate) fn initial_suite_from_provider(
provider: &Arc<rustls::crypto::CryptoProvider>,
) -> Option<Suite> {
provider
.cipher_suites
.iter()
.find_map(|cs| match (cs.suite(), cs.tls13()) {
(rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => {
Some(suite.quic_suite())
}
_ => None,
})
.flatten()
}
pub(crate) fn configured_provider() -> Arc<rustls::crypto::CryptoProvider> {
DEBUG_KEM_ONLY.store(true, Ordering::Relaxed);
let provider = rustls::crypto::aws_lc_rs::default_provider();
Arc::new(provider)
}
pub fn configured_provider_with_pqc(
config: Option<&PqcConfig>,
) -> Arc<rustls::crypto::CryptoProvider> {
DEBUG_KEM_ONLY.store(true, Ordering::Relaxed);
let default_config = PqcConfig::default();
let pqc_config = config.unwrap_or(&default_config);
match crate::crypto::pqc::create_crypto_provider(pqc_config) {
Ok(provider) => provider,
Err(e) => {
tracing::warn!(
"Failed to create PQC provider: {:?}, falling back to default",
e
);
configured_provider()
}
}
}
pub fn validate_pqc_connection(
negotiated_group: rustls::NamedGroup,
) -> Result<(), crate::crypto::pqc::PqcError> {
crate::crypto::pqc::validate_negotiated_group(negotiated_group)
}
pub fn debug_kem_only_enabled() -> bool {
DEBUG_KEM_ONLY.load(Ordering::Relaxed)
}
fn to_vec(params: &TransportParameters) -> Result<Vec<u8>, crate::transport_parameters::Error> {
let mut bytes = Vec::new();
params.write(&mut bytes)?;
Ok(bytes)
}
pub(crate) fn initial_keys(
version: Version,
dst_cid: ConnectionId,
side: Side,
suite: &Suite,
) -> Keys {
let keys = suite.keys(&dst_cid, side.into(), version);
Keys {
header: KeyPair {
local: Box::new(keys.local.header),
remote: Box::new(keys.remote.header),
},
packet: KeyPair {
local: Box::new(keys.local.packet),
remote: Box::new(keys.remote.packet),
},
}
}
impl crypto::PacketKey for Box<dyn PacketKey> {
#[allow(clippy::expect_used)]
fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize) {
let (header, payload_tag) = buf.split_at_mut(header_len);
let (payload, tag_storage) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len());
let tag = self
.encrypt_in_place(packet, &*header, payload)
.map_err(|_| rustls::Error::General("Packet encryption failed".into()))
.expect("Packet encryption should not fail with valid parameters");
tag_storage.copy_from_slice(tag.as_ref());
}
fn decrypt(
&self,
packet: u64,
header: &[u8],
payload: &mut BytesMut,
) -> Result<(), CryptoError> {
let plain = self
.decrypt_in_place(packet, header, payload.as_mut())
.map_err(|_| CryptoError)?;
let plain_len = plain.len();
payload.truncate(plain_len);
Ok(())
}
fn tag_len(&self) -> usize {
(**self).tag_len()
}
fn confidentiality_limit(&self) -> u64 {
(**self).confidentiality_limit()
}
fn integrity_limit(&self) -> u64 {
(**self).integrity_limit()
}
}
fn interpret_version(version: u32) -> Result<Version, UnsupportedVersion> {
match version {
0xff00_001d..=0xff00_0020 => Ok(Version::V1Draft),
0x0000_0001 | 0xff00_0021..=0xff00_0022 => Ok(Version::V1),
_ => Err(UnsupportedVersion),
}
}