use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use blueprint_core::{debug, info};
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, SignatureScheme};
use rustls_pemfile;
use crate::db::RocksDb;
use crate::models::ServiceModel;
use crate::types::ServiceId;
#[derive(Clone, Debug)]
pub struct TlsClientConfig {
pub verify_server_cert: bool,
pub custom_ca_certs: Vec<Vec<u8>>,
pub client_cert: Option<Vec<u8>>,
pub client_key: Option<Vec<u8>>,
pub alpn_protocols: Vec<Vec<u8>>,
pub handshake_timeout: Duration,
}
impl Default for TlsClientConfig {
fn default() -> Self {
Self {
verify_server_cert: true,
custom_ca_certs: Vec::new(),
client_cert: None,
client_key: None,
alpn_protocols: vec![
b"h2".to_vec(), b"http/1.1".to_vec(), ],
handshake_timeout: Duration::from_secs(10),
}
}
}
#[derive(Clone, Debug)]
pub struct CachedTlsClient {
pub http_client: Client<
HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
axum::body::Body,
>,
pub http2_client: Client<
HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
axum::body::Body,
>,
pub config: TlsClientConfig,
pub last_access: std::time::Instant,
}
#[derive(Clone, Debug)]
pub struct TlsClientManager {
clients: Arc<Mutex<HashMap<String, CachedTlsClient>>>,
db: RocksDb,
max_cache_size: usize,
cache_ttl: Duration,
}
impl TlsClientManager {
pub fn new(db: RocksDb) -> Self {
Self {
clients: Arc::new(Mutex::new(HashMap::new())),
db,
max_cache_size: 100,
cache_ttl: Duration::from_secs(3600), }
}
pub async fn get_client_for_service(
&self,
service_id: ServiceId,
) -> Result<CachedTlsClient, crate::Error> {
let service = ServiceModel::find_by_id(service_id, &self.db)?
.ok_or(crate::Error::ServiceNotFound(service_id))?;
let tls_config = self.get_service_tls_config(&service).await?;
let config_hash = self.hash_config(&tls_config);
{
let clients = self.clients.lock().unwrap();
if let Some(cached_client) = clients.get(&config_hash) {
if cached_client.last_access.elapsed() < self.cache_ttl {
debug!("Using cached TLS client for service {}", service_id);
return Ok(cached_client.clone());
}
}
}
debug!("Creating new TLS client for service {}", service_id);
let client = self.create_tls_client(tls_config.clone()).await?;
{
let mut clients = self.clients.lock().unwrap();
if clients.len() >= self.max_cache_size {
self.evict_old_entries(&mut clients);
}
clients.insert(config_hash.clone(), client.clone());
}
Ok(client)
}
async fn get_service_tls_config(
&self,
service: &ServiceModel,
) -> Result<TlsClientConfig, crate::Error> {
let mut config = TlsClientConfig::default();
if let Some(profile) = &service.tls_profile {
if profile.tls_enabled {
config.verify_server_cert = true;
if !profile.encrypted_upstream_ca_bundle.is_empty() {
config
.custom_ca_certs
.push(profile.encrypted_upstream_ca_bundle.clone());
}
if !profile.encrypted_upstream_client_cert.is_empty()
&& !profile.encrypted_upstream_client_key.is_empty()
{
config.client_cert = Some(profile.encrypted_upstream_client_cert.clone());
config.client_key = Some(profile.encrypted_upstream_client_key.clone());
}
}
}
Ok(config)
}
async fn create_tls_client(
&self,
config: TlsClientConfig,
) -> Result<CachedTlsClient, crate::Error> {
let executor = TokioExecutor::new();
let mut root_store = RootCertStore::empty();
if !config.custom_ca_certs.is_empty() {
for ca_bundle in &config.custom_ca_certs {
merge_ca_bundle(&mut root_store, ca_bundle)?;
}
}
let builder = ClientConfig::builder().with_root_certificates(root_store);
let mut client_config = if let (Some(client_cert), Some(client_key)) =
(config.client_cert.as_ref(), config.client_key.as_ref())
{
let cert_chain = load_cert_chain_from_bytes(client_cert)?;
let private_key = load_private_key_from_bytes(client_key)?;
builder
.with_client_auth_cert(cert_chain, private_key)
.map_err(|err| {
crate::Error::Tls(format!("failed to configure client mTLS: {err}"))
})?
} else {
builder.with_no_client_auth()
};
client_config.alpn_protocols = config.alpn_protocols.clone();
if !config.verify_server_cert {
client_config
.dangerous()
.set_certificate_verifier(Arc::new(NoCertificateVerification));
}
let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(client_config)
.https_or_http()
.enable_http2()
.build();
let http_client = Client::builder(executor.clone())
.http2_only(false)
.build(https_connector.clone());
let http2_client = Client::builder(executor)
.http2_only(true)
.http2_adaptive_window(true)
.build(https_connector);
Ok(CachedTlsClient {
http_client,
http2_client,
config,
last_access: std::time::Instant::now(),
})
}
fn hash_config(&self, config: &TlsClientConfig) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
config.verify_server_cert.hash(&mut hasher);
config.custom_ca_certs.hash(&mut hasher);
config.client_cert.hash(&mut hasher);
config.client_key.hash(&mut hasher);
config.alpn_protocols.hash(&mut hasher);
config.handshake_timeout.hash(&mut hasher);
format!("{:016x}", hasher.finish())
}
fn evict_old_entries(&self, clients: &mut HashMap<String, CachedTlsClient>) {
let now = std::time::Instant::now();
clients.retain(|_, client| now.duration_since(client.last_access) < self.cache_ttl);
if clients.len() >= self.max_cache_size {
let to_remove = clients.len() - self.max_cache_size + 10;
let mut keys_to_remove: Vec<(String, std::time::Instant)> = clients
.iter()
.map(|(key, client)| (key.clone(), client.last_access))
.collect();
keys_to_remove.sort_by_key(|(_, last_access)| *last_access);
for (key, _) in keys_to_remove.iter().take(to_remove) {
clients.remove(key);
}
info!("Evicted {} old TLS client cache entries", to_remove);
}
}
pub fn cleanup_expired_entries(&self) {
let mut clients = self.clients.lock().unwrap();
let now = std::time::Instant::now();
let initial_size = clients.len();
clients.retain(|_, client| now.duration_since(client.last_access) < self.cache_ttl);
let removed = initial_size - clients.len();
if removed > 0 {
info!("Cleaned up {} expired TLS client cache entries", removed);
}
}
pub fn get_cache_stats(&self) -> TlsClientCacheStats {
let clients = self.clients.lock().unwrap();
let now = std::time::Instant::now();
let mut active_count = 0;
let mut expired_count = 0;
for client in clients.values() {
if now.duration_since(client.last_access) < self.cache_ttl {
active_count += 1;
} else {
expired_count += 1;
}
}
TlsClientCacheStats {
total_entries: clients.len(),
active_entries: active_count,
expired_entries: expired_count,
max_cache_size: self.max_cache_size,
cache_ttl: self.cache_ttl,
}
}
}
#[derive(Debug, Clone)]
pub struct TlsClientCacheStats {
pub total_entries: usize,
pub active_entries: usize,
pub expired_entries: usize,
pub max_cache_size: usize,
pub cache_ttl: Duration,
}
impl TlsClientCacheStats {
pub fn usage_percentage(&self) -> f64 {
if self.max_cache_size == 0 {
0.0
} else {
(self.total_entries as f64 / self.max_cache_size as f64) * 100.0
}
}
}
fn merge_ca_bundle(store: &mut RootCertStore, pem_data: &[u8]) -> Result<(), crate::Error> {
let mut reader = std::io::Cursor::new(pem_data);
let mut loaded_any = false;
for item in rustls_pemfile::read_all(&mut reader) {
let item =
item.map_err(|err| crate::Error::Tls(format!("Failed to parse CA bundle: {err}")))?;
if let rustls_pemfile::Item::X509Certificate(cert) = item {
store.add(cert).map_err(|err| {
crate::Error::Tls(format!("Failed to add CA certificate to root store: {err}"))
})?;
loaded_any = true;
}
}
if !loaded_any {
return Err(crate::Error::Tls(
"CA bundle does not contain any certificates".into(),
));
}
Ok(())
}
#[derive(Debug)]
struct NoCertificateVerification;
impl ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
let _ = (
_end_entity,
_intermediates,
_server_name,
_ocsp_response,
_now,
);
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::ED25519,
]
}
}
fn load_cert_chain_from_bytes(pem: &[u8]) -> Result<Vec<CertificateDer<'static>>, crate::Error> {
let mut reader = std::io::Cursor::new(pem);
let certs = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|err| crate::Error::Tls(format!("Failed to parse client certificate: {err}")))?;
if certs.is_empty() {
return Err(crate::Error::Tls(
"Client certificate chain is empty".into(),
));
}
Ok(certs)
}
fn load_private_key_from_bytes(pem: &[u8]) -> Result<PrivateKeyDer<'static>, crate::Error> {
if let Some(result) = {
let mut reader = std::io::Cursor::new(pem);
rustls_pemfile::pkcs8_private_keys(&mut reader).next()
} {
return result.map(PrivateKeyDer::from).map_err(|err| {
crate::Error::Tls(format!("Failed to parse PKCS#8 private key: {err}"))
});
}
if let Some(result) = {
let mut reader = std::io::Cursor::new(pem);
rustls_pemfile::rsa_private_keys(&mut reader).next()
} {
return result
.map(PrivateKeyDer::from)
.map_err(|err| crate::Error::Tls(format!("Failed to parse RSA private key: {err}")));
}
Err(crate::Error::Tls(
"Client private key not found in provided PEM".into(),
))
}