microsandbox_network/tls/
state.rs1use std::collections::HashMap;
4use std::num::NonZeroUsize;
5use std::path::{Path, PathBuf};
6use std::sync::{Arc, Mutex};
7
8use lru::LruCache;
9use microsandbox_utils::TLS_SUBDIR;
10use rustls::DigitallySignedStruct;
11use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
12use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
13use time::{Duration, OffsetDateTime};
14use tokio_rustls::TlsConnector;
15
16use super::ca::CertAuthority;
17use super::certgen::{self, DomainCert, DomainCertError};
18use super::config::TlsConfig;
19use crate::secrets::config::SecretsConfig;
20
21pub struct TlsState {
30 pub intercept_ca: CertAuthority,
32 cert_cache: Mutex<LruCache<String, Arc<DomainCert>>>,
34 pub connector: TlsConnector,
36 scoped_upstream_connectors: Vec<ScopedUpstreamConnector>,
38 pub config: TlsConfig,
40 pub secrets: SecretsConfig,
42 bypass_patterns: Vec<DomainPattern>,
44}
45
46enum DomainPattern {
48 Exact(String),
50 Wildcard { suffix: String, dotted: String },
53}
54
55struct ScopedUpstreamConnector {
57 pattern: DomainPattern,
58 connector: TlsConnector,
59}
60
61struct ScopedUpstreamSettings {
63 pattern: String,
64 ca_cert: Vec<PathBuf>,
65 verify_upstream: Option<bool>,
66}
67
68#[derive(Debug)]
71struct NoVerify;
72
73const CERT_REFRESH_WINDOW: Duration = Duration::minutes(5);
76
77impl TlsState {
82 pub fn new(config: TlsConfig, secrets: SecretsConfig) -> Self {
89 let ca = load_or_generate_ca(&config);
90
91 let capacity =
92 NonZeroUsize::new(config.cache.capacity).unwrap_or(NonZeroUsize::new(1000).unwrap());
93 let cert_cache = Mutex::new(LruCache::new(capacity));
94
95 let connector = build_upstream_connector(&config, config.verify_upstream, &[]);
96 let scoped_upstream_connectors = build_scoped_upstream_connectors(&config);
97
98 let bypass_patterns = config
100 .bypass
101 .iter()
102 .map(|pattern| DomainPattern::new(pattern))
103 .collect();
104
105 Self {
106 intercept_ca: ca,
107 cert_cache,
108 connector,
109 scoped_upstream_connectors,
110 config,
111 secrets,
112 bypass_patterns,
113 }
114 }
115
116 pub fn get_or_generate_cert(&self, domain: &str) -> Result<Arc<DomainCert>, DomainCertError> {
118 let mut cache = match self.cert_cache.lock() {
119 Ok(cache) => cache,
120 Err(poisoned) => {
121 tracing::warn!("TLS certificate cache was poisoned; recovering");
122 poisoned.into_inner()
123 }
124 };
125 if let Some(cert) = cache.get(domain)
126 && cert.expires_at > OffsetDateTime::now_utc() + CERT_REFRESH_WINDOW
127 {
128 return Ok(cert.clone());
129 }
130
131 let cert = Arc::new(certgen::generate_domain_cert(
132 domain,
133 &self.intercept_ca,
134 self.config.cache.validity_hours,
135 )?);
136 cache.put(domain.to_string(), cert.clone());
137 Ok(cert)
138 }
139
140 pub fn should_bypass(&self, sni: &str) -> bool {
142 let sni_lower = normalize_domain(sni);
143 self.bypass_patterns
144 .iter()
145 .any(|pattern| pattern.matches_normalized(&sni_lower))
146 }
147
148 pub fn upstream_connector_for(&self, sni: &str) -> &TlsConnector {
150 let sni_lower = normalize_domain(sni);
151 let mut best = None;
152
153 for scoped in &self.scoped_upstream_connectors {
154 if !scoped.pattern.matches_normalized(&sni_lower) {
155 continue;
156 }
157 let specificity = scoped.pattern.specificity();
158 if best
159 .map(|(_, best_specificity)| specificity > best_specificity)
160 .unwrap_or(true)
161 {
162 best = Some((scoped, specificity));
163 }
164 }
165
166 best.map_or(&self.connector, |(scoped, _)| &scoped.connector)
167 }
168
169 pub fn ca_cert_pem(&self) -> Vec<u8> {
171 self.intercept_ca.cert_pem()
172 }
173}
174
175impl DomainPattern {
176 fn new(pattern: &str) -> Self {
177 let lower = normalize_domain(pattern);
178 if let Some(suffix) = lower.strip_prefix("*.") {
179 let dotted = format!(".{suffix}");
180 DomainPattern::Wildcard {
181 suffix: suffix.to_string(),
182 dotted,
183 }
184 } else {
185 DomainPattern::Exact(lower)
186 }
187 }
188
189 fn matches_normalized(&self, sni_lower: &str) -> bool {
190 match self {
191 DomainPattern::Exact(exact) => sni_lower == exact,
192 DomainPattern::Wildcard { suffix, dotted } => {
193 sni_lower == suffix || sni_lower.ends_with(dotted.as_str())
194 }
195 }
196 }
197
198 fn specificity(&self) -> usize {
199 match self {
200 DomainPattern::Exact(exact) => exact.len() + 1,
201 DomainPattern::Wildcard { suffix, .. } => suffix.len(),
202 }
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::super::config::{ScopedUpstreamCaCert, ScopedVerifyUpstream};
209 use super::*;
210
211 use crate::secrets::config::SecretsConfig;
212
213 #[test]
214 fn regenerates_cached_domain_cert_when_near_expiry() {
215 let _ = rustls::crypto::ring::default_provider().install_default();
216 let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
217 let first = state.get_or_generate_cert("openrouter.ai").unwrap();
218 let original_expires_at = first.expires_at;
219
220 {
221 let mut cache = state.cert_cache.lock().unwrap();
222 let stale = Arc::new(DomainCert {
223 chain: first.chain.clone(),
224 key: first.key.clone_key(),
225 expires_at: OffsetDateTime::now_utc() + Duration::seconds(30),
226 server_config: first.server_config.clone(),
227 });
228 cache.put("openrouter.ai".to_string(), stale);
229 }
230
231 let refreshed = state.get_or_generate_cert("openrouter.ai").unwrap();
232 assert!(refreshed.expires_at > OffsetDateTime::now_utc() + Duration::hours(23));
233 assert!(refreshed.expires_at > original_expires_at - Duration::minutes(10));
234 }
235
236 #[test]
237 fn invalid_domain_cert_request_does_not_poison_cache() {
238 let _ = rustls::crypto::ring::default_provider().install_default();
239 let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
240
241 assert!(state.get_or_generate_cert("snowman.☃").is_err());
242 assert!(state.get_or_generate_cert("openrouter.ai").is_ok());
243 }
244
245 #[test]
246 fn default_ca_dir_uses_microsandbox_home_tls_subdir() {
247 let home = PathBuf::from("isolated-msb-home");
248
249 assert_eq!(
250 default_ca_dir_from_home(&home),
251 home.join(microsandbox_utils::TLS_SUBDIR)
252 );
253 }
254
255 #[test]
256 fn domain_patterns_match_exact_and_wildcard_hosts() {
257 let exact = DomainPattern::new("api.internal.");
258 assert!(exact.matches_normalized("api.internal"));
259 assert!(!exact.matches_normalized("other.api.internal"));
260
261 let wildcard = DomainPattern::new("*.internal");
262 assert!(wildcard.matches_normalized("internal"));
263 assert!(wildcard.matches_normalized("api.internal"));
264 assert!(!wildcard.matches_normalized("notinternal"));
265 }
266
267 #[test]
268 fn domain_patterns_score_exact_as_more_specific() {
269 let exact = DomainPattern::new("api.internal");
270 let wildcard = DomainPattern::new("*.internal");
271
272 assert!(exact.specificity() > wildcard.specificity());
273 }
274
275 #[test]
276 fn scoped_upstream_settings_group_ca_and_verify_by_pattern() {
277 let mut config = TlsConfig::default();
278 config.scoped_upstream_ca_cert.push(ScopedUpstreamCaCert {
279 pattern: "*.internal".to_string(),
280 path: PathBuf::from("/tmp/one.pem"),
281 });
282 config.scoped_upstream_ca_cert.push(ScopedUpstreamCaCert {
283 pattern: "*.internal.".to_string(),
284 path: PathBuf::from("/tmp/two.pem"),
285 });
286 config.scoped_verify_upstream.push(ScopedVerifyUpstream {
287 pattern: "*.internal".to_string(),
288 verify: false,
289 });
290
291 let settings = grouped_scoped_upstream_settings(&config);
292
293 assert_eq!(settings.len(), 1);
294 assert_eq!(settings[0].pattern, "*.internal");
295 assert_eq!(
296 settings[0].ca_cert,
297 vec![PathBuf::from("/tmp/one.pem"), PathBuf::from("/tmp/two.pem")]
298 );
299 assert_eq!(settings[0].verify_upstream, Some(false));
300 }
301
302 #[test]
303 fn upstream_connector_for_selects_scoped_no_verify_connector() {
304 let _ = rustls::crypto::ring::default_provider().install_default();
305 let mut config = TlsConfig::default();
306 config.scoped_verify_upstream.push(ScopedVerifyUpstream {
307 pattern: "*.internal".to_string(),
308 verify: false,
309 });
310 let state = TlsState::new(config, SecretsConfig::default());
311
312 let default = &state.connector as *const TlsConnector;
313 let scoped = state.upstream_connector_for("api.internal") as *const TlsConnector;
314 let unmatched = state.upstream_connector_for("api.example.com") as *const TlsConnector;
315
316 assert_ne!(default, scoped);
317 assert_eq!(default, unmatched);
318 }
319}
320
321impl ServerCertVerifier for NoVerify {
326 fn verify_server_cert(
327 &self,
328 _end_entity: &CertificateDer<'_>,
329 _intermediates: &[CertificateDer<'_>],
330 _server_name: &ServerName<'_>,
331 _ocsp_response: &[u8],
332 _now: UnixTime,
333 ) -> Result<ServerCertVerified, rustls::Error> {
334 Ok(ServerCertVerified::assertion())
335 }
336
337 fn verify_tls12_signature(
338 &self,
339 _message: &[u8],
340 _cert: &CertificateDer<'_>,
341 _dss: &DigitallySignedStruct,
342 ) -> Result<HandshakeSignatureValid, rustls::Error> {
343 Ok(HandshakeSignatureValid::assertion())
344 }
345
346 fn verify_tls13_signature(
347 &self,
348 _message: &[u8],
349 _cert: &CertificateDer<'_>,
350 _dss: &DigitallySignedStruct,
351 ) -> Result<HandshakeSignatureValid, rustls::Error> {
352 Ok(HandshakeSignatureValid::assertion())
353 }
354
355 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
356 static SCHEMES: std::sync::OnceLock<Vec<rustls::SignatureScheme>> =
357 std::sync::OnceLock::new();
358 SCHEMES
359 .get_or_init(|| {
360 rustls::crypto::ring::default_provider()
361 .signature_verification_algorithms
362 .supported_schemes()
363 })
364 .clone()
365 }
366}
367
368fn build_upstream_connector(
377 config: &TlsConfig,
378 verify_upstream: bool,
379 scoped_ca_cert: &[PathBuf],
380) -> TlsConnector {
381 let client_config = if verify_upstream {
382 let mut root_store = rustls::RootCertStore::empty();
383 let certs = rustls_native_certs::load_native_certs();
384 if !certs.errors.is_empty() {
385 tracing::warn!(
386 count = certs.errors.len(),
387 "errors loading native certificates"
388 );
389 }
390 let mut added = 0usize;
391 for cert in certs.certs {
392 if root_store.add(cert).is_ok() {
393 added += 1;
394 }
395 }
396 if added == 0 {
397 tracing::error!("no native root certificates loaded — all upstream TLS will fail");
398 }
399
400 load_upstream_ca_certificates(&mut root_store, &config.upstream_ca_cert);
401 load_upstream_ca_certificates(&mut root_store, scoped_ca_cert);
402
403 rustls::ClientConfig::builder()
404 .with_root_certificates(root_store)
405 .with_no_client_auth()
406 } else {
407 rustls::ClientConfig::builder()
408 .dangerous()
409 .with_custom_certificate_verifier(Arc::new(NoVerify))
410 .with_no_client_auth()
411 };
412
413 TlsConnector::from(Arc::new(client_config))
414}
415
416fn build_scoped_upstream_connectors(config: &TlsConfig) -> Vec<ScopedUpstreamConnector> {
418 grouped_scoped_upstream_settings(config)
419 .into_iter()
420 .filter_map(|settings| {
421 let verify_upstream = settings.verify_upstream.unwrap_or(config.verify_upstream);
422 if verify_upstream == config.verify_upstream && settings.ca_cert.is_empty() {
423 return None;
424 }
425
426 Some(ScopedUpstreamConnector {
427 pattern: DomainPattern::new(&settings.pattern),
428 connector: build_upstream_connector(config, verify_upstream, &settings.ca_cert),
429 })
430 })
431 .collect()
432}
433
434fn grouped_scoped_upstream_settings(config: &TlsConfig) -> Vec<ScopedUpstreamSettings> {
436 let mut grouped = Vec::<ScopedUpstreamSettings>::new();
437 let mut indexes = HashMap::<String, usize>::new();
438
439 for scoped in &config.scoped_upstream_ca_cert {
440 let index = scoped_settings_index(&mut grouped, &mut indexes, &scoped.pattern);
441 grouped[index].ca_cert.push(scoped.path.clone());
442 }
443
444 for scoped in &config.scoped_verify_upstream {
445 let index = scoped_settings_index(&mut grouped, &mut indexes, &scoped.pattern);
446 grouped[index].verify_upstream = Some(scoped.verify);
447 }
448
449 grouped
450}
451
452fn scoped_settings_index(
454 grouped: &mut Vec<ScopedUpstreamSettings>,
455 indexes: &mut HashMap<String, usize>,
456 pattern: &str,
457) -> usize {
458 let normalized = normalize_domain(pattern);
459 if let Some(index) = indexes.get(&normalized) {
460 return *index;
461 }
462
463 let index = grouped.len();
464 indexes.insert(normalized, index);
465 grouped.push(ScopedUpstreamSettings {
466 pattern: pattern.to_string(),
467 ca_cert: Vec::new(),
468 verify_upstream: None,
469 });
470 index
471}
472
473fn load_upstream_ca_certificates(root_store: &mut rustls::RootCertStore, paths: &[PathBuf]) {
475 for path in paths {
476 match std::fs::read(path) {
477 Ok(pem_data) => {
478 let mut extra_added = 0usize;
479 for cert in rustls_pemfile::certs(&mut pem_data.as_slice()).flatten() {
480 if root_store.add(cert).is_ok() {
481 extra_added += 1;
482 }
483 }
484 tracing::info!(
485 path = %path.display(),
486 count = extra_added,
487 "loaded upstream CA certificates"
488 );
489 }
490 Err(e) => {
491 tracing::error!(
492 path = %path.display(),
493 error = %e,
494 "failed to read upstream CA certificate file"
495 );
496 }
497 }
498 }
499}
500
501fn normalize_domain(domain: &str) -> String {
503 domain.trim_end_matches('.').to_ascii_lowercase()
504}
505
506fn load_or_generate_ca(config: &TlsConfig) -> CertAuthority {
513 if config.intercept_ca.cert_path.is_some() != config.intercept_ca.key_path.is_some() {
515 tracing::warn!(
516 "incomplete CA config: both cert_path and key_path must be set together, ignoring"
517 );
518 }
519
520 if let (Some(cert_path), Some(key_path)) = (
522 &config.intercept_ca.cert_path,
523 &config.intercept_ca.key_path,
524 ) {
525 match (std::fs::read(cert_path), std::fs::read(key_path)) {
526 (Ok(cert_pem), Ok(key_pem)) => match CertAuthority::load(&cert_pem, &key_pem) {
527 Ok(ca) => {
528 tracing::info!("loaded user-provided CA from {:?}", cert_path);
529 return ca;
530 }
531 Err(e) => {
532 tracing::error!(
533 error = %e,
534 "failed to load user-provided CA, falling back to auto-generate"
535 );
536 }
537 },
538 _ => {
539 tracing::error!(
540 "failed to read CA files from {:?} / {:?}, falling back to auto-generate",
541 cert_path,
542 key_path,
543 );
544 }
545 }
546 }
547
548 let default_dir = default_ca_dir();
550 let cert_path = default_dir.join("ca.crt");
551 let key_path = default_dir.join("ca.key");
552
553 if cert_path.exists()
554 && key_path.exists()
555 && let (Ok(cert_pem), Ok(key_pem)) = (std::fs::read(&cert_path), std::fs::read(&key_path))
556 && let Ok(ca) = CertAuthority::load(&cert_pem, &key_pem)
557 {
558 tracing::debug!("loaded persisted CA from {:?}", cert_path);
559 return ca;
560 }
561
562 let ca = CertAuthority::generate();
564 if let Err(e) = std::fs::create_dir_all(&default_dir) {
565 tracing::warn!(error = %e, "failed to create CA directory, CA will not persist");
566 } else {
567 if let Err(e) = std::fs::write(&cert_path, ca.cert_pem()) {
568 tracing::warn!(error = %e, "failed to persist CA certificate");
569 }
570 if let Err(e) = write_key_file(&key_path, &ca.key_pem()) {
571 tracing::warn!(error = %e, "failed to persist CA key");
572 } else {
573 tracing::info!("generated and persisted CA to {:?}", default_dir);
574 }
575 }
576 ca
577}
578
579fn default_ca_dir() -> PathBuf {
581 default_ca_dir_from_home(microsandbox_utils::resolve_home())
582}
583
584fn default_ca_dir_from_home(home: impl AsRef<Path>) -> PathBuf {
586 home.as_ref().join(TLS_SUBDIR)
587}
588
589fn write_key_file(path: &Path, data: &[u8]) -> std::io::Result<()> {
594 #[cfg(unix)]
595 {
596 use std::io::Write;
597 use std::os::unix::fs::OpenOptionsExt;
598 let mut file = std::fs::OpenOptions::new()
599 .write(true)
600 .create(true)
601 .truncate(true)
602 .mode(0o600)
603 .open(path)?;
604 file.write_all(data)?;
605 }
606 #[cfg(not(unix))]
607 {
608 std::fs::write(path, data)?;
609 }
610 Ok(())
611}