use crate::builder::{ConfigBuilder, WantsCipherSuites};
use crate::conn::{CommonState, ConnectionCommon, Side, State};
use crate::error::Error;
use crate::kx::SupportedKxGroup;
#[cfg(feature = "logging")]
use crate::log::trace;
use crate::msgs::base::{Payload, PayloadU8};
#[cfg(feature = "quic")]
use crate::msgs::enums::AlertDescription;
use crate::msgs::enums::ProtocolVersion;
use crate::msgs::enums::SignatureScheme;
use crate::msgs::handshake::{ClientHelloPayload, ServerExtension};
use crate::msgs::message::Message;
use crate::suites::SupportedCipherSuite;
use crate::vecbuf::ChunkVecBuffer;
use crate::verify;
use crate::KeyLog;
#[cfg(feature = "quic")]
use crate::{conn::Protocol, quic};
use crate::{sign, CipherSuite};
use super::hs;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::{fmt, io};
pub trait StoresServerSessions: Send + Sync {
fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool;
fn get(&self, key: &[u8]) -> Option<Vec<u8>>;
fn take(&self, key: &[u8]) -> Option<Vec<u8>>;
fn can_cache(&self) -> bool;
}
pub trait ProducesTickets: Send + Sync {
fn enabled(&self) -> bool;
fn lifetime(&self) -> u32;
fn encrypt(&self, plain: &[u8]) -> Option<Vec<u8>>;
fn decrypt(&self, cipher: &[u8]) -> Option<Vec<u8>>;
}
pub trait ResolvesServerCert: Send + Sync {
fn resolve(&self, client_hello: ClientHello) -> Option<Arc<sign::CertifiedKey>>;
}
pub struct ClientHello<'a> {
server_name: &'a Option<webpki::DnsName>,
signature_schemes: &'a [SignatureScheme],
alpn: Option<&'a Vec<PayloadU8>>,
cipher_suites: &'a [CipherSuite],
}
impl<'a> ClientHello<'a> {
pub(super) fn new(
server_name: &'a Option<webpki::DnsName>,
signature_schemes: &'a [SignatureScheme],
alpn: Option<&'a Vec<PayloadU8>>,
cipher_suites: &'a [CipherSuite],
) -> Self {
trace!("sni {:?}", server_name);
trace!("sig schemes {:?}", signature_schemes);
trace!("alpn protocols {:?}", alpn);
trace!("cipher suites {:?}", cipher_suites);
ClientHello {
server_name,
signature_schemes,
alpn,
cipher_suites,
}
}
pub fn server_name(&self) -> Option<&str> {
self.server_name
.as_ref()
.map(<webpki::DnsName as AsRef<str>>::as_ref)
}
pub fn signature_schemes(&self) -> &[SignatureScheme] {
self.signature_schemes
}
pub fn alpn(&self) -> Option<impl Iterator<Item = &'a [u8]>> {
self.alpn.map(|protocols| {
protocols
.iter()
.map(|proto| proto.0.as_slice())
})
}
pub fn cipher_suites(&self) -> &[CipherSuite] {
self.cipher_suites
}
}
#[derive(Clone)]
pub struct ServerConfig {
pub(super) cipher_suites: Vec<SupportedCipherSuite>,
pub(super) kx_groups: Vec<&'static SupportedKxGroup>,
pub ignore_client_order: bool,
pub max_fragment_size: Option<usize>,
pub session_storage: Arc<dyn StoresServerSessions + Send + Sync>,
pub ticketer: Arc<dyn ProducesTickets>,
pub cert_resolver: Arc<dyn ResolvesServerCert>,
pub alpn_protocols: Vec<Vec<u8>>,
pub(super) versions: crate::versions::EnabledVersions,
pub(super) verifier: Arc<dyn verify::ClientCertVerifier>,
pub key_log: Arc<dyn KeyLog>,
pub max_early_data_size: u32,
pub send_half_rtt_data: bool,
}
impl ServerConfig {
pub fn builder() -> ConfigBuilder<Self, WantsCipherSuites> {
ConfigBuilder {
state: WantsCipherSuites(()),
side: PhantomData::default(),
}
}
#[doc(hidden)]
pub fn supports_version(&self, v: ProtocolVersion) -> bool {
self.versions.contains(v)
&& self
.cipher_suites
.iter()
.any(|cs| cs.version().version == v)
}
}
pub struct ReadEarlyData<'a> {
early_data: &'a mut EarlyDataState,
}
impl<'a> ReadEarlyData<'a> {
fn new(early_data: &'a mut EarlyDataState) -> Self {
ReadEarlyData { early_data }
}
}
impl<'a> std::io::Read for ReadEarlyData<'a> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.early_data.read(buf)
}
#[cfg(read_buf)]
fn read_buf(&mut self, buf: &mut io::ReadBuf<'_>) -> io::Result<()> {
self.early_data.read_buf(buf)
}
}
pub struct ServerConnection {
inner: ConnectionCommon<ServerConnectionData>,
}
impl ServerConnection {
pub fn new(config: Arc<ServerConfig>) -> Result<Self, Error> {
Self::from_config(config, vec![])
}
fn from_config(
config: Arc<ServerConfig>,
extra_exts: Vec<ServerExtension>,
) -> Result<Self, Error> {
let common = CommonState::new(config.max_fragment_size, Side::Server)?;
Ok(Self {
inner: ConnectionCommon::new(
Box::new(hs::ExpectClientHello::new(config, extra_exts)),
ServerConnectionData::default(),
common,
),
})
}
pub fn sni_hostname(&self) -> Option<&str> {
self.inner.data.get_sni_str()
}
pub fn received_resumption_data(&self) -> Option<&[u8]> {
self.inner
.data
.received_resumption_data
.as_ref()
.map(|x| &x[..])
}
pub fn set_resumption_data(&mut self, data: &[u8]) {
assert!(data.len() < 2usize.pow(15));
self.inner.data.resumption_data = data.into();
}
pub fn reject_early_data(&mut self) {
assert!(
self.is_handshaking(),
"cannot retroactively reject early data"
);
self.inner.data.early_data.reject();
}
pub fn early_data(&mut self) -> Option<ReadEarlyData> {
if self
.inner
.data
.early_data
.was_accepted()
{
Some(ReadEarlyData::new(&mut self.inner.data.early_data))
} else {
None
}
}
}
impl fmt::Debug for ServerConnection {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ServerConnection")
.finish()
}
}
impl Deref for ServerConnection {
type Target = ConnectionCommon<ServerConnectionData>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for ServerConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl From<ServerConnection> for crate::Connection {
fn from(conn: ServerConnection) -> Self {
Self::Server(conn)
}
}
pub struct Acceptor {
inner: Option<ConnectionCommon<ServerConnectionData>>,
}
impl Acceptor {
pub fn new() -> Result<Self, Error> {
let common = CommonState::new(None, Side::Server)?;
let state = Box::new(Accepting);
Ok(Self {
inner: Some(ConnectionCommon::new(state, Default::default(), common)),
})
}
pub fn wants_read(&self) -> bool {
self.inner
.as_ref()
.map(|conn| conn.common_state.wants_read())
.unwrap_or(false)
}
pub fn read_tls(&mut self, rd: &mut dyn io::Read) -> Result<usize, io::Error> {
match &mut self.inner {
Some(conn) => conn.read_tls(rd),
None => Err(io::Error::new(
io::ErrorKind::Other,
"acceptor cannot read after successful acceptance",
)),
}
}
pub fn accept(&mut self) -> Result<Option<Accepted>, Error> {
let mut connection = match self.inner.take() {
Some(conn) => conn,
None => {
return Err(Error::General(
"cannot accept after successful acceptance".into(),
));
}
};
let message = match connection.first_handshake_message() {
Ok(Some(msg)) => msg,
Ok(None) => {
self.inner = Some(connection);
return Ok(None);
}
Err(e) => {
self.inner = Some(connection);
return Err(e);
}
};
let supported_cipher_suites = &crate::ALL_CIPHER_SUITES;
let (_, sig_schemes) = hs::process_client_hello(
&message,
false,
supported_cipher_suites,
&mut connection.common_state,
&mut connection.data,
)?;
Ok(Some(Accepted {
connection,
message,
sig_schemes,
}))
}
}
pub struct Accepted {
connection: ConnectionCommon<ServerConnectionData>,
message: Message,
sig_schemes: Vec<SignatureScheme>,
}
impl Accepted {
pub fn client_hello(&self) -> ClientHello<'_> {
let payload = Self::client_hello_payload(&self.message);
ClientHello::new(
&self.connection.data.sni,
&self.sig_schemes,
payload.get_alpn_extension(),
&payload.cipher_suites,
)
}
pub fn into_connection(mut self, config: Arc<ServerConfig>) -> Result<ServerConnection, Error> {
self.connection
.common_state
.set_max_fragment_size(config.max_fragment_size)?;
let state = hs::ExpectClientHello::new(config, Vec::new());
let mut cx = hs::ServerContext {
common: &mut self.connection.common_state,
data: &mut self.connection.data,
};
let new = state.with_certified_key(
self.sig_schemes,
Self::client_hello_payload(&self.message),
&self.message,
&mut cx,
)?;
self.connection.replace_state(new);
Ok(ServerConnection {
inner: self.connection,
})
}
fn client_hello_payload(message: &Message) -> &ClientHelloPayload {
match &message.payload {
crate::msgs::message::MessagePayload::Handshake { parsed, .. } => match &parsed.payload
{
crate::msgs::handshake::HandshakePayload::ClientHello(ch) => ch,
_ => unreachable!(),
},
_ => unreachable!(),
}
}
}
struct Accepting;
impl State<ServerConnectionData> for Accepting {
fn handle(
self: Box<Self>,
_cx: &mut hs::ServerContext<'_>,
_m: Message,
) -> Result<Box<dyn State<ServerConnectionData>>, Error> {
Err(Error::General("unreachable state".into()))
}
}
pub(super) enum EarlyDataState {
New,
Accepted(ChunkVecBuffer),
Rejected,
}
impl Default for EarlyDataState {
fn default() -> Self {
Self::New
}
}
impl EarlyDataState {
pub(super) fn reject(&mut self) {
*self = Self::Rejected;
}
pub(super) fn accept(&mut self, max_size: usize) {
*self = Self::Accepted(ChunkVecBuffer::new(Some(max_size)));
}
fn was_accepted(&self) -> bool {
matches!(self, Self::Accepted(_))
}
pub(super) fn was_rejected(&self) -> bool {
matches!(self, Self::Rejected)
}
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Accepted(ref mut received) => received.read(buf),
_ => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
}
}
#[cfg(read_buf)]
fn read_buf(&mut self, buf: &mut io::ReadBuf<'_>) -> io::Result<()> {
match self {
Self::Accepted(ref mut received) => received.read_buf(buf),
_ => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
}
}
pub(super) fn take_received_plaintext(&mut self, bytes: Payload) -> bool {
let available = bytes.0.len();
match self {
Self::Accepted(ref mut received) if received.apply_limit(available) == available => {
received.append(bytes.0);
true
}
_ => false,
}
}
}
#[test]
fn test_read_in_new_state() {
assert_eq!(
format!("{:?}", EarlyDataState::default().read(&mut [0u8; 5])),
"Err(Kind(BrokenPipe))"
);
}
#[cfg(read_buf)]
#[test]
fn test_read_buf_in_new_state() {
assert_eq!(
format!(
"{:?}",
EarlyDataState::default().read_buf(&mut io::ReadBuf::new(&mut [0u8; 5]))
),
"Err(Kind(BrokenPipe))"
);
}
#[derive(Default)]
pub struct ServerConnectionData {
pub(super) sni: Option<webpki::DnsName>,
pub(super) received_resumption_data: Option<Vec<u8>>,
pub(super) resumption_data: Vec<u8>,
pub(super) early_data: EarlyDataState,
}
impl ServerConnectionData {
pub(super) fn get_sni_str(&self) -> Option<&str> {
self.sni.as_ref().map(AsRef::as_ref)
}
}
impl crate::conn::SideData for ServerConnectionData {}
#[cfg(feature = "quic")]
impl quic::QuicExt for ServerConnection {
fn quic_transport_parameters(&self) -> Option<&[u8]> {
self.inner
.common_state
.quic
.params
.as_ref()
.map(|v| v.as_ref())
}
fn zero_rtt_keys(&self) -> Option<quic::DirectionalKeys> {
Some(quic::DirectionalKeys::new(
self.inner
.common_state
.suite
.and_then(|suite| suite.tls13())?,
self.inner
.common_state
.quic
.early_secret
.as_ref()?,
))
}
fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error> {
self.inner.read_quic_hs(plaintext)
}
fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<quic::KeyChange> {
quic::write_hs(&mut self.inner.common_state, buf)
}
fn alert(&self) -> Option<AlertDescription> {
self.inner.common_state.quic.alert
}
}
#[cfg(feature = "quic")]
pub trait ServerQuicExt {
fn new_quic(
config: Arc<ServerConfig>,
quic_version: quic::Version,
params: Vec<u8>,
) -> Result<ServerConnection, Error> {
if !config.supports_version(ProtocolVersion::TLSv1_3) {
return Err(Error::General(
"TLS 1.3 support is required for QUIC".into(),
));
}
if config.max_early_data_size != 0 && config.max_early_data_size != 0xffff_ffff {
return Err(Error::General(
"QUIC sessions must set a max early data of 0 or 2^32-1".into(),
));
}
let ext = match quic_version {
quic::Version::V1Draft => ServerExtension::TransportParametersDraft(params),
quic::Version::V1 => ServerExtension::TransportParameters(params),
};
let mut new = ServerConnection::from_config(config, vec![ext])?;
new.inner.common_state.protocol = Protocol::Quic;
Ok(new)
}
}
#[cfg(feature = "quic")]
impl ServerQuicExt for ServerConnection {}