Skip to main content

microsandbox_network/tls/
state.rs

1//! Shared TLS state: CA, certificate cache, and upstream connector.
2
3use std::num::NonZeroUsize;
4use std::path::{Path, PathBuf};
5use std::sync::{Arc, Mutex};
6
7use lru::LruCache;
8use microsandbox_utils::TLS_SUBDIR;
9use rustls::DigitallySignedStruct;
10use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
11use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
12use time::{Duration, OffsetDateTime};
13use tokio_rustls::TlsConnector;
14
15use super::ca::CertAuthority;
16use super::certgen::{self, DomainCert, DomainCertError};
17use super::config::TlsConfig;
18use crate::secrets::config::SecretsConfig;
19
20//--------------------------------------------------------------------------------------------------
21// Types
22//--------------------------------------------------------------------------------------------------
23
24/// Shared TLS interception state.
25///
26/// Holds the CA, per-domain certificate cache, upstream TLS connector,
27/// and configuration. Shared across all TLS proxy tasks via `Arc`.
28pub struct TlsState {
29    /// Interception CA for signing per-domain certs presented to the guest.
30    pub intercept_ca: CertAuthority,
31    /// LRU cache of generated domain certificates.
32    cert_cache: Mutex<LruCache<String, Arc<DomainCert>>>,
33    /// TLS connector for upstream (real server) connections.
34    pub connector: TlsConnector,
35    /// TLS configuration.
36    pub config: TlsConfig,
37    /// Secrets configuration for placeholder substitution.
38    pub secrets: SecretsConfig,
39    /// Pre-computed lowercased bypass patterns for efficient matching.
40    bypass_patterns: Vec<BypassPattern>,
41}
42
43/// A pre-processed bypass pattern (avoids per-connection allocations).
44enum BypassPattern {
45    /// Exact domain match (lowercased).
46    Exact(String),
47    /// Wildcard suffix match. `suffix` is the bare suffix, `dotted` is `.suffix`
48    /// (pre-computed to avoid per-connection `format!` allocations).
49    Wildcard { suffix: String, dotted: String },
50}
51
52/// A [`ServerCertVerifier`] that accepts all server certificates without
53/// validation. Used when `verify_upstream` is `false`.
54#[derive(Debug)]
55struct NoVerify;
56
57/// Refresh cached leaf certs shortly before expiry so long-lived sandboxes
58/// do not start serving an already-expired intercept certificate.
59const CERT_REFRESH_WINDOW: Duration = Duration::minutes(5);
60
61//--------------------------------------------------------------------------------------------------
62// Methods
63//--------------------------------------------------------------------------------------------------
64
65impl TlsState {
66    /// Create TLS state from configuration.
67    ///
68    /// CA resolution order:
69    /// 1. User-provided paths (`config.intercept_ca.cert_path` + `config.intercept_ca.key_path`)
70    /// 2. Microsandbox home TLS path (`$MSB_HOME/tls` or `~/.microsandbox/tls`)
71    /// 3. Auto-generate and persist to the microsandbox home TLS path
72    pub fn new(config: TlsConfig, secrets: SecretsConfig) -> Self {
73        let ca = load_or_generate_ca(&config);
74
75        let capacity =
76            NonZeroUsize::new(config.cache.capacity).unwrap_or(NonZeroUsize::new(1000).unwrap());
77        let cert_cache = Mutex::new(LruCache::new(capacity));
78
79        let connector = build_upstream_connector(&config);
80
81        // Pre-compute lowercased bypass patterns to avoid per-connection allocations.
82        let bypass_patterns = config
83            .bypass
84            .iter()
85            .map(|pattern| {
86                let lower = pattern.to_lowercase();
87                if let Some(suffix) = lower.strip_prefix("*.") {
88                    let dotted = format!(".{suffix}");
89                    BypassPattern::Wildcard {
90                        suffix: suffix.to_string(),
91                        dotted,
92                    }
93                } else {
94                    BypassPattern::Exact(lower)
95                }
96            })
97            .collect();
98
99        Self {
100            intercept_ca: ca,
101            cert_cache,
102            connector,
103            config,
104            secrets,
105            bypass_patterns,
106        }
107    }
108
109    /// Get or generate a certificate for the given domain.
110    pub fn get_or_generate_cert(&self, domain: &str) -> Result<Arc<DomainCert>, DomainCertError> {
111        let mut cache = match self.cert_cache.lock() {
112            Ok(cache) => cache,
113            Err(poisoned) => {
114                tracing::warn!("TLS certificate cache was poisoned; recovering");
115                poisoned.into_inner()
116            }
117        };
118        if let Some(cert) = cache.get(domain)
119            && cert.expires_at > OffsetDateTime::now_utc() + CERT_REFRESH_WINDOW
120        {
121            return Ok(cert.clone());
122        }
123
124        let cert = Arc::new(certgen::generate_domain_cert(
125            domain,
126            &self.intercept_ca,
127            self.config.cache.validity_hours,
128        )?);
129        cache.put(domain.to_string(), cert.clone());
130        Ok(cert)
131    }
132
133    /// Check if a domain should bypass TLS interception.
134    pub fn should_bypass(&self, sni: &str) -> bool {
135        let sni_lower = sni.to_lowercase();
136        self.bypass_patterns.iter().any(|pattern| match pattern {
137            BypassPattern::Exact(exact) => sni_lower == *exact,
138            BypassPattern::Wildcard { suffix, dotted } => {
139                sni_lower == *suffix || sni_lower.ends_with(dotted.as_str())
140            }
141        })
142    }
143
144    /// Get the CA certificate PEM bytes for guest installation.
145    pub fn ca_cert_pem(&self) -> Vec<u8> {
146        self.intercept_ca.cert_pem()
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use crate::secrets::config::SecretsConfig;
154
155    #[test]
156    fn regenerates_cached_domain_cert_when_near_expiry() {
157        let _ = rustls::crypto::ring::default_provider().install_default();
158        let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
159        let first = state.get_or_generate_cert("openrouter.ai").unwrap();
160        let original_expires_at = first.expires_at;
161
162        {
163            let mut cache = state.cert_cache.lock().unwrap();
164            let stale = Arc::new(DomainCert {
165                chain: first.chain.clone(),
166                key: first.key.clone_key(),
167                expires_at: OffsetDateTime::now_utc() + Duration::seconds(30),
168                server_config: first.server_config.clone(),
169            });
170            cache.put("openrouter.ai".to_string(), stale);
171        }
172
173        let refreshed = state.get_or_generate_cert("openrouter.ai").unwrap();
174        assert!(refreshed.expires_at > OffsetDateTime::now_utc() + Duration::hours(23));
175        assert!(refreshed.expires_at > original_expires_at - Duration::minutes(10));
176    }
177
178    #[test]
179    fn invalid_domain_cert_request_does_not_poison_cache() {
180        let _ = rustls::crypto::ring::default_provider().install_default();
181        let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
182
183        assert!(state.get_or_generate_cert("snowman.☃").is_err());
184        assert!(state.get_or_generate_cert("openrouter.ai").is_ok());
185    }
186
187    #[test]
188    fn default_ca_dir_uses_microsandbox_home_tls_subdir() {
189        let home = PathBuf::from("isolated-msb-home");
190
191        assert_eq!(
192            default_ca_dir_from_home(&home),
193            home.join(microsandbox_utils::TLS_SUBDIR)
194        );
195    }
196}
197
198//--------------------------------------------------------------------------------------------------
199// Trait Implementations
200//--------------------------------------------------------------------------------------------------
201
202impl ServerCertVerifier for NoVerify {
203    fn verify_server_cert(
204        &self,
205        _end_entity: &CertificateDer<'_>,
206        _intermediates: &[CertificateDer<'_>],
207        _server_name: &ServerName<'_>,
208        _ocsp_response: &[u8],
209        _now: UnixTime,
210    ) -> Result<ServerCertVerified, rustls::Error> {
211        Ok(ServerCertVerified::assertion())
212    }
213
214    fn verify_tls12_signature(
215        &self,
216        _message: &[u8],
217        _cert: &CertificateDer<'_>,
218        _dss: &DigitallySignedStruct,
219    ) -> Result<HandshakeSignatureValid, rustls::Error> {
220        Ok(HandshakeSignatureValid::assertion())
221    }
222
223    fn verify_tls13_signature(
224        &self,
225        _message: &[u8],
226        _cert: &CertificateDer<'_>,
227        _dss: &DigitallySignedStruct,
228    ) -> Result<HandshakeSignatureValid, rustls::Error> {
229        Ok(HandshakeSignatureValid::assertion())
230    }
231
232    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
233        static SCHEMES: std::sync::OnceLock<Vec<rustls::SignatureScheme>> =
234            std::sync::OnceLock::new();
235        SCHEMES
236            .get_or_init(|| {
237                rustls::crypto::ring::default_provider()
238                    .signature_verification_algorithms
239                    .supported_schemes()
240            })
241            .clone()
242    }
243}
244
245//--------------------------------------------------------------------------------------------------
246// Functions
247//--------------------------------------------------------------------------------------------------
248
249/// Build the upstream TLS connector based on configuration.
250///
251/// When `verify_upstream` is true, loads the system's native root certificates.
252/// When false, uses a permissive verifier that accepts all server certificates.
253fn build_upstream_connector(config: &TlsConfig) -> TlsConnector {
254    let client_config = if config.verify_upstream {
255        let mut root_store = rustls::RootCertStore::empty();
256        let certs = rustls_native_certs::load_native_certs();
257        if !certs.errors.is_empty() {
258            tracing::warn!(
259                count = certs.errors.len(),
260                "errors loading native certificates"
261            );
262        }
263        let mut added = 0usize;
264        for cert in certs.certs {
265            if root_store.add(cert).is_ok() {
266                added += 1;
267            }
268        }
269        if added == 0 {
270            tracing::error!("no native root certificates loaded — all upstream TLS will fail");
271        }
272
273        // Load extra CA certificates from user-provided PEM files.
274        for path in &config.upstream_ca_cert {
275            match std::fs::read(path) {
276                Ok(pem_data) => {
277                    let mut extra_added = 0usize;
278                    for cert in rustls_pemfile::certs(&mut pem_data.as_slice()).flatten() {
279                        if root_store.add(cert).is_ok() {
280                            extra_added += 1;
281                        }
282                    }
283                    tracing::info!(
284                        path = %path.display(),
285                        count = extra_added,
286                        "loaded upstream CA certificates"
287                    );
288                }
289                Err(e) => {
290                    tracing::error!(
291                        path = %path.display(),
292                        error = %e,
293                        "failed to read upstream CA certificate file"
294                    );
295                }
296            }
297        }
298
299        rustls::ClientConfig::builder()
300            .with_root_certificates(root_store)
301            .with_no_client_auth()
302    } else {
303        rustls::ClientConfig::builder()
304            .dangerous()
305            .with_custom_certificate_verifier(Arc::new(NoVerify))
306            .with_no_client_auth()
307    };
308
309    TlsConnector::from(Arc::new(client_config))
310}
311
312/// Load or generate a CA based on the TLS configuration.
313///
314/// Resolution order:
315/// 1. User-provided paths (`cert_path` + `key_path`)
316/// 2. Microsandbox home TLS path (`$MSB_HOME/tls` or `~/.microsandbox/tls`)
317/// 3. Auto-generate and persist to the microsandbox home TLS path
318fn load_or_generate_ca(config: &TlsConfig) -> CertAuthority {
319    // Warn if only one of cert_path/key_path is set (likely a config error).
320    if config.intercept_ca.cert_path.is_some() != config.intercept_ca.key_path.is_some() {
321        tracing::warn!(
322            "incomplete CA config: both cert_path and key_path must be set together, ignoring"
323        );
324    }
325
326    // 1. Try user-provided paths.
327    if let (Some(cert_path), Some(key_path)) = (
328        &config.intercept_ca.cert_path,
329        &config.intercept_ca.key_path,
330    ) {
331        match (std::fs::read(cert_path), std::fs::read(key_path)) {
332            (Ok(cert_pem), Ok(key_pem)) => match CertAuthority::load(&cert_pem, &key_pem) {
333                Ok(ca) => {
334                    tracing::info!("loaded user-provided CA from {:?}", cert_path);
335                    return ca;
336                }
337                Err(e) => {
338                    tracing::error!(
339                        error = %e,
340                        "failed to load user-provided CA, falling back to auto-generate"
341                    );
342                }
343            },
344            _ => {
345                tracing::error!(
346                    "failed to read CA files from {:?} / {:?}, falling back to auto-generate",
347                    cert_path,
348                    key_path,
349                );
350            }
351        }
352    }
353
354    // 2. Try the same microsandbox home root used by cache/db/logs/metrics.
355    let default_dir = default_ca_dir();
356    let cert_path = default_dir.join("ca.crt");
357    let key_path = default_dir.join("ca.key");
358
359    if cert_path.exists()
360        && key_path.exists()
361        && let (Ok(cert_pem), Ok(key_pem)) = (std::fs::read(&cert_path), std::fs::read(&key_path))
362        && let Ok(ca) = CertAuthority::load(&cert_pem, &key_pem)
363    {
364        tracing::debug!("loaded persisted CA from {:?}", cert_path);
365        return ca;
366    }
367
368    // 3. Auto-generate and persist.
369    let ca = CertAuthority::generate();
370    if let Err(e) = std::fs::create_dir_all(&default_dir) {
371        tracing::warn!(error = %e, "failed to create CA directory, CA will not persist");
372    } else {
373        if let Err(e) = std::fs::write(&cert_path, ca.cert_pem()) {
374            tracing::warn!(error = %e, "failed to persist CA certificate");
375        }
376        if let Err(e) = write_key_file(&key_path, &ca.key_pem()) {
377            tracing::warn!(error = %e, "failed to persist CA key");
378        } else {
379            tracing::info!("generated and persisted CA to {:?}", default_dir);
380        }
381    }
382    ca
383}
384
385/// Default CA persistence directory under the resolved microsandbox home.
386fn default_ca_dir() -> PathBuf {
387    default_ca_dir_from_home(microsandbox_utils::resolve_home())
388}
389
390/// Build the CA directory from a known microsandbox home.
391fn default_ca_dir_from_home(home: impl AsRef<Path>) -> PathBuf {
392    home.as_ref().join(TLS_SUBDIR)
393}
394
395/// Write a private key file with restricted permissions (0o600) from creation.
396///
397/// Uses `OpenOptions` with mode set at creation time to avoid the TOCTOU race
398/// of write-then-chmod where the file is briefly world-readable.
399fn write_key_file(path: &Path, data: &[u8]) -> std::io::Result<()> {
400    #[cfg(unix)]
401    {
402        use std::io::Write;
403        use std::os::unix::fs::OpenOptionsExt;
404        let mut file = std::fs::OpenOptions::new()
405            .write(true)
406            .create(true)
407            .truncate(true)
408            .mode(0o600)
409            .open(path)?;
410        file.write_all(data)?;
411    }
412    #[cfg(not(unix))]
413    {
414        std::fs::write(path, data)?;
415    }
416    Ok(())
417}