Skip to main content

microsandbox_network/tls/
state.rs

1//! Shared TLS state: CA, certificate cache, and upstream connector.
2
3use std::num::NonZeroUsize;
4use std::sync::{Arc, Mutex};
5
6use lru::LruCache;
7use rustls::DigitallySignedStruct;
8use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
9use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
10use time::{Duration, OffsetDateTime};
11use tokio_rustls::TlsConnector;
12
13use super::ca::CertAuthority;
14use super::certgen::{self, DomainCert, DomainCertError};
15use super::config::TlsConfig;
16use crate::secrets::config::SecretsConfig;
17
18//--------------------------------------------------------------------------------------------------
19// Types
20//--------------------------------------------------------------------------------------------------
21
22/// Shared TLS interception state.
23///
24/// Holds the CA, per-domain certificate cache, upstream TLS connector,
25/// and configuration. Shared across all TLS proxy tasks via `Arc`.
26pub struct TlsState {
27    /// Interception CA for signing per-domain certs presented to the guest.
28    pub intercept_ca: CertAuthority,
29    /// LRU cache of generated domain certificates.
30    cert_cache: Mutex<LruCache<String, Arc<DomainCert>>>,
31    /// TLS connector for upstream (real server) connections.
32    pub connector: TlsConnector,
33    /// TLS configuration.
34    pub config: TlsConfig,
35    /// Secrets configuration for placeholder substitution.
36    pub secrets: SecretsConfig,
37    /// Pre-computed lowercased bypass patterns for efficient matching.
38    bypass_patterns: Vec<BypassPattern>,
39}
40
41/// A pre-processed bypass pattern (avoids per-connection allocations).
42enum BypassPattern {
43    /// Exact domain match (lowercased).
44    Exact(String),
45    /// Wildcard suffix match. `suffix` is the bare suffix, `dotted` is `.suffix`
46    /// (pre-computed to avoid per-connection `format!` allocations).
47    Wildcard { suffix: String, dotted: String },
48}
49
50/// A [`ServerCertVerifier`] that accepts all server certificates without
51/// validation. Used when `verify_upstream` is `false`.
52#[derive(Debug)]
53struct NoVerify;
54
55/// Refresh cached leaf certs shortly before expiry so long-lived sandboxes
56/// do not start serving an already-expired intercept certificate.
57const CERT_REFRESH_WINDOW: Duration = Duration::minutes(5);
58
59//--------------------------------------------------------------------------------------------------
60// Methods
61//--------------------------------------------------------------------------------------------------
62
63impl TlsState {
64    /// Create TLS state from configuration.
65    ///
66    /// CA resolution order:
67    /// 1. User-provided paths (`config.intercept_ca.cert_path` + `config.intercept_ca.key_path`)
68    /// 2. Default persistence path (`~/.microsandbox/tls/ca.{crt,key}`)
69    /// 3. Auto-generate and persist to default path
70    pub fn new(config: TlsConfig, secrets: SecretsConfig) -> Self {
71        let ca = load_or_generate_ca(&config);
72
73        let capacity =
74            NonZeroUsize::new(config.cache.capacity).unwrap_or(NonZeroUsize::new(1000).unwrap());
75        let cert_cache = Mutex::new(LruCache::new(capacity));
76
77        let connector = build_upstream_connector(&config);
78
79        // Pre-compute lowercased bypass patterns to avoid per-connection allocations.
80        let bypass_patterns = config
81            .bypass
82            .iter()
83            .map(|pattern| {
84                let lower = pattern.to_lowercase();
85                if let Some(suffix) = lower.strip_prefix("*.") {
86                    let dotted = format!(".{suffix}");
87                    BypassPattern::Wildcard {
88                        suffix: suffix.to_string(),
89                        dotted,
90                    }
91                } else {
92                    BypassPattern::Exact(lower)
93                }
94            })
95            .collect();
96
97        Self {
98            intercept_ca: ca,
99            cert_cache,
100            connector,
101            config,
102            secrets,
103            bypass_patterns,
104        }
105    }
106
107    /// Get or generate a certificate for the given domain.
108    pub fn get_or_generate_cert(&self, domain: &str) -> Result<Arc<DomainCert>, DomainCertError> {
109        let mut cache = match self.cert_cache.lock() {
110            Ok(cache) => cache,
111            Err(poisoned) => {
112                tracing::warn!("TLS certificate cache was poisoned; recovering");
113                poisoned.into_inner()
114            }
115        };
116        if let Some(cert) = cache.get(domain)
117            && cert.expires_at > OffsetDateTime::now_utc() + CERT_REFRESH_WINDOW
118        {
119            return Ok(cert.clone());
120        }
121
122        let cert = Arc::new(certgen::generate_domain_cert(
123            domain,
124            &self.intercept_ca,
125            self.config.cache.validity_hours,
126        )?);
127        cache.put(domain.to_string(), cert.clone());
128        Ok(cert)
129    }
130
131    /// Check if a domain should bypass TLS interception.
132    pub fn should_bypass(&self, sni: &str) -> bool {
133        let sni_lower = sni.to_lowercase();
134        self.bypass_patterns.iter().any(|pattern| match pattern {
135            BypassPattern::Exact(exact) => sni_lower == *exact,
136            BypassPattern::Wildcard { suffix, dotted } => {
137                sni_lower == *suffix || sni_lower.ends_with(dotted.as_str())
138            }
139        })
140    }
141
142    /// Get the CA certificate PEM bytes for guest installation.
143    pub fn ca_cert_pem(&self) -> Vec<u8> {
144        self.intercept_ca.cert_pem()
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::secrets::config::SecretsConfig;
152
153    #[test]
154    fn regenerates_cached_domain_cert_when_near_expiry() {
155        let _ = rustls::crypto::ring::default_provider().install_default();
156        let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
157        let first = state.get_or_generate_cert("openrouter.ai").unwrap();
158        let original_expires_at = first.expires_at;
159
160        {
161            let mut cache = state.cert_cache.lock().unwrap();
162            let stale = Arc::new(DomainCert {
163                chain: first.chain.clone(),
164                key: first.key.clone_key(),
165                expires_at: OffsetDateTime::now_utc() + Duration::seconds(30),
166                server_config: first.server_config.clone(),
167            });
168            cache.put("openrouter.ai".to_string(), stale);
169        }
170
171        let refreshed = state.get_or_generate_cert("openrouter.ai").unwrap();
172        assert!(refreshed.expires_at > OffsetDateTime::now_utc() + Duration::hours(23));
173        assert!(refreshed.expires_at > original_expires_at - Duration::minutes(10));
174    }
175
176    #[test]
177    fn invalid_domain_cert_request_does_not_poison_cache() {
178        let _ = rustls::crypto::ring::default_provider().install_default();
179        let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
180
181        assert!(state.get_or_generate_cert("snowman.☃").is_err());
182        assert!(state.get_or_generate_cert("openrouter.ai").is_ok());
183    }
184}
185
186//--------------------------------------------------------------------------------------------------
187// Trait Implementations
188//--------------------------------------------------------------------------------------------------
189
190impl ServerCertVerifier for NoVerify {
191    fn verify_server_cert(
192        &self,
193        _end_entity: &CertificateDer<'_>,
194        _intermediates: &[CertificateDer<'_>],
195        _server_name: &ServerName<'_>,
196        _ocsp_response: &[u8],
197        _now: UnixTime,
198    ) -> Result<ServerCertVerified, rustls::Error> {
199        Ok(ServerCertVerified::assertion())
200    }
201
202    fn verify_tls12_signature(
203        &self,
204        _message: &[u8],
205        _cert: &CertificateDer<'_>,
206        _dss: &DigitallySignedStruct,
207    ) -> Result<HandshakeSignatureValid, rustls::Error> {
208        Ok(HandshakeSignatureValid::assertion())
209    }
210
211    fn verify_tls13_signature(
212        &self,
213        _message: &[u8],
214        _cert: &CertificateDer<'_>,
215        _dss: &DigitallySignedStruct,
216    ) -> Result<HandshakeSignatureValid, rustls::Error> {
217        Ok(HandshakeSignatureValid::assertion())
218    }
219
220    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
221        static SCHEMES: std::sync::OnceLock<Vec<rustls::SignatureScheme>> =
222            std::sync::OnceLock::new();
223        SCHEMES
224            .get_or_init(|| {
225                rustls::crypto::ring::default_provider()
226                    .signature_verification_algorithms
227                    .supported_schemes()
228            })
229            .clone()
230    }
231}
232
233//--------------------------------------------------------------------------------------------------
234// Functions
235//--------------------------------------------------------------------------------------------------
236
237/// Build the upstream TLS connector based on configuration.
238///
239/// When `verify_upstream` is true, loads the system's native root certificates.
240/// When false, uses a permissive verifier that accepts all server certificates.
241fn build_upstream_connector(config: &TlsConfig) -> TlsConnector {
242    let client_config = if config.verify_upstream {
243        let mut root_store = rustls::RootCertStore::empty();
244        let certs = rustls_native_certs::load_native_certs();
245        if !certs.errors.is_empty() {
246            tracing::warn!(
247                count = certs.errors.len(),
248                "errors loading native certificates"
249            );
250        }
251        let mut added = 0usize;
252        for cert in certs.certs {
253            if root_store.add(cert).is_ok() {
254                added += 1;
255            }
256        }
257        if added == 0 {
258            tracing::error!("no native root certificates loaded — all upstream TLS will fail");
259        }
260
261        // Load extra CA certificates from user-provided PEM files.
262        for path in &config.upstream_ca_cert {
263            match std::fs::read(path) {
264                Ok(pem_data) => {
265                    let mut extra_added = 0usize;
266                    for cert in rustls_pemfile::certs(&mut pem_data.as_slice()).flatten() {
267                        if root_store.add(cert).is_ok() {
268                            extra_added += 1;
269                        }
270                    }
271                    tracing::info!(
272                        path = %path.display(),
273                        count = extra_added,
274                        "loaded upstream CA certificates"
275                    );
276                }
277                Err(e) => {
278                    tracing::error!(
279                        path = %path.display(),
280                        error = %e,
281                        "failed to read upstream CA certificate file"
282                    );
283                }
284            }
285        }
286
287        rustls::ClientConfig::builder()
288            .with_root_certificates(root_store)
289            .with_no_client_auth()
290    } else {
291        rustls::ClientConfig::builder()
292            .dangerous()
293            .with_custom_certificate_verifier(Arc::new(NoVerify))
294            .with_no_client_auth()
295    };
296
297    TlsConnector::from(Arc::new(client_config))
298}
299
300/// Load or generate a CA based on the TLS configuration.
301///
302/// Resolution order:
303/// 1. User-provided paths (`cert_path` + `key_path`)
304/// 2. Default persistence path (`~/.microsandbox/tls/ca.{crt,key}`)
305/// 3. Auto-generate and persist to default path
306fn load_or_generate_ca(config: &TlsConfig) -> CertAuthority {
307    // Warn if only one of cert_path/key_path is set (likely a config error).
308    if config.intercept_ca.cert_path.is_some() != config.intercept_ca.key_path.is_some() {
309        tracing::warn!(
310            "incomplete CA config: both cert_path and key_path must be set together, ignoring"
311        );
312    }
313
314    // 1. Try user-provided paths.
315    if let (Some(cert_path), Some(key_path)) = (
316        &config.intercept_ca.cert_path,
317        &config.intercept_ca.key_path,
318    ) {
319        match (std::fs::read(cert_path), std::fs::read(key_path)) {
320            (Ok(cert_pem), Ok(key_pem)) => match CertAuthority::load(&cert_pem, &key_pem) {
321                Ok(ca) => {
322                    tracing::info!("loaded user-provided CA from {:?}", cert_path);
323                    return ca;
324                }
325                Err(e) => {
326                    tracing::error!(
327                        error = %e,
328                        "failed to load user-provided CA, falling back to auto-generate"
329                    );
330                }
331            },
332            _ => {
333                tracing::error!(
334                    "failed to read CA files from {:?} / {:?}, falling back to auto-generate",
335                    cert_path,
336                    key_path,
337                );
338            }
339        }
340    }
341
342    // 2. Try default persistence path.
343    if let Some(default_dir) = default_ca_dir() {
344        let cert_path = default_dir.join("ca.crt");
345        let key_path = default_dir.join("ca.key");
346
347        if cert_path.exists()
348            && key_path.exists()
349            && let (Ok(cert_pem), Ok(key_pem)) =
350                (std::fs::read(&cert_path), std::fs::read(&key_path))
351            && let Ok(ca) = CertAuthority::load(&cert_pem, &key_pem)
352        {
353            tracing::debug!("loaded persisted CA from {:?}", cert_path);
354            return ca;
355        }
356
357        // 3. Auto-generate and persist.
358        let ca = CertAuthority::generate();
359        if let Err(e) = std::fs::create_dir_all(&default_dir) {
360            tracing::warn!(error = %e, "failed to create CA directory, CA will not persist");
361        } else {
362            if let Err(e) = std::fs::write(&cert_path, ca.cert_pem()) {
363                tracing::warn!(error = %e, "failed to persist CA certificate");
364            }
365            if let Err(e) = write_key_file(&key_path, &ca.key_pem()) {
366                tracing::warn!(error = %e, "failed to persist CA key");
367            } else {
368                tracing::info!("generated and persisted CA to {:?}", default_dir);
369            }
370        }
371        return ca;
372    }
373
374    // Fallback: generate without persistence.
375    tracing::warn!("could not determine CA persistence path, generating ephemeral CA");
376    CertAuthority::generate()
377}
378
379/// Default CA persistence directory: `~/.microsandbox/tls/`.
380fn default_ca_dir() -> Option<std::path::PathBuf> {
381    dirs::home_dir().map(|h| h.join(".microsandbox").join("tls"))
382}
383
384/// Write a private key file with restricted permissions (0o600) from creation.
385///
386/// Uses `OpenOptions` with mode set at creation time to avoid the TOCTOU race
387/// of write-then-chmod where the file is briefly world-readable.
388fn write_key_file(path: &std::path::Path, data: &[u8]) -> std::io::Result<()> {
389    use std::io::Write;
390    #[cfg(unix)]
391    {
392        use std::os::unix::fs::OpenOptionsExt;
393        let mut file = std::fs::OpenOptions::new()
394            .write(true)
395            .create(true)
396            .truncate(true)
397            .mode(0o600)
398            .open(path)?;
399        file.write_all(data)?;
400    }
401    #[cfg(not(unix))]
402    {
403        std::fs::write(path, data)?;
404    }
405    Ok(())
406}