pub(crate) mod client;
pub(crate) mod server;
use std::convert::TryFrom;
use std::io::{self, ErrorKind};
use std::path::Path;
pub(crate) use client::*;
pub(crate) use server::*;
use tokio_rustls::{rustls, webpki};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CertificateMode {
AuthorityBased,
SelfSigned,
}
#[derive(Debug)]
pub enum TlsError {
InvalidPeerCertificate(io::Error),
InvalidLocalCertificate(io::Error),
InvalidPrivateKey(io::Error),
InvalidDnsName,
BadConfig(String),
}
impl std::fmt::Display for TlsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidPeerCertificate(err) => {
write!(f, "invalid peer certificate file: {err}")
}
Self::InvalidLocalCertificate(err) => {
write!(f, "invalid local certificate file: {err}")
}
Self::InvalidPrivateKey(err) => write!(f, "invalid private key file: {err}"),
Self::InvalidDnsName => write!(f, "invalid DNS name"),
Self::BadConfig(err) => write!(f, "bad config: {err}"),
}
}
}
impl std::error::Error for TlsError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MinTlsVersion {
V1_2,
V1_3,
}
impl MinTlsVersion {
fn to_rustls(self) -> &'static [&'static rustls::SupportedProtocolVersion] {
static MIN_TLS12_VERSIONS: &[&rustls::SupportedProtocolVersion] =
&[&rustls::version::TLS13, &rustls::version::TLS12];
static MIN_TLS13_VERSIONS: &[&rustls::SupportedProtocolVersion] =
&[&rustls::version::TLS13];
match self {
Self::V1_2 => MIN_TLS12_VERSIONS,
Self::V1_3 => MIN_TLS13_VERSIONS,
}
}
}
fn verify_dns_name(cert: &rustls::Certificate, server_name: &str) -> Result<(), rustls::Error> {
let dns_name = webpki::DnsNameRef::try_from_ascii_str(server_name)
.map_err(|_| rustls::Error::InvalidCertificateData("invalid DNS name".to_string()))?;
let end_entity_cert = webpki::EndEntityCert::try_from(cert.0.as_ref()).map_err(pki_error)?;
match end_entity_cert.verify_is_valid_for_dns_name(dns_name) {
Ok(()) => Ok(()), Err(webpki::Error::CertNotValidForName) => {
let parsed_cert = rx509::x509::Certificate::parse(&cert.0).map_err(|err| {
rustls::Error::InvalidCertificateData(format!(
"unable to parse cert with rasn: {err:?}"
))
})?;
if let Some(extensions) = &parsed_cert.tbs_certificate.value.extensions {
let extensions = extensions.parse().map_err(|err| {
rustls::Error::InvalidCertificateData(format!(
"unable to parse certificate extensions with rasn: {err:?}"
))
})?;
if extensions.iter().any(|x| {
matches!(
x.content,
rx509::x509::ext::SpecificExtension::SubjectAlternativeName(_)
)
}) {
return Err(rustls::Error::InvalidCertificateData(
"certificate not valid for name, SAN extensions do not match".to_string(),
));
}
}
let subject = parsed_cert
.tbs_certificate
.value
.subject
.parse()
.map_err(|err| {
rustls::Error::InvalidCertificateData(format!(
"unable to parse certificate subject: {err:?}"
))
})?;
let common_name = subject.common_name.ok_or_else(|| {
rustls::Error::InvalidCertificateData(
"certificate not valid for name, no SAN and no CN present".to_string(),
)
})?;
match common_name == server_name {
true => Ok(()),
false => Err(rustls::Error::InvalidCertificateData(
"certificate not valid for name, no SAN and CN doesn't match".to_string(),
)),
}
}
Err(err) => Err(pki_error(err)), }
}
fn pki_error(error: webpki::Error) -> rustls::Error {
use webpki::Error::*;
match error {
BadDer | BadDerTime => rustls::Error::InvalidCertificateEncoding,
InvalidSignatureForPublicKey => rustls::Error::InvalidCertificateSignature,
UnsupportedSignatureAlgorithm | UnsupportedSignatureAlgorithmForPublicKey => {
rustls::Error::InvalidCertificateSignatureType
}
e => rustls::Error::InvalidCertificateData(format!("invalid peer certificate: {e}")),
}
}
fn load_certs(path: &Path, is_local: bool) -> Result<Vec<rustls::Certificate>, TlsError> {
let map_error_fn = match is_local {
false => TlsError::InvalidPeerCertificate,
true => TlsError::InvalidLocalCertificate,
};
let content = std::fs::read(path).map_err(map_error_fn)?;
let certs = pem::parse_many(content)
.map_err(|err| map_error_fn(io::Error::new(ErrorKind::InvalidData, err.to_string())))?
.into_iter()
.filter(|x| x.tag == "CERTIFICATE")
.map(|x| rustls::Certificate(x.contents))
.collect::<Vec<_>>();
if certs.is_empty() {
return Err(map_error_fn(io::Error::new(
ErrorKind::InvalidData,
"no certificate in pem file",
)));
}
Ok(certs)
}
fn load_private_key(path: &Path, password: Option<&str>) -> Result<rustls::PrivateKey, TlsError> {
let expected_tag = match &password {
Some(_) => "ENCRYPTED PRIVATE KEY",
None => "PRIVATE KEY",
};
let content = std::fs::read(path).map_err(TlsError::InvalidPrivateKey)?;
let mut iter = pem::parse_many(content)
.map_err(|err| {
TlsError::InvalidPrivateKey(io::Error::new(ErrorKind::InvalidData, err.to_string()))
})?
.into_iter()
.filter(|x| x.tag == expected_tag)
.map(|x| x.contents);
let key = match iter.next() {
Some(key) => match password {
Some(password) => {
let encrypted = pkcs8::EncryptedPrivateKeyDocument::from_der(&key)?;
let decrypted = encrypted.decrypt(password)?;
rustls::PrivateKey(decrypted.as_ref().to_owned())
}
None => rustls::PrivateKey(key),
},
None => {
return Err(TlsError::InvalidPrivateKey(io::Error::new(
ErrorKind::InvalidData,
"no private key found in PEM file",
)));
}
};
if iter.next().is_some() {
return Err(TlsError::InvalidPrivateKey(io::Error::new(
ErrorKind::InvalidData,
"more than one private key is present in the PEM file",
)));
}
Ok(key)
}
impl From<pkcs8::Error> for TlsError {
fn from(from: pkcs8::Error) -> Self {
TlsError::InvalidPrivateKey(io::Error::new(ErrorKind::InvalidData, from.to_string()))
}
}