Skip to main content

microsandbox_network/tls/
state.rs

1//! Shared TLS state: CA, certificate cache, and upstream connector.
2
3use std::num::NonZeroUsize;
4use std::sync::{Arc, Mutex};
5
6use lru::LruCache;
7use rustls::DigitallySignedStruct;
8use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
9use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
10use tokio_rustls::TlsConnector;
11
12use super::ca::CertAuthority;
13use super::certgen::{self, DomainCert};
14use super::config::TlsConfig;
15use crate::secrets::config::SecretsConfig;
16
17//--------------------------------------------------------------------------------------------------
18// Types
19//--------------------------------------------------------------------------------------------------
20
21/// Shared TLS interception state.
22///
23/// Holds the CA, per-domain certificate cache, upstream TLS connector,
24/// and configuration. Shared across all TLS proxy tasks via `Arc`.
25pub struct TlsState {
26    /// Certificate authority for signing per-domain certs.
27    pub ca: CertAuthority,
28    /// LRU cache of generated domain certificates.
29    cert_cache: Mutex<LruCache<String, Arc<DomainCert>>>,
30    /// TLS connector for upstream (real server) connections.
31    pub connector: TlsConnector,
32    /// TLS configuration.
33    pub config: TlsConfig,
34    /// Secrets configuration for placeholder substitution.
35    pub secrets: SecretsConfig,
36    /// Pre-computed lowercased bypass patterns for efficient matching.
37    bypass_patterns: Vec<BypassPattern>,
38}
39
40/// A pre-processed bypass pattern (avoids per-connection allocations).
41enum BypassPattern {
42    /// Exact domain match (lowercased).
43    Exact(String),
44    /// Wildcard suffix match. `suffix` is the bare suffix, `dotted` is `.suffix`
45    /// (pre-computed to avoid per-connection `format!` allocations).
46    Wildcard { suffix: String, dotted: String },
47}
48
49/// A [`ServerCertVerifier`] that accepts all server certificates without
50/// validation. Used when `verify_upstream` is `false`.
51#[derive(Debug)]
52struct NoVerify;
53
54//--------------------------------------------------------------------------------------------------
55// Methods
56//--------------------------------------------------------------------------------------------------
57
58impl TlsState {
59    /// Create TLS state from configuration.
60    ///
61    /// CA resolution order:
62    /// 1. User-provided paths (`config.ca.cert_path` + `config.ca.key_path`)
63    /// 2. Default persistence path (`~/.microsandbox/tls/ca.{crt,key}`)
64    /// 3. Auto-generate and persist to default path
65    pub fn new(config: TlsConfig, secrets: SecretsConfig) -> Self {
66        let ca = load_or_generate_ca(&config);
67
68        let capacity =
69            NonZeroUsize::new(config.cache.capacity).unwrap_or(NonZeroUsize::new(1000).unwrap());
70        let cert_cache = Mutex::new(LruCache::new(capacity));
71
72        let connector = build_upstream_connector(&config);
73
74        // Pre-compute lowercased bypass patterns to avoid per-connection allocations.
75        let bypass_patterns = config
76            .bypass
77            .iter()
78            .map(|pattern| {
79                let lower = pattern.to_lowercase();
80                if let Some(suffix) = lower.strip_prefix("*.") {
81                    let dotted = format!(".{suffix}");
82                    BypassPattern::Wildcard {
83                        suffix: suffix.to_string(),
84                        dotted,
85                    }
86                } else {
87                    BypassPattern::Exact(lower)
88                }
89            })
90            .collect();
91
92        Self {
93            ca,
94            cert_cache,
95            connector,
96            config,
97            secrets,
98            bypass_patterns,
99        }
100    }
101
102    /// Get or generate a certificate for the given domain.
103    pub fn get_or_generate_cert(&self, domain: &str) -> Arc<DomainCert> {
104        let mut cache = self.cert_cache.lock().unwrap();
105        if let Some(cert) = cache.get(domain) {
106            return cert.clone();
107        }
108
109        let cert = Arc::new(certgen::generate_domain_cert(
110            domain,
111            &self.ca,
112            self.config.cache.validity_hours,
113        ));
114        cache.put(domain.to_string(), cert.clone());
115        cert
116    }
117
118    /// Check if a domain should bypass TLS interception.
119    pub fn should_bypass(&self, sni: &str) -> bool {
120        let sni_lower = sni.to_lowercase();
121        self.bypass_patterns.iter().any(|pattern| match pattern {
122            BypassPattern::Exact(exact) => sni_lower == *exact,
123            BypassPattern::Wildcard { suffix, dotted } => {
124                sni_lower == *suffix || sni_lower.ends_with(dotted.as_str())
125            }
126        })
127    }
128
129    /// Get the CA certificate PEM bytes for guest installation.
130    pub fn ca_cert_pem(&self) -> Vec<u8> {
131        self.ca.cert_pem()
132    }
133}
134
135//--------------------------------------------------------------------------------------------------
136// Trait Implementations
137//--------------------------------------------------------------------------------------------------
138
139impl ServerCertVerifier for NoVerify {
140    fn verify_server_cert(
141        &self,
142        _end_entity: &CertificateDer<'_>,
143        _intermediates: &[CertificateDer<'_>],
144        _server_name: &ServerName<'_>,
145        _ocsp_response: &[u8],
146        _now: UnixTime,
147    ) -> Result<ServerCertVerified, rustls::Error> {
148        Ok(ServerCertVerified::assertion())
149    }
150
151    fn verify_tls12_signature(
152        &self,
153        _message: &[u8],
154        _cert: &CertificateDer<'_>,
155        _dss: &DigitallySignedStruct,
156    ) -> Result<HandshakeSignatureValid, rustls::Error> {
157        Ok(HandshakeSignatureValid::assertion())
158    }
159
160    fn verify_tls13_signature(
161        &self,
162        _message: &[u8],
163        _cert: &CertificateDer<'_>,
164        _dss: &DigitallySignedStruct,
165    ) -> Result<HandshakeSignatureValid, rustls::Error> {
166        Ok(HandshakeSignatureValid::assertion())
167    }
168
169    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
170        static SCHEMES: std::sync::OnceLock<Vec<rustls::SignatureScheme>> =
171            std::sync::OnceLock::new();
172        SCHEMES
173            .get_or_init(|| {
174                rustls::crypto::ring::default_provider()
175                    .signature_verification_algorithms
176                    .supported_schemes()
177            })
178            .clone()
179    }
180}
181
182//--------------------------------------------------------------------------------------------------
183// Functions
184//--------------------------------------------------------------------------------------------------
185
186/// Build the upstream TLS connector based on configuration.
187///
188/// When `verify_upstream` is true, loads the system's native root certificates.
189/// When false, uses a permissive verifier that accepts all server certificates.
190fn build_upstream_connector(config: &TlsConfig) -> TlsConnector {
191    let client_config = if config.verify_upstream {
192        let mut root_store = rustls::RootCertStore::empty();
193        let certs = rustls_native_certs::load_native_certs();
194        if !certs.errors.is_empty() {
195            tracing::warn!(
196                count = certs.errors.len(),
197                "errors loading native certificates"
198            );
199        }
200        let mut added = 0usize;
201        for cert in certs.certs {
202            if root_store.add(cert).is_ok() {
203                added += 1;
204            }
205        }
206        if added == 0 {
207            tracing::error!("no native root certificates loaded — all upstream TLS will fail");
208        }
209        rustls::ClientConfig::builder()
210            .with_root_certificates(root_store)
211            .with_no_client_auth()
212    } else {
213        rustls::ClientConfig::builder()
214            .dangerous()
215            .with_custom_certificate_verifier(Arc::new(NoVerify))
216            .with_no_client_auth()
217    };
218
219    TlsConnector::from(Arc::new(client_config))
220}
221
222/// Load or generate a CA based on the TLS configuration.
223///
224/// Resolution order:
225/// 1. User-provided paths (`cert_path` + `key_path`)
226/// 2. Default persistence path (`~/.microsandbox/tls/ca.{crt,key}`)
227/// 3. Auto-generate and persist to default path
228fn load_or_generate_ca(config: &TlsConfig) -> CertAuthority {
229    // Warn if only one of cert_path/key_path is set (likely a config error).
230    if config.ca.cert_path.is_some() != config.ca.key_path.is_some() {
231        tracing::warn!(
232            "incomplete CA config: both cert_path and key_path must be set together, ignoring"
233        );
234    }
235
236    // 1. Try user-provided paths.
237    if let (Some(cert_path), Some(key_path)) = (&config.ca.cert_path, &config.ca.key_path) {
238        match (std::fs::read(cert_path), std::fs::read(key_path)) {
239            (Ok(cert_pem), Ok(key_pem)) => match CertAuthority::load(&cert_pem, &key_pem) {
240                Ok(ca) => {
241                    tracing::info!("loaded user-provided CA from {:?}", cert_path);
242                    return ca;
243                }
244                Err(e) => {
245                    tracing::error!(
246                        error = %e,
247                        "failed to load user-provided CA, falling back to auto-generate"
248                    );
249                }
250            },
251            _ => {
252                tracing::error!(
253                    "failed to read CA files from {:?} / {:?}, falling back to auto-generate",
254                    cert_path,
255                    key_path,
256                );
257            }
258        }
259    }
260
261    // 2. Try default persistence path.
262    if let Some(default_dir) = default_ca_dir() {
263        let cert_path = default_dir.join("ca.crt");
264        let key_path = default_dir.join("ca.key");
265
266        if cert_path.exists()
267            && key_path.exists()
268            && let (Ok(cert_pem), Ok(key_pem)) =
269                (std::fs::read(&cert_path), std::fs::read(&key_path))
270            && let Ok(ca) = CertAuthority::load(&cert_pem, &key_pem)
271        {
272            tracing::debug!("loaded persisted CA from {:?}", cert_path);
273            return ca;
274        }
275
276        // 3. Auto-generate and persist.
277        let ca = CertAuthority::generate();
278        if let Err(e) = std::fs::create_dir_all(&default_dir) {
279            tracing::warn!(error = %e, "failed to create CA directory, CA will not persist");
280        } else {
281            if let Err(e) = std::fs::write(&cert_path, ca.cert_pem()) {
282                tracing::warn!(error = %e, "failed to persist CA certificate");
283            }
284            if let Err(e) = write_key_file(&key_path, &ca.key_pem()) {
285                tracing::warn!(error = %e, "failed to persist CA key");
286            } else {
287                tracing::info!("generated and persisted CA to {:?}", default_dir);
288            }
289        }
290        return ca;
291    }
292
293    // Fallback: generate without persistence.
294    tracing::warn!("could not determine CA persistence path, generating ephemeral CA");
295    CertAuthority::generate()
296}
297
298/// Default CA persistence directory: `~/.microsandbox/tls/`.
299fn default_ca_dir() -> Option<std::path::PathBuf> {
300    dirs::home_dir().map(|h| h.join(".microsandbox").join("tls"))
301}
302
303/// Write a private key file with restricted permissions (0o600) from creation.
304///
305/// Uses `OpenOptions` with mode set at creation time to avoid the TOCTOU race
306/// of write-then-chmod where the file is briefly world-readable.
307fn write_key_file(path: &std::path::Path, data: &[u8]) -> std::io::Result<()> {
308    use std::io::Write;
309    #[cfg(unix)]
310    {
311        use std::os::unix::fs::OpenOptionsExt;
312        let mut file = std::fs::OpenOptions::new()
313            .write(true)
314            .create(true)
315            .truncate(true)
316            .mode(0o600)
317            .open(path)?;
318        file.write_all(data)?;
319    }
320    #[cfg(not(unix))]
321    {
322        std::fs::write(path, data)?;
323    }
324    Ok(())
325}