use crate::error::TLSError;
use crate::key;
use crate::keylog::{KeyLog, NoKeyLog};
#[cfg(feature = "logging")]
use crate::log::trace;
use crate::msgs::enums::ContentType;
use crate::msgs::enums::SignatureScheme;
use crate::msgs::enums::{AlertDescription, HandshakeType, ProtocolVersion};
use crate::msgs::handshake::ServerExtension;
use crate::msgs::message::Message;
use crate::session::{MiddleboxCCS, Session, SessionCommon};
use crate::sign;
use crate::suites::{SupportedCipherSuite, ALL_CIPHERSUITES};
use crate::verify;
use webpki;
use std::fmt;
use std::io::{self, IoSlice};
use std::sync::Arc;
#[macro_use]
mod hs;
mod common;
pub mod handy;
mod tls12;
mod tls13;
pub trait StoresServerSessions: Send + Sync {
fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool;
fn get(&self, key: &[u8]) -> Option<Vec<u8>>;
fn take(&self, key: &[u8]) -> Option<Vec<u8>>;
}
pub trait ProducesTickets: Send + Sync {
fn enabled(&self) -> bool;
fn get_lifetime(&self) -> u32;
fn encrypt(&self, plain: &[u8]) -> Option<Vec<u8>>;
fn decrypt(&self, cipher: &[u8]) -> Option<Vec<u8>>;
}
pub trait ResolvesServerCert: Send + Sync {
fn resolve(&self, client_hello: ClientHello) -> Option<sign::CertifiedKey>;
}
pub struct ClientHello<'a> {
server_name: Option<webpki::DNSNameRef<'a>>,
sigschemes: &'a [SignatureScheme],
alpn: Option<&'a [&'a [u8]]>,
}
impl<'a> ClientHello<'a> {
fn new(
server_name: Option<webpki::DNSNameRef<'a>>,
sigschemes: &'a [SignatureScheme],
alpn: Option<&'a [&'a [u8]]>,
) -> Self {
ClientHello {
server_name,
sigschemes,
alpn,
}
}
pub fn server_name(&self) -> Option<webpki::DNSNameRef> {
self.server_name
}
pub fn sigschemes(&self) -> &[SignatureScheme] {
self.sigschemes
}
pub fn alpn(&self) -> Option<&'a [&'a [u8]]> {
self.alpn
}
}
#[derive(Clone)]
pub struct ServerConfig {
pub ciphersuites: Vec<&'static SupportedCipherSuite>,
pub ignore_client_order: bool,
pub mtu: Option<usize>,
pub session_storage: Arc<dyn StoresServerSessions + Send + Sync>,
pub ticketer: Arc<dyn ProducesTickets>,
pub cert_resolver: Arc<dyn ResolvesServerCert>,
pub alpn_protocols: Vec<Vec<u8>>,
pub versions: Vec<ProtocolVersion>,
verifier: Arc<dyn verify::ClientCertVerifier>,
pub key_log: Arc<dyn KeyLog>,
#[cfg(feature = "quic")] #[doc(hidden)]
pub max_early_data_size: u32,
}
impl ServerConfig {
pub fn new(client_cert_verifier: Arc<dyn verify::ClientCertVerifier>) -> ServerConfig {
ServerConfig::with_ciphersuites(client_cert_verifier, &ALL_CIPHERSUITES)
}
pub fn with_ciphersuites(
client_cert_verifier: Arc<dyn verify::ClientCertVerifier>,
ciphersuites: &[&'static SupportedCipherSuite],
) -> ServerConfig {
ServerConfig {
ciphersuites: ciphersuites.to_vec(),
ignore_client_order: false,
mtu: None,
session_storage: handy::ServerSessionMemoryCache::new(256),
ticketer: Arc::new(handy::NeverProducesTickets {}),
alpn_protocols: Vec::new(),
cert_resolver: Arc::new(handy::FailResolveChain {}),
versions: vec![ProtocolVersion::TLSv1_3, ProtocolVersion::TLSv1_2],
verifier: client_cert_verifier,
key_log: Arc::new(NoKeyLog {}),
#[cfg(feature = "quic")]
max_early_data_size: 0,
}
}
#[doc(hidden)]
pub fn supports_version(&self, v: ProtocolVersion) -> bool {
self.versions.contains(&v)
&& self
.ciphersuites
.iter()
.any(|cs| cs.usable_for_version(v))
}
#[doc(hidden)]
pub fn get_verifier(&self) -> &dyn verify::ClientCertVerifier {
self.verifier.as_ref()
}
pub fn set_persistence(&mut self, persist: Arc<dyn StoresServerSessions + Send + Sync>) {
self.session_storage = persist;
}
pub fn set_single_cert(
&mut self,
cert_chain: Vec<key::Certificate>,
key_der: key::PrivateKey,
) -> Result<(), TLSError> {
let resolver = handy::AlwaysResolvesChain::new(cert_chain, &key_der)?;
self.cert_resolver = Arc::new(resolver);
Ok(())
}
pub fn set_single_cert_with_ocsp_and_sct(
&mut self,
cert_chain: Vec<key::Certificate>,
key_der: key::PrivateKey,
ocsp: Vec<u8>,
scts: Vec<u8>,
) -> Result<(), TLSError> {
let resolver =
handy::AlwaysResolvesChain::new_with_extras(cert_chain, &key_der, ocsp, scts)?;
self.cert_resolver = Arc::new(resolver);
Ok(())
}
pub fn set_protocols(&mut self, protocols: &[Vec<u8>]) {
self.alpn_protocols.clear();
self.alpn_protocols
.extend_from_slice(protocols);
}
pub fn set_client_certificate_verifier(
&mut self,
verifier: Arc<dyn verify::ClientCertVerifier>,
) {
self.verifier = verifier;
}
}
pub struct ServerSessionImpl {
pub config: Arc<ServerConfig>,
pub common: SessionCommon,
sni: Option<webpki::DNSName>,
pub alpn_protocol: Option<Vec<u8>>,
pub quic_params: Option<Vec<u8>>,
pub received_resumption_data: Option<Vec<u8>>,
pub resumption_data: Vec<u8>,
pub error: Option<TLSError>,
pub state: Option<Box<dyn hs::State + Send + Sync>>,
pub client_cert_chain: Option<Vec<key::Certificate>>,
pub reject_early_data: bool,
}
impl fmt::Debug for ServerSessionImpl {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ServerSessionImpl")
.finish()
}
}
impl ServerSessionImpl {
pub fn new(
server_config: &Arc<ServerConfig>,
extra_exts: Vec<ServerExtension>,
) -> ServerSessionImpl {
ServerSessionImpl {
config: server_config.clone(),
common: SessionCommon::new(server_config.mtu, false),
sni: None,
alpn_protocol: None,
quic_params: None,
received_resumption_data: None,
resumption_data: Vec::new(),
error: None,
state: Some(Box::new(hs::ExpectClientHello::new(
server_config,
extra_exts,
))),
client_cert_chain: None,
reject_early_data: false,
}
}
pub fn wants_read(&self) -> bool {
!self.common.has_readable_plaintext()
}
pub fn wants_write(&self) -> bool {
!self.common.sendable_tls.is_empty()
}
pub fn is_handshaking(&self) -> bool {
!self.common.traffic
}
pub fn set_buffer_limit(&mut self, len: usize) {
self.common.set_buffer_limit(len)
}
pub fn process_msg(&mut self, mut msg: Message) -> Result<(), TLSError> {
if let MiddleboxCCS::Drop = self.common.filter_tls13_ccs(&msg)? {
trace!("Dropping CCS");
return Ok(());
}
if self.common.record_layer.is_decrypting() {
let dm = self.common.decrypt_incoming(msg)?;
msg = dm;
}
if self
.common
.handshake_joiner
.want_message(&msg)
{
self.common
.handshake_joiner
.take_message(msg)
.ok_or_else(|| {
self.common
.send_fatal_alert(AlertDescription::DecodeError);
TLSError::CorruptMessagePayload(ContentType::Handshake)
})?;
return self.process_new_handshake_messages();
}
msg.decode_payload();
if msg.is_content_type(ContentType::Alert) {
return self.common.process_alert(msg);
}
self.process_main_protocol(msg)
}
pub fn process_new_handshake_messages(&mut self) -> Result<(), TLSError> {
while let Some(msg) = self
.common
.handshake_joiner
.frames
.pop_front()
{
self.process_main_protocol(msg)?;
}
Ok(())
}
fn queue_unexpected_alert(&mut self) {
self.common
.send_fatal_alert(AlertDescription::UnexpectedMessage);
}
fn maybe_send_unexpected_alert(&mut self, rc: hs::NextStateOrError) -> hs::NextStateOrError {
match rc {
Err(TLSError::InappropriateMessage { .. })
| Err(TLSError::InappropriateHandshakeMessage { .. }) => {
self.queue_unexpected_alert();
}
_ => {}
};
rc
}
pub fn process_main_protocol(&mut self, msg: Message) -> Result<(), TLSError> {
if self.common.traffic
&& !self.common.is_tls13()
&& msg.is_handshake_type(HandshakeType::ClientHello)
{
self.common
.send_warning_alert(AlertDescription::NoRenegotiation);
return Ok(());
}
let state = self.state.take().unwrap();
let maybe_next_state = state.handle(self, msg);
let next_state = self.maybe_send_unexpected_alert(maybe_next_state)?;
self.state = Some(next_state);
Ok(())
}
pub fn process_new_packets(&mut self) -> Result<(), TLSError> {
if let Some(ref err) = self.error {
return Err(err.clone());
}
if self.common.message_deframer.desynced {
return Err(TLSError::CorruptMessage);
}
while let Some(msg) = self
.common
.message_deframer
.frames
.pop_front()
{
match self.process_msg(msg) {
Ok(_) => {}
Err(err) => {
self.error = Some(err.clone());
return Err(err);
}
}
}
Ok(())
}
pub fn get_peer_certificates(&self) -> Option<Vec<key::Certificate>> {
self.client_cert_chain
.as_ref()
.map(|chain| chain.iter().cloned().collect())
}
pub fn get_alpn_protocol(&self) -> Option<&[u8]> {
self.alpn_protocol
.as_ref()
.map(AsRef::as_ref)
}
pub fn get_protocol_version(&self) -> Option<ProtocolVersion> {
self.common.negotiated_version
}
pub fn get_negotiated_ciphersuite(&self) -> Option<&'static SupportedCipherSuite> {
self.common.get_suite()
}
pub fn get_sni(&self) -> Option<&webpki::DNSName> {
self.sni.as_ref()
}
pub fn set_sni(&mut self, value: webpki::DNSName) {
assert!(self.sni.is_none());
self.sni = Some(value)
}
fn export_keying_material(
&self,
output: &mut [u8],
label: &[u8],
context: Option<&[u8]>,
) -> Result<(), TLSError> {
self.state
.as_ref()
.ok_or_else(|| TLSError::HandshakeNotComplete)
.and_then(|st| st.export_keying_material(output, label, context))
}
fn send_some_plaintext(&mut self, buf: &[u8]) -> usize {
let mut st = self.state.take();
st.as_mut()
.map(|st| st.perhaps_write_key_update(self));
self.state = st;
self.common.send_some_plaintext(buf)
}
}
#[derive(Debug)]
pub struct ServerSession {
pub(crate) imp: ServerSessionImpl,
}
impl ServerSession {
pub fn new(config: &Arc<ServerConfig>) -> ServerSession {
ServerSession {
imp: ServerSessionImpl::new(config, vec![]),
}
}
pub fn get_sni_hostname(&self) -> Option<&str> {
self.imp
.get_sni()
.map(|s| s.as_ref().into())
}
pub fn received_resumption_data(&self) -> Option<&[u8]> {
self.imp
.received_resumption_data
.as_ref()
.map(|x| &x[..])
}
pub fn set_resumption_data(&mut self, data: &[u8]) {
assert!(data.len() < 2usize.pow(15));
self.imp.resumption_data = data.into();
}
pub fn reject_early_data(&mut self) {
assert!(
self.is_handshaking(),
"cannot retroactively reject early data"
);
self.imp.reject_early_data = true;
}
}
impl Session for ServerSession {
fn read_tls(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
self.imp.common.read_tls(rd)
}
fn write_tls(&mut self, wr: &mut dyn io::Write) -> io::Result<usize> {
self.imp.common.write_tls(wr)
}
fn process_new_packets(&mut self) -> Result<(), TLSError> {
self.imp.process_new_packets()
}
fn wants_read(&self) -> bool {
self.imp.wants_read()
}
fn wants_write(&self) -> bool {
self.imp.wants_write()
}
fn is_handshaking(&self) -> bool {
self.imp.is_handshaking()
}
fn set_buffer_limit(&mut self, len: usize) {
self.imp.set_buffer_limit(len)
}
fn send_close_notify(&mut self) {
self.imp.common.send_close_notify()
}
fn get_peer_certificates(&self) -> Option<Vec<key::Certificate>> {
self.imp.get_peer_certificates()
}
fn get_alpn_protocol(&self) -> Option<&[u8]> {
self.imp.get_alpn_protocol()
}
fn get_protocol_version(&self) -> Option<ProtocolVersion> {
self.imp.get_protocol_version()
}
fn export_keying_material(
&self,
output: &mut [u8],
label: &[u8],
context: Option<&[u8]>,
) -> Result<(), TLSError> {
self.imp
.export_keying_material(output, label, context)
}
fn get_negotiated_ciphersuite(&self) -> Option<&'static SupportedCipherSuite> {
self.imp.get_negotiated_ciphersuite()
}
}
impl io::Read for ServerSession {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.imp.common.read(buf)
}
}
impl io::Write for ServerSession {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(self.imp.send_some_plaintext(buf))
}
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
let mut sz = 0;
for buf in bufs {
sz += self.imp.send_some_plaintext(buf);
}
Ok(sz)
}
fn flush(&mut self) -> io::Result<()> {
self.imp.common.flush_plaintext();
Ok(())
}
}