use std::sync::Arc;
use rustls::{ClientConfig, RootCertStore};
use rustls_pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
#[derive(Debug, thiserror::Error)]
pub enum MtlsClientError {
#[error("client certificate decode failed: {0}")]
CertParse(String),
#[error("private key decode failed: {0}")]
KeyParse(String),
#[error("server root certificate decode failed: {0}")]
RootsParse(String),
#[error("rustls ClientConfig construction failed: {0}")]
ClientConfig(String),
}
pub struct OutboundMtlsClient {
cert_chain: Vec<CertificateDer<'static>>,
private_key: PrivateKeyDer<'static>,
server_roots: Option<RootCertStore>,
}
impl std::fmt::Debug for OutboundMtlsClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OutboundMtlsClient")
.field("cert_chain_len", &self.cert_chain.len())
.field("private_key", &"<redacted>")
.field("server_roots_present", &self.server_roots.is_some())
.finish()
}
}
impl OutboundMtlsClient {
pub fn new_from_pem(
cert_chain_pem: &[u8],
private_key_pem: &[u8],
) -> Result<Self, MtlsClientError> {
let cert_chain: Vec<CertificateDer<'static>> =
CertificateDer::pem_slice_iter(cert_chain_pem)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| MtlsClientError::CertParse(e.to_string()))?;
if cert_chain.is_empty() {
return Err(MtlsClientError::CertParse(
"no PEM `CERTIFICATE` block found".to_string(),
));
}
let private_key = PrivateKeyDer::from_pem_slice(private_key_pem)
.map_err(|e| MtlsClientError::KeyParse(e.to_string()))?;
Ok(Self {
cert_chain,
private_key,
server_roots: None,
})
}
pub fn new_from_der(
cert_chain_der: Vec<CertificateDer<'static>>,
private_key_der: PrivateKeyDer<'static>,
) -> Result<Self, MtlsClientError> {
if cert_chain_der.is_empty() {
return Err(MtlsClientError::CertParse(
"empty cert chain provided".to_string(),
));
}
Ok(Self {
cert_chain: cert_chain_der,
private_key: private_key_der,
server_roots: None,
})
}
pub fn with_server_roots(mut self, roots: RootCertStore) -> Self {
self.server_roots = Some(roots);
self
}
pub fn with_server_roots_pem(self, roots_pem: &[u8]) -> Result<Self, MtlsClientError> {
let mut store = RootCertStore::empty();
for cert in CertificateDer::pem_slice_iter(roots_pem) {
let cert = cert.map_err(|e| MtlsClientError::RootsParse(e.to_string()))?;
store
.add(cert)
.map_err(|e| MtlsClientError::RootsParse(e.to_string()))?;
}
if store.is_empty() {
return Err(MtlsClientError::RootsParse(
"no `CERTIFICATE` blocks found in roots PEM".to_string(),
));
}
Ok(self.with_server_roots(store))
}
pub fn cert_chain(&self) -> &[CertificateDer<'static>] {
&self.cert_chain
}
pub fn private_key(&self) -> &PrivateKeyDer<'static> {
&self.private_key
}
pub fn server_roots(&self) -> Option<&RootCertStore> {
self.server_roots.as_ref()
}
pub fn rustls_client_config(&self) -> Result<Arc<ClientConfig>, MtlsClientError> {
let roots = self.server_roots.clone().ok_or_else(|| {
MtlsClientError::ClientConfig(
"no server-roots store configured; call with_server_roots or \
with_server_roots_pem before building the ClientConfig"
.to_string(),
)
})?;
let cert_chain = self.cert_chain.clone();
let private_key = self.private_key.clone_key();
let config = ClientConfig::builder()
.with_root_certificates(roots)
.with_client_auth_cert(cert_chain, private_key)
.map_err(|e| MtlsClientError::ClientConfig(e.to_string()))?;
Ok(Arc::new(config))
}
}
#[cfg(test)]
mod tests;