use crate::builder::{ConfigBuilder, WantsCipherSuites};
use crate::conn::{Connection, ConnectionCommon, IoState, PlaintextSink, Protocol, Reader, Writer};
use crate::error::Error;
use crate::key;
use crate::keylog::KeyLog;
use crate::kx::SupportedKxGroup;
#[cfg(feature = "logging")]
use crate::log::trace;
#[cfg(feature = "quic")]
use crate::msgs::enums::AlertDescription;
use crate::msgs::enums::CipherSuite;
use crate::msgs::enums::ProtocolVersion;
use crate::msgs::enums::SignatureScheme;
use crate::msgs::handshake::{CertificatePayload, ClientExtension};
use crate::sign;
use crate::suites::SupportedCipherSuite;
use crate::verify;
use crate::versions;
#[cfg(feature = "quic")]
use crate::quic;
use std::convert::TryFrom;
use std::fmt;
use std::io::{self, IoSlice};
use std::marker::PhantomData;
use std::mem;
use std::sync::Arc;
#[macro_use]
mod hs;
pub(super) mod builder;
mod common;
pub(super) mod handy;
mod tls12;
mod tls13;
pub trait StoresClientSessions: Send + Sync {
fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool;
fn get(&self, key: &[u8]) -> Option<Vec<u8>>;
}
pub trait ResolvesClientCert: Send + Sync {
fn resolve(
&self,
acceptable_issuers: &[&[u8]],
sigschemes: &[SignatureScheme],
) -> Option<Arc<sign::CertifiedKey>>;
fn has_certs(&self) -> bool;
}
#[derive(Clone)]
pub struct ClientConfig {
cipher_suites: Vec<SupportedCipherSuite>,
kx_groups: Vec<&'static SupportedKxGroup>,
pub alpn_protocols: Vec<Vec<u8>>,
pub session_storage: Arc<dyn StoresClientSessions>,
pub max_fragment_size: Option<usize>,
pub client_auth_cert_resolver: Arc<dyn ResolvesClientCert>,
pub enable_tickets: bool,
versions: versions::EnabledVersions,
pub enable_sni: bool,
verifier: Arc<dyn verify::ServerCertVerifier>,
pub key_log: Arc<dyn KeyLog>,
pub enable_early_data: bool,
}
impl ClientConfig {
pub fn builder() -> ConfigBuilder<Self, WantsCipherSuites> {
ConfigBuilder {
state: WantsCipherSuites,
side: PhantomData::default(),
}
}
#[doc(hidden)]
pub fn supports_version(&self, v: ProtocolVersion) -> bool {
self.versions.contains(v)
&& self
.cipher_suites
.iter()
.any(|cs| cs.version().version == v)
}
#[cfg(feature = "dangerous_configuration")]
pub fn dangerous(&mut self) -> danger::DangerousClientConfig {
danger::DangerousClientConfig { cfg: self }
}
fn find_cipher_suite(&self, suite: CipherSuite) -> Option<SupportedCipherSuite> {
self.cipher_suites
.iter()
.copied()
.find(|&scs| scs.suite() == suite)
}
}
#[non_exhaustive]
#[derive(Debug, PartialEq, Clone)]
pub enum ServerName {
DnsName(verify::DnsName),
}
impl ServerName {
pub(crate) fn for_sni(&self) -> Option<webpki::DnsNameRef> {
match self {
Self::DnsName(dns_name) => Some(dns_name.0.as_ref()),
}
}
pub(crate) fn encode(&self) -> Vec<u8> {
enum UniqueTypeCode {
DnsName = 0x01,
}
let Self::DnsName(dns_name) = self;
let bytes = dns_name.0.as_ref();
let mut r = Vec::with_capacity(2 + bytes.as_ref().len());
r.push(UniqueTypeCode::DnsName as u8);
r.push(bytes.as_ref().len() as u8);
r.extend_from_slice(bytes.as_ref());
r
}
}
impl TryFrom<&str> for ServerName {
type Error = InvalidDnsNameError;
fn try_from(s: &str) -> Result<Self, Self::Error> {
match webpki::DnsNameRef::try_from_ascii_str(s) {
Ok(dns) => Ok(Self::DnsName(verify::DnsName(dns.into()))),
Err(webpki::InvalidDnsNameError) => Err(InvalidDnsNameError),
}
}
}
#[derive(Debug)]
pub struct InvalidDnsNameError;
#[cfg(feature = "dangerous_configuration")]
pub(super) mod danger {
use std::sync::Arc;
use super::verify::ServerCertVerifier;
use super::ClientConfig;
pub struct DangerousClientConfig<'a> {
pub cfg: &'a mut ClientConfig,
}
impl<'a> DangerousClientConfig<'a> {
pub fn set_certificate_verifier(&mut self, verifier: Arc<dyn ServerCertVerifier>) {
self.cfg.verifier = verifier;
}
}
}
#[derive(Debug, PartialEq)]
enum EarlyDataState {
Disabled,
Ready,
Accepted,
AcceptedFinished,
Rejected,
}
pub(super) struct EarlyData {
state: EarlyDataState,
left: usize,
}
impl EarlyData {
fn new() -> Self {
Self {
left: 0,
state: EarlyDataState::Disabled,
}
}
fn is_enabled(&self) -> bool {
matches!(self.state, EarlyDataState::Ready | EarlyDataState::Accepted)
}
fn is_accepted(&self) -> bool {
matches!(
self.state,
EarlyDataState::Accepted | EarlyDataState::AcceptedFinished
)
}
fn enable(&mut self, max_data: usize) {
assert_eq!(self.state, EarlyDataState::Disabled);
self.state = EarlyDataState::Ready;
self.left = max_data;
}
fn rejected(&mut self) {
trace!("EarlyData rejected");
self.state = EarlyDataState::Rejected;
}
fn accepted(&mut self) {
trace!("EarlyData accepted");
assert_eq!(self.state, EarlyDataState::Ready);
self.state = EarlyDataState::Accepted;
}
fn finished(&mut self) {
trace!("EarlyData finished");
self.state = match self.state {
EarlyDataState::Accepted => EarlyDataState::AcceptedFinished,
_ => panic!("bad EarlyData state"),
}
}
fn check_write(&mut self, sz: usize) -> io::Result<usize> {
match self.state {
EarlyDataState::Disabled => unreachable!(),
EarlyDataState::Ready | EarlyDataState::Accepted => {
let take = if self.left < sz {
mem::replace(&mut self.left, 0)
} else {
self.left -= sz;
sz
};
Ok(take)
}
EarlyDataState::Rejected | EarlyDataState::AcceptedFinished => {
Err(io::Error::from(io::ErrorKind::InvalidInput))
}
}
}
fn bytes_left(&self) -> usize {
self.left
}
}
pub struct WriteEarlyData<'a> {
sess: &'a mut ClientConnection,
}
impl<'a> WriteEarlyData<'a> {
fn new(sess: &'a mut ClientConnection) -> WriteEarlyData<'a> {
WriteEarlyData { sess }
}
pub fn bytes_left(&self) -> usize {
self.sess.data.early_data.bytes_left()
}
}
impl<'a> io::Write for WriteEarlyData<'a> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.sess.write_early_data(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
pub struct ClientConnection {
common: ConnectionCommon,
state: Option<hs::NextState>,
data: ClientConnectionData,
}
impl fmt::Debug for ClientConnection {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ClientConnection")
.finish()
}
}
impl ClientConnection {
pub fn new(config: Arc<ClientConfig>, name: ServerName) -> Result<Self, Error> {
Self::new_inner(config, name, Vec::new(), Protocol::Tcp)
}
fn new_inner(
config: Arc<ClientConfig>,
name: ServerName,
extra_exts: Vec<ClientExtension>,
proto: Protocol,
) -> Result<Self, Error> {
let mut new = Self {
common: ConnectionCommon::new(config.max_fragment_size, true)?,
state: None,
data: ClientConnectionData::new(),
};
new.common.protocol = proto;
let mut cx = hs::ClientContext {
common: &mut new.common,
data: &mut new.data,
};
new.state = Some(hs::start_handshake(name, extra_exts, config, &mut cx)?);
Ok(new)
}
pub fn early_data(&mut self) -> Option<WriteEarlyData> {
if self.data.early_data.is_enabled() {
Some(WriteEarlyData::new(self))
} else {
None
}
}
pub fn is_early_data_accepted(&self) -> bool {
self.data.early_data.is_accepted()
}
fn write_early_data(&mut self, data: &[u8]) -> io::Result<usize> {
self.data
.early_data
.check_write(data.len())
.map(|sz| {
self.common
.send_early_plaintext(&data[..sz])
})
}
fn send_some_plaintext(&mut self, buf: &[u8]) -> usize {
let mut st = self.state.take();
if let Some(st) = st.as_mut() {
st.perhaps_write_key_update(&mut self.common);
}
self.state = st;
self.common.send_some_plaintext(buf)
}
}
impl Connection for ClientConnection {
fn read_tls(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
self.common.read_tls(rd)
}
fn write_tls(&mut self, wr: &mut dyn io::Write) -> io::Result<usize> {
self.common.write_tls(wr)
}
fn process_new_packets(&mut self) -> Result<IoState, Error> {
self.common
.process_new_packets(&mut self.state, &mut self.data)
}
fn wants_read(&self) -> bool {
self.common.wants_read()
}
fn wants_write(&self) -> bool {
!self.common.sendable_tls.is_empty()
}
fn is_handshaking(&self) -> bool {
!self.common.traffic
}
fn set_buffer_limit(&mut self, len: Option<usize>) {
self.common.set_buffer_limit(len)
}
fn send_close_notify(&mut self) {
self.common.send_close_notify()
}
fn peer_certificates(&self) -> Option<&[key::Certificate]> {
if self.data.server_cert_chain.is_empty() {
return None;
}
Some(&self.data.server_cert_chain)
}
fn alpn_protocol(&self) -> Option<&[u8]> {
self.common.get_alpn_protocol()
}
fn protocol_version(&self) -> Option<ProtocolVersion> {
self.common.negotiated_version
}
fn export_keying_material(
&self,
output: &mut [u8],
label: &[u8],
context: Option<&[u8]>,
) -> Result<(), Error> {
self.state
.as_ref()
.ok_or(Error::HandshakeNotComplete)
.and_then(|st| st.export_keying_material(output, label, context))
}
fn negotiated_cipher_suite(&self) -> Option<SupportedCipherSuite> {
self.common
.get_suite()
.or(self.data.resumption_ciphersuite)
}
fn writer(&mut self) -> Writer {
Writer::new(self)
}
fn reader(&mut self) -> Reader {
self.common.reader()
}
}
impl PlaintextSink for ClientConnection {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(self.send_some_plaintext(buf))
}
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
let mut sz = 0;
for buf in bufs {
sz += self.send_some_plaintext(buf);
}
Ok(sz)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
struct ClientConnectionData {
server_cert_chain: CertificatePayload,
early_data: EarlyData,
resumption_ciphersuite: Option<SupportedCipherSuite>,
}
impl ClientConnectionData {
fn new() -> Self {
Self {
server_cert_chain: Vec::new(),
early_data: EarlyData::new(),
resumption_ciphersuite: None,
}
}
}
#[cfg(feature = "quic")]
impl quic::QuicExt for ClientConnection {
fn quic_transport_parameters(&self) -> Option<&[u8]> {
self.common
.quic
.params
.as_ref()
.map(|v| v.as_ref())
}
fn zero_rtt_keys(&self) -> Option<quic::DirectionalKeys> {
Some(quic::DirectionalKeys::new(
self.data
.resumption_ciphersuite
.and_then(|suite| suite.tls13())?,
self.common.quic.early_secret.as_ref()?,
))
}
fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error> {
quic::read_hs(&mut self.common, plaintext)?;
self.common
.process_new_handshake_messages(&mut self.state, &mut self.data)
}
fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<quic::Keys> {
quic::write_hs(&mut self.common, buf)
}
fn alert(&self) -> Option<AlertDescription> {
self.common.quic.alert
}
fn next_1rtt_keys(&mut self) -> Option<quic::PacketKeySet> {
quic::next_1rtt_keys(&mut self.common)
}
}
#[cfg(feature = "quic")]
pub trait ClientQuicExt {
fn new_quic(
config: Arc<ClientConfig>,
quic_version: quic::Version,
name: ServerName,
params: Vec<u8>,
) -> Result<ClientConnection, Error> {
if !config.supports_version(ProtocolVersion::TLSv1_3) {
return Err(Error::General(
"TLS 1.3 support is required for QUIC".into(),
));
}
let ext = match quic_version {
quic::Version::V1Draft => ClientExtension::TransportParametersDraft(params),
quic::Version::V1 => ClientExtension::TransportParameters(params),
};
ClientConnection::new_inner(config, name, vec![ext], Protocol::Quic)
}
}
#[cfg(feature = "quic")]
impl ClientQuicExt for ClientConnection {}