1use std::net::SocketAddr;
11use std::sync::Arc;
12
13use axum::Router;
14use axum::body::Body;
15use axum::extract::{Request, State};
16use axum::http::{HeaderValue, StatusCode, Uri};
17use axum::response::{IntoResponse, Response};
18use hyper::header::HOST;
19
20const PITCHFORK_HEADER: &str = "x-pitchfork";
22
23const PROXY_HOPS_HEADER: &str = "x-pitchfork-hops";
26
27const MAX_PROXY_HOPS: u64 = 5;
29
30const HOP_BY_HOP_HEADERS: &[&str] = &[
33 "connection",
34 "keep-alive",
35 "proxy-connection",
36 "transfer-encoding",
37 "upgrade",
38];
39
40use hyper_util::client::legacy::Client;
41use hyper_util::client::legacy::connect::HttpConnector;
42use hyper_util::rt::TokioExecutor;
43use tokio::net::TcpListener;
44
45use crate::daemon_id::DaemonId;
46use crate::settings::settings;
47use crate::supervisor::SUPERVISOR;
48
49const SLUG_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(2);
62
63#[derive(Clone, Debug)]
65struct CachedSlugEntry {
66 namespace: Option<String>,
68 daemon_name: String,
70 dir: std::path::PathBuf,
72}
73
74struct SlugCache {
76 entries: Arc<std::collections::HashMap<String, CachedSlugEntry>>,
77 expires_at: std::time::Instant,
78}
79
80static SLUG_CACHE: once_cell::sync::Lazy<tokio::sync::Mutex<SlugCache>> =
81 once_cell::sync::Lazy::new(|| {
82 tokio::sync::Mutex::new(SlugCache {
83 entries: Arc::new(std::collections::HashMap::new()),
84 expires_at: std::time::Instant::now(), })
86 });
87
88fn build_slug_entries() -> std::collections::HashMap<String, CachedSlugEntry> {
91 let global_slugs = crate::pitchfork_toml::PitchforkToml::read_global_slugs();
92 let mut entries = std::collections::HashMap::with_capacity(global_slugs.len());
93 for (slug, entry) in &global_slugs {
94 let ns = crate::pitchfork_toml::PitchforkToml::namespace_for_dir(&entry.dir).ok();
95 let daemon_name = entry.daemon.as_deref().unwrap_or(slug).to_string();
96 entries.insert(
97 slug.clone(),
98 CachedSlugEntry {
99 namespace: ns,
100 daemon_name,
101 dir: entry.dir.clone(),
102 },
103 );
104 }
105 entries
106}
107
108async fn get_cached_slugs() -> Arc<std::collections::HashMap<String, CachedSlugEntry>> {
114 {
116 let cache = SLUG_CACHE.lock().await;
117 if std::time::Instant::now() < cache.expires_at {
118 return Arc::clone(&cache.entries);
119 }
120 } let new_entries = Arc::new(build_slug_entries());
124
125 {
127 let mut cache = SLUG_CACHE.lock().await;
128 cache.entries = Arc::clone(&new_entries);
129 cache.expires_at = std::time::Instant::now() + SLUG_CACHE_TTL;
130 }
131
132 new_entries
133}
134
135async fn cached_slug_lookup(subdomain: &str) -> Option<CachedSlugEntry> {
137 let entries = get_cached_slugs().await;
138 entries.get(subdomain).cloned()
139}
140
141static AUTO_START_IN_PROGRESS: once_cell::sync::Lazy<
148 tokio::sync::Mutex<std::collections::HashSet<DaemonId>>,
149> = once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new(std::collections::HashSet::new()));
150
151enum ResolveResult {
153 Ready(u16),
156 Starting { slug: String },
158 NotFound,
160 Error(String),
162}
163
164type OnErrorFn = Arc<dyn Fn(&str) + Send + Sync>;
167
168#[derive(Clone)]
169struct ProxyState {
170 client: Arc<Client<HttpConnector, Body>>,
172 tld: String,
174 is_tls: bool,
176 on_error: Option<OnErrorFn>,
178}
179
180pub async fn serve(
188 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
189 cancel: tokio_util::sync::CancellationToken,
190) -> crate::Result<()> {
191 let s = settings();
192 let Some(effective_port) = u16::try_from(s.proxy.port).ok().filter(|&p| p > 0) else {
193 let msg = format!(
194 "proxy.port {} is out of valid port range (1-65535), proxy server cannot start",
195 s.proxy.port
196 );
197 let _ = bind_tx.send(Err(msg.clone()));
198 miette::bail!("{msg}");
199 };
200
201 let mut connector = HttpConnector::new();
202 connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
206
207 let client = Client::builder(TokioExecutor::new())
208 .pool_idle_timeout(std::time::Duration::from_secs(30))
211 .build(connector);
212
213 let state = ProxyState {
214 client: Arc::new(client),
215 tld: s.proxy.tld.clone(),
216 is_tls: s.proxy.https,
217 on_error: None,
218 };
219
220 let app = Router::new().fallback(proxy_handler).with_state(state);
221
222 let bind_ip: std::net::IpAddr = match s.proxy.host.parse() {
224 Ok(ip) => ip,
225 Err(_) => {
226 log::warn!(
227 "proxy.host {:?} is not a valid IP address — falling back to 127.0.0.1. \
228 The proxy will only be reachable on the loopback interface.",
229 s.proxy.host
230 );
231 std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
232 }
233 };
234 let addr = SocketAddr::from((bind_ip, effective_port));
235
236 if s.proxy.https {
237 serve_https_with_http_fallback(app, addr, s, effective_port, bind_tx, cancel).await
238 } else {
239 serve_http(app, addr, effective_port, bind_tx, cancel).await
240 }
241}
242
243async fn serve_http(
245 app: Router,
246 addr: SocketAddr,
247 effective_port: u16,
248 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
249 cancel: tokio_util::sync::CancellationToken,
250) -> crate::Result<()> {
251 let listener = match TcpListener::bind(addr).await {
252 Ok(l) => {
253 if settings().proxy.sync_hosts {
254 crate::proxy::hosts::sync_hosts_from_settings();
255 }
256 let _ = bind_tx.send(Ok(()));
257 l
258 }
259 Err(e) => {
260 let msg = bind_error_message(effective_port, &e);
261 let _ = bind_tx.send(Err(msg.clone()));
262 return Err(miette::miette!("{msg}"));
263 }
264 };
265
266 log::info!("Proxy server listening on http://{addr}");
267 if effective_port < 1024 {
268 log::info!(
269 "Note: port {effective_port} is a privileged port. \
270 The supervisor must be started with sudo to bind to this port."
271 );
272 }
273 let shutdown_signal = cancel.clone().cancelled_owned();
274 axum::serve(
275 listener,
276 app.into_make_service_with_connect_info::<SocketAddr>(),
277 )
278 .with_graceful_shutdown(shutdown_signal)
279 .await
280 .map_err(|e| miette::miette!("Proxy server error: {e}"))?;
281 Ok(())
282}
283
284#[cfg(feature = "proxy-tls")]
290async fn serve_https_with_http_fallback(
291 app: Router,
292 addr: SocketAddr,
293 s: &crate::settings::Settings,
294 effective_port: u16,
295 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
296 cancel: tokio_util::sync::CancellationToken,
297) -> crate::Result<()> {
298 use rustls::ServerConfig;
299 use tokio_rustls::TlsAcceptor;
300
301 let (ca_cert_path, ca_key_path) = resolve_tls_paths(s);
302
303 if !ca_cert_path.exists() || !ca_key_path.exists() {
305 generate_ca(&ca_cert_path, &ca_key_path)?;
306 log::info!(
307 "Generated local CA certificate at {}",
308 ca_cert_path.display()
309 );
310 log::info!("To trust the CA in your browser, run: pitchfork proxy trust");
311 }
312
313 let _ = rustls::crypto::ring::default_provider().install_default();
315
316 let resolver = SniCertResolver::new(&ca_cert_path, &ca_key_path)?;
318
319 let mut tls_config = ServerConfig::builder()
320 .with_no_client_auth()
321 .with_cert_resolver(Arc::new(resolver));
322 tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
325
326 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
327
328 let listener = match TcpListener::bind(addr).await {
329 Ok(l) => {
330 if settings().proxy.sync_hosts {
331 crate::proxy::hosts::sync_hosts_from_settings();
332 }
333 let _ = bind_tx.send(Ok(()));
334 l
335 }
336 Err(e) => {
337 let msg = bind_error_message(effective_port, &e);
338 let _ = bind_tx.send(Err(msg.clone()));
339 return Err(miette::miette!("{msg}"));
340 }
341 };
342
343 log::info!("Proxy server listening on https://{addr} (HTTP also accepted)");
344 if effective_port < 1024 {
345 log::info!(
346 "Note: port {effective_port} is a privileged port. \
347 The supervisor must be started with sudo to bind to this port."
348 );
349 }
350
351 let redirect_app = Router::new().fallback(redirect_to_https_handler);
353
354 let mut conn_tasks: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
356 loop {
357 while conn_tasks.try_join_next().is_some() {}
360
361 tokio::select! {
362 accept_result = listener.accept() => {
363 let (stream, _peer_addr) = match accept_result {
364 Ok(conn) => conn,
365 Err(e) => {
366 log::warn!("Accept error (will retry): {e}");
367 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
368 continue;
369 }
370 };
371
372 let acceptor = acceptor.clone();
373 let app = app.clone();
374 let redirect_app = redirect_app.clone();
375
376 conn_tasks.spawn(async move {
377 let mut peek_buf = [0u8; 1];
380 match stream.peek(&mut peek_buf).await {
381 Ok(0) | Err(_) => return,
382 _ => {}
383 }
384
385 if peek_buf[0] == 0x16 {
386 match acceptor.accept(stream).await {
388 Ok(tls_stream) => {
389 let io = hyper_util::rt::TokioIo::new(tls_stream);
390 let svc = hyper_util::service::TowerToHyperService::new(app);
391 if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
392 .serve_connection_with_upgrades(io, svc)
393 .await
394 {
395 log::debug!("Connection error: {e}");
398 }
399 }
400 Err(e) => {
401 log::debug!("TLS handshake error: {e}");
402 }
403 }
404 } else {
405 let io = hyper_util::rt::TokioIo::new(stream);
407 let svc = hyper_util::service::TowerToHyperService::new(redirect_app);
408 let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
409 .serve_connection_with_upgrades(io, svc)
410 .await;
411 }
412 });
413
414 while conn_tasks.try_join_next().is_some() {}
415 }
416 _ = cancel.cancelled() => {
417 log::info!("Proxy server shutting down (cancel signal received)");
418 break;
419 }
420 }
421 }
422
423 let drain_timeout = std::time::Duration::from_secs(10);
425 let _ = tokio::time::timeout(drain_timeout, async {
426 while conn_tasks.join_next().await.is_some() {}
427 })
428 .await;
429
430 Ok(())
431}
432
433#[cfg(not(feature = "proxy-tls"))]
435async fn serve_https_with_http_fallback(
436 _app: Router,
437 _addr: SocketAddr,
438 _s: &crate::settings::Settings,
439 _effective_port: u16,
440 bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
441 _cancel: tokio_util::sync::CancellationToken,
442) -> crate::Result<()> {
443 let msg = "HTTPS proxy support requires the `proxy-tls` feature.\n\
444 Rebuild pitchfork with: cargo build --features proxy-tls"
445 .to_string();
446 let _ = bind_tx.send(Err(msg.clone()));
447 miette::bail!("{msg}")
448}
449
450#[cfg(feature = "proxy-tls")]
455fn resolve_tls_paths(s: &crate::settings::Settings) -> (std::path::PathBuf, std::path::PathBuf) {
456 let proxy_dir = crate::env::PITCHFORK_STATE_DIR.join("proxy");
457 let resolve = |configured: &str, default: &str| {
458 if configured.is_empty() {
459 proxy_dir.join(default)
460 } else {
461 std::path::PathBuf::from(configured)
462 }
463 };
464 (
465 resolve(&s.proxy.tls_cert, "ca.pem"),
466 resolve(&s.proxy.tls_key, "ca-key.pem"),
467 )
468}
469
470#[cfg(feature = "proxy-tls")]
475pub fn generate_ca(cert_path: &std::path::Path, key_path: &std::path::Path) -> crate::Result<()> {
476 use rcgen::{
477 BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyUsagePurpose,
478 };
479
480 if let Some(parent) = cert_path.parent() {
482 std::fs::create_dir_all(parent)
483 .map_err(|e| miette::miette!("Failed to create proxy cert directory: {e}"))?;
484 }
485
486 let mut params = CertificateParams::default();
487 let mut dn = DistinguishedName::new();
488 dn.push(DnType::CommonName, "Pitchfork Local CA");
489 dn.push(DnType::OrganizationName, "Pitchfork");
490 params.distinguished_name = dn;
491 params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
492 params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
493
494 let key_pair = rcgen::KeyPair::generate()
495 .map_err(|e| miette::miette!("Failed to generate CA key pair: {e}"))?;
496 let ca_cert = params
497 .self_signed(&key_pair)
498 .map_err(|e| miette::miette!("Failed to self-sign CA certificate: {e}"))?;
499
500 std::fs::write(cert_path, ca_cert.pem()).map_err(|e| {
502 miette::miette!(
503 "Failed to write CA certificate to {}: {e}",
504 cert_path.display()
505 )
506 })?;
507
508 {
512 #[cfg(unix)]
513 {
514 use std::io::Write;
515 use std::os::unix::fs::OpenOptionsExt;
516 std::fs::OpenOptions::new()
517 .write(true)
518 .create(true)
519 .truncate(true)
520 .mode(0o600)
521 .open(key_path)
522 .and_then(|mut f| f.write_all(key_pair.serialize_pem().as_bytes()))
523 .map_err(|e| {
524 miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
525 })?;
526 }
527 #[cfg(not(unix))]
528 {
529 std::fs::write(key_path, key_pair.serialize_pem()).map_err(|e| {
530 miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
531 })?;
532 log::debug!(
533 "CA private key written to {} (file permissions are not restricted \
534 on non-Unix platforms — consider restricting access manually)",
535 key_path.display()
536 );
537 }
538 }
539
540 Ok(())
541}
542
543#[cfg(feature = "proxy-tls")]
563struct SniCertResolver {
564 issuer: rcgen::Issuer<'static, rcgen::KeyPair>,
566 host_certs_dir: std::path::PathBuf,
568 cache: std::sync::Mutex<std::collections::HashMap<String, Arc<rustls::sign::CertifiedKey>>>,
570 pending: std::sync::Mutex<std::collections::HashSet<String>>,
574 pending_cv: std::sync::Condvar,
576}
577
578#[cfg(feature = "proxy-tls")]
579impl std::fmt::Debug for SniCertResolver {
580 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581 f.debug_struct("SniCertResolver").finish_non_exhaustive()
582 }
583}
584
585#[cfg(feature = "proxy-tls")]
586impl SniCertResolver {
587 fn new(ca_cert_path: &std::path::Path, ca_key_path: &std::path::Path) -> crate::Result<Self> {
589 let ca_key_pem = std::fs::read_to_string(ca_key_path)
590 .map_err(|e| miette::miette!("Failed to read CA key {}: {e}", ca_key_path.display()))?;
591 let ca_cert_pem = std::fs::read_to_string(ca_cert_path).map_err(|e| {
592 miette::miette!("Failed to read CA cert {}: {e}", ca_cert_path.display())
593 })?;
594
595 if !ca_cert_pem.contains("BEGIN CERTIFICATE") {
597 miette::bail!("CA cert file does not contain a valid PEM certificate");
598 }
599
600 let ca_key = rcgen::KeyPair::from_pem(&ca_key_pem)
601 .map_err(|e| miette::miette!("Failed to parse CA key: {e}"))?;
602
603 let issuer = rcgen::Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key)
605 .map_err(|e| miette::miette!("Failed to parse CA cert: {e}"))?;
606
607 let host_certs_dir = ca_cert_path
609 .parent()
610 .unwrap_or(std::path::Path::new("."))
611 .join("host-certs");
612 std::fs::create_dir_all(&host_certs_dir)
613 .map_err(|e| miette::miette!("Failed to create host-certs dir: {e}"))?;
614
615 Ok(Self {
616 issuer,
617 host_certs_dir,
618 cache: std::sync::Mutex::new(std::collections::HashMap::new()),
619 pending: std::sync::Mutex::new(std::collections::HashSet::new()),
620 pending_cv: std::sync::Condvar::new(),
621 })
622 }
623
624 fn get_or_create(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
647 {
649 let cache = self.cache.lock().ok()?;
650 if let Some(ck) = cache.get(domain) {
651 return Some(Arc::clone(ck));
652 }
653 } loop {
666 {
667 let mut pending = self.pending.lock().ok()?;
668 if pending.contains(domain) {
669 pending = self.pending_cv.wait(pending).ok()?;
671 drop(pending);
673 } else {
674 pending.insert(domain.to_string());
676 break;
677 }
678 } {
684 let cache = self.cache.lock().ok()?;
685 if let Some(ck) = cache.get(domain) {
686 return Some(Arc::clone(ck));
687 }
688 } } let result = self.get_or_create_inner(domain);
692
693 {
699 let mut pending = match self.pending.lock() {
700 Ok(g) => g,
701 Err(e) => e.into_inner(),
702 };
703 pending.remove(domain);
704 self.pending_cv.notify_all();
705 }
706
707 result
708 }
709
710 fn get_or_create_inner(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
712 let safe_name = domain.replace('.', "_").replace('*', "wildcard");
713 let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
714
715 if disk_path.exists() {
717 if let Ok(ck) = self.load_from_disk(&disk_path) {
718 let ck = Arc::new(ck);
719 if let Ok(mut cache) = self.cache.lock() {
720 cache.insert(domain.to_string(), Arc::clone(&ck));
721 }
722 return Some(ck);
723 }
724 let _ = std::fs::remove_file(&disk_path);
726 }
727
728 let ck = self.sign_for_domain(domain).ok()?;
730
731 let ck = Arc::new(ck);
732 if let Ok(mut cache) = self.cache.lock() {
733 cache.insert(domain.to_string(), Arc::clone(&ck));
734 }
735 Some(ck)
736 }
737
738 fn load_from_disk(&self, path: &std::path::Path) -> crate::Result<rustls::sign::CertifiedKey> {
743 use rustls::pki_types::CertificateDer;
744 use rustls_pemfile::{certs, private_key};
745
746 let pem = std::fs::read_to_string(path)
747 .map_err(|e| miette::miette!("Failed to read disk cert {}: {e}", path.display()))?;
748
749 let cert_ders: Vec<CertificateDer<'static>> = certs(&mut pem.as_bytes())
750 .collect::<Result<Vec<_>, _>>()
751 .map_err(|e| miette::miette!("Failed to parse certs from {}: {e}", path.display()))?;
752
753 if cert_ders.is_empty() {
754 miette::bail!("No certificates found in {}", path.display());
755 }
756
757 {
759 let (_, cert) = x509_parser::parse_x509_certificate(&cert_ders[0]).map_err(|e| {
760 miette::miette!("Failed to parse certificate from {}: {e}", path.display())
761 })?;
762 use chrono::Utc;
763 let now_ts = Utc::now().timestamp();
764 let not_after_ts = cert.validity().not_after.timestamp();
765 if not_after_ts < now_ts {
766 miette::bail!(
767 "Cached certificate at {} has expired — will regenerate",
768 path.display()
769 );
770 }
771 }
772
773 let key_der = private_key(&mut pem.as_bytes())
774 .map_err(|e| miette::miette!("Failed to parse key from {}: {e}", path.display()))?
775 .ok_or_else(|| miette::miette!("No private key found in {}", path.display()))?;
776
777 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
778 .map_err(|e| miette::miette!("Failed to create signing key from disk: {e}"))?;
779
780 Ok(rustls::sign::CertifiedKey::new(cert_ders, signing_key))
781 }
782
783 fn sign_for_domain(&self, domain: &str) -> crate::Result<rustls::sign::CertifiedKey> {
791 use rcgen::date_time_ymd;
792 use rcgen::{CertificateParams, DistinguishedName, DnType, SanType};
793 use rustls::pki_types::CertificateDer;
794 use rustls_pemfile::private_key;
795
796 let mut params = CertificateParams::default();
797 let mut dn = DistinguishedName::new();
798 dn.push(DnType::CommonName, domain);
799 params.distinguished_name = dn;
800
801 {
803 use chrono::{Datelike, Duration, Utc};
804 let yesterday = Utc::now() - Duration::days(1);
805 let expiry = Utc::now() + Duration::days(397);
808 params.not_before = date_time_ymd(
809 yesterday.year(),
810 yesterday.month() as u8,
811 yesterday.day() as u8,
812 );
813 params.not_after =
814 date_time_ymd(expiry.year(), expiry.month() as u8, expiry.day() as u8);
815 }
816
817 let mut sans =
819 vec![SanType::DnsName(domain.to_string().try_into().map_err(
820 |e| miette::miette!("Invalid domain name '{domain}': {e}"),
821 )?)];
822 if let Some(dot_pos) = domain.find('.') {
824 let parent = &domain[dot_pos + 1..];
825 if parent.contains('.') {
827 let wildcard = format!("*.{parent}");
828 if let Ok(wc) = wildcard.try_into() {
829 sans.push(SanType::DnsName(wc));
830 }
831 }
832 }
833 params.subject_alt_names = sans;
834
835 let leaf_key = rcgen::KeyPair::generate()
836 .map_err(|e| miette::miette!("Failed to generate leaf key: {e}"))?;
837 let leaf_cert = params
838 .signed_by(&leaf_key, &self.issuer)
839 .map_err(|e| miette::miette!("Failed to sign leaf cert for '{domain}': {e}"))?;
840
841 let cert_der = CertificateDer::from(leaf_cert.der().to_vec());
843 let key_pem = leaf_key.serialize_pem();
844 let key_der = private_key(&mut key_pem.as_bytes())
845 .map_err(|e| miette::miette!("Failed to parse leaf key PEM: {e}"))?
846 .ok_or_else(|| miette::miette!("No private key found in generated PEM"))?;
847
848 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
849 .map_err(|e| miette::miette!("Failed to create signing key: {e}"))?;
850
851 let safe_name = domain.replace('.', "_").replace('*', "wildcard");
854 let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
855 let combined_pem = format!("{}{}", leaf_cert.pem(), key_pem);
856 {
857 #[cfg(unix)]
858 {
859 use std::io::Write;
860 use std::os::unix::fs::OpenOptionsExt;
861 if let Err(e) = std::fs::OpenOptions::new()
862 .write(true)
863 .create(true)
864 .truncate(true)
865 .mode(0o600)
866 .open(&disk_path)
867 .and_then(|mut f| f.write_all(combined_pem.as_bytes()))
868 {
869 log::warn!(
870 "Failed to persist cert for '{domain}' to {}: {e}",
871 disk_path.display()
872 );
873 }
874 }
875 #[cfg(not(unix))]
876 {
877 if let Err(e) = std::fs::write(&disk_path, combined_pem) {
878 log::warn!(
879 "Failed to persist cert for '{domain}' to {}: {e}",
880 disk_path.display()
881 );
882 } else {
883 log::debug!(
884 "Leaf cert for '{domain}' written to {} (file permissions are not \
885 restricted on non-Unix platforms — consider restricting access manually)",
886 disk_path.display()
887 );
888 }
889 }
890 }
891
892 Ok(rustls::sign::CertifiedKey::new(vec![cert_der], signing_key))
893 }
894}
895
896#[cfg(feature = "proxy-tls")]
897impl rustls::server::ResolvesServerCert for SniCertResolver {
898 fn resolve(
899 &self,
900 client_hello: rustls::server::ClientHello<'_>,
901 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
902 let domain = client_hello.server_name()?;
903 self.get_or_create(domain)
904 }
905}
906
907fn get_request_host(req: &Request) -> Option<String> {
913 let authority = req
915 .uri()
916 .authority()
917 .map(|a| a.as_str().to_string())
918 .filter(|s| !s.is_empty());
919
920 authority.or_else(|| {
921 req.headers()
922 .get(HOST)
923 .and_then(|h| h.to_str().ok())
924 .map(str::to_string)
925 })
926}
927
928fn inject_forwarded_headers(req: &mut Request, is_tls: bool, host_header: &str) {
939 let remote_addr = req
940 .extensions()
941 .get::<axum::extract::ConnectInfo<SocketAddr>>()
942 .map(|ci| ci.0.ip().to_string())
943 .unwrap_or_else(|| "127.0.0.1".to_string());
944
945 let proto = if is_tls { "https" } else { "http" };
946 let default_port = if is_tls { "443" } else { "80" };
947
948 let forwarded_for = remote_addr.clone();
951 let forwarded_proto = proto.to_string();
952 let forwarded_host = host_header.to_string();
953 let forwarded_port = host_header
954 .rsplit_once(':')
955 .map(|(_, port)| port.to_string())
956 .unwrap_or_else(|| default_port.to_string());
957
958 for name in [
964 "x-forwarded-for",
965 "x-forwarded-proto",
966 "x-forwarded-host",
967 "x-forwarded-port",
968 "forwarded",
969 ] {
970 if let Ok(header_name) = axum::http::HeaderName::from_bytes(name.as_bytes()) {
971 req.headers_mut().remove(&header_name);
972 }
973 }
974
975 let headers = [
976 ("x-forwarded-for", forwarded_for),
977 ("x-forwarded-proto", forwarded_proto),
978 ("x-forwarded-host", forwarded_host),
979 ("x-forwarded-port", forwarded_port),
980 ];
981
982 for (name, value) in headers {
983 if let Ok(v) = HeaderValue::from_str(&value) {
984 let header_name = axum::http::HeaderName::from_static(name);
985 req.headers_mut().insert(header_name, v);
986 }
987 }
988}
989
990async fn proxy_handler(State(state): State<ProxyState>, mut req: Request) -> Response {
995 let Some(raw_host) = get_request_host(&req) else {
997 return error_response(StatusCode::BAD_REQUEST, "Missing Host header");
998 };
999 let host = if raw_host.starts_with('[') {
1003 raw_host
1005 .split("]:")
1006 .next()
1007 .unwrap_or(&raw_host)
1008 .trim_start_matches('[')
1009 .trim_end_matches(']')
1010 .to_string()
1011 } else {
1012 raw_host.split(':').next().unwrap_or(&raw_host).to_string()
1014 };
1015
1016 let is_from_pitchfork = req.headers().contains_key(PROXY_HOPS_HEADER);
1027 let hops: u64 = if is_from_pitchfork {
1028 req.headers()
1029 .get(PROXY_HOPS_HEADER)
1030 .and_then(|v| v.to_str().ok())
1031 .and_then(|s| s.parse().ok())
1032 .unwrap_or(0)
1033 } else {
1034 0
1036 };
1037 if hops >= MAX_PROXY_HOPS {
1038 return error_response(
1039 StatusCode::LOOP_DETECTED,
1040 &format!(
1041 "Loop detected for '{host}': request has passed through the proxy {hops} times.\n\
1042 This usually means a backend is proxying back through pitchfork without rewriting \n\
1043 the Host header. If you use Vite/webpack proxy, set changeOrigin: true."
1044 ),
1045 );
1046 }
1047
1048 let target_port = match resolve_target(&host, &state.tld).await {
1050 ResolveResult::Ready(port) => port,
1051 ResolveResult::Starting { slug } => {
1052 return starting_html_response(&slug, &raw_host);
1053 }
1054 ResolveResult::NotFound => {
1055 return error_response(
1056 StatusCode::BAD_GATEWAY,
1057 &format!(
1058 "No daemon found for host '{host}'.\n\
1059 Make sure the daemon has a slug, is running, and has a port configured.\n\
1060 Expected format: <slug>.{tld}",
1061 tld = state.tld
1062 ),
1063 );
1064 }
1065 ResolveResult::Error(msg) => {
1066 return error_response(StatusCode::BAD_GATEWAY, &msg);
1067 }
1068 };
1069
1070 let path_and_query = req
1072 .uri()
1073 .path_and_query()
1074 .map(|pq| pq.as_str())
1075 .unwrap_or("/");
1076
1077 let forward_uri = match Uri::builder()
1078 .scheme("http")
1079 .authority(format!("localhost:{target_port}"))
1080 .path_and_query(path_and_query)
1081 .build()
1082 {
1083 Ok(uri) => uri,
1084 Err(e) => {
1085 return error_response(
1086 StatusCode::INTERNAL_SERVER_ERROR,
1087 &format!("Failed to build forward URI: {e}"),
1088 );
1089 }
1090 };
1091
1092 *req.uri_mut() = forward_uri;
1094 req.headers_mut().insert(
1095 HOST,
1096 HeaderValue::from_str(&format!("localhost:{target_port}"))
1097 .unwrap_or_else(|_| HeaderValue::from_static("localhost")),
1098 );
1099
1100 inject_forwarded_headers(&mut req, state.is_tls, &raw_host);
1102
1103 if let Ok(v) = HeaderValue::from_str(&(hops + 1).to_string()) {
1105 req.headers_mut()
1106 .insert(axum::http::HeaderName::from_static(PROXY_HOPS_HEADER), v);
1107 }
1108
1109 let pseudo_headers: Vec<_> = req
1114 .headers()
1115 .keys()
1116 .filter(|k| k.as_str().starts_with(':'))
1117 .cloned()
1118 .collect();
1119 for key in pseudo_headers {
1120 req.headers_mut().remove(&key);
1121 }
1122
1123 let client_upgrade = hyper::upgrade::on(&mut req);
1125
1126 let result = match tokio::time::timeout(
1134 std::time::Duration::from_secs(120),
1135 state.client.request(req),
1136 )
1137 .await
1138 {
1139 Ok(r) => r,
1140 Err(_elapsed) => {
1141 let msg = format!(
1142 "Request to daemon on port {target_port} timed out after 120 s.\n\
1143 The daemon accepted the connection but did not respond in time."
1144 );
1145 log::warn!("{msg}");
1146 if let Some(ref on_error) = state.on_error {
1147 on_error(&msg);
1148 }
1149 return error_response(StatusCode::GATEWAY_TIMEOUT, &msg);
1150 }
1151 };
1152 match result {
1153 Ok(mut resp) => {
1154 let backend_upgrade = hyper::upgrade::on(&mut resp);
1156 let (mut parts, body) = resp.into_parts();
1157
1158 parts.headers.insert(
1160 axum::http::HeaderName::from_static(PITCHFORK_HEADER),
1161 HeaderValue::from_static("1"),
1162 );
1163
1164 parts.headers.remove(PROXY_HOPS_HEADER);
1166
1167 if state.is_tls && parts.status != StatusCode::SWITCHING_PROTOCOLS {
1172 for h in HOP_BY_HOP_HEADERS {
1173 if let Ok(name) = axum::http::HeaderName::from_bytes(h.as_bytes()) {
1174 parts.headers.remove(&name);
1175 }
1176 }
1177 }
1178
1179 if parts.status == StatusCode::SWITCHING_PROTOCOLS {
1181 tokio::spawn(async move {
1186 if let (Ok(client_upgraded), Ok(backend_upgraded)) =
1187 (client_upgrade.await, backend_upgrade.await)
1188 {
1189 let mut client_io = hyper_util::rt::TokioIo::new(client_upgraded);
1190 let mut backend_io = hyper_util::rt::TokioIo::new(backend_upgraded);
1191 let _ =
1199 tokio::io::copy_bidirectional(&mut client_io, &mut backend_io).await;
1200 }
1201 });
1202 return Response::from_parts(parts, Body::empty());
1203 }
1204
1205 Response::from_parts(parts, Body::new(body))
1208 }
1209 Err(e) => {
1210 let msg = format!(
1211 "Failed to connect to daemon on port {target_port}: {e}\n\
1212 The daemon may have stopped or is not yet ready."
1213 );
1214 if let Some(ref on_error) = state.on_error {
1215 on_error(&msg);
1216 } else {
1217 log::warn!("{msg}");
1218 }
1219 error_response(StatusCode::BAD_GATEWAY, &msg)
1220 }
1221 }
1222}
1223
1224async fn resolve_target(host: &str, tld: &str) -> ResolveResult {
1243 let Some(subdomain) = strip_tld(host, tld) else {
1245 return ResolveResult::NotFound;
1246 };
1247
1248 let Some(cached) = cached_slug_lookup(&subdomain).await else {
1250 return ResolveResult::NotFound;
1251 };
1252
1253 let daemon_name = &cached.daemon_name;
1254 let expected_namespace = &cached.namespace;
1255
1256 let daemons = {
1258 let state_file = SUPERVISOR.state_file.lock().await;
1259 state_file.daemons.clone()
1260 };
1261
1262 let running_matches: Vec<(&DaemonId, &crate::daemon::Daemon)> = daemons
1265 .iter()
1266 .filter(|(id, d)| {
1267 id.name() == daemon_name
1268 && d.status.is_running()
1269 && match expected_namespace {
1270 Some(ns) => id.namespace() == ns,
1271 None => true,
1272 }
1273 })
1274 .collect();
1275
1276 match running_matches.as_slice() {
1277 [] => {
1278 try_auto_start(&subdomain, &cached).await
1280 }
1281 [(_, d)] => {
1282 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1283 ResolveResult::Ready(port)
1284 } else {
1285 ResolveResult::NotFound
1286 }
1287 }
1288 _ => {
1289 let d = running_matches[0].1;
1290 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1291 ResolveResult::Ready(port)
1292 } else {
1293 ResolveResult::NotFound
1294 }
1295 }
1296 }
1297}
1298
1299struct AutoStartGuard {
1306 daemon_id: DaemonId,
1307}
1308
1309impl Drop for AutoStartGuard {
1310 fn drop(&mut self) {
1311 let daemon_id = self.daemon_id.clone();
1312 tokio::spawn(async move {
1316 AUTO_START_IN_PROGRESS.lock().await.remove(&daemon_id);
1317 });
1318 }
1319}
1320
1321async fn try_auto_start(slug: &str, cached: &CachedSlugEntry) -> ResolveResult {
1332 let s = settings();
1333 if !s.proxy.auto_start {
1334 return ResolveResult::NotFound;
1335 }
1336
1337 let ns = cached
1342 .namespace
1343 .clone()
1344 .unwrap_or_else(|| "global".to_string());
1345 let daemon_id = match DaemonId::try_new(&ns, &cached.daemon_name) {
1346 Ok(id) => id,
1347 Err(_) => return ResolveResult::NotFound,
1348 };
1349
1350 {
1353 let mut in_progress = AUTO_START_IN_PROGRESS.lock().await;
1354 if !in_progress.insert(daemon_id.clone()) {
1355 return ResolveResult::Starting {
1356 slug: slug.to_string(),
1357 };
1358 }
1359 }
1360
1361 let _guard = AutoStartGuard {
1363 daemon_id: daemon_id.clone(),
1364 };
1365
1366 let timeout = s.proxy_auto_start_timeout();
1370
1371 match tokio::time::timeout(timeout, try_auto_start_inner(slug, cached, &daemon_id)).await {
1372 Ok(result) => result,
1373 Err(_elapsed) => {
1374 log::warn!("Auto-start: total timeout ({timeout:?}) exceeded for daemon {daemon_id}");
1375 ResolveResult::Error(format!(
1376 "Auto-start for '{daemon_id}' timed out after {timeout:?}.\n\
1377 The daemon did not become ready and bind a port within the configured \
1378 proxy_auto_start_timeout.\n\
1379 Increase the timeout or check the daemon's logs for slow startup."
1380 ))
1381 }
1382 }
1383}
1384
1385async fn try_auto_start_inner(
1389 slug: &str,
1390 cached: &CachedSlugEntry,
1391 daemon_id: &DaemonId,
1392) -> ResolveResult {
1393 let pt = match crate::pitchfork_toml::PitchforkToml::all_merged_from(&cached.dir) {
1395 Ok(pt) => pt,
1396 Err(e) => {
1397 log::warn!(
1398 "Auto-start: failed to load config from {}: {e}",
1399 cached.dir.display()
1400 );
1401 return ResolveResult::NotFound;
1402 }
1403 };
1404
1405 let daemon_config = match pt.daemons.get(daemon_id) {
1406 Some(cfg) => cfg,
1407 None => {
1408 log::debug!(
1409 "Auto-start: daemon {daemon_id} not found in config at {}",
1410 cached.dir.display()
1411 );
1412 return ResolveResult::NotFound;
1413 }
1414 };
1415
1416 let opts = crate::ipc::batch::StartOptions::default();
1420 let mut run_opts = match crate::ipc::batch::build_run_options(daemon_id, daemon_config, &opts) {
1421 Ok(o) => o,
1422 Err(e) => {
1423 log::warn!("Auto-start: failed to build run options for {daemon_id}: {e}");
1424 return ResolveResult::Error(format!("Failed to build run options: {e}"));
1425 }
1426 };
1427
1428 if run_opts.dir.0.as_os_str().is_empty() {
1429 run_opts.dir = crate::config_types::Dir(cached.dir.clone());
1430 }
1431
1432 log::info!("Auto-start: starting daemon {daemon_id} for slug '{slug}'");
1433
1434 let run_result = SUPERVISOR.run(run_opts).await;
1437
1438 if let Err(e) = run_result {
1439 log::warn!("Auto-start: failed to start daemon {daemon_id}: {e}");
1440 return ResolveResult::Error(format!("Failed to start daemon: {e}"));
1441 }
1442
1443 let poll_interval = std::time::Duration::from_millis(250);
1447
1448 loop {
1449 let daemons = {
1450 let sf = SUPERVISOR.state_file.lock().await;
1451 sf.daemons.clone()
1452 };
1453
1454 if let Some(d) = daemons.get(daemon_id) {
1455 if d.status.is_running() {
1456 if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1457 log::info!("Auto-start: daemon {daemon_id} is ready on port {port}");
1458 return ResolveResult::Ready(port);
1459 }
1460 } else {
1461 log::warn!(
1462 "Auto-start: daemon {daemon_id} is no longer running (status: {})",
1463 d.status
1464 );
1465 return ResolveResult::Error(format!(
1466 "Daemon '{daemon_id}' started but exited unexpectedly.\n\
1467 Check its logs for errors."
1468 ));
1469 }
1470 } else {
1471 log::warn!("Auto-start: daemon {daemon_id} not found in state file after start");
1474 return ResolveResult::Error(format!(
1475 "Daemon '{daemon_id}' started but disappeared from the state file.\n\
1476 Check its logs for errors."
1477 ));
1478 }
1479
1480 tokio::time::sleep(poll_interval).await;
1481 }
1482}
1483
1484fn strip_tld(host: &str, tld: &str) -> Option<String> {
1491 host.strip_suffix(&format!(".{tld}"))
1492 .filter(|s| !s.is_empty())
1493 .map(str::to_string)
1494}
1495
1496fn bind_error_message(port: u16, err: &std::io::Error) -> String {
1498 if port < 1024 {
1499 format!(
1500 "Failed to bind proxy server to port {port}: {err}\n\
1501 Hint: ports below 1024 require elevated privileges. \
1502 Try: sudo pitchfork supervisor start"
1503 )
1504 } else {
1505 format!(
1506 "Failed to bind proxy server to port {port}: {err}\n\
1507 Hint: another process may already be using this port."
1508 )
1509 }
1510}
1511
1512fn starting_html_response(slug: &str, raw_host: &str) -> Response {
1517 let escaped_slug = slug
1518 .replace('&', "&")
1519 .replace('<', "<")
1520 .replace('>', ">")
1521 .replace('"', """)
1522 .replace('\'', "'");
1523 let escaped_host = raw_host
1524 .replace('&', "&")
1525 .replace('<', "<")
1526 .replace('>', ">")
1527 .replace('"', """)
1528 .replace('\'', "'");
1529
1530 let html = format!(
1531 r##"<!DOCTYPE html>
1532<html lang="en">
1533<head>
1534 <meta charset="UTF-8">
1535 <meta name="viewport" content="width=device-width, initial-scale=1">
1536 <meta http-equiv="refresh" content="2">
1537 <title>Starting {escaped_slug}… — pitchfork</title>
1538 <style>
1539 * {{ margin: 0; padding: 0; box-sizing: border-box; }}
1540 body {{
1541 font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
1542 background: #0f1117;
1543 color: #e1e4e8;
1544 display: flex;
1545 align-items: center;
1546 justify-content: center;
1547 min-height: 100vh;
1548 }}
1549 .container {{
1550 text-align: center;
1551 max-width: 480px;
1552 padding: 2rem;
1553 }}
1554 .spinner {{
1555 width: 48px;
1556 height: 48px;
1557 border: 4px solid rgba(255, 255, 255, 0.1);
1558 border-top-color: #58a6ff;
1559 border-radius: 50%;
1560 animation: spin 0.8s linear infinite;
1561 margin: 0 auto 1.5rem;
1562 }}
1563 @keyframes spin {{
1564 to {{ transform: rotate(360deg); }}
1565 }}
1566 h1 {{
1567 font-size: 1.5rem;
1568 font-weight: 600;
1569 margin-bottom: 0.5rem;
1570 }}
1571 .slug {{
1572 color: #58a6ff;
1573 font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
1574 }}
1575 .host {{
1576 color: #8b949e;
1577 font-size: 0.875rem;
1578 margin-top: 0.25rem;
1579 }}
1580 .hint {{
1581 color: #8b949e;
1582 font-size: 0.8rem;
1583 margin-top: 1.5rem;
1584 }}
1585 </style>
1586</head>
1587<body>
1588 <div class="container">
1589 <div class="spinner"></div>
1590 <h1>Starting <span class="slug">{escaped_slug}</span>…</h1>
1591 <p class="host">{escaped_host}</p>
1592 <p class="hint">This page will refresh automatically when the daemon is ready.</p>
1593 </div>
1594</body>
1595</html>"##
1596 );
1597
1598 Response::builder()
1599 .status(StatusCode::SERVICE_UNAVAILABLE)
1600 .header("content-type", "text/html; charset=utf-8")
1601 .header("retry-after", "2")
1602 .body(Body::from(html))
1603 .unwrap_or_else(|_| (StatusCode::SERVICE_UNAVAILABLE, "Starting…").into_response())
1604}
1605
1606async fn redirect_to_https_handler(req: Request) -> Response {
1615 if req.headers().contains_key("upgrade") {
1617 log::warn!("Dropping plain-HTTP WebSocket upgrade attempt — use wss:// instead of ws://");
1618 return (
1619 StatusCode::BAD_REQUEST,
1620 "WebSocket over plain HTTP is not supported on the HTTPS port. Use wss:// instead.",
1621 )
1622 .into_response();
1623 }
1624
1625 let raw_host = get_request_host(&req);
1626 let Some(raw_host) = raw_host else {
1627 return (StatusCode::BAD_REQUEST, "Missing Host header").into_response();
1628 };
1629
1630 let hostname = if raw_host.starts_with('[') {
1632 raw_host
1634 .split_once("]:")
1635 .map(|(host, _)| host)
1636 .unwrap_or(&raw_host)
1637 .trim_start_matches('[')
1638 .trim_end_matches(']')
1639 } else {
1640 let mut parts = raw_host.rsplitn(2, ':');
1642 let last = parts.next().unwrap_or(&raw_host);
1643 parts.next().unwrap_or(last)
1644 };
1645
1646 let path = req
1647 .uri()
1648 .path_and_query()
1649 .map(|pq| pq.as_str())
1650 .unwrap_or("/");
1651
1652 let https_port = match u16::try_from(settings().proxy.port).ok().filter(|&p| p > 0) {
1653 Some(443) | None => String::new(),
1654 Some(port) => format!(":{port}"),
1655 };
1656
1657 let host_for_url = if raw_host.starts_with('[') {
1658 format!("[{hostname}]")
1659 } else {
1660 hostname.to_string()
1661 };
1662
1663 let location = format!("https://{host_for_url}{https_port}{path}");
1664 (
1665 StatusCode::FOUND,
1666 [(axum::http::header::LOCATION, location)],
1667 )
1668 .into_response()
1669}
1670
1671fn error_response(status: StatusCode, message: &str) -> Response {
1673 (status, message.to_string()).into_response()
1674}
1675
1676#[cfg(test)]
1677mod tests {
1678 use super::*;
1679
1680 #[test]
1681 fn test_strip_tld() {
1682 assert_eq!(
1683 strip_tld("api.myproject.localhost", "localhost"),
1684 Some("api.myproject".to_string())
1685 );
1686 assert_eq!(
1687 strip_tld("api.localhost", "localhost"),
1688 Some("api".to_string())
1689 );
1690 assert_eq!(strip_tld("localhost", "localhost"), None);
1691 assert_eq!(
1692 strip_tld("api.myproject.test", "test"),
1693 Some("api.myproject".to_string())
1694 );
1695 assert_eq!(strip_tld("other.com", "localhost"), None);
1696 }
1697
1698 #[cfg(feature = "proxy-tls")]
1699 #[test]
1700 fn test_generate_ca() {
1701 let dir = tempfile::tempdir().unwrap();
1702 let cert_path = dir.path().join("ca.pem");
1703 let key_path = dir.path().join("ca-key.pem");
1704
1705 generate_ca(&cert_path, &key_path).unwrap();
1706
1707 assert!(cert_path.exists(), "ca.pem should be created");
1708 assert!(key_path.exists(), "ca-key.pem should be created");
1709
1710 let cert_pem = std::fs::read_to_string(&cert_path).unwrap();
1711 let key_pem = std::fs::read_to_string(&key_path).unwrap();
1712
1713 assert!(cert_pem.contains("BEGIN CERTIFICATE"), "should be PEM cert");
1714 assert!(
1715 key_pem.contains("BEGIN") && key_pem.contains("PRIVATE KEY"),
1716 "should be PEM key"
1717 );
1718 }
1719}