microsandbox_network/tls/
state.rs1use std::num::NonZeroUsize;
4use std::path::{Path, PathBuf};
5use std::sync::{Arc, Mutex};
6
7use lru::LruCache;
8use microsandbox_utils::TLS_SUBDIR;
9use rustls::DigitallySignedStruct;
10use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
11use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
12use time::{Duration, OffsetDateTime};
13use tokio_rustls::TlsConnector;
14
15use super::ca::CertAuthority;
16use super::certgen::{self, DomainCert, DomainCertError};
17use super::config::TlsConfig;
18use crate::secrets::config::SecretsConfig;
19
20pub struct TlsState {
29 pub intercept_ca: CertAuthority,
31 cert_cache: Mutex<LruCache<String, Arc<DomainCert>>>,
33 pub connector: TlsConnector,
35 pub config: TlsConfig,
37 pub secrets: SecretsConfig,
39 bypass_patterns: Vec<BypassPattern>,
41}
42
43enum BypassPattern {
45 Exact(String),
47 Wildcard { suffix: String, dotted: String },
50}
51
52#[derive(Debug)]
55struct NoVerify;
56
57const CERT_REFRESH_WINDOW: Duration = Duration::minutes(5);
60
61impl TlsState {
66 pub fn new(config: TlsConfig, secrets: SecretsConfig) -> Self {
73 let ca = load_or_generate_ca(&config);
74
75 let capacity =
76 NonZeroUsize::new(config.cache.capacity).unwrap_or(NonZeroUsize::new(1000).unwrap());
77 let cert_cache = Mutex::new(LruCache::new(capacity));
78
79 let connector = build_upstream_connector(&config);
80
81 let bypass_patterns = config
83 .bypass
84 .iter()
85 .map(|pattern| {
86 let lower = pattern.to_lowercase();
87 if let Some(suffix) = lower.strip_prefix("*.") {
88 let dotted = format!(".{suffix}");
89 BypassPattern::Wildcard {
90 suffix: suffix.to_string(),
91 dotted,
92 }
93 } else {
94 BypassPattern::Exact(lower)
95 }
96 })
97 .collect();
98
99 Self {
100 intercept_ca: ca,
101 cert_cache,
102 connector,
103 config,
104 secrets,
105 bypass_patterns,
106 }
107 }
108
109 pub fn get_or_generate_cert(&self, domain: &str) -> Result<Arc<DomainCert>, DomainCertError> {
111 let mut cache = match self.cert_cache.lock() {
112 Ok(cache) => cache,
113 Err(poisoned) => {
114 tracing::warn!("TLS certificate cache was poisoned; recovering");
115 poisoned.into_inner()
116 }
117 };
118 if let Some(cert) = cache.get(domain)
119 && cert.expires_at > OffsetDateTime::now_utc() + CERT_REFRESH_WINDOW
120 {
121 return Ok(cert.clone());
122 }
123
124 let cert = Arc::new(certgen::generate_domain_cert(
125 domain,
126 &self.intercept_ca,
127 self.config.cache.validity_hours,
128 )?);
129 cache.put(domain.to_string(), cert.clone());
130 Ok(cert)
131 }
132
133 pub fn should_bypass(&self, sni: &str) -> bool {
135 let sni_lower = sni.to_lowercase();
136 self.bypass_patterns.iter().any(|pattern| match pattern {
137 BypassPattern::Exact(exact) => sni_lower == *exact,
138 BypassPattern::Wildcard { suffix, dotted } => {
139 sni_lower == *suffix || sni_lower.ends_with(dotted.as_str())
140 }
141 })
142 }
143
144 pub fn ca_cert_pem(&self) -> Vec<u8> {
146 self.intercept_ca.cert_pem()
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::secrets::config::SecretsConfig;
154
155 #[test]
156 fn regenerates_cached_domain_cert_when_near_expiry() {
157 let _ = rustls::crypto::ring::default_provider().install_default();
158 let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
159 let first = state.get_or_generate_cert("openrouter.ai").unwrap();
160 let original_expires_at = first.expires_at;
161
162 {
163 let mut cache = state.cert_cache.lock().unwrap();
164 let stale = Arc::new(DomainCert {
165 chain: first.chain.clone(),
166 key: first.key.clone_key(),
167 expires_at: OffsetDateTime::now_utc() + Duration::seconds(30),
168 server_config: first.server_config.clone(),
169 });
170 cache.put("openrouter.ai".to_string(), stale);
171 }
172
173 let refreshed = state.get_or_generate_cert("openrouter.ai").unwrap();
174 assert!(refreshed.expires_at > OffsetDateTime::now_utc() + Duration::hours(23));
175 assert!(refreshed.expires_at > original_expires_at - Duration::minutes(10));
176 }
177
178 #[test]
179 fn invalid_domain_cert_request_does_not_poison_cache() {
180 let _ = rustls::crypto::ring::default_provider().install_default();
181 let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
182
183 assert!(state.get_or_generate_cert("snowman.☃").is_err());
184 assert!(state.get_or_generate_cert("openrouter.ai").is_ok());
185 }
186
187 #[test]
188 fn default_ca_dir_uses_microsandbox_home_tls_subdir() {
189 let home = PathBuf::from("isolated-msb-home");
190
191 assert_eq!(
192 default_ca_dir_from_home(&home),
193 home.join(microsandbox_utils::TLS_SUBDIR)
194 );
195 }
196}
197
198impl ServerCertVerifier for NoVerify {
203 fn verify_server_cert(
204 &self,
205 _end_entity: &CertificateDer<'_>,
206 _intermediates: &[CertificateDer<'_>],
207 _server_name: &ServerName<'_>,
208 _ocsp_response: &[u8],
209 _now: UnixTime,
210 ) -> Result<ServerCertVerified, rustls::Error> {
211 Ok(ServerCertVerified::assertion())
212 }
213
214 fn verify_tls12_signature(
215 &self,
216 _message: &[u8],
217 _cert: &CertificateDer<'_>,
218 _dss: &DigitallySignedStruct,
219 ) -> Result<HandshakeSignatureValid, rustls::Error> {
220 Ok(HandshakeSignatureValid::assertion())
221 }
222
223 fn verify_tls13_signature(
224 &self,
225 _message: &[u8],
226 _cert: &CertificateDer<'_>,
227 _dss: &DigitallySignedStruct,
228 ) -> Result<HandshakeSignatureValid, rustls::Error> {
229 Ok(HandshakeSignatureValid::assertion())
230 }
231
232 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
233 static SCHEMES: std::sync::OnceLock<Vec<rustls::SignatureScheme>> =
234 std::sync::OnceLock::new();
235 SCHEMES
236 .get_or_init(|| {
237 rustls::crypto::ring::default_provider()
238 .signature_verification_algorithms
239 .supported_schemes()
240 })
241 .clone()
242 }
243}
244
245fn build_upstream_connector(config: &TlsConfig) -> TlsConnector {
254 let client_config = if config.verify_upstream {
255 let mut root_store = rustls::RootCertStore::empty();
256 let certs = rustls_native_certs::load_native_certs();
257 if !certs.errors.is_empty() {
258 tracing::warn!(
259 count = certs.errors.len(),
260 "errors loading native certificates"
261 );
262 }
263 let mut added = 0usize;
264 for cert in certs.certs {
265 if root_store.add(cert).is_ok() {
266 added += 1;
267 }
268 }
269 if added == 0 {
270 tracing::error!("no native root certificates loaded — all upstream TLS will fail");
271 }
272
273 for path in &config.upstream_ca_cert {
275 match std::fs::read(path) {
276 Ok(pem_data) => {
277 let mut extra_added = 0usize;
278 for cert in rustls_pemfile::certs(&mut pem_data.as_slice()).flatten() {
279 if root_store.add(cert).is_ok() {
280 extra_added += 1;
281 }
282 }
283 tracing::info!(
284 path = %path.display(),
285 count = extra_added,
286 "loaded upstream CA certificates"
287 );
288 }
289 Err(e) => {
290 tracing::error!(
291 path = %path.display(),
292 error = %e,
293 "failed to read upstream CA certificate file"
294 );
295 }
296 }
297 }
298
299 rustls::ClientConfig::builder()
300 .with_root_certificates(root_store)
301 .with_no_client_auth()
302 } else {
303 rustls::ClientConfig::builder()
304 .dangerous()
305 .with_custom_certificate_verifier(Arc::new(NoVerify))
306 .with_no_client_auth()
307 };
308
309 TlsConnector::from(Arc::new(client_config))
310}
311
312fn load_or_generate_ca(config: &TlsConfig) -> CertAuthority {
319 if config.intercept_ca.cert_path.is_some() != config.intercept_ca.key_path.is_some() {
321 tracing::warn!(
322 "incomplete CA config: both cert_path and key_path must be set together, ignoring"
323 );
324 }
325
326 if let (Some(cert_path), Some(key_path)) = (
328 &config.intercept_ca.cert_path,
329 &config.intercept_ca.key_path,
330 ) {
331 match (std::fs::read(cert_path), std::fs::read(key_path)) {
332 (Ok(cert_pem), Ok(key_pem)) => match CertAuthority::load(&cert_pem, &key_pem) {
333 Ok(ca) => {
334 tracing::info!("loaded user-provided CA from {:?}", cert_path);
335 return ca;
336 }
337 Err(e) => {
338 tracing::error!(
339 error = %e,
340 "failed to load user-provided CA, falling back to auto-generate"
341 );
342 }
343 },
344 _ => {
345 tracing::error!(
346 "failed to read CA files from {:?} / {:?}, falling back to auto-generate",
347 cert_path,
348 key_path,
349 );
350 }
351 }
352 }
353
354 let default_dir = default_ca_dir();
356 let cert_path = default_dir.join("ca.crt");
357 let key_path = default_dir.join("ca.key");
358
359 if cert_path.exists()
360 && key_path.exists()
361 && let (Ok(cert_pem), Ok(key_pem)) = (std::fs::read(&cert_path), std::fs::read(&key_path))
362 && let Ok(ca) = CertAuthority::load(&cert_pem, &key_pem)
363 {
364 tracing::debug!("loaded persisted CA from {:?}", cert_path);
365 return ca;
366 }
367
368 let ca = CertAuthority::generate();
370 if let Err(e) = std::fs::create_dir_all(&default_dir) {
371 tracing::warn!(error = %e, "failed to create CA directory, CA will not persist");
372 } else {
373 if let Err(e) = std::fs::write(&cert_path, ca.cert_pem()) {
374 tracing::warn!(error = %e, "failed to persist CA certificate");
375 }
376 if let Err(e) = write_key_file(&key_path, &ca.key_pem()) {
377 tracing::warn!(error = %e, "failed to persist CA key");
378 } else {
379 tracing::info!("generated and persisted CA to {:?}", default_dir);
380 }
381 }
382 ca
383}
384
385fn default_ca_dir() -> PathBuf {
387 default_ca_dir_from_home(microsandbox_utils::resolve_home())
388}
389
390fn default_ca_dir_from_home(home: impl AsRef<Path>) -> PathBuf {
392 home.as_ref().join(TLS_SUBDIR)
393}
394
395fn write_key_file(path: &Path, data: &[u8]) -> std::io::Result<()> {
400 #[cfg(unix)]
401 {
402 use std::io::Write;
403 use std::os::unix::fs::OpenOptionsExt;
404 let mut file = std::fs::OpenOptions::new()
405 .write(true)
406 .create(true)
407 .truncate(true)
408 .mode(0o600)
409 .open(path)?;
410 file.write_all(data)?;
411 }
412 #[cfg(not(unix))]
413 {
414 std::fs::write(path, data)?;
415 }
416 Ok(())
417}