use alloc::borrow::ToOwned;
use alloc::boxed::Box;
use alloc::vec;
use alloc::vec::Vec;
use core::ops::Deref;
use pki_types::ServerName;
#[cfg(feature = "tls12")]
use super::tls12;
use super::{ResolvesClientCert, Tls12Resumption};
use crate::SupportedCipherSuite;
#[cfg(feature = "logging")]
use crate::bs_debug;
use crate::check::inappropriate_handshake_message;
use crate::client::client_conn::ClientConnectionData;
use crate::client::common::ClientHelloDetails;
use crate::client::ech::EchState;
use crate::client::{ClientConfig, EchMode, EchStatus, tls13};
use crate::common_state::{CommonState, HandshakeKind, KxState, State};
use crate::conn::ConnectionRandoms;
use crate::crypto::{ActiveKeyExchange, KeyExchangeAlgorithm};
use crate::enums::{
AlertDescription, CertificateType, CipherSuite, ContentType, HandshakeType, ProtocolVersion,
};
use crate::error::{Error, PeerIncompatible, PeerMisbehaved};
use crate::hash_hs::HandshakeHashBuffer;
use crate::log::{debug, trace};
use crate::msgs::base::Payload;
use crate::msgs::enums::{Compression, ExtensionType};
use crate::msgs::handshake::{
CertificateStatusRequest, ClientExtensions, ClientExtensionsInput, ClientHelloPayload,
ClientSessionTicket, EncryptedClientHello, HandshakeMessagePayload, HandshakePayload,
HelloRetryRequest, KeyShareEntry, ProtocolName, PskKeyExchangeModes, Random, ServerNamePayload,
SessionId, SupportedEcPointFormats, SupportedProtocolVersions, TransportParameters,
};
use crate::msgs::message::{Message, MessagePayload};
use crate::msgs::persist;
use crate::sync::Arc;
use crate::tls13::key_schedule::KeyScheduleEarly;
use crate::verify::ServerCertVerifier;
pub(super) type NextState<'a> = Box<dyn State<ClientConnectionData> + 'a>;
pub(super) type NextStateOrError<'a> = Result<NextState<'a>, Error>;
pub(super) type ClientContext<'a> = crate::common_state::Context<'a, ClientConnectionData>;
struct ExpectServerHello {
input: ClientHelloInput,
transcript_buffer: HandshakeHashBuffer,
early_data_key_schedule: Option<KeyScheduleEarly>,
offered_key_share: Option<Box<dyn ActiveKeyExchange>>,
suite: Option<SupportedCipherSuite>,
ech_state: Option<EchState>,
}
struct ExpectServerHelloOrHelloRetryRequest {
next: ExpectServerHello,
extra_exts: ClientExtensionsInput<'static>,
}
pub(super) struct ClientHelloInput {
pub(super) config: Arc<ClientConfig>,
pub(super) resuming: Option<persist::Retrieved<ClientSessionValue>>,
pub(super) random: Random,
pub(super) sent_tls13_fake_ccs: bool,
pub(super) hello: ClientHelloDetails,
pub(super) session_id: SessionId,
pub(super) server_name: ServerName<'static>,
pub(super) prev_ech_ext: Option<EncryptedClientHello>,
}
impl ClientHelloInput {
pub(super) fn new(
server_name: ServerName<'static>,
extra_exts: &ClientExtensionsInput<'_>,
cx: &mut ClientContext<'_>,
config: Arc<ClientConfig>,
) -> Result<Self, Error> {
let mut resuming = ClientSessionValue::retrieve(&server_name, &config, cx);
let session_id = match &mut resuming {
Some(_resuming) => {
debug!("Resuming session");
match &mut _resuming.value {
#[cfg(feature = "tls12")]
ClientSessionValue::Tls12(inner) => {
if !inner.ticket().0.is_empty() {
inner.session_id = SessionId::random(config.provider.secure_random)?;
}
Some(inner.session_id)
}
_ => None,
}
}
_ => {
debug!("Not resuming any session");
None
}
};
let session_id = match session_id {
Some(session_id) => session_id,
None if cx.common.is_quic() => SessionId::empty(),
None if !config.supports_version(ProtocolVersion::TLSv1_3) => SessionId::empty(),
None => SessionId::random(config.provider.secure_random)?,
};
let hello = ClientHelloDetails::new(
extra_exts
.protocols
.clone()
.unwrap_or_default(),
crate::rand::random_u16(config.provider.secure_random)?,
);
Ok(Self {
resuming,
random: Random::new(config.provider.secure_random)?,
sent_tls13_fake_ccs: false,
hello,
session_id,
server_name,
prev_ech_ext: None,
config,
})
}
pub(super) fn start_handshake(
self,
extra_exts: ClientExtensionsInput<'static>,
cx: &mut ClientContext<'_>,
) -> NextStateOrError<'static> {
let mut transcript_buffer = HandshakeHashBuffer::new();
if self
.config
.client_auth_cert_resolver
.has_certs()
{
transcript_buffer.set_client_auth_enabled();
}
let key_share = if self.config.needs_key_share() {
Some(tls13::initial_key_share(
&self.config,
&self.server_name,
&mut cx.common.kx_state,
)?)
} else {
None
};
let ech_state = match self.config.ech_mode.as_ref() {
Some(EchMode::Enable(ech_config)) => {
Some(ech_config.state(self.server_name.clone(), &self.config)?)
}
_ => None,
};
emit_client_hello_for_retry(
transcript_buffer,
None,
key_share,
extra_exts,
None,
self,
cx,
ech_state,
)
}
}
fn emit_client_hello_for_retry(
mut transcript_buffer: HandshakeHashBuffer,
retryreq: Option<&HelloRetryRequest>,
key_share: Option<Box<dyn ActiveKeyExchange>>,
extra_exts: ClientExtensionsInput<'static>,
suite: Option<SupportedCipherSuite>,
mut input: ClientHelloInput,
cx: &mut ClientContext<'_>,
mut ech_state: Option<EchState>,
) -> NextStateOrError<'static> {
let config = &input.config;
let forbids_tls12 = cx.common.is_quic() || ech_state.is_some();
let supported_versions = SupportedProtocolVersions {
tls12: config.supports_version(ProtocolVersion::TLSv1_2) && !forbids_tls12,
tls13: config.supports_version(ProtocolVersion::TLSv1_3),
};
assert!(supported_versions.any(|_| true));
let mut exts = Box::new(ClientExtensions {
named_groups: Some(
config
.provider
.kx_groups
.iter()
.filter(|skxg| supported_versions.any(|v| skxg.usable_for_version(v)))
.map(|skxg| skxg.name())
.collect(),
),
supported_versions: Some(supported_versions),
signature_schemes: Some(
config
.verifier
.supported_verify_schemes(),
),
extended_master_secret_request: Some(()),
certificate_status_request: Some(CertificateStatusRequest::build_ocsp()),
protocols: extra_exts.protocols.clone(),
..Default::default()
});
match extra_exts.transport_parameters.clone() {
Some(TransportParameters::Quic(v)) => exts.transport_parameters = Some(v),
Some(TransportParameters::QuicDraft(v)) => exts.transport_parameters_draft = Some(v),
None => {}
};
if supported_versions.tls13 {
if let Some(cas_extension) = config.verifier.root_hint_subjects() {
exts.certificate_authority_names = Some(cas_extension.to_owned());
}
}
if config
.provider
.kx_groups
.iter()
.any(|skxg| skxg.name().key_exchange_algorithm() == KeyExchangeAlgorithm::ECDHE)
{
exts.ec_point_formats = Some(SupportedEcPointFormats::default());
}
exts.server_name = match (ech_state.as_ref(), config.enable_sni) {
(Some(ech_state), _) => Some(ServerNamePayload::from(&ech_state.outer_name)),
(None, true) => match &input.server_name {
ServerName::DnsName(dns_name) => Some(ServerNamePayload::from(dns_name)),
_ => None,
},
(None, false) => None,
};
if let Some(key_share) = &key_share {
debug_assert!(supported_versions.tls13);
let mut shares = vec![KeyShareEntry::new(key_share.group(), key_share.pub_key())];
if !retryreq
.map(|rr| rr.key_share.is_some())
.unwrap_or_default()
{
if let Some((component_group, component_share)) =
key_share
.hybrid_component()
.filter(|(group, _)| {
config
.find_kx_group(*group, ProtocolVersion::TLSv1_3)
.is_some()
})
{
shares.push(KeyShareEntry::new(component_group, component_share));
}
}
exts.key_shares = Some(shares);
}
if let Some(cookie) = retryreq.and_then(|hrr| hrr.cookie.as_ref()) {
exts.cookie = Some(cookie.clone());
}
if supported_versions.tls13 {
exts.preshared_key_modes = Some(PskKeyExchangeModes {
psk: false,
psk_dhe: true,
});
}
input.hello.offered_cert_compression =
if supported_versions.tls13 && !config.cert_decompressors.is_empty() {
exts.certificate_compression_algorithms = Some(
config
.cert_decompressors
.iter()
.map(|dec| dec.algorithm())
.collect(),
);
true
} else {
false
};
if config
.client_auth_cert_resolver
.only_raw_public_keys()
{
exts.client_certificate_types = Some(vec![CertificateType::RawPublicKey]);
}
if config
.verifier
.requires_raw_public_keys()
{
exts.server_certificate_types = Some(vec![CertificateType::RawPublicKey]);
}
if matches!(cx.data.ech_status, EchStatus::Rejected | EchStatus::Grease) & retryreq.is_some() {
if let Some(prev_ech_ext) = input.prev_ech_ext.take() {
exts.encrypted_client_hello = Some(prev_ech_ext);
}
}
let tls13_session = prepare_resumption(&input.resuming, &mut exts, suite, cx, config);
exts.order_seed = input.hello.extension_order_seed;
let mut cipher_suites: Vec<_> = config
.provider
.cipher_suites
.iter()
.filter_map(|cs| match cs.usable_for_protocol(cx.common.protocol) {
true => Some(cs.suite()),
false => None,
})
.collect();
if supported_versions.tls12 {
cipher_suites.push(CipherSuite::TLS_EMPTY_RENEGOTIATION_INFO_SCSV);
}
let mut chp_payload = ClientHelloPayload {
client_version: ProtocolVersion::TLSv1_2,
random: input.random,
session_id: input.session_id,
cipher_suites,
compression_methods: vec![Compression::Null],
extensions: exts,
};
let ech_grease_ext = config
.ech_mode
.as_ref()
.and_then(|mode| match mode {
EchMode::Grease(cfg) => Some(cfg.grease_ext(
config.provider.secure_random,
input.server_name.clone(),
&chp_payload,
)),
_ => None,
});
match (cx.data.ech_status, &mut ech_state) {
(EchStatus::NotOffered | EchStatus::Offered, Some(ech_state)) => {
chp_payload = ech_state.ech_hello(chp_payload, retryreq, &tls13_session)?;
cx.data.ech_status = EchStatus::Offered;
input.prev_ech_ext = chp_payload
.encrypted_client_hello
.clone();
}
(EchStatus::NotOffered, None) => {
if let Some(grease_ext) = ech_grease_ext {
let grease_ext = grease_ext?;
chp_payload.encrypted_client_hello = Some(grease_ext.clone());
cx.data.ech_status = EchStatus::Grease;
input.prev_ech_ext = Some(grease_ext);
}
}
_ => {}
}
input.hello.sent_extensions = chp_payload.collect_used();
let mut chp = HandshakeMessagePayload(HandshakePayload::ClientHello(chp_payload));
let tls13_early_data_key_schedule = match (ech_state.as_mut(), tls13_session) {
(Some(ech_state), Some(tls13_session)) => ech_state
.early_data_key_schedule
.take()
.map(|schedule| (tls13_session.suite(), schedule)),
(_, Some(tls13_session)) => Some((
tls13_session.suite(),
tls13::fill_in_psk_binder(&tls13_session, &transcript_buffer, &mut chp),
)),
_ => None,
};
let ch = Message {
version: match retryreq {
Some(_) => ProtocolVersion::TLSv1_2,
None => ProtocolVersion::TLSv1_0,
},
payload: MessagePayload::handshake(chp),
};
if retryreq.is_some() {
tls13::emit_fake_ccs(&mut input.sent_tls13_fake_ccs, cx.common);
}
trace!("Sending ClientHello {ch:#?}");
transcript_buffer.add_message(&ch);
cx.common.send_msg(ch, false);
let early_data_key_schedule =
tls13_early_data_key_schedule.map(|(resuming_suite, schedule)| {
if !cx.data.early_data.is_enabled() {
return schedule;
}
let (transcript_buffer, random) = match &ech_state {
Some(ech_state) => (
&ech_state.inner_hello_transcript,
&ech_state.inner_hello_random.0,
),
None => (&transcript_buffer, &input.random.0),
};
tls13::derive_early_traffic_secret(
&*config.key_log,
cx,
resuming_suite.common.hash_provider,
&schedule,
&mut input.sent_tls13_fake_ccs,
transcript_buffer,
random,
);
schedule
});
let next = ExpectServerHello {
input,
transcript_buffer,
early_data_key_schedule,
offered_key_share: key_share,
suite,
ech_state,
};
Ok(if supported_versions.tls13 && retryreq.is_none() {
Box::new(ExpectServerHelloOrHelloRetryRequest {
next,
extra_exts: extra_exts.into_owned(),
})
} else {
Box::new(next)
})
}
fn prepare_resumption<'a>(
resuming: &'a Option<persist::Retrieved<ClientSessionValue>>,
exts: &mut ClientExtensions<'_>,
suite: Option<SupportedCipherSuite>,
cx: &mut ClientContext<'_>,
config: &ClientConfig,
) -> Option<persist::Retrieved<&'a persist::Tls13ClientSessionValue>> {
let resuming = match resuming {
Some(resuming) if !resuming.ticket().is_empty() => resuming,
_ => {
if config.supports_version(ProtocolVersion::TLSv1_2)
&& config.resumption.tls12_resumption == Tls12Resumption::SessionIdOrTickets
{
exts.session_ticket = Some(ClientSessionTicket::Request);
}
return None;
}
};
let Some(tls13) = resuming.map(|csv| csv.tls13()) else {
if config.supports_version(ProtocolVersion::TLSv1_2)
&& config.resumption.tls12_resumption == Tls12Resumption::SessionIdOrTickets
{
exts.session_ticket = Some(ClientSessionTicket::Offer(Payload::new(resuming.ticket())));
}
return None; };
if !config.supports_version(ProtocolVersion::TLSv1_3) {
return None;
}
let suite = match suite {
Some(SupportedCipherSuite::Tls13(suite)) => Some(suite),
#[cfg(feature = "tls12")]
Some(SupportedCipherSuite::Tls12(_)) => return None,
None => None,
};
if let Some(suite) = suite {
suite.can_resume_from(tls13.suite())?;
}
tls13::prepare_resumption(config, cx, &tls13, exts, suite.is_some());
Some(tls13)
}
pub(super) fn process_alpn_protocol(
common: &mut CommonState,
offered_protocols: &[ProtocolName],
selected: Option<&ProtocolName>,
check_selected_offered: bool,
) -> Result<(), Error> {
common.alpn_protocol = selected.map(ToOwned::to_owned);
if let Some(alpn_protocol) = &common.alpn_protocol {
if check_selected_offered && !offered_protocols.contains(alpn_protocol) {
return Err(common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::SelectedUnofferedApplicationProtocol,
));
}
}
if common.is_quic() && common.alpn_protocol.is_none() && !offered_protocols.is_empty() {
return Err(common.send_fatal_alert(
AlertDescription::NoApplicationProtocol,
Error::NoApplicationProtocol,
));
}
debug!(
"ALPN protocol is {:?}",
common
.alpn_protocol
.as_ref()
.map(|v| bs_debug::BsDebug(v.as_ref()))
);
Ok(())
}
pub(super) fn process_server_cert_type_extension(
common: &mut CommonState,
config: &ClientConfig,
server_cert_extension: Option<&CertificateType>,
) -> Result<Option<(ExtensionType, CertificateType)>, Error> {
process_cert_type_extension(
common,
config
.verifier
.requires_raw_public_keys(),
server_cert_extension.copied(),
ExtensionType::ServerCertificateType,
)
}
pub(super) fn process_client_cert_type_extension(
common: &mut CommonState,
config: &ClientConfig,
client_cert_extension: Option<&CertificateType>,
) -> Result<Option<(ExtensionType, CertificateType)>, Error> {
process_cert_type_extension(
common,
config
.client_auth_cert_resolver
.only_raw_public_keys(),
client_cert_extension.copied(),
ExtensionType::ClientCertificateType,
)
}
impl State<ClientConnectionData> for ExpectServerHello {
fn handle<'m>(
mut self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message<'m>,
) -> NextStateOrError<'m>
where
Self: 'm,
{
let server_hello =
require_handshake_msg!(m, HandshakeType::ServerHello, HandshakePayload::ServerHello)?;
trace!("We got ServerHello {server_hello:#?}");
use crate::ProtocolVersion::{TLSv1_2, TLSv1_3};
let config = &self.input.config;
let tls13_supported = config.supports_version(TLSv1_3);
let server_version = if server_hello.legacy_version == TLSv1_2 {
server_hello
.selected_version
.unwrap_or(server_hello.legacy_version)
} else {
server_hello.legacy_version
};
let version = match server_version {
TLSv1_3 if tls13_supported => TLSv1_3,
TLSv1_2 if config.supports_version(TLSv1_2) => {
if cx.data.early_data.is_enabled() && cx.common.early_traffic {
return Err(PeerMisbehaved::OfferedEarlyDataWithOldProtocolVersion.into());
}
if server_hello.selected_version.is_some() {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::SelectedTls12UsingTls13VersionExtension,
)
});
}
TLSv1_2
}
_ => {
let reason = match server_version {
TLSv1_2 | TLSv1_3 => PeerIncompatible::ServerTlsVersionIsDisabledByOurConfig,
_ => PeerIncompatible::ServerDoesNotSupportTls12Or13,
};
return Err(cx
.common
.send_fatal_alert(AlertDescription::ProtocolVersion, reason));
}
};
if server_hello.compression_method != Compression::Null {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::SelectedUnofferedCompression,
)
});
}
let allowed_unsolicited = [ExtensionType::RenegotiationInfo];
if self
.input
.hello
.server_sent_unsolicited_extensions(server_hello, &allowed_unsolicited)
{
return Err(cx.common.send_fatal_alert(
AlertDescription::UnsupportedExtension,
PeerMisbehaved::UnsolicitedServerHelloExtension,
));
}
cx.common.negotiated_version = Some(version);
if !cx.common.is_tls13() {
process_alpn_protocol(
cx.common,
&self.input.hello.alpn_protocols,
server_hello
.selected_protocol
.as_ref()
.map(|s| s.as_ref()),
self.input.config.check_selected_alpn,
)?;
}
if let Some(point_fmts) = &server_hello.ec_point_formats {
if !point_fmts.uncompressed {
return Err(cx.common.send_fatal_alert(
AlertDescription::HandshakeFailure,
PeerMisbehaved::ServerHelloMustOfferUncompressedEcPoints,
));
}
}
let suite = config
.find_cipher_suite(server_hello.cipher_suite)
.ok_or_else(|| {
cx.common.send_fatal_alert(
AlertDescription::HandshakeFailure,
PeerMisbehaved::SelectedUnofferedCipherSuite,
)
})?;
if version != suite.version().version {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::SelectedUnusableCipherSuiteForVersion,
)
});
}
match self.suite {
Some(prev_suite) if prev_suite != suite => {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::SelectedDifferentCipherSuiteAfterRetry,
)
});
}
_ => {
debug!("Using ciphersuite {suite:?}");
self.suite = Some(suite);
cx.common.suite = Some(suite);
}
}
let mut transcript = self
.transcript_buffer
.start_hash(suite.hash_provider());
transcript.add_message(&m);
let randoms = ConnectionRandoms::new(self.input.random, server_hello.random);
match suite {
SupportedCipherSuite::Tls13(suite) => {
tls13::handle_server_hello(
cx,
server_hello,
randoms,
suite,
transcript,
self.early_data_key_schedule,
self.offered_key_share.unwrap(),
&m,
self.ech_state,
self.input,
)
}
#[cfg(feature = "tls12")]
SupportedCipherSuite::Tls12(suite) => tls12::CompleteServerHelloHandling {
randoms,
transcript,
input: self.input,
}
.handle_server_hello(cx, suite, server_hello, tls13_supported),
}
}
fn into_owned(self: Box<Self>) -> NextState<'static> {
self
}
}
impl ExpectServerHelloOrHelloRetryRequest {
fn into_expect_server_hello(self) -> NextState<'static> {
Box::new(self.next)
}
fn handle_hello_retry_request(
mut self,
cx: &mut ClientContext<'_>,
m: Message<'_>,
) -> NextStateOrError<'static> {
let hrr = require_handshake_msg!(
m,
HandshakeType::HelloRetryRequest,
HandshakePayload::HelloRetryRequest
)?;
trace!("Got HRR {hrr:?}");
cx.common.check_aligned_handshake()?;
let offered_key_share = self.next.offered_key_share.unwrap();
let config = &self.next.input.config;
if let (None, Some(req_group)) = (&hrr.cookie, hrr.key_share) {
let offered_hybrid = offered_key_share
.hybrid_component()
.and_then(|(group_name, _)| {
config.find_kx_group(group_name, ProtocolVersion::TLSv1_3)
})
.map(|skxg| skxg.name());
if req_group == offered_key_share.group() || Some(req_group) == offered_hybrid {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithOfferedGroup,
)
});
}
}
if let Some(cookie) = &hrr.cookie {
if cookie.0.is_empty() {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithEmptyCookie,
)
});
}
}
if hrr.cookie.is_none() && hrr.key_share.is_none() {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithNoChanges,
)
});
}
if hrr.session_id != self.next.input.session_id {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithWrongSessionId,
)
});
}
match hrr.supported_versions {
Some(ProtocolVersion::TLSv1_3) => {
cx.common.negotiated_version = Some(ProtocolVersion::TLSv1_3);
}
_ => {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithUnsupportedVersion,
)
});
}
}
let Some(cs) = config.find_cipher_suite(hrr.cipher_suite) else {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithUnofferedCipherSuite,
)
});
};
if cx.data.ech_status == EchStatus::NotOffered && hrr.encrypted_client_hello.is_some() {
return Err({
cx.common.send_fatal_alert(
AlertDescription::UnsupportedExtension,
PeerMisbehaved::IllegalHelloRetryRequestWithInvalidEch,
)
});
}
cx.common.suite = Some(cs);
cx.common.handshake_kind = Some(HandshakeKind::FullWithHelloRetryRequest);
match (self.next.ech_state.as_ref(), cs.tls13()) {
(Some(ech_state), Some(tls13_cs)) => {
if !ech_state.confirm_hrr_acceptance(hrr, tls13_cs, cx.common)? {
cx.data.ech_status = EchStatus::Rejected;
}
}
(Some(_), None) => {
unreachable!("ECH state should only be set when TLS 1.3 was negotiated")
}
_ => {}
};
let transcript = self
.next
.transcript_buffer
.start_hash(cs.hash_provider());
let mut transcript_buffer = transcript.into_hrr_buffer();
transcript_buffer.add_message(&m);
if let Some(ech_state) = self.next.ech_state.as_mut() {
ech_state.transcript_hrr_update(cs.hash_provider(), &m);
}
if cx.data.early_data.is_enabled() {
cx.data.early_data.rejected();
}
let key_share = match hrr.key_share {
Some(group) if group != offered_key_share.group() => {
let Some(skxg) = config.find_kx_group(group, ProtocolVersion::TLSv1_3) else {
return Err(cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithUnofferedNamedGroup,
));
};
cx.common.kx_state = KxState::Start(skxg);
skxg.start()?
}
_ => offered_key_share,
};
emit_client_hello_for_retry(
transcript_buffer,
Some(hrr),
Some(key_share),
self.extra_exts,
Some(cs),
self.next.input,
cx,
self.next.ech_state,
)
}
}
impl State<ClientConnectionData> for ExpectServerHelloOrHelloRetryRequest {
fn handle<'m>(
self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message<'m>,
) -> NextStateOrError<'m>
where
Self: 'm,
{
match m.payload {
MessagePayload::Handshake {
parsed: HandshakeMessagePayload(HandshakePayload::ServerHello(..)),
..
} => self
.into_expect_server_hello()
.handle(cx, m),
MessagePayload::Handshake {
parsed: HandshakeMessagePayload(HandshakePayload::HelloRetryRequest(..)),
..
} => self.handle_hello_retry_request(cx, m),
payload => Err(inappropriate_handshake_message(
&payload,
&[ContentType::Handshake],
&[HandshakeType::ServerHello, HandshakeType::HelloRetryRequest],
)),
}
}
fn into_owned(self: Box<Self>) -> NextState<'static> {
self
}
}
fn process_cert_type_extension(
common: &mut CommonState,
client_expects: bool,
server_negotiated: Option<CertificateType>,
extension_type: ExtensionType,
) -> Result<Option<(ExtensionType, CertificateType)>, Error> {
match (client_expects, server_negotiated) {
(true, Some(CertificateType::RawPublicKey)) => {
Ok(Some((extension_type, CertificateType::RawPublicKey)))
}
(true, _) => Err(common.send_fatal_alert(
AlertDescription::HandshakeFailure,
Error::PeerIncompatible(PeerIncompatible::IncorrectCertificateTypeExtension),
)),
(_, Some(CertificateType::RawPublicKey)) => {
unreachable!("Caught by `PeerMisbehaved::UnsolicitedEncryptedExtension`")
}
(_, _) => Ok(None),
}
}
pub(super) enum ClientSessionValue {
Tls13(persist::Tls13ClientSessionValue),
#[cfg(feature = "tls12")]
Tls12(persist::Tls12ClientSessionValue),
}
impl ClientSessionValue {
fn retrieve(
server_name: &ServerName<'static>,
config: &ClientConfig,
cx: &mut ClientContext<'_>,
) -> Option<persist::Retrieved<Self>> {
let found = config
.resumption
.store
.take_tls13_ticket(server_name)
.map(ClientSessionValue::Tls13)
.or_else(|| {
#[cfg(feature = "tls12")]
{
config
.resumption
.store
.tls12_session(server_name)
.map(ClientSessionValue::Tls12)
}
#[cfg(not(feature = "tls12"))]
None
})
.and_then(|resuming| {
resuming.compatible_config(&config.verifier, &config.client_auth_cert_resolver)
})
.and_then(|resuming| {
let now = config
.current_time()
.map_err(|_err| debug!("Could not get current time: {_err}"))
.ok()?;
let retrieved = persist::Retrieved::new(resuming, now);
match retrieved.has_expired() {
false => Some(retrieved),
true => None,
}
})
.or_else(|| {
debug!("No cached session for {server_name:?}");
None
});
if let Some(resuming) = &found {
if cx.common.is_quic() {
cx.common.quic.params = resuming
.tls13()
.map(|v| v.quic_params());
}
}
found
}
fn common(&self) -> &persist::ClientSessionCommon {
match self {
Self::Tls13(inner) => &inner.common,
#[cfg(feature = "tls12")]
Self::Tls12(inner) => &inner.common,
}
}
fn tls13(&self) -> Option<&persist::Tls13ClientSessionValue> {
match self {
Self::Tls13(v) => Some(v),
#[cfg(feature = "tls12")]
Self::Tls12(_) => None,
}
}
fn compatible_config(
self,
server_cert_verifier: &Arc<dyn ServerCertVerifier>,
client_creds: &Arc<dyn ResolvesClientCert>,
) -> Option<Self> {
match &self {
Self::Tls13(v) => v
.compatible_config(server_cert_verifier, client_creds)
.then_some(self),
#[cfg(feature = "tls12")]
Self::Tls12(v) => v
.compatible_config(server_cert_verifier, client_creds)
.then_some(self),
}
}
}
impl Deref for ClientSessionValue {
type Target = persist::ClientSessionCommon;
fn deref(&self) -> &Self::Target {
self.common()
}
}