1use std::net::SocketAddr;
11use std::sync::Arc;
12
13use axum::Router;
14use axum::body::Body;
15use axum::extract::{Request, State};
16use axum::http::{HeaderValue, StatusCode, Uri};
17use axum::response::{IntoResponse, Response};
18use hyper::header::HOST;
19
20const PITCHFORK_HEADER: &str = "x-pitchfork";
22
23const PROXY_HOPS_HEADER: &str = "x-pitchfork-hops";
26
27const MAX_PROXY_HOPS: u64 = 5;
29
30const HOP_BY_HOP_HEADERS: &[&str] = &[
33 "connection",
34 "keep-alive",
35 "proxy-connection",
36 "transfer-encoding",
37 "upgrade",
38];
39
40use hyper_util::client::legacy::Client;
41use hyper_util::client::legacy::connect::HttpConnector;
42use hyper_util::rt::TokioExecutor;
43use tokio::net::TcpListener;
44
45use crate::daemon_id::DaemonId;
46use crate::settings::settings;
47use crate::supervisor::SUPERVISOR;
48
49const SLUG_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(2);
62
63#[derive(Clone, Debug)]
65struct CachedSlugEntry {
66 slug: String,
68 namespace: Option<String>,
70 daemon_name: String,
72 dir: std::path::PathBuf,
74 worktrees: Vec<crate::proxy::worktree::WorktreeEntry>,
76}
77
78struct SlugCache {
80 entries: Arc<std::collections::HashMap<String, CachedSlugEntry>>,
81 expires_at: std::time::Instant,
82}
83
84static SLUG_CACHE: once_cell::sync::Lazy<tokio::sync::Mutex<SlugCache>> =
85 once_cell::sync::Lazy::new(|| {
86 tokio::sync::Mutex::new(SlugCache {
87 entries: Arc::new(std::collections::HashMap::new()),
88 expires_at: std::time::Instant::now(), })
90 });
91
92fn build_slug_entries() -> std::collections::HashMap<String, CachedSlugEntry> {
95 let global_slugs = crate::pitchfork_toml::PitchforkToml::read_global_slugs();
96 let mut entries = std::collections::HashMap::with_capacity(global_slugs.len());
97 let worktree_enabled = crate::settings::settings().proxy.worktree;
98 for (slug, entry) in &global_slugs {
99 let ns = crate::pitchfork_toml::PitchforkToml::namespace_for_dir(&entry.dir).ok();
100 let daemon_name = entry.daemon.as_deref().unwrap_or(slug).to_string();
101 let worktrees = if worktree_enabled {
102 let wts = crate::proxy::worktree::discover_worktrees(&entry.dir);
103 let mut seen = std::collections::HashMap::with_capacity(wts.len());
106 let mut deduped = Vec::with_capacity(wts.len());
107 for mut wt in wts {
108 let wt_ns = crate::pitchfork_toml::PitchforkToml::namespace_for_dir(&wt.path).ok();
109 wt.namespace = wt_ns;
110 match seen.entry(wt.sanitized_branch.clone()) {
111 std::collections::hash_map::Entry::Occupied(e) => {
112 log::warn!(
113 "Worktree slug collision: '{}' and '{}' both sanitize to '{}'. \
114 Only the first (in discovery order) will be routed.",
115 e.get(),
116 wt.branch,
117 wt.sanitized_branch,
118 );
119 }
120 std::collections::hash_map::Entry::Vacant(e) => {
121 e.insert(wt.branch.clone());
122 deduped.push(wt);
123 }
124 }
125 }
126 deduped
127 } else {
128 vec![]
129 };
130 entries.insert(
131 slug.clone(),
132 CachedSlugEntry {
133 slug: slug.clone(),
134 namespace: ns,
135 daemon_name,
136 dir: entry.dir.clone(),
137 worktrees,
138 },
139 );
140 }
141 entries
142}
143
144async fn get_cached_slugs() -> Arc<std::collections::HashMap<String, CachedSlugEntry>> {
150 {
152 let cache = SLUG_CACHE.lock().await;
153 if std::time::Instant::now() < cache.expires_at {
154 return Arc::clone(&cache.entries);
155 }
156 } let new_entries = Arc::new(
160 tokio::task::spawn_blocking(build_slug_entries)
161 .await
162 .unwrap_or_else(|e| {
163 log::warn!("Failed to refresh slug cache: {e}");
164 std::collections::HashMap::new()
165 }),
166 );
167
168 {
170 let mut cache = SLUG_CACHE.lock().await;
171 cache.entries = Arc::clone(&new_entries);
172 cache.expires_at = std::time::Instant::now() + SLUG_CACHE_TTL;
173 }
174
175 new_entries
176}
177
178fn wildcard_slug_lookup<'a>(
184 subdomain: &str,
185 entries: &'a std::collections::HashMap<String, CachedSlugEntry>,
186 wildcard: bool,
187) -> Option<&'a CachedSlugEntry> {
188 entries.get(subdomain).or_else(|| {
189 if !wildcard {
190 return None;
191 }
192 subdomain
194 .match_indices('.')
195 .map(|(i, _)| &subdomain[i + 1..])
196 .find_map(|candidate| entries.get(candidate))
197 })
198}
199
200async fn cached_slug_lookup(subdomain: &str) -> Option<CachedSlugEntry> {
207 let entries = get_cached_slugs().await;
208 wildcard_slug_lookup(subdomain, &entries, settings().proxy.wildcard).cloned()
209}
210
211static AUTO_START_IN_PROGRESS: once_cell::sync::Lazy<
218 tokio::sync::Mutex<std::collections::HashSet<DaemonId>>,
219> = once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new(std::collections::HashSet::new()));
220
221enum ResolveResult {
223 Ready(u16),
226 Starting { slug: String },
228 NotFound,
230 Error(String),
232}
233
234type OnErrorFn = Arc<dyn Fn(&str) + Send + Sync>;
237
238#[derive(Clone)]
239struct ProxyState {
240 client: Arc<Client<HttpConnector, Body>>,
242 tld: String,
244 is_tls: bool,
246 on_error: Option<OnErrorFn>,
248}
249
250pub async fn serve(
258 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
259 cancel: tokio_util::sync::CancellationToken,
260) -> crate::Result<()> {
261 let s = settings();
262 let lan_enabled = s.proxy.lan || !s.proxy.lan_ip.is_empty();
263
264 let effective_tld = if lan_enabled {
265 "local".to_string()
266 } else {
267 s.proxy.tld.clone()
268 };
269
270 let Some(effective_port) = u16::try_from(s.proxy.port).ok().filter(|&p| p > 0) else {
271 let msg = format!(
272 "proxy.port {} is out of valid port range (1-65535), proxy server cannot start",
273 s.proxy.port
274 );
275 let _ = bind_tx.send(Err(msg.clone()));
276 miette::bail!("{msg}");
277 };
278
279 let mut connector = HttpConnector::new();
280 connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
284
285 let client = Client::builder(TokioExecutor::new())
286 .pool_idle_timeout(std::time::Duration::from_secs(30))
289 .build(connector);
290
291 let state = ProxyState {
292 client: Arc::new(client),
293 tld: effective_tld.clone(),
294 is_tls: s.proxy.https,
295 on_error: None,
296 };
297
298 let app = Router::new().fallback(proxy_handler).with_state(state);
299
300 let bind_ip: std::net::IpAddr = if lan_enabled && s.proxy.host == "127.0.0.1" {
304 std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)
305 } else {
306 match s.proxy.host.parse() {
307 Ok(ip) => ip,
308 Err(_) => {
309 log::warn!(
310 "proxy.host {:?} is not a valid IP address — falling back to 127.0.0.1. \
311 The proxy will only be reachable on the loopback interface.",
312 s.proxy.host
313 );
314 std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
315 }
316 }
317 };
318 let addr = SocketAddr::from((bind_ip, effective_port));
319
320 if s.proxy.https {
321 serve_https_with_http_fallback(app, addr, s, effective_port, bind_tx, cancel).await
322 } else {
323 serve_http(app, addr, effective_port, bind_tx, cancel).await
324 }
325}
326
327async fn serve_http(
329 app: Router,
330 addr: SocketAddr,
331 effective_port: u16,
332 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
333 cancel: tokio_util::sync::CancellationToken,
334) -> crate::Result<()> {
335 let listener = match TcpListener::bind(addr).await {
336 Ok(l) => {
337 if settings().proxy.sync_hosts {
338 crate::proxy::hosts::sync_hosts_from_settings();
339 }
340 let _ = bind_tx.send(Ok(()));
341 l
342 }
343 Err(e) => {
344 let msg = bind_error_message(effective_port, &e);
345 let _ = bind_tx.send(Err(msg.clone()));
346 return Err(miette::miette!("{msg}"));
347 }
348 };
349
350 log::info!("Proxy server listening on http://{addr}");
351 if effective_port < 1024 {
352 log::info!(
353 "Note: port {effective_port} is a privileged port. \
354 The supervisor must be started with sudo to bind to this port."
355 );
356 }
357 let shutdown_signal = cancel.clone().cancelled_owned();
358 axum::serve(
359 listener,
360 app.into_make_service_with_connect_info::<SocketAddr>(),
361 )
362 .with_graceful_shutdown(shutdown_signal)
363 .await
364 .map_err(|e| miette::miette!("Proxy server error: {e}"))?;
365 Ok(())
366}
367
368#[cfg(feature = "proxy-tls")]
374async fn serve_https_with_http_fallback(
375 app: Router,
376 addr: SocketAddr,
377 s: &crate::settings::Settings,
378 effective_port: u16,
379 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
380 cancel: tokio_util::sync::CancellationToken,
381) -> crate::Result<()> {
382 use rustls::ServerConfig;
383 use tokio_rustls::TlsAcceptor;
384
385 let (ca_cert_path, ca_key_path) = resolve_tls_paths(s);
386
387 if !ca_cert_path.exists() || !ca_key_path.exists() {
389 generate_ca(&ca_cert_path, &ca_key_path)?;
390 log::info!(
391 "Generated local CA certificate at {}",
392 ca_cert_path.display()
393 );
394 log::info!("To trust the CA in your browser, run: pitchfork proxy trust");
395 }
396
397 let _ = rustls::crypto::ring::default_provider().install_default();
399
400 let resolver = SniCertResolver::new(&ca_cert_path, &ca_key_path)?;
402
403 let mut tls_config = ServerConfig::builder()
404 .with_no_client_auth()
405 .with_cert_resolver(Arc::new(resolver));
406 tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
409
410 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
411
412 let listener = match TcpListener::bind(addr).await {
413 Ok(l) => {
414 if settings().proxy.sync_hosts {
415 crate::proxy::hosts::sync_hosts_from_settings();
416 }
417 let _ = bind_tx.send(Ok(()));
418 l
419 }
420 Err(e) => {
421 let msg = bind_error_message(effective_port, &e);
422 let _ = bind_tx.send(Err(msg.clone()));
423 return Err(miette::miette!("{msg}"));
424 }
425 };
426
427 log::info!("Proxy server listening on https://{addr} (HTTP also accepted)");
428 if effective_port < 1024 {
429 log::info!(
430 "Note: port {effective_port} is a privileged port. \
431 The supervisor must be started with sudo to bind to this port."
432 );
433 }
434
435 let redirect_app = Router::new().fallback(redirect_to_https_handler);
437
438 let mut conn_tasks: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
440 loop {
441 while conn_tasks.try_join_next().is_some() {}
444
445 tokio::select! {
446 accept_result = listener.accept() => {
447 let (stream, _peer_addr) = match accept_result {
448 Ok(conn) => conn,
449 Err(e) => {
450 log::warn!("Accept error (will retry): {e}");
451 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
452 continue;
453 }
454 };
455
456 let acceptor = acceptor.clone();
457 let app = app.clone();
458 let redirect_app = redirect_app.clone();
459
460 conn_tasks.spawn(async move {
461 let mut peek_buf = [0u8; 1];
464 match stream.peek(&mut peek_buf).await {
465 Ok(0) | Err(_) => return,
466 _ => {}
467 }
468
469 if peek_buf[0] == 0x16 {
470 match acceptor.accept(stream).await {
472 Ok(tls_stream) => {
473 let io = hyper_util::rt::TokioIo::new(tls_stream);
474 let svc = hyper_util::service::TowerToHyperService::new(app);
475 if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
476 .serve_connection_with_upgrades(io, svc)
477 .await
478 {
479 log::debug!("Connection error: {e}");
482 }
483 }
484 Err(e) => {
485 log::debug!("TLS handshake error: {e}");
486 }
487 }
488 } else {
489 let io = hyper_util::rt::TokioIo::new(stream);
491 let svc = hyper_util::service::TowerToHyperService::new(redirect_app);
492 let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
493 .serve_connection_with_upgrades(io, svc)
494 .await;
495 }
496 });
497
498 while conn_tasks.try_join_next().is_some() {}
499 }
500 _ = cancel.cancelled() => {
501 log::info!("Proxy server shutting down (cancel signal received)");
502 break;
503 }
504 }
505 }
506
507 let drain_timeout = std::time::Duration::from_secs(10);
509 let _ = tokio::time::timeout(drain_timeout, async {
510 while conn_tasks.join_next().await.is_some() {}
511 })
512 .await;
513
514 Ok(())
515}
516
517#[cfg(not(feature = "proxy-tls"))]
519async fn serve_https_with_http_fallback(
520 _app: Router,
521 _addr: SocketAddr,
522 _s: &crate::settings::Settings,
523 _effective_port: u16,
524 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
525 _cancel: tokio_util::sync::CancellationToken,
526) -> crate::Result<()> {
527 let msg = "HTTPS proxy support requires the `proxy-tls` feature.\n\
528 Rebuild pitchfork with: cargo build --features proxy-tls"
529 .to_string();
530 let _ = bind_tx.send(Err(msg.clone()));
531 miette::bail!("{msg}")
532}
533
534#[cfg(feature = "proxy-tls")]
539fn resolve_tls_paths(s: &crate::settings::Settings) -> (std::path::PathBuf, std::path::PathBuf) {
540 let proxy_dir = crate::env::PITCHFORK_STATE_DIR.join("proxy");
541 let resolve = |configured: &str, default: &str| {
542 if configured.is_empty() {
543 proxy_dir.join(default)
544 } else {
545 std::path::PathBuf::from(configured)
546 }
547 };
548 (
549 resolve(&s.proxy.tls_cert, "ca.pem"),
550 resolve(&s.proxy.tls_key, "ca-key.pem"),
551 )
552}
553
554#[cfg(feature = "proxy-tls")]
559pub fn generate_ca(cert_path: &std::path::Path, key_path: &std::path::Path) -> crate::Result<()> {
560 use rcgen::{
561 BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyUsagePurpose,
562 };
563
564 if let Some(parent) = cert_path.parent() {
566 std::fs::create_dir_all(parent)
567 .map_err(|e| miette::miette!("Failed to create proxy cert directory: {e}"))?;
568 }
569
570 let mut params = CertificateParams::default();
571 let mut dn = DistinguishedName::new();
572 dn.push(DnType::CommonName, "Pitchfork Local CA");
573 dn.push(DnType::OrganizationName, "Pitchfork");
574 params.distinguished_name = dn;
575 params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
576 params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
577
578 let key_pair = rcgen::KeyPair::generate()
579 .map_err(|e| miette::miette!("Failed to generate CA key pair: {e}"))?;
580 let ca_cert = params
581 .self_signed(&key_pair)
582 .map_err(|e| miette::miette!("Failed to self-sign CA certificate: {e}"))?;
583
584 std::fs::write(cert_path, ca_cert.pem()).map_err(|e| {
586 miette::miette!(
587 "Failed to write CA certificate to {}: {e}",
588 cert_path.display()
589 )
590 })?;
591
592 {
596 #[cfg(unix)]
597 {
598 use std::io::Write;
599 use std::os::unix::fs::OpenOptionsExt;
600 std::fs::OpenOptions::new()
601 .write(true)
602 .create(true)
603 .truncate(true)
604 .mode(0o600)
605 .open(key_path)
606 .and_then(|mut f| f.write_all(key_pair.serialize_pem().as_bytes()))
607 .map_err(|e| {
608 miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
609 })?;
610 }
611 #[cfg(not(unix))]
612 {
613 std::fs::write(key_path, key_pair.serialize_pem()).map_err(|e| {
614 miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
615 })?;
616 log::debug!(
617 "CA private key written to {} (file permissions are not restricted \
618 on non-Unix platforms — consider restricting access manually)",
619 key_path.display()
620 );
621 }
622 }
623
624 Ok(())
625}
626
627#[cfg(feature = "proxy-tls")]
647struct SniCertResolver {
648 issuer: rcgen::Issuer<'static, rcgen::KeyPair>,
650 host_certs_dir: std::path::PathBuf,
652 cache: std::sync::Mutex<std::collections::HashMap<String, Arc<rustls::sign::CertifiedKey>>>,
654 pending: std::sync::Mutex<std::collections::HashSet<String>>,
658 pending_cv: std::sync::Condvar,
660}
661
662#[cfg(feature = "proxy-tls")]
663impl std::fmt::Debug for SniCertResolver {
664 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
665 f.debug_struct("SniCertResolver").finish_non_exhaustive()
666 }
667}
668
669#[cfg(feature = "proxy-tls")]
670impl SniCertResolver {
671 fn new(ca_cert_path: &std::path::Path, ca_key_path: &std::path::Path) -> crate::Result<Self> {
673 let ca_key_pem = std::fs::read_to_string(ca_key_path)
674 .map_err(|e| miette::miette!("Failed to read CA key {}: {e}", ca_key_path.display()))?;
675 let ca_cert_pem = std::fs::read_to_string(ca_cert_path).map_err(|e| {
676 miette::miette!("Failed to read CA cert {}: {e}", ca_cert_path.display())
677 })?;
678
679 if !ca_cert_pem.contains("BEGIN CERTIFICATE") {
681 miette::bail!("CA cert file does not contain a valid PEM certificate");
682 }
683
684 let ca_key = rcgen::KeyPair::from_pem(&ca_key_pem)
685 .map_err(|e| miette::miette!("Failed to parse CA key: {e}"))?;
686
687 let issuer = rcgen::Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key)
689 .map_err(|e| miette::miette!("Failed to parse CA cert: {e}"))?;
690
691 let host_certs_dir = ca_cert_path
693 .parent()
694 .unwrap_or(std::path::Path::new("."))
695 .join("host-certs");
696 std::fs::create_dir_all(&host_certs_dir)
697 .map_err(|e| miette::miette!("Failed to create host-certs dir: {e}"))?;
698
699 Ok(Self {
700 issuer,
701 host_certs_dir,
702 cache: std::sync::Mutex::new(std::collections::HashMap::new()),
703 pending: std::sync::Mutex::new(std::collections::HashSet::new()),
704 pending_cv: std::sync::Condvar::new(),
705 })
706 }
707
708 fn get_or_create(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
731 {
733 let cache = self.cache.lock().ok()?;
734 if let Some(ck) = cache.get(domain) {
735 return Some(Arc::clone(ck));
736 }
737 } loop {
750 {
751 let mut pending = self.pending.lock().ok()?;
752 if pending.contains(domain) {
753 pending = self.pending_cv.wait(pending).ok()?;
755 drop(pending);
757 } else {
758 pending.insert(domain.to_string());
760 break;
761 }
762 } {
768 let cache = self.cache.lock().ok()?;
769 if let Some(ck) = cache.get(domain) {
770 return Some(Arc::clone(ck));
771 }
772 } } let result = self.get_or_create_inner(domain);
776
777 {
783 let mut pending = match self.pending.lock() {
784 Ok(g) => g,
785 Err(e) => e.into_inner(),
786 };
787 pending.remove(domain);
788 self.pending_cv.notify_all();
789 }
790
791 result
792 }
793
794 fn get_or_create_inner(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
796 let safe_name = domain.replace('.', "_").replace('*', "wildcard");
797 let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
798
799 if disk_path.exists() {
801 if let Ok(ck) = self.load_from_disk(&disk_path) {
802 let ck = Arc::new(ck);
803 if let Ok(mut cache) = self.cache.lock() {
804 cache.insert(domain.to_string(), Arc::clone(&ck));
805 }
806 return Some(ck);
807 }
808 let _ = std::fs::remove_file(&disk_path);
810 }
811
812 let ck = self.sign_for_domain(domain).ok()?;
814
815 let ck = Arc::new(ck);
816 if let Ok(mut cache) = self.cache.lock() {
817 cache.insert(domain.to_string(), Arc::clone(&ck));
818 }
819 Some(ck)
820 }
821
822 fn load_from_disk(&self, path: &std::path::Path) -> crate::Result<rustls::sign::CertifiedKey> {
827 use rustls::pki_types::CertificateDer;
828 use rustls_pemfile::{certs, private_key};
829
830 let pem = std::fs::read_to_string(path)
831 .map_err(|e| miette::miette!("Failed to read disk cert {}: {e}", path.display()))?;
832
833 let cert_ders: Vec<CertificateDer<'static>> = certs(&mut pem.as_bytes())
834 .collect::<Result<Vec<_>, _>>()
835 .map_err(|e| miette::miette!("Failed to parse certs from {}: {e}", path.display()))?;
836
837 if cert_ders.is_empty() {
838 miette::bail!("No certificates found in {}", path.display());
839 }
840
841 {
843 let (_, cert) = x509_parser::parse_x509_certificate(&cert_ders[0]).map_err(|e| {
844 miette::miette!("Failed to parse certificate from {}: {e}", path.display())
845 })?;
846 use chrono::Utc;
847 let now_ts = Utc::now().timestamp();
848 let not_after_ts = cert.validity().not_after.timestamp();
849 if not_after_ts < now_ts {
850 miette::bail!(
851 "Cached certificate at {} has expired — will regenerate",
852 path.display()
853 );
854 }
855 }
856
857 let key_der = private_key(&mut pem.as_bytes())
858 .map_err(|e| miette::miette!("Failed to parse key from {}: {e}", path.display()))?
859 .ok_or_else(|| miette::miette!("No private key found in {}", path.display()))?;
860
861 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
862 .map_err(|e| miette::miette!("Failed to create signing key from disk: {e}"))?;
863
864 Ok(rustls::sign::CertifiedKey::new(cert_ders, signing_key))
865 }
866
867 fn sign_for_domain(&self, domain: &str) -> crate::Result<rustls::sign::CertifiedKey> {
875 use rcgen::date_time_ymd;
876 use rcgen::{CertificateParams, DistinguishedName, DnType, SanType};
877 use rustls::pki_types::CertificateDer;
878 use rustls_pemfile::private_key;
879
880 let mut params = CertificateParams::default();
881 let mut dn = DistinguishedName::new();
882 dn.push(DnType::CommonName, domain);
883 params.distinguished_name = dn;
884
885 {
887 use chrono::{Datelike, Duration, Utc};
888 let yesterday = Utc::now() - Duration::days(1);
889 let expiry = Utc::now() + Duration::days(397);
892 params.not_before = date_time_ymd(
893 yesterday.year(),
894 yesterday.month() as u8,
895 yesterday.day() as u8,
896 );
897 params.not_after =
898 date_time_ymd(expiry.year(), expiry.month() as u8, expiry.day() as u8);
899 }
900
901 let mut sans =
903 vec![SanType::DnsName(domain.to_string().try_into().map_err(
904 |e| miette::miette!("Invalid domain name '{domain}': {e}"),
905 )?)];
906 if let Some(dot_pos) = domain.find('.') {
908 let parent = &domain[dot_pos + 1..];
909 if parent.contains('.') {
911 let wildcard = format!("*.{parent}");
912 if let Ok(wc) = wildcard.try_into() {
913 sans.push(SanType::DnsName(wc));
914 }
915 }
916 }
917 params.subject_alt_names = sans;
918
919 let leaf_key = rcgen::KeyPair::generate()
920 .map_err(|e| miette::miette!("Failed to generate leaf key: {e}"))?;
921 let leaf_cert = params
922 .signed_by(&leaf_key, &self.issuer)
923 .map_err(|e| miette::miette!("Failed to sign leaf cert for '{domain}': {e}"))?;
924
925 let cert_der = CertificateDer::from(leaf_cert.der().to_vec());
927 let key_pem = leaf_key.serialize_pem();
928 let key_der = private_key(&mut key_pem.as_bytes())
929 .map_err(|e| miette::miette!("Failed to parse leaf key PEM: {e}"))?
930 .ok_or_else(|| miette::miette!("No private key found in generated PEM"))?;
931
932 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
933 .map_err(|e| miette::miette!("Failed to create signing key: {e}"))?;
934
935 let safe_name = domain.replace('.', "_").replace('*', "wildcard");
938 let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
939 let combined_pem = format!("{}{}", leaf_cert.pem(), key_pem);
940 {
941 #[cfg(unix)]
942 {
943 use std::io::Write;
944 use std::os::unix::fs::OpenOptionsExt;
945 if let Err(e) = std::fs::OpenOptions::new()
946 .write(true)
947 .create(true)
948 .truncate(true)
949 .mode(0o600)
950 .open(&disk_path)
951 .and_then(|mut f| f.write_all(combined_pem.as_bytes()))
952 {
953 log::warn!(
954 "Failed to persist cert for '{domain}' to {}: {e}",
955 disk_path.display()
956 );
957 }
958 }
959 #[cfg(not(unix))]
960 {
961 if let Err(e) = std::fs::write(&disk_path, combined_pem) {
962 log::warn!(
963 "Failed to persist cert for '{domain}' to {}: {e}",
964 disk_path.display()
965 );
966 } else {
967 log::debug!(
968 "Leaf cert for '{domain}' written to {} (file permissions are not \
969 restricted on non-Unix platforms — consider restricting access manually)",
970 disk_path.display()
971 );
972 }
973 }
974 }
975
976 Ok(rustls::sign::CertifiedKey::new(vec![cert_der], signing_key))
977 }
978}
979
980#[cfg(feature = "proxy-tls")]
981impl rustls::server::ResolvesServerCert for SniCertResolver {
982 fn resolve(
983 &self,
984 client_hello: rustls::server::ClientHello<'_>,
985 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
986 let domain = client_hello.server_name()?;
987 self.get_or_create(domain)
988 }
989}
990
991fn get_request_host(req: &Request) -> Option<String> {
997 let authority = req
999 .uri()
1000 .authority()
1001 .map(|a| a.as_str().to_string())
1002 .filter(|s| !s.is_empty());
1003
1004 authority.or_else(|| {
1005 req.headers()
1006 .get(HOST)
1007 .and_then(|h| h.to_str().ok())
1008 .map(str::to_string)
1009 })
1010}
1011
1012fn inject_forwarded_headers(req: &mut Request, is_tls: bool, host_header: &str) {
1023 let remote_addr = req
1024 .extensions()
1025 .get::<axum::extract::ConnectInfo<SocketAddr>>()
1026 .map(|ci| ci.0.ip().to_string())
1027 .unwrap_or_else(|| "127.0.0.1".to_string());
1028
1029 let proto = if is_tls { "https" } else { "http" };
1030 let default_port = if is_tls { "443" } else { "80" };
1031
1032 let forwarded_for = remote_addr.clone();
1035 let forwarded_proto = proto.to_string();
1036 let forwarded_host = host_header.to_string();
1037 let forwarded_port = host_header
1038 .rsplit_once(':')
1039 .map(|(_, port)| port.to_string())
1040 .unwrap_or_else(|| default_port.to_string());
1041
1042 for name in [
1048 "x-forwarded-for",
1049 "x-forwarded-proto",
1050 "x-forwarded-host",
1051 "x-forwarded-port",
1052 "forwarded",
1053 ] {
1054 if let Ok(header_name) = axum::http::HeaderName::from_bytes(name.as_bytes()) {
1055 req.headers_mut().remove(&header_name);
1056 }
1057 }
1058
1059 let headers = [
1060 ("x-forwarded-for", forwarded_for),
1061 ("x-forwarded-proto", forwarded_proto),
1062 ("x-forwarded-host", forwarded_host),
1063 ("x-forwarded-port", forwarded_port),
1064 ];
1065
1066 for (name, value) in headers {
1067 if let Ok(v) = HeaderValue::from_str(&value) {
1068 let header_name = axum::http::HeaderName::from_static(name);
1069 req.headers_mut().insert(header_name, v);
1070 }
1071 }
1072}
1073
1074async fn proxy_handler(State(state): State<ProxyState>, mut req: Request) -> Response {
1079 let Some(raw_host) = get_request_host(&req) else {
1081 return error_response(StatusCode::BAD_REQUEST, "Missing Host header");
1082 };
1083 let host = if raw_host.starts_with('[') {
1087 raw_host
1089 .split("]:")
1090 .next()
1091 .unwrap_or(&raw_host)
1092 .trim_start_matches('[')
1093 .trim_end_matches(']')
1094 .to_string()
1095 } else {
1096 raw_host.split(':').next().unwrap_or(&raw_host).to_string()
1098 };
1099
1100 let is_from_pitchfork = req.headers().contains_key(PROXY_HOPS_HEADER);
1111 let hops: u64 = if is_from_pitchfork {
1112 req.headers()
1113 .get(PROXY_HOPS_HEADER)
1114 .and_then(|v| v.to_str().ok())
1115 .and_then(|s| s.parse().ok())
1116 .unwrap_or(0)
1117 } else {
1118 0
1120 };
1121 if hops >= MAX_PROXY_HOPS {
1122 return error_response(
1123 StatusCode::LOOP_DETECTED,
1124 &format!(
1125 "Loop detected for '{host}': request has passed through the proxy {hops} times.\n\
1126 This usually means a backend is proxying back through pitchfork without rewriting \n\
1127 the Host header. If you use Vite/webpack proxy, set changeOrigin: true."
1128 ),
1129 );
1130 }
1131
1132 let target_port = match resolve_target(&host, &state.tld).await {
1134 ResolveResult::Ready(port) => port,
1135 ResolveResult::Starting { slug } => {
1136 return starting_html_response(&slug, &raw_host);
1137 }
1138 ResolveResult::NotFound => {
1139 return error_response(
1140 StatusCode::BAD_GATEWAY,
1141 &format!(
1142 "No daemon found for host '{host}'.\n\
1143 Make sure the daemon has a slug, is running, and has a port configured.\n\
1144 Expected format: <slug>.{tld}",
1145 tld = state.tld
1146 ),
1147 );
1148 }
1149 ResolveResult::Error(msg) => {
1150 return error_response(StatusCode::BAD_GATEWAY, &msg);
1151 }
1152 };
1153
1154 let path_and_query = req
1156 .uri()
1157 .path_and_query()
1158 .map(|pq| pq.as_str())
1159 .unwrap_or("/");
1160
1161 let forward_uri = match Uri::builder()
1162 .scheme("http")
1163 .authority(format!("localhost:{target_port}"))
1164 .path_and_query(path_and_query)
1165 .build()
1166 {
1167 Ok(uri) => uri,
1168 Err(e) => {
1169 return error_response(
1170 StatusCode::INTERNAL_SERVER_ERROR,
1171 &format!("Failed to build forward URI: {e}"),
1172 );
1173 }
1174 };
1175
1176 *req.uri_mut() = forward_uri;
1178 req.headers_mut().insert(
1179 HOST,
1180 HeaderValue::from_str(&format!("localhost:{target_port}"))
1181 .unwrap_or_else(|_| HeaderValue::from_static("localhost")),
1182 );
1183
1184 inject_forwarded_headers(&mut req, state.is_tls, &raw_host);
1186
1187 if let Ok(v) = HeaderValue::from_str(&(hops + 1).to_string()) {
1189 req.headers_mut()
1190 .insert(axum::http::HeaderName::from_static(PROXY_HOPS_HEADER), v);
1191 }
1192
1193 let pseudo_headers: Vec<_> = req
1198 .headers()
1199 .keys()
1200 .filter(|k| k.as_str().starts_with(':'))
1201 .cloned()
1202 .collect();
1203 for key in pseudo_headers {
1204 req.headers_mut().remove(&key);
1205 }
1206
1207 let client_upgrade = hyper::upgrade::on(&mut req);
1209
1210 let result = match tokio::time::timeout(
1218 std::time::Duration::from_secs(120),
1219 state.client.request(req),
1220 )
1221 .await
1222 {
1223 Ok(r) => r,
1224 Err(_elapsed) => {
1225 let msg = format!(
1226 "Request to daemon on port {target_port} timed out after 120 s.\n\
1227 The daemon accepted the connection but did not respond in time."
1228 );
1229 log::warn!("{msg}");
1230 if let Some(ref on_error) = state.on_error {
1231 on_error(&msg);
1232 }
1233 return error_response(StatusCode::GATEWAY_TIMEOUT, &msg);
1234 }
1235 };
1236 match result {
1237 Ok(mut resp) => {
1238 let backend_upgrade = hyper::upgrade::on(&mut resp);
1240 let (mut parts, body) = resp.into_parts();
1241
1242 parts.headers.insert(
1244 axum::http::HeaderName::from_static(PITCHFORK_HEADER),
1245 HeaderValue::from_static("1"),
1246 );
1247
1248 parts.headers.remove(PROXY_HOPS_HEADER);
1250
1251 if state.is_tls && parts.status != StatusCode::SWITCHING_PROTOCOLS {
1256 for h in HOP_BY_HOP_HEADERS {
1257 if let Ok(name) = axum::http::HeaderName::from_bytes(h.as_bytes()) {
1258 parts.headers.remove(&name);
1259 }
1260 }
1261 }
1262
1263 if parts.status == StatusCode::SWITCHING_PROTOCOLS {
1265 tokio::spawn(async move {
1270 if let (Ok(client_upgraded), Ok(backend_upgraded)) =
1271 (client_upgrade.await, backend_upgrade.await)
1272 {
1273 let mut client_io = hyper_util::rt::TokioIo::new(client_upgraded);
1274 let mut backend_io = hyper_util::rt::TokioIo::new(backend_upgraded);
1275 let _ =
1283 tokio::io::copy_bidirectional(&mut client_io, &mut backend_io).await;
1284 }
1285 });
1286 return Response::from_parts(parts, Body::empty());
1287 }
1288
1289 Response::from_parts(parts, Body::new(body))
1292 }
1293 Err(e) => {
1294 let msg = format!(
1295 "Failed to connect to daemon on port {target_port}: {e}\n\
1296 The daemon may have stopped or is not yet ready."
1297 );
1298 if let Some(ref on_error) = state.on_error {
1299 on_error(&msg);
1300 } else {
1301 log::warn!("{msg}");
1302 }
1303 error_response(StatusCode::BAD_GATEWAY, &msg)
1304 }
1305 }
1306}
1307
1308async fn resolve_target(host: &str, tld: &str) -> ResolveResult {
1327 let Some(subdomain) = strip_tld(host, tld) else {
1328 return ResolveResult::NotFound;
1329 };
1330
1331 let Some(cached) = cached_slug_lookup(&subdomain).await else {
1332 return ResolveResult::NotFound;
1333 };
1334
1335 let (expected_namespace, worktree_dir) = if subdomain != cached.slug {
1339 let prefix = subdomain
1340 .strip_suffix(&format!(".{}", cached.slug))
1341 .map(|s| s.to_string());
1342 match prefix {
1343 Some(ref p) => match cached.worktrees.iter().find(|w| w.sanitized_branch == *p) {
1344 Some(wt) => {
1345 let ns = wt.namespace.clone().or_else(|| {
1346 log::warn!(
1347 "Worktree '{}' has no cached namespace; \
1348 falling back to parent slug namespace.",
1349 wt.path.display()
1350 );
1351 cached.namespace.clone()
1352 });
1353 (ns, Some(wt.path.clone()))
1354 }
1355 None => (cached.namespace.clone(), None),
1356 },
1357 None => (cached.namespace.clone(), None),
1358 }
1359 } else {
1360 (cached.namespace.clone(), None)
1361 };
1362
1363 let daemon_name = &cached.daemon_name;
1364
1365 let daemons = {
1366 let state_file = SUPERVISOR.state_file.lock().await;
1367 state_file.daemons.clone()
1368 };
1369
1370 let running_matches: Vec<(&DaemonId, &crate::daemon::Daemon)> = daemons
1371 .iter()
1372 .filter(|(id, d)| {
1373 id.name() == daemon_name
1374 && d.status.is_running()
1375 && match &expected_namespace {
1376 Some(ns) => id.namespace() == ns,
1377 None => true,
1378 }
1379 })
1380 .collect();
1381
1382 match running_matches.as_slice() {
1383 [] => {
1384 try_auto_start(
1385 &cached.slug,
1386 &cached,
1387 worktree_dir.as_deref(),
1388 expected_namespace.as_deref(),
1389 )
1390 .await
1391 }
1392 [(_, d)] => {
1393 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1394 ResolveResult::Ready(port)
1395 } else {
1396 ResolveResult::NotFound
1397 }
1398 }
1399 _ => {
1400 let d = running_matches[0].1;
1401 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1402 ResolveResult::Ready(port)
1403 } else {
1404 ResolveResult::NotFound
1405 }
1406 }
1407 }
1408}
1409
1410struct AutoStartGuard {
1417 daemon_id: DaemonId,
1418}
1419
1420impl Drop for AutoStartGuard {
1421 fn drop(&mut self) {
1422 let daemon_id = self.daemon_id.clone();
1423 tokio::spawn(async move {
1427 AUTO_START_IN_PROGRESS.lock().await.remove(&daemon_id);
1428 });
1429 }
1430}
1431
1432async fn try_auto_start(
1443 slug: &str,
1444 cached: &CachedSlugEntry,
1445 worktree_dir: Option<&std::path::Path>,
1446 expected_namespace: Option<&str>,
1447) -> ResolveResult {
1448 let s = settings();
1449 if !s.proxy.auto_start {
1450 return ResolveResult::NotFound;
1451 }
1452
1453 let ns = expected_namespace
1454 .map(|s| s.to_string())
1455 .or_else(|| cached.namespace.clone())
1456 .unwrap_or_else(|| "global".to_string());
1457 let daemon_id = match DaemonId::try_new(&ns, &cached.daemon_name) {
1458 Ok(id) => id,
1459 Err(_) => return ResolveResult::NotFound,
1460 };
1461
1462 {
1463 let mut in_progress = AUTO_START_IN_PROGRESS.lock().await;
1464 if !in_progress.insert(daemon_id.clone()) {
1465 return ResolveResult::Starting {
1466 slug: slug.to_string(),
1467 };
1468 }
1469 }
1470
1471 let _guard = AutoStartGuard {
1472 daemon_id: daemon_id.clone(),
1473 };
1474
1475 let timeout = s.proxy_auto_start_timeout();
1476
1477 match tokio::time::timeout(
1478 timeout,
1479 try_auto_start_inner(slug, cached, &daemon_id, worktree_dir),
1480 )
1481 .await
1482 {
1483 Ok(result) => result,
1484 Err(_elapsed) => {
1485 log::warn!("Auto-start: total timeout ({timeout:?}) exceeded for daemon {daemon_id}");
1486 ResolveResult::Error(format!(
1487 "Auto-start for '{daemon_id}' timed out after {timeout:?}.\n\
1488 The daemon did not become ready and bind a port within the configured \
1489 proxy_auto_start_timeout.\n\
1490 Increase the timeout or check the daemon's logs for slow startup."
1491 ))
1492 }
1493 }
1494}
1495
1496async fn try_auto_start_inner(
1500 slug: &str,
1501 cached: &CachedSlugEntry,
1502 daemon_id: &DaemonId,
1503 worktree_dir: Option<&std::path::Path>,
1504) -> ResolveResult {
1505 let config_dir = worktree_dir.unwrap_or(&cached.dir);
1506
1507 let pt = match crate::pitchfork_toml::PitchforkToml::all_merged_from(config_dir) {
1508 Ok(pt) => pt,
1509 Err(e) => {
1510 log::warn!(
1511 "Auto-start: failed to load config from {}: {e}",
1512 config_dir.display()
1513 );
1514 return ResolveResult::NotFound;
1515 }
1516 };
1517
1518 let daemon_config = match pt.daemons.get(daemon_id) {
1519 Some(cfg) => cfg,
1520 None => {
1521 log::debug!(
1522 "Auto-start: daemon {daemon_id} not found in config at {}",
1523 config_dir.display()
1524 );
1525 return ResolveResult::NotFound;
1526 }
1527 };
1528
1529 let opts = crate::ipc::batch::StartOptions {
1530 quiet: true,
1531 ..crate::ipc::batch::StartOptions::default()
1532 };
1533 let mut run_opts = match crate::ipc::batch::build_run_options(daemon_id, daemon_config, &opts) {
1534 Ok(o) => o,
1535 Err(e) => {
1536 log::warn!("Auto-start: failed to build run options for {daemon_id}: {e}");
1537 return ResolveResult::Error(format!("Failed to build run options: {e}"));
1538 }
1539 };
1540
1541 if run_opts.dir.0.as_os_str().is_empty() {
1544 run_opts.dir = crate::config_types::Dir(config_dir.to_path_buf());
1545 }
1546
1547 log::info!("Auto-start: starting daemon {daemon_id} for slug '{slug}'");
1548
1549 let run_result = SUPERVISOR.run(run_opts).await;
1550
1551 if let Err(e) = run_result {
1552 log::warn!("Auto-start: failed to start daemon {daemon_id}: {e}");
1553 return ResolveResult::Error(format!("Failed to start daemon: {e}"));
1554 }
1555
1556 let poll_interval = std::time::Duration::from_millis(250);
1557
1558 loop {
1559 let daemons = {
1560 let sf = SUPERVISOR.state_file.lock().await;
1561 sf.daemons.clone()
1562 };
1563
1564 if let Some(d) = daemons.get(daemon_id) {
1565 if d.status.is_running() {
1566 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1567 log::info!("Auto-start: daemon {daemon_id} is ready on port {port}");
1568 return ResolveResult::Ready(port);
1569 }
1570 } else {
1571 log::warn!(
1572 "Auto-start: daemon {daemon_id} is no longer running (status: {})",
1573 d.status
1574 );
1575 return ResolveResult::Error(format!(
1576 "Daemon '{daemon_id}' started but exited unexpectedly.\n\
1577 Check its logs for errors."
1578 ));
1579 }
1580 } else {
1581 log::warn!("Auto-start: daemon {daemon_id} not found in state file after start");
1582 return ResolveResult::Error(format!(
1583 "Daemon '{daemon_id}' started but disappeared from the state file.\n\
1584 Check its logs for errors."
1585 ));
1586 }
1587
1588 tokio::time::sleep(poll_interval).await;
1589 }
1590}
1591
1592fn strip_tld(host: &str, tld: &str) -> Option<String> {
1599 host.strip_suffix(&format!(".{tld}"))
1600 .filter(|s| !s.is_empty())
1601 .map(str::to_string)
1602}
1603
1604fn bind_error_message(port: u16, err: &std::io::Error) -> String {
1606 if port < 1024 {
1607 format!(
1608 "Failed to bind proxy server to port {port}: {err}\n\
1609 Hint: ports below 1024 require elevated privileges. \
1610 Try: sudo pitchfork supervisor start"
1611 )
1612 } else {
1613 format!(
1614 "Failed to bind proxy server to port {port}: {err}\n\
1615 Hint: another process may already be using this port."
1616 )
1617 }
1618}
1619
1620fn starting_html_response(slug: &str, raw_host: &str) -> Response {
1625 let escaped_slug = slug
1626 .replace('&', "&")
1627 .replace('<', "<")
1628 .replace('>', ">")
1629 .replace('"', """)
1630 .replace('\'', "'");
1631 let escaped_host = raw_host
1632 .replace('&', "&")
1633 .replace('<', "<")
1634 .replace('>', ">")
1635 .replace('"', """)
1636 .replace('\'', "'");
1637
1638 let html = format!(
1639 r##"<!DOCTYPE html>
1640<html lang="en">
1641<head>
1642 <meta charset="UTF-8">
1643 <meta name="viewport" content="width=device-width, initial-scale=1">
1644 <meta http-equiv="refresh" content="2">
1645 <title>Starting {escaped_slug}… — pitchfork</title>
1646 <style>
1647 * {{ margin: 0; padding: 0; box-sizing: border-box; }}
1648 body {{
1649 font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
1650 background: #0f1117;
1651 color: #e1e4e8;
1652 display: flex;
1653 align-items: center;
1654 justify-content: center;
1655 min-height: 100vh;
1656 }}
1657 .container {{
1658 text-align: center;
1659 max-width: 480px;
1660 padding: 2rem;
1661 }}
1662 .spinner {{
1663 width: 48px;
1664 height: 48px;
1665 border: 4px solid rgba(255, 255, 255, 0.1);
1666 border-top-color: #58a6ff;
1667 border-radius: 50%;
1668 animation: spin 0.8s linear infinite;
1669 margin: 0 auto 1.5rem;
1670 }}
1671 @keyframes spin {{
1672 to {{ transform: rotate(360deg); }}
1673 }}
1674 h1 {{
1675 font-size: 1.5rem;
1676 font-weight: 600;
1677 margin-bottom: 0.5rem;
1678 }}
1679 .slug {{
1680 color: #58a6ff;
1681 font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
1682 }}
1683 .host {{
1684 color: #8b949e;
1685 font-size: 0.875rem;
1686 margin-top: 0.25rem;
1687 }}
1688 .hint {{
1689 color: #8b949e;
1690 font-size: 0.8rem;
1691 margin-top: 1.5rem;
1692 }}
1693 </style>
1694</head>
1695<body>
1696 <div class="container">
1697 <div class="spinner"></div>
1698 <h1>Starting <span class="slug">{escaped_slug}</span>…</h1>
1699 <p class="host">{escaped_host}</p>
1700 <p class="hint">This page will refresh automatically when the daemon is ready.</p>
1701 </div>
1702</body>
1703</html>"##
1704 );
1705
1706 Response::builder()
1707 .status(StatusCode::SERVICE_UNAVAILABLE)
1708 .header("content-type", "text/html; charset=utf-8")
1709 .header("retry-after", "2")
1710 .body(Body::from(html))
1711 .unwrap_or_else(|_| (StatusCode::SERVICE_UNAVAILABLE, "Starting…").into_response())
1712}
1713
1714async fn redirect_to_https_handler(req: Request) -> Response {
1723 if req.headers().contains_key("upgrade") {
1725 log::warn!("Dropping plain-HTTP WebSocket upgrade attempt — use wss:// instead of ws://");
1726 return (
1727 StatusCode::BAD_REQUEST,
1728 "WebSocket over plain HTTP is not supported on the HTTPS port. Use wss:// instead.",
1729 )
1730 .into_response();
1731 }
1732
1733 let raw_host = get_request_host(&req);
1734 let Some(raw_host) = raw_host else {
1735 return (StatusCode::BAD_REQUEST, "Missing Host header").into_response();
1736 };
1737
1738 let hostname = if raw_host.starts_with('[') {
1740 raw_host
1742 .split_once("]:")
1743 .map(|(host, _)| host)
1744 .unwrap_or(&raw_host)
1745 .trim_start_matches('[')
1746 .trim_end_matches(']')
1747 } else {
1748 let mut parts = raw_host.rsplitn(2, ':');
1750 let last = parts.next().unwrap_or(&raw_host);
1751 parts.next().unwrap_or(last)
1752 };
1753
1754 let path = req
1755 .uri()
1756 .path_and_query()
1757 .map(|pq| pq.as_str())
1758 .unwrap_or("/");
1759
1760 let https_port = match u16::try_from(settings().proxy.port).ok().filter(|&p| p > 0) {
1761 Some(443) | None => String::new(),
1762 Some(port) => format!(":{port}"),
1763 };
1764
1765 let host_for_url = if raw_host.starts_with('[') {
1766 format!("[{hostname}]")
1767 } else {
1768 hostname.to_string()
1769 };
1770
1771 let location = format!("https://{host_for_url}{https_port}{path}");
1772 (
1773 StatusCode::FOUND,
1774 [(axum::http::header::LOCATION, location)],
1775 )
1776 .into_response()
1777}
1778
1779fn error_response(status: StatusCode, message: &str) -> Response {
1781 (status, message.to_string()).into_response()
1782}
1783
1784#[cfg(test)]
1785mod tests {
1786 use super::*;
1787
1788 #[test]
1789 fn test_strip_tld() {
1790 assert_eq!(
1791 strip_tld("api.myproject.localhost", "localhost"),
1792 Some("api.myproject".to_string())
1793 );
1794 assert_eq!(
1795 strip_tld("api.localhost", "localhost"),
1796 Some("api".to_string())
1797 );
1798 assert_eq!(strip_tld("localhost", "localhost"), None);
1799 assert_eq!(
1800 strip_tld("api.myproject.test", "test"),
1801 Some("api.myproject".to_string())
1802 );
1803 assert_eq!(strip_tld("other.com", "localhost"), None);
1804 }
1805
1806 fn make_entry(name: &str) -> CachedSlugEntry {
1807 CachedSlugEntry {
1808 slug: name.to_string(),
1809 namespace: None,
1810 daemon_name: name.to_string(),
1811 dir: std::path::PathBuf::from(format!("/tmp/{name}")),
1812 worktrees: vec![],
1813 }
1814 }
1815
1816 #[test]
1817 fn test_wildcard_slug_lookup_exact_match() {
1818 let mut entries = std::collections::HashMap::new();
1819 entries.insert("myapp".to_string(), make_entry("myapp"));
1820 let result = wildcard_slug_lookup("myapp", &entries, true);
1822 assert!(result.is_some());
1823 assert_eq!(result.unwrap().daemon_name, "myapp");
1824 }
1825
1826 #[test]
1827 fn test_wildcard_slug_lookup_subdomain_fallback() {
1828 let mut entries = std::collections::HashMap::new();
1829 entries.insert("myapp".to_string(), make_entry("myapp"));
1830 let result = wildcard_slug_lookup("tenant.myapp", &entries, true);
1832 assert!(result.is_some());
1833 assert_eq!(result.unwrap().daemon_name, "myapp");
1834 }
1835
1836 #[test]
1837 fn test_wildcard_slug_lookup_nested_fallback() {
1838 let mut entries = std::collections::HashMap::new();
1839 entries.insert("myapp".to_string(), make_entry("myapp"));
1840 let result = wildcard_slug_lookup("a.b.myapp", &entries, true);
1842 assert!(result.is_some());
1843 assert_eq!(result.unwrap().daemon_name, "myapp");
1844 }
1845
1846 #[test]
1847 fn test_wildcard_slug_lookup_no_match() {
1848 let entries = std::collections::HashMap::new();
1849 let result = wildcard_slug_lookup("tenant.myapp", &entries, true);
1851 assert!(result.is_none());
1852 }
1853
1854 #[test]
1855 fn test_wildcard_slug_lookup_disabled() {
1856 let mut entries = std::collections::HashMap::new();
1857 entries.insert("myapp".to_string(), make_entry("myapp"));
1858 let result = wildcard_slug_lookup("tenant.myapp", &entries, false);
1860 assert!(result.is_none());
1861 let result = wildcard_slug_lookup("myapp", &entries, false);
1863 assert!(result.is_some());
1864 }
1865
1866 #[test]
1867 fn test_wildcard_slug_lookup_exact_beats_wildcard() {
1868 let mut entries = std::collections::HashMap::new();
1869 entries.insert("myapp".to_string(), make_entry("myapp"));
1870 let mut tenant_entry = make_entry("tenant-daemon");
1871 tenant_entry.slug = "tenant.myapp".to_string();
1872 entries.insert("tenant.myapp".to_string(), tenant_entry);
1873 let result = wildcard_slug_lookup("tenant.myapp", &entries, true);
1875 assert!(result.is_some());
1876 assert_eq!(result.unwrap().daemon_name, "tenant-daemon");
1877 }
1878
1879 #[cfg(feature = "proxy-tls")]
1880 #[test]
1881 fn test_generate_ca() {
1882 let dir = tempfile::tempdir().unwrap();
1883 let cert_path = dir.path().join("ca.pem");
1884 let key_path = dir.path().join("ca-key.pem");
1885
1886 generate_ca(&cert_path, &key_path).unwrap();
1887
1888 assert!(cert_path.exists(), "ca.pem should be created");
1889 assert!(key_path.exists(), "ca-key.pem should be created");
1890
1891 let cert_pem = std::fs::read_to_string(&cert_path).unwrap();
1892 let key_pem = std::fs::read_to_string(&key_path).unwrap();
1893
1894 assert!(cert_pem.contains("BEGIN CERTIFICATE"), "should be PEM cert");
1895 assert!(
1896 key_pem.contains("BEGIN") && key_pem.contains("PRIVATE KEY"),
1897 "should be PEM key"
1898 );
1899 }
1900}