use alloc::boxed::Box;
use alloc::vec::Vec;
use core::fmt;
use core::fmt::{Debug, Formatter};
use core::marker::PhantomData;
use core::ops::{Deref, DerefMut};
#[cfg(feature = "std")]
use std::io;
use pki_types::{DnsName, UnixTime};
use super::hs;
#[cfg(feature = "std")]
use crate::WantsVerifier;
use crate::builder::ConfigBuilder;
use crate::common_state::{CommonState, Side};
#[cfg(feature = "std")]
use crate::common_state::{Protocol, State};
use crate::conn::{ConnectionCommon, ConnectionCore, UnbufferedConnectionCommon};
#[cfg(doc)]
use crate::crypto;
use crate::crypto::CryptoProvider;
use crate::enums::{CertificateType, CipherSuite, ProtocolVersion, SignatureScheme};
use crate::error::Error;
use crate::kernel::KernelConnection;
use crate::log::trace;
use crate::msgs::base::Payload;
use crate::msgs::handshake::{ClientHelloPayload, ProtocolName, ServerExtensionsInput};
use crate::msgs::message::Message;
use crate::suites::ExtractedSecrets;
use crate::sync::Arc;
#[cfg(feature = "std")]
use crate::time_provider::DefaultTimeProvider;
use crate::time_provider::TimeProvider;
use crate::vecbuf::ChunkVecBuffer;
use crate::{
DistinguishedName, KeyLog, NamedGroup, WantsVersions, compress, sign, verify, versions,
};
pub trait StoresServerSessions: Debug + Send + Sync {
fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool;
fn get(&self, key: &[u8]) -> Option<Vec<u8>>;
fn take(&self, key: &[u8]) -> Option<Vec<u8>>;
fn can_cache(&self) -> bool;
}
pub trait ProducesTickets: Debug + Send + Sync {
fn enabled(&self) -> bool;
fn lifetime(&self) -> u32;
fn encrypt(&self, plain: &[u8]) -> Option<Vec<u8>>;
fn decrypt(&self, cipher: &[u8]) -> Option<Vec<u8>>;
}
pub trait ResolvesServerCert: Debug + Send + Sync {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<sign::CertifiedKey>>;
fn only_raw_public_keys(&self) -> bool {
false
}
}
#[derive(Debug)]
pub struct ClientHello<'a> {
pub(super) server_name: &'a Option<DnsName<'a>>,
pub(super) signature_schemes: &'a [SignatureScheme],
pub(super) alpn: Option<&'a Vec<ProtocolName>>,
pub(super) server_cert_types: Option<&'a [CertificateType]>,
pub(super) client_cert_types: Option<&'a [CertificateType]>,
pub(super) cipher_suites: &'a [CipherSuite],
pub(super) certificate_authorities: Option<&'a [DistinguishedName]>,
pub(super) named_groups: Option<&'a [NamedGroup]>,
}
impl<'a> ClientHello<'a> {
pub fn server_name(&self) -> Option<&str> {
self.server_name
.as_ref()
.map(<DnsName<'_> as AsRef<str>>::as_ref)
}
pub fn signature_schemes(&self) -> &[SignatureScheme] {
self.signature_schemes
}
pub fn alpn(&self) -> Option<impl Iterator<Item = &'a [u8]>> {
self.alpn.map(|protocols| {
protocols
.iter()
.map(|proto| proto.as_ref())
})
}
pub fn cipher_suites(&self) -> &[CipherSuite] {
self.cipher_suites
}
pub fn server_cert_types(&self) -> Option<&'a [CertificateType]> {
self.server_cert_types
}
pub fn client_cert_types(&self) -> Option<&'a [CertificateType]> {
self.client_cert_types
}
pub fn certificate_authorities(&self) -> Option<&'a [DistinguishedName]> {
self.certificate_authorities
}
pub fn named_groups(&self) -> Option<&'a [NamedGroup]> {
self.named_groups
}
}
#[derive(Clone, Debug)]
pub struct ServerConfig {
pub(super) provider: Arc<CryptoProvider>,
pub ignore_client_order: bool,
pub max_fragment_size: Option<usize>,
pub session_storage: Arc<dyn StoresServerSessions>,
pub ticketer: Arc<dyn ProducesTickets>,
pub cert_resolver: Arc<dyn ResolvesServerCert>,
pub alpn_protocols: Vec<Vec<u8>>,
pub(super) versions: versions::EnabledVersions,
pub(super) verifier: Arc<dyn verify::ClientCertVerifier>,
pub key_log: Arc<dyn KeyLog>,
pub enable_secret_extraction: bool,
pub max_early_data_size: u32,
pub send_half_rtt_data: bool,
pub send_tls13_tickets: usize,
#[cfg(feature = "tls12")]
pub require_ems: bool,
pub time_provider: Arc<dyn TimeProvider>,
pub cert_compressors: Vec<&'static dyn compress::CertCompressor>,
pub cert_compression_cache: Arc<compress::CompressionCache>,
pub cert_decompressors: Vec<&'static dyn compress::CertDecompressor>,
}
impl ServerConfig {
#[cfg(feature = "std")]
pub fn builder() -> ConfigBuilder<Self, WantsVerifier> {
Self::builder_with_protocol_versions(versions::DEFAULT_VERSIONS)
}
#[cfg(feature = "std")]
pub fn builder_with_protocol_versions(
versions: &[&'static versions::SupportedProtocolVersion],
) -> ConfigBuilder<Self, WantsVerifier> {
Self::builder_with_provider(
CryptoProvider::get_default_or_install_from_crate_features().clone(),
)
.with_protocol_versions(versions)
.unwrap()
}
#[cfg(feature = "std")]
pub fn builder_with_provider(
provider: Arc<CryptoProvider>,
) -> ConfigBuilder<Self, WantsVersions> {
ConfigBuilder {
state: WantsVersions {},
provider,
time_provider: Arc::new(DefaultTimeProvider),
side: PhantomData,
}
}
pub fn builder_with_details(
provider: Arc<CryptoProvider>,
time_provider: Arc<dyn TimeProvider>,
) -> ConfigBuilder<Self, WantsVersions> {
ConfigBuilder {
state: WantsVersions {},
provider,
time_provider,
side: PhantomData,
}
}
pub fn fips(&self) -> bool {
#[cfg(feature = "tls12")]
{
self.provider.fips() && self.require_ems
}
#[cfg(not(feature = "tls12"))]
{
self.provider.fips()
}
}
pub fn crypto_provider(&self) -> &Arc<CryptoProvider> {
&self.provider
}
pub(crate) fn supports_version(&self, v: ProtocolVersion) -> bool {
self.versions.contains(v)
&& self
.provider
.cipher_suites
.iter()
.any(|cs| cs.version().version == v)
}
#[cfg(feature = "std")]
pub(crate) fn supports_protocol(&self, proto: Protocol) -> bool {
self.provider
.cipher_suites
.iter()
.any(|cs| cs.usable_for_protocol(proto))
}
pub(super) fn current_time(&self) -> Result<UnixTime, Error> {
self.time_provider
.current_time()
.ok_or(Error::FailedToGetCurrentTime)
}
}
#[cfg(feature = "std")]
mod connection {
use alloc::boxed::Box;
use core::fmt;
use core::fmt::{Debug, Formatter};
use core::ops::{Deref, DerefMut};
use std::io;
use super::{
Accepted, Accepting, EarlyDataState, ServerConfig, ServerConnectionData,
ServerExtensionsInput,
};
use crate::common_state::{CommonState, Context, Side};
use crate::conn::{ConnectionCommon, ConnectionCore};
use crate::error::Error;
use crate::server::hs;
use crate::suites::ExtractedSecrets;
use crate::sync::Arc;
use crate::vecbuf::ChunkVecBuffer;
pub struct ReadEarlyData<'a> {
early_data: &'a mut EarlyDataState,
}
impl<'a> ReadEarlyData<'a> {
fn new(early_data: &'a mut EarlyDataState) -> Self {
ReadEarlyData { early_data }
}
}
impl io::Read for ReadEarlyData<'_> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.early_data.read(buf)
}
#[cfg(read_buf)]
fn read_buf(&mut self, cursor: core::io::BorrowedCursor<'_>) -> io::Result<()> {
self.early_data.read_buf(cursor)
}
}
pub struct ServerConnection {
pub(super) inner: ConnectionCommon<ServerConnectionData>,
}
impl ServerConnection {
pub fn new(config: Arc<ServerConfig>) -> Result<Self, Error> {
Ok(Self {
inner: ConnectionCommon::from(ConnectionCore::for_server(
config,
ServerExtensionsInput::default(),
)?),
})
}
pub fn server_name(&self) -> Option<&str> {
self.inner.core.get_sni_str()
}
pub fn received_resumption_data(&self) -> Option<&[u8]> {
self.inner
.core
.data
.received_resumption_data
.as_ref()
.map(|x| &x[..])
}
pub fn set_resumption_data(&mut self, data: &[u8]) {
assert!(data.len() < 2usize.pow(15));
self.inner.core.data.resumption_data = data.into();
}
pub fn reject_early_data(&mut self) {
self.inner.core.reject_early_data()
}
pub fn early_data(&mut self) -> Option<ReadEarlyData<'_>> {
let data = &mut self.inner.core.data;
if data.early_data.was_accepted() {
Some(ReadEarlyData::new(&mut data.early_data))
} else {
None
}
}
pub fn fips(&self) -> bool {
self.inner.core.common_state.fips
}
pub fn dangerous_extract_secrets(self) -> Result<ExtractedSecrets, Error> {
self.inner.dangerous_extract_secrets()
}
}
impl Debug for ServerConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("ServerConnection")
.finish()
}
}
impl Deref for ServerConnection {
type Target = ConnectionCommon<ServerConnectionData>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for ServerConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl From<ServerConnection> for crate::Connection {
fn from(conn: ServerConnection) -> Self {
Self::Server(conn)
}
}
pub struct Acceptor {
inner: Option<ConnectionCommon<ServerConnectionData>>,
}
impl Default for Acceptor {
fn default() -> Self {
Self {
inner: Some(
ConnectionCore::new(
Box::new(Accepting),
ServerConnectionData::default(),
CommonState::new(Side::Server),
)
.into(),
),
}
}
}
impl Acceptor {
pub fn read_tls(&mut self, rd: &mut dyn io::Read) -> Result<usize, io::Error> {
match &mut self.inner {
Some(conn) => conn.read_tls(rd),
None => Err(io::Error::new(
io::ErrorKind::Other,
"acceptor cannot read after successful acceptance",
)),
}
}
pub fn accept(&mut self) -> Result<Option<Accepted>, (Error, AcceptedAlert)> {
let Some(mut connection) = self.inner.take() else {
return Err((
Error::General("Acceptor polled after completion".into()),
AcceptedAlert::empty(),
));
};
let message = match connection.first_handshake_message() {
Ok(Some(msg)) => msg,
Ok(None) => {
self.inner = Some(connection);
return Ok(None);
}
Err(err) => return Err((err, AcceptedAlert::from(connection))),
};
let mut cx = Context::from(&mut connection);
let sig_schemes = match hs::process_client_hello(&message, false, &mut cx) {
Ok((_, sig_schemes)) => sig_schemes,
Err(err) => {
return Err((err, AcceptedAlert::from(connection)));
}
};
Ok(Some(Accepted {
connection,
message,
sig_schemes,
}))
}
}
pub struct AcceptedAlert(ChunkVecBuffer);
impl AcceptedAlert {
pub(super) fn empty() -> Self {
Self(ChunkVecBuffer::new(None))
}
pub fn write(&mut self, wr: &mut dyn io::Write) -> Result<usize, io::Error> {
self.0.write_to(wr)
}
pub fn write_all(&mut self, wr: &mut dyn io::Write) -> Result<(), io::Error> {
while self.write(wr)? != 0 {}
Ok(())
}
}
impl From<ConnectionCommon<ServerConnectionData>> for AcceptedAlert {
fn from(conn: ConnectionCommon<ServerConnectionData>) -> Self {
Self(conn.core.common_state.sendable_tls)
}
}
impl Debug for AcceptedAlert {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("AcceptedAlert").finish()
}
}
}
#[cfg(feature = "std")]
pub use connection::{AcceptedAlert, Acceptor, ReadEarlyData, ServerConnection};
pub struct UnbufferedServerConnection {
inner: UnbufferedConnectionCommon<ServerConnectionData>,
}
impl UnbufferedServerConnection {
pub fn new(config: Arc<ServerConfig>) -> Result<Self, Error> {
Ok(Self {
inner: UnbufferedConnectionCommon::from(ConnectionCore::for_server(
config,
ServerExtensionsInput::default(),
)?),
})
}
#[deprecated = "dangerous_extract_secrets() does not support session tickets or \
key updates, use dangerous_into_kernel_connection() instead"]
pub fn dangerous_extract_secrets(self) -> Result<ExtractedSecrets, Error> {
self.inner.dangerous_extract_secrets()
}
pub fn dangerous_into_kernel_connection(
self,
) -> Result<(ExtractedSecrets, KernelConnection<ServerConnectionData>), Error> {
self.inner
.core
.dangerous_into_kernel_connection()
}
}
impl Deref for UnbufferedServerConnection {
type Target = UnbufferedConnectionCommon<ServerConnectionData>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for UnbufferedServerConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl UnbufferedConnectionCommon<ServerConnectionData> {
pub(crate) fn pop_early_data(&mut self) -> Option<Vec<u8>> {
self.core.data.early_data.pop()
}
pub(crate) fn peek_early_data(&self) -> Option<&[u8]> {
self.core.data.early_data.peek()
}
}
pub struct Accepted {
connection: ConnectionCommon<ServerConnectionData>,
message: Message<'static>,
sig_schemes: Vec<SignatureScheme>,
}
impl Accepted {
pub fn client_hello(&self) -> ClientHello<'_> {
let payload = Self::client_hello_payload(&self.message);
let ch = ClientHello {
server_name: &self.connection.core.data.sni,
signature_schemes: &self.sig_schemes,
alpn: payload.protocols.as_ref(),
server_cert_types: payload
.server_certificate_types
.as_deref(),
client_cert_types: payload
.client_certificate_types
.as_deref(),
cipher_suites: &payload.cipher_suites,
certificate_authorities: payload
.certificate_authority_names
.as_deref(),
named_groups: payload.named_groups.as_deref(),
};
trace!("Accepted::client_hello(): {ch:#?}");
ch
}
#[cfg(feature = "std")]
pub fn into_connection(
mut self,
config: Arc<ServerConfig>,
) -> Result<ServerConnection, (Error, AcceptedAlert)> {
if let Err(err) = self
.connection
.set_max_fragment_size(config.max_fragment_size)
{
return Err((err, AcceptedAlert::empty()));
}
self.connection.enable_secret_extraction = config.enable_secret_extraction;
let state = hs::ExpectClientHello::new(config, ServerExtensionsInput::default());
let mut cx = hs::ServerContext::from(&mut self.connection);
let ch = Self::client_hello_payload(&self.message);
let new = match state.with_certified_key(self.sig_schemes, ch, &self.message, &mut cx) {
Ok(new) => new,
Err(err) => return Err((err, AcceptedAlert::from(self.connection))),
};
self.connection.replace_state(new);
Ok(ServerConnection {
inner: self.connection,
})
}
fn client_hello_payload<'a>(message: &'a Message<'_>) -> &'a ClientHelloPayload {
match &message.payload {
crate::msgs::message::MessagePayload::Handshake { parsed, .. } => match &parsed.0 {
crate::msgs::handshake::HandshakePayload::ClientHello(ch) => ch,
_ => unreachable!(),
},
_ => unreachable!(),
}
}
}
impl Debug for Accepted {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Accepted").finish()
}
}
#[cfg(feature = "std")]
struct Accepting;
#[cfg(feature = "std")]
impl State<ServerConnectionData> for Accepting {
fn handle<'m>(
self: Box<Self>,
_cx: &mut hs::ServerContext<'_>,
_m: Message<'m>,
) -> Result<Box<dyn State<ServerConnectionData> + 'm>, Error>
where
Self: 'm,
{
Err(Error::General("unreachable state".into()))
}
fn into_owned(self: Box<Self>) -> hs::NextState<'static> {
self
}
}
#[derive(Default)]
pub(super) enum EarlyDataState {
#[default]
New,
Accepted {
received: ChunkVecBuffer,
left: usize,
},
Rejected,
}
impl EarlyDataState {
pub(super) fn reject(&mut self) {
*self = Self::Rejected;
}
pub(super) fn accept(&mut self, max_size: usize) {
*self = Self::Accepted {
received: ChunkVecBuffer::new(Some(max_size)),
left: max_size,
};
}
#[cfg(feature = "std")]
fn was_accepted(&self) -> bool {
matches!(self, Self::Accepted { .. })
}
pub(super) fn was_rejected(&self) -> bool {
matches!(self, Self::Rejected)
}
fn peek(&self) -> Option<&[u8]> {
match self {
Self::Accepted { received, .. } => received.peek(),
_ => None,
}
}
fn pop(&mut self) -> Option<Vec<u8>> {
match self {
Self::Accepted { received, .. } => received.pop(),
_ => None,
}
}
#[cfg(feature = "std")]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Accepted { received, .. } => received.read(buf),
_ => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
}
}
#[cfg(read_buf)]
fn read_buf(&mut self, cursor: core::io::BorrowedCursor<'_>) -> io::Result<()> {
match self {
Self::Accepted { received, .. } => received.read_buf(cursor),
_ => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
}
}
pub(super) fn take_received_plaintext(&mut self, bytes: Payload<'_>) -> bool {
let available = bytes.bytes().len();
let Self::Accepted { received, left } = self else {
return false;
};
if received.apply_limit(available) != available || available > *left {
return false;
}
received.append(bytes.into_vec());
*left -= available;
true
}
}
impl Debug for EarlyDataState {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::New => write!(f, "EarlyDataState::New"),
Self::Accepted { received, left } => write!(
f,
"EarlyDataState::Accepted {{ received: {}, left: {} }}",
received.len(),
left
),
Self::Rejected => write!(f, "EarlyDataState::Rejected"),
}
}
}
impl ConnectionCore<ServerConnectionData> {
pub(crate) fn for_server(
config: Arc<ServerConfig>,
extra_exts: ServerExtensionsInput<'static>,
) -> Result<Self, Error> {
let mut common = CommonState::new(Side::Server);
common.set_max_fragment_size(config.max_fragment_size)?;
common.enable_secret_extraction = config.enable_secret_extraction;
common.fips = config.fips();
Ok(Self::new(
Box::new(hs::ExpectClientHello::new(config, extra_exts)),
ServerConnectionData::default(),
common,
))
}
#[cfg(feature = "std")]
pub(crate) fn reject_early_data(&mut self) {
assert!(
self.common_state.is_handshaking(),
"cannot retroactively reject early data"
);
self.data.early_data.reject();
}
#[cfg(feature = "std")]
pub(crate) fn get_sni_str(&self) -> Option<&str> {
self.data.get_sni_str()
}
}
#[derive(Default, Debug)]
pub struct ServerConnectionData {
pub(super) sni: Option<DnsName<'static>>,
pub(super) received_resumption_data: Option<Vec<u8>>,
pub(super) resumption_data: Vec<u8>,
pub(super) early_data: EarlyDataState,
}
impl ServerConnectionData {
#[cfg(feature = "std")]
pub(super) fn get_sni_str(&self) -> Option<&str> {
self.sni.as_ref().map(AsRef::as_ref)
}
}
impl crate::conn::SideData for ServerConnectionData {}
#[cfg(feature = "std")]
#[cfg(test)]
mod tests {
use std::format;
use super::*;
#[test]
fn test_read_in_new_state() {
assert_eq!(
format!("{:?}", EarlyDataState::default().read(&mut [0u8; 5])),
"Err(Kind(BrokenPipe))"
);
}
#[cfg(read_buf)]
#[test]
fn test_read_buf_in_new_state() {
use core::io::BorrowedBuf;
let mut buf = [0u8; 5];
let mut buf: BorrowedBuf<'_> = buf.as_mut_slice().into();
assert_eq!(
format!("{:?}", EarlyDataState::default().read_buf(buf.unfilled())),
"Err(Kind(BrokenPipe))"
);
}
}