Skip to main content

microsandbox_network/tls/
state.rs

1//! Shared TLS state: CA, certificate cache, and upstream connectors.
2
3use std::collections::HashMap;
4use std::num::NonZeroUsize;
5use std::path::{Path, PathBuf};
6use std::sync::{Arc, Mutex};
7
8use lru::LruCache;
9use microsandbox_utils::TLS_SUBDIR;
10use rustls::DigitallySignedStruct;
11use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
12use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
13use time::{Duration, OffsetDateTime};
14use tokio_rustls::TlsConnector;
15
16use super::ca::CertAuthority;
17use super::certgen::{self, DomainCert, DomainCertError};
18use super::config::TlsConfig;
19use crate::secrets::config::SecretsConfig;
20
21//--------------------------------------------------------------------------------------------------
22// Types
23//--------------------------------------------------------------------------------------------------
24
25/// Shared TLS interception state.
26///
27/// Holds the CA, per-domain certificate cache, upstream TLS connectors,
28/// and configuration. Shared across all TLS proxy tasks via `Arc`.
29pub struct TlsState {
30    /// Interception CA for signing per-domain certs presented to the guest.
31    pub intercept_ca: CertAuthority,
32    /// LRU cache of generated domain certificates.
33    cert_cache: Mutex<LruCache<String, Arc<DomainCert>>>,
34    /// Default TLS connector for upstream (real server) connections.
35    pub connector: TlsConnector,
36    /// Host-scoped TLS connectors for upstream connections.
37    scoped_upstream_connectors: Vec<ScopedUpstreamConnector>,
38    /// TLS configuration.
39    pub config: TlsConfig,
40    /// Secrets configuration for placeholder substitution.
41    pub secrets: SecretsConfig,
42    /// Pre-computed lowercased bypass patterns for efficient matching.
43    bypass_patterns: Vec<DomainPattern>,
44}
45
46/// A pre-processed domain pattern (avoids per-connection allocations).
47enum DomainPattern {
48    /// Exact domain match (lowercased).
49    Exact(String),
50    /// Wildcard suffix match. `suffix` is the bare suffix, `dotted` is `.suffix`
51    /// (pre-computed to avoid per-connection `format!` allocations).
52    Wildcard { suffix: String, dotted: String },
53}
54
55/// An upstream connector selected only for matching server names.
56struct ScopedUpstreamConnector {
57    pattern: DomainPattern,
58    connector: TlsConnector,
59}
60
61/// Effective upstream TLS settings for one host pattern.
62struct ScopedUpstreamSettings {
63    pattern: String,
64    ca_cert: Vec<PathBuf>,
65    verify_upstream: Option<bool>,
66}
67
68/// A [`ServerCertVerifier`] that accepts all server certificates without
69/// validation. Used when `verify_upstream` is `false`.
70#[derive(Debug)]
71struct NoVerify;
72
73/// Refresh cached leaf certs shortly before expiry so long-lived sandboxes
74/// do not start serving an already-expired intercept certificate.
75const CERT_REFRESH_WINDOW: Duration = Duration::minutes(5);
76
77//--------------------------------------------------------------------------------------------------
78// Methods
79//--------------------------------------------------------------------------------------------------
80
81impl TlsState {
82    /// Create TLS state from configuration.
83    ///
84    /// CA resolution order:
85    /// 1. User-provided paths (`config.intercept_ca.cert_path` + `config.intercept_ca.key_path`)
86    /// 2. Microsandbox home TLS path (`$MSB_HOME/tls` or `~/.microsandbox/tls`)
87    /// 3. Auto-generate and persist to the microsandbox home TLS path
88    pub fn new(config: TlsConfig, secrets: SecretsConfig) -> Self {
89        let ca = load_or_generate_ca(&config);
90
91        let capacity =
92            NonZeroUsize::new(config.cache.capacity).unwrap_or(NonZeroUsize::new(1000).unwrap());
93        let cert_cache = Mutex::new(LruCache::new(capacity));
94
95        let connector = build_upstream_connector(&config, config.verify_upstream, &[]);
96        let scoped_upstream_connectors = build_scoped_upstream_connectors(&config);
97
98        // Pre-compute lowercased bypass patterns to avoid per-connection allocations.
99        let bypass_patterns = config
100            .bypass
101            .iter()
102            .map(|pattern| DomainPattern::new(pattern))
103            .collect();
104
105        Self {
106            intercept_ca: ca,
107            cert_cache,
108            connector,
109            scoped_upstream_connectors,
110            config,
111            secrets,
112            bypass_patterns,
113        }
114    }
115
116    /// Get or generate a certificate for the given domain.
117    pub fn get_or_generate_cert(&self, domain: &str) -> Result<Arc<DomainCert>, DomainCertError> {
118        let mut cache = match self.cert_cache.lock() {
119            Ok(cache) => cache,
120            Err(poisoned) => {
121                tracing::warn!("TLS certificate cache was poisoned; recovering");
122                poisoned.into_inner()
123            }
124        };
125        if let Some(cert) = cache.get(domain)
126            && cert.expires_at > OffsetDateTime::now_utc() + CERT_REFRESH_WINDOW
127        {
128            return Ok(cert.clone());
129        }
130
131        let cert = Arc::new(certgen::generate_domain_cert(
132            domain,
133            &self.intercept_ca,
134            self.config.cache.validity_hours,
135        )?);
136        cache.put(domain.to_string(), cert.clone());
137        Ok(cert)
138    }
139
140    /// Check if a domain should bypass TLS interception.
141    pub fn should_bypass(&self, sni: &str) -> bool {
142        let sni_lower = normalize_domain(sni);
143        self.bypass_patterns
144            .iter()
145            .any(|pattern| pattern.matches_normalized(&sni_lower))
146    }
147
148    /// Select the upstream connector for the given server name.
149    pub fn upstream_connector_for(&self, sni: &str) -> &TlsConnector {
150        let sni_lower = normalize_domain(sni);
151        let mut best = None;
152
153        for scoped in &self.scoped_upstream_connectors {
154            if !scoped.pattern.matches_normalized(&sni_lower) {
155                continue;
156            }
157            let specificity = scoped.pattern.specificity();
158            if best
159                .map(|(_, best_specificity)| specificity > best_specificity)
160                .unwrap_or(true)
161            {
162                best = Some((scoped, specificity));
163            }
164        }
165
166        best.map_or(&self.connector, |(scoped, _)| &scoped.connector)
167    }
168
169    /// Get the CA certificate PEM bytes for guest installation.
170    pub fn ca_cert_pem(&self) -> Vec<u8> {
171        self.intercept_ca.cert_pem()
172    }
173}
174
175impl DomainPattern {
176    fn new(pattern: &str) -> Self {
177        let lower = normalize_domain(pattern);
178        if let Some(suffix) = lower.strip_prefix("*.") {
179            let dotted = format!(".{suffix}");
180            DomainPattern::Wildcard {
181                suffix: suffix.to_string(),
182                dotted,
183            }
184        } else {
185            DomainPattern::Exact(lower)
186        }
187    }
188
189    fn matches_normalized(&self, sni_lower: &str) -> bool {
190        match self {
191            DomainPattern::Exact(exact) => sni_lower == exact,
192            DomainPattern::Wildcard { suffix, dotted } => {
193                sni_lower == suffix || sni_lower.ends_with(dotted.as_str())
194            }
195        }
196    }
197
198    fn specificity(&self) -> usize {
199        match self {
200            DomainPattern::Exact(exact) => exact.len() + 1,
201            DomainPattern::Wildcard { suffix, .. } => suffix.len(),
202        }
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::super::config::{ScopedUpstreamCaCert, ScopedVerifyUpstream};
209    use super::*;
210
211    use crate::secrets::config::SecretsConfig;
212
213    #[test]
214    fn regenerates_cached_domain_cert_when_near_expiry() {
215        let _ = rustls::crypto::ring::default_provider().install_default();
216        let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
217        let first = state.get_or_generate_cert("openrouter.ai").unwrap();
218        let original_expires_at = first.expires_at;
219
220        {
221            let mut cache = state.cert_cache.lock().unwrap();
222            let stale = Arc::new(DomainCert {
223                chain: first.chain.clone(),
224                key: first.key.clone_key(),
225                expires_at: OffsetDateTime::now_utc() + Duration::seconds(30),
226                server_config: first.server_config.clone(),
227            });
228            cache.put("openrouter.ai".to_string(), stale);
229        }
230
231        let refreshed = state.get_or_generate_cert("openrouter.ai").unwrap();
232        assert!(refreshed.expires_at > OffsetDateTime::now_utc() + Duration::hours(23));
233        assert!(refreshed.expires_at > original_expires_at - Duration::minutes(10));
234    }
235
236    #[test]
237    fn invalid_domain_cert_request_does_not_poison_cache() {
238        let _ = rustls::crypto::ring::default_provider().install_default();
239        let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
240
241        assert!(state.get_or_generate_cert("snowman.☃").is_err());
242        assert!(state.get_or_generate_cert("openrouter.ai").is_ok());
243    }
244
245    #[test]
246    fn default_ca_dir_uses_microsandbox_home_tls_subdir() {
247        let home = PathBuf::from("isolated-msb-home");
248
249        assert_eq!(
250            default_ca_dir_from_home(&home),
251            home.join(microsandbox_utils::TLS_SUBDIR)
252        );
253    }
254
255    #[test]
256    fn domain_patterns_match_exact_and_wildcard_hosts() {
257        let exact = DomainPattern::new("api.internal.");
258        assert!(exact.matches_normalized("api.internal"));
259        assert!(!exact.matches_normalized("other.api.internal"));
260
261        let wildcard = DomainPattern::new("*.internal");
262        assert!(wildcard.matches_normalized("internal"));
263        assert!(wildcard.matches_normalized("api.internal"));
264        assert!(!wildcard.matches_normalized("notinternal"));
265    }
266
267    #[test]
268    fn domain_patterns_score_exact_as_more_specific() {
269        let exact = DomainPattern::new("api.internal");
270        let wildcard = DomainPattern::new("*.internal");
271
272        assert!(exact.specificity() > wildcard.specificity());
273    }
274
275    #[test]
276    fn scoped_upstream_settings_group_ca_and_verify_by_pattern() {
277        let mut config = TlsConfig::default();
278        config.scoped_upstream_ca_cert.push(ScopedUpstreamCaCert {
279            pattern: "*.internal".to_string(),
280            path: PathBuf::from("/tmp/one.pem"),
281        });
282        config.scoped_upstream_ca_cert.push(ScopedUpstreamCaCert {
283            pattern: "*.internal.".to_string(),
284            path: PathBuf::from("/tmp/two.pem"),
285        });
286        config.scoped_verify_upstream.push(ScopedVerifyUpstream {
287            pattern: "*.internal".to_string(),
288            verify: false,
289        });
290
291        let settings = grouped_scoped_upstream_settings(&config);
292
293        assert_eq!(settings.len(), 1);
294        assert_eq!(settings[0].pattern, "*.internal");
295        assert_eq!(
296            settings[0].ca_cert,
297            vec![PathBuf::from("/tmp/one.pem"), PathBuf::from("/tmp/two.pem")]
298        );
299        assert_eq!(settings[0].verify_upstream, Some(false));
300    }
301
302    #[test]
303    fn upstream_connector_for_selects_scoped_no_verify_connector() {
304        let _ = rustls::crypto::ring::default_provider().install_default();
305        let mut config = TlsConfig::default();
306        config.scoped_verify_upstream.push(ScopedVerifyUpstream {
307            pattern: "*.internal".to_string(),
308            verify: false,
309        });
310        let state = TlsState::new(config, SecretsConfig::default());
311
312        let default = &state.connector as *const TlsConnector;
313        let scoped = state.upstream_connector_for("api.internal") as *const TlsConnector;
314        let unmatched = state.upstream_connector_for("api.example.com") as *const TlsConnector;
315
316        assert_ne!(default, scoped);
317        assert_eq!(default, unmatched);
318    }
319}
320
321//--------------------------------------------------------------------------------------------------
322// Trait Implementations
323//--------------------------------------------------------------------------------------------------
324
325impl ServerCertVerifier for NoVerify {
326    fn verify_server_cert(
327        &self,
328        _end_entity: &CertificateDer<'_>,
329        _intermediates: &[CertificateDer<'_>],
330        _server_name: &ServerName<'_>,
331        _ocsp_response: &[u8],
332        _now: UnixTime,
333    ) -> Result<ServerCertVerified, rustls::Error> {
334        Ok(ServerCertVerified::assertion())
335    }
336
337    fn verify_tls12_signature(
338        &self,
339        _message: &[u8],
340        _cert: &CertificateDer<'_>,
341        _dss: &DigitallySignedStruct,
342    ) -> Result<HandshakeSignatureValid, rustls::Error> {
343        Ok(HandshakeSignatureValid::assertion())
344    }
345
346    fn verify_tls13_signature(
347        &self,
348        _message: &[u8],
349        _cert: &CertificateDer<'_>,
350        _dss: &DigitallySignedStruct,
351    ) -> Result<HandshakeSignatureValid, rustls::Error> {
352        Ok(HandshakeSignatureValid::assertion())
353    }
354
355    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
356        static SCHEMES: std::sync::OnceLock<Vec<rustls::SignatureScheme>> =
357            std::sync::OnceLock::new();
358        SCHEMES
359            .get_or_init(|| {
360                rustls::crypto::ring::default_provider()
361                    .signature_verification_algorithms
362                    .supported_schemes()
363            })
364            .clone()
365    }
366}
367
368//--------------------------------------------------------------------------------------------------
369// Functions
370//--------------------------------------------------------------------------------------------------
371
372/// Build the upstream TLS connector based on configuration.
373///
374/// When `verify_upstream` is true, loads the system's native root certificates.
375/// When false, uses a permissive verifier that accepts all server certificates.
376fn build_upstream_connector(
377    config: &TlsConfig,
378    verify_upstream: bool,
379    scoped_ca_cert: &[PathBuf],
380) -> TlsConnector {
381    let client_config = if verify_upstream {
382        let mut root_store = rustls::RootCertStore::empty();
383        let certs = rustls_native_certs::load_native_certs();
384        if !certs.errors.is_empty() {
385            tracing::warn!(
386                count = certs.errors.len(),
387                "errors loading native certificates"
388            );
389        }
390        let mut added = 0usize;
391        for cert in certs.certs {
392            if root_store.add(cert).is_ok() {
393                added += 1;
394            }
395        }
396        if added == 0 {
397            tracing::error!("no native root certificates loaded — all upstream TLS will fail");
398        }
399
400        load_upstream_ca_certificates(&mut root_store, &config.upstream_ca_cert);
401        load_upstream_ca_certificates(&mut root_store, scoped_ca_cert);
402
403        rustls::ClientConfig::builder()
404            .with_root_certificates(root_store)
405            .with_no_client_auth()
406    } else {
407        rustls::ClientConfig::builder()
408            .dangerous()
409            .with_custom_certificate_verifier(Arc::new(NoVerify))
410            .with_no_client_auth()
411    };
412
413    TlsConnector::from(Arc::new(client_config))
414}
415
416/// Build host-scoped upstream TLS connectors from grouped scoped settings.
417fn build_scoped_upstream_connectors(config: &TlsConfig) -> Vec<ScopedUpstreamConnector> {
418    grouped_scoped_upstream_settings(config)
419        .into_iter()
420        .filter_map(|settings| {
421            let verify_upstream = settings.verify_upstream.unwrap_or(config.verify_upstream);
422            if verify_upstream == config.verify_upstream && settings.ca_cert.is_empty() {
423                return None;
424            }
425
426            Some(ScopedUpstreamConnector {
427                pattern: DomainPattern::new(&settings.pattern),
428                connector: build_upstream_connector(config, verify_upstream, &settings.ca_cert),
429            })
430        })
431        .collect()
432}
433
434/// Group repeated scoped upstream settings by host pattern while preserving first-seen order.
435fn grouped_scoped_upstream_settings(config: &TlsConfig) -> Vec<ScopedUpstreamSettings> {
436    let mut grouped = Vec::<ScopedUpstreamSettings>::new();
437    let mut indexes = HashMap::<String, usize>::new();
438
439    for scoped in &config.scoped_upstream_ca_cert {
440        let index = scoped_settings_index(&mut grouped, &mut indexes, &scoped.pattern);
441        grouped[index].ca_cert.push(scoped.path.clone());
442    }
443
444    for scoped in &config.scoped_verify_upstream {
445        let index = scoped_settings_index(&mut grouped, &mut indexes, &scoped.pattern);
446        grouped[index].verify_upstream = Some(scoped.verify);
447    }
448
449    grouped
450}
451
452/// Return the grouped settings index for `pattern`, creating it if needed.
453fn scoped_settings_index(
454    grouped: &mut Vec<ScopedUpstreamSettings>,
455    indexes: &mut HashMap<String, usize>,
456    pattern: &str,
457) -> usize {
458    let normalized = normalize_domain(pattern);
459    if let Some(index) = indexes.get(&normalized) {
460        return *index;
461    }
462
463    let index = grouped.len();
464    indexes.insert(normalized, index);
465    grouped.push(ScopedUpstreamSettings {
466        pattern: pattern.to_string(),
467        ca_cert: Vec::new(),
468        verify_upstream: None,
469    });
470    index
471}
472
473/// Load extra upstream CA certificates into the provided root store.
474fn load_upstream_ca_certificates(root_store: &mut rustls::RootCertStore, paths: &[PathBuf]) {
475    for path in paths {
476        match std::fs::read(path) {
477            Ok(pem_data) => {
478                let mut extra_added = 0usize;
479                for cert in rustls_pemfile::certs(&mut pem_data.as_slice()).flatten() {
480                    if root_store.add(cert).is_ok() {
481                        extra_added += 1;
482                    }
483                }
484                tracing::info!(
485                    path = %path.display(),
486                    count = extra_added,
487                    "loaded upstream CA certificates"
488                );
489            }
490            Err(e) => {
491                tracing::error!(
492                    path = %path.display(),
493                    error = %e,
494                    "failed to read upstream CA certificate file"
495                );
496            }
497        }
498    }
499}
500
501/// Normalize host patterns and SNI names for matching.
502fn normalize_domain(domain: &str) -> String {
503    domain.trim_end_matches('.').to_ascii_lowercase()
504}
505
506/// Load or generate a CA based on the TLS configuration.
507///
508/// Resolution order:
509/// 1. User-provided paths (`cert_path` + `key_path`)
510/// 2. Microsandbox home TLS path (`$MSB_HOME/tls` or `~/.microsandbox/tls`)
511/// 3. Auto-generate and persist to the microsandbox home TLS path
512fn load_or_generate_ca(config: &TlsConfig) -> CertAuthority {
513    // Warn if only one of cert_path/key_path is set (likely a config error).
514    if config.intercept_ca.cert_path.is_some() != config.intercept_ca.key_path.is_some() {
515        tracing::warn!(
516            "incomplete CA config: both cert_path and key_path must be set together, ignoring"
517        );
518    }
519
520    // 1. Try user-provided paths.
521    if let (Some(cert_path), Some(key_path)) = (
522        &config.intercept_ca.cert_path,
523        &config.intercept_ca.key_path,
524    ) {
525        match (std::fs::read(cert_path), std::fs::read(key_path)) {
526            (Ok(cert_pem), Ok(key_pem)) => match CertAuthority::load(&cert_pem, &key_pem) {
527                Ok(ca) => {
528                    tracing::info!("loaded user-provided CA from {:?}", cert_path);
529                    return ca;
530                }
531                Err(e) => {
532                    tracing::error!(
533                        error = %e,
534                        "failed to load user-provided CA, falling back to auto-generate"
535                    );
536                }
537            },
538            _ => {
539                tracing::error!(
540                    "failed to read CA files from {:?} / {:?}, falling back to auto-generate",
541                    cert_path,
542                    key_path,
543                );
544            }
545        }
546    }
547
548    // 2. Try the same microsandbox home root used by cache/db/logs/metrics.
549    let default_dir = default_ca_dir();
550    let cert_path = default_dir.join("ca.crt");
551    let key_path = default_dir.join("ca.key");
552
553    if cert_path.exists()
554        && key_path.exists()
555        && let (Ok(cert_pem), Ok(key_pem)) = (std::fs::read(&cert_path), std::fs::read(&key_path))
556        && let Ok(ca) = CertAuthority::load(&cert_pem, &key_pem)
557    {
558        tracing::debug!("loaded persisted CA from {:?}", cert_path);
559        return ca;
560    }
561
562    // 3. Auto-generate and persist.
563    let ca = CertAuthority::generate();
564    if let Err(e) = std::fs::create_dir_all(&default_dir) {
565        tracing::warn!(error = %e, "failed to create CA directory, CA will not persist");
566    } else {
567        if let Err(e) = std::fs::write(&cert_path, ca.cert_pem()) {
568            tracing::warn!(error = %e, "failed to persist CA certificate");
569        }
570        if let Err(e) = write_key_file(&key_path, &ca.key_pem()) {
571            tracing::warn!(error = %e, "failed to persist CA key");
572        } else {
573            tracing::info!("generated and persisted CA to {:?}", default_dir);
574        }
575    }
576    ca
577}
578
579/// Default CA persistence directory under the resolved microsandbox home.
580fn default_ca_dir() -> PathBuf {
581    default_ca_dir_from_home(microsandbox_utils::resolve_home())
582}
583
584/// Build the CA directory from a known microsandbox home.
585fn default_ca_dir_from_home(home: impl AsRef<Path>) -> PathBuf {
586    home.as_ref().join(TLS_SUBDIR)
587}
588
589/// Write a private key file with restricted permissions (0o600) from creation.
590///
591/// Uses `OpenOptions` with mode set at creation time to avoid the TOCTOU race
592/// of write-then-chmod where the file is briefly world-readable.
593fn write_key_file(path: &Path, data: &[u8]) -> std::io::Result<()> {
594    #[cfg(unix)]
595    {
596        use std::io::Write;
597        use std::os::unix::fs::OpenOptionsExt;
598        let mut file = std::fs::OpenOptions::new()
599            .write(true)
600            .create(true)
601            .truncate(true)
602            .mode(0o600)
603            .open(path)?;
604        file.write_all(data)?;
605    }
606    #[cfg(not(unix))]
607    {
608        std::fs::write(path, data)?;
609    }
610    Ok(())
611}