microsandbox_network/tls/
state.rs1use std::num::NonZeroUsize;
4use std::sync::{Arc, Mutex};
5
6use lru::LruCache;
7use rustls::DigitallySignedStruct;
8use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
9use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
10use time::{Duration, OffsetDateTime};
11use tokio_rustls::TlsConnector;
12
13use super::ca::CertAuthority;
14use super::certgen::{self, DomainCert};
15use super::config::TlsConfig;
16use crate::secrets::config::SecretsConfig;
17
18pub struct TlsState {
27 pub intercept_ca: CertAuthority,
29 cert_cache: Mutex<LruCache<String, Arc<DomainCert>>>,
31 pub connector: TlsConnector,
33 pub config: TlsConfig,
35 pub secrets: SecretsConfig,
37 bypass_patterns: Vec<BypassPattern>,
39}
40
41enum BypassPattern {
43 Exact(String),
45 Wildcard { suffix: String, dotted: String },
48}
49
50#[derive(Debug)]
53struct NoVerify;
54
55const CERT_REFRESH_WINDOW: Duration = Duration::minutes(5);
58
59impl TlsState {
64 pub fn new(config: TlsConfig, secrets: SecretsConfig) -> Self {
71 let ca = load_or_generate_ca(&config);
72
73 let capacity =
74 NonZeroUsize::new(config.cache.capacity).unwrap_or(NonZeroUsize::new(1000).unwrap());
75 let cert_cache = Mutex::new(LruCache::new(capacity));
76
77 let connector = build_upstream_connector(&config);
78
79 let bypass_patterns = config
81 .bypass
82 .iter()
83 .map(|pattern| {
84 let lower = pattern.to_lowercase();
85 if let Some(suffix) = lower.strip_prefix("*.") {
86 let dotted = format!(".{suffix}");
87 BypassPattern::Wildcard {
88 suffix: suffix.to_string(),
89 dotted,
90 }
91 } else {
92 BypassPattern::Exact(lower)
93 }
94 })
95 .collect();
96
97 Self {
98 intercept_ca: ca,
99 cert_cache,
100 connector,
101 config,
102 secrets,
103 bypass_patterns,
104 }
105 }
106
107 pub fn get_or_generate_cert(&self, domain: &str) -> Arc<DomainCert> {
109 let mut cache = self.cert_cache.lock().unwrap();
110 if let Some(cert) = cache.get(domain)
111 && cert.expires_at > OffsetDateTime::now_utc() + CERT_REFRESH_WINDOW
112 {
113 return cert.clone();
114 }
115
116 let cert = Arc::new(certgen::generate_domain_cert(
117 domain,
118 &self.intercept_ca,
119 self.config.cache.validity_hours,
120 ));
121 cache.put(domain.to_string(), cert.clone());
122 cert
123 }
124
125 pub fn should_bypass(&self, sni: &str) -> bool {
127 let sni_lower = sni.to_lowercase();
128 self.bypass_patterns.iter().any(|pattern| match pattern {
129 BypassPattern::Exact(exact) => sni_lower == *exact,
130 BypassPattern::Wildcard { suffix, dotted } => {
131 sni_lower == *suffix || sni_lower.ends_with(dotted.as_str())
132 }
133 })
134 }
135
136 pub fn ca_cert_pem(&self) -> Vec<u8> {
138 self.intercept_ca.cert_pem()
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::secrets::config::SecretsConfig;
146
147 #[test]
148 fn regenerates_cached_domain_cert_when_near_expiry() {
149 let _ = rustls::crypto::ring::default_provider().install_default();
150 let state = TlsState::new(TlsConfig::default(), SecretsConfig::default());
151 let first = state.get_or_generate_cert("openrouter.ai");
152 let original_expires_at = first.expires_at;
153
154 {
155 let mut cache = state.cert_cache.lock().unwrap();
156 let stale = Arc::new(DomainCert {
157 chain: first.chain.clone(),
158 key: first.key.clone_key(),
159 expires_at: OffsetDateTime::now_utc() + Duration::seconds(30),
160 server_config: first.server_config.clone(),
161 });
162 cache.put("openrouter.ai".to_string(), stale);
163 }
164
165 let refreshed = state.get_or_generate_cert("openrouter.ai");
166 assert!(refreshed.expires_at > OffsetDateTime::now_utc() + Duration::hours(23));
167 assert!(refreshed.expires_at > original_expires_at - Duration::minutes(10));
168 }
169}
170
171impl ServerCertVerifier for NoVerify {
176 fn verify_server_cert(
177 &self,
178 _end_entity: &CertificateDer<'_>,
179 _intermediates: &[CertificateDer<'_>],
180 _server_name: &ServerName<'_>,
181 _ocsp_response: &[u8],
182 _now: UnixTime,
183 ) -> Result<ServerCertVerified, rustls::Error> {
184 Ok(ServerCertVerified::assertion())
185 }
186
187 fn verify_tls12_signature(
188 &self,
189 _message: &[u8],
190 _cert: &CertificateDer<'_>,
191 _dss: &DigitallySignedStruct,
192 ) -> Result<HandshakeSignatureValid, rustls::Error> {
193 Ok(HandshakeSignatureValid::assertion())
194 }
195
196 fn verify_tls13_signature(
197 &self,
198 _message: &[u8],
199 _cert: &CertificateDer<'_>,
200 _dss: &DigitallySignedStruct,
201 ) -> Result<HandshakeSignatureValid, rustls::Error> {
202 Ok(HandshakeSignatureValid::assertion())
203 }
204
205 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
206 static SCHEMES: std::sync::OnceLock<Vec<rustls::SignatureScheme>> =
207 std::sync::OnceLock::new();
208 SCHEMES
209 .get_or_init(|| {
210 rustls::crypto::ring::default_provider()
211 .signature_verification_algorithms
212 .supported_schemes()
213 })
214 .clone()
215 }
216}
217
218fn build_upstream_connector(config: &TlsConfig) -> TlsConnector {
227 let client_config = if config.verify_upstream {
228 let mut root_store = rustls::RootCertStore::empty();
229 let certs = rustls_native_certs::load_native_certs();
230 if !certs.errors.is_empty() {
231 tracing::warn!(
232 count = certs.errors.len(),
233 "errors loading native certificates"
234 );
235 }
236 let mut added = 0usize;
237 for cert in certs.certs {
238 if root_store.add(cert).is_ok() {
239 added += 1;
240 }
241 }
242 if added == 0 {
243 tracing::error!("no native root certificates loaded — all upstream TLS will fail");
244 }
245
246 for path in &config.upstream_ca_cert {
248 match std::fs::read(path) {
249 Ok(pem_data) => {
250 let mut extra_added = 0usize;
251 for cert in rustls_pemfile::certs(&mut pem_data.as_slice()).flatten() {
252 if root_store.add(cert).is_ok() {
253 extra_added += 1;
254 }
255 }
256 tracing::info!(
257 path = %path.display(),
258 count = extra_added,
259 "loaded upstream CA certificates"
260 );
261 }
262 Err(e) => {
263 tracing::error!(
264 path = %path.display(),
265 error = %e,
266 "failed to read upstream CA certificate file"
267 );
268 }
269 }
270 }
271
272 rustls::ClientConfig::builder()
273 .with_root_certificates(root_store)
274 .with_no_client_auth()
275 } else {
276 rustls::ClientConfig::builder()
277 .dangerous()
278 .with_custom_certificate_verifier(Arc::new(NoVerify))
279 .with_no_client_auth()
280 };
281
282 TlsConnector::from(Arc::new(client_config))
283}
284
285fn load_or_generate_ca(config: &TlsConfig) -> CertAuthority {
292 if config.intercept_ca.cert_path.is_some() != config.intercept_ca.key_path.is_some() {
294 tracing::warn!(
295 "incomplete CA config: both cert_path and key_path must be set together, ignoring"
296 );
297 }
298
299 if let (Some(cert_path), Some(key_path)) = (
301 &config.intercept_ca.cert_path,
302 &config.intercept_ca.key_path,
303 ) {
304 match (std::fs::read(cert_path), std::fs::read(key_path)) {
305 (Ok(cert_pem), Ok(key_pem)) => match CertAuthority::load(&cert_pem, &key_pem) {
306 Ok(ca) => {
307 tracing::info!("loaded user-provided CA from {:?}", cert_path);
308 return ca;
309 }
310 Err(e) => {
311 tracing::error!(
312 error = %e,
313 "failed to load user-provided CA, falling back to auto-generate"
314 );
315 }
316 },
317 _ => {
318 tracing::error!(
319 "failed to read CA files from {:?} / {:?}, falling back to auto-generate",
320 cert_path,
321 key_path,
322 );
323 }
324 }
325 }
326
327 if let Some(default_dir) = default_ca_dir() {
329 let cert_path = default_dir.join("ca.crt");
330 let key_path = default_dir.join("ca.key");
331
332 if cert_path.exists()
333 && key_path.exists()
334 && let (Ok(cert_pem), Ok(key_pem)) =
335 (std::fs::read(&cert_path), std::fs::read(&key_path))
336 && let Ok(ca) = CertAuthority::load(&cert_pem, &key_pem)
337 {
338 tracing::debug!("loaded persisted CA from {:?}", cert_path);
339 return ca;
340 }
341
342 let ca = CertAuthority::generate();
344 if let Err(e) = std::fs::create_dir_all(&default_dir) {
345 tracing::warn!(error = %e, "failed to create CA directory, CA will not persist");
346 } else {
347 if let Err(e) = std::fs::write(&cert_path, ca.cert_pem()) {
348 tracing::warn!(error = %e, "failed to persist CA certificate");
349 }
350 if let Err(e) = write_key_file(&key_path, &ca.key_pem()) {
351 tracing::warn!(error = %e, "failed to persist CA key");
352 } else {
353 tracing::info!("generated and persisted CA to {:?}", default_dir);
354 }
355 }
356 return ca;
357 }
358
359 tracing::warn!("could not determine CA persistence path, generating ephemeral CA");
361 CertAuthority::generate()
362}
363
364fn default_ca_dir() -> Option<std::path::PathBuf> {
366 dirs::home_dir().map(|h| h.join(".microsandbox").join("tls"))
367}
368
369fn write_key_file(path: &std::path::Path, data: &[u8]) -> std::io::Result<()> {
374 use std::io::Write;
375 #[cfg(unix)]
376 {
377 use std::os::unix::fs::OpenOptionsExt;
378 let mut file = std::fs::OpenOptions::new()
379 .write(true)
380 .create(true)
381 .truncate(true)
382 .mode(0o600)
383 .open(path)?;
384 file.write_all(data)?;
385 }
386 #[cfg(not(unix))]
387 {
388 std::fs::write(path, data)?;
389 }
390 Ok(())
391}