use std::{collections::HashMap, fmt, io::Cursor, marker::PhantomData, sync::Arc};
use rustls::{
pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName},
server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier},
sign::CertifiedKey,
ClientConfig, RootCertStore, ServerConfig,
};
use tokio_rustls::{client, server, TlsAcceptor, TlsConnector};
use crate::{Error, Result};
pub struct TlsContextBuilder;
pub struct NoTrust;
pub struct HasTrust;
#[derive(Clone, Debug)]
pub struct TlsInfo {
peer_certificates: Vec<CertificateDer<'static>>,
selected_alpn_protocol: Option<Vec<u8>>,
server_name: Option<String>,
}
impl TlsInfo {
pub fn peer_certificates(&self) -> &[CertificateDer<'static>] {
&self.peer_certificates
}
pub fn selected_alpn_protocol(&self) -> Option<&[u8]> {
self.selected_alpn_protocol.as_deref()
}
pub fn server_name(&self) -> Option<&str> {
self.server_name.as_deref()
}
fn from_server_connection(connection: &rustls::ServerConnection) -> Self {
Self {
peer_certificates: connection
.peer_certificates()
.map(|certificates| certificates.to_vec())
.unwrap_or_default(),
selected_alpn_protocol: connection.alpn_protocol().map(Vec::from),
server_name: connection.server_name().map(str::to_string),
}
}
fn from_client_connection(connection: &rustls::ClientConnection, server_name: String) -> Self {
Self {
peer_certificates: connection
.peer_certificates()
.map(|certificates| certificates.to_vec())
.unwrap_or_default(),
selected_alpn_protocol: connection.alpn_protocol().map(Vec::from),
server_name: Some(server_name),
}
}
}
#[derive(Clone)]
pub struct ServerTlsContext {
config: Arc<ServerConfig>,
}
impl ServerTlsContext {
pub(crate) fn acceptor(&self) -> TlsAcceptor {
TlsAcceptor::from(self.config.clone())
}
pub(crate) fn info_for_stream<S>(&self, stream: &server::TlsStream<S>) -> Arc<TlsInfo> {
Arc::new(TlsInfo::from_server_connection(stream.get_ref().1))
}
}
#[derive(Clone)]
pub struct ClientTlsContext {
config: Arc<ClientConfig>,
server_name: Option<String>,
}
impl ClientTlsContext {
pub(crate) fn connector(&self) -> TlsConnector {
TlsConnector::from(self.config.clone())
}
pub(crate) fn server_name_for(&self, host: &str) -> Result<ResolvedServerName> {
let name = self.server_name.as_deref().unwrap_or(host);
let server_name = ServerName::try_from(name.to_string()).map_err(|err| {
tls_invalid_server_name(format!("invalid TLS server name `{name}`: {err}"))
})?;
Ok(ResolvedServerName {
display: name.to_string(),
server_name,
})
}
pub(crate) fn info_for_stream<S>(
&self,
stream: &client::TlsStream<S>,
server_name: String,
) -> Arc<TlsInfo> {
Arc::new(TlsInfo::from_client_connection(
stream.get_ref().1,
server_name,
))
}
}
pub(crate) struct ResolvedServerName {
pub(crate) display: String,
pub(crate) server_name: ServerName<'static>,
}
pub struct ServerTlsContextBuilder {
certificates: Vec<CertificateDer<'static>>,
private_key: Option<PrivateKeyDer<'static>>,
client_auth_roots: RootCertStore,
client_auth_mode: ClientAuthMode,
sni_identities: Vec<SniIdentity>,
alpn_protocols: Vec<Vec<u8>>,
errors: Vec<String>,
}
pub struct ClientTlsContextBuilder<TrustState> {
roots: RootCertStore,
verifier: Option<Arc<dyn rustls::client::danger::ServerCertVerifier>>,
client_identity: Option<ClientIdentity>,
alpn_protocols: Vec<Vec<u8>>,
server_name: Option<String>,
errors: Vec<String>,
_state: PhantomData<TrustState>,
}
struct ClientIdentity {
certificates: Vec<CertificateDer<'static>>,
private_key: PrivateKeyDer<'static>,
}
struct SniIdentity {
name: String,
certificates: Vec<CertificateDer<'static>>,
private_key: PrivateKeyDer<'static>,
}
#[derive(Clone, Copy)]
enum ClientAuthMode {
None,
Required,
Optional,
}
impl TlsContextBuilder {
pub fn for_server() -> ServerTlsContextBuilder {
ServerTlsContextBuilder {
certificates: Vec::new(),
private_key: None,
client_auth_roots: RootCertStore::empty(),
client_auth_mode: ClientAuthMode::None,
sni_identities: Vec::new(),
alpn_protocols: Vec::new(),
errors: Vec::new(),
}
}
pub fn for_client() -> ClientTlsContextBuilder<NoTrust> {
ClientTlsContextBuilder {
roots: RootCertStore::empty(),
verifier: None,
client_identity: None,
alpn_protocols: Vec::new(),
server_name: None,
errors: Vec::new(),
_state: PhantomData,
}
}
#[cfg(feature = "tls-native-roots")]
pub fn for_client_with_native_roots() -> ClientTlsContextBuilder<HasTrust> {
Self::for_client().native_roots()
}
#[cfg(feature = "tls-webpki-roots")]
pub fn for_client_with_webpki_roots() -> ClientTlsContextBuilder<HasTrust> {
Self::for_client().webpki_roots()
}
}
impl ServerTlsContextBuilder {
pub fn certificate_chain_pem(mut self, pem: impl AsRef<[u8]>) -> Self {
match parse_certificates_pem(pem.as_ref()) {
Ok(mut certificates) => self.certificates.append(&mut certificates),
Err(err) => self.errors.push(err),
}
self
}
pub fn certificate_der(mut self, der: impl Into<Vec<u8>>) -> Self {
self.certificates.push(CertificateDer::from(der.into()));
self
}
pub fn private_key_pem(mut self, pem: impl AsRef<[u8]>) -> Self {
match parse_private_key_pem(pem.as_ref()) {
Ok(private_key) => self.private_key = Some(private_key),
Err(err) => self.errors.push(err),
}
self
}
pub fn private_key_der(mut self, der: impl Into<Vec<u8>>) -> Self {
self.private_key = Some(PrivatePkcs8KeyDer::from(der.into()).into());
self
}
pub fn client_auth_required_pem(mut self, pem: impl AsRef<[u8]>) -> Self {
self.client_auth_mode = ClientAuthMode::Required;
match parse_certificates_pem(pem.as_ref()) {
Ok(certificates) => {
add_roots(&mut self.client_auth_roots, certificates, &mut self.errors)
}
Err(err) => self.errors.push(err),
}
self
}
pub fn client_auth_required_der(mut self, der: impl Into<Vec<u8>>) -> Self {
self.client_auth_mode = ClientAuthMode::Required;
add_roots(
&mut self.client_auth_roots,
vec![CertificateDer::from(der.into())],
&mut self.errors,
);
self
}
pub fn client_auth_optional_pem(mut self, pem: impl AsRef<[u8]>) -> Self {
self.client_auth_mode = ClientAuthMode::Optional;
match parse_certificates_pem(pem.as_ref()) {
Ok(certificates) => {
add_roots(&mut self.client_auth_roots, certificates, &mut self.errors)
}
Err(err) => self.errors.push(err),
}
self
}
pub fn client_auth_optional_der(mut self, der: impl Into<Vec<u8>>) -> Self {
self.client_auth_mode = ClientAuthMode::Optional;
add_roots(
&mut self.client_auth_roots,
vec![CertificateDer::from(der.into())],
&mut self.errors,
);
self
}
pub fn alpn_protocols<I, P>(mut self, protocols: I) -> Self
where
I: IntoIterator<Item = P>,
P: AsRef<[u8]>,
{
self.alpn_protocols = collect_alpn_protocols(protocols, &mut self.errors);
self
}
pub fn sni_certificate_pem(
mut self,
name: impl Into<String>,
certificate_chain: impl AsRef<[u8]>,
private_key: impl AsRef<[u8]>,
) -> Self {
let certificates = match parse_certificates_pem(certificate_chain.as_ref()) {
Ok(certificates) => certificates,
Err(err) => {
self.errors.push(err);
Vec::new()
}
};
let private_key = match parse_private_key_pem(private_key.as_ref()) {
Ok(private_key) => Some(private_key),
Err(err) => {
self.errors.push(err);
None
}
};
if let Some(private_key) = private_key {
self.sni_identities.push(SniIdentity {
name: name.into(),
certificates,
private_key,
});
}
self
}
pub fn sni_certificate_der(
mut self,
name: impl Into<String>,
certificate_chain: impl IntoIterator<Item = impl Into<Vec<u8>>>,
private_key: impl Into<Vec<u8>>,
) -> Self {
self.sni_identities.push(SniIdentity {
name: name.into(),
certificates: certificate_chain
.into_iter()
.map(|certificate| CertificateDer::from(certificate.into()))
.collect(),
private_key: PrivatePkcs8KeyDer::from(private_key.into()).into(),
});
self
}
pub fn build(self) -> Result<ServerTlsContext> {
if let Some(error) = first_error(self.errors) {
return Err(tls_config_error(error));
}
if self.certificates.is_empty() && self.sni_identities.is_empty() {
return Err(tls_config_error(
"server TLS context requires a certificate chain".to_string(),
));
}
if !self.certificates.is_empty() && self.private_key.is_none() {
return Err(tls_config_error(
"server TLS context requires a private key".to_string(),
));
}
if self.certificates.is_empty() && self.private_key.is_some() {
return Err(tls_config_error(
"server TLS context private key requires a certificate chain".to_string(),
));
}
let builder = ServerConfig::builder();
let builder = match self.client_auth_mode {
ClientAuthMode::None => builder.with_no_client_auth(),
ClientAuthMode::Required | ClientAuthMode::Optional => {
let mut verifier = WebPkiClientVerifier::builder(Arc::new(self.client_auth_roots));
if matches!(self.client_auth_mode, ClientAuthMode::Optional) {
verifier = verifier.allow_unauthenticated();
}
let verifier = verifier.build().map_err(|err| {
tls_config_error(format!("invalid client authentication roots: {err}"))
})?;
builder.with_client_cert_verifier(verifier)
}
};
let mut config = if self.sni_identities.is_empty() {
let private_key = self.private_key.ok_or_else(|| {
tls_config_error("server TLS context requires a private key".to_string())
})?;
builder
.with_single_cert(self.certificates, private_key)
.map_err(|err| tls_config_error(format!("invalid server TLS identity: {err}")))?
} else {
let resolver = build_sni_resolver(
builder.crypto_provider(),
self.certificates,
self.private_key,
self.sni_identities,
)?;
builder.with_cert_resolver(Arc::new(resolver))
};
config.alpn_protocols = self.alpn_protocols;
Ok(ServerTlsContext {
config: Arc::new(config),
})
}
}
impl<State> ClientTlsContextBuilder<State> {
pub fn server_name(mut self, server_name: impl Into<String>) -> Self {
self.server_name = Some(server_name.into());
self
}
pub fn alpn_protocols<I, P>(mut self, protocols: I) -> Self
where
I: IntoIterator<Item = P>,
P: AsRef<[u8]>,
{
self.alpn_protocols = collect_alpn_protocols(protocols, &mut self.errors);
self
}
pub fn root_certificate_pem(
mut self,
pem: impl AsRef<[u8]>,
) -> ClientTlsContextBuilder<HasTrust> {
match parse_certificates_pem(pem.as_ref()) {
Ok(certificates) => add_roots(&mut self.roots, certificates, &mut self.errors),
Err(err) => self.errors.push(err),
}
self.with_state()
}
pub fn root_certificate_der(
mut self,
der: impl Into<Vec<u8>>,
) -> ClientTlsContextBuilder<HasTrust> {
add_roots(
&mut self.roots,
vec![CertificateDer::from(der.into())],
&mut self.errors,
);
self.with_state()
}
#[cfg(feature = "tls-native-roots")]
pub fn native_roots(mut self) -> ClientTlsContextBuilder<HasTrust> {
let certificates = rustls_native_certs::load_native_certs();
if certificates.certs.is_empty() && !certificates.errors.is_empty() {
self.errors.push(format!(
"failed to load native root certificates: {:?}",
certificates.errors
));
}
add_roots(&mut self.roots, certificates.certs, &mut self.errors);
self.with_state()
}
#[cfg(feature = "tls-webpki-roots")]
pub fn webpki_roots(mut self) -> ClientTlsContextBuilder<HasTrust> {
self.roots
.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
self.with_state()
}
#[cfg(feature = "tls-dangerous")]
pub fn custom_verifier(
mut self,
verifier: Arc<dyn rustls::client::danger::ServerCertVerifier>,
) -> ClientTlsContextBuilder<HasTrust> {
self.verifier = Some(verifier);
self.with_state()
}
#[cfg(feature = "tls-dangerous")]
pub fn danger_accept_invalid_certs(self) -> ClientTlsContextBuilder<HasTrust> {
self.custom_verifier(Arc::new(DangerAcceptInvalidCerts))
}
fn with_state<NextState>(self) -> ClientTlsContextBuilder<NextState> {
ClientTlsContextBuilder {
roots: self.roots,
verifier: self.verifier,
client_identity: self.client_identity,
alpn_protocols: self.alpn_protocols,
server_name: self.server_name,
errors: self.errors,
_state: PhantomData,
}
}
}
impl ClientTlsContextBuilder<HasTrust> {
pub fn client_identity_pem(
mut self,
certificate_chain: impl AsRef<[u8]>,
private_key: impl AsRef<[u8]>,
) -> Self {
let certificates = match parse_certificates_pem(certificate_chain.as_ref()) {
Ok(certificates) => certificates,
Err(err) => {
self.errors.push(err);
Vec::new()
}
};
let private_key = match parse_private_key_pem(private_key.as_ref()) {
Ok(private_key) => Some(private_key),
Err(err) => {
self.errors.push(err);
None
}
};
if let Some(private_key) = private_key {
self.client_identity = Some(ClientIdentity {
certificates,
private_key,
});
}
self
}
pub fn client_identity_der(
mut self,
certificate_chain: impl IntoIterator<Item = impl Into<Vec<u8>>>,
private_key: impl Into<Vec<u8>>,
) -> Self {
self.client_identity = Some(ClientIdentity {
certificates: certificate_chain
.into_iter()
.map(|certificate| CertificateDer::from(certificate.into()))
.collect(),
private_key: PrivatePkcs8KeyDer::from(private_key.into()).into(),
});
self
}
pub fn build(self) -> Result<ClientTlsContext> {
if let Some(error) = first_error(self.errors) {
return Err(tls_config_error(error));
}
let builder = if let Some(verifier) = self.verifier {
ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(verifier)
} else {
if self.roots.is_empty() {
return Err(tls_config_error(
"client TLS context requires at least one root certificate".to_string(),
));
}
ClientConfig::builder().with_root_certificates(self.roots)
};
let mut config = if let Some(identity) = self.client_identity {
if identity.certificates.is_empty() {
return Err(tls_config_error(
"client TLS identity requires a certificate chain".to_string(),
));
}
builder
.with_client_auth_cert(identity.certificates, identity.private_key)
.map_err(|err| tls_config_error(format!("invalid client TLS identity: {err}")))?
} else {
builder.with_no_client_auth()
};
config.alpn_protocols = self.alpn_protocols;
Ok(ClientTlsContext {
config: Arc::new(config),
server_name: self.server_name,
})
}
}
#[derive(Debug)]
struct SniCertResolver {
by_name: HashMap<String, Arc<CertifiedKey>>,
fallback: Option<Arc<CertifiedKey>>,
}
impl ResolvesServerCert for SniCertResolver {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
client_hello
.server_name()
.and_then(|name| self.by_name.get(name).cloned())
.or_else(|| self.fallback.clone())
}
}
fn build_sni_resolver(
provider: &Arc<rustls::crypto::CryptoProvider>,
default_certificates: Vec<CertificateDer<'static>>,
default_private_key: Option<PrivateKeyDer<'static>>,
sni_identities: Vec<SniIdentity>,
) -> Result<SniCertResolver> {
let fallback = if default_certificates.is_empty() {
None
} else {
let private_key = default_private_key.ok_or_else(|| {
tls_config_error("server TLS context requires a private key".to_string())
})?;
Some(Arc::new(
CertifiedKey::from_der(default_certificates, private_key, provider)
.map_err(|err| tls_config_error(format!("invalid fallback TLS identity: {err}")))?,
))
};
let mut by_name = HashMap::new();
for identity in sni_identities {
let name = normalize_sni_name(&identity.name)?;
if identity.certificates.is_empty() {
return Err(tls_config_error(format!(
"SNI identity `{}` requires a certificate chain",
identity.name
)));
}
let certified_key =
CertifiedKey::from_der(identity.certificates, identity.private_key, provider).map_err(
|err| {
tls_config_error(format!(
"invalid SNI TLS identity `{}`: {err}",
identity.name
))
},
)?;
by_name.insert(name, Arc::new(certified_key));
}
Ok(SniCertResolver { by_name, fallback })
}
fn normalize_sni_name(name: &str) -> Result<String> {
match ServerName::try_from(name.to_string())
.map_err(|err| tls_config_error(format!("invalid SNI server name `{name}`: {err}")))?
{
ServerName::DnsName(name) => Ok(name.as_ref().to_ascii_lowercase()),
ServerName::IpAddress(_) => Err(tls_config_error(format!(
"invalid SNI server name `{name}`: IP addresses are not valid SNI names"
))),
_ => Err(tls_config_error(format!(
"invalid SNI server name `{name}`: unsupported server name type"
))),
}
}
fn collect_alpn_protocols<I, P>(protocols: I, errors: &mut Vec<String>) -> Vec<Vec<u8>>
where
I: IntoIterator<Item = P>,
P: AsRef<[u8]>,
{
let mut values = Vec::new();
for protocol in protocols {
let protocol = protocol.as_ref();
if protocol.is_empty() {
errors.push("ALPN protocol name must not be empty".to_string());
continue;
}
if protocol.len() > 255 {
errors.push(format!(
"ALPN protocol name must be at most 255 bytes, got {}",
protocol.len()
));
continue;
}
values.push(protocol.to_vec());
}
values
}
pub(crate) fn tls_handshake_error(action: &str, err: impl fmt::Display) -> Error {
Error::Tls(format!("TLS handshake failed during {action}: {err}"))
}
fn tls_config_error(message: impl Into<String>) -> Error {
Error::Tls(format!("TLS configuration failed: {}", message.into()))
}
fn tls_invalid_server_name(message: impl Into<String>) -> Error {
Error::Tls(format!(
"TLS server name validation failed: {}",
message.into()
))
}
fn parse_certificates_pem(pem: &[u8]) -> std::result::Result<Vec<CertificateDer<'static>>, String> {
rustls_pemfile::certs(&mut Cursor::new(pem))
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|err| format!("failed to parse PEM certificate chain: {err}"))
}
fn parse_private_key_pem(pem: &[u8]) -> std::result::Result<PrivateKeyDer<'static>, String> {
rustls_pemfile::private_key(&mut Cursor::new(pem))
.map_err(|err| format!("failed to parse PEM private key: {err}"))?
.ok_or_else(|| "PEM private key was not found".to_string())
}
fn add_roots(
roots: &mut RootCertStore,
certificates: Vec<CertificateDer<'static>>,
errors: &mut Vec<String>,
) {
let (added, ignored) = roots.add_parsable_certificates(certificates);
if added == 0 {
errors.push("no valid root certificates were added".to_string());
}
if ignored > 0 {
errors.push(format!("{ignored} root certificate(s) could not be parsed"));
}
}
fn first_error(errors: Vec<String>) -> Option<String> {
errors.into_iter().next()
}
#[cfg(feature = "tls-dangerous")]
#[derive(Debug)]
struct DangerAcceptInvalidCerts;
#[cfg(feature = "tls-dangerous")]
impl rustls::client::danger::ServerCertVerifier for DangerAcceptInvalidCerts {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::ED448,
]
}
}