use std::fs;
use std::sync::Arc;
use anyhow::{Context, Result};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::ServerConfig;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TlsSettings {
pub cert: String,
pub key: String,
#[serde(default)]
pub client_ca: Option<String>,
}
impl TlsSettings {
pub fn mtls_enabled(&self) -> bool {
self.client_ca.is_some()
}
fn load_certs(path: &str) -> Result<Vec<CertificateDer<'static>>> {
let pem = fs::read(path).with_context(|| format!("reading cert {path}"))?;
let certs = rustls_pemfile::certs(&mut pem.as_slice())
.collect::<std::result::Result<Vec<_>, _>>()
.with_context(|| format!("parsing certs in {path}"))?;
if certs.is_empty() {
anyhow::bail!("no certificates found in {path}");
}
Ok(certs)
}
fn load_key(path: &str) -> Result<PrivateKeyDer<'static>> {
let pem = fs::read(path).with_context(|| format!("reading key {path}"))?;
rustls_pemfile::private_key(&mut pem.as_slice())
.with_context(|| format!("parsing key {path}"))?
.ok_or_else(|| anyhow::anyhow!("no private key found in {path}"))
}
fn provider() -> Arc<rustls::crypto::CryptoProvider> {
Arc::new(rustls::crypto::ring::default_provider())
}
}
pub fn ensure_crypto_provider() {
use std::sync::Once;
static ONCE: Once = Once::new();
ONCE.call_once(|| {
let _ = rustls::crypto::ring::default_provider().install_default();
});
}
impl TlsSettings {
pub fn rustls_gateway_config(&self) -> Result<Arc<ServerConfig>> {
let certs = Self::load_certs(&self.cert)?;
let key = Self::load_key(&self.key)?;
let builder = ServerConfig::builder_with_provider(Self::provider())
.with_safe_default_protocol_versions()
.context("rustls protocol versions")?;
let mut config = match &self.client_ca {
Some(ca_path) => {
let mut roots = rustls::RootCertStore::empty();
for c in Self::load_certs(ca_path)? {
roots.add(c).context("adding client CA")?;
}
let verifier = rustls::server::WebPkiClientVerifier::builder_with_provider(
Arc::new(roots),
Self::provider(),
)
.allow_unauthenticated()
.build()
.context("building client verifier")?;
builder
.with_client_cert_verifier(verifier)
.with_single_cert(certs, key)
.context("loading server identity")?
}
None => builder
.with_no_client_auth()
.with_single_cert(certs, key)
.context("loading server identity")?,
};
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Ok(Arc::new(config))
}
pub fn tonic_config(&self) -> Result<tonic::transport::ServerTlsConfig> {
use tonic::transport::{Certificate, Identity, ServerTlsConfig};
let cert_pem = fs::read(&self.cert).with_context(|| format!("reading cert {}", self.cert))?;
let key_pem = fs::read(&self.key).with_context(|| format!("reading key {}", self.key))?;
let mut tls = ServerTlsConfig::new().identity(Identity::from_pem(cert_pem, key_pem));
if let Some(ca) = &self.client_ca {
let ca_pem = fs::read(ca).with_context(|| format!("reading client CA {ca}"))?;
tls = tls.client_ca_root(Certificate::from_pem(ca_pem));
}
Ok(tls)
}
}
pub fn leaf_common_name(chain: &[CertificateDer<'_>]) -> Option<String> {
let leaf = chain.first()?;
let (_, cert) = x509_parser::parse_x509_certificate(leaf.as_ref()).ok()?;
let mut cn = None;
for attr in cert.subject().iter_common_name() {
if let Ok(s) = attr.as_str() {
cn = Some(s.to_string());
break;
}
}
cn
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "testmatrix")]
fn fstatus(component: &str, check: &str, ok: bool, detail: &str) {
nornir_testmatrix::functional_status(component, check, ok, detail);
}
#[test]
fn extracts_common_name_from_cert() {
let mut params = rcgen::CertificateParams::new(vec![]).unwrap();
params
.distinguished_name
.push(rcgen::DnType::CommonName, "writer-bot");
let kp = rcgen::KeyPair::generate().unwrap();
let cert = params.self_signed(&kp).unwrap();
let der = cert.der().clone();
let cn = leaf_common_name(&[der]);
assert_eq!(cn.as_deref(), Some("writer-bot"));
#[cfg(feature = "testmatrix")]
fstatus(
"tls",
"extracts_common_name_from_cert",
cn.as_deref() == Some("writer-bot"),
&format!("leaf CN = {cn:?}"),
);
}
#[test]
fn empty_chain_yields_no_cn() {
let cn = leaf_common_name(&[]);
assert_eq!(cn, None);
#[cfg(feature = "testmatrix")]
fstatus(
"tls",
"empty_chain_yields_no_cn",
cn.is_none(),
&format!("empty cert chain -> CN {cn:?}"),
);
}
}