blueprint_auth/
tls_client.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use std::time::Duration;
4
5use blueprint_core::{debug, info};
6use hyper_rustls::HttpsConnector;
7use hyper_util::client::legacy::Client;
8use hyper_util::rt::TokioExecutor;
9use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
10use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
11use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, SignatureScheme};
12use rustls_pemfile;
13
14use crate::db::RocksDb;
15use crate::models::ServiceModel;
16use crate::types::ServiceId;
17
18/// TLS client configuration for outbound connections
19#[derive(Clone, Debug)]
20pub struct TlsClientConfig {
21    /// Whether to verify server certificates
22    pub verify_server_cert: bool,
23    /// Custom CA certificates to trust
24    pub custom_ca_certs: Vec<Vec<u8>>,
25    /// Client certificate for mTLS
26    pub client_cert: Option<Vec<u8>>,
27    /// Client private key for mTLS
28    pub client_key: Option<Vec<u8>>,
29    /// ALPN protocols to negotiate
30    pub alpn_protocols: Vec<Vec<u8>>,
31    /// Timeout for TLS handshake
32    pub handshake_timeout: Duration,
33}
34
35impl Default for TlsClientConfig {
36    fn default() -> Self {
37        Self {
38            verify_server_cert: true,
39            custom_ca_certs: Vec::new(),
40            client_cert: None,
41            client_key: None,
42            alpn_protocols: vec![
43                b"h2".to_vec(),       // HTTP/2
44                b"http/1.1".to_vec(), // HTTP/1.1
45            ],
46            handshake_timeout: Duration::from_secs(10),
47        }
48    }
49}
50
51/// Cached TLS client with configuration
52#[derive(Clone, Debug)]
53pub struct CachedTlsClient {
54    /// HTTP/1.1 client with TLS support
55    pub http_client: Client<
56        HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
57        axum::body::Body,
58    >,
59    /// HTTP/2 client with TLS support
60    pub http2_client: Client<
61        HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
62        axum::body::Body,
63    >,
64    /// Configuration used to create this client
65    pub config: TlsClientConfig,
66    /// Last access timestamp for cache eviction
67    pub last_access: std::time::Instant,
68}
69
70/// TLS client manager with caching and per-service configuration
71#[derive(Clone, Debug)]
72pub struct TlsClientManager {
73    /// Cache of TLS clients by configuration hash
74    clients: Arc<Mutex<HashMap<String, CachedTlsClient>>>,
75    /// Database for persistent storage
76    db: RocksDb,
77    /// Maximum cache size
78    max_cache_size: usize,
79    /// Cache entry TTL
80    cache_ttl: Duration,
81}
82
83impl TlsClientManager {
84    /// Create a new TLS client manager
85    pub fn new(db: RocksDb) -> Self {
86        Self {
87            clients: Arc::new(Mutex::new(HashMap::new())),
88            db,
89            max_cache_size: 100,
90            cache_ttl: Duration::from_secs(3600), // 1 hour
91        }
92    }
93
94    /// Get or create a TLS client for a service
95    pub async fn get_client_for_service(
96        &self,
97        service_id: ServiceId,
98    ) -> Result<CachedTlsClient, crate::Error> {
99        // Load service model
100        let service = ServiceModel::find_by_id(service_id, &self.db)?
101            .ok_or(crate::Error::ServiceNotFound(service_id))?;
102
103        // Get TLS configuration for service
104        let tls_config = self.get_service_tls_config(&service).await?;
105
106        // Generate configuration hash for caching
107        let config_hash = self.hash_config(&tls_config);
108
109        // Check cache first
110        {
111            let clients = self.clients.lock().unwrap();
112            if let Some(cached_client) = clients.get(&config_hash) {
113                // Check if cache entry is still valid
114                if cached_client.last_access.elapsed() < self.cache_ttl {
115                    debug!("Using cached TLS client for service {}", service_id);
116                    return Ok(cached_client.clone());
117                }
118            }
119        }
120
121        // Create new client
122        debug!("Creating new TLS client for service {}", service_id);
123        let client = self.create_tls_client(tls_config.clone()).await?;
124
125        // Update cache
126        {
127            let mut clients = self.clients.lock().unwrap();
128
129            // Evict old entries if cache is full
130            if clients.len() >= self.max_cache_size {
131                self.evict_old_entries(&mut clients);
132            }
133
134            clients.insert(config_hash.clone(), client.clone());
135        }
136
137        Ok(client)
138    }
139
140    /// Get TLS configuration for a service
141    async fn get_service_tls_config(
142        &self,
143        service: &ServiceModel,
144    ) -> Result<TlsClientConfig, crate::Error> {
145        let mut config = TlsClientConfig::default();
146
147        // Check if service has a TLS profile
148        if let Some(profile) = &service.tls_profile {
149            if profile.tls_enabled {
150                // Apply profile configuration
151                config.verify_server_cert = true; // Default to true for TLS-enabled services
152
153                // Load custom CA certificates if specified
154                if !profile.encrypted_upstream_ca_bundle.is_empty() {
155                    config
156                        .custom_ca_certs
157                        .push(profile.encrypted_upstream_ca_bundle.clone());
158                }
159
160                // Load client certificate for mTLS if specified
161                if !profile.encrypted_upstream_client_cert.is_empty()
162                    && !profile.encrypted_upstream_client_key.is_empty()
163                {
164                    config.client_cert = Some(profile.encrypted_upstream_client_cert.clone());
165                    config.client_key = Some(profile.encrypted_upstream_client_key.clone());
166                }
167            }
168        }
169
170        Ok(config)
171    }
172
173    /// Create a TLS client with the given configuration
174    async fn create_tls_client(
175        &self,
176        config: TlsClientConfig,
177    ) -> Result<CachedTlsClient, crate::Error> {
178        let executor = TokioExecutor::new();
179
180        let mut root_store = RootCertStore::empty();
181        if !config.custom_ca_certs.is_empty() {
182            for ca_bundle in &config.custom_ca_certs {
183                merge_ca_bundle(&mut root_store, ca_bundle)?;
184            }
185        }
186
187        let builder = ClientConfig::builder().with_root_certificates(root_store);
188
189        let mut client_config = if let (Some(client_cert), Some(client_key)) =
190            (config.client_cert.as_ref(), config.client_key.as_ref())
191        {
192            let cert_chain = load_cert_chain_from_bytes(client_cert)?;
193            let private_key = load_private_key_from_bytes(client_key)?;
194            builder
195                .with_client_auth_cert(cert_chain, private_key)
196                .map_err(|err| {
197                    crate::Error::Tls(format!("failed to configure client mTLS: {err}"))
198                })?
199        } else {
200            builder.with_no_client_auth()
201        };
202
203        client_config.alpn_protocols = config.alpn_protocols.clone();
204
205        if !config.verify_server_cert {
206            client_config
207                .dangerous()
208                .set_certificate_verifier(Arc::new(NoCertificateVerification));
209        }
210
211        let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
212            .with_tls_config(client_config)
213            .https_or_http()
214            .enable_http2()
215            .build();
216
217        // Build HTTP/1.1 client (for REST APIs)
218        let http_client = Client::builder(executor.clone())
219            .http2_only(false)
220            .build(https_connector.clone());
221
222        // Build HTTP/2 client (for gRPC)
223        let http2_client = Client::builder(executor)
224            .http2_only(true)
225            .http2_adaptive_window(true)
226            .build(https_connector);
227
228        Ok(CachedTlsClient {
229            http_client,
230            http2_client,
231            config,
232            last_access: std::time::Instant::now(),
233        })
234    }
235
236    /// Generate a hash for TLS configuration
237    fn hash_config(&self, config: &TlsClientConfig) -> String {
238        use std::collections::hash_map::DefaultHasher;
239        use std::hash::{Hash, Hasher};
240
241        let mut hasher = DefaultHasher::new();
242
243        config.verify_server_cert.hash(&mut hasher);
244        config.custom_ca_certs.hash(&mut hasher);
245        config.client_cert.hash(&mut hasher);
246        config.client_key.hash(&mut hasher);
247        config.alpn_protocols.hash(&mut hasher);
248        config.handshake_timeout.hash(&mut hasher);
249
250        format!("{:016x}", hasher.finish())
251    }
252
253    /// Evict old entries from the cache
254    fn evict_old_entries(&self, clients: &mut HashMap<String, CachedTlsClient>) {
255        let now = std::time::Instant::now();
256
257        // Remove expired entries
258        clients.retain(|_, client| now.duration_since(client.last_access) < self.cache_ttl);
259
260        // If still too many, remove the oldest entries
261        if clients.len() >= self.max_cache_size {
262            let to_remove = clients.len() - self.max_cache_size + 10; // Remove 10 extra to avoid frequent evictions
263
264            // Collect keys to remove first to avoid borrow issues
265            let mut keys_to_remove: Vec<(String, std::time::Instant)> = clients
266                .iter()
267                .map(|(key, client)| (key.clone(), client.last_access))
268                .collect();
269
270            // Sort by last access time
271            keys_to_remove.sort_by_key(|(_, last_access)| *last_access);
272
273            // Remove oldest entries
274            for (key, _) in keys_to_remove.iter().take(to_remove) {
275                clients.remove(key);
276            }
277
278            info!("Evicted {} old TLS client cache entries", to_remove);
279        }
280    }
281
282    /// Clean up expired cache entries
283    pub fn cleanup_expired_entries(&self) {
284        let mut clients = self.clients.lock().unwrap();
285        let now = std::time::Instant::now();
286        let initial_size = clients.len();
287
288        clients.retain(|_, client| now.duration_since(client.last_access) < self.cache_ttl);
289
290        let removed = initial_size - clients.len();
291        if removed > 0 {
292            info!("Cleaned up {} expired TLS client cache entries", removed);
293        }
294    }
295
296    /// Get cache statistics
297    pub fn get_cache_stats(&self) -> TlsClientCacheStats {
298        let clients = self.clients.lock().unwrap();
299        let now = std::time::Instant::now();
300
301        let mut active_count = 0;
302        let mut expired_count = 0;
303
304        for client in clients.values() {
305            if now.duration_since(client.last_access) < self.cache_ttl {
306                active_count += 1;
307            } else {
308                expired_count += 1;
309            }
310        }
311
312        TlsClientCacheStats {
313            total_entries: clients.len(),
314            active_entries: active_count,
315            expired_entries: expired_count,
316            max_cache_size: self.max_cache_size,
317            cache_ttl: self.cache_ttl,
318        }
319    }
320}
321
322/// Cache statistics for TLS clients
323#[derive(Debug, Clone)]
324pub struct TlsClientCacheStats {
325    pub total_entries: usize,
326    pub active_entries: usize,
327    pub expired_entries: usize,
328    pub max_cache_size: usize,
329    pub cache_ttl: Duration,
330}
331
332impl TlsClientCacheStats {
333    pub fn usage_percentage(&self) -> f64 {
334        if self.max_cache_size == 0 {
335            0.0
336        } else {
337            (self.total_entries as f64 / self.max_cache_size as f64) * 100.0
338        }
339    }
340}
341
342fn merge_ca_bundle(store: &mut RootCertStore, pem_data: &[u8]) -> Result<(), crate::Error> {
343    let mut reader = std::io::Cursor::new(pem_data);
344    let mut loaded_any = false;
345
346    for item in rustls_pemfile::read_all(&mut reader) {
347        let item =
348            item.map_err(|err| crate::Error::Tls(format!("Failed to parse CA bundle: {err}")))?;
349        if let rustls_pemfile::Item::X509Certificate(cert) = item {
350            store.add(cert).map_err(|err| {
351                crate::Error::Tls(format!("Failed to add CA certificate to root store: {err}"))
352            })?;
353            loaded_any = true;
354        }
355    }
356
357    if !loaded_any {
358        return Err(crate::Error::Tls(
359            "CA bundle does not contain any certificates".into(),
360        ));
361    }
362
363    Ok(())
364}
365
366#[derive(Debug)]
367struct NoCertificateVerification;
368
369impl ServerCertVerifier for NoCertificateVerification {
370    fn verify_server_cert(
371        &self,
372        _end_entity: &CertificateDer<'_>,
373        _intermediates: &[CertificateDer<'_>],
374        _server_name: &ServerName,
375        _ocsp_response: &[u8],
376        _now: UnixTime,
377    ) -> Result<ServerCertVerified, rustls::Error> {
378        let _ = (
379            _end_entity,
380            _intermediates,
381            _server_name,
382            _ocsp_response,
383            _now,
384        );
385        Ok(ServerCertVerified::assertion())
386    }
387
388    fn verify_tls12_signature(
389        &self,
390        _message: &[u8],
391        _cert: &CertificateDer<'_>,
392        _dss: &DigitallySignedStruct,
393    ) -> Result<HandshakeSignatureValid, rustls::Error> {
394        Ok(HandshakeSignatureValid::assertion())
395    }
396
397    fn verify_tls13_signature(
398        &self,
399        _message: &[u8],
400        _cert: &CertificateDer<'_>,
401        _dss: &DigitallySignedStruct,
402    ) -> Result<HandshakeSignatureValid, rustls::Error> {
403        Ok(HandshakeSignatureValid::assertion())
404    }
405
406    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
407        vec![
408            SignatureScheme::RSA_PSS_SHA256,
409            SignatureScheme::RSA_PSS_SHA384,
410            SignatureScheme::RSA_PSS_SHA512,
411            SignatureScheme::ECDSA_NISTP256_SHA256,
412            SignatureScheme::ECDSA_NISTP384_SHA384,
413            SignatureScheme::ED25519,
414        ]
415    }
416}
417
418fn load_cert_chain_from_bytes(pem: &[u8]) -> Result<Vec<CertificateDer<'static>>, crate::Error> {
419    let mut reader = std::io::Cursor::new(pem);
420    let certs = rustls_pemfile::certs(&mut reader)
421        .collect::<Result<Vec<_>, _>>()
422        .map_err(|err| crate::Error::Tls(format!("Failed to parse client certificate: {err}")))?;
423    if certs.is_empty() {
424        return Err(crate::Error::Tls(
425            "Client certificate chain is empty".into(),
426        ));
427    }
428    Ok(certs)
429}
430
431fn load_private_key_from_bytes(pem: &[u8]) -> Result<PrivateKeyDer<'static>, crate::Error> {
432    if let Some(result) = {
433        let mut reader = std::io::Cursor::new(pem);
434        rustls_pemfile::pkcs8_private_keys(&mut reader).next()
435    } {
436        return result.map(PrivateKeyDer::from).map_err(|err| {
437            crate::Error::Tls(format!("Failed to parse PKCS#8 private key: {err}"))
438        });
439    }
440
441    if let Some(result) = {
442        let mut reader = std::io::Cursor::new(pem);
443        rustls_pemfile::rsa_private_keys(&mut reader).next()
444    } {
445        return result
446            .map(PrivateKeyDer::from)
447            .map_err(|err| crate::Error::Tls(format!("Failed to parse RSA private key: {err}")));
448    }
449
450    Err(crate::Error::Tls(
451        "Client private key not found in provided PEM".into(),
452    ))
453}