microsandbox_network/tls/
state.rs1use std::num::NonZeroUsize;
4use std::sync::{Arc, Mutex};
5
6use lru::LruCache;
7use rustls::DigitallySignedStruct;
8use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
9use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
10use tokio_rustls::TlsConnector;
11
12use super::ca::CertAuthority;
13use super::certgen::{self, DomainCert};
14use super::config::TlsConfig;
15use crate::secrets::config::SecretsConfig;
16
17pub struct TlsState {
26 pub ca: CertAuthority,
28 cert_cache: Mutex<LruCache<String, Arc<DomainCert>>>,
30 pub connector: TlsConnector,
32 pub config: TlsConfig,
34 pub secrets: SecretsConfig,
36 bypass_patterns: Vec<BypassPattern>,
38}
39
40enum BypassPattern {
42 Exact(String),
44 Wildcard { suffix: String, dotted: String },
47}
48
49#[derive(Debug)]
52struct NoVerify;
53
54impl TlsState {
59 pub fn new(config: TlsConfig, secrets: SecretsConfig) -> Self {
66 let ca = load_or_generate_ca(&config);
67
68 let capacity =
69 NonZeroUsize::new(config.cache.capacity).unwrap_or(NonZeroUsize::new(1000).unwrap());
70 let cert_cache = Mutex::new(LruCache::new(capacity));
71
72 let connector = build_upstream_connector(&config);
73
74 let bypass_patterns = config
76 .bypass
77 .iter()
78 .map(|pattern| {
79 let lower = pattern.to_lowercase();
80 if let Some(suffix) = lower.strip_prefix("*.") {
81 let dotted = format!(".{suffix}");
82 BypassPattern::Wildcard {
83 suffix: suffix.to_string(),
84 dotted,
85 }
86 } else {
87 BypassPattern::Exact(lower)
88 }
89 })
90 .collect();
91
92 Self {
93 ca,
94 cert_cache,
95 connector,
96 config,
97 secrets,
98 bypass_patterns,
99 }
100 }
101
102 pub fn get_or_generate_cert(&self, domain: &str) -> Arc<DomainCert> {
104 let mut cache = self.cert_cache.lock().unwrap();
105 if let Some(cert) = cache.get(domain) {
106 return cert.clone();
107 }
108
109 let cert = Arc::new(certgen::generate_domain_cert(
110 domain,
111 &self.ca,
112 self.config.cache.validity_hours,
113 ));
114 cache.put(domain.to_string(), cert.clone());
115 cert
116 }
117
118 pub fn should_bypass(&self, sni: &str) -> bool {
120 let sni_lower = sni.to_lowercase();
121 self.bypass_patterns.iter().any(|pattern| match pattern {
122 BypassPattern::Exact(exact) => sni_lower == *exact,
123 BypassPattern::Wildcard { suffix, dotted } => {
124 sni_lower == *suffix || sni_lower.ends_with(dotted.as_str())
125 }
126 })
127 }
128
129 pub fn ca_cert_pem(&self) -> Vec<u8> {
131 self.ca.cert_pem()
132 }
133}
134
135impl ServerCertVerifier for NoVerify {
140 fn verify_server_cert(
141 &self,
142 _end_entity: &CertificateDer<'_>,
143 _intermediates: &[CertificateDer<'_>],
144 _server_name: &ServerName<'_>,
145 _ocsp_response: &[u8],
146 _now: UnixTime,
147 ) -> Result<ServerCertVerified, rustls::Error> {
148 Ok(ServerCertVerified::assertion())
149 }
150
151 fn verify_tls12_signature(
152 &self,
153 _message: &[u8],
154 _cert: &CertificateDer<'_>,
155 _dss: &DigitallySignedStruct,
156 ) -> Result<HandshakeSignatureValid, rustls::Error> {
157 Ok(HandshakeSignatureValid::assertion())
158 }
159
160 fn verify_tls13_signature(
161 &self,
162 _message: &[u8],
163 _cert: &CertificateDer<'_>,
164 _dss: &DigitallySignedStruct,
165 ) -> Result<HandshakeSignatureValid, rustls::Error> {
166 Ok(HandshakeSignatureValid::assertion())
167 }
168
169 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
170 static SCHEMES: std::sync::OnceLock<Vec<rustls::SignatureScheme>> =
171 std::sync::OnceLock::new();
172 SCHEMES
173 .get_or_init(|| {
174 rustls::crypto::ring::default_provider()
175 .signature_verification_algorithms
176 .supported_schemes()
177 })
178 .clone()
179 }
180}
181
182fn build_upstream_connector(config: &TlsConfig) -> TlsConnector {
191 let client_config = if config.verify_upstream {
192 let mut root_store = rustls::RootCertStore::empty();
193 let certs = rustls_native_certs::load_native_certs();
194 if !certs.errors.is_empty() {
195 tracing::warn!(
196 count = certs.errors.len(),
197 "errors loading native certificates"
198 );
199 }
200 let mut added = 0usize;
201 for cert in certs.certs {
202 if root_store.add(cert).is_ok() {
203 added += 1;
204 }
205 }
206 if added == 0 {
207 tracing::error!("no native root certificates loaded — all upstream TLS will fail");
208 }
209 rustls::ClientConfig::builder()
210 .with_root_certificates(root_store)
211 .with_no_client_auth()
212 } else {
213 rustls::ClientConfig::builder()
214 .dangerous()
215 .with_custom_certificate_verifier(Arc::new(NoVerify))
216 .with_no_client_auth()
217 };
218
219 TlsConnector::from(Arc::new(client_config))
220}
221
222fn load_or_generate_ca(config: &TlsConfig) -> CertAuthority {
229 if config.ca.cert_path.is_some() != config.ca.key_path.is_some() {
231 tracing::warn!(
232 "incomplete CA config: both cert_path and key_path must be set together, ignoring"
233 );
234 }
235
236 if let (Some(cert_path), Some(key_path)) = (&config.ca.cert_path, &config.ca.key_path) {
238 match (std::fs::read(cert_path), std::fs::read(key_path)) {
239 (Ok(cert_pem), Ok(key_pem)) => match CertAuthority::load(&cert_pem, &key_pem) {
240 Ok(ca) => {
241 tracing::info!("loaded user-provided CA from {:?}", cert_path);
242 return ca;
243 }
244 Err(e) => {
245 tracing::error!(
246 error = %e,
247 "failed to load user-provided CA, falling back to auto-generate"
248 );
249 }
250 },
251 _ => {
252 tracing::error!(
253 "failed to read CA files from {:?} / {:?}, falling back to auto-generate",
254 cert_path,
255 key_path,
256 );
257 }
258 }
259 }
260
261 if let Some(default_dir) = default_ca_dir() {
263 let cert_path = default_dir.join("ca.crt");
264 let key_path = default_dir.join("ca.key");
265
266 if cert_path.exists()
267 && key_path.exists()
268 && let (Ok(cert_pem), Ok(key_pem)) =
269 (std::fs::read(&cert_path), std::fs::read(&key_path))
270 && let Ok(ca) = CertAuthority::load(&cert_pem, &key_pem)
271 {
272 tracing::debug!("loaded persisted CA from {:?}", cert_path);
273 return ca;
274 }
275
276 let ca = CertAuthority::generate();
278 if let Err(e) = std::fs::create_dir_all(&default_dir) {
279 tracing::warn!(error = %e, "failed to create CA directory, CA will not persist");
280 } else {
281 if let Err(e) = std::fs::write(&cert_path, ca.cert_pem()) {
282 tracing::warn!(error = %e, "failed to persist CA certificate");
283 }
284 if let Err(e) = write_key_file(&key_path, &ca.key_pem()) {
285 tracing::warn!(error = %e, "failed to persist CA key");
286 } else {
287 tracing::info!("generated and persisted CA to {:?}", default_dir);
288 }
289 }
290 return ca;
291 }
292
293 tracing::warn!("could not determine CA persistence path, generating ephemeral CA");
295 CertAuthority::generate()
296}
297
298fn default_ca_dir() -> Option<std::path::PathBuf> {
300 dirs::home_dir().map(|h| h.join(".microsandbox").join("tls"))
301}
302
303fn write_key_file(path: &std::path::Path, data: &[u8]) -> std::io::Result<()> {
308 use std::io::Write;
309 #[cfg(unix)]
310 {
311 use std::os::unix::fs::OpenOptionsExt;
312 let mut file = std::fs::OpenOptions::new()
313 .write(true)
314 .create(true)
315 .truncate(true)
316 .mode(0o600)
317 .open(path)?;
318 file.write_all(data)?;
319 }
320 #[cfg(not(unix))]
321 {
322 std::fs::write(path, data)?;
323 }
324 Ok(())
325}