Skip to main content

microsandbox_network/tls/
state.rs

1//! Shared TLS state: CA, certificate cache, and upstream connector.
2
3use std::num::NonZeroUsize;
4use std::sync::{Arc, Mutex};
5
6use lru::LruCache;
7use rustls::DigitallySignedStruct;
8use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
9use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
10use time::{Duration, OffsetDateTime};
11use tokio_rustls::TlsConnector;
12
13use super::ca::CertAuthority;
14use super::certgen::{self, DomainCert};
15use super::config::TlsConfig;
16use crate::secrets::config::SecretsConfig;
17
18//--------------------------------------------------------------------------------------------------
19// Types
20//--------------------------------------------------------------------------------------------------
21
22/// Shared TLS interception state.
23///
24/// Holds the CA, per-domain certificate cache, upstream TLS connector,
25/// and configuration. Shared across all TLS proxy tasks via `Arc`.
26pub struct TlsState {
27    /// Interception CA for signing per-domain certs presented to the guest.
28    pub intercept_ca: CertAuthority,
29    /// LRU cache of generated domain certificates.
30    cert_cache: Mutex<LruCache<String, Arc<DomainCert>>>,
31    /// TLS connector for upstream (real server) connections.
32    pub connector: TlsConnector,
33    /// TLS configuration.
34    pub config: TlsConfig,
35    /// Secrets configuration for placeholder substitution.
36    pub secrets: SecretsConfig,
37    /// Pre-computed lowercased bypass patterns for efficient matching.
38    bypass_patterns: Vec<BypassPattern>,
39}
40
41/// A pre-processed bypass pattern (avoids per-connection allocations).
42enum BypassPattern {
43    /// Exact domain match (lowercased).
44    Exact(String),
45    /// Wildcard suffix match. `suffix` is the bare suffix, `dotted` is `.suffix`
46    /// (pre-computed to avoid per-connection `format!` allocations).
47    Wildcard { suffix: String, dotted: String },
48}
49
50/// A [`ServerCertVerifier`] that accepts all server certificates without
51/// validation. Used when `verify_upstream` is `false`.
52#[derive(Debug)]
53struct NoVerify;
54
55/// Refresh cached leaf certs shortly before expiry so long-lived sandboxes
56/// do not start serving an already-expired intercept certificate.
57const CERT_REFRESH_WINDOW: Duration = Duration::minutes(5);
58
59//--------------------------------------------------------------------------------------------------
60// Methods
61//--------------------------------------------------------------------------------------------------
62
63impl TlsState {
64    /// Create TLS state from configuration.
65    ///
66    /// CA resolution order:
67    /// 1. User-provided paths (`config.intercept_ca.cert_path` + `config.intercept_ca.key_path`)
68    /// 2. Default persistence path (`~/.microsandbox/tls/ca.{crt,key}`)
69    /// 3. Auto-generate and persist to default path
70    pub fn new(config: TlsConfig, secrets: SecretsConfig) -> Self {
71        let ca = load_or_generate_ca(&config);
72
73        let capacity =
74            NonZeroUsize::new(config.cache.capacity).unwrap_or(NonZeroUsize::new(1000).unwrap());
75        let cert_cache = Mutex::new(LruCache::new(capacity));
76
77        let connector = build_upstream_connector(&config);
78
79        // Pre-compute lowercased bypass patterns to avoid per-connection allocations.
80        let bypass_patterns = config
81            .bypass
82            .iter()
83            .map(|pattern| {
84                let lower = pattern.to_lowercase();
85                if let Some(suffix) = lower.strip_prefix("*.") {
86                    let dotted = format!(".{suffix}");
87                    BypassPattern::Wildcard {
88                        suffix: suffix.to_string(),
89                        dotted,
90                    }
91                } else {
92                    BypassPattern::Exact(lower)
93                }
94            })
95            .collect();
96
97        Self {
98            intercept_ca: ca,
99            cert_cache,
100            connector,
101            config,
102            secrets,
103            bypass_patterns,
104        }
105    }
106
107    /// Get or generate a certificate for the given domain.
108    pub fn get_or_generate_cert(&self, domain: &str) -> Arc<DomainCert> {
109        let mut cache = self.cert_cache.lock().unwrap();
110        if let Some(cert) = cache.get(domain)
111            && cert.expires_at > OffsetDateTime::now_utc() + CERT_REFRESH_WINDOW
112        {
113            return cert.clone();
114        }
115
116        let cert = Arc::new(certgen::generate_domain_cert(
117            domain,
118            &self.intercept_ca,
119            self.config.cache.validity_hours,
120        ));
121        cache.put(domain.to_string(), cert.clone());
122        cert
123    }
124
125    /// Check if a domain should bypass TLS interception.
126    pub fn should_bypass(&self, sni: &str) -> bool {
127        let sni_lower = sni.to_lowercase();
128        self.bypass_patterns.iter().any(|pattern| match pattern {
129            BypassPattern::Exact(exact) => sni_lower == *exact,
130            BypassPattern::Wildcard { suffix, dotted } => {
131                sni_lower == *suffix || sni_lower.ends_with(dotted.as_str())
132            }
133        })
134    }
135
136    /// Get the CA certificate PEM bytes for guest installation.
137    pub fn ca_cert_pem(&self) -> Vec<u8> {
138        self.intercept_ca.cert_pem()
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::secrets::config::SecretsConfig;
146
147    #[test]
148    fn regenerates_cached_domain_cert_when_near_expiry() {
149        let _ = rustls::crypto::ring::default_provider().install_default();
150        let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
151        let first = state.get_or_generate_cert("openrouter.ai");
152        let original_expires_at = first.expires_at;
153
154        {
155            let mut cache = state.cert_cache.lock().unwrap();
156            let stale = Arc::new(DomainCert {
157                chain: first.chain.clone(),
158                key: first.key.clone_key(),
159                expires_at: OffsetDateTime::now_utc() + Duration::seconds(30),
160                server_config: first.server_config.clone(),
161            });
162            cache.put("openrouter.ai".to_string(), stale);
163        }
164
165        let refreshed = state.get_or_generate_cert("openrouter.ai");
166        assert!(refreshed.expires_at > OffsetDateTime::now_utc() + Duration::hours(23));
167        assert!(refreshed.expires_at > original_expires_at - Duration::minutes(10));
168    }
169}
170
171//--------------------------------------------------------------------------------------------------
172// Trait Implementations
173//--------------------------------------------------------------------------------------------------
174
175impl ServerCertVerifier for NoVerify {
176    fn verify_server_cert(
177        &self,
178        _end_entity: &CertificateDer<'_>,
179        _intermediates: &[CertificateDer<'_>],
180        _server_name: &ServerName<'_>,
181        _ocsp_response: &[u8],
182        _now: UnixTime,
183    ) -> Result<ServerCertVerified, rustls::Error> {
184        Ok(ServerCertVerified::assertion())
185    }
186
187    fn verify_tls12_signature(
188        &self,
189        _message: &[u8],
190        _cert: &CertificateDer<'_>,
191        _dss: &DigitallySignedStruct,
192    ) -> Result<HandshakeSignatureValid, rustls::Error> {
193        Ok(HandshakeSignatureValid::assertion())
194    }
195
196    fn verify_tls13_signature(
197        &self,
198        _message: &[u8],
199        _cert: &CertificateDer<'_>,
200        _dss: &DigitallySignedStruct,
201    ) -> Result<HandshakeSignatureValid, rustls::Error> {
202        Ok(HandshakeSignatureValid::assertion())
203    }
204
205    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
206        static SCHEMES: std::sync::OnceLock<Vec<rustls::SignatureScheme>> =
207            std::sync::OnceLock::new();
208        SCHEMES
209            .get_or_init(|| {
210                rustls::crypto::ring::default_provider()
211                    .signature_verification_algorithms
212                    .supported_schemes()
213            })
214            .clone()
215    }
216}
217
218//--------------------------------------------------------------------------------------------------
219// Functions
220//--------------------------------------------------------------------------------------------------
221
222/// Build the upstream TLS connector based on configuration.
223///
224/// When `verify_upstream` is true, loads the system's native root certificates.
225/// When false, uses a permissive verifier that accepts all server certificates.
226fn build_upstream_connector(config: &TlsConfig) -> TlsConnector {
227    let client_config = if config.verify_upstream {
228        let mut root_store = rustls::RootCertStore::empty();
229        let certs = rustls_native_certs::load_native_certs();
230        if !certs.errors.is_empty() {
231            tracing::warn!(
232                count = certs.errors.len(),
233                "errors loading native certificates"
234            );
235        }
236        let mut added = 0usize;
237        for cert in certs.certs {
238            if root_store.add(cert).is_ok() {
239                added += 1;
240            }
241        }
242        if added == 0 {
243            tracing::error!("no native root certificates loaded — all upstream TLS will fail");
244        }
245
246        // Load extra CA certificates from user-provided PEM files.
247        for path in &config.upstream_ca_cert {
248            match std::fs::read(path) {
249                Ok(pem_data) => {
250                    let mut extra_added = 0usize;
251                    for cert in rustls_pemfile::certs(&mut pem_data.as_slice()).flatten() {
252                        if root_store.add(cert).is_ok() {
253                            extra_added += 1;
254                        }
255                    }
256                    tracing::info!(
257                        path = %path.display(),
258                        count = extra_added,
259                        "loaded upstream CA certificates"
260                    );
261                }
262                Err(e) => {
263                    tracing::error!(
264                        path = %path.display(),
265                        error = %e,
266                        "failed to read upstream CA certificate file"
267                    );
268                }
269            }
270        }
271
272        rustls::ClientConfig::builder()
273            .with_root_certificates(root_store)
274            .with_no_client_auth()
275    } else {
276        rustls::ClientConfig::builder()
277            .dangerous()
278            .with_custom_certificate_verifier(Arc::new(NoVerify))
279            .with_no_client_auth()
280    };
281
282    TlsConnector::from(Arc::new(client_config))
283}
284
285/// Load or generate a CA based on the TLS configuration.
286///
287/// Resolution order:
288/// 1. User-provided paths (`cert_path` + `key_path`)
289/// 2. Default persistence path (`~/.microsandbox/tls/ca.{crt,key}`)
290/// 3. Auto-generate and persist to default path
291fn load_or_generate_ca(config: &TlsConfig) -> CertAuthority {
292    // Warn if only one of cert_path/key_path is set (likely a config error).
293    if config.intercept_ca.cert_path.is_some() != config.intercept_ca.key_path.is_some() {
294        tracing::warn!(
295            "incomplete CA config: both cert_path and key_path must be set together, ignoring"
296        );
297    }
298
299    // 1. Try user-provided paths.
300    if let (Some(cert_path), Some(key_path)) = (
301        &config.intercept_ca.cert_path,
302        &config.intercept_ca.key_path,
303    ) {
304        match (std::fs::read(cert_path), std::fs::read(key_path)) {
305            (Ok(cert_pem), Ok(key_pem)) => match CertAuthority::load(&cert_pem, &key_pem) {
306                Ok(ca) => {
307                    tracing::info!("loaded user-provided CA from {:?}", cert_path);
308                    return ca;
309                }
310                Err(e) => {
311                    tracing::error!(
312                        error = %e,
313                        "failed to load user-provided CA, falling back to auto-generate"
314                    );
315                }
316            },
317            _ => {
318                tracing::error!(
319                    "failed to read CA files from {:?} / {:?}, falling back to auto-generate",
320                    cert_path,
321                    key_path,
322                );
323            }
324        }
325    }
326
327    // 2. Try default persistence path.
328    if let Some(default_dir) = default_ca_dir() {
329        let cert_path = default_dir.join("ca.crt");
330        let key_path = default_dir.join("ca.key");
331
332        if cert_path.exists()
333            && key_path.exists()
334            && let (Ok(cert_pem), Ok(key_pem)) =
335                (std::fs::read(&cert_path), std::fs::read(&key_path))
336            && let Ok(ca) = CertAuthority::load(&cert_pem, &key_pem)
337        {
338            tracing::debug!("loaded persisted CA from {:?}", cert_path);
339            return ca;
340        }
341
342        // 3. Auto-generate and persist.
343        let ca = CertAuthority::generate();
344        if let Err(e) = std::fs::create_dir_all(&default_dir) {
345            tracing::warn!(error = %e, "failed to create CA directory, CA will not persist");
346        } else {
347            if let Err(e) = std::fs::write(&cert_path, ca.cert_pem()) {
348                tracing::warn!(error = %e, "failed to persist CA certificate");
349            }
350            if let Err(e) = write_key_file(&key_path, &ca.key_pem()) {
351                tracing::warn!(error = %e, "failed to persist CA key");
352            } else {
353                tracing::info!("generated and persisted CA to {:?}", default_dir);
354            }
355        }
356        return ca;
357    }
358
359    // Fallback: generate without persistence.
360    tracing::warn!("could not determine CA persistence path, generating ephemeral CA");
361    CertAuthority::generate()
362}
363
364/// Default CA persistence directory: `~/.microsandbox/tls/`.
365fn default_ca_dir() -> Option<std::path::PathBuf> {
366    dirs::home_dir().map(|h| h.join(".microsandbox").join("tls"))
367}
368
369/// Write a private key file with restricted permissions (0o600) from creation.
370///
371/// Uses `OpenOptions` with mode set at creation time to avoid the TOCTOU race
372/// of write-then-chmod where the file is briefly world-readable.
373fn write_key_file(path: &std::path::Path, data: &[u8]) -> std::io::Result<()> {
374    use std::io::Write;
375    #[cfg(unix)]
376    {
377        use std::os::unix::fs::OpenOptionsExt;
378        let mut file = std::fs::OpenOptions::new()
379            .write(true)
380            .create(true)
381            .truncate(true)
382            .mode(0o600)
383            .open(path)?;
384        file.write_all(data)?;
385    }
386    #[cfg(not(unix))]
387    {
388        std::fs::write(path, data)?;
389    }
390    Ok(())
391}