use std::path::{Path, PathBuf};
use std::sync::Arc;
use rustls::{ClientConfig, RootCertStore};
use rustls_pki_types::CertificateDer;
use rustls_pki_types::pem::PemObject;
#[derive(Debug, thiserror::Error)]
pub enum TlsError {
#[error("TLS: cannot read certificate file {path}: {reason}")]
PemRead {
path: PathBuf,
reason: String,
},
#[error("TLS: no certificates parsed from {path}")]
NoCertsFound {
path: PathBuf,
},
#[error(
"TLS: trust store is empty -- no roots from native/webpki/extra sources \
(check `native_roots`/`webpki_roots`/`extra_roots`, or `exclusive` with no files)"
)]
EmptyTrustStore,
#[error("TLS: failed to build client config: {0}")]
Build(String),
}
#[derive(Debug, Clone)]
pub struct TlsTrust {
pub native_roots: bool,
pub webpki_roots: bool,
pub extra_roots: Vec<PathBuf>,
pub extra_intermediates: Vec<PathBuf>,
pub exclusive: bool,
}
impl Default for TlsTrust {
fn default() -> Self {
Self {
native_roots: true,
webpki_roots: false,
extra_roots: Vec::new(),
extra_intermediates: Vec::new(),
exclusive: false,
}
}
}
impl TlsTrust {
#[must_use]
pub fn private_ca(pem_path: impl Into<PathBuf>) -> Self {
Self {
native_roots: false,
webpki_roots: false,
extra_roots: vec![pem_path.into()],
extra_intermediates: Vec::new(),
exclusive: true,
}
}
}
pub enum TlsConfigSource {
Explicit(Arc<ClientConfig>),
Trust(TlsTrust),
}
pub fn add_pem_file_certs(store: &mut RootCertStore, path: &Path) -> Result<usize, TlsError> {
if !path.is_file() {
return Err(TlsError::PemRead {
path: path.to_path_buf(),
reason: "not a readable file".to_string(),
});
}
let iter = CertificateDer::pem_file_iter(path).map_err(|e| TlsError::PemRead {
path: path.to_path_buf(),
reason: e.to_string(),
})?;
let certs: Vec<CertificateDer<'static>> = iter.filter_map(Result::ok).collect();
let (added, _ignored) = store.add_parsable_certificates(certs);
if added == 0 {
return Err(TlsError::NoCertsFound {
path: path.to_path_buf(),
});
}
Ok(added)
}
pub fn build_root_store(trust: &TlsTrust) -> Result<RootCertStore, TlsError> {
let mut store = RootCertStore::empty();
if !trust.exclusive {
if trust.native_roots {
let result = rustls_native_certs::load_native_certs();
let (_added, _ignored) = store.add_parsable_certificates(result.certs);
for err in result.errors {
tracing::warn!(error = %err, "TLS: error loading a native root (continuing)");
}
}
if trust.webpki_roots {
store
.roots
.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
}
}
for path in &trust.extra_roots {
add_pem_file_certs(&mut store, path)?;
}
for path in &trust.extra_intermediates {
add_pem_file_certs(&mut store, path)?;
}
if store.is_empty() {
return Err(TlsError::EmptyTrustStore);
}
Ok(store)
}
pub fn build_client_config(source: TlsConfigSource) -> Result<Arc<ClientConfig>, TlsError> {
match source {
TlsConfigSource::Explicit(cfg) => Ok(cfg),
TlsConfigSource::Trust(trust) => {
let roots = build_root_store(&trust)?;
let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
let cfg = ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(|e| TlsError::Build(e.to_string()))?
.with_root_certificates(roots)
.with_no_client_auth();
Ok(Arc::new(cfg))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
fn gen_ca_bundle(n: usize) -> String {
let mut bundle = String::new();
for i in 0..n {
let cert = rcgen::generate_simple_self_signed(vec![format!("ca-{i}.test")])
.expect("rcgen self-signed");
bundle.push_str(&cert.cert.pem());
}
bundle
}
fn write_temp(contents: &str) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().expect("temp file");
f.write_all(contents.as_bytes()).expect("write");
f.flush().expect("flush");
f
}
#[test]
fn add_pem_counts_multi_cert_bundle() {
let f = write_temp(&gen_ca_bundle(3));
let mut store = RootCertStore::empty();
let added = add_pem_file_certs(&mut store, f.path()).unwrap();
assert_eq!(added, 3, "all three certs in the bundle are added");
}
#[test]
fn add_pem_is_lenient_with_junk_plus_valid() {
let mut contents = String::from("this is not a PEM block\ngarbage line\n");
contents.push_str(&gen_ca_bundle(1));
contents.push_str("\ntrailing junk\n");
let f = write_temp(&contents);
let mut store = RootCertStore::empty();
let added = add_pem_file_certs(&mut store, f.path()).unwrap();
assert_eq!(added, 1, "junk is skipped, the valid cert still loads");
}
#[test]
fn add_pem_zero_certs_is_error() {
let f = write_temp("no certificates here at all\n");
let mut store = RootCertStore::empty();
let err = add_pem_file_certs(&mut store, f.path()).unwrap_err();
assert!(matches!(err, TlsError::NoCertsFound { .. }));
}
#[test]
fn add_pem_unreadable_path_is_error() {
let mut store = RootCertStore::empty();
let err = add_pem_file_certs(&mut store, Path::new("/nonexistent/nope.pem")).unwrap_err();
assert!(matches!(err, TlsError::PemRead { .. }));
}
#[test]
fn build_store_augments_native_with_extra() {
let f = write_temp(&gen_ca_bundle(1));
let trust = TlsTrust {
native_roots: true,
webpki_roots: false,
extra_roots: vec![f.path().to_path_buf()],
extra_intermediates: Vec::new(),
exclusive: false,
};
let store = build_root_store(&trust).unwrap();
assert!(!store.is_empty());
}
#[test]
fn build_store_exclusive_uses_only_extra() {
let f = write_temp(&gen_ca_bundle(2));
let trust = TlsTrust::private_ca(f.path());
let store = build_root_store(&trust).unwrap();
assert_eq!(
store.roots.len(),
2,
"exclusive store holds exactly the private-CA certs, no native roots"
);
}
#[test]
fn build_store_exclusive_with_no_files_is_error() {
let trust = TlsTrust {
native_roots: true, webpki_roots: true, extra_roots: Vec::new(),
extra_intermediates: Vec::new(),
exclusive: true,
};
let err = build_root_store(&trust).unwrap_err();
assert!(matches!(err, TlsError::EmptyTrustStore));
}
#[test]
fn build_store_no_sources_is_empty_error() {
let trust = TlsTrust {
native_roots: false,
webpki_roots: false,
extra_roots: Vec::new(),
extra_intermediates: Vec::new(),
exclusive: false,
};
let err = build_root_store(&trust).unwrap_err();
assert!(matches!(err, TlsError::EmptyTrustStore));
}
#[test]
fn build_client_config_from_private_ca() {
let f = write_temp(&gen_ca_bundle(1));
let cfg =
build_client_config(TlsConfigSource::Trust(TlsTrust::private_ca(f.path()))).unwrap();
assert!(Arc::strong_count(&cfg) >= 1);
}
}