use crate::builder::{ConfigBuilder, WantsCipherSuites};
use crate::conn::{CommonState, ConnectionCommon, Protocol, Side};
use crate::error::Error;
use crate::kx::SupportedKxGroup;
#[cfg(feature = "logging")]
use crate::log::trace;
#[cfg(feature = "quic")]
use crate::msgs::enums::AlertDescription;
use crate::msgs::enums::CipherSuite;
use crate::msgs::enums::ProtocolVersion;
use crate::msgs::enums::SignatureScheme;
use crate::msgs::handshake::ClientExtension;
use crate::sign;
use crate::suites::SupportedCipherSuite;
use crate::verify;
use crate::versions;
use crate::KeyLog;
use super::hs;
#[cfg(feature = "quic")]
use crate::quic;
use std::convert::TryFrom;
use std::error::Error as StdError;
use std::marker::PhantomData;
use std::net::IpAddr;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::{fmt, io, mem};
pub trait StoresClientSessions: Send + Sync {
fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool;
fn get(&self, key: &[u8]) -> Option<Vec<u8>>;
}
pub trait ResolvesClientCert: Send + Sync {
fn resolve(
&self,
acceptable_issuers: &[&[u8]],
sigschemes: &[SignatureScheme],
) -> Option<Arc<sign::CertifiedKey>>;
fn has_certs(&self) -> bool;
}
#[derive(Clone)]
pub struct ClientConfig {
pub(super) cipher_suites: Vec<SupportedCipherSuite>,
pub(super) kx_groups: Vec<&'static SupportedKxGroup>,
pub alpn_protocols: Vec<Vec<u8>>,
pub session_storage: Arc<dyn StoresClientSessions>,
pub max_fragment_size: Option<usize>,
pub client_auth_cert_resolver: Arc<dyn ResolvesClientCert>,
pub enable_tickets: bool,
pub(super) versions: versions::EnabledVersions,
pub enable_sni: bool,
pub(super) verifier: Arc<dyn verify::ServerCertVerifier>,
pub key_log: Arc<dyn KeyLog>,
pub enable_early_data: bool,
}
impl ClientConfig {
pub fn builder() -> ConfigBuilder<Self, WantsCipherSuites> {
ConfigBuilder {
state: WantsCipherSuites(()),
side: PhantomData::default(),
}
}
#[doc(hidden)]
pub fn supports_version(&self, v: ProtocolVersion) -> bool {
self.versions.contains(v)
&& self
.cipher_suites
.iter()
.any(|cs| cs.version().version == v)
}
#[cfg(feature = "dangerous_configuration")]
pub fn dangerous(&mut self) -> danger::DangerousClientConfig {
danger::DangerousClientConfig { cfg: self }
}
pub(super) fn find_cipher_suite(&self, suite: CipherSuite) -> Option<SupportedCipherSuite> {
self.cipher_suites
.iter()
.copied()
.find(|&scs| scs.suite() == suite)
}
}
#[non_exhaustive]
#[derive(Debug, PartialEq, Clone)]
pub enum ServerName {
DnsName(verify::DnsName),
IpAddress(IpAddr),
}
impl ServerName {
pub(crate) fn for_sni(&self) -> Option<webpki::DnsNameRef> {
match self {
Self::DnsName(dns_name) => Some(dns_name.0.as_ref()),
Self::IpAddress(_) => None,
}
}
pub(crate) fn encode(&self) -> Vec<u8> {
enum UniqueTypeCode {
DnsName = 0x01,
IpAddr = 0x02,
}
match self {
Self::DnsName(dns_name) => {
let bytes = dns_name.0.as_ref();
let mut r = Vec::with_capacity(2 + bytes.as_ref().len());
r.push(UniqueTypeCode::DnsName as u8);
r.push(bytes.as_ref().len() as u8);
r.extend_from_slice(bytes.as_ref());
r
}
Self::IpAddress(address) => {
let string = address.to_string();
let bytes = string.as_bytes();
let mut r = Vec::with_capacity(2 + bytes.len());
r.push(UniqueTypeCode::IpAddr as u8);
r.push(bytes.len() as u8);
r.extend_from_slice(bytes);
r
}
}
}
}
impl TryFrom<&str> for ServerName {
type Error = InvalidDnsNameError;
fn try_from(s: &str) -> Result<Self, Self::Error> {
match webpki::DnsNameRef::try_from_ascii_str(s) {
Ok(dns) => Ok(Self::DnsName(verify::DnsName(dns.into()))),
Err(webpki::InvalidDnsNameError) => match s.parse() {
Ok(ip) => Ok(Self::IpAddress(ip)),
Err(_) => Err(InvalidDnsNameError),
},
}
}
}
#[derive(Debug)]
pub struct InvalidDnsNameError;
impl fmt::Display for InvalidDnsNameError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("invalid dns name")
}
}
impl StdError for InvalidDnsNameError {}
#[cfg(feature = "dangerous_configuration")]
pub(super) mod danger {
use std::sync::Arc;
use super::verify::ServerCertVerifier;
use super::ClientConfig;
pub struct DangerousClientConfig<'a> {
pub cfg: &'a mut ClientConfig,
}
impl<'a> DangerousClientConfig<'a> {
pub fn set_certificate_verifier(&mut self, verifier: Arc<dyn ServerCertVerifier>) {
self.cfg.verifier = verifier;
}
}
}
#[derive(Debug, PartialEq)]
enum EarlyDataState {
Disabled,
Ready,
Accepted,
AcceptedFinished,
Rejected,
}
pub(super) struct EarlyData {
state: EarlyDataState,
left: usize,
}
impl EarlyData {
fn new() -> Self {
Self {
left: 0,
state: EarlyDataState::Disabled,
}
}
pub(super) fn is_enabled(&self) -> bool {
matches!(self.state, EarlyDataState::Ready | EarlyDataState::Accepted)
}
fn is_accepted(&self) -> bool {
matches!(
self.state,
EarlyDataState::Accepted | EarlyDataState::AcceptedFinished
)
}
pub(super) fn enable(&mut self, max_data: usize) {
assert_eq!(self.state, EarlyDataState::Disabled);
self.state = EarlyDataState::Ready;
self.left = max_data;
}
pub(super) fn rejected(&mut self) {
trace!("EarlyData rejected");
self.state = EarlyDataState::Rejected;
}
pub(super) fn accepted(&mut self) {
trace!("EarlyData accepted");
assert_eq!(self.state, EarlyDataState::Ready);
self.state = EarlyDataState::Accepted;
}
pub(super) fn finished(&mut self) {
trace!("EarlyData finished");
self.state = match self.state {
EarlyDataState::Accepted => EarlyDataState::AcceptedFinished,
_ => panic!("bad EarlyData state"),
}
}
fn check_write(&mut self, sz: usize) -> io::Result<usize> {
match self.state {
EarlyDataState::Disabled => unreachable!(),
EarlyDataState::Ready | EarlyDataState::Accepted => {
let take = if self.left < sz {
mem::replace(&mut self.left, 0)
} else {
self.left -= sz;
sz
};
Ok(take)
}
EarlyDataState::Rejected | EarlyDataState::AcceptedFinished => {
Err(io::Error::from(io::ErrorKind::InvalidInput))
}
}
}
fn bytes_left(&self) -> usize {
self.left
}
}
pub struct WriteEarlyData<'a> {
sess: &'a mut ClientConnection,
}
impl<'a> WriteEarlyData<'a> {
fn new(sess: &'a mut ClientConnection) -> WriteEarlyData<'a> {
WriteEarlyData { sess }
}
pub fn bytes_left(&self) -> usize {
self.sess
.inner
.data
.early_data
.bytes_left()
}
}
impl<'a> io::Write for WriteEarlyData<'a> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.sess.write_early_data(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
pub struct ClientConnection {
inner: ConnectionCommon<ClientConnectionData>,
}
impl fmt::Debug for ClientConnection {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ClientConnection")
.finish()
}
}
impl ClientConnection {
pub fn new(config: Arc<ClientConfig>, name: ServerName) -> Result<Self, Error> {
Self::new_inner(config, name, Vec::new(), Protocol::Tcp)
}
fn new_inner(
config: Arc<ClientConfig>,
name: ServerName,
extra_exts: Vec<ClientExtension>,
proto: Protocol,
) -> Result<Self, Error> {
let mut common_state = CommonState::new(config.max_fragment_size, Side::Client)?;
common_state.protocol = proto;
let mut data = ClientConnectionData::new();
let mut cx = hs::ClientContext {
common: &mut common_state,
data: &mut data,
};
let state = hs::start_handshake(name, extra_exts, config, &mut cx)?;
let inner = ConnectionCommon::new(state, data, common_state);
Ok(Self { inner })
}
pub fn early_data(&mut self) -> Option<WriteEarlyData> {
if self.inner.data.early_data.is_enabled() {
Some(WriteEarlyData::new(self))
} else {
None
}
}
pub fn is_early_data_accepted(&self) -> bool {
self.inner.data.early_data.is_accepted()
}
fn write_early_data(&mut self, data: &[u8]) -> io::Result<usize> {
self.inner
.data
.early_data
.check_write(data.len())
.map(|sz| {
self.inner
.common_state
.send_early_plaintext(&data[..sz])
})
}
}
impl Deref for ClientConnection {
type Target = ConnectionCommon<ClientConnectionData>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for ClientConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
#[doc(hidden)]
impl<'a> TryFrom<&'a mut crate::Connection> for &'a mut ClientConnection {
type Error = ();
fn try_from(value: &'a mut crate::Connection) -> Result<Self, Self::Error> {
use crate::Connection::*;
match value {
Client(conn) => Ok(conn),
Server(_) => Err(()),
}
}
}
impl From<ClientConnection> for crate::Connection {
fn from(conn: ClientConnection) -> Self {
Self::Client(conn)
}
}
pub struct ClientConnectionData {
pub(super) early_data: EarlyData,
pub(super) resumption_ciphersuite: Option<SupportedCipherSuite>,
}
impl ClientConnectionData {
fn new() -> Self {
Self {
early_data: EarlyData::new(),
resumption_ciphersuite: None,
}
}
}
impl crate::conn::SideData for ClientConnectionData {}
#[cfg(feature = "quic")]
impl quic::QuicExt for ClientConnection {
fn quic_transport_parameters(&self) -> Option<&[u8]> {
self.inner
.common_state
.quic
.params
.as_ref()
.map(|v| v.as_ref())
}
fn zero_rtt_keys(&self) -> Option<quic::DirectionalKeys> {
Some(quic::DirectionalKeys::new(
self.inner
.data
.resumption_ciphersuite
.and_then(|suite| suite.tls13())?,
self.inner
.common_state
.quic
.early_secret
.as_ref()?,
))
}
fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error> {
self.inner.read_quic_hs(plaintext)
}
fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<quic::KeyChange> {
quic::write_hs(&mut self.inner.common_state, buf)
}
fn alert(&self) -> Option<AlertDescription> {
self.inner.common_state.quic.alert
}
}
#[cfg(feature = "quic")]
pub trait ClientQuicExt {
fn new_quic(
config: Arc<ClientConfig>,
quic_version: quic::Version,
name: ServerName,
params: Vec<u8>,
) -> Result<ClientConnection, Error> {
if !config.supports_version(ProtocolVersion::TLSv1_3) {
return Err(Error::General(
"TLS 1.3 support is required for QUIC".into(),
));
}
let ext = match quic_version {
quic::Version::V1Draft => ClientExtension::TransportParametersDraft(params),
quic::Version::V1 => ClientExtension::TransportParameters(params),
};
ClientConnection::new_inner(config, name, vec![ext], Protocol::Quic)
}
}
#[cfg(feature = "quic")]
impl ClientQuicExt for ClientConnection {}