use crate::crypto;
use crate::server::{ServerTlsConfig, ServerTlsInfo};
use anyhow::Context;
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer, ServerName, UnixTime};
use std::fs;
use std::io::{self, Cursor, Read};
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
#[derive(Debug)]
pub(crate) struct FingerprintVerifier {
provider: crypto::Provider,
fingerprint: Vec<u8>,
}
impl FingerprintVerifier {
pub fn new(provider: crypto::Provider, fingerprint: Vec<u8>) -> Self {
Self { provider, fingerprint }
}
}
impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp: &[u8],
_now: UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
let fingerprint = crypto::sha256(&self.provider, end_entity);
if fingerprint.as_ref() == self.fingerprint.as_slice() {
Ok(rustls::client::danger::ServerCertVerified::assertion())
} else {
Err(rustls::Error::General("fingerprint mismatch".into()))
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.provider.signature_verification_algorithms.supported_schemes()
}
}
#[derive(Debug)]
pub(crate) struct ServeCerts {
pub info: Arc<RwLock<ServerTlsInfo>>,
provider: crypto::Provider,
}
impl ServeCerts {
pub fn new(provider: crypto::Provider) -> Self {
Self {
info: Arc::new(RwLock::new(ServerTlsInfo {
certs: Vec::new(),
fingerprints: Vec::new(),
})),
provider,
}
}
pub fn load_certs(&self, config: &ServerTlsConfig) -> anyhow::Result<()> {
anyhow::ensure!(config.cert.len() == config.key.len(), "must provide both cert and key");
anyhow::ensure!(
!config.cert.is_empty() || !config.generate.is_empty(),
"must provide at least one cert/key pair or generate entry"
);
let mut certs = Vec::new();
for (cert, key) in config.cert.iter().zip(config.key.iter()) {
certs.push(Arc::new(self.load(cert, key)?));
}
if !config.generate.is_empty() {
certs.push(Arc::new(self.generate(&config.generate)?));
}
self.set_certs(certs);
Ok(())
}
fn load(&self, chain_path: &PathBuf, key_path: &PathBuf) -> anyhow::Result<rustls::sign::CertifiedKey> {
let chain = fs::File::open(chain_path).context("failed to open cert file")?;
let mut chain = io::BufReader::new(chain);
let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
.collect::<Result<_, _>>()
.context("failed to read certs")?;
anyhow::ensure!(!chain.is_empty(), "could not find certificate");
let mut keys = fs::File::open(key_path).context("failed to open key file")?;
let mut buf = Vec::new();
keys.read_to_end(&mut buf)?;
let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
let key = self.provider.key_provider.load_private_key(key)?;
let certified_key = rustls::sign::CertifiedKey::new(chain, key);
certified_key.keys_match().context(format!(
"private key {} doesn't match certificate {}",
key_path.display(),
chain_path.display()
))?;
Ok(certified_key)
}
#[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
fn generate(&self, hostnames: &[String]) -> anyhow::Result<rustls::sign::CertifiedKey> {
let key_pair = rcgen::KeyPair::generate()?;
let mut params = rcgen::CertificateParams::new(hostnames)?;
params.not_before = ::time::OffsetDateTime::now_utc() - ::time::Duration::days(1);
params.not_after = params.not_before + ::time::Duration::days(14);
let cert = params.self_signed(&key_pair)?;
let key_der = key_pair.serialized_der().to_vec();
let key_der = PrivatePkcs8KeyDer::from(key_der);
let key = self.provider.key_provider.load_private_key(key_der.into())?;
Ok(rustls::sign::CertifiedKey::new(vec![cert.into()], key))
}
#[cfg(not(any(feature = "aws-lc-rs", feature = "ring")))]
fn generate(&self, _hostnames: &[String]) -> anyhow::Result<rustls::sign::CertifiedKey> {
anyhow::bail!("no crypto provider available; enable aws-lc-rs or ring feature");
}
pub fn set_certs(&self, certs: Vec<Arc<rustls::sign::CertifiedKey>>) {
let fingerprints = certs
.iter()
.map(|ck| {
let fingerprint = crate::crypto::sha256(&self.provider, ck.cert[0].as_ref());
hex::encode(fingerprint)
})
.collect();
let mut info = self.info.write().expect("info write lock poisoned");
info.certs = certs;
info.fingerprints = fingerprints;
}
fn best_certificate(
&self,
client_hello: &rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
let server_name = client_hello.server_name()?;
let dns_name = rustls::pki_types::ServerName::try_from(server_name).ok()?;
for ck in self.info.read().expect("info read lock poisoned").certs.iter() {
let leaf: webpki::EndEntityCert = ck
.end_entity_cert()
.expect("missing certificate")
.try_into()
.expect("failed to parse certificate");
if leaf.verify_is_valid_for_subject_name(&dns_name).is_ok() {
return Some(ck.clone());
}
}
None
}
}
impl rustls::server::ResolvesServerCert for ServeCerts {
fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<rustls::sign::CertifiedKey>> {
if let Some(cert) = self.best_certificate(&client_hello) {
return Some(cert);
}
tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");
self.info
.read()
.expect("info read lock poisoned")
.certs
.first()
.cloned()
}
}
#[cfg(unix)]
pub(crate) async fn reload_certs(certs: Arc<ServeCerts>, tls_config: ServerTlsConfig) {
use tokio::signal::unix::{SignalKind, signal};
let mut listener = signal(SignalKind::user_defined1()).expect("failed to listen for signals");
while listener.recv().await.is_some() {
tracing::info!("reloading server certificates");
if let Err(err) = certs.load_certs(&tls_config) {
tracing::warn!(%err, "failed to reload server certificates");
}
}
}