microsandbox_network/tls/
state.rs1use std::num::NonZeroUsize;
4use std::sync::{Arc, Mutex};
5
6use lru::LruCache;
7use rustls::DigitallySignedStruct;
8use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
9use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
10use time::{Duration, OffsetDateTime};
11use tokio_rustls::TlsConnector;
12
13use super::ca::CertAuthority;
14use super::certgen::{self, DomainCert, DomainCertError};
15use super::config::TlsConfig;
16use crate::secrets::config::SecretsConfig;
17
18pub struct TlsState {
27 pub intercept_ca: CertAuthority,
29 cert_cache: Mutex<LruCache<String, Arc<DomainCert>>>,
31 pub connector: TlsConnector,
33 pub config: TlsConfig,
35 pub secrets: SecretsConfig,
37 bypass_patterns: Vec<BypassPattern>,
39}
40
41enum BypassPattern {
43 Exact(String),
45 Wildcard { suffix: String, dotted: String },
48}
49
50#[derive(Debug)]
53struct NoVerify;
54
55const CERT_REFRESH_WINDOW: Duration = Duration::minutes(5);
58
59impl TlsState {
64 pub fn new(config: TlsConfig, secrets: SecretsConfig) -> Self {
71 let ca = load_or_generate_ca(&config);
72
73 let capacity =
74 NonZeroUsize::new(config.cache.capacity).unwrap_or(NonZeroUsize::new(1000).unwrap());
75 let cert_cache = Mutex::new(LruCache::new(capacity));
76
77 let connector = build_upstream_connector(&config);
78
79 let bypass_patterns = config
81 .bypass
82 .iter()
83 .map(|pattern| {
84 let lower = pattern.to_lowercase();
85 if let Some(suffix) = lower.strip_prefix("*.") {
86 let dotted = format!(".{suffix}");
87 BypassPattern::Wildcard {
88 suffix: suffix.to_string(),
89 dotted,
90 }
91 } else {
92 BypassPattern::Exact(lower)
93 }
94 })
95 .collect();
96
97 Self {
98 intercept_ca: ca,
99 cert_cache,
100 connector,
101 config,
102 secrets,
103 bypass_patterns,
104 }
105 }
106
107 pub fn get_or_generate_cert(&self, domain: &str) -> Result<Arc<DomainCert>, DomainCertError> {
109 let mut cache = match self.cert_cache.lock() {
110 Ok(cache) => cache,
111 Err(poisoned) => {
112 tracing::warn!("TLS certificate cache was poisoned; recovering");
113 poisoned.into_inner()
114 }
115 };
116 if let Some(cert) = cache.get(domain)
117 && cert.expires_at > OffsetDateTime::now_utc() + CERT_REFRESH_WINDOW
118 {
119 return Ok(cert.clone());
120 }
121
122 let cert = Arc::new(certgen::generate_domain_cert(
123 domain,
124 &self.intercept_ca,
125 self.config.cache.validity_hours,
126 )?);
127 cache.put(domain.to_string(), cert.clone());
128 Ok(cert)
129 }
130
131 pub fn should_bypass(&self, sni: &str) -> bool {
133 let sni_lower = sni.to_lowercase();
134 self.bypass_patterns.iter().any(|pattern| match pattern {
135 BypassPattern::Exact(exact) => sni_lower == *exact,
136 BypassPattern::Wildcard { suffix, dotted } => {
137 sni_lower == *suffix || sni_lower.ends_with(dotted.as_str())
138 }
139 })
140 }
141
142 pub fn ca_cert_pem(&self) -> Vec<u8> {
144 self.intercept_ca.cert_pem()
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::secrets::config::SecretsConfig;
152
153 #[test]
154 fn regenerates_cached_domain_cert_when_near_expiry() {
155 let _ = rustls::crypto::ring::default_provider().install_default();
156 let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
157 let first = state.get_or_generate_cert("openrouter.ai").unwrap();
158 let original_expires_at = first.expires_at;
159
160 {
161 let mut cache = state.cert_cache.lock().unwrap();
162 let stale = Arc::new(DomainCert {
163 chain: first.chain.clone(),
164 key: first.key.clone_key(),
165 expires_at: OffsetDateTime::now_utc() + Duration::seconds(30),
166 server_config: first.server_config.clone(),
167 });
168 cache.put("openrouter.ai".to_string(), stale);
169 }
170
171 let refreshed = state.get_or_generate_cert("openrouter.ai").unwrap();
172 assert!(refreshed.expires_at > OffsetDateTime::now_utc() + Duration::hours(23));
173 assert!(refreshed.expires_at > original_expires_at - Duration::minutes(10));
174 }
175
176 #[test]
177 fn invalid_domain_cert_request_does_not_poison_cache() {
178 let _ = rustls::crypto::ring::default_provider().install_default();
179 let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
180
181 assert!(state.get_or_generate_cert("snowman.☃").is_err());
182 assert!(state.get_or_generate_cert("openrouter.ai").is_ok());
183 }
184}
185
186impl ServerCertVerifier for NoVerify {
191 fn verify_server_cert(
192 &self,
193 _end_entity: &CertificateDer<'_>,
194 _intermediates: &[CertificateDer<'_>],
195 _server_name: &ServerName<'_>,
196 _ocsp_response: &[u8],
197 _now: UnixTime,
198 ) -> Result<ServerCertVerified, rustls::Error> {
199 Ok(ServerCertVerified::assertion())
200 }
201
202 fn verify_tls12_signature(
203 &self,
204 _message: &[u8],
205 _cert: &CertificateDer<'_>,
206 _dss: &DigitallySignedStruct,
207 ) -> Result<HandshakeSignatureValid, rustls::Error> {
208 Ok(HandshakeSignatureValid::assertion())
209 }
210
211 fn verify_tls13_signature(
212 &self,
213 _message: &[u8],
214 _cert: &CertificateDer<'_>,
215 _dss: &DigitallySignedStruct,
216 ) -> Result<HandshakeSignatureValid, rustls::Error> {
217 Ok(HandshakeSignatureValid::assertion())
218 }
219
220 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
221 static SCHEMES: std::sync::OnceLock<Vec<rustls::SignatureScheme>> =
222 std::sync::OnceLock::new();
223 SCHEMES
224 .get_or_init(|| {
225 rustls::crypto::ring::default_provider()
226 .signature_verification_algorithms
227 .supported_schemes()
228 })
229 .clone()
230 }
231}
232
233fn build_upstream_connector(config: &TlsConfig) -> TlsConnector {
242 let client_config = if config.verify_upstream {
243 let mut root_store = rustls::RootCertStore::empty();
244 let certs = rustls_native_certs::load_native_certs();
245 if !certs.errors.is_empty() {
246 tracing::warn!(
247 count = certs.errors.len(),
248 "errors loading native certificates"
249 );
250 }
251 let mut added = 0usize;
252 for cert in certs.certs {
253 if root_store.add(cert).is_ok() {
254 added += 1;
255 }
256 }
257 if added == 0 {
258 tracing::error!("no native root certificates loaded — all upstream TLS will fail");
259 }
260
261 for path in &config.upstream_ca_cert {
263 match std::fs::read(path) {
264 Ok(pem_data) => {
265 let mut extra_added = 0usize;
266 for cert in rustls_pemfile::certs(&mut pem_data.as_slice()).flatten() {
267 if root_store.add(cert).is_ok() {
268 extra_added += 1;
269 }
270 }
271 tracing::info!(
272 path = %path.display(),
273 count = extra_added,
274 "loaded upstream CA certificates"
275 );
276 }
277 Err(e) => {
278 tracing::error!(
279 path = %path.display(),
280 error = %e,
281 "failed to read upstream CA certificate file"
282 );
283 }
284 }
285 }
286
287 rustls::ClientConfig::builder()
288 .with_root_certificates(root_store)
289 .with_no_client_auth()
290 } else {
291 rustls::ClientConfig::builder()
292 .dangerous()
293 .with_custom_certificate_verifier(Arc::new(NoVerify))
294 .with_no_client_auth()
295 };
296
297 TlsConnector::from(Arc::new(client_config))
298}
299
300fn load_or_generate_ca(config: &TlsConfig) -> CertAuthority {
307 if config.intercept_ca.cert_path.is_some() != config.intercept_ca.key_path.is_some() {
309 tracing::warn!(
310 "incomplete CA config: both cert_path and key_path must be set together, ignoring"
311 );
312 }
313
314 if let (Some(cert_path), Some(key_path)) = (
316 &config.intercept_ca.cert_path,
317 &config.intercept_ca.key_path,
318 ) {
319 match (std::fs::read(cert_path), std::fs::read(key_path)) {
320 (Ok(cert_pem), Ok(key_pem)) => match CertAuthority::load(&cert_pem, &key_pem) {
321 Ok(ca) => {
322 tracing::info!("loaded user-provided CA from {:?}", cert_path);
323 return ca;
324 }
325 Err(e) => {
326 tracing::error!(
327 error = %e,
328 "failed to load user-provided CA, falling back to auto-generate"
329 );
330 }
331 },
332 _ => {
333 tracing::error!(
334 "failed to read CA files from {:?} / {:?}, falling back to auto-generate",
335 cert_path,
336 key_path,
337 );
338 }
339 }
340 }
341
342 if let Some(default_dir) = default_ca_dir() {
344 let cert_path = default_dir.join("ca.crt");
345 let key_path = default_dir.join("ca.key");
346
347 if cert_path.exists()
348 && key_path.exists()
349 && let (Ok(cert_pem), Ok(key_pem)) =
350 (std::fs::read(&cert_path), std::fs::read(&key_path))
351 && let Ok(ca) = CertAuthority::load(&cert_pem, &key_pem)
352 {
353 tracing::debug!("loaded persisted CA from {:?}", cert_path);
354 return ca;
355 }
356
357 let ca = CertAuthority::generate();
359 if let Err(e) = std::fs::create_dir_all(&default_dir) {
360 tracing::warn!(error = %e, "failed to create CA directory, CA will not persist");
361 } else {
362 if let Err(e) = std::fs::write(&cert_path, ca.cert_pem()) {
363 tracing::warn!(error = %e, "failed to persist CA certificate");
364 }
365 if let Err(e) = write_key_file(&key_path, &ca.key_pem()) {
366 tracing::warn!(error = %e, "failed to persist CA key");
367 } else {
368 tracing::info!("generated and persisted CA to {:?}", default_dir);
369 }
370 }
371 return ca;
372 }
373
374 tracing::warn!("could not determine CA persistence path, generating ephemeral CA");
376 CertAuthority::generate()
377}
378
379fn default_ca_dir() -> Option<std::path::PathBuf> {
381 dirs::home_dir().map(|h| h.join(".microsandbox").join("tls"))
382}
383
384fn write_key_file(path: &std::path::Path, data: &[u8]) -> std::io::Result<()> {
389 use std::io::Write;
390 #[cfg(unix)]
391 {
392 use std::os::unix::fs::OpenOptionsExt;
393 let mut file = std::fs::OpenOptions::new()
394 .write(true)
395 .create(true)
396 .truncate(true)
397 .mode(0o600)
398 .open(path)?;
399 file.write_all(data)?;
400 }
401 #[cfg(not(unix))]
402 {
403 std::fs::write(path, data)?;
404 }
405 Ok(())
406}