use {
rustls::{
sign::{any_supported_type, CertifiedKey},
ResolvesServerCert,
},
std::{
ffi::OsStr,
fmt::{Display, Formatter},
path::Path,
sync::Arc,
},
webpki::DNSNameRef,
};
pub(crate) struct CertStore {
certs: Vec<(String, CertifiedKey)>,
}
pub static CERT_FILE_NAME: &str = "cert.der";
pub static KEY_FILE_NAME: &str = "key.der";
#[derive(Debug)]
pub enum CertLoadError {
NoReadCertDir,
Empty,
BadDomain(String),
BadKey(String),
BadCert(String, String),
MissingKey(String),
MissingCert(String),
EmptyDomain(String),
}
impl Display for CertLoadError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoReadCertDir => write!(f, "Could not read from certificate directory."),
Self::Empty => write!(f, "No keys or certificates were found in the given directory.\nSpecify the --hostname option to generate these automatically."),
Self::BadDomain(domain) if !domain.is_ascii() => write!(
f,
"The domain name {} cannot be processed, it must be punycoded.",
domain
),
Self::BadDomain(domain) => write!(f, "The domain name {} cannot be processed.", domain),
Self::BadKey(domain) => write!(f, "The key file for {} is malformed.", domain),
Self::BadCert(domain, e) => {
write!(f, "The certificate file for {} is malformed: {}", domain, e)
}
Self::MissingKey(domain) => write!(f, "The key file for {} is missing.", domain),
Self::MissingCert(domain) => {
write!(f, "The certificate file for {} is missing.", domain)
}
Self::EmptyDomain(domain) => write!(
f,
"A folder for {} exists, but there is no certificate or key file.",
domain
),
}
}
}
impl std::error::Error for CertLoadError {}
fn load_domain(certs_dir: &Path, domain: String) -> Result<CertifiedKey, CertLoadError> {
let mut path = certs_dir.to_path_buf();
path.push(&domain);
path.push(CERT_FILE_NAME);
if !path.is_file() {
return Err(if !path.with_file_name(KEY_FILE_NAME).is_file() {
CertLoadError::EmptyDomain(domain)
} else {
CertLoadError::MissingCert(domain)
});
}
let cert = rustls::Certificate(
std::fs::read(&path).map_err(|_| CertLoadError::MissingCert(domain.clone()))?,
);
path.set_file_name(KEY_FILE_NAME);
if !path.is_file() {
return Err(CertLoadError::MissingKey(domain));
}
let key = rustls::PrivateKey(
std::fs::read(&path).map_err(|_| CertLoadError::MissingKey(domain.clone()))?,
);
let key = match any_supported_type(&key) {
Ok(key) => key,
Err(()) => return Err(CertLoadError::BadKey(domain)),
};
Ok(CertifiedKey::new(vec![cert], Arc::new(key)))
}
impl CertStore {
pub fn load_from(certs_dir: &Path) -> Result<Self, CertLoadError> {
let mut certs = vec![];
match load_domain(certs_dir, String::new()) {
Err(CertLoadError::EmptyDomain(_)) => { }
Err(CertLoadError::Empty)
| Err(CertLoadError::NoReadCertDir)
| Err(CertLoadError::BadDomain(_)) => unreachable!(),
Err(CertLoadError::BadKey(_)) => {
return Err(CertLoadError::BadKey("fallback".to_string()))
}
Err(CertLoadError::BadCert(_, e)) => {
return Err(CertLoadError::BadCert("fallback".to_string(), e))
}
Err(CertLoadError::MissingKey(_)) => {
return Err(CertLoadError::MissingKey("fallback".to_string()))
}
Err(CertLoadError::MissingCert(_)) => {
return Err(CertLoadError::MissingCert("fallback".to_string()))
}
Ok(key) => certs.push((String::new(), key)),
}
for file in certs_dir
.read_dir()
.or(Err(CertLoadError::NoReadCertDir))?
.filter_map(Result::ok)
.filter(|x| x.path().is_dir())
{
let path = file.path();
let filename = path
.file_name()
.and_then(OsStr::to_str)
.unwrap()
.to_string();
let dns_name = match DNSNameRef::try_from_ascii_str(&filename) {
Ok(name) => name,
Err(_) => return Err(CertLoadError::BadDomain(filename)),
};
let key = load_domain(certs_dir, filename.clone())?;
key.cross_check_end_entity_cert(Some(dns_name))
.map_err(|e| CertLoadError::BadCert(filename.clone(), e.to_string()))?;
certs.push((filename, key));
}
if certs.is_empty() {
return Err(CertLoadError::Empty);
}
certs.sort_unstable_by(|(a, _), (b, _)| {
for (a_part, b_part) in a.split('.').rev().zip(b.split('.').rev()) {
if a_part != b_part {
return a_part.cmp(b_part).reverse();
}
}
a.len().cmp(&b.len()).reverse()
});
log::debug!(
"certs loaded for {:?}",
certs.iter().map(|t| &t.0).collect::<Vec<_>>()
);
Ok(Self { certs })
}
pub fn has_domain(&self, domain: &str) -> bool {
self.certs.iter().any(|(s, _)| domain.ends_with(s))
}
}
impl ResolvesServerCert for CertStore {
fn resolve(&self, client_hello: rustls::ClientHello<'_>) -> Option<CertifiedKey> {
if let Some(name) = client_hello.server_name() {
let name: &str = name.into();
self.certs
.iter()
.find(|(s, _)| name.ends_with(s))
.map(|(_, k)| k)
.cloned()
} else {
None
}
}
}