microsandbox_network/tls/
state.rs1use std::num::NonZeroUsize;
4use std::sync::{Arc, Mutex};
5
6use lru::LruCache;
7use rustls::DigitallySignedStruct;
8use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
9use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
10use tokio_rustls::TlsConnector;
11
12use super::ca::CertAuthority;
13use super::certgen::{self, DomainCert};
14use super::config::TlsConfig;
15use crate::secrets::config::SecretsConfig;
16
17pub struct TlsState {
26 pub intercept_ca: CertAuthority,
28 cert_cache: Mutex<LruCache<String, Arc<DomainCert>>>,
30 pub connector: TlsConnector,
32 pub config: TlsConfig,
34 pub secrets: SecretsConfig,
36 bypass_patterns: Vec<BypassPattern>,
38}
39
40enum BypassPattern {
42 Exact(String),
44 Wildcard { suffix: String, dotted: String },
47}
48
49#[derive(Debug)]
52struct NoVerify;
53
54impl TlsState {
59 pub fn new(config: TlsConfig, secrets: SecretsConfig) -> Self {
66 let ca = load_or_generate_ca(&config);
67
68 let capacity =
69 NonZeroUsize::new(config.cache.capacity).unwrap_or(NonZeroUsize::new(1000).unwrap());
70 let cert_cache = Mutex::new(LruCache::new(capacity));
71
72 let connector = build_upstream_connector(&config);
73
74 let bypass_patterns = config
76 .bypass
77 .iter()
78 .map(|pattern| {
79 let lower = pattern.to_lowercase();
80 if let Some(suffix) = lower.strip_prefix("*.") {
81 let dotted = format!(".{suffix}");
82 BypassPattern::Wildcard {
83 suffix: suffix.to_string(),
84 dotted,
85 }
86 } else {
87 BypassPattern::Exact(lower)
88 }
89 })
90 .collect();
91
92 Self {
93 intercept_ca: ca,
94 cert_cache,
95 connector,
96 config,
97 secrets,
98 bypass_patterns,
99 }
100 }
101
102 pub fn get_or_generate_cert(&self, domain: &str) -> Arc<DomainCert> {
104 let mut cache = self.cert_cache.lock().unwrap();
105 if let Some(cert) = cache.get(domain) {
106 return cert.clone();
107 }
108
109 let cert = Arc::new(certgen::generate_domain_cert(
110 domain,
111 &self.intercept_ca,
112 self.config.cache.validity_hours,
113 ));
114 cache.put(domain.to_string(), cert.clone());
115 cert
116 }
117
118 pub fn should_bypass(&self, sni: &str) -> bool {
120 let sni_lower = sni.to_lowercase();
121 self.bypass_patterns.iter().any(|pattern| match pattern {
122 BypassPattern::Exact(exact) => sni_lower == *exact,
123 BypassPattern::Wildcard { suffix, dotted } => {
124 sni_lower == *suffix || sni_lower.ends_with(dotted.as_str())
125 }
126 })
127 }
128
129 pub fn ca_cert_pem(&self) -> Vec<u8> {
131 self.intercept_ca.cert_pem()
132 }
133}
134
135impl ServerCertVerifier for NoVerify {
140 fn verify_server_cert(
141 &self,
142 _end_entity: &CertificateDer<'_>,
143 _intermediates: &[CertificateDer<'_>],
144 _server_name: &ServerName<'_>,
145 _ocsp_response: &[u8],
146 _now: UnixTime,
147 ) -> Result<ServerCertVerified, rustls::Error> {
148 Ok(ServerCertVerified::assertion())
149 }
150
151 fn verify_tls12_signature(
152 &self,
153 _message: &[u8],
154 _cert: &CertificateDer<'_>,
155 _dss: &DigitallySignedStruct,
156 ) -> Result<HandshakeSignatureValid, rustls::Error> {
157 Ok(HandshakeSignatureValid::assertion())
158 }
159
160 fn verify_tls13_signature(
161 &self,
162 _message: &[u8],
163 _cert: &CertificateDer<'_>,
164 _dss: &DigitallySignedStruct,
165 ) -> Result<HandshakeSignatureValid, rustls::Error> {
166 Ok(HandshakeSignatureValid::assertion())
167 }
168
169 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
170 static SCHEMES: std::sync::OnceLock<Vec<rustls::SignatureScheme>> =
171 std::sync::OnceLock::new();
172 SCHEMES
173 .get_or_init(|| {
174 rustls::crypto::ring::default_provider()
175 .signature_verification_algorithms
176 .supported_schemes()
177 })
178 .clone()
179 }
180}
181
182fn build_upstream_connector(config: &TlsConfig) -> TlsConnector {
191 let client_config = if config.verify_upstream {
192 let mut root_store = rustls::RootCertStore::empty();
193 let certs = rustls_native_certs::load_native_certs();
194 if !certs.errors.is_empty() {
195 tracing::warn!(
196 count = certs.errors.len(),
197 "errors loading native certificates"
198 );
199 }
200 let mut added = 0usize;
201 for cert in certs.certs {
202 if root_store.add(cert).is_ok() {
203 added += 1;
204 }
205 }
206 if added == 0 {
207 tracing::error!("no native root certificates loaded — all upstream TLS will fail");
208 }
209
210 for path in &config.upstream_ca_cert {
212 match std::fs::read(path) {
213 Ok(pem_data) => {
214 let mut extra_added = 0usize;
215 for cert in rustls_pemfile::certs(&mut pem_data.as_slice()).flatten() {
216 if root_store.add(cert).is_ok() {
217 extra_added += 1;
218 }
219 }
220 tracing::info!(
221 path = %path.display(),
222 count = extra_added,
223 "loaded upstream CA certificates"
224 );
225 }
226 Err(e) => {
227 tracing::error!(
228 path = %path.display(),
229 error = %e,
230 "failed to read upstream CA certificate file"
231 );
232 }
233 }
234 }
235
236 rustls::ClientConfig::builder()
237 .with_root_certificates(root_store)
238 .with_no_client_auth()
239 } else {
240 rustls::ClientConfig::builder()
241 .dangerous()
242 .with_custom_certificate_verifier(Arc::new(NoVerify))
243 .with_no_client_auth()
244 };
245
246 TlsConnector::from(Arc::new(client_config))
247}
248
249fn load_or_generate_ca(config: &TlsConfig) -> CertAuthority {
256 if config.intercept_ca.cert_path.is_some() != config.intercept_ca.key_path.is_some() {
258 tracing::warn!(
259 "incomplete CA config: both cert_path and key_path must be set together, ignoring"
260 );
261 }
262
263 if let (Some(cert_path), Some(key_path)) = (
265 &config.intercept_ca.cert_path,
266 &config.intercept_ca.key_path,
267 ) {
268 match (std::fs::read(cert_path), std::fs::read(key_path)) {
269 (Ok(cert_pem), Ok(key_pem)) => match CertAuthority::load(&cert_pem, &key_pem) {
270 Ok(ca) => {
271 tracing::info!("loaded user-provided CA from {:?}", cert_path);
272 return ca;
273 }
274 Err(e) => {
275 tracing::error!(
276 error = %e,
277 "failed to load user-provided CA, falling back to auto-generate"
278 );
279 }
280 },
281 _ => {
282 tracing::error!(
283 "failed to read CA files from {:?} / {:?}, falling back to auto-generate",
284 cert_path,
285 key_path,
286 );
287 }
288 }
289 }
290
291 if let Some(default_dir) = default_ca_dir() {
293 let cert_path = default_dir.join("ca.crt");
294 let key_path = default_dir.join("ca.key");
295
296 if cert_path.exists()
297 && key_path.exists()
298 && let (Ok(cert_pem), Ok(key_pem)) =
299 (std::fs::read(&cert_path), std::fs::read(&key_path))
300 && let Ok(ca) = CertAuthority::load(&cert_pem, &key_pem)
301 {
302 tracing::debug!("loaded persisted CA from {:?}", cert_path);
303 return ca;
304 }
305
306 let ca = CertAuthority::generate();
308 if let Err(e) = std::fs::create_dir_all(&default_dir) {
309 tracing::warn!(error = %e, "failed to create CA directory, CA will not persist");
310 } else {
311 if let Err(e) = std::fs::write(&cert_path, ca.cert_pem()) {
312 tracing::warn!(error = %e, "failed to persist CA certificate");
313 }
314 if let Err(e) = write_key_file(&key_path, &ca.key_pem()) {
315 tracing::warn!(error = %e, "failed to persist CA key");
316 } else {
317 tracing::info!("generated and persisted CA to {:?}", default_dir);
318 }
319 }
320 return ca;
321 }
322
323 tracing::warn!("could not determine CA persistence path, generating ephemeral CA");
325 CertAuthority::generate()
326}
327
328fn default_ca_dir() -> Option<std::path::PathBuf> {
330 dirs::home_dir().map(|h| h.join(".microsandbox").join("tls"))
331}
332
333fn write_key_file(path: &std::path::Path, data: &[u8]) -> std::io::Result<()> {
338 use std::io::Write;
339 #[cfg(unix)]
340 {
341 use std::os::unix::fs::OpenOptionsExt;
342 let mut file = std::fs::OpenOptions::new()
343 .write(true)
344 .create(true)
345 .truncate(true)
346 .mode(0o600)
347 .open(path)?;
348 file.write_all(data)?;
349 }
350 #[cfg(not(unix))]
351 {
352 std::fs::write(path, data)?;
353 }
354 Ok(())
355}