use std::sync::Arc;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
#[derive(Clone)]
pub struct ClientAuth {
pub certificates: Vec<CertificateDer<'static>>,
pub key: Arc<PrivateKeyDer<'static>>,
}
impl ClientAuth {
pub fn new(certificates: Vec<CertificateDer<'static>>, key: PrivateKeyDer<'static>) -> Self {
Self {
certificates,
key: Arc::new(key),
}
}
}
impl std::fmt::Debug for ClientAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientAuth")
.field("certificates_count", &self.certificates.len())
.field("has_key", &true)
.finish()
}
}
#[derive(Clone, Debug)]
pub struct TlsConfig {
pub trust_server_certificate: bool,
pub root_certificates: Vec<CertificateDer<'static>>,
pub client_auth: Option<ClientAuth>,
pub server_name: Option<String>,
pub min_protocol_version: TlsVersion,
pub max_protocol_version: TlsVersion,
pub strict_mode: bool,
pub alpn_protocols: Vec<Vec<u8>>,
}
impl Default for TlsConfig {
fn default() -> Self {
Self {
trust_server_certificate: false,
root_certificates: Vec::new(),
client_auth: None,
server_name: None,
min_protocol_version: TlsVersion::Tls12,
max_protocol_version: TlsVersion::Tls13,
strict_mode: false,
alpn_protocols: Vec::new(),
}
}
}
impl TlsConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn trust_server_certificate(mut self, trust: bool) -> Self {
self.trust_server_certificate = trust;
self
}
#[must_use]
pub fn add_root_certificate(mut self, cert: CertificateDer<'static>) -> Self {
self.root_certificates.push(cert);
self
}
#[must_use]
pub fn with_root_certificates(mut self, certs: Vec<CertificateDer<'static>>) -> Self {
self.root_certificates = certs;
self
}
#[must_use]
pub fn with_client_auth(
mut self,
certs: Vec<CertificateDer<'static>>,
key: PrivateKeyDer<'static>,
) -> Self {
self.client_auth = Some(ClientAuth::new(certs, key));
self
}
#[must_use]
pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
self.server_name = Some(name.into());
self
}
#[must_use]
pub fn min_protocol_version(mut self, version: TlsVersion) -> Self {
self.min_protocol_version = version;
self
}
#[must_use]
pub fn max_protocol_version(mut self, version: TlsVersion) -> Self {
self.max_protocol_version = version;
self
}
#[must_use]
pub fn strict_mode(mut self, enabled: bool) -> Self {
self.strict_mode = enabled;
self
}
#[must_use]
pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
self.alpn_protocols = protocols;
self
}
#[must_use]
pub fn has_client_auth(&self) -> bool {
self.client_auth.is_some()
}
#[must_use]
pub fn add_root_certificate_der(self, der_bytes: Vec<u8>) -> Self {
self.add_root_certificate(CertificateDer::from(der_bytes))
}
#[must_use]
pub fn with_client_auth_der(
self,
cert_chain_der: Vec<Vec<u8>>,
private_key_der: Vec<u8>,
) -> Self {
let certs = cert_chain_der
.into_iter()
.map(CertificateDer::from)
.collect();
let key = PrivateKeyDer::Pkcs8(private_key_der.into());
self.with_client_auth(certs, key)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
#[non_exhaustive]
pub enum TlsVersion {
#[default]
Tls12,
Tls13,
}
impl TlsVersion {
#[must_use]
pub fn to_rustls(&self) -> &'static rustls::SupportedProtocolVersion {
match self {
Self::Tls12 => &rustls::version::TLS12,
Self::Tls13 => &rustls::version::TLS13,
}
}
}