1use std::net::SocketAddr;
11use std::sync::Arc;
12
13use axum::Router;
14use axum::body::Body;
15use axum::extract::{Request, State};
16use axum::http::{HeaderValue, StatusCode, Uri};
17use axum::response::{IntoResponse, Response};
18use hyper::header::HOST;
19
20const PITCHFORK_HEADER: &str = "x-pitchfork";
22
23const PROXY_HOPS_HEADER: &str = "x-pitchfork-hops";
26
27const MAX_PROXY_HOPS: u64 = 5;
29
30const HOP_BY_HOP_HEADERS: &[&str] = &[
33 "connection",
34 "keep-alive",
35 "proxy-connection",
36 "transfer-encoding",
37 "upgrade",
38];
39
40use hyper_util::client::legacy::Client;
41use hyper_util::client::legacy::connect::HttpConnector;
42use hyper_util::rt::TokioExecutor;
43use tokio::net::TcpListener;
44
45use crate::daemon_id::DaemonId;
46use crate::settings::settings;
47use crate::supervisor::SUPERVISOR;
48
49const SLUG_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(2);
62
63#[derive(Clone, Debug)]
65struct CachedSlugEntry {
66 slug: String,
68 namespace: Option<String>,
70 daemon_name: String,
72 dir: std::path::PathBuf,
74}
75
76struct SlugCache {
78 entries: Arc<std::collections::HashMap<String, CachedSlugEntry>>,
79 expires_at: std::time::Instant,
80}
81
82static SLUG_CACHE: once_cell::sync::Lazy<tokio::sync::Mutex<SlugCache>> =
83 once_cell::sync::Lazy::new(|| {
84 tokio::sync::Mutex::new(SlugCache {
85 entries: Arc::new(std::collections::HashMap::new()),
86 expires_at: std::time::Instant::now(), })
88 });
89
90fn build_slug_entries() -> std::collections::HashMap<String, CachedSlugEntry> {
93 let global_slugs = crate::pitchfork_toml::PitchforkToml::read_global_slugs();
94 let mut entries = std::collections::HashMap::with_capacity(global_slugs.len());
95 for (slug, entry) in &global_slugs {
96 let ns = crate::pitchfork_toml::PitchforkToml::namespace_for_dir(&entry.dir).ok();
97 let daemon_name = entry.daemon.as_deref().unwrap_or(slug).to_string();
98 entries.insert(
99 slug.clone(),
100 CachedSlugEntry {
101 slug: slug.clone(),
102 namespace: ns,
103 daemon_name,
104 dir: entry.dir.clone(),
105 },
106 );
107 }
108 entries
109}
110
111async fn get_cached_slugs() -> Arc<std::collections::HashMap<String, CachedSlugEntry>> {
117 {
119 let cache = SLUG_CACHE.lock().await;
120 if std::time::Instant::now() < cache.expires_at {
121 return Arc::clone(&cache.entries);
122 }
123 } let new_entries = Arc::new(build_slug_entries());
127
128 {
130 let mut cache = SLUG_CACHE.lock().await;
131 cache.entries = Arc::clone(&new_entries);
132 cache.expires_at = std::time::Instant::now() + SLUG_CACHE_TTL;
133 }
134
135 new_entries
136}
137
138fn wildcard_slug_lookup<'a>(
144 subdomain: &str,
145 entries: &'a std::collections::HashMap<String, CachedSlugEntry>,
146 wildcard: bool,
147) -> Option<&'a CachedSlugEntry> {
148 entries.get(subdomain).or_else(|| {
149 if !wildcard {
150 return None;
151 }
152 subdomain
154 .match_indices('.')
155 .map(|(i, _)| &subdomain[i + 1..])
156 .find_map(|candidate| entries.get(candidate))
157 })
158}
159
160async fn cached_slug_lookup(subdomain: &str) -> Option<CachedSlugEntry> {
167 let entries = get_cached_slugs().await;
168 wildcard_slug_lookup(subdomain, &entries, settings().proxy.wildcard).cloned()
169}
170
171static AUTO_START_IN_PROGRESS: once_cell::sync::Lazy<
178 tokio::sync::Mutex<std::collections::HashSet<DaemonId>>,
179> = once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new(std::collections::HashSet::new()));
180
181enum ResolveResult {
183 Ready(u16),
186 Starting { slug: String },
188 NotFound,
190 Error(String),
192}
193
194type OnErrorFn = Arc<dyn Fn(&str) + Send + Sync>;
197
198#[derive(Clone)]
199struct ProxyState {
200 client: Arc<Client<HttpConnector, Body>>,
202 tld: String,
204 is_tls: bool,
206 on_error: Option<OnErrorFn>,
208}
209
210pub async fn serve(
218 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
219 cancel: tokio_util::sync::CancellationToken,
220) -> crate::Result<()> {
221 let s = settings();
222 let lan_enabled = s.proxy.lan || !s.proxy.lan_ip.is_empty();
223
224 let effective_tld = if lan_enabled {
225 "local".to_string()
226 } else {
227 s.proxy.tld.clone()
228 };
229
230 let Some(effective_port) = u16::try_from(s.proxy.port).ok().filter(|&p| p > 0) else {
231 let msg = format!(
232 "proxy.port {} is out of valid port range (1-65535), proxy server cannot start",
233 s.proxy.port
234 );
235 let _ = bind_tx.send(Err(msg.clone()));
236 miette::bail!("{msg}");
237 };
238
239 let mut connector = HttpConnector::new();
240 connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
244
245 let client = Client::builder(TokioExecutor::new())
246 .pool_idle_timeout(std::time::Duration::from_secs(30))
249 .build(connector);
250
251 let state = ProxyState {
252 client: Arc::new(client),
253 tld: effective_tld.clone(),
254 is_tls: s.proxy.https,
255 on_error: None,
256 };
257
258 let app = Router::new().fallback(proxy_handler).with_state(state);
259
260 let bind_ip: std::net::IpAddr = if lan_enabled && s.proxy.host == "127.0.0.1" {
264 std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)
265 } else {
266 match s.proxy.host.parse() {
267 Ok(ip) => ip,
268 Err(_) => {
269 log::warn!(
270 "proxy.host {:?} is not a valid IP address — falling back to 127.0.0.1. \
271 The proxy will only be reachable on the loopback interface.",
272 s.proxy.host
273 );
274 std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
275 }
276 }
277 };
278 let addr = SocketAddr::from((bind_ip, effective_port));
279
280 if s.proxy.https {
281 serve_https_with_http_fallback(app, addr, s, effective_port, bind_tx, cancel).await
282 } else {
283 serve_http(app, addr, effective_port, bind_tx, cancel).await
284 }
285}
286
287async fn serve_http(
289 app: Router,
290 addr: SocketAddr,
291 effective_port: u16,
292 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
293 cancel: tokio_util::sync::CancellationToken,
294) -> crate::Result<()> {
295 let listener = match TcpListener::bind(addr).await {
296 Ok(l) => {
297 if settings().proxy.sync_hosts {
298 crate::proxy::hosts::sync_hosts_from_settings();
299 }
300 let _ = bind_tx.send(Ok(()));
301 l
302 }
303 Err(e) => {
304 let msg = bind_error_message(effective_port, &e);
305 let _ = bind_tx.send(Err(msg.clone()));
306 return Err(miette::miette!("{msg}"));
307 }
308 };
309
310 log::info!("Proxy server listening on http://{addr}");
311 if effective_port < 1024 {
312 log::info!(
313 "Note: port {effective_port} is a privileged port. \
314 The supervisor must be started with sudo to bind to this port."
315 );
316 }
317 let shutdown_signal = cancel.clone().cancelled_owned();
318 axum::serve(
319 listener,
320 app.into_make_service_with_connect_info::<SocketAddr>(),
321 )
322 .with_graceful_shutdown(shutdown_signal)
323 .await
324 .map_err(|e| miette::miette!("Proxy server error: {e}"))?;
325 Ok(())
326}
327
328#[cfg(feature = "proxy-tls")]
334async fn serve_https_with_http_fallback(
335 app: Router,
336 addr: SocketAddr,
337 s: &crate::settings::Settings,
338 effective_port: u16,
339 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
340 cancel: tokio_util::sync::CancellationToken,
341) -> crate::Result<()> {
342 use rustls::ServerConfig;
343 use tokio_rustls::TlsAcceptor;
344
345 let (ca_cert_path, ca_key_path) = resolve_tls_paths(s);
346
347 if !ca_cert_path.exists() || !ca_key_path.exists() {
349 generate_ca(&ca_cert_path, &ca_key_path)?;
350 log::info!(
351 "Generated local CA certificate at {}",
352 ca_cert_path.display()
353 );
354 log::info!("To trust the CA in your browser, run: pitchfork proxy trust");
355 }
356
357 let _ = rustls::crypto::ring::default_provider().install_default();
359
360 let resolver = SniCertResolver::new(&ca_cert_path, &ca_key_path)?;
362
363 let mut tls_config = ServerConfig::builder()
364 .with_no_client_auth()
365 .with_cert_resolver(Arc::new(resolver));
366 tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
369
370 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
371
372 let listener = match TcpListener::bind(addr).await {
373 Ok(l) => {
374 if settings().proxy.sync_hosts {
375 crate::proxy::hosts::sync_hosts_from_settings();
376 }
377 let _ = bind_tx.send(Ok(()));
378 l
379 }
380 Err(e) => {
381 let msg = bind_error_message(effective_port, &e);
382 let _ = bind_tx.send(Err(msg.clone()));
383 return Err(miette::miette!("{msg}"));
384 }
385 };
386
387 log::info!("Proxy server listening on https://{addr} (HTTP also accepted)");
388 if effective_port < 1024 {
389 log::info!(
390 "Note: port {effective_port} is a privileged port. \
391 The supervisor must be started with sudo to bind to this port."
392 );
393 }
394
395 let redirect_app = Router::new().fallback(redirect_to_https_handler);
397
398 let mut conn_tasks: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
400 loop {
401 while conn_tasks.try_join_next().is_some() {}
404
405 tokio::select! {
406 accept_result = listener.accept() => {
407 let (stream, _peer_addr) = match accept_result {
408 Ok(conn) => conn,
409 Err(e) => {
410 log::warn!("Accept error (will retry): {e}");
411 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
412 continue;
413 }
414 };
415
416 let acceptor = acceptor.clone();
417 let app = app.clone();
418 let redirect_app = redirect_app.clone();
419
420 conn_tasks.spawn(async move {
421 let mut peek_buf = [0u8; 1];
424 match stream.peek(&mut peek_buf).await {
425 Ok(0) | Err(_) => return,
426 _ => {}
427 }
428
429 if peek_buf[0] == 0x16 {
430 match acceptor.accept(stream).await {
432 Ok(tls_stream) => {
433 let io = hyper_util::rt::TokioIo::new(tls_stream);
434 let svc = hyper_util::service::TowerToHyperService::new(app);
435 if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
436 .serve_connection_with_upgrades(io, svc)
437 .await
438 {
439 log::debug!("Connection error: {e}");
442 }
443 }
444 Err(e) => {
445 log::debug!("TLS handshake error: {e}");
446 }
447 }
448 } else {
449 let io = hyper_util::rt::TokioIo::new(stream);
451 let svc = hyper_util::service::TowerToHyperService::new(redirect_app);
452 let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
453 .serve_connection_with_upgrades(io, svc)
454 .await;
455 }
456 });
457
458 while conn_tasks.try_join_next().is_some() {}
459 }
460 _ = cancel.cancelled() => {
461 log::info!("Proxy server shutting down (cancel signal received)");
462 break;
463 }
464 }
465 }
466
467 let drain_timeout = std::time::Duration::from_secs(10);
469 let _ = tokio::time::timeout(drain_timeout, async {
470 while conn_tasks.join_next().await.is_some() {}
471 })
472 .await;
473
474 Ok(())
475}
476
477#[cfg(not(feature = "proxy-tls"))]
479async fn serve_https_with_http_fallback(
480 _app: Router,
481 _addr: SocketAddr,
482 _s: &crate::settings::Settings,
483 _effective_port: u16,
484 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
485 _cancel: tokio_util::sync::CancellationToken,
486) -> crate::Result<()> {
487 let msg = "HTTPS proxy support requires the `proxy-tls` feature.\n\
488 Rebuild pitchfork with: cargo build --features proxy-tls"
489 .to_string();
490 let _ = bind_tx.send(Err(msg.clone()));
491 miette::bail!("{msg}")
492}
493
494#[cfg(feature = "proxy-tls")]
499fn resolve_tls_paths(s: &crate::settings::Settings) -> (std::path::PathBuf, std::path::PathBuf) {
500 let proxy_dir = crate::env::PITCHFORK_STATE_DIR.join("proxy");
501 let resolve = |configured: &str, default: &str| {
502 if configured.is_empty() {
503 proxy_dir.join(default)
504 } else {
505 std::path::PathBuf::from(configured)
506 }
507 };
508 (
509 resolve(&s.proxy.tls_cert, "ca.pem"),
510 resolve(&s.proxy.tls_key, "ca-key.pem"),
511 )
512}
513
514#[cfg(feature = "proxy-tls")]
519pub fn generate_ca(cert_path: &std::path::Path, key_path: &std::path::Path) -> crate::Result<()> {
520 use rcgen::{
521 BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyUsagePurpose,
522 };
523
524 if let Some(parent) = cert_path.parent() {
526 std::fs::create_dir_all(parent)
527 .map_err(|e| miette::miette!("Failed to create proxy cert directory: {e}"))?;
528 }
529
530 let mut params = CertificateParams::default();
531 let mut dn = DistinguishedName::new();
532 dn.push(DnType::CommonName, "Pitchfork Local CA");
533 dn.push(DnType::OrganizationName, "Pitchfork");
534 params.distinguished_name = dn;
535 params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
536 params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
537
538 let key_pair = rcgen::KeyPair::generate()
539 .map_err(|e| miette::miette!("Failed to generate CA key pair: {e}"))?;
540 let ca_cert = params
541 .self_signed(&key_pair)
542 .map_err(|e| miette::miette!("Failed to self-sign CA certificate: {e}"))?;
543
544 std::fs::write(cert_path, ca_cert.pem()).map_err(|e| {
546 miette::miette!(
547 "Failed to write CA certificate to {}: {e}",
548 cert_path.display()
549 )
550 })?;
551
552 {
556 #[cfg(unix)]
557 {
558 use std::io::Write;
559 use std::os::unix::fs::OpenOptionsExt;
560 std::fs::OpenOptions::new()
561 .write(true)
562 .create(true)
563 .truncate(true)
564 .mode(0o600)
565 .open(key_path)
566 .and_then(|mut f| f.write_all(key_pair.serialize_pem().as_bytes()))
567 .map_err(|e| {
568 miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
569 })?;
570 }
571 #[cfg(not(unix))]
572 {
573 std::fs::write(key_path, key_pair.serialize_pem()).map_err(|e| {
574 miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
575 })?;
576 log::debug!(
577 "CA private key written to {} (file permissions are not restricted \
578 on non-Unix platforms — consider restricting access manually)",
579 key_path.display()
580 );
581 }
582 }
583
584 Ok(())
585}
586
587#[cfg(feature = "proxy-tls")]
607struct SniCertResolver {
608 issuer: rcgen::Issuer<'static, rcgen::KeyPair>,
610 host_certs_dir: std::path::PathBuf,
612 cache: std::sync::Mutex<std::collections::HashMap<String, Arc<rustls::sign::CertifiedKey>>>,
614 pending: std::sync::Mutex<std::collections::HashSet<String>>,
618 pending_cv: std::sync::Condvar,
620}
621
622#[cfg(feature = "proxy-tls")]
623impl std::fmt::Debug for SniCertResolver {
624 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
625 f.debug_struct("SniCertResolver").finish_non_exhaustive()
626 }
627}
628
629#[cfg(feature = "proxy-tls")]
630impl SniCertResolver {
631 fn new(ca_cert_path: &std::path::Path, ca_key_path: &std::path::Path) -> crate::Result<Self> {
633 let ca_key_pem = std::fs::read_to_string(ca_key_path)
634 .map_err(|e| miette::miette!("Failed to read CA key {}: {e}", ca_key_path.display()))?;
635 let ca_cert_pem = std::fs::read_to_string(ca_cert_path).map_err(|e| {
636 miette::miette!("Failed to read CA cert {}: {e}", ca_cert_path.display())
637 })?;
638
639 if !ca_cert_pem.contains("BEGIN CERTIFICATE") {
641 miette::bail!("CA cert file does not contain a valid PEM certificate");
642 }
643
644 let ca_key = rcgen::KeyPair::from_pem(&ca_key_pem)
645 .map_err(|e| miette::miette!("Failed to parse CA key: {e}"))?;
646
647 let issuer = rcgen::Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key)
649 .map_err(|e| miette::miette!("Failed to parse CA cert: {e}"))?;
650
651 let host_certs_dir = ca_cert_path
653 .parent()
654 .unwrap_or(std::path::Path::new("."))
655 .join("host-certs");
656 std::fs::create_dir_all(&host_certs_dir)
657 .map_err(|e| miette::miette!("Failed to create host-certs dir: {e}"))?;
658
659 Ok(Self {
660 issuer,
661 host_certs_dir,
662 cache: std::sync::Mutex::new(std::collections::HashMap::new()),
663 pending: std::sync::Mutex::new(std::collections::HashSet::new()),
664 pending_cv: std::sync::Condvar::new(),
665 })
666 }
667
668 fn get_or_create(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
691 {
693 let cache = self.cache.lock().ok()?;
694 if let Some(ck) = cache.get(domain) {
695 return Some(Arc::clone(ck));
696 }
697 } loop {
710 {
711 let mut pending = self.pending.lock().ok()?;
712 if pending.contains(domain) {
713 pending = self.pending_cv.wait(pending).ok()?;
715 drop(pending);
717 } else {
718 pending.insert(domain.to_string());
720 break;
721 }
722 } {
728 let cache = self.cache.lock().ok()?;
729 if let Some(ck) = cache.get(domain) {
730 return Some(Arc::clone(ck));
731 }
732 } } let result = self.get_or_create_inner(domain);
736
737 {
743 let mut pending = match self.pending.lock() {
744 Ok(g) => g,
745 Err(e) => e.into_inner(),
746 };
747 pending.remove(domain);
748 self.pending_cv.notify_all();
749 }
750
751 result
752 }
753
754 fn get_or_create_inner(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
756 let safe_name = domain.replace('.', "_").replace('*', "wildcard");
757 let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
758
759 if disk_path.exists() {
761 if let Ok(ck) = self.load_from_disk(&disk_path) {
762 let ck = Arc::new(ck);
763 if let Ok(mut cache) = self.cache.lock() {
764 cache.insert(domain.to_string(), Arc::clone(&ck));
765 }
766 return Some(ck);
767 }
768 let _ = std::fs::remove_file(&disk_path);
770 }
771
772 let ck = self.sign_for_domain(domain).ok()?;
774
775 let ck = Arc::new(ck);
776 if let Ok(mut cache) = self.cache.lock() {
777 cache.insert(domain.to_string(), Arc::clone(&ck));
778 }
779 Some(ck)
780 }
781
782 fn load_from_disk(&self, path: &std::path::Path) -> crate::Result<rustls::sign::CertifiedKey> {
787 use rustls::pki_types::CertificateDer;
788 use rustls_pemfile::{certs, private_key};
789
790 let pem = std::fs::read_to_string(path)
791 .map_err(|e| miette::miette!("Failed to read disk cert {}: {e}", path.display()))?;
792
793 let cert_ders: Vec<CertificateDer<'static>> = certs(&mut pem.as_bytes())
794 .collect::<Result<Vec<_>, _>>()
795 .map_err(|e| miette::miette!("Failed to parse certs from {}: {e}", path.display()))?;
796
797 if cert_ders.is_empty() {
798 miette::bail!("No certificates found in {}", path.display());
799 }
800
801 {
803 let (_, cert) = x509_parser::parse_x509_certificate(&cert_ders[0]).map_err(|e| {
804 miette::miette!("Failed to parse certificate from {}: {e}", path.display())
805 })?;
806 use chrono::Utc;
807 let now_ts = Utc::now().timestamp();
808 let not_after_ts = cert.validity().not_after.timestamp();
809 if not_after_ts < now_ts {
810 miette::bail!(
811 "Cached certificate at {} has expired — will regenerate",
812 path.display()
813 );
814 }
815 }
816
817 let key_der = private_key(&mut pem.as_bytes())
818 .map_err(|e| miette::miette!("Failed to parse key from {}: {e}", path.display()))?
819 .ok_or_else(|| miette::miette!("No private key found in {}", path.display()))?;
820
821 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
822 .map_err(|e| miette::miette!("Failed to create signing key from disk: {e}"))?;
823
824 Ok(rustls::sign::CertifiedKey::new(cert_ders, signing_key))
825 }
826
827 fn sign_for_domain(&self, domain: &str) -> crate::Result<rustls::sign::CertifiedKey> {
835 use rcgen::date_time_ymd;
836 use rcgen::{CertificateParams, DistinguishedName, DnType, SanType};
837 use rustls::pki_types::CertificateDer;
838 use rustls_pemfile::private_key;
839
840 let mut params = CertificateParams::default();
841 let mut dn = DistinguishedName::new();
842 dn.push(DnType::CommonName, domain);
843 params.distinguished_name = dn;
844
845 {
847 use chrono::{Datelike, Duration, Utc};
848 let yesterday = Utc::now() - Duration::days(1);
849 let expiry = Utc::now() + Duration::days(397);
852 params.not_before = date_time_ymd(
853 yesterday.year(),
854 yesterday.month() as u8,
855 yesterday.day() as u8,
856 );
857 params.not_after =
858 date_time_ymd(expiry.year(), expiry.month() as u8, expiry.day() as u8);
859 }
860
861 let mut sans =
863 vec![SanType::DnsName(domain.to_string().try_into().map_err(
864 |e| miette::miette!("Invalid domain name '{domain}': {e}"),
865 )?)];
866 if let Some(dot_pos) = domain.find('.') {
868 let parent = &domain[dot_pos + 1..];
869 if parent.contains('.') {
871 let wildcard = format!("*.{parent}");
872 if let Ok(wc) = wildcard.try_into() {
873 sans.push(SanType::DnsName(wc));
874 }
875 }
876 }
877 params.subject_alt_names = sans;
878
879 let leaf_key = rcgen::KeyPair::generate()
880 .map_err(|e| miette::miette!("Failed to generate leaf key: {e}"))?;
881 let leaf_cert = params
882 .signed_by(&leaf_key, &self.issuer)
883 .map_err(|e| miette::miette!("Failed to sign leaf cert for '{domain}': {e}"))?;
884
885 let cert_der = CertificateDer::from(leaf_cert.der().to_vec());
887 let key_pem = leaf_key.serialize_pem();
888 let key_der = private_key(&mut key_pem.as_bytes())
889 .map_err(|e| miette::miette!("Failed to parse leaf key PEM: {e}"))?
890 .ok_or_else(|| miette::miette!("No private key found in generated PEM"))?;
891
892 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
893 .map_err(|e| miette::miette!("Failed to create signing key: {e}"))?;
894
895 let safe_name = domain.replace('.', "_").replace('*', "wildcard");
898 let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
899 let combined_pem = format!("{}{}", leaf_cert.pem(), key_pem);
900 {
901 #[cfg(unix)]
902 {
903 use std::io::Write;
904 use std::os::unix::fs::OpenOptionsExt;
905 if let Err(e) = std::fs::OpenOptions::new()
906 .write(true)
907 .create(true)
908 .truncate(true)
909 .mode(0o600)
910 .open(&disk_path)
911 .and_then(|mut f| f.write_all(combined_pem.as_bytes()))
912 {
913 log::warn!(
914 "Failed to persist cert for '{domain}' to {}: {e}",
915 disk_path.display()
916 );
917 }
918 }
919 #[cfg(not(unix))]
920 {
921 if let Err(e) = std::fs::write(&disk_path, combined_pem) {
922 log::warn!(
923 "Failed to persist cert for '{domain}' to {}: {e}",
924 disk_path.display()
925 );
926 } else {
927 log::debug!(
928 "Leaf cert for '{domain}' written to {} (file permissions are not \
929 restricted on non-Unix platforms — consider restricting access manually)",
930 disk_path.display()
931 );
932 }
933 }
934 }
935
936 Ok(rustls::sign::CertifiedKey::new(vec![cert_der], signing_key))
937 }
938}
939
940#[cfg(feature = "proxy-tls")]
941impl rustls::server::ResolvesServerCert for SniCertResolver {
942 fn resolve(
943 &self,
944 client_hello: rustls::server::ClientHello<'_>,
945 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
946 let domain = client_hello.server_name()?;
947 self.get_or_create(domain)
948 }
949}
950
951fn get_request_host(req: &Request) -> Option<String> {
957 let authority = req
959 .uri()
960 .authority()
961 .map(|a| a.as_str().to_string())
962 .filter(|s| !s.is_empty());
963
964 authority.or_else(|| {
965 req.headers()
966 .get(HOST)
967 .and_then(|h| h.to_str().ok())
968 .map(str::to_string)
969 })
970}
971
972fn inject_forwarded_headers(req: &mut Request, is_tls: bool, host_header: &str) {
983 let remote_addr = req
984 .extensions()
985 .get::<axum::extract::ConnectInfo<SocketAddr>>()
986 .map(|ci| ci.0.ip().to_string())
987 .unwrap_or_else(|| "127.0.0.1".to_string());
988
989 let proto = if is_tls { "https" } else { "http" };
990 let default_port = if is_tls { "443" } else { "80" };
991
992 let forwarded_for = remote_addr.clone();
995 let forwarded_proto = proto.to_string();
996 let forwarded_host = host_header.to_string();
997 let forwarded_port = host_header
998 .rsplit_once(':')
999 .map(|(_, port)| port.to_string())
1000 .unwrap_or_else(|| default_port.to_string());
1001
1002 for name in [
1008 "x-forwarded-for",
1009 "x-forwarded-proto",
1010 "x-forwarded-host",
1011 "x-forwarded-port",
1012 "forwarded",
1013 ] {
1014 if let Ok(header_name) = axum::http::HeaderName::from_bytes(name.as_bytes()) {
1015 req.headers_mut().remove(&header_name);
1016 }
1017 }
1018
1019 let headers = [
1020 ("x-forwarded-for", forwarded_for),
1021 ("x-forwarded-proto", forwarded_proto),
1022 ("x-forwarded-host", forwarded_host),
1023 ("x-forwarded-port", forwarded_port),
1024 ];
1025
1026 for (name, value) in headers {
1027 if let Ok(v) = HeaderValue::from_str(&value) {
1028 let header_name = axum::http::HeaderName::from_static(name);
1029 req.headers_mut().insert(header_name, v);
1030 }
1031 }
1032}
1033
1034async fn proxy_handler(State(state): State<ProxyState>, mut req: Request) -> Response {
1039 let Some(raw_host) = get_request_host(&req) else {
1041 return error_response(StatusCode::BAD_REQUEST, "Missing Host header");
1042 };
1043 let host = if raw_host.starts_with('[') {
1047 raw_host
1049 .split("]:")
1050 .next()
1051 .unwrap_or(&raw_host)
1052 .trim_start_matches('[')
1053 .trim_end_matches(']')
1054 .to_string()
1055 } else {
1056 raw_host.split(':').next().unwrap_or(&raw_host).to_string()
1058 };
1059
1060 let is_from_pitchfork = req.headers().contains_key(PROXY_HOPS_HEADER);
1071 let hops: u64 = if is_from_pitchfork {
1072 req.headers()
1073 .get(PROXY_HOPS_HEADER)
1074 .and_then(|v| v.to_str().ok())
1075 .and_then(|s| s.parse().ok())
1076 .unwrap_or(0)
1077 } else {
1078 0
1080 };
1081 if hops >= MAX_PROXY_HOPS {
1082 return error_response(
1083 StatusCode::LOOP_DETECTED,
1084 &format!(
1085 "Loop detected for '{host}': request has passed through the proxy {hops} times.\n\
1086 This usually means a backend is proxying back through pitchfork without rewriting \n\
1087 the Host header. If you use Vite/webpack proxy, set changeOrigin: true."
1088 ),
1089 );
1090 }
1091
1092 let target_port = match resolve_target(&host, &state.tld).await {
1094 ResolveResult::Ready(port) => port,
1095 ResolveResult::Starting { slug } => {
1096 return starting_html_response(&slug, &raw_host);
1097 }
1098 ResolveResult::NotFound => {
1099 return error_response(
1100 StatusCode::BAD_GATEWAY,
1101 &format!(
1102 "No daemon found for host '{host}'.\n\
1103 Make sure the daemon has a slug, is running, and has a port configured.\n\
1104 Expected format: <slug>.{tld}",
1105 tld = state.tld
1106 ),
1107 );
1108 }
1109 ResolveResult::Error(msg) => {
1110 return error_response(StatusCode::BAD_GATEWAY, &msg);
1111 }
1112 };
1113
1114 let path_and_query = req
1116 .uri()
1117 .path_and_query()
1118 .map(|pq| pq.as_str())
1119 .unwrap_or("/");
1120
1121 let forward_uri = match Uri::builder()
1122 .scheme("http")
1123 .authority(format!("localhost:{target_port}"))
1124 .path_and_query(path_and_query)
1125 .build()
1126 {
1127 Ok(uri) => uri,
1128 Err(e) => {
1129 return error_response(
1130 StatusCode::INTERNAL_SERVER_ERROR,
1131 &format!("Failed to build forward URI: {e}"),
1132 );
1133 }
1134 };
1135
1136 *req.uri_mut() = forward_uri;
1138 req.headers_mut().insert(
1139 HOST,
1140 HeaderValue::from_str(&format!("localhost:{target_port}"))
1141 .unwrap_or_else(|_| HeaderValue::from_static("localhost")),
1142 );
1143
1144 inject_forwarded_headers(&mut req, state.is_tls, &raw_host);
1146
1147 if let Ok(v) = HeaderValue::from_str(&(hops + 1).to_string()) {
1149 req.headers_mut()
1150 .insert(axum::http::HeaderName::from_static(PROXY_HOPS_HEADER), v);
1151 }
1152
1153 let pseudo_headers: Vec<_> = req
1158 .headers()
1159 .keys()
1160 .filter(|k| k.as_str().starts_with(':'))
1161 .cloned()
1162 .collect();
1163 for key in pseudo_headers {
1164 req.headers_mut().remove(&key);
1165 }
1166
1167 let client_upgrade = hyper::upgrade::on(&mut req);
1169
1170 let result = match tokio::time::timeout(
1178 std::time::Duration::from_secs(120),
1179 state.client.request(req),
1180 )
1181 .await
1182 {
1183 Ok(r) => r,
1184 Err(_elapsed) => {
1185 let msg = format!(
1186 "Request to daemon on port {target_port} timed out after 120 s.\n\
1187 The daemon accepted the connection but did not respond in time."
1188 );
1189 log::warn!("{msg}");
1190 if let Some(ref on_error) = state.on_error {
1191 on_error(&msg);
1192 }
1193 return error_response(StatusCode::GATEWAY_TIMEOUT, &msg);
1194 }
1195 };
1196 match result {
1197 Ok(mut resp) => {
1198 let backend_upgrade = hyper::upgrade::on(&mut resp);
1200 let (mut parts, body) = resp.into_parts();
1201
1202 parts.headers.insert(
1204 axum::http::HeaderName::from_static(PITCHFORK_HEADER),
1205 HeaderValue::from_static("1"),
1206 );
1207
1208 parts.headers.remove(PROXY_HOPS_HEADER);
1210
1211 if state.is_tls && parts.status != StatusCode::SWITCHING_PROTOCOLS {
1216 for h in HOP_BY_HOP_HEADERS {
1217 if let Ok(name) = axum::http::HeaderName::from_bytes(h.as_bytes()) {
1218 parts.headers.remove(&name);
1219 }
1220 }
1221 }
1222
1223 if parts.status == StatusCode::SWITCHING_PROTOCOLS {
1225 tokio::spawn(async move {
1230 if let (Ok(client_upgraded), Ok(backend_upgraded)) =
1231 (client_upgrade.await, backend_upgrade.await)
1232 {
1233 let mut client_io = hyper_util::rt::TokioIo::new(client_upgraded);
1234 let mut backend_io = hyper_util::rt::TokioIo::new(backend_upgraded);
1235 let _ =
1243 tokio::io::copy_bidirectional(&mut client_io, &mut backend_io).await;
1244 }
1245 });
1246 return Response::from_parts(parts, Body::empty());
1247 }
1248
1249 Response::from_parts(parts, Body::new(body))
1252 }
1253 Err(e) => {
1254 let msg = format!(
1255 "Failed to connect to daemon on port {target_port}: {e}\n\
1256 The daemon may have stopped or is not yet ready."
1257 );
1258 if let Some(ref on_error) = state.on_error {
1259 on_error(&msg);
1260 } else {
1261 log::warn!("{msg}");
1262 }
1263 error_response(StatusCode::BAD_GATEWAY, &msg)
1264 }
1265 }
1266}
1267
1268async fn resolve_target(host: &str, tld: &str) -> ResolveResult {
1287 let Some(subdomain) = strip_tld(host, tld) else {
1289 return ResolveResult::NotFound;
1290 };
1291
1292 let Some(cached) = cached_slug_lookup(&subdomain).await else {
1294 return ResolveResult::NotFound;
1295 };
1296
1297 let daemon_name = &cached.daemon_name;
1298 let expected_namespace = &cached.namespace;
1299
1300 let daemons = {
1302 let state_file = SUPERVISOR.state_file.lock().await;
1303 state_file.daemons.clone()
1304 };
1305
1306 let running_matches: Vec<(&DaemonId, &crate::daemon::Daemon)> = daemons
1309 .iter()
1310 .filter(|(id, d)| {
1311 id.name() == daemon_name
1312 && d.status.is_running()
1313 && match expected_namespace {
1314 Some(ns) => id.namespace() == ns,
1315 None => true,
1316 }
1317 })
1318 .collect();
1319
1320 match running_matches.as_slice() {
1321 [] => {
1322 try_auto_start(&cached.slug, &cached).await
1326 }
1327 [(_, d)] => {
1328 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1329 ResolveResult::Ready(port)
1330 } else {
1331 ResolveResult::NotFound
1332 }
1333 }
1334 _ => {
1335 let d = running_matches[0].1;
1336 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1337 ResolveResult::Ready(port)
1338 } else {
1339 ResolveResult::NotFound
1340 }
1341 }
1342 }
1343}
1344
1345struct AutoStartGuard {
1352 daemon_id: DaemonId,
1353}
1354
1355impl Drop for AutoStartGuard {
1356 fn drop(&mut self) {
1357 let daemon_id = self.daemon_id.clone();
1358 tokio::spawn(async move {
1362 AUTO_START_IN_PROGRESS.lock().await.remove(&daemon_id);
1363 });
1364 }
1365}
1366
1367async fn try_auto_start(slug: &str, cached: &CachedSlugEntry) -> ResolveResult {
1378 let s = settings();
1379 if !s.proxy.auto_start {
1380 return ResolveResult::NotFound;
1381 }
1382
1383 let ns = cached
1388 .namespace
1389 .clone()
1390 .unwrap_or_else(|| "global".to_string());
1391 let daemon_id = match DaemonId::try_new(&ns, &cached.daemon_name) {
1392 Ok(id) => id,
1393 Err(_) => return ResolveResult::NotFound,
1394 };
1395
1396 {
1399 let mut in_progress = AUTO_START_IN_PROGRESS.lock().await;
1400 if !in_progress.insert(daemon_id.clone()) {
1401 return ResolveResult::Starting {
1402 slug: slug.to_string(),
1403 };
1404 }
1405 }
1406
1407 let _guard = AutoStartGuard {
1409 daemon_id: daemon_id.clone(),
1410 };
1411
1412 let timeout = s.proxy_auto_start_timeout();
1416
1417 match tokio::time::timeout(timeout, try_auto_start_inner(slug, cached, &daemon_id)).await {
1418 Ok(result) => result,
1419 Err(_elapsed) => {
1420 log::warn!("Auto-start: total timeout ({timeout:?}) exceeded for daemon {daemon_id}");
1421 ResolveResult::Error(format!(
1422 "Auto-start for '{daemon_id}' timed out after {timeout:?}.\n\
1423 The daemon did not become ready and bind a port within the configured \
1424 proxy_auto_start_timeout.\n\
1425 Increase the timeout or check the daemon's logs for slow startup."
1426 ))
1427 }
1428 }
1429}
1430
1431async fn try_auto_start_inner(
1435 slug: &str,
1436 cached: &CachedSlugEntry,
1437 daemon_id: &DaemonId,
1438) -> ResolveResult {
1439 let pt = match crate::pitchfork_toml::PitchforkToml::all_merged_from(&cached.dir) {
1441 Ok(pt) => pt,
1442 Err(e) => {
1443 log::warn!(
1444 "Auto-start: failed to load config from {}: {e}",
1445 cached.dir.display()
1446 );
1447 return ResolveResult::NotFound;
1448 }
1449 };
1450
1451 let daemon_config = match pt.daemons.get(daemon_id) {
1452 Some(cfg) => cfg,
1453 None => {
1454 log::debug!(
1455 "Auto-start: daemon {daemon_id} not found in config at {}",
1456 cached.dir.display()
1457 );
1458 return ResolveResult::NotFound;
1459 }
1460 };
1461
1462 let opts = crate::ipc::batch::StartOptions::default();
1466 let mut run_opts = match crate::ipc::batch::build_run_options(daemon_id, daemon_config, &opts) {
1467 Ok(o) => o,
1468 Err(e) => {
1469 log::warn!("Auto-start: failed to build run options for {daemon_id}: {e}");
1470 return ResolveResult::Error(format!("Failed to build run options: {e}"));
1471 }
1472 };
1473
1474 if run_opts.dir.0.as_os_str().is_empty() {
1475 run_opts.dir = crate::config_types::Dir(cached.dir.clone());
1476 }
1477
1478 log::info!("Auto-start: starting daemon {daemon_id} for slug '{slug}'");
1479
1480 let run_result = SUPERVISOR.run(run_opts).await;
1483
1484 if let Err(e) = run_result {
1485 log::warn!("Auto-start: failed to start daemon {daemon_id}: {e}");
1486 return ResolveResult::Error(format!("Failed to start daemon: {e}"));
1487 }
1488
1489 let poll_interval = std::time::Duration::from_millis(250);
1493
1494 loop {
1495 let daemons = {
1496 let sf = SUPERVISOR.state_file.lock().await;
1497 sf.daemons.clone()
1498 };
1499
1500 if let Some(d) = daemons.get(daemon_id) {
1501 if d.status.is_running() {
1502 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1503 log::info!("Auto-start: daemon {daemon_id} is ready on port {port}");
1504 return ResolveResult::Ready(port);
1505 }
1506 } else {
1507 log::warn!(
1508 "Auto-start: daemon {daemon_id} is no longer running (status: {})",
1509 d.status
1510 );
1511 return ResolveResult::Error(format!(
1512 "Daemon '{daemon_id}' started but exited unexpectedly.\n\
1513 Check its logs for errors."
1514 ));
1515 }
1516 } else {
1517 log::warn!("Auto-start: daemon {daemon_id} not found in state file after start");
1520 return ResolveResult::Error(format!(
1521 "Daemon '{daemon_id}' started but disappeared from the state file.\n\
1522 Check its logs for errors."
1523 ));
1524 }
1525
1526 tokio::time::sleep(poll_interval).await;
1527 }
1528}
1529
1530fn strip_tld(host: &str, tld: &str) -> Option<String> {
1537 host.strip_suffix(&format!(".{tld}"))
1538 .filter(|s| !s.is_empty())
1539 .map(str::to_string)
1540}
1541
1542fn bind_error_message(port: u16, err: &std::io::Error) -> String {
1544 if port < 1024 {
1545 format!(
1546 "Failed to bind proxy server to port {port}: {err}\n\
1547 Hint: ports below 1024 require elevated privileges. \
1548 Try: sudo pitchfork supervisor start"
1549 )
1550 } else {
1551 format!(
1552 "Failed to bind proxy server to port {port}: {err}\n\
1553 Hint: another process may already be using this port."
1554 )
1555 }
1556}
1557
1558fn starting_html_response(slug: &str, raw_host: &str) -> Response {
1563 let escaped_slug = slug
1564 .replace('&', "&")
1565 .replace('<', "<")
1566 .replace('>', ">")
1567 .replace('"', """)
1568 .replace('\'', "'");
1569 let escaped_host = raw_host
1570 .replace('&', "&")
1571 .replace('<', "<")
1572 .replace('>', ">")
1573 .replace('"', """)
1574 .replace('\'', "'");
1575
1576 let html = format!(
1577 r##"<!DOCTYPE html>
1578<html lang="en">
1579<head>
1580 <meta charset="UTF-8">
1581 <meta name="viewport" content="width=device-width, initial-scale=1">
1582 <meta http-equiv="refresh" content="2">
1583 <title>Starting {escaped_slug}… — pitchfork</title>
1584 <style>
1585 * {{ margin: 0; padding: 0; box-sizing: border-box; }}
1586 body {{
1587 font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
1588 background: #0f1117;
1589 color: #e1e4e8;
1590 display: flex;
1591 align-items: center;
1592 justify-content: center;
1593 min-height: 100vh;
1594 }}
1595 .container {{
1596 text-align: center;
1597 max-width: 480px;
1598 padding: 2rem;
1599 }}
1600 .spinner {{
1601 width: 48px;
1602 height: 48px;
1603 border: 4px solid rgba(255, 255, 255, 0.1);
1604 border-top-color: #58a6ff;
1605 border-radius: 50%;
1606 animation: spin 0.8s linear infinite;
1607 margin: 0 auto 1.5rem;
1608 }}
1609 @keyframes spin {{
1610 to {{ transform: rotate(360deg); }}
1611 }}
1612 h1 {{
1613 font-size: 1.5rem;
1614 font-weight: 600;
1615 margin-bottom: 0.5rem;
1616 }}
1617 .slug {{
1618 color: #58a6ff;
1619 font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
1620 }}
1621 .host {{
1622 color: #8b949e;
1623 font-size: 0.875rem;
1624 margin-top: 0.25rem;
1625 }}
1626 .hint {{
1627 color: #8b949e;
1628 font-size: 0.8rem;
1629 margin-top: 1.5rem;
1630 }}
1631 </style>
1632</head>
1633<body>
1634 <div class="container">
1635 <div class="spinner"></div>
1636 <h1>Starting <span class="slug">{escaped_slug}</span>…</h1>
1637 <p class="host">{escaped_host}</p>
1638 <p class="hint">This page will refresh automatically when the daemon is ready.</p>
1639 </div>
1640</body>
1641</html>"##
1642 );
1643
1644 Response::builder()
1645 .status(StatusCode::SERVICE_UNAVAILABLE)
1646 .header("content-type", "text/html; charset=utf-8")
1647 .header("retry-after", "2")
1648 .body(Body::from(html))
1649 .unwrap_or_else(|_| (StatusCode::SERVICE_UNAVAILABLE, "Starting…").into_response())
1650}
1651
1652async fn redirect_to_https_handler(req: Request) -> Response {
1661 if req.headers().contains_key("upgrade") {
1663 log::warn!("Dropping plain-HTTP WebSocket upgrade attempt — use wss:// instead of ws://");
1664 return (
1665 StatusCode::BAD_REQUEST,
1666 "WebSocket over plain HTTP is not supported on the HTTPS port. Use wss:// instead.",
1667 )
1668 .into_response();
1669 }
1670
1671 let raw_host = get_request_host(&req);
1672 let Some(raw_host) = raw_host else {
1673 return (StatusCode::BAD_REQUEST, "Missing Host header").into_response();
1674 };
1675
1676 let hostname = if raw_host.starts_with('[') {
1678 raw_host
1680 .split_once("]:")
1681 .map(|(host, _)| host)
1682 .unwrap_or(&raw_host)
1683 .trim_start_matches('[')
1684 .trim_end_matches(']')
1685 } else {
1686 let mut parts = raw_host.rsplitn(2, ':');
1688 let last = parts.next().unwrap_or(&raw_host);
1689 parts.next().unwrap_or(last)
1690 };
1691
1692 let path = req
1693 .uri()
1694 .path_and_query()
1695 .map(|pq| pq.as_str())
1696 .unwrap_or("/");
1697
1698 let https_port = match u16::try_from(settings().proxy.port).ok().filter(|&p| p > 0) {
1699 Some(443) | None => String::new(),
1700 Some(port) => format!(":{port}"),
1701 };
1702
1703 let host_for_url = if raw_host.starts_with('[') {
1704 format!("[{hostname}]")
1705 } else {
1706 hostname.to_string()
1707 };
1708
1709 let location = format!("https://{host_for_url}{https_port}{path}");
1710 (
1711 StatusCode::FOUND,
1712 [(axum::http::header::LOCATION, location)],
1713 )
1714 .into_response()
1715}
1716
1717fn error_response(status: StatusCode, message: &str) -> Response {
1719 (status, message.to_string()).into_response()
1720}
1721
1722#[cfg(test)]
1723mod tests {
1724 use super::*;
1725
1726 #[test]
1727 fn test_strip_tld() {
1728 assert_eq!(
1729 strip_tld("api.myproject.localhost", "localhost"),
1730 Some("api.myproject".to_string())
1731 );
1732 assert_eq!(
1733 strip_tld("api.localhost", "localhost"),
1734 Some("api".to_string())
1735 );
1736 assert_eq!(strip_tld("localhost", "localhost"), None);
1737 assert_eq!(
1738 strip_tld("api.myproject.test", "test"),
1739 Some("api.myproject".to_string())
1740 );
1741 assert_eq!(strip_tld("other.com", "localhost"), None);
1742 }
1743
1744 fn make_entry(name: &str) -> CachedSlugEntry {
1745 CachedSlugEntry {
1746 slug: name.to_string(),
1747 namespace: None,
1748 daemon_name: name.to_string(),
1749 dir: std::path::PathBuf::from(format!("/tmp/{name}")),
1750 }
1751 }
1752
1753 #[test]
1754 fn test_wildcard_slug_lookup_exact_match() {
1755 let mut entries = std::collections::HashMap::new();
1756 entries.insert("myapp".to_string(), make_entry("myapp"));
1757 let result = wildcard_slug_lookup("myapp", &entries, true);
1759 assert!(result.is_some());
1760 assert_eq!(result.unwrap().daemon_name, "myapp");
1761 }
1762
1763 #[test]
1764 fn test_wildcard_slug_lookup_subdomain_fallback() {
1765 let mut entries = std::collections::HashMap::new();
1766 entries.insert("myapp".to_string(), make_entry("myapp"));
1767 let result = wildcard_slug_lookup("tenant.myapp", &entries, true);
1769 assert!(result.is_some());
1770 assert_eq!(result.unwrap().daemon_name, "myapp");
1771 }
1772
1773 #[test]
1774 fn test_wildcard_slug_lookup_nested_fallback() {
1775 let mut entries = std::collections::HashMap::new();
1776 entries.insert("myapp".to_string(), make_entry("myapp"));
1777 let result = wildcard_slug_lookup("a.b.myapp", &entries, true);
1779 assert!(result.is_some());
1780 assert_eq!(result.unwrap().daemon_name, "myapp");
1781 }
1782
1783 #[test]
1784 fn test_wildcard_slug_lookup_no_match() {
1785 let entries = std::collections::HashMap::new();
1786 let result = wildcard_slug_lookup("tenant.myapp", &entries, true);
1788 assert!(result.is_none());
1789 }
1790
1791 #[test]
1792 fn test_wildcard_slug_lookup_disabled() {
1793 let mut entries = std::collections::HashMap::new();
1794 entries.insert("myapp".to_string(), make_entry("myapp"));
1795 let result = wildcard_slug_lookup("tenant.myapp", &entries, false);
1797 assert!(result.is_none());
1798 let result = wildcard_slug_lookup("myapp", &entries, false);
1800 assert!(result.is_some());
1801 }
1802
1803 #[test]
1804 fn test_wildcard_slug_lookup_exact_beats_wildcard() {
1805 let mut entries = std::collections::HashMap::new();
1806 entries.insert("myapp".to_string(), make_entry("myapp"));
1807 let mut tenant_entry = make_entry("tenant-daemon");
1808 tenant_entry.slug = "tenant.myapp".to_string();
1809 entries.insert("tenant.myapp".to_string(), tenant_entry);
1810 let result = wildcard_slug_lookup("tenant.myapp", &entries, true);
1812 assert!(result.is_some());
1813 assert_eq!(result.unwrap().daemon_name, "tenant-daemon");
1814 }
1815
1816 #[cfg(feature = "proxy-tls")]
1817 #[test]
1818 fn test_generate_ca() {
1819 let dir = tempfile::tempdir().unwrap();
1820 let cert_path = dir.path().join("ca.pem");
1821 let key_path = dir.path().join("ca-key.pem");
1822
1823 generate_ca(&cert_path, &key_path).unwrap();
1824
1825 assert!(cert_path.exists(), "ca.pem should be created");
1826 assert!(key_path.exists(), "ca-key.pem should be created");
1827
1828 let cert_pem = std::fs::read_to_string(&cert_path).unwrap();
1829 let key_pem = std::fs::read_to_string(&key_path).unwrap();
1830
1831 assert!(cert_pem.contains("BEGIN CERTIFICATE"), "should be PEM cert");
1832 assert!(
1833 key_pem.contains("BEGIN") && key_pem.contains("PRIVATE KEY"),
1834 "should be PEM key"
1835 );
1836 }
1837}