use alloc::vec::Vec;
use core::marker::PhantomData;
use core::ops::{Deref, DerefMut};
use core::{fmt, mem};
use pki_types::{ServerName, UnixTime};
use super::handy::NoClientSessionStorage;
use super::hs::{self, ClientHelloInput};
#[cfg(feature = "std")]
use crate::WantsVerifier;
use crate::builder::ConfigBuilder;
use crate::client::{EchMode, EchStatus};
use crate::common_state::{CommonState, Protocol, Side};
use crate::conn::{ConnectionCore, UnbufferedConnectionCommon};
use crate::crypto::{CryptoProvider, SupportedKxGroup};
use crate::enums::{CipherSuite, ProtocolVersion, SignatureScheme};
use crate::error::Error;
use crate::kernel::KernelConnection;
use crate::log::trace;
use crate::msgs::enums::NamedGroup;
use crate::msgs::handshake::ClientExtensionsInput;
use crate::msgs::persist;
use crate::suites::{ExtractedSecrets, SupportedCipherSuite};
use crate::sync::Arc;
#[cfg(feature = "std")]
use crate::time_provider::DefaultTimeProvider;
use crate::time_provider::TimeProvider;
use crate::unbuffered::{EncryptError, TransmitTlsData};
#[cfg(doc)]
use crate::{DistinguishedName, crypto};
use crate::{KeyLog, WantsVersions, compress, sign, verify, versions};
pub trait ClientSessionStore: fmt::Debug + Send + Sync {
fn set_kx_hint(&self, server_name: ServerName<'static>, group: NamedGroup);
fn kx_hint(&self, server_name: &ServerName<'_>) -> Option<NamedGroup>;
fn set_tls12_session(
&self,
server_name: ServerName<'static>,
value: persist::Tls12ClientSessionValue,
);
fn tls12_session(
&self,
server_name: &ServerName<'_>,
) -> Option<persist::Tls12ClientSessionValue>;
fn remove_tls12_session(&self, server_name: &ServerName<'static>);
fn insert_tls13_ticket(
&self,
server_name: ServerName<'static>,
value: persist::Tls13ClientSessionValue,
);
fn take_tls13_ticket(
&self,
server_name: &ServerName<'static>,
) -> Option<persist::Tls13ClientSessionValue>;
}
pub trait ResolvesClientCert: fmt::Debug + Send + Sync {
fn resolve(
&self,
root_hint_subjects: &[&[u8]],
sigschemes: &[SignatureScheme],
) -> Option<Arc<sign::CertifiedKey>>;
fn only_raw_public_keys(&self) -> bool {
false
}
fn has_certs(&self) -> bool;
}
#[derive(Clone, Debug)]
pub struct ClientConfig {
pub alpn_protocols: Vec<Vec<u8>>,
pub check_selected_alpn: bool,
pub resumption: Resumption,
pub max_fragment_size: Option<usize>,
pub client_auth_cert_resolver: Arc<dyn ResolvesClientCert>,
pub enable_sni: bool,
pub key_log: Arc<dyn KeyLog>,
pub enable_secret_extraction: bool,
pub enable_early_data: bool,
#[cfg(feature = "tls12")]
pub require_ems: bool,
pub time_provider: Arc<dyn TimeProvider>,
pub(super) provider: Arc<CryptoProvider>,
pub(super) versions: versions::EnabledVersions,
pub(super) verifier: Arc<dyn verify::ServerCertVerifier>,
pub cert_decompressors: Vec<&'static dyn compress::CertDecompressor>,
pub cert_compressors: Vec<&'static dyn compress::CertCompressor>,
pub cert_compression_cache: Arc<compress::CompressionCache>,
pub(super) ech_mode: Option<EchMode>,
}
impl ClientConfig {
#[cfg(feature = "std")]
pub fn builder() -> ConfigBuilder<Self, WantsVerifier> {
Self::builder_with_protocol_versions(versions::DEFAULT_VERSIONS)
}
#[cfg(feature = "std")]
pub fn builder_with_protocol_versions(
versions: &[&'static versions::SupportedProtocolVersion],
) -> ConfigBuilder<Self, WantsVerifier> {
Self::builder_with_provider(
CryptoProvider::get_default_or_install_from_crate_features().clone(),
)
.with_protocol_versions(versions)
.unwrap()
}
#[cfg(feature = "std")]
pub fn builder_with_provider(
provider: Arc<CryptoProvider>,
) -> ConfigBuilder<Self, WantsVersions> {
ConfigBuilder {
state: WantsVersions {},
provider,
time_provider: Arc::new(DefaultTimeProvider),
side: PhantomData,
}
}
pub fn builder_with_details(
provider: Arc<CryptoProvider>,
time_provider: Arc<dyn TimeProvider>,
) -> ConfigBuilder<Self, WantsVersions> {
ConfigBuilder {
state: WantsVersions {},
provider,
time_provider,
side: PhantomData,
}
}
pub fn fips(&self) -> bool {
let mut is_fips = self.provider.fips();
#[cfg(feature = "tls12")]
{
is_fips = is_fips && self.require_ems
}
if let Some(ech_mode) = &self.ech_mode {
is_fips = is_fips && ech_mode.fips();
}
is_fips
}
pub fn crypto_provider(&self) -> &Arc<CryptoProvider> {
&self.provider
}
pub fn dangerous(&mut self) -> danger::DangerousClientConfig<'_> {
danger::DangerousClientConfig { cfg: self }
}
pub(super) fn needs_key_share(&self) -> bool {
self.supports_version(ProtocolVersion::TLSv1_3)
}
pub(crate) fn supports_version(&self, v: ProtocolVersion) -> bool {
self.versions.contains(v)
&& self
.provider
.cipher_suites
.iter()
.any(|cs| cs.version().version == v)
}
#[cfg(feature = "std")]
pub(crate) fn supports_protocol(&self, proto: Protocol) -> bool {
self.provider
.cipher_suites
.iter()
.any(|cs| cs.usable_for_protocol(proto))
}
pub(super) fn find_cipher_suite(&self, suite: CipherSuite) -> Option<SupportedCipherSuite> {
self.provider
.cipher_suites
.iter()
.copied()
.find(|&scs| scs.suite() == suite)
}
pub(super) fn find_kx_group(
&self,
group: NamedGroup,
version: ProtocolVersion,
) -> Option<&'static dyn SupportedKxGroup> {
self.provider
.kx_groups
.iter()
.copied()
.find(|skxg| skxg.usable_for_version(version) && skxg.name() == group)
}
pub(super) fn current_time(&self) -> Result<UnixTime, Error> {
self.time_provider
.current_time()
.ok_or(Error::FailedToGetCurrentTime)
}
}
#[derive(Clone, Debug)]
pub struct Resumption {
pub(super) store: Arc<dyn ClientSessionStore>,
pub(super) tls12_resumption: Tls12Resumption,
}
impl Resumption {
#[cfg(feature = "std")]
pub fn in_memory_sessions(num: usize) -> Self {
Self {
store: Arc::new(super::handy::ClientSessionMemoryCache::new(num)),
tls12_resumption: Tls12Resumption::SessionIdOrTickets,
}
}
pub fn store(store: Arc<dyn ClientSessionStore>) -> Self {
Self {
store,
tls12_resumption: Tls12Resumption::SessionIdOrTickets,
}
}
pub fn disabled() -> Self {
Self {
store: Arc::new(NoClientSessionStorage),
tls12_resumption: Tls12Resumption::Disabled,
}
}
pub fn tls12_resumption(mut self, tls12: Tls12Resumption) -> Self {
self.tls12_resumption = tls12;
self
}
}
impl Default for Resumption {
fn default() -> Self {
#[cfg(feature = "std")]
let ret = Self::in_memory_sessions(256);
#[cfg(not(feature = "std"))]
let ret = Self::disabled();
ret
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Tls12Resumption {
Disabled,
SessionIdOnly,
SessionIdOrTickets,
}
pub(super) mod danger {
use super::ClientConfig;
use super::verify::ServerCertVerifier;
use crate::sync::Arc;
#[derive(Debug)]
pub struct DangerousClientConfig<'a> {
pub cfg: &'a mut ClientConfig,
}
impl DangerousClientConfig<'_> {
pub fn set_certificate_verifier(&mut self, verifier: Arc<dyn ServerCertVerifier>) {
self.cfg.verifier = verifier;
}
}
}
#[derive(Debug, PartialEq)]
enum EarlyDataState {
Disabled,
Ready,
Accepted,
AcceptedFinished,
Rejected,
}
#[derive(Debug)]
pub(super) struct EarlyData {
state: EarlyDataState,
left: usize,
}
impl EarlyData {
fn new() -> Self {
Self {
left: 0,
state: EarlyDataState::Disabled,
}
}
pub(super) fn is_enabled(&self) -> bool {
matches!(self.state, EarlyDataState::Ready | EarlyDataState::Accepted)
}
#[cfg(feature = "std")]
fn is_accepted(&self) -> bool {
matches!(
self.state,
EarlyDataState::Accepted | EarlyDataState::AcceptedFinished
)
}
pub(super) fn enable(&mut self, max_data: usize) {
assert_eq!(self.state, EarlyDataState::Disabled);
self.state = EarlyDataState::Ready;
self.left = max_data;
}
pub(super) fn rejected(&mut self) {
trace!("EarlyData rejected");
self.state = EarlyDataState::Rejected;
}
pub(super) fn accepted(&mut self) {
trace!("EarlyData accepted");
assert_eq!(self.state, EarlyDataState::Ready);
self.state = EarlyDataState::Accepted;
}
pub(super) fn finished(&mut self) {
trace!("EarlyData finished");
self.state = match self.state {
EarlyDataState::Accepted => EarlyDataState::AcceptedFinished,
_ => panic!("bad EarlyData state"),
}
}
fn check_write_opt(&mut self, sz: usize) -> Option<usize> {
match self.state {
EarlyDataState::Disabled => unreachable!(),
EarlyDataState::Ready | EarlyDataState::Accepted => {
let take = if self.left < sz {
mem::replace(&mut self.left, 0)
} else {
self.left -= sz;
sz
};
Some(take)
}
EarlyDataState::Rejected | EarlyDataState::AcceptedFinished => None,
}
}
}
#[cfg(feature = "std")]
mod connection {
use alloc::vec::Vec;
use core::fmt;
use core::ops::{Deref, DerefMut};
use std::io;
use pki_types::ServerName;
use super::{ClientConnectionData, ClientExtensionsInput};
use crate::ClientConfig;
use crate::client::EchStatus;
use crate::common_state::Protocol;
use crate::conn::{ConnectionCommon, ConnectionCore};
use crate::error::Error;
use crate::suites::ExtractedSecrets;
use crate::sync::Arc;
pub struct WriteEarlyData<'a> {
sess: &'a mut ClientConnection,
}
impl<'a> WriteEarlyData<'a> {
fn new(sess: &'a mut ClientConnection) -> Self {
WriteEarlyData { sess }
}
pub fn bytes_left(&self) -> usize {
self.sess
.inner
.core
.data
.early_data
.bytes_left()
}
}
impl io::Write for WriteEarlyData<'_> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.sess.write_early_data(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl super::EarlyData {
fn check_write(&mut self, sz: usize) -> io::Result<usize> {
self.check_write_opt(sz)
.ok_or_else(|| io::Error::from(io::ErrorKind::InvalidInput))
}
fn bytes_left(&self) -> usize {
self.left
}
}
pub struct ClientConnection {
inner: ConnectionCommon<ClientConnectionData>,
}
impl fmt::Debug for ClientConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ClientConnection")
.finish()
}
}
impl ClientConnection {
pub fn new(config: Arc<ClientConfig>, name: ServerName<'static>) -> Result<Self, Error> {
Self::new_with_alpn(config.clone(), name, config.alpn_protocols.clone())
}
pub fn new_with_alpn(
config: Arc<ClientConfig>,
name: ServerName<'static>,
alpn_protocols: Vec<Vec<u8>>,
) -> Result<Self, Error> {
Ok(Self {
inner: ConnectionCommon::from(ConnectionCore::for_client(
config,
name,
ClientExtensionsInput::from_alpn(alpn_protocols),
Protocol::Tcp,
)?),
})
}
pub fn early_data(&mut self) -> Option<WriteEarlyData<'_>> {
if self
.inner
.core
.data
.early_data
.is_enabled()
{
Some(WriteEarlyData::new(self))
} else {
None
}
}
pub fn is_early_data_accepted(&self) -> bool {
self.inner.core.is_early_data_accepted()
}
pub fn dangerous_extract_secrets(self) -> Result<ExtractedSecrets, Error> {
self.inner.dangerous_extract_secrets()
}
pub fn ech_status(&self) -> EchStatus {
self.inner.core.data.ech_status
}
pub fn tls13_tickets_received(&self) -> u32 {
self.inner.tls13_tickets_received
}
pub fn fips(&self) -> bool {
self.inner.core.common_state.fips
}
fn write_early_data(&mut self, data: &[u8]) -> io::Result<usize> {
self.inner
.core
.data
.early_data
.check_write(data.len())
.map(|sz| {
self.inner
.send_early_plaintext(&data[..sz])
})
}
}
impl Deref for ClientConnection {
type Target = ConnectionCommon<ClientConnectionData>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for ClientConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
#[doc(hidden)]
impl<'a> TryFrom<&'a mut crate::Connection> for &'a mut ClientConnection {
type Error = ();
fn try_from(value: &'a mut crate::Connection) -> Result<Self, Self::Error> {
use crate::Connection::*;
match value {
Client(conn) => Ok(conn),
Server(_) => Err(()),
}
}
}
impl From<ClientConnection> for crate::Connection {
fn from(conn: ClientConnection) -> Self {
Self::Client(conn)
}
}
}
#[cfg(feature = "std")]
pub use connection::{ClientConnection, WriteEarlyData};
impl ConnectionCore<ClientConnectionData> {
pub(crate) fn for_client(
config: Arc<ClientConfig>,
name: ServerName<'static>,
extra_exts: ClientExtensionsInput<'static>,
proto: Protocol,
) -> Result<Self, Error> {
let mut common_state = CommonState::new(Side::Client);
common_state.set_max_fragment_size(config.max_fragment_size)?;
common_state.protocol = proto;
common_state.enable_secret_extraction = config.enable_secret_extraction;
common_state.fips = config.fips();
let mut data = ClientConnectionData::new();
let mut cx = hs::ClientContext {
common: &mut common_state,
data: &mut data,
sendable_plaintext: None,
};
let input = ClientHelloInput::new(name, &extra_exts, &mut cx, config)?;
let state = input.start_handshake(extra_exts, &mut cx)?;
Ok(Self::new(state, data, common_state))
}
#[cfg(feature = "std")]
pub(crate) fn is_early_data_accepted(&self) -> bool {
self.data.early_data.is_accepted()
}
}
pub struct UnbufferedClientConnection {
inner: UnbufferedConnectionCommon<ClientConnectionData>,
}
impl UnbufferedClientConnection {
pub fn new(config: Arc<ClientConfig>, name: ServerName<'static>) -> Result<Self, Error> {
Self::new_with_extensions(
config.clone(),
name,
ClientExtensionsInput::from_alpn(config.alpn_protocols.clone()),
)
}
pub fn new_with_alpn(
config: Arc<ClientConfig>,
name: ServerName<'static>,
alpn_protocols: Vec<Vec<u8>>,
) -> Result<Self, Error> {
Self::new_with_extensions(
config,
name,
ClientExtensionsInput::from_alpn(alpn_protocols),
)
}
fn new_with_extensions(
config: Arc<ClientConfig>,
name: ServerName<'static>,
extensions: ClientExtensionsInput<'static>,
) -> Result<Self, Error> {
Ok(Self {
inner: UnbufferedConnectionCommon::from(ConnectionCore::for_client(
config,
name,
extensions,
Protocol::Tcp,
)?),
})
}
#[deprecated = "dangerous_extract_secrets() does not support session tickets or \
key updates, use dangerous_into_kernel_connection() instead"]
pub fn dangerous_extract_secrets(self) -> Result<ExtractedSecrets, Error> {
self.inner.dangerous_extract_secrets()
}
pub fn dangerous_into_kernel_connection(
self,
) -> Result<(ExtractedSecrets, KernelConnection<ClientConnectionData>), Error> {
self.inner
.core
.dangerous_into_kernel_connection()
}
pub fn tls13_tickets_received(&self) -> u32 {
self.inner.tls13_tickets_received
}
}
impl Deref for UnbufferedClientConnection {
type Target = UnbufferedConnectionCommon<ClientConnectionData>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for UnbufferedClientConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl TransmitTlsData<'_, ClientConnectionData> {
pub fn may_encrypt_early_data(&mut self) -> Option<MayEncryptEarlyData<'_>> {
if self
.conn
.core
.data
.early_data
.is_enabled()
{
Some(MayEncryptEarlyData { conn: self.conn })
} else {
None
}
}
}
pub struct MayEncryptEarlyData<'c> {
conn: &'c mut UnbufferedConnectionCommon<ClientConnectionData>,
}
impl MayEncryptEarlyData<'_> {
pub fn encrypt(
&mut self,
early_data: &[u8],
outgoing_tls: &mut [u8],
) -> Result<usize, EarlyDataError> {
let Some(allowed) = self
.conn
.core
.data
.early_data
.check_write_opt(early_data.len())
else {
return Err(EarlyDataError::ExceededAllowedEarlyData);
};
self.conn
.core
.common_state
.write_plaintext(early_data[..allowed].into(), outgoing_tls)
.map_err(|e| e.into())
}
}
#[derive(Debug)]
pub enum EarlyDataError {
ExceededAllowedEarlyData,
Encrypt(EncryptError),
}
impl From<EncryptError> for EarlyDataError {
fn from(v: EncryptError) -> Self {
Self::Encrypt(v)
}
}
impl fmt::Display for EarlyDataError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ExceededAllowedEarlyData => f.write_str("cannot send any more early data"),
Self::Encrypt(e) => fmt::Display::fmt(e, f),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for EarlyDataError {}
#[derive(Debug)]
pub struct ClientConnectionData {
pub(super) early_data: EarlyData,
pub(super) ech_status: EchStatus,
}
impl ClientConnectionData {
fn new() -> Self {
Self {
early_data: EarlyData::new(),
ech_status: EchStatus::NotOffered,
}
}
}
impl crate::conn::SideData for ClientConnectionData {}