1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use std::time::Duration;
4
5use blueprint_core::{debug, info};
6use hyper_rustls::HttpsConnector;
7use hyper_util::client::legacy::Client;
8use hyper_util::rt::TokioExecutor;
9use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
10use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
11use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, SignatureScheme};
12use rustls_pemfile;
13
14use crate::db::RocksDb;
15use crate::models::ServiceModel;
16use crate::types::ServiceId;
17
18#[derive(Clone, Debug)]
20pub struct TlsClientConfig {
21 pub verify_server_cert: bool,
23 pub custom_ca_certs: Vec<Vec<u8>>,
25 pub client_cert: Option<Vec<u8>>,
27 pub client_key: Option<Vec<u8>>,
29 pub alpn_protocols: Vec<Vec<u8>>,
31 pub handshake_timeout: Duration,
33}
34
35impl Default for TlsClientConfig {
36 fn default() -> Self {
37 Self {
38 verify_server_cert: true,
39 custom_ca_certs: Vec::new(),
40 client_cert: None,
41 client_key: None,
42 alpn_protocols: vec![
43 b"h2".to_vec(), b"http/1.1".to_vec(), ],
46 handshake_timeout: Duration::from_secs(10),
47 }
48 }
49}
50
51#[derive(Clone, Debug)]
53pub struct CachedTlsClient {
54 pub http_client: Client<
56 HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
57 axum::body::Body,
58 >,
59 pub http2_client: Client<
61 HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
62 axum::body::Body,
63 >,
64 pub config: TlsClientConfig,
66 pub last_access: std::time::Instant,
68}
69
70#[derive(Clone, Debug)]
72pub struct TlsClientManager {
73 clients: Arc<Mutex<HashMap<String, CachedTlsClient>>>,
75 db: RocksDb,
77 max_cache_size: usize,
79 cache_ttl: Duration,
81}
82
83impl TlsClientManager {
84 pub fn new(db: RocksDb) -> Self {
86 Self {
87 clients: Arc::new(Mutex::new(HashMap::new())),
88 db,
89 max_cache_size: 100,
90 cache_ttl: Duration::from_secs(3600), }
92 }
93
94 pub async fn get_client_for_service(
96 &self,
97 service_id: ServiceId,
98 ) -> Result<CachedTlsClient, crate::Error> {
99 let service = ServiceModel::find_by_id(service_id, &self.db)?
101 .ok_or(crate::Error::ServiceNotFound(service_id))?;
102
103 let tls_config = self.get_service_tls_config(&service).await?;
105
106 let config_hash = self.hash_config(&tls_config);
108
109 {
111 let clients = self.clients.lock().unwrap();
112 if let Some(cached_client) = clients.get(&config_hash) {
113 if cached_client.last_access.elapsed() < self.cache_ttl {
115 debug!("Using cached TLS client for service {}", service_id);
116 return Ok(cached_client.clone());
117 }
118 }
119 }
120
121 debug!("Creating new TLS client for service {}", service_id);
123 let client = self.create_tls_client(tls_config.clone()).await?;
124
125 {
127 let mut clients = self.clients.lock().unwrap();
128
129 if clients.len() >= self.max_cache_size {
131 self.evict_old_entries(&mut clients);
132 }
133
134 clients.insert(config_hash.clone(), client.clone());
135 }
136
137 Ok(client)
138 }
139
140 async fn get_service_tls_config(
142 &self,
143 service: &ServiceModel,
144 ) -> Result<TlsClientConfig, crate::Error> {
145 let mut config = TlsClientConfig::default();
146
147 if let Some(profile) = &service.tls_profile {
149 if profile.tls_enabled {
150 config.verify_server_cert = true; if !profile.encrypted_upstream_ca_bundle.is_empty() {
155 config
156 .custom_ca_certs
157 .push(profile.encrypted_upstream_ca_bundle.clone());
158 }
159
160 if !profile.encrypted_upstream_client_cert.is_empty()
162 && !profile.encrypted_upstream_client_key.is_empty()
163 {
164 config.client_cert = Some(profile.encrypted_upstream_client_cert.clone());
165 config.client_key = Some(profile.encrypted_upstream_client_key.clone());
166 }
167 }
168 }
169
170 Ok(config)
171 }
172
173 async fn create_tls_client(
175 &self,
176 config: TlsClientConfig,
177 ) -> Result<CachedTlsClient, crate::Error> {
178 let executor = TokioExecutor::new();
179
180 let mut root_store = RootCertStore::empty();
181 if !config.custom_ca_certs.is_empty() {
182 for ca_bundle in &config.custom_ca_certs {
183 merge_ca_bundle(&mut root_store, ca_bundle)?;
184 }
185 }
186
187 let builder = ClientConfig::builder().with_root_certificates(root_store);
188
189 let mut client_config = if let (Some(client_cert), Some(client_key)) =
190 (config.client_cert.as_ref(), config.client_key.as_ref())
191 {
192 let cert_chain = load_cert_chain_from_bytes(client_cert)?;
193 let private_key = load_private_key_from_bytes(client_key)?;
194 builder
195 .with_client_auth_cert(cert_chain, private_key)
196 .map_err(|err| {
197 crate::Error::Tls(format!("failed to configure client mTLS: {err}"))
198 })?
199 } else {
200 builder.with_no_client_auth()
201 };
202
203 client_config.alpn_protocols = config.alpn_protocols.clone();
204
205 if !config.verify_server_cert {
206 client_config
207 .dangerous()
208 .set_certificate_verifier(Arc::new(NoCertificateVerification));
209 }
210
211 let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
212 .with_tls_config(client_config)
213 .https_or_http()
214 .enable_http2()
215 .build();
216
217 let http_client = Client::builder(executor.clone())
219 .http2_only(false)
220 .build(https_connector.clone());
221
222 let http2_client = Client::builder(executor)
224 .http2_only(true)
225 .http2_adaptive_window(true)
226 .build(https_connector);
227
228 Ok(CachedTlsClient {
229 http_client,
230 http2_client,
231 config,
232 last_access: std::time::Instant::now(),
233 })
234 }
235
236 fn hash_config(&self, config: &TlsClientConfig) -> String {
238 use std::collections::hash_map::DefaultHasher;
239 use std::hash::{Hash, Hasher};
240
241 let mut hasher = DefaultHasher::new();
242
243 config.verify_server_cert.hash(&mut hasher);
244 config.custom_ca_certs.hash(&mut hasher);
245 config.client_cert.hash(&mut hasher);
246 config.client_key.hash(&mut hasher);
247 config.alpn_protocols.hash(&mut hasher);
248 config.handshake_timeout.hash(&mut hasher);
249
250 format!("{:016x}", hasher.finish())
251 }
252
253 fn evict_old_entries(&self, clients: &mut HashMap<String, CachedTlsClient>) {
255 let now = std::time::Instant::now();
256
257 clients.retain(|_, client| now.duration_since(client.last_access) < self.cache_ttl);
259
260 if clients.len() >= self.max_cache_size {
262 let to_remove = clients.len() - self.max_cache_size + 10; let mut keys_to_remove: Vec<(String, std::time::Instant)> = clients
266 .iter()
267 .map(|(key, client)| (key.clone(), client.last_access))
268 .collect();
269
270 keys_to_remove.sort_by_key(|(_, last_access)| *last_access);
272
273 for (key, _) in keys_to_remove.iter().take(to_remove) {
275 clients.remove(key);
276 }
277
278 info!("Evicted {} old TLS client cache entries", to_remove);
279 }
280 }
281
282 pub fn cleanup_expired_entries(&self) {
284 let mut clients = self.clients.lock().unwrap();
285 let now = std::time::Instant::now();
286 let initial_size = clients.len();
287
288 clients.retain(|_, client| now.duration_since(client.last_access) < self.cache_ttl);
289
290 let removed = initial_size - clients.len();
291 if removed > 0 {
292 info!("Cleaned up {} expired TLS client cache entries", removed);
293 }
294 }
295
296 pub fn get_cache_stats(&self) -> TlsClientCacheStats {
298 let clients = self.clients.lock().unwrap();
299 let now = std::time::Instant::now();
300
301 let mut active_count = 0;
302 let mut expired_count = 0;
303
304 for client in clients.values() {
305 if now.duration_since(client.last_access) < self.cache_ttl {
306 active_count += 1;
307 } else {
308 expired_count += 1;
309 }
310 }
311
312 TlsClientCacheStats {
313 total_entries: clients.len(),
314 active_entries: active_count,
315 expired_entries: expired_count,
316 max_cache_size: self.max_cache_size,
317 cache_ttl: self.cache_ttl,
318 }
319 }
320}
321
322#[derive(Debug, Clone)]
324pub struct TlsClientCacheStats {
325 pub total_entries: usize,
326 pub active_entries: usize,
327 pub expired_entries: usize,
328 pub max_cache_size: usize,
329 pub cache_ttl: Duration,
330}
331
332impl TlsClientCacheStats {
333 pub fn usage_percentage(&self) -> f64 {
334 if self.max_cache_size == 0 {
335 0.0
336 } else {
337 (self.total_entries as f64 / self.max_cache_size as f64) * 100.0
338 }
339 }
340}
341
342fn merge_ca_bundle(store: &mut RootCertStore, pem_data: &[u8]) -> Result<(), crate::Error> {
343 let mut reader = std::io::Cursor::new(pem_data);
344 let mut loaded_any = false;
345
346 for item in rustls_pemfile::read_all(&mut reader) {
347 let item =
348 item.map_err(|err| crate::Error::Tls(format!("Failed to parse CA bundle: {err}")))?;
349 if let rustls_pemfile::Item::X509Certificate(cert) = item {
350 store.add(cert).map_err(|err| {
351 crate::Error::Tls(format!("Failed to add CA certificate to root store: {err}"))
352 })?;
353 loaded_any = true;
354 }
355 }
356
357 if !loaded_any {
358 return Err(crate::Error::Tls(
359 "CA bundle does not contain any certificates".into(),
360 ));
361 }
362
363 Ok(())
364}
365
366#[derive(Debug)]
367struct NoCertificateVerification;
368
369impl ServerCertVerifier for NoCertificateVerification {
370 fn verify_server_cert(
371 &self,
372 _end_entity: &CertificateDer<'_>,
373 _intermediates: &[CertificateDer<'_>],
374 _server_name: &ServerName,
375 _ocsp_response: &[u8],
376 _now: UnixTime,
377 ) -> Result<ServerCertVerified, rustls::Error> {
378 let _ = (
379 _end_entity,
380 _intermediates,
381 _server_name,
382 _ocsp_response,
383 _now,
384 );
385 Ok(ServerCertVerified::assertion())
386 }
387
388 fn verify_tls12_signature(
389 &self,
390 _message: &[u8],
391 _cert: &CertificateDer<'_>,
392 _dss: &DigitallySignedStruct,
393 ) -> Result<HandshakeSignatureValid, rustls::Error> {
394 Ok(HandshakeSignatureValid::assertion())
395 }
396
397 fn verify_tls13_signature(
398 &self,
399 _message: &[u8],
400 _cert: &CertificateDer<'_>,
401 _dss: &DigitallySignedStruct,
402 ) -> Result<HandshakeSignatureValid, rustls::Error> {
403 Ok(HandshakeSignatureValid::assertion())
404 }
405
406 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
407 vec![
408 SignatureScheme::RSA_PSS_SHA256,
409 SignatureScheme::RSA_PSS_SHA384,
410 SignatureScheme::RSA_PSS_SHA512,
411 SignatureScheme::ECDSA_NISTP256_SHA256,
412 SignatureScheme::ECDSA_NISTP384_SHA384,
413 SignatureScheme::ED25519,
414 ]
415 }
416}
417
418fn load_cert_chain_from_bytes(pem: &[u8]) -> Result<Vec<CertificateDer<'static>>, crate::Error> {
419 let mut reader = std::io::Cursor::new(pem);
420 let certs = rustls_pemfile::certs(&mut reader)
421 .collect::<Result<Vec<_>, _>>()
422 .map_err(|err| crate::Error::Tls(format!("Failed to parse client certificate: {err}")))?;
423 if certs.is_empty() {
424 return Err(crate::Error::Tls(
425 "Client certificate chain is empty".into(),
426 ));
427 }
428 Ok(certs)
429}
430
431fn load_private_key_from_bytes(pem: &[u8]) -> Result<PrivateKeyDer<'static>, crate::Error> {
432 if let Some(result) = {
433 let mut reader = std::io::Cursor::new(pem);
434 rustls_pemfile::pkcs8_private_keys(&mut reader).next()
435 } {
436 return result.map(PrivateKeyDer::from).map_err(|err| {
437 crate::Error::Tls(format!("Failed to parse PKCS#8 private key: {err}"))
438 });
439 }
440
441 if let Some(result) = {
442 let mut reader = std::io::Cursor::new(pem);
443 rustls_pemfile::rsa_private_keys(&mut reader).next()
444 } {
445 return result
446 .map(PrivateKeyDer::from)
447 .map_err(|err| crate::Error::Tls(format!("Failed to parse RSA private key: {err}")));
448 }
449
450 Err(crate::Error::Tls(
451 "Client private key not found in provided PEM".into(),
452 ))
453}