1use std::net::SocketAddr;
11use std::sync::Arc;
12
13use axum::Router;
14use axum::body::Body;
15use axum::extract::{Request, State};
16use axum::http::{HeaderValue, StatusCode, Uri};
17use axum::response::{IntoResponse, Response};
18use hyper::header::HOST;
19
20const PITCHFORK_HEADER: &str = "x-pitchfork";
22
23const PROXY_HOPS_HEADER: &str = "x-pitchfork-hops";
26
27const MAX_PROXY_HOPS: u64 = 5;
29
30const HOP_BY_HOP_HEADERS: &[&str] = &[
33 "connection",
34 "keep-alive",
35 "proxy-connection",
36 "transfer-encoding",
37 "upgrade",
38];
39
40use hyper_util::client::legacy::Client;
41use hyper_util::client::legacy::connect::HttpConnector;
42use hyper_util::rt::TokioExecutor;
43use tokio::net::TcpListener;
44
45use crate::daemon_id::DaemonId;
46use crate::settings::settings;
47use crate::supervisor::SUPERVISOR;
48
49const SLUG_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(2);
62
63#[derive(Clone, Debug)]
65struct CachedSlugEntry {
66 namespace: Option<String>,
68 daemon_name: String,
70 dir: std::path::PathBuf,
72}
73
74struct SlugCache {
76 entries: Arc<std::collections::HashMap<String, CachedSlugEntry>>,
77 expires_at: std::time::Instant,
78}
79
80static SLUG_CACHE: once_cell::sync::Lazy<tokio::sync::Mutex<SlugCache>> =
81 once_cell::sync::Lazy::new(|| {
82 tokio::sync::Mutex::new(SlugCache {
83 entries: Arc::new(std::collections::HashMap::new()),
84 expires_at: std::time::Instant::now(), })
86 });
87
88fn build_slug_entries() -> std::collections::HashMap<String, CachedSlugEntry> {
91 let global_slugs = crate::pitchfork_toml::PitchforkToml::read_global_slugs();
92 let mut entries = std::collections::HashMap::with_capacity(global_slugs.len());
93 for (slug, entry) in &global_slugs {
94 let ns = crate::pitchfork_toml::PitchforkToml::namespace_for_dir(&entry.dir).ok();
95 let daemon_name = entry.daemon.as_deref().unwrap_or(slug).to_string();
96 entries.insert(
97 slug.clone(),
98 CachedSlugEntry {
99 namespace: ns,
100 daemon_name,
101 dir: entry.dir.clone(),
102 },
103 );
104 }
105 entries
106}
107
108async fn get_cached_slugs() -> Arc<std::collections::HashMap<String, CachedSlugEntry>> {
114 {
116 let cache = SLUG_CACHE.lock().await;
117 if std::time::Instant::now() < cache.expires_at {
118 return Arc::clone(&cache.entries);
119 }
120 } let new_entries = Arc::new(build_slug_entries());
124
125 {
127 let mut cache = SLUG_CACHE.lock().await;
128 cache.entries = Arc::clone(&new_entries);
129 cache.expires_at = std::time::Instant::now() + SLUG_CACHE_TTL;
130 }
131
132 new_entries
133}
134
135async fn cached_slug_lookup(subdomain: &str) -> Option<CachedSlugEntry> {
137 let entries = get_cached_slugs().await;
138 entries.get(subdomain).cloned()
139}
140
141static AUTO_START_IN_PROGRESS: once_cell::sync::Lazy<
148 tokio::sync::Mutex<std::collections::HashSet<DaemonId>>,
149> = once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new(std::collections::HashSet::new()));
150
151enum ResolveResult {
153 Ready(u16),
156 Starting { slug: String },
158 NotFound,
160 Error(String),
162}
163
164type OnErrorFn = Arc<dyn Fn(&str) + Send + Sync>;
167
168#[derive(Clone)]
169struct ProxyState {
170 client: Arc<Client<HttpConnector, Body>>,
172 tld: String,
174 is_tls: bool,
176 on_error: Option<OnErrorFn>,
178}
179
180pub async fn serve(
188 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
189) -> crate::Result<()> {
190 let s = settings();
191 let Some(effective_port) = u16::try_from(s.proxy.port).ok().filter(|&p| p > 0) else {
192 let msg = format!(
193 "proxy.port {} is out of valid port range (1-65535), proxy server cannot start",
194 s.proxy.port
195 );
196 let _ = bind_tx.send(Err(msg.clone()));
197 miette::bail!("{msg}");
198 };
199
200 let mut connector = HttpConnector::new();
201 connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
205
206 let client = Client::builder(TokioExecutor::new())
207 .pool_idle_timeout(std::time::Duration::from_secs(30))
210 .build(connector);
211
212 let state = ProxyState {
213 client: Arc::new(client),
214 tld: s.proxy.tld.clone(),
215 is_tls: s.proxy.https,
216 on_error: None,
217 };
218
219 let app = Router::new().fallback(proxy_handler).with_state(state);
220
221 let bind_ip: std::net::IpAddr = match s.proxy.host.parse() {
223 Ok(ip) => ip,
224 Err(_) => {
225 log::warn!(
226 "proxy.host {:?} is not a valid IP address — falling back to 127.0.0.1. \
227 The proxy will only be reachable on the loopback interface.",
228 s.proxy.host
229 );
230 std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
231 }
232 };
233 let addr = SocketAddr::from((bind_ip, effective_port));
234
235 if s.proxy.https {
236 serve_https_with_http_fallback(app, addr, s, effective_port, bind_tx).await
237 } else {
238 serve_http(app, addr, effective_port, bind_tx).await
239 }
240}
241
242async fn serve_http(
244 app: Router,
245 addr: SocketAddr,
246 effective_port: u16,
247 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
248) -> crate::Result<()> {
249 let listener = match TcpListener::bind(addr).await {
250 Ok(l) => {
251 let _ = bind_tx.send(Ok(()));
252 l
253 }
254 Err(e) => {
255 let msg = bind_error_message(effective_port, &e);
256 let _ = bind_tx.send(Err(msg.clone()));
257 return Err(miette::miette!("{msg}"));
258 }
259 };
260
261 log::info!("Proxy server listening on http://{addr}");
262 if effective_port < 1024 {
263 log::info!(
264 "Note: port {effective_port} is a privileged port. \
265 The supervisor must be started with sudo to bind to this port."
266 );
267 }
268 axum::serve(
269 listener,
270 app.into_make_service_with_connect_info::<SocketAddr>(),
271 )
272 .await
273 .map_err(|e| miette::miette!("Proxy server error: {e}"))?;
274 Ok(())
275}
276
277#[cfg(feature = "proxy-tls")]
284async fn serve_https_with_http_fallback(
285 app: Router,
286 addr: SocketAddr,
287 s: &crate::settings::Settings,
288 effective_port: u16,
289 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
290) -> crate::Result<()> {
291 use rustls::ServerConfig;
292 use tokio_rustls::TlsAcceptor;
293
294 let (ca_cert_path, ca_key_path) = resolve_tls_paths(s);
295
296 if !ca_cert_path.exists() || !ca_key_path.exists() {
298 generate_ca(&ca_cert_path, &ca_key_path)?;
299 log::info!(
300 "Generated local CA certificate at {}",
301 ca_cert_path.display()
302 );
303 log::info!("To trust the CA in your browser, run: pitchfork proxy trust");
304 }
305
306 let _ = rustls::crypto::ring::default_provider().install_default();
308
309 let resolver = SniCertResolver::new(&ca_cert_path, &ca_key_path)?;
311
312 let tls_config = ServerConfig::builder()
313 .with_no_client_auth()
314 .with_cert_resolver(Arc::new(resolver));
315
316 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
317
318 let listener = match TcpListener::bind(addr).await {
319 Ok(l) => {
320 let _ = bind_tx.send(Ok(()));
321 l
322 }
323 Err(e) => {
324 let msg = bind_error_message(effective_port, &e);
325 let _ = bind_tx.send(Err(msg.clone()));
326 return Err(miette::miette!("{msg}"));
327 }
328 };
329
330 log::info!("Proxy server listening on https://{addr} (HTTP also accepted)");
331 if effective_port < 1024 {
332 log::info!(
333 "Note: port {effective_port} is a privileged port. \
334 The supervisor must be started with sudo to bind to this port."
335 );
336 }
337
338 loop {
340 let (stream, _peer_addr) = match listener.accept().await {
341 Ok(conn) => conn,
342 Err(e) => {
343 log::warn!("Accept error (will retry): {e}");
346 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
347 continue;
348 }
349 };
350
351 let acceptor = acceptor.clone();
352 let app = app.clone();
353
354 tokio::spawn(async move {
355 let mut peek_buf = [0u8; 1];
358 match stream.peek(&mut peek_buf).await {
359 Ok(0) | Err(_) => return, _ => {}
361 }
362
363 if peek_buf[0] == 0x16 {
364 match acceptor.accept(stream).await {
366 Ok(tls_stream) => {
367 let io = hyper_util::rt::TokioIo::new(tls_stream);
368 let svc = hyper_util::service::TowerToHyperService::new(app);
369 let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
370 .serve_connection_with_upgrades(io, svc)
371 .await;
372 }
373 Err(e) => {
374 log::debug!("TLS handshake error: {e}");
375 }
376 }
377 } else {
378 let io = hyper_util::rt::TokioIo::new(stream);
380 let svc = hyper_util::service::TowerToHyperService::new(app);
381 let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
382 .serve_connection_with_upgrades(io, svc)
383 .await;
384 }
385 });
386 }
387}
388
389#[cfg(not(feature = "proxy-tls"))]
391async fn serve_https_with_http_fallback(
392 _app: Router,
393 _addr: SocketAddr,
394 _s: &crate::settings::Settings,
395 _effective_port: u16,
396 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
397) -> crate::Result<()> {
398 let msg = "HTTPS proxy support requires the `proxy-tls` feature.\n\
399 Rebuild pitchfork with: cargo build --features proxy-tls"
400 .to_string();
401 let _ = bind_tx.send(Err(msg.clone()));
402 miette::bail!("{msg}")
403}
404
405#[cfg(feature = "proxy-tls")]
410fn resolve_tls_paths(s: &crate::settings::Settings) -> (std::path::PathBuf, std::path::PathBuf) {
411 let proxy_dir = crate::env::PITCHFORK_STATE_DIR.join("proxy");
412 let resolve = |configured: &str, default: &str| {
413 if configured.is_empty() {
414 proxy_dir.join(default)
415 } else {
416 std::path::PathBuf::from(configured)
417 }
418 };
419 (
420 resolve(&s.proxy.tls_cert, "ca.pem"),
421 resolve(&s.proxy.tls_key, "ca-key.pem"),
422 )
423}
424
425#[cfg(feature = "proxy-tls")]
430pub fn generate_ca(cert_path: &std::path::Path, key_path: &std::path::Path) -> crate::Result<()> {
431 use rcgen::{
432 BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyUsagePurpose,
433 };
434
435 if let Some(parent) = cert_path.parent() {
437 std::fs::create_dir_all(parent)
438 .map_err(|e| miette::miette!("Failed to create proxy cert directory: {e}"))?;
439 }
440
441 let mut params = CertificateParams::default();
442 let mut dn = DistinguishedName::new();
443 dn.push(DnType::CommonName, "Pitchfork Local CA");
444 dn.push(DnType::OrganizationName, "Pitchfork");
445 params.distinguished_name = dn;
446 params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
447 params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
448
449 let key_pair = rcgen::KeyPair::generate()
450 .map_err(|e| miette::miette!("Failed to generate CA key pair: {e}"))?;
451 let ca_cert = params
452 .self_signed(&key_pair)
453 .map_err(|e| miette::miette!("Failed to self-sign CA certificate: {e}"))?;
454
455 std::fs::write(cert_path, ca_cert.pem()).map_err(|e| {
457 miette::miette!(
458 "Failed to write CA certificate to {}: {e}",
459 cert_path.display()
460 )
461 })?;
462
463 {
467 use std::io::Write;
468 #[cfg(unix)]
469 {
470 use std::os::unix::fs::OpenOptionsExt;
471 std::fs::OpenOptions::new()
472 .write(true)
473 .create(true)
474 .truncate(true)
475 .mode(0o600)
476 .open(key_path)
477 .and_then(|mut f| f.write_all(key_pair.serialize_pem().as_bytes()))
478 .map_err(|e| {
479 miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
480 })?;
481 }
482 #[cfg(not(unix))]
483 {
484 std::fs::write(key_path, key_pair.serialize_pem()).map_err(|e| {
485 miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
486 })?;
487 log::debug!(
488 "CA private key written to {} (file permissions are not restricted \
489 on non-Unix platforms — consider restricting access manually)",
490 key_path.display()
491 );
492 }
493 }
494
495 Ok(())
496}
497
498#[cfg(feature = "proxy-tls")]
518struct SniCertResolver {
519 issuer: rcgen::Issuer<'static, rcgen::KeyPair>,
521 host_certs_dir: std::path::PathBuf,
523 cache: std::sync::Mutex<std::collections::HashMap<String, Arc<rustls::sign::CertifiedKey>>>,
525 pending: std::sync::Mutex<std::collections::HashSet<String>>,
529 pending_cv: std::sync::Condvar,
531}
532
533#[cfg(feature = "proxy-tls")]
534impl std::fmt::Debug for SniCertResolver {
535 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536 f.debug_struct("SniCertResolver").finish_non_exhaustive()
537 }
538}
539
540#[cfg(feature = "proxy-tls")]
541impl SniCertResolver {
542 fn new(ca_cert_path: &std::path::Path, ca_key_path: &std::path::Path) -> crate::Result<Self> {
544 let ca_key_pem = std::fs::read_to_string(ca_key_path)
545 .map_err(|e| miette::miette!("Failed to read CA key {}: {e}", ca_key_path.display()))?;
546 let ca_cert_pem = std::fs::read_to_string(ca_cert_path).map_err(|e| {
547 miette::miette!("Failed to read CA cert {}: {e}", ca_cert_path.display())
548 })?;
549
550 if !ca_cert_pem.contains("BEGIN CERTIFICATE") {
552 miette::bail!("CA cert file does not contain a valid PEM certificate");
553 }
554
555 let ca_key = rcgen::KeyPair::from_pem(&ca_key_pem)
556 .map_err(|e| miette::miette!("Failed to parse CA key: {e}"))?;
557
558 let issuer = rcgen::Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key)
560 .map_err(|e| miette::miette!("Failed to parse CA cert: {e}"))?;
561
562 let host_certs_dir = ca_cert_path
564 .parent()
565 .unwrap_or(std::path::Path::new("."))
566 .join("host-certs");
567 std::fs::create_dir_all(&host_certs_dir)
568 .map_err(|e| miette::miette!("Failed to create host-certs dir: {e}"))?;
569
570 Ok(Self {
571 issuer,
572 host_certs_dir,
573 cache: std::sync::Mutex::new(std::collections::HashMap::new()),
574 pending: std::sync::Mutex::new(std::collections::HashSet::new()),
575 pending_cv: std::sync::Condvar::new(),
576 })
577 }
578
579 fn get_or_create(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
602 {
604 let cache = self.cache.lock().ok()?;
605 if let Some(ck) = cache.get(domain) {
606 return Some(Arc::clone(ck));
607 }
608 } loop {
621 {
622 let mut pending = self.pending.lock().ok()?;
623 if pending.contains(domain) {
624 pending = self.pending_cv.wait(pending).ok()?;
626 drop(pending);
628 } else {
629 pending.insert(domain.to_string());
631 break;
632 }
633 } {
639 let cache = self.cache.lock().ok()?;
640 if let Some(ck) = cache.get(domain) {
641 return Some(Arc::clone(ck));
642 }
643 } } let result = self.get_or_create_inner(domain);
647
648 {
654 let mut pending = match self.pending.lock() {
655 Ok(g) => g,
656 Err(e) => e.into_inner(),
657 };
658 pending.remove(domain);
659 self.pending_cv.notify_all();
660 }
661
662 result
663 }
664
665 fn get_or_create_inner(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
667 let safe_name = domain.replace('.', "_").replace('*', "wildcard");
668 let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
669
670 if disk_path.exists() {
672 if let Ok(ck) = self.load_from_disk(&disk_path) {
673 let ck = Arc::new(ck);
674 if let Ok(mut cache) = self.cache.lock() {
675 cache.insert(domain.to_string(), Arc::clone(&ck));
676 }
677 return Some(ck);
678 }
679 let _ = std::fs::remove_file(&disk_path);
681 }
682
683 let ck = self.sign_for_domain(domain).ok()?;
685
686 let ck = Arc::new(ck);
687 if let Ok(mut cache) = self.cache.lock() {
688 cache.insert(domain.to_string(), Arc::clone(&ck));
689 }
690 Some(ck)
691 }
692
693 fn load_from_disk(&self, path: &std::path::Path) -> crate::Result<rustls::sign::CertifiedKey> {
698 use rustls::pki_types::CertificateDer;
699 use rustls_pemfile::{certs, private_key};
700
701 let pem = std::fs::read_to_string(path)
702 .map_err(|e| miette::miette!("Failed to read disk cert {}: {e}", path.display()))?;
703
704 let cert_ders: Vec<CertificateDer<'static>> = certs(&mut pem.as_bytes())
705 .collect::<Result<Vec<_>, _>>()
706 .map_err(|e| miette::miette!("Failed to parse certs from {}: {e}", path.display()))?;
707
708 if cert_ders.is_empty() {
709 miette::bail!("No certificates found in {}", path.display());
710 }
711
712 {
714 let (_, cert) = x509_parser::parse_x509_certificate(&cert_ders[0]).map_err(|e| {
715 miette::miette!("Failed to parse certificate from {}: {e}", path.display())
716 })?;
717 use chrono::Utc;
718 let now_ts = Utc::now().timestamp();
719 let not_after_ts = cert.validity().not_after.timestamp();
720 if not_after_ts < now_ts {
721 miette::bail!(
722 "Cached certificate at {} has expired — will regenerate",
723 path.display()
724 );
725 }
726 }
727
728 let key_der = private_key(&mut pem.as_bytes())
729 .map_err(|e| miette::miette!("Failed to parse key from {}: {e}", path.display()))?
730 .ok_or_else(|| miette::miette!("No private key found in {}", path.display()))?;
731
732 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
733 .map_err(|e| miette::miette!("Failed to create signing key from disk: {e}"))?;
734
735 Ok(rustls::sign::CertifiedKey::new(cert_ders, signing_key))
736 }
737
738 fn sign_for_domain(&self, domain: &str) -> crate::Result<rustls::sign::CertifiedKey> {
746 use rcgen::date_time_ymd;
747 use rcgen::{CertificateParams, DistinguishedName, DnType, SanType};
748 use rustls::pki_types::CertificateDer;
749 use rustls_pemfile::private_key;
750
751 let mut params = CertificateParams::default();
752 let mut dn = DistinguishedName::new();
753 dn.push(DnType::CommonName, domain);
754 params.distinguished_name = dn;
755
756 {
758 use chrono::{Datelike, Duration, Utc};
759 let yesterday = Utc::now() - Duration::days(1);
760 let expiry = Utc::now() + Duration::days(397);
763 params.not_before = date_time_ymd(
764 yesterday.year(),
765 yesterday.month() as u8,
766 yesterday.day() as u8,
767 );
768 params.not_after =
769 date_time_ymd(expiry.year(), expiry.month() as u8, expiry.day() as u8);
770 }
771
772 let mut sans =
774 vec![SanType::DnsName(domain.to_string().try_into().map_err(
775 |e| miette::miette!("Invalid domain name '{domain}': {e}"),
776 )?)];
777 if let Some(dot_pos) = domain.find('.') {
779 let parent = &domain[dot_pos + 1..];
780 if parent.contains('.') {
782 let wildcard = format!("*.{parent}");
783 if let Ok(wc) = wildcard.try_into() {
784 sans.push(SanType::DnsName(wc));
785 }
786 }
787 }
788 params.subject_alt_names = sans;
789
790 let leaf_key = rcgen::KeyPair::generate()
791 .map_err(|e| miette::miette!("Failed to generate leaf key: {e}"))?;
792 let leaf_cert = params
793 .signed_by(&leaf_key, &self.issuer)
794 .map_err(|e| miette::miette!("Failed to sign leaf cert for '{domain}': {e}"))?;
795
796 let cert_der = CertificateDer::from(leaf_cert.der().to_vec());
798 let key_pem = leaf_key.serialize_pem();
799 let key_der = private_key(&mut key_pem.as_bytes())
800 .map_err(|e| miette::miette!("Failed to parse leaf key PEM: {e}"))?
801 .ok_or_else(|| miette::miette!("No private key found in generated PEM"))?;
802
803 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
804 .map_err(|e| miette::miette!("Failed to create signing key: {e}"))?;
805
806 let safe_name = domain.replace('.', "_").replace('*', "wildcard");
809 let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
810 let combined_pem = format!("{}{}", leaf_cert.pem(), key_pem);
811 {
812 use std::io::Write;
813 #[cfg(unix)]
814 {
815 use std::os::unix::fs::OpenOptionsExt;
816 if let Err(e) = std::fs::OpenOptions::new()
817 .write(true)
818 .create(true)
819 .truncate(true)
820 .mode(0o600)
821 .open(&disk_path)
822 .and_then(|mut f| f.write_all(combined_pem.as_bytes()))
823 {
824 log::warn!(
825 "Failed to persist cert for '{domain}' to {}: {e}",
826 disk_path.display()
827 );
828 }
829 }
830 #[cfg(not(unix))]
831 {
832 if let Err(e) = std::fs::write(&disk_path, combined_pem) {
833 log::warn!(
834 "Failed to persist cert for '{domain}' to {}: {e}",
835 disk_path.display()
836 );
837 } else {
838 log::debug!(
839 "Leaf cert for '{domain}' written to {} (file permissions are not \
840 restricted on non-Unix platforms — consider restricting access manually)",
841 disk_path.display()
842 );
843 }
844 }
845 }
846
847 Ok(rustls::sign::CertifiedKey::new(vec![cert_der], signing_key))
848 }
849}
850
851#[cfg(feature = "proxy-tls")]
852impl rustls::server::ResolvesServerCert for SniCertResolver {
853 fn resolve(
854 &self,
855 client_hello: rustls::server::ClientHello<'_>,
856 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
857 let domain = client_hello.server_name()?;
858 self.get_or_create(domain)
859 }
860}
861
862fn get_request_host(req: &Request) -> Option<String> {
868 let authority = req
870 .uri()
871 .authority()
872 .map(|a| a.as_str().to_string())
873 .filter(|s| !s.is_empty());
874
875 authority.or_else(|| {
876 req.headers()
877 .get(HOST)
878 .and_then(|h| h.to_str().ok())
879 .map(str::to_string)
880 })
881}
882
883fn inject_forwarded_headers(req: &mut Request, is_tls: bool, host_header: &str) {
894 let remote_addr = req
895 .extensions()
896 .get::<axum::extract::ConnectInfo<SocketAddr>>()
897 .map(|ci| ci.0.ip().to_string())
898 .unwrap_or_else(|| "127.0.0.1".to_string());
899
900 let proto = if is_tls { "https" } else { "http" };
901 let default_port = if is_tls { "443" } else { "80" };
902
903 let forwarded_for = remote_addr.clone();
906 let forwarded_proto = proto.to_string();
907 let forwarded_host = host_header.to_string();
908 let forwarded_port = host_header
909 .rsplit_once(':')
910 .map(|(_, port)| port.to_string())
911 .unwrap_or_else(|| default_port.to_string());
912
913 for name in [
919 "x-forwarded-for",
920 "x-forwarded-proto",
921 "x-forwarded-host",
922 "x-forwarded-port",
923 "forwarded",
924 ] {
925 if let Ok(header_name) = axum::http::HeaderName::from_bytes(name.as_bytes()) {
926 req.headers_mut().remove(&header_name);
927 }
928 }
929
930 let headers = [
931 ("x-forwarded-for", forwarded_for),
932 ("x-forwarded-proto", forwarded_proto),
933 ("x-forwarded-host", forwarded_host),
934 ("x-forwarded-port", forwarded_port),
935 ];
936
937 for (name, value) in headers {
938 if let Ok(v) = HeaderValue::from_str(&value) {
939 let header_name = axum::http::HeaderName::from_static(name);
940 req.headers_mut().insert(header_name, v);
941 }
942 }
943}
944
945async fn proxy_handler(State(state): State<ProxyState>, mut req: Request) -> Response {
950 let Some(raw_host) = get_request_host(&req) else {
952 return error_response(StatusCode::BAD_REQUEST, "Missing Host header");
953 };
954 let host = if raw_host.starts_with('[') {
958 raw_host
960 .split("]:")
961 .next()
962 .unwrap_or(&raw_host)
963 .trim_start_matches('[')
964 .trim_end_matches(']')
965 .to_string()
966 } else {
967 raw_host.split(':').next().unwrap_or(&raw_host).to_string()
969 };
970
971 let is_from_pitchfork = req.headers().contains_key(PROXY_HOPS_HEADER);
982 let hops: u64 = if is_from_pitchfork {
983 req.headers()
984 .get(PROXY_HOPS_HEADER)
985 .and_then(|v| v.to_str().ok())
986 .and_then(|s| s.parse().ok())
987 .unwrap_or(0)
988 } else {
989 0
991 };
992 if hops >= MAX_PROXY_HOPS {
993 return error_response(
994 StatusCode::LOOP_DETECTED,
995 &format!(
996 "Loop detected for '{host}': request has passed through the proxy {hops} times.\n\
997 This usually means a backend is proxying back through pitchfork without rewriting \n\
998 the Host header. If you use Vite/webpack proxy, set changeOrigin: true."
999 ),
1000 );
1001 }
1002
1003 let target_port = match resolve_target(&host, &state.tld).await {
1005 ResolveResult::Ready(port) => port,
1006 ResolveResult::Starting { slug } => {
1007 return starting_html_response(&slug, &raw_host);
1008 }
1009 ResolveResult::NotFound => {
1010 return error_response(
1011 StatusCode::BAD_GATEWAY,
1012 &format!(
1013 "No daemon found for host '{host}'.\n\
1014 Make sure the daemon has a slug, is running, and has a port configured.\n\
1015 Expected format: <slug>.{tld}",
1016 tld = state.tld
1017 ),
1018 );
1019 }
1020 ResolveResult::Error(msg) => {
1021 return error_response(StatusCode::BAD_GATEWAY, &msg);
1022 }
1023 };
1024
1025 let path_and_query = req
1027 .uri()
1028 .path_and_query()
1029 .map(|pq| pq.as_str())
1030 .unwrap_or("/");
1031
1032 let forward_uri = match Uri::builder()
1033 .scheme("http")
1034 .authority(format!("localhost:{target_port}"))
1035 .path_and_query(path_and_query)
1036 .build()
1037 {
1038 Ok(uri) => uri,
1039 Err(e) => {
1040 return error_response(
1041 StatusCode::INTERNAL_SERVER_ERROR,
1042 &format!("Failed to build forward URI: {e}"),
1043 );
1044 }
1045 };
1046
1047 *req.uri_mut() = forward_uri;
1049 req.headers_mut().insert(
1050 HOST,
1051 HeaderValue::from_str(&format!("localhost:{target_port}"))
1052 .unwrap_or_else(|_| HeaderValue::from_static("localhost")),
1053 );
1054
1055 inject_forwarded_headers(&mut req, state.is_tls, &raw_host);
1057
1058 if let Ok(v) = HeaderValue::from_str(&(hops + 1).to_string()) {
1060 req.headers_mut()
1061 .insert(axum::http::HeaderName::from_static(PROXY_HOPS_HEADER), v);
1062 }
1063
1064 let pseudo_headers: Vec<_> = req
1069 .headers()
1070 .keys()
1071 .filter(|k| k.as_str().starts_with(':'))
1072 .cloned()
1073 .collect();
1074 for key in pseudo_headers {
1075 req.headers_mut().remove(&key);
1076 }
1077
1078 let client_upgrade = hyper::upgrade::on(&mut req);
1080
1081 let result = match tokio::time::timeout(
1089 std::time::Duration::from_secs(120),
1090 state.client.request(req),
1091 )
1092 .await
1093 {
1094 Ok(r) => r,
1095 Err(_elapsed) => {
1096 let msg = format!(
1097 "Request to daemon on port {target_port} timed out after 120 s.\n\
1098 The daemon accepted the connection but did not respond in time."
1099 );
1100 log::warn!("{msg}");
1101 if let Some(ref on_error) = state.on_error {
1102 on_error(&msg);
1103 }
1104 return error_response(StatusCode::GATEWAY_TIMEOUT, &msg);
1105 }
1106 };
1107 match result {
1108 Ok(mut resp) => {
1109 let backend_upgrade = hyper::upgrade::on(&mut resp);
1111 let (mut parts, body) = resp.into_parts();
1112
1113 parts.headers.insert(
1115 axum::http::HeaderName::from_static(PITCHFORK_HEADER),
1116 HeaderValue::from_static("1"),
1117 );
1118
1119 parts.headers.remove(PROXY_HOPS_HEADER);
1121
1122 if state.is_tls && parts.status != StatusCode::SWITCHING_PROTOCOLS {
1127 for h in HOP_BY_HOP_HEADERS {
1128 if let Ok(name) = axum::http::HeaderName::from_bytes(h.as_bytes()) {
1129 parts.headers.remove(&name);
1130 }
1131 }
1132 }
1133
1134 if parts.status == StatusCode::SWITCHING_PROTOCOLS {
1136 tokio::spawn(async move {
1141 if let (Ok(client_upgraded), Ok(backend_upgraded)) =
1142 (client_upgrade.await, backend_upgrade.await)
1143 {
1144 let mut client_io = hyper_util::rt::TokioIo::new(client_upgraded);
1145 let mut backend_io = hyper_util::rt::TokioIo::new(backend_upgraded);
1146 let _ =
1154 tokio::io::copy_bidirectional(&mut client_io, &mut backend_io).await;
1155 }
1156 });
1157 return Response::from_parts(parts, Body::empty());
1158 }
1159
1160 Response::from_parts(parts, Body::new(body))
1163 }
1164 Err(e) => {
1165 let msg = format!(
1166 "Failed to connect to daemon on port {target_port}: {e}\n\
1167 The daemon may have stopped or is not yet ready."
1168 );
1169 if let Some(ref on_error) = state.on_error {
1170 on_error(&msg);
1171 } else {
1172 log::warn!("{msg}");
1173 }
1174 error_response(StatusCode::BAD_GATEWAY, &msg)
1175 }
1176 }
1177}
1178
1179async fn resolve_target(host: &str, tld: &str) -> ResolveResult {
1198 let Some(subdomain) = strip_tld(host, tld) else {
1200 return ResolveResult::NotFound;
1201 };
1202
1203 let Some(cached) = cached_slug_lookup(&subdomain).await else {
1205 return ResolveResult::NotFound;
1206 };
1207
1208 let daemon_name = &cached.daemon_name;
1209 let expected_namespace = &cached.namespace;
1210
1211 let daemons = {
1213 let state_file = SUPERVISOR.state_file.lock().await;
1214 state_file.daemons.clone()
1215 };
1216
1217 let running_matches: Vec<(&DaemonId, &crate::daemon::Daemon)> = daemons
1220 .iter()
1221 .filter(|(id, d)| {
1222 id.name() == daemon_name
1223 && d.status.is_running()
1224 && match expected_namespace {
1225 Some(ns) => id.namespace() == ns,
1226 None => true,
1227 }
1228 })
1229 .collect();
1230
1231 match running_matches.as_slice() {
1232 [] => {
1233 try_auto_start(&subdomain, &cached).await
1235 }
1236 [(_, d)] => {
1237 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1238 ResolveResult::Ready(port)
1239 } else {
1240 ResolveResult::NotFound
1241 }
1242 }
1243 _ => {
1244 let d = running_matches[0].1;
1245 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1246 ResolveResult::Ready(port)
1247 } else {
1248 ResolveResult::NotFound
1249 }
1250 }
1251 }
1252}
1253
1254struct AutoStartGuard {
1261 daemon_id: DaemonId,
1262}
1263
1264impl Drop for AutoStartGuard {
1265 fn drop(&mut self) {
1266 let daemon_id = self.daemon_id.clone();
1267 tokio::spawn(async move {
1271 AUTO_START_IN_PROGRESS.lock().await.remove(&daemon_id);
1272 });
1273 }
1274}
1275
1276async fn try_auto_start(slug: &str, cached: &CachedSlugEntry) -> ResolveResult {
1287 let s = settings();
1288 if !s.proxy.auto_start {
1289 return ResolveResult::NotFound;
1290 }
1291
1292 let ns = cached
1297 .namespace
1298 .clone()
1299 .unwrap_or_else(|| "global".to_string());
1300 let daemon_id = match DaemonId::try_new(&ns, &cached.daemon_name) {
1301 Ok(id) => id,
1302 Err(_) => return ResolveResult::NotFound,
1303 };
1304
1305 {
1308 let mut in_progress = AUTO_START_IN_PROGRESS.lock().await;
1309 if !in_progress.insert(daemon_id.clone()) {
1310 return ResolveResult::Starting {
1311 slug: slug.to_string(),
1312 };
1313 }
1314 }
1315
1316 let _guard = AutoStartGuard {
1318 daemon_id: daemon_id.clone(),
1319 };
1320
1321 let timeout = s.proxy_auto_start_timeout();
1325
1326 match tokio::time::timeout(timeout, try_auto_start_inner(slug, cached, &daemon_id)).await {
1327 Ok(result) => result,
1328 Err(_elapsed) => {
1329 log::warn!("Auto-start: total timeout ({timeout:?}) exceeded for daemon {daemon_id}");
1330 ResolveResult::Error(format!(
1331 "Auto-start for '{daemon_id}' timed out after {timeout:?}.\n\
1332 The daemon did not become ready and bind a port within the configured \
1333 proxy_auto_start_timeout.\n\
1334 Increase the timeout or check the daemon's logs for slow startup."
1335 ))
1336 }
1337 }
1338}
1339
1340async fn try_auto_start_inner(
1344 slug: &str,
1345 cached: &CachedSlugEntry,
1346 daemon_id: &DaemonId,
1347) -> ResolveResult {
1348 let pt = match crate::pitchfork_toml::PitchforkToml::all_merged_from(&cached.dir) {
1350 Ok(pt) => pt,
1351 Err(e) => {
1352 log::warn!(
1353 "Auto-start: failed to load config from {}: {e}",
1354 cached.dir.display()
1355 );
1356 return ResolveResult::NotFound;
1357 }
1358 };
1359
1360 let daemon_config = match pt.daemons.get(daemon_id) {
1361 Some(cfg) => cfg,
1362 None => {
1363 log::debug!(
1364 "Auto-start: daemon {daemon_id} not found in config at {}",
1365 cached.dir.display()
1366 );
1367 return ResolveResult::NotFound;
1368 }
1369 };
1370
1371 let opts = crate::ipc::batch::StartOptions::default();
1375 let mut run_opts = match crate::ipc::batch::build_run_options(daemon_id, daemon_config, &opts) {
1376 Ok(o) => o,
1377 Err(e) => {
1378 log::warn!("Auto-start: failed to build run options for {daemon_id}: {e}");
1379 return ResolveResult::Error(format!("Failed to build run options: {e}"));
1380 }
1381 };
1382
1383 if run_opts.dir.as_os_str().is_empty() {
1384 run_opts.dir = cached.dir.clone();
1385 }
1386
1387 log::info!("Auto-start: starting daemon {daemon_id} for slug '{slug}'");
1388
1389 let run_result = SUPERVISOR.run(run_opts).await;
1392
1393 if let Err(e) = run_result {
1394 log::warn!("Auto-start: failed to start daemon {daemon_id}: {e}");
1395 return ResolveResult::Error(format!("Failed to start daemon: {e}"));
1396 }
1397
1398 let poll_interval = std::time::Duration::from_millis(250);
1402
1403 loop {
1404 let daemons = {
1405 let sf = SUPERVISOR.state_file.lock().await;
1406 sf.daemons.clone()
1407 };
1408
1409 if let Some(d) = daemons.get(daemon_id) {
1410 if d.status.is_running() {
1411 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1412 log::info!("Auto-start: daemon {daemon_id} is ready on port {port}");
1413 return ResolveResult::Ready(port);
1414 }
1415 } else {
1416 log::warn!(
1417 "Auto-start: daemon {daemon_id} is no longer running (status: {})",
1418 d.status
1419 );
1420 return ResolveResult::Error(format!(
1421 "Daemon '{daemon_id}' started but exited unexpectedly.\n\
1422 Check its logs for errors."
1423 ));
1424 }
1425 } else {
1426 log::warn!("Auto-start: daemon {daemon_id} not found in state file after start");
1429 return ResolveResult::Error(format!(
1430 "Daemon '{daemon_id}' started but disappeared from the state file.\n\
1431 Check its logs for errors."
1432 ));
1433 }
1434
1435 tokio::time::sleep(poll_interval).await;
1436 }
1437}
1438
1439fn strip_tld(host: &str, tld: &str) -> Option<String> {
1446 host.strip_suffix(&format!(".{tld}"))
1447 .filter(|s| !s.is_empty())
1448 .map(str::to_string)
1449}
1450
1451fn bind_error_message(port: u16, err: &std::io::Error) -> String {
1453 if port < 1024 {
1454 format!(
1455 "Failed to bind proxy server to port {port}: {err}\n\
1456 Hint: ports below 1024 require elevated privileges. \
1457 Try: sudo pitchfork supervisor start"
1458 )
1459 } else {
1460 format!(
1461 "Failed to bind proxy server to port {port}: {err}\n\
1462 Hint: another process may already be using this port."
1463 )
1464 }
1465}
1466
1467fn starting_html_response(slug: &str, raw_host: &str) -> Response {
1472 let escaped_slug = slug
1473 .replace('&', "&")
1474 .replace('<', "<")
1475 .replace('>', ">")
1476 .replace('"', """)
1477 .replace('\'', "'");
1478 let escaped_host = raw_host
1479 .replace('&', "&")
1480 .replace('<', "<")
1481 .replace('>', ">")
1482 .replace('"', """)
1483 .replace('\'', "'");
1484
1485 let html = format!(
1486 r##"<!DOCTYPE html>
1487<html lang="en">
1488<head>
1489 <meta charset="UTF-8">
1490 <meta name="viewport" content="width=device-width, initial-scale=1">
1491 <meta http-equiv="refresh" content="2">
1492 <title>Starting {escaped_slug}… — pitchfork</title>
1493 <style>
1494 * {{ margin: 0; padding: 0; box-sizing: border-box; }}
1495 body {{
1496 font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
1497 background: #0f1117;
1498 color: #e1e4e8;
1499 display: flex;
1500 align-items: center;
1501 justify-content: center;
1502 min-height: 100vh;
1503 }}
1504 .container {{
1505 text-align: center;
1506 max-width: 480px;
1507 padding: 2rem;
1508 }}
1509 .spinner {{
1510 width: 48px;
1511 height: 48px;
1512 border: 4px solid rgba(255, 255, 255, 0.1);
1513 border-top-color: #58a6ff;
1514 border-radius: 50%;
1515 animation: spin 0.8s linear infinite;
1516 margin: 0 auto 1.5rem;
1517 }}
1518 @keyframes spin {{
1519 to {{ transform: rotate(360deg); }}
1520 }}
1521 h1 {{
1522 font-size: 1.5rem;
1523 font-weight: 600;
1524 margin-bottom: 0.5rem;
1525 }}
1526 .slug {{
1527 color: #58a6ff;
1528 font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
1529 }}
1530 .host {{
1531 color: #8b949e;
1532 font-size: 0.875rem;
1533 margin-top: 0.25rem;
1534 }}
1535 .hint {{
1536 color: #8b949e;
1537 font-size: 0.8rem;
1538 margin-top: 1.5rem;
1539 }}
1540 </style>
1541</head>
1542<body>
1543 <div class="container">
1544 <div class="spinner"></div>
1545 <h1>Starting <span class="slug">{escaped_slug}</span>…</h1>
1546 <p class="host">{escaped_host}</p>
1547 <p class="hint">This page will refresh automatically when the daemon is ready.</p>
1548 </div>
1549</body>
1550</html>"##
1551 );
1552
1553 Response::builder()
1554 .status(StatusCode::SERVICE_UNAVAILABLE)
1555 .header("content-type", "text/html; charset=utf-8")
1556 .header("retry-after", "2")
1557 .body(Body::from(html))
1558 .unwrap_or_else(|_| (StatusCode::SERVICE_UNAVAILABLE, "Starting…").into_response())
1559}
1560
1561fn error_response(status: StatusCode, message: &str) -> Response {
1563 (status, message.to_string()).into_response()
1564}
1565
1566#[cfg(test)]
1567mod tests {
1568 use super::*;
1569
1570 #[test]
1571 fn test_strip_tld() {
1572 assert_eq!(
1573 strip_tld("api.myproject.localhost", "localhost"),
1574 Some("api.myproject".to_string())
1575 );
1576 assert_eq!(
1577 strip_tld("api.localhost", "localhost"),
1578 Some("api".to_string())
1579 );
1580 assert_eq!(strip_tld("localhost", "localhost"), None);
1581 assert_eq!(
1582 strip_tld("api.myproject.test", "test"),
1583 Some("api.myproject".to_string())
1584 );
1585 assert_eq!(strip_tld("other.com", "localhost"), None);
1586 }
1587
1588 #[cfg(feature = "proxy-tls")]
1589 #[test]
1590 fn test_generate_ca() {
1591 let dir = tempfile::tempdir().unwrap();
1592 let cert_path = dir.path().join("ca.pem");
1593 let key_path = dir.path().join("ca-key.pem");
1594
1595 generate_ca(&cert_path, &key_path).unwrap();
1596
1597 assert!(cert_path.exists(), "ca.pem should be created");
1598 assert!(key_path.exists(), "ca-key.pem should be created");
1599
1600 let cert_pem = std::fs::read_to_string(&cert_path).unwrap();
1601 let key_pem = std::fs::read_to_string(&key_path).unwrap();
1602
1603 assert!(cert_pem.contains("BEGIN CERTIFICATE"), "should be PEM cert");
1604 assert!(
1605 key_pem.contains("BEGIN") && key_pem.contains("PRIVATE KEY"),
1606 "should be PEM key"
1607 );
1608 }
1609}