use anyhow::{Result, anyhow};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use std::fs;
use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
use tokio_rustls::{TlsAcceptor, TlsConnector, rustls};
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct MtlsConfig {
pub ca_path: String,
pub cert_path: String,
pub key_path: String,
}
fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>> {
let file = fs::File::open(path)?;
let mut reader = BufReader::new(file);
let certs = rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
Ok(certs)
}
fn load_key(path: &Path) -> Result<PrivateKeyDer<'static>> {
let file = fs::File::open(path)?;
let mut reader = BufReader::new(file);
let keys: Vec<_> =
rustls_pemfile::pkcs8_private_keys(&mut reader).collect::<Result<Vec<_>, _>>()?;
if let Some(key) = keys.into_iter().next() {
return Ok(PrivateKeyDer::Pkcs8(key));
}
let file = fs::File::open(path)?;
let mut reader = BufReader::new(file);
let keys: Vec<_> =
rustls_pemfile::rsa_private_keys(&mut reader).collect::<Result<Vec<_>, _>>()?;
keys.into_iter()
.next()
.map(PrivateKeyDer::Pkcs1)
.ok_or_else(|| anyhow!("No private key found"))
}
pub fn create_acceptor(config: &MtlsConfig) -> Result<TlsAcceptor> {
let certs = load_certs(Path::new(&config.cert_path))?;
let key = load_key(Path::new(&config.key_path))?;
let ca_certs = load_certs(Path::new(&config.ca_path))?;
let mut root_store = rustls::RootCertStore::empty();
for cert in ca_certs {
root_store.add(cert)?;
}
let client_verifier =
rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store)).build()?;
let server_config = rustls::ServerConfig::builder()
.with_client_cert_verifier(client_verifier)
.with_single_cert(certs, key)?;
info!("mTLS acceptor created");
Ok(TlsAcceptor::from(Arc::new(server_config)))
}
pub fn create_connector(config: &MtlsConfig) -> Result<TlsConnector> {
let certs = load_certs(Path::new(&config.cert_path))?;
let key = load_key(Path::new(&config.key_path))?;
let ca_certs = load_certs(Path::new(&config.ca_path))?;
let mut root_store = rustls::RootCertStore::empty();
for cert in ca_certs {
root_store.add(cert)?;
}
let client_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_client_auth_cert(certs, key)?;
info!("mTLS connector created");
Ok(TlsConnector::from(Arc::new(client_config)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mtls_config() {
let config = MtlsConfig {
ca_path: "/path/to/ca.pem".to_string(),
cert_path: "/path/to/cert.pem".to_string(),
key_path: "/path/to/key.pem".to_string(),
};
assert!(!config.ca_path.is_empty());
}
}