1use std::net::SocketAddr;
11use std::sync::Arc;
12
13use axum::Router;
14use axum::body::Body;
15use axum::extract::{Request, State};
16use axum::http::{HeaderValue, StatusCode, Uri};
17use axum::response::{IntoResponse, Response};
18use hyper::header::HOST;
19
20const PITCHFORK_HEADER: &str = "x-pitchfork";
22
23const PROXY_HOPS_HEADER: &str = "x-pitchfork-hops";
26
27const MAX_PROXY_HOPS: u64 = 5;
29
30const HOP_BY_HOP_HEADERS: &[&str] = &[
33 "connection",
34 "keep-alive",
35 "proxy-connection",
36 "transfer-encoding",
37 "upgrade",
38];
39
40use hyper_util::client::legacy::Client;
41use hyper_util::client::legacy::connect::HttpConnector;
42use hyper_util::rt::TokioExecutor;
43use tokio::net::TcpListener;
44
45use crate::daemon_id::DaemonId;
46use crate::settings::settings;
47use crate::supervisor::SUPERVISOR;
48
49const SLUG_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(2);
62
63#[derive(Clone, Debug)]
65pub struct CachedSlugEntry {
66 pub slug: String,
68 pub namespace: Option<String>,
70 pub daemon_name: String,
72 pub dir: std::path::PathBuf,
74 pub worktrees: Vec<crate::proxy::worktree::WorktreeEntry>,
76}
77
78struct SlugCache {
80 entries: Arc<std::collections::HashMap<String, CachedSlugEntry>>,
81 expires_at: std::time::Instant,
82}
83
84static SLUG_CACHE: once_cell::sync::Lazy<tokio::sync::Mutex<SlugCache>> =
85 once_cell::sync::Lazy::new(|| {
86 tokio::sync::Mutex::new(SlugCache {
87 entries: Arc::new(std::collections::HashMap::new()),
88 expires_at: std::time::Instant::now(), })
90 });
91
92fn build_slug_entries() -> std::collections::HashMap<String, CachedSlugEntry> {
95 let global_slugs = crate::pitchfork_toml::PitchforkToml::read_global_slugs();
96 let mut entries = std::collections::HashMap::with_capacity(global_slugs.len());
97 let worktree_enabled = crate::settings::settings().proxy.worktree;
98 for (slug, entry) in &global_slugs {
99 let ns = entry.resolve_namespace();
100 let daemon_name = entry.daemon.as_deref().unwrap_or(slug).to_string();
101 let worktrees = if worktree_enabled {
102 let wts = match entry.resolve_dir() {
103 Some(dir) => crate::proxy::worktree::discover_worktrees(&dir),
104 None => vec![],
105 };
106 let mut seen = std::collections::HashMap::with_capacity(wts.len());
109 let mut deduped = Vec::with_capacity(wts.len());
110 for mut wt in wts {
111 let wt_ns = crate::pitchfork_toml::PitchforkToml::namespace_for_dir(&wt.path).ok();
112 wt.namespace = wt_ns;
113 match seen.entry(wt.sanitized_branch.clone()) {
114 std::collections::hash_map::Entry::Occupied(e) => {
115 log::warn!(
116 "Worktree slug collision: '{}' and '{}' both sanitize to '{}'. \
117 Only the first (in discovery order) will be routed.",
118 e.get(),
119 wt.branch,
120 wt.sanitized_branch,
121 );
122 }
123 std::collections::hash_map::Entry::Vacant(e) => {
124 e.insert(wt.branch.clone());
125 deduped.push(wt);
126 }
127 }
128 }
129 deduped
130 } else {
131 vec![]
132 };
133 entries.insert(
134 slug.clone(),
135 CachedSlugEntry {
136 slug: slug.clone(),
137 namespace: ns,
138 daemon_name,
139 dir: entry.resolve_dir().unwrap_or_default(),
140 worktrees,
141 },
142 );
143 }
144 entries
145}
146
147pub async fn get_cached_slugs() -> Arc<std::collections::HashMap<String, CachedSlugEntry>> {
153 {
155 let cache = SLUG_CACHE.lock().await;
156 if std::time::Instant::now() < cache.expires_at {
157 return Arc::clone(&cache.entries);
158 }
159 } let new_entries = Arc::new(
163 tokio::task::spawn_blocking(build_slug_entries)
164 .await
165 .unwrap_or_else(|e| {
166 log::warn!("Failed to refresh slug cache: {e}");
167 std::collections::HashMap::new()
168 }),
169 );
170
171 {
173 let mut cache = SLUG_CACHE.lock().await;
174 cache.entries = Arc::clone(&new_entries);
175 cache.expires_at = std::time::Instant::now() + SLUG_CACHE_TTL;
176 }
177
178 new_entries
179}
180
181fn wildcard_slug_lookup<'a>(
187 subdomain: &str,
188 entries: &'a std::collections::HashMap<String, CachedSlugEntry>,
189 wildcard: bool,
190) -> Option<&'a CachedSlugEntry> {
191 entries.get(subdomain).or_else(|| {
192 if !wildcard {
193 return None;
194 }
195 subdomain
197 .match_indices('.')
198 .map(|(i, _)| &subdomain[i + 1..])
199 .find_map(|candidate| entries.get(candidate))
200 })
201}
202
203async fn cached_slug_lookup(subdomain: &str) -> Option<CachedSlugEntry> {
210 let entries = get_cached_slugs().await;
211 wildcard_slug_lookup(subdomain, &entries, settings().proxy.wildcard).cloned()
212}
213
214static AUTO_START_IN_PROGRESS: once_cell::sync::Lazy<
221 tokio::sync::Mutex<std::collections::HashSet<DaemonId>>,
222> = once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new(std::collections::HashSet::new()));
223
224enum ResolveResult {
226 Ready(u16),
229 Starting { slug: String },
231 NotFound,
233 Error(String),
235}
236
237type OnErrorFn = Arc<dyn Fn(&str) + Send + Sync>;
240
241#[derive(Clone)]
242struct ProxyState {
243 client: Arc<Client<HttpConnector, Body>>,
245 tld: String,
247 is_tls: bool,
249 on_error: Option<OnErrorFn>,
251}
252
253pub async fn serve(
261 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
262 cancel: tokio_util::sync::CancellationToken,
263) -> crate::Result<()> {
264 let s = settings();
265 let lan_enabled = s.proxy.lan || !s.proxy.lan_ip.is_empty();
266
267 let effective_tld = if lan_enabled {
268 "local".to_string()
269 } else {
270 s.proxy.tld.clone()
271 };
272
273 let Some(effective_port) = u16::try_from(s.proxy.port).ok().filter(|&p| p > 0) else {
274 let msg = format!(
275 "proxy.port {} is out of valid port range (1-65535), proxy server cannot start",
276 s.proxy.port
277 );
278 let _ = bind_tx.send(Err(msg.clone()));
279 miette::bail!("{msg}");
280 };
281
282 let mut connector = HttpConnector::new();
283 connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
287
288 let client = Client::builder(TokioExecutor::new())
289 .pool_idle_timeout(std::time::Duration::from_secs(30))
292 .build(connector);
293
294 let state = ProxyState {
295 client: Arc::new(client),
296 tld: effective_tld.clone(),
297 is_tls: s.proxy.https,
298 on_error: None,
299 };
300
301 let app = Router::new().fallback(proxy_handler).with_state(state);
302
303 let bind_ip: std::net::IpAddr = if lan_enabled && s.proxy.host == "127.0.0.1" {
307 std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)
308 } else {
309 match s.proxy.host.parse() {
310 Ok(ip) => ip,
311 Err(_) => {
312 log::warn!(
313 "proxy.host {:?} is not a valid IP address — falling back to 127.0.0.1. \
314 The proxy will only be reachable on the loopback interface.",
315 s.proxy.host
316 );
317 std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
318 }
319 }
320 };
321 let addr = SocketAddr::from((bind_ip, effective_port));
322
323 if s.proxy.https {
324 serve_https_with_http_fallback(app, addr, &s, effective_port, bind_tx, cancel).await
325 } else {
326 serve_http(app, addr, effective_port, bind_tx, cancel).await
327 }
328}
329
330async fn serve_http(
332 app: Router,
333 addr: SocketAddr,
334 effective_port: u16,
335 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
336 cancel: tokio_util::sync::CancellationToken,
337) -> crate::Result<()> {
338 let listener = match TcpListener::bind(addr).await {
339 Ok(l) => {
340 if settings().proxy.sync_hosts {
341 crate::proxy::hosts::sync_hosts_from_settings();
342 }
343 let _ = bind_tx.send(Ok(()));
344 l
345 }
346 Err(e) => {
347 let msg = bind_error_message(effective_port, &e);
348 let _ = bind_tx.send(Err(msg.clone()));
349 return Err(miette::miette!("{msg}"));
350 }
351 };
352
353 log::info!("Proxy server listening on http://{addr}");
354 if effective_port < 1024 {
355 log::info!(
356 "Note: port {effective_port} is a privileged port. \
357 The supervisor must be started with sudo to bind to this port."
358 );
359 }
360 let shutdown_signal = cancel.clone().cancelled_owned();
361 axum::serve(
362 listener,
363 app.into_make_service_with_connect_info::<SocketAddr>(),
364 )
365 .with_graceful_shutdown(shutdown_signal)
366 .await
367 .map_err(|e| miette::miette!("Proxy server error: {e}"))?;
368 Ok(())
369}
370
371#[cfg(feature = "proxy-tls")]
377async fn serve_https_with_http_fallback(
378 app: Router,
379 addr: SocketAddr,
380 s: &crate::settings::Settings,
381 effective_port: u16,
382 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
383 cancel: tokio_util::sync::CancellationToken,
384) -> crate::Result<()> {
385 use rustls::ServerConfig;
386 use tokio_rustls::TlsAcceptor;
387
388 let (ca_cert_path, ca_key_path) = resolve_tls_paths(s);
389
390 if !ca_cert_path.exists() || !ca_key_path.exists() {
392 generate_ca(&ca_cert_path, &ca_key_path)?;
393 log::info!(
394 "Generated local CA certificate at {}",
395 ca_cert_path.display()
396 );
397 log::info!("To trust the CA in your browser, run: pitchfork proxy trust");
398 }
399
400 let _ = rustls::crypto::ring::default_provider().install_default();
402
403 let resolver = SniCertResolver::new(&ca_cert_path, &ca_key_path)?;
405
406 let mut tls_config = ServerConfig::builder()
407 .with_no_client_auth()
408 .with_cert_resolver(Arc::new(resolver));
409 tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
412
413 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
414
415 let listener = match TcpListener::bind(addr).await {
416 Ok(l) => {
417 if settings().proxy.sync_hosts {
418 crate::proxy::hosts::sync_hosts_from_settings();
419 }
420 let _ = bind_tx.send(Ok(()));
421 l
422 }
423 Err(e) => {
424 let msg = bind_error_message(effective_port, &e);
425 let _ = bind_tx.send(Err(msg.clone()));
426 return Err(miette::miette!("{msg}"));
427 }
428 };
429
430 log::info!("Proxy server listening on https://{addr} (HTTP also accepted)");
431 if effective_port < 1024 {
432 log::info!(
433 "Note: port {effective_port} is a privileged port. \
434 The supervisor must be started with sudo to bind to this port."
435 );
436 }
437
438 let redirect_app = Router::new().fallback(redirect_to_https_handler);
440
441 let mut conn_tasks: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
443 loop {
444 while conn_tasks.try_join_next().is_some() {}
447
448 tokio::select! {
449 accept_result = listener.accept() => {
450 let (stream, _peer_addr) = match accept_result {
451 Ok(conn) => conn,
452 Err(e) => {
453 log::warn!("Accept error (will retry): {e}");
454 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
455 continue;
456 }
457 };
458
459 let acceptor = acceptor.clone();
460 let app = app.clone();
461 let redirect_app = redirect_app.clone();
462
463 conn_tasks.spawn(async move {
464 let mut peek_buf = [0u8; 1];
467 match stream.peek(&mut peek_buf).await {
468 Ok(0) | Err(_) => return,
469 _ => {}
470 }
471
472 if peek_buf[0] == 0x16 {
473 match acceptor.accept(stream).await {
475 Ok(tls_stream) => {
476 let io = hyper_util::rt::TokioIo::new(tls_stream);
477 let svc = hyper_util::service::TowerToHyperService::new(app);
478 if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
479 .serve_connection_with_upgrades(io, svc)
480 .await
481 {
482 log::debug!("Connection error: {e}");
485 }
486 }
487 Err(e) => {
488 log::debug!("TLS handshake error: {e}");
489 }
490 }
491 } else {
492 let io = hyper_util::rt::TokioIo::new(stream);
494 let svc = hyper_util::service::TowerToHyperService::new(redirect_app);
495 let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
496 .serve_connection_with_upgrades(io, svc)
497 .await;
498 }
499 });
500
501 while conn_tasks.try_join_next().is_some() {}
502 }
503 _ = cancel.cancelled() => {
504 log::info!("Proxy server shutting down (cancel signal received)");
505 break;
506 }
507 }
508 }
509
510 let drain_timeout = std::time::Duration::from_secs(10);
512 let _ = tokio::time::timeout(drain_timeout, async {
513 while conn_tasks.join_next().await.is_some() {}
514 })
515 .await;
516
517 Ok(())
518}
519
520#[cfg(not(feature = "proxy-tls"))]
522async fn serve_https_with_http_fallback(
523 _app: Router,
524 _addr: SocketAddr,
525 _s: &crate::settings::Settings,
526 _effective_port: u16,
527 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
528 _cancel: tokio_util::sync::CancellationToken,
529) -> crate::Result<()> {
530 let msg = "HTTPS proxy support requires the `proxy-tls` feature.\n\
531 Rebuild pitchfork with: cargo build --features proxy-tls"
532 .to_string();
533 let _ = bind_tx.send(Err(msg.clone()));
534 miette::bail!("{msg}")
535}
536
537#[cfg(feature = "proxy-tls")]
542fn resolve_tls_paths(s: &crate::settings::Settings) -> (std::path::PathBuf, std::path::PathBuf) {
543 let proxy_dir = crate::env::PITCHFORK_STATE_DIR.join("proxy");
544 let resolve = |configured: &str, default: &str| {
545 if configured.is_empty() {
546 proxy_dir.join(default)
547 } else {
548 std::path::PathBuf::from(configured)
549 }
550 };
551 (
552 resolve(&s.proxy.tls_cert, "ca.pem"),
553 resolve(&s.proxy.tls_key, "ca-key.pem"),
554 )
555}
556
557#[cfg(feature = "proxy-tls")]
562pub fn generate_ca(cert_path: &std::path::Path, key_path: &std::path::Path) -> crate::Result<()> {
563 use rcgen::{
564 BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyUsagePurpose,
565 };
566
567 if let Some(parent) = cert_path.parent() {
569 std::fs::create_dir_all(parent)
570 .map_err(|e| miette::miette!("Failed to create proxy cert directory: {e}"))?;
571 }
572
573 let mut params = CertificateParams::default();
574 let mut dn = DistinguishedName::new();
575 dn.push(DnType::CommonName, "Pitchfork Local CA");
576 dn.push(DnType::OrganizationName, "Pitchfork");
577 params.distinguished_name = dn;
578 params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
579 params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
580
581 let key_pair = rcgen::KeyPair::generate()
582 .map_err(|e| miette::miette!("Failed to generate CA key pair: {e}"))?;
583 let ca_cert = params
584 .self_signed(&key_pair)
585 .map_err(|e| miette::miette!("Failed to self-sign CA certificate: {e}"))?;
586
587 std::fs::write(cert_path, ca_cert.pem()).map_err(|e| {
589 miette::miette!(
590 "Failed to write CA certificate to {}: {e}",
591 cert_path.display()
592 )
593 })?;
594
595 {
599 #[cfg(unix)]
600 {
601 use std::io::Write;
602 use std::os::unix::fs::OpenOptionsExt;
603 std::fs::OpenOptions::new()
604 .write(true)
605 .create(true)
606 .truncate(true)
607 .mode(0o600)
608 .open(key_path)
609 .and_then(|mut f| f.write_all(key_pair.serialize_pem().as_bytes()))
610 .map_err(|e| {
611 miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
612 })?;
613 }
614 #[cfg(not(unix))]
615 {
616 std::fs::write(key_path, key_pair.serialize_pem()).map_err(|e| {
617 miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
618 })?;
619 log::debug!(
620 "CA private key written to {} (file permissions are not restricted \
621 on non-Unix platforms — consider restricting access manually)",
622 key_path.display()
623 );
624 }
625 }
626
627 Ok(())
628}
629
630#[cfg(feature = "proxy-tls")]
650struct SniCertResolver {
651 issuer: rcgen::Issuer<'static, rcgen::KeyPair>,
653 host_certs_dir: std::path::PathBuf,
655 cache: std::sync::Mutex<std::collections::HashMap<String, Arc<rustls::sign::CertifiedKey>>>,
657 pending: std::sync::Mutex<std::collections::HashSet<String>>,
661 pending_cv: std::sync::Condvar,
663}
664
665#[cfg(feature = "proxy-tls")]
666impl std::fmt::Debug for SniCertResolver {
667 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
668 f.debug_struct("SniCertResolver").finish_non_exhaustive()
669 }
670}
671
672#[cfg(feature = "proxy-tls")]
673impl SniCertResolver {
674 fn new(ca_cert_path: &std::path::Path, ca_key_path: &std::path::Path) -> crate::Result<Self> {
676 let ca_key_pem = std::fs::read_to_string(ca_key_path)
677 .map_err(|e| miette::miette!("Failed to read CA key {}: {e}", ca_key_path.display()))?;
678 let ca_cert_pem = std::fs::read_to_string(ca_cert_path).map_err(|e| {
679 miette::miette!("Failed to read CA cert {}: {e}", ca_cert_path.display())
680 })?;
681
682 if !ca_cert_pem.contains("BEGIN CERTIFICATE") {
684 miette::bail!("CA cert file does not contain a valid PEM certificate");
685 }
686
687 let ca_key = rcgen::KeyPair::from_pem(&ca_key_pem)
688 .map_err(|e| miette::miette!("Failed to parse CA key: {e}"))?;
689
690 let issuer = rcgen::Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key)
692 .map_err(|e| miette::miette!("Failed to parse CA cert: {e}"))?;
693
694 let host_certs_dir = ca_cert_path
696 .parent()
697 .unwrap_or(std::path::Path::new("."))
698 .join("host-certs");
699 std::fs::create_dir_all(&host_certs_dir)
700 .map_err(|e| miette::miette!("Failed to create host-certs dir: {e}"))?;
701
702 Ok(Self {
703 issuer,
704 host_certs_dir,
705 cache: std::sync::Mutex::new(std::collections::HashMap::new()),
706 pending: std::sync::Mutex::new(std::collections::HashSet::new()),
707 pending_cv: std::sync::Condvar::new(),
708 })
709 }
710
711 fn get_or_create(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
734 {
736 let cache = self.cache.lock().ok()?;
737 if let Some(ck) = cache.get(domain) {
738 return Some(Arc::clone(ck));
739 }
740 } loop {
753 {
754 let mut pending = self.pending.lock().ok()?;
755 if pending.contains(domain) {
756 pending = self.pending_cv.wait(pending).ok()?;
758 drop(pending);
760 } else {
761 pending.insert(domain.to_string());
763 break;
764 }
765 } {
771 let cache = self.cache.lock().ok()?;
772 if let Some(ck) = cache.get(domain) {
773 return Some(Arc::clone(ck));
774 }
775 } } let result = self.get_or_create_inner(domain);
779
780 {
786 let mut pending = match self.pending.lock() {
787 Ok(g) => g,
788 Err(e) => e.into_inner(),
789 };
790 pending.remove(domain);
791 self.pending_cv.notify_all();
792 }
793
794 result
795 }
796
797 fn get_or_create_inner(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
799 let safe_name = domain.replace('.', "_").replace('*', "wildcard");
800 let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
801
802 if disk_path.exists() {
804 if let Ok(ck) = self.load_from_disk(&disk_path) {
805 let ck = Arc::new(ck);
806 if let Ok(mut cache) = self.cache.lock() {
807 cache.insert(domain.to_string(), Arc::clone(&ck));
808 }
809 return Some(ck);
810 }
811 let _ = std::fs::remove_file(&disk_path);
813 }
814
815 let ck = self.sign_for_domain(domain).ok()?;
817
818 let ck = Arc::new(ck);
819 if let Ok(mut cache) = self.cache.lock() {
820 cache.insert(domain.to_string(), Arc::clone(&ck));
821 }
822 Some(ck)
823 }
824
825 fn load_from_disk(&self, path: &std::path::Path) -> crate::Result<rustls::sign::CertifiedKey> {
830 use rustls::pki_types::CertificateDer;
831 use rustls_pemfile::{certs, private_key};
832
833 let pem = std::fs::read_to_string(path)
834 .map_err(|e| miette::miette!("Failed to read disk cert {}: {e}", path.display()))?;
835
836 let cert_ders: Vec<CertificateDer<'static>> = certs(&mut pem.as_bytes())
837 .collect::<Result<Vec<_>, _>>()
838 .map_err(|e| miette::miette!("Failed to parse certs from {}: {e}", path.display()))?;
839
840 if cert_ders.is_empty() {
841 miette::bail!("No certificates found in {}", path.display());
842 }
843
844 {
846 let (_, cert) = x509_parser::parse_x509_certificate(&cert_ders[0]).map_err(|e| {
847 miette::miette!("Failed to parse certificate from {}: {e}", path.display())
848 })?;
849 use chrono::Utc;
850 let now_ts = Utc::now().timestamp();
851 let not_after_ts = cert.validity().not_after.timestamp();
852 if not_after_ts < now_ts {
853 miette::bail!(
854 "Cached certificate at {} has expired — will regenerate",
855 path.display()
856 );
857 }
858 }
859
860 let key_der = private_key(&mut pem.as_bytes())
861 .map_err(|e| miette::miette!("Failed to parse key from {}: {e}", path.display()))?
862 .ok_or_else(|| miette::miette!("No private key found in {}", path.display()))?;
863
864 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
865 .map_err(|e| miette::miette!("Failed to create signing key from disk: {e}"))?;
866
867 Ok(rustls::sign::CertifiedKey::new(cert_ders, signing_key))
868 }
869
870 fn sign_for_domain(&self, domain: &str) -> crate::Result<rustls::sign::CertifiedKey> {
878 use rcgen::date_time_ymd;
879 use rcgen::{CertificateParams, DistinguishedName, DnType, SanType};
880 use rustls::pki_types::CertificateDer;
881 use rustls_pemfile::private_key;
882
883 let mut params = CertificateParams::default();
884 let mut dn = DistinguishedName::new();
885 dn.push(DnType::CommonName, domain);
886 params.distinguished_name = dn;
887
888 {
890 use chrono::{Datelike, Duration, Utc};
891 let yesterday = Utc::now() - Duration::days(1);
892 let expiry = Utc::now() + Duration::days(397);
895 params.not_before = date_time_ymd(
896 yesterday.year(),
897 yesterday.month() as u8,
898 yesterday.day() as u8,
899 );
900 params.not_after =
901 date_time_ymd(expiry.year(), expiry.month() as u8, expiry.day() as u8);
902 }
903
904 let mut sans =
906 vec![SanType::DnsName(domain.to_string().try_into().map_err(
907 |e| miette::miette!("Invalid domain name '{domain}': {e}"),
908 )?)];
909 if let Some(dot_pos) = domain.find('.') {
911 let parent = &domain[dot_pos + 1..];
912 if parent.contains('.') {
914 let wildcard = format!("*.{parent}");
915 if let Ok(wc) = wildcard.try_into() {
916 sans.push(SanType::DnsName(wc));
917 }
918 }
919 }
920 params.subject_alt_names = sans;
921
922 let leaf_key = rcgen::KeyPair::generate()
923 .map_err(|e| miette::miette!("Failed to generate leaf key: {e}"))?;
924 let leaf_cert = params
925 .signed_by(&leaf_key, &self.issuer)
926 .map_err(|e| miette::miette!("Failed to sign leaf cert for '{domain}': {e}"))?;
927
928 let cert_der = CertificateDer::from(leaf_cert.der().to_vec());
930 let key_pem = leaf_key.serialize_pem();
931 let key_der = private_key(&mut key_pem.as_bytes())
932 .map_err(|e| miette::miette!("Failed to parse leaf key PEM: {e}"))?
933 .ok_or_else(|| miette::miette!("No private key found in generated PEM"))?;
934
935 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
936 .map_err(|e| miette::miette!("Failed to create signing key: {e}"))?;
937
938 let safe_name = domain.replace('.', "_").replace('*', "wildcard");
941 let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
942 let combined_pem = format!("{}{}", leaf_cert.pem(), key_pem);
943 {
944 #[cfg(unix)]
945 {
946 use std::io::Write;
947 use std::os::unix::fs::OpenOptionsExt;
948 if let Err(e) = std::fs::OpenOptions::new()
949 .write(true)
950 .create(true)
951 .truncate(true)
952 .mode(0o600)
953 .open(&disk_path)
954 .and_then(|mut f| f.write_all(combined_pem.as_bytes()))
955 {
956 log::warn!(
957 "Failed to persist cert for '{domain}' to {}: {e}",
958 disk_path.display()
959 );
960 }
961 }
962 #[cfg(not(unix))]
963 {
964 if let Err(e) = std::fs::write(&disk_path, combined_pem) {
965 log::warn!(
966 "Failed to persist cert for '{domain}' to {}: {e}",
967 disk_path.display()
968 );
969 } else {
970 log::debug!(
971 "Leaf cert for '{domain}' written to {} (file permissions are not \
972 restricted on non-Unix platforms — consider restricting access manually)",
973 disk_path.display()
974 );
975 }
976 }
977 }
978
979 Ok(rustls::sign::CertifiedKey::new(vec![cert_der], signing_key))
980 }
981}
982
983#[cfg(feature = "proxy-tls")]
984impl rustls::server::ResolvesServerCert for SniCertResolver {
985 fn resolve(
986 &self,
987 client_hello: rustls::server::ClientHello<'_>,
988 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
989 let domain = client_hello.server_name()?;
990 self.get_or_create(domain)
991 }
992}
993
994fn get_request_host(req: &Request) -> Option<String> {
1000 let authority = req
1002 .uri()
1003 .authority()
1004 .map(|a| a.as_str().to_string())
1005 .filter(|s| !s.is_empty());
1006
1007 authority.or_else(|| {
1008 req.headers()
1009 .get(HOST)
1010 .and_then(|h| h.to_str().ok())
1011 .map(str::to_string)
1012 })
1013}
1014
1015fn inject_forwarded_headers(req: &mut Request, is_tls: bool, host_header: &str) {
1026 let remote_addr = req
1027 .extensions()
1028 .get::<axum::extract::ConnectInfo<SocketAddr>>()
1029 .map(|ci| ci.0.ip().to_string())
1030 .unwrap_or_else(|| "127.0.0.1".to_string());
1031
1032 let proto = if is_tls { "https" } else { "http" };
1033 let default_port = if is_tls { "443" } else { "80" };
1034
1035 let forwarded_for = remote_addr.clone();
1038 let forwarded_proto = proto.to_string();
1039 let forwarded_host = host_header.to_string();
1040 let forwarded_port = host_header
1041 .rsplit_once(':')
1042 .map(|(_, port)| port.to_string())
1043 .unwrap_or_else(|| default_port.to_string());
1044
1045 for name in [
1051 "x-forwarded-for",
1052 "x-forwarded-proto",
1053 "x-forwarded-host",
1054 "x-forwarded-port",
1055 "forwarded",
1056 ] {
1057 if let Ok(header_name) = axum::http::HeaderName::from_bytes(name.as_bytes()) {
1058 req.headers_mut().remove(&header_name);
1059 }
1060 }
1061
1062 let headers = [
1063 ("x-forwarded-for", forwarded_for),
1064 ("x-forwarded-proto", forwarded_proto),
1065 ("x-forwarded-host", forwarded_host),
1066 ("x-forwarded-port", forwarded_port),
1067 ];
1068
1069 for (name, value) in headers {
1070 if let Ok(v) = HeaderValue::from_str(&value) {
1071 let header_name = axum::http::HeaderName::from_static(name);
1072 req.headers_mut().insert(header_name, v);
1073 }
1074 }
1075}
1076
1077async fn proxy_handler(State(state): State<ProxyState>, mut req: Request) -> Response {
1082 let Some(raw_host) = get_request_host(&req) else {
1084 return error_response(StatusCode::BAD_REQUEST, "Missing Host header");
1085 };
1086 let host = if raw_host.starts_with('[') {
1090 raw_host
1092 .split("]:")
1093 .next()
1094 .unwrap_or(&raw_host)
1095 .trim_start_matches('[')
1096 .trim_end_matches(']')
1097 .to_string()
1098 } else {
1099 raw_host.split(':').next().unwrap_or(&raw_host).to_string()
1101 };
1102
1103 let is_from_pitchfork = req.headers().contains_key(PROXY_HOPS_HEADER);
1114 let hops: u64 = if is_from_pitchfork {
1115 req.headers()
1116 .get(PROXY_HOPS_HEADER)
1117 .and_then(|v| v.to_str().ok())
1118 .and_then(|s| s.parse().ok())
1119 .unwrap_or(0)
1120 } else {
1121 0
1123 };
1124 if hops >= MAX_PROXY_HOPS {
1125 return error_response(
1126 StatusCode::LOOP_DETECTED,
1127 &format!(
1128 "Loop detected for '{host}': request has passed through the proxy {hops} times.\n\
1129 This usually means a backend is proxying back through pitchfork without rewriting \n\
1130 the Host header. If you use Vite/webpack proxy, set changeOrigin: true."
1131 ),
1132 );
1133 }
1134
1135 let target_port = if let Some(subdomain) = strip_tld(&host, &state.tld) {
1137 if subdomain == "pitchfork" {
1138 crate::web::port()
1139 } else {
1140 None
1141 }
1142 } else {
1143 None
1144 };
1145
1146 let target_port = if let Some(port) = target_port {
1147 port
1148 } else {
1149 match resolve_target(&host, &state.tld).await {
1150 ResolveResult::Ready(port) => port,
1151 ResolveResult::Starting { slug } => {
1152 return starting_html_response(&slug, &raw_host);
1153 }
1154 ResolveResult::NotFound => {
1155 return error_response(
1156 StatusCode::BAD_GATEWAY,
1157 &format!(
1158 "No daemon found for host '{host}'.\n\
1159 Make sure the daemon has a slug, is running, and has a port configured.\n\
1160 Expected format: <slug>.{tld}",
1161 tld = state.tld
1162 ),
1163 );
1164 }
1165 ResolveResult::Error(msg) => {
1166 return error_response(StatusCode::BAD_GATEWAY, &msg);
1167 }
1168 }
1169 };
1170 let path_and_query = req
1172 .uri()
1173 .path_and_query()
1174 .map(|pq| pq.as_str())
1175 .unwrap_or("/");
1176
1177 let forward_uri = match Uri::builder()
1178 .scheme("http")
1179 .authority(format!("localhost:{target_port}"))
1180 .path_and_query(path_and_query)
1181 .build()
1182 {
1183 Ok(uri) => uri,
1184 Err(e) => {
1185 return error_response(
1186 StatusCode::INTERNAL_SERVER_ERROR,
1187 &format!("Failed to build forward URI: {e}"),
1188 );
1189 }
1190 };
1191
1192 *req.uri_mut() = forward_uri;
1194 req.headers_mut().insert(
1195 HOST,
1196 HeaderValue::from_str(&format!("localhost:{target_port}"))
1197 .unwrap_or_else(|_| HeaderValue::from_static("localhost")),
1198 );
1199
1200 inject_forwarded_headers(&mut req, state.is_tls, &raw_host);
1202
1203 if let Ok(v) = HeaderValue::from_str(&(hops + 1).to_string()) {
1205 req.headers_mut()
1206 .insert(axum::http::HeaderName::from_static(PROXY_HOPS_HEADER), v);
1207 }
1208
1209 let pseudo_headers: Vec<_> = req
1214 .headers()
1215 .keys()
1216 .filter(|k| k.as_str().starts_with(':'))
1217 .cloned()
1218 .collect();
1219 for key in pseudo_headers {
1220 req.headers_mut().remove(&key);
1221 }
1222
1223 let client_upgrade = hyper::upgrade::on(&mut req);
1225
1226 let result = match tokio::time::timeout(
1234 std::time::Duration::from_secs(120),
1235 state.client.request(req),
1236 )
1237 .await
1238 {
1239 Ok(r) => r,
1240 Err(_elapsed) => {
1241 let msg = format!(
1242 "Request to daemon on port {target_port} timed out after 120 s.\n\
1243 The daemon accepted the connection but did not respond in time."
1244 );
1245 log::warn!("{msg}");
1246 if let Some(ref on_error) = state.on_error {
1247 on_error(&msg);
1248 }
1249 return error_response(StatusCode::GATEWAY_TIMEOUT, &msg);
1250 }
1251 };
1252 match result {
1253 Ok(mut resp) => {
1254 let backend_upgrade = hyper::upgrade::on(&mut resp);
1256 let (mut parts, body) = resp.into_parts();
1257
1258 parts.headers.insert(
1260 axum::http::HeaderName::from_static(PITCHFORK_HEADER),
1261 HeaderValue::from_static("1"),
1262 );
1263
1264 parts.headers.remove(PROXY_HOPS_HEADER);
1266
1267 if state.is_tls && parts.status != StatusCode::SWITCHING_PROTOCOLS {
1272 for h in HOP_BY_HOP_HEADERS {
1273 if let Ok(name) = axum::http::HeaderName::from_bytes(h.as_bytes()) {
1274 parts.headers.remove(&name);
1275 }
1276 }
1277 }
1278
1279 if parts.status == StatusCode::SWITCHING_PROTOCOLS {
1281 tokio::spawn(async move {
1286 if let (Ok(client_upgraded), Ok(backend_upgraded)) =
1287 (client_upgrade.await, backend_upgrade.await)
1288 {
1289 let mut client_io = hyper_util::rt::TokioIo::new(client_upgraded);
1290 let mut backend_io = hyper_util::rt::TokioIo::new(backend_upgraded);
1291 let _ =
1299 tokio::io::copy_bidirectional(&mut client_io, &mut backend_io).await;
1300 }
1301 });
1302 return Response::from_parts(parts, Body::empty());
1303 }
1304
1305 Response::from_parts(parts, Body::new(body))
1308 }
1309 Err(e) => {
1310 let msg = format!(
1311 "Failed to connect to daemon on port {target_port}: {e}\n\
1312 The daemon may have stopped or is not yet ready."
1313 );
1314 if let Some(ref on_error) = state.on_error {
1315 on_error(&msg);
1316 } else {
1317 log::warn!("{msg}");
1318 }
1319 error_response(StatusCode::BAD_GATEWAY, &msg)
1320 }
1321 }
1322}
1323
1324async fn resolve_target(host: &str, tld: &str) -> ResolveResult {
1343 let Some(subdomain) = strip_tld(host, tld) else {
1344 return ResolveResult::NotFound;
1345 };
1346
1347 let Some(cached) = cached_slug_lookup(&subdomain).await else {
1348 return ResolveResult::NotFound;
1349 };
1350
1351 let (expected_namespace, worktree_dir) = if subdomain != cached.slug {
1355 let prefix = subdomain
1356 .strip_suffix(&format!(".{}", cached.slug))
1357 .map(|s| s.to_string());
1358 match prefix {
1359 Some(ref p) => match cached.worktrees.iter().find(|w| w.sanitized_branch == *p) {
1360 Some(wt) => {
1361 let ns = wt.namespace.clone().or_else(|| {
1362 log::warn!(
1363 "Worktree '{}' has no cached namespace; \
1364 falling back to parent slug namespace.",
1365 wt.path.display()
1366 );
1367 cached.namespace.clone()
1368 });
1369 (ns, Some(wt.path.clone()))
1370 }
1371 None => (cached.namespace.clone(), None),
1372 },
1373 None => (cached.namespace.clone(), None),
1374 }
1375 } else {
1376 (cached.namespace.clone(), None)
1377 };
1378
1379 let daemon_name = &cached.daemon_name;
1380
1381 let daemons = {
1382 let state_file = SUPERVISOR.state_file.lock().await;
1383 state_file.daemons.clone()
1384 };
1385
1386 let running_matches: Vec<(&DaemonId, &crate::daemon::Daemon)> = daemons
1387 .iter()
1388 .filter(|(id, d)| {
1389 id.name() == daemon_name
1390 && d.status.is_running()
1391 && match &expected_namespace {
1392 Some(ns) => id.namespace() == ns,
1393 None => true,
1394 }
1395 })
1396 .collect();
1397
1398 match running_matches.as_slice() {
1399 [] => {
1400 try_auto_start(
1401 &cached.slug,
1402 &cached,
1403 worktree_dir.as_deref(),
1404 expected_namespace.as_deref(),
1405 )
1406 .await
1407 }
1408 [(_, d)] => {
1409 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1410 ResolveResult::Ready(port)
1411 } else {
1412 ResolveResult::NotFound
1413 }
1414 }
1415 _ => {
1416 let d = running_matches[0].1;
1417 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1418 ResolveResult::Ready(port)
1419 } else {
1420 ResolveResult::NotFound
1421 }
1422 }
1423 }
1424}
1425
1426struct AutoStartGuard {
1433 daemon_id: DaemonId,
1434}
1435
1436impl Drop for AutoStartGuard {
1437 fn drop(&mut self) {
1438 let daemon_id = self.daemon_id.clone();
1439 tokio::spawn(async move {
1443 AUTO_START_IN_PROGRESS.lock().await.remove(&daemon_id);
1444 });
1445 }
1446}
1447
1448async fn try_auto_start(
1459 slug: &str,
1460 cached: &CachedSlugEntry,
1461 worktree_dir: Option<&std::path::Path>,
1462 expected_namespace: Option<&str>,
1463) -> ResolveResult {
1464 let s = settings();
1465 if !s.proxy.auto_start {
1466 return ResolveResult::NotFound;
1467 }
1468
1469 let ns = expected_namespace
1470 .map(|s| s.to_string())
1471 .or_else(|| cached.namespace.clone())
1472 .unwrap_or_else(|| "global".to_string());
1473 let daemon_id = match DaemonId::try_new(&ns, &cached.daemon_name) {
1474 Ok(id) => id,
1475 Err(_) => return ResolveResult::NotFound,
1476 };
1477
1478 {
1479 let mut in_progress = AUTO_START_IN_PROGRESS.lock().await;
1480 if !in_progress.insert(daemon_id.clone()) {
1481 return ResolveResult::Starting {
1482 slug: slug.to_string(),
1483 };
1484 }
1485 }
1486
1487 let _guard = AutoStartGuard {
1488 daemon_id: daemon_id.clone(),
1489 };
1490
1491 let timeout = s.proxy_auto_start_timeout();
1492
1493 match tokio::time::timeout(
1494 timeout,
1495 try_auto_start_inner(slug, cached, &daemon_id, worktree_dir),
1496 )
1497 .await
1498 {
1499 Ok(result) => result,
1500 Err(_elapsed) => {
1501 log::warn!("Auto-start: total timeout ({timeout:?}) exceeded for daemon {daemon_id}");
1502 ResolveResult::Error(format!(
1503 "Auto-start for '{daemon_id}' timed out after {timeout:?}.\n\
1504 The daemon did not become ready and bind a port within the configured \
1505 proxy_auto_start_timeout.\n\
1506 Increase the timeout or check the daemon's logs for slow startup."
1507 ))
1508 }
1509 }
1510}
1511
1512async fn try_auto_start_inner(
1516 slug: &str,
1517 cached: &CachedSlugEntry,
1518 daemon_id: &DaemonId,
1519 worktree_dir: Option<&std::path::Path>,
1520) -> ResolveResult {
1521 let config_dir = worktree_dir.unwrap_or(&cached.dir);
1522
1523 let pt = match crate::pitchfork_toml::PitchforkToml::all_merged_from(config_dir) {
1524 Ok(pt) => pt,
1525 Err(e) => {
1526 log::warn!(
1527 "Auto-start: failed to load config from {}: {e}",
1528 config_dir.display()
1529 );
1530 return ResolveResult::NotFound;
1531 }
1532 };
1533
1534 let daemon_config = match pt.daemons.get(daemon_id) {
1535 Some(cfg) => cfg,
1536 None => {
1537 log::debug!(
1538 "Auto-start: daemon {daemon_id} not found in config at {}",
1539 config_dir.display()
1540 );
1541 return ResolveResult::NotFound;
1542 }
1543 };
1544
1545 let opts = crate::ipc::batch::StartOptions {
1546 quiet: true,
1547 ..crate::ipc::batch::StartOptions::default()
1548 };
1549 let mut run_opts =
1550 match crate::ipc::batch::build_run_options(daemon_id, daemon_config, Some(&opts)) {
1551 Ok(o) => o,
1552 Err(e) => {
1553 log::warn!("Auto-start: failed to build run options for {daemon_id}: {e}");
1554 return ResolveResult::Error(format!("Failed to build run options: {e}"));
1555 }
1556 };
1557
1558 if run_opts.dir.0.as_os_str().is_empty() {
1561 run_opts.dir = crate::config_types::Dir(config_dir.to_path_buf());
1562 }
1563
1564 log::info!("Auto-start: starting daemon {daemon_id} for slug '{slug}'");
1565
1566 let run_result = SUPERVISOR.run(run_opts).await;
1567
1568 if let Err(e) = run_result {
1569 log::warn!("Auto-start: failed to start daemon {daemon_id}: {e}");
1570 return ResolveResult::Error(format!("Failed to start daemon: {e}"));
1571 }
1572
1573 let poll_interval = std::time::Duration::from_millis(250);
1574
1575 loop {
1576 let daemons = {
1577 let sf = SUPERVISOR.state_file.lock().await;
1578 sf.daemons.clone()
1579 };
1580
1581 if let Some(d) = daemons.get(daemon_id) {
1582 if d.status.is_running() {
1583 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1584 log::info!("Auto-start: daemon {daemon_id} is ready on port {port}");
1585 return ResolveResult::Ready(port);
1586 }
1587 } else {
1588 log::warn!(
1589 "Auto-start: daemon {daemon_id} is no longer running (status: {})",
1590 d.status
1591 );
1592 return ResolveResult::Error(format!(
1593 "Daemon '{daemon_id}' started but exited unexpectedly.\n\
1594 Check its logs for errors."
1595 ));
1596 }
1597 } else {
1598 log::warn!("Auto-start: daemon {daemon_id} not found in state file after start");
1599 return ResolveResult::Error(format!(
1600 "Daemon '{daemon_id}' started but disappeared from the state file.\n\
1601 Check its logs for errors."
1602 ));
1603 }
1604
1605 tokio::time::sleep(poll_interval).await;
1606 }
1607}
1608
1609fn strip_tld(host: &str, tld: &str) -> Option<String> {
1616 host.strip_suffix(&format!(".{tld}"))
1617 .filter(|s| !s.is_empty())
1618 .map(str::to_string)
1619}
1620
1621fn bind_error_message(port: u16, err: &std::io::Error) -> String {
1623 if port < 1024 {
1624 format!(
1625 "Failed to bind proxy server to port {port}: {err}\n\
1626 Hint: ports below 1024 require elevated privileges. \
1627 Try: sudo pitchfork supervisor start"
1628 )
1629 } else {
1630 format!(
1631 "Failed to bind proxy server to port {port}: {err}\n\
1632 Hint: another process may already be using this port."
1633 )
1634 }
1635}
1636
1637fn starting_html_response(slug: &str, raw_host: &str) -> Response {
1642 let escaped_slug = slug
1643 .replace('&', "&")
1644 .replace('<', "<")
1645 .replace('>', ">")
1646 .replace('"', """)
1647 .replace('\'', "'");
1648 let escaped_host = raw_host
1649 .replace('&', "&")
1650 .replace('<', "<")
1651 .replace('>', ">")
1652 .replace('"', """)
1653 .replace('\'', "'");
1654
1655 let html = format!(
1656 r##"<!DOCTYPE html>
1657<html lang="en">
1658<head>
1659 <meta charset="UTF-8">
1660 <meta name="viewport" content="width=device-width, initial-scale=1">
1661 <meta http-equiv="refresh" content="2">
1662 <title>Starting {escaped_slug}… — pitchfork</title>
1663 <style>
1664 * {{ margin: 0; padding: 0; box-sizing: border-box; }}
1665 body {{
1666 font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
1667 background: #0f1117;
1668 color: #e1e4e8;
1669 display: flex;
1670 align-items: center;
1671 justify-content: center;
1672 min-height: 100vh;
1673 }}
1674 .container {{
1675 text-align: center;
1676 max-width: 480px;
1677 padding: 2rem;
1678 }}
1679 .spinner {{
1680 width: 48px;
1681 height: 48px;
1682 border: 4px solid rgba(255, 255, 255, 0.1);
1683 border-top-color: #58a6ff;
1684 border-radius: 50%;
1685 animation: spin 0.8s linear infinite;
1686 margin: 0 auto 1.5rem;
1687 }}
1688 @keyframes spin {{
1689 to {{ transform: rotate(360deg); }}
1690 }}
1691 h1 {{
1692 font-size: 1.5rem;
1693 font-weight: 600;
1694 margin-bottom: 0.5rem;
1695 }}
1696 .slug {{
1697 color: #58a6ff;
1698 font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
1699 }}
1700 .host {{
1701 color: #8b949e;
1702 font-size: 0.875rem;
1703 margin-top: 0.25rem;
1704 }}
1705 .hint {{
1706 color: #8b949e;
1707 font-size: 0.8rem;
1708 margin-top: 1.5rem;
1709 }}
1710 </style>
1711</head>
1712<body>
1713 <div class="container">
1714 <div class="spinner"></div>
1715 <h1>Starting <span class="slug">{escaped_slug}</span>…</h1>
1716 <p class="host">{escaped_host}</p>
1717 <p class="hint">This page will refresh automatically when the daemon is ready.</p>
1718 </div>
1719</body>
1720</html>"##
1721 );
1722
1723 Response::builder()
1724 .status(StatusCode::SERVICE_UNAVAILABLE)
1725 .header("content-type", "text/html; charset=utf-8")
1726 .header("retry-after", "2")
1727 .body(Body::from(html))
1728 .unwrap_or_else(|_| (StatusCode::SERVICE_UNAVAILABLE, "Starting…").into_response())
1729}
1730
1731async fn redirect_to_https_handler(req: Request) -> Response {
1740 if req.headers().contains_key("upgrade") {
1742 log::warn!("Dropping plain-HTTP WebSocket upgrade attempt — use wss:// instead of ws://");
1743 return (
1744 StatusCode::BAD_REQUEST,
1745 "WebSocket over plain HTTP is not supported on the HTTPS port. Use wss:// instead.",
1746 )
1747 .into_response();
1748 }
1749
1750 let raw_host = get_request_host(&req);
1751 let Some(raw_host) = raw_host else {
1752 return (StatusCode::BAD_REQUEST, "Missing Host header").into_response();
1753 };
1754
1755 let hostname = if raw_host.starts_with('[') {
1757 raw_host
1759 .split_once("]:")
1760 .map(|(host, _)| host)
1761 .unwrap_or(&raw_host)
1762 .trim_start_matches('[')
1763 .trim_end_matches(']')
1764 } else {
1765 let mut parts = raw_host.rsplitn(2, ':');
1767 let last = parts.next().unwrap_or(&raw_host);
1768 parts.next().unwrap_or(last)
1769 };
1770
1771 let path = req
1772 .uri()
1773 .path_and_query()
1774 .map(|pq| pq.as_str())
1775 .unwrap_or("/");
1776
1777 let https_port = match u16::try_from(settings().proxy.port).ok().filter(|&p| p > 0) {
1778 Some(443) | None => String::new(),
1779 Some(port) => format!(":{port}"),
1780 };
1781
1782 let host_for_url = if raw_host.starts_with('[') {
1783 format!("[{hostname}]")
1784 } else {
1785 hostname.to_string()
1786 };
1787
1788 let location = format!("https://{host_for_url}{https_port}{path}");
1789 (
1790 StatusCode::FOUND,
1791 [(axum::http::header::LOCATION, location)],
1792 )
1793 .into_response()
1794}
1795
1796fn error_response(status: StatusCode, message: &str) -> Response {
1798 (status, message.to_string()).into_response()
1799}
1800
1801#[cfg(test)]
1802mod tests {
1803 use super::*;
1804
1805 #[test]
1806 fn test_strip_tld() {
1807 assert_eq!(
1808 strip_tld("api.myproject.localhost", "localhost"),
1809 Some("api.myproject".to_string())
1810 );
1811 assert_eq!(
1812 strip_tld("api.localhost", "localhost"),
1813 Some("api".to_string())
1814 );
1815 assert_eq!(strip_tld("localhost", "localhost"), None);
1816 assert_eq!(
1817 strip_tld("api.myproject.test", "test"),
1818 Some("api.myproject".to_string())
1819 );
1820 assert_eq!(strip_tld("other.com", "localhost"), None);
1821 }
1822
1823 fn make_entry(name: &str) -> CachedSlugEntry {
1824 CachedSlugEntry {
1825 slug: name.to_string(),
1826 namespace: None,
1827 daemon_name: name.to_string(),
1828 dir: std::path::PathBuf::from(format!("/tmp/{name}")),
1829 worktrees: vec![],
1830 }
1831 }
1832
1833 #[test]
1834 fn test_wildcard_slug_lookup_exact_match() {
1835 let mut entries = std::collections::HashMap::new();
1836 entries.insert("myapp".to_string(), make_entry("myapp"));
1837 let result = wildcard_slug_lookup("myapp", &entries, true);
1839 assert!(result.is_some());
1840 assert_eq!(result.unwrap().daemon_name, "myapp");
1841 }
1842
1843 #[test]
1844 fn test_wildcard_slug_lookup_subdomain_fallback() {
1845 let mut entries = std::collections::HashMap::new();
1846 entries.insert("myapp".to_string(), make_entry("myapp"));
1847 let result = wildcard_slug_lookup("tenant.myapp", &entries, true);
1849 assert!(result.is_some());
1850 assert_eq!(result.unwrap().daemon_name, "myapp");
1851 }
1852
1853 #[test]
1854 fn test_wildcard_slug_lookup_nested_fallback() {
1855 let mut entries = std::collections::HashMap::new();
1856 entries.insert("myapp".to_string(), make_entry("myapp"));
1857 let result = wildcard_slug_lookup("a.b.myapp", &entries, true);
1859 assert!(result.is_some());
1860 assert_eq!(result.unwrap().daemon_name, "myapp");
1861 }
1862
1863 #[test]
1864 fn test_wildcard_slug_lookup_no_match() {
1865 let entries = std::collections::HashMap::new();
1866 let result = wildcard_slug_lookup("tenant.myapp", &entries, true);
1868 assert!(result.is_none());
1869 }
1870
1871 #[test]
1872 fn test_wildcard_slug_lookup_disabled() {
1873 let mut entries = std::collections::HashMap::new();
1874 entries.insert("myapp".to_string(), make_entry("myapp"));
1875 let result = wildcard_slug_lookup("tenant.myapp", &entries, false);
1877 assert!(result.is_none());
1878 let result = wildcard_slug_lookup("myapp", &entries, false);
1880 assert!(result.is_some());
1881 }
1882
1883 #[test]
1884 fn test_wildcard_slug_lookup_exact_beats_wildcard() {
1885 let mut entries = std::collections::HashMap::new();
1886 entries.insert("myapp".to_string(), make_entry("myapp"));
1887 let mut tenant_entry = make_entry("tenant-daemon");
1888 tenant_entry.slug = "tenant.myapp".to_string();
1889 entries.insert("tenant.myapp".to_string(), tenant_entry);
1890 let result = wildcard_slug_lookup("tenant.myapp", &entries, true);
1892 assert!(result.is_some());
1893 assert_eq!(result.unwrap().daemon_name, "tenant-daemon");
1894 }
1895
1896 #[cfg(feature = "proxy-tls")]
1897 #[test]
1898 fn test_generate_ca() {
1899 let dir = tempfile::tempdir().unwrap();
1900 let cert_path = dir.path().join("ca.pem");
1901 let key_path = dir.path().join("ca-key.pem");
1902
1903 generate_ca(&cert_path, &key_path).unwrap();
1904
1905 assert!(cert_path.exists(), "ca.pem should be created");
1906 assert!(key_path.exists(), "ca-key.pem should be created");
1907
1908 let cert_pem = std::fs::read_to_string(&cert_path).unwrap();
1909 let key_pem = std::fs::read_to_string(&key_path).unwrap();
1910
1911 assert!(cert_pem.contains("BEGIN CERTIFICATE"), "should be PEM cert");
1912 assert!(
1913 key_pem.contains("BEGIN") && key_pem.contains("PRIVATE KEY"),
1914 "should be PEM key"
1915 );
1916 }
1917}