Skip to main content

pitchfork_cli/proxy/
server.rs

1//! Reverse proxy server implementation.
2//!
3//! Listens on a configured port and routes requests to daemon processes based
4//! on the `Host` header subdomain pattern.
5//!
6//! When `proxy.https = true`, a local CA is auto-generated (via `rcgen`) and
7//! each incoming TLS connection is served with a per-domain certificate signed
8//! by that CA (SNI-based dynamic certificate issuance).
9
10use std::net::SocketAddr;
11use std::sync::Arc;
12
13use axum::Router;
14use axum::body::Body;
15use axum::extract::{Request, State};
16use axum::http::{HeaderValue, StatusCode, Uri};
17use axum::response::{IntoResponse, Response};
18use hyper::header::HOST;
19
20/// Response header used to identify a pitchfork proxy (for health checks and debugging).
21const PITCHFORK_HEADER: &str = "x-pitchfork";
22
23/// Request header tracking how many times a request has passed through the proxy.
24/// Used to detect forwarding loops.
25const PROXY_HOPS_HEADER: &str = "x-pitchfork-hops";
26
27/// Maximum number of proxy hops before rejecting as a loop.
28const MAX_PROXY_HOPS: u64 = 5;
29
30/// HTTP/1.1 hop-by-hop headers that are forbidden in HTTP/2 responses.
31/// These must be stripped when proxying an HTTP/1.1 backend response back to an HTTP/2 client.
32const HOP_BY_HOP_HEADERS: &[&str] = &[
33    "connection",
34    "keep-alive",
35    "proxy-connection",
36    "transfer-encoding",
37    "upgrade",
38];
39
40use hyper_util::client::legacy::Client;
41use hyper_util::client::legacy::connect::HttpConnector;
42use hyper_util::rt::TokioExecutor;
43use tokio::net::TcpListener;
44
45use crate::daemon_id::DaemonId;
46use crate::settings::settings;
47use crate::supervisor::SUPERVISOR;
48
49// ─── Slug resolution cache ──────────────────────────────────────────────────
50//
51// `read_global_slugs()` reads ~/.config/pitchfork/config.toml from disk on every
52// call, and `namespace_for_dir()` traverses the filesystem upward to find the
53// nearest pitchfork.toml.  Both are called from `resolve_target_port()` which
54// sits in the hot path of every proxied HTTP request.
55//
56// This cache stores the resolved slug → (namespace, daemon_name) mapping
57// in memory with a short TTL so that the proxy does zero disk I/O for the vast
58// majority of requests while still picking up config changes within seconds.
59
60/// How long to cache the slug resolution table before re-reading from disk.
61const SLUG_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(2);
62
63/// Cached slug entry: pre-resolved namespace + daemon name for a slug.
64#[derive(Clone, Debug)]
65struct CachedSlugEntry {
66    /// Expected namespace derived from `entry.dir` (None if derivation failed).
67    namespace: Option<String>,
68    /// Daemon short name (defaults to slug name when not explicitly set).
69    daemon_name: String,
70    /// Project directory for this slug (needed for auto-start).
71    dir: std::path::PathBuf,
72}
73
74/// In-memory cache for the global slug registry + derived namespaces.
75struct SlugCache {
76    entries: Arc<std::collections::HashMap<String, CachedSlugEntry>>,
77    expires_at: std::time::Instant,
78}
79
80static SLUG_CACHE: once_cell::sync::Lazy<tokio::sync::Mutex<SlugCache>> =
81    once_cell::sync::Lazy::new(|| {
82        tokio::sync::Mutex::new(SlugCache {
83            entries: Arc::new(std::collections::HashMap::new()),
84            expires_at: std::time::Instant::now(), // expired → will be populated on first access
85        })
86    });
87
88/// Build the slug lookup table from disk (expensive — involves file I/O).
89/// Called outside the cache lock to avoid blocking concurrent proxy requests.
90fn build_slug_entries() -> std::collections::HashMap<String, CachedSlugEntry> {
91    let global_slugs = crate::pitchfork_toml::PitchforkToml::read_global_slugs();
92    let mut entries = std::collections::HashMap::with_capacity(global_slugs.len());
93    for (slug, entry) in &global_slugs {
94        let ns = crate::pitchfork_toml::PitchforkToml::namespace_for_dir(&entry.dir).ok();
95        let daemon_name = entry.daemon.as_deref().unwrap_or(slug).to_string();
96        entries.insert(
97            slug.clone(),
98            CachedSlugEntry {
99                namespace: ns,
100                daemon_name,
101                dir: entry.dir.clone(),
102            },
103        );
104    }
105    entries
106}
107
108/// Return a snapshot of the cached slug table, refreshing from disk if expired.
109///
110/// The disk I/O happens *outside* the mutex to avoid blocking concurrent requests
111/// during the refresh.  A short race window exists where two threads may both
112/// refresh, but that is harmless (last writer wins with identical data).
113async fn get_cached_slugs() -> Arc<std::collections::HashMap<String, CachedSlugEntry>> {
114    // Fast path: cache still valid — just clone the Arc.
115    {
116        let cache = SLUG_CACHE.lock().await;
117        if std::time::Instant::now() < cache.expires_at {
118            return Arc::clone(&cache.entries);
119        }
120    } // lock released before disk I/O
121
122    // Slow path: refresh from disk (no lock held).
123    let new_entries = Arc::new(build_slug_entries());
124
125    // Store the refreshed entries.
126    {
127        let mut cache = SLUG_CACHE.lock().await;
128        cache.entries = Arc::clone(&new_entries);
129        cache.expires_at = std::time::Instant::now() + SLUG_CACHE_TTL;
130    }
131
132    new_entries
133}
134
135/// Look up a single slug in the cached table.
136async fn cached_slug_lookup(subdomain: &str) -> Option<CachedSlugEntry> {
137    let entries = get_cached_slugs().await;
138    entries.get(subdomain).cloned()
139}
140
141// ─── Auto-start deduplication ───────────────────────────────────────────────
142//
143// When auto_start is enabled, concurrent proxy requests for the same stopped
144// daemon must not trigger multiple start operations.  This set tracks daemon
145// IDs that are currently being auto-started.
146
147static AUTO_START_IN_PROGRESS: once_cell::sync::Lazy<
148    tokio::sync::Mutex<std::collections::HashSet<DaemonId>>,
149> = once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new(std::collections::HashSet::new()));
150
151/// Result of resolving a proxy target for a given host.
152enum ResolveResult {
153    /// Daemon is running and ready — forward to this port.
154    /// Covers both already-running daemons and freshly auto-started ones.
155    Ready(u16),
156    /// Daemon is currently starting (auto-start in progress or just triggered).
157    Starting { slug: String },
158    /// No matching slug or daemon found.
159    NotFound,
160    /// Routing refused with a descriptive reason.
161    Error(String),
162}
163
164/// Shared proxy state passed to each request handler.
165/// Callback type invoked on proxy errors (e.g. for logging/alerting).
166type OnErrorFn = Arc<dyn Fn(&str) + Send + Sync>;
167
168#[derive(Clone)]
169struct ProxyState {
170    /// HTTP client used to forward requests to daemon backends.
171    client: Arc<Client<HttpConnector, Body>>,
172    /// The configured TLD (e.g. "localhost").
173    tld: String,
174    /// Whether the proxy is serving HTTPS.
175    is_tls: bool,
176    /// Optional error callback invoked on proxy errors (e.g. for logging/alerting).
177    on_error: Option<OnErrorFn>,
178}
179
180/// Start the reverse proxy server.
181///
182/// Binds to the configured port and serves until the process exits.
183/// When `proxy.https = true`, TLS is terminated here using a self-signed
184/// certificate (auto-generated if not present).
185///
186/// This function is intended to be spawned as a background task.
187pub async fn serve(
188    bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
189    cancel: tokio_util::sync::CancellationToken,
190) -> crate::Result<()> {
191    let s = settings();
192    let Some(effective_port) = u16::try_from(s.proxy.port).ok().filter(|&p| p > 0) else {
193        let msg = format!(
194            "proxy.port {} is out of valid port range (1-65535), proxy server cannot start",
195            s.proxy.port
196        );
197        let _ = bind_tx.send(Err(msg.clone()));
198        miette::bail!("{msg}");
199    };
200
201    let mut connector = HttpConnector::new();
202    // Limit how long the proxy waits to establish a TCP connection to a backend.
203    // Without this, a daemon that accepts the SYN but never completes the handshake
204    // would stall the proxy indefinitely.
205    connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
206
207    let client = Client::builder(TokioExecutor::new())
208        // Reclaim idle keep-alive connections after 30 s so that file descriptors
209        // are not held open forever when a backend goes quiet.
210        .pool_idle_timeout(std::time::Duration::from_secs(30))
211        .build(connector);
212
213    let state = ProxyState {
214        client: Arc::new(client),
215        tld: s.proxy.tld.clone(),
216        is_tls: s.proxy.https,
217        on_error: None,
218    };
219
220    let app = Router::new().fallback(proxy_handler).with_state(state);
221
222    // Resolve bind address from settings (default: 127.0.0.1 for local-only access).
223    let bind_ip: std::net::IpAddr = match s.proxy.host.parse() {
224        Ok(ip) => ip,
225        Err(_) => {
226            log::warn!(
227                "proxy.host {:?} is not a valid IP address — falling back to 127.0.0.1. \
228                 The proxy will only be reachable on the loopback interface.",
229                s.proxy.host
230            );
231            std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
232        }
233    };
234    let addr = SocketAddr::from((bind_ip, effective_port));
235
236    if s.proxy.https {
237        serve_https_with_http_fallback(app, addr, s, effective_port, bind_tx, cancel).await
238    } else {
239        serve_http(app, addr, effective_port, bind_tx, cancel).await
240    }
241}
242
243/// Serve plain HTTP.
244async fn serve_http(
245    app: Router,
246    addr: SocketAddr,
247    effective_port: u16,
248    bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
249    cancel: tokio_util::sync::CancellationToken,
250) -> crate::Result<()> {
251    let listener = match TcpListener::bind(addr).await {
252        Ok(l) => {
253            if settings().proxy.sync_hosts {
254                crate::proxy::hosts::sync_hosts_from_settings();
255            }
256            let _ = bind_tx.send(Ok(()));
257            l
258        }
259        Err(e) => {
260            let msg = bind_error_message(effective_port, &e);
261            let _ = bind_tx.send(Err(msg.clone()));
262            return Err(miette::miette!("{msg}"));
263        }
264    };
265
266    log::info!("Proxy server listening on http://{addr}");
267    if effective_port < 1024 {
268        log::info!(
269            "Note: port {effective_port} is a privileged port. \
270             The supervisor must be started with sudo to bind to this port."
271        );
272    }
273    let shutdown_signal = cancel.clone().cancelled_owned();
274    axum::serve(
275        listener,
276        app.into_make_service_with_connect_info::<SocketAddr>(),
277    )
278    .with_graceful_shutdown(shutdown_signal)
279    .await
280    .map_err(|e| miette::miette!("Proxy server error: {e}"))?;
281    Ok(())
282}
283
284/// Serve HTTPS with automatic HTTP detection on the same port.
285///
286/// Peeks at the first byte of each incoming TCP connection:
287/// - `0x16` (TLS ClientHello) → hand off to the TLS acceptor (HTTP/2 + HTTP/1.1 via ALPN)
288/// - anything else → 302 redirect to HTTPS
289#[cfg(feature = "proxy-tls")]
290async fn serve_https_with_http_fallback(
291    app: Router,
292    addr: SocketAddr,
293    s: &crate::settings::Settings,
294    effective_port: u16,
295    bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
296    cancel: tokio_util::sync::CancellationToken,
297) -> crate::Result<()> {
298    use rustls::ServerConfig;
299    use tokio_rustls::TlsAcceptor;
300
301    let (ca_cert_path, ca_key_path) = resolve_tls_paths(s);
302
303    // Generate CA if not present
304    if !ca_cert_path.exists() || !ca_key_path.exists() {
305        generate_ca(&ca_cert_path, &ca_key_path)?;
306        log::info!(
307            "Generated local CA certificate at {}",
308            ca_cert_path.display()
309        );
310        log::info!("To trust the CA in your browser, run: pitchfork proxy trust");
311    }
312
313    // Install ring as the default CryptoProvider if none has been set yet.
314    let _ = rustls::crypto::ring::default_provider().install_default();
315
316    // Build the SNI resolver (loads CA, caches per-domain certs)
317    let resolver = SniCertResolver::new(&ca_cert_path, &ca_key_path)?;
318
319    let mut tls_config = ServerConfig::builder()
320        .with_no_client_auth()
321        .with_cert_resolver(Arc::new(resolver));
322    // Advertise HTTP/2 and HTTP/1.1 via ALPN so browsers negotiate HTTP/2
323    // for multiplexed requests (eliminates the 6-connection-per-host limit).
324    tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
325
326    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
327
328    let listener = match TcpListener::bind(addr).await {
329        Ok(l) => {
330            if settings().proxy.sync_hosts {
331                crate::proxy::hosts::sync_hosts_from_settings();
332            }
333            let _ = bind_tx.send(Ok(()));
334            l
335        }
336        Err(e) => {
337            let msg = bind_error_message(effective_port, &e);
338            let _ = bind_tx.send(Err(msg.clone()));
339            return Err(miette::miette!("{msg}"));
340        }
341    };
342
343    log::info!("Proxy server listening on https://{addr} (HTTP also accepted)");
344    if effective_port < 1024 {
345        log::info!(
346            "Note: port {effective_port} is a privileged port. \
347             The supervisor must be started with sudo to bind to this port."
348        );
349    }
350
351    // Build a lightweight redirect app for plain-HTTP requests.
352    let redirect_app = Router::new().fallback(redirect_to_https_handler);
353
354    // Accept connections and sniff the first byte to decide TLS vs plain HTTP.
355    let mut conn_tasks: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
356    loop {
357        // Reap finished connection tasks during normal operation so the JoinSet
358        // does not retain one entry per historical connection.
359        while conn_tasks.try_join_next().is_some() {}
360
361        tokio::select! {
362            accept_result = listener.accept() => {
363                let (stream, _peer_addr) = match accept_result {
364                    Ok(conn) => conn,
365                    Err(e) => {
366                        log::warn!("Accept error (will retry): {e}");
367                        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
368                        continue;
369                    }
370                };
371
372                let acceptor = acceptor.clone();
373                let app = app.clone();
374                let redirect_app = redirect_app.clone();
375
376                conn_tasks.spawn(async move {
377                    // Peek at the first byte without consuming it.
378                    // TLS ClientHello always starts with 0x16 (content type "handshake").
379                    let mut peek_buf = [0u8; 1];
380                    match stream.peek(&mut peek_buf).await {
381                        Ok(0) | Err(_) => return,
382                        _ => {}
383                    }
384
385                    if peek_buf[0] == 0x16 {
386                        // TLS handshake → HTTP/2 or HTTP/1.1 (negotiated via ALPN)
387                        match acceptor.accept(stream).await {
388                            Ok(tls_stream) => {
389                                let io = hyper_util::rt::TokioIo::new(tls_stream);
390                                let svc = hyper_util::service::TowerToHyperService::new(app);
391                                if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
392                                    .serve_connection_with_upgrades(io, svc)
393                                    .await
394                                {
395                                    // HTTP/2 RST_STREAM errors from cancelled browser requests
396                                    // (navigation, HMR) are normal — log at debug to avoid noise.
397                                    log::debug!("Connection error: {e}");
398                                }
399                            }
400                            Err(e) => {
401                                log::debug!("TLS handshake error: {e}");
402                            }
403                        }
404                    } else {
405                        // Plain HTTP on the TLS port → 302 redirect to HTTPS
406                        let io = hyper_util::rt::TokioIo::new(stream);
407                        let svc = hyper_util::service::TowerToHyperService::new(redirect_app);
408                        let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
409                            .serve_connection_with_upgrades(io, svc)
410                            .await;
411                    }
412                });
413
414                while conn_tasks.try_join_next().is_some() {}
415            }
416            _ = cancel.cancelled() => {
417                log::info!("Proxy server shutting down (cancel signal received)");
418                break;
419            }
420        }
421    }
422
423    // Drain in-flight connections with a timeout.
424    let drain_timeout = std::time::Duration::from_secs(10);
425    let _ = tokio::time::timeout(drain_timeout, async {
426        while conn_tasks.join_next().await.is_some() {}
427    })
428    .await;
429
430    Ok(())
431}
432
433/// Fallback when proxy-tls feature is not enabled.
434#[cfg(not(feature = "proxy-tls"))]
435async fn serve_https_with_http_fallback(
436    _app: Router,
437    _addr: SocketAddr,
438    _s: &crate::settings::Settings,
439    _effective_port: u16,
440    bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
441    _cancel: tokio_util::sync::CancellationToken,
442) -> crate::Result<()> {
443    let msg = "HTTPS proxy support requires the `proxy-tls` feature.\n\
444         Rebuild pitchfork with: cargo build --features proxy-tls"
445        .to_string();
446    let _ = bind_tx.send(Err(msg.clone()));
447    miette::bail!("{msg}")
448}
449
450/// Resolve the CA certificate and key paths from settings.
451///
452/// If `tls_cert` / `tls_key` are empty, falls back to the auto-generated
453/// CA paths in `$PITCHFORK_STATE_DIR/proxy/`.
454#[cfg(feature = "proxy-tls")]
455fn resolve_tls_paths(s: &crate::settings::Settings) -> (std::path::PathBuf, std::path::PathBuf) {
456    let proxy_dir = crate::env::PITCHFORK_STATE_DIR.join("proxy");
457    let resolve = |configured: &str, default: &str| {
458        if configured.is_empty() {
459            proxy_dir.join(default)
460        } else {
461            std::path::PathBuf::from(configured)
462        }
463    };
464    (
465        resolve(&s.proxy.tls_cert, "ca.pem"),
466        resolve(&s.proxy.tls_key, "ca-key.pem"),
467    )
468}
469
470/// Generate a local root CA certificate and private key using `rcgen`.
471///
472/// The CA is used to sign per-domain certificates on demand (SNI).
473/// Files are written in PEM format to `cert_path` and `key_path`.
474#[cfg(feature = "proxy-tls")]
475pub fn generate_ca(cert_path: &std::path::Path, key_path: &std::path::Path) -> crate::Result<()> {
476    use rcgen::{
477        BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyUsagePurpose,
478    };
479
480    // Create parent directory if needed
481    if let Some(parent) = cert_path.parent() {
482        std::fs::create_dir_all(parent)
483            .map_err(|e| miette::miette!("Failed to create proxy cert directory: {e}"))?;
484    }
485
486    let mut params = CertificateParams::default();
487    let mut dn = DistinguishedName::new();
488    dn.push(DnType::CommonName, "Pitchfork Local CA");
489    dn.push(DnType::OrganizationName, "Pitchfork");
490    params.distinguished_name = dn;
491    params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
492    params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
493
494    let key_pair = rcgen::KeyPair::generate()
495        .map_err(|e| miette::miette!("Failed to generate CA key pair: {e}"))?;
496    let ca_cert = params
497        .self_signed(&key_pair)
498        .map_err(|e| miette::miette!("Failed to self-sign CA certificate: {e}"))?;
499
500    // Write the CA certificate (public — 0644 is fine)
501    std::fs::write(cert_path, ca_cert.pem()).map_err(|e| {
502        miette::miette!(
503            "Failed to write CA certificate to {}: {e}",
504            cert_path.display()
505        )
506    })?;
507
508    // Write the CA private key with restrictive permissions (0600).
509    // Using OpenOptions + mode() so the file is never world-readable,
510    // even briefly before a chmod call.
511    {
512        #[cfg(unix)]
513        {
514            use std::io::Write;
515            use std::os::unix::fs::OpenOptionsExt;
516            std::fs::OpenOptions::new()
517                .write(true)
518                .create(true)
519                .truncate(true)
520                .mode(0o600)
521                .open(key_path)
522                .and_then(|mut f| f.write_all(key_pair.serialize_pem().as_bytes()))
523                .map_err(|e| {
524                    miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
525                })?;
526        }
527        #[cfg(not(unix))]
528        {
529            std::fs::write(key_path, key_pair.serialize_pem()).map_err(|e| {
530                miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
531            })?;
532            log::debug!(
533                "CA private key written to {} (file permissions are not restricted \
534                 on non-Unix platforms — consider restricting access manually)",
535                key_path.display()
536            );
537        }
538    }
539
540    Ok(())
541}
542
543/// SNI-based certificate resolver.
544///
545/// Holds the local CA and a two-level cache of per-domain certificates:
546/// - L1: in-memory `HashMap` (fastest, process-lifetime)
547/// - L2: on-disk `host-certs/<safe_name>.pem` (survives restarts)
548///
549/// A `pending` set prevents concurrent requests for the same domain from
550/// triggering multiple simultaneous cert-generation operations.
551///
552/// On each new TLS connection, `resolve()` is called with the SNI hostname;
553/// if no cached cert exists for that domain, one is signed by the CA on the fly.
554///
555/// # Locking strategy
556/// Both `cache` and `pending` use `std::sync::Mutex` paired with a
557/// `std::sync::Condvar`.  The critical sections are intentionally short
558/// (hash-map lookups / inserts), so the blocking time is negligible.
559/// `get_or_create` is only called from the synchronous `ResolvesServerCert`
560/// trait method (not from an async context), so blocking a thread here is
561/// acceptable.
562#[cfg(feature = "proxy-tls")]
563struct SniCertResolver {
564    /// The CA issuer (key + parsed cert params, used to sign leaf certs).
565    issuer: rcgen::Issuer<'static, rcgen::KeyPair>,
566    /// Directory where per-domain PEM files are cached on disk.
567    host_certs_dir: std::path::PathBuf,
568    /// L1 cache: domain → certified key (in-memory).
569    cache: std::sync::Mutex<std::collections::HashMap<String, Arc<rustls::sign::CertifiedKey>>>,
570    /// Pending set: domains currently being generated (dedup concurrent requests).
571    /// Using a `Condvar` so waiting threads are parked instead of spin-sleeping,
572    /// which avoids blocking tokio worker threads.
573    pending: std::sync::Mutex<std::collections::HashSet<String>>,
574    /// Condvar paired with `pending` — notified when a domain is removed from the set.
575    pending_cv: std::sync::Condvar,
576}
577
578#[cfg(feature = "proxy-tls")]
579impl std::fmt::Debug for SniCertResolver {
580    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581        f.debug_struct("SniCertResolver").finish_non_exhaustive()
582    }
583}
584
585#[cfg(feature = "proxy-tls")]
586impl SniCertResolver {
587    /// Load the CA from disk and prepare the resolver.
588    fn new(ca_cert_path: &std::path::Path, ca_key_path: &std::path::Path) -> crate::Result<Self> {
589        let ca_key_pem = std::fs::read_to_string(ca_key_path)
590            .map_err(|e| miette::miette!("Failed to read CA key {}: {e}", ca_key_path.display()))?;
591        let ca_cert_pem = std::fs::read_to_string(ca_cert_path).map_err(|e| {
592            miette::miette!("Failed to read CA cert {}: {e}", ca_cert_path.display())
593        })?;
594
595        // Verify the PEM is readable (sanity check)
596        if !ca_cert_pem.contains("BEGIN CERTIFICATE") {
597            miette::bail!("CA cert file does not contain a valid PEM certificate");
598        }
599
600        let ca_key = rcgen::KeyPair::from_pem(&ca_key_pem)
601            .map_err(|e| miette::miette!("Failed to parse CA key: {e}"))?;
602
603        // Parse the CA cert + key into an Issuer for signing leaf certs.
604        let issuer = rcgen::Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key)
605            .map_err(|e| miette::miette!("Failed to parse CA cert: {e}"))?;
606
607        // Ensure the host-certs directory exists
608        let host_certs_dir = ca_cert_path
609            .parent()
610            .unwrap_or(std::path::Path::new("."))
611            .join("host-certs");
612        std::fs::create_dir_all(&host_certs_dir)
613            .map_err(|e| miette::miette!("Failed to create host-certs dir: {e}"))?;
614
615        Ok(Self {
616            issuer,
617            host_certs_dir,
618            cache: std::sync::Mutex::new(std::collections::HashMap::new()),
619            pending: std::sync::Mutex::new(std::collections::HashSet::new()),
620            pending_cv: std::sync::Condvar::new(),
621        })
622    }
623
624    /// Get or create a `CertifiedKey` for the given domain.
625    ///
626    /// Resolution order:
627    /// 1. L1 in-memory cache
628    /// 2. L2 on-disk cache (`host-certs/<safe_name>.pem`)
629    /// 3. Generate fresh cert, persist to disk, populate both caches
630    ///
631    /// Concurrent requests for the same domain are deduplicated: the second
632    /// thread waits on a `Condvar` until the first thread finishes, then reads
633    /// from the cache.  This avoids both duplicate cert generation and the
634    /// spin-sleep anti-pattern that would block tokio worker threads.
635    ///
636    /// # Locking discipline
637    /// `cache` and `pending` are **never held simultaneously**.  The protocol is:
638    /// 1. Check `cache` (lock, read, unlock).
639    /// 2. Acquire `pending`; wait if domain is in-progress; re-check `cache`
640    ///    after waking (unlock `cache` before re-acquiring `pending` is not
641    ///    needed because we release `cache` before entering the `pending` block).
642    /// 3. Insert domain into `pending`; release `pending` lock.
643    /// 4. Generate cert (no locks held).
644    /// 5. Insert into `cache` (lock, write, unlock).
645    /// 6. Remove from `pending` and notify (lock, write, unlock).
646    fn get_or_create(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
647        // L1: memory cache (fast path — no pending lock needed)
648        {
649            let cache = self.cache.lock().ok()?;
650            if let Some(ck) = cache.get(domain) {
651                return Some(Arc::clone(ck));
652            }
653        } // cache lock released here
654
655        // Dedup: acquire the pending lock, wait if another thread is generating
656        // this domain, then re-check the cache (without holding pending) before
657        // deciding to generate.
658        //
659        // We deliberately release the pending lock before re-checking the cache
660        // to avoid holding both locks simultaneously.  The re-check is safe
661        // because: if the generating thread inserted into the cache and then
662        // removed from pending, we will see the cert in the cache.  If we miss
663        // the window (extremely unlikely), we will generate a duplicate cert,
664        // which is harmless — the last writer wins in the cache.
665        loop {
666            {
667                let mut pending = self.pending.lock().ok()?;
668                if pending.contains(domain) {
669                    // Another thread is generating; wait until it finishes.
670                    pending = self.pending_cv.wait(pending).ok()?;
671                    // pending lock re-acquired; loop to re-check cache below.
672                    drop(pending);
673                } else {
674                    // No one else is generating; claim the slot and proceed.
675                    pending.insert(domain.to_string());
676                    break;
677                }
678            } // pending lock released
679
680            // Re-check cache after being woken (the generating thread may have
681            // already populated it).  Cache lock is acquired independently of
682            // pending lock here — no nesting.
683            {
684                let cache = self.cache.lock().ok()?;
685                if let Some(ck) = cache.get(domain) {
686                    return Some(Arc::clone(ck));
687                }
688            } // cache lock released
689        } // pending lock released at break
690
691        let result = self.get_or_create_inner(domain);
692
693        // Always clear the pending flag and wake waiting threads.
694        // notify_all() is called *inside* the lock scope so that the domain is
695        // guaranteed to be removed before any waiting thread is woken up.
696        // If the lock is poisoned we recover it (the data is still valid) so
697        // that the domain is always removed and waiters are always notified.
698        {
699            let mut pending = match self.pending.lock() {
700                Ok(g) => g,
701                Err(e) => e.into_inner(),
702            };
703            pending.remove(domain);
704            self.pending_cv.notify_all();
705        }
706
707        result
708    }
709
710    /// Inner implementation: check disk cache, then generate.
711    fn get_or_create_inner(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
712        let safe_name = domain.replace('.', "_").replace('*', "wildcard");
713        let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
714
715        // L2: disk cache — try to load existing cert+key PEM
716        if disk_path.exists() {
717            if let Ok(ck) = self.load_from_disk(&disk_path) {
718                let ck = Arc::new(ck);
719                if let Ok(mut cache) = self.cache.lock() {
720                    cache.insert(domain.to_string(), Arc::clone(&ck));
721                }
722                return Some(ck);
723            }
724            // Disk cache corrupt/expired — fall through to regenerate
725            let _ = std::fs::remove_file(&disk_path);
726        }
727
728        // L3: generate fresh cert
729        let ck = self.sign_for_domain(domain).ok()?;
730
731        let ck = Arc::new(ck);
732        if let Ok(mut cache) = self.cache.lock() {
733            cache.insert(domain.to_string(), Arc::clone(&ck));
734        }
735        Some(ck)
736    }
737
738    /// Load a `CertifiedKey` from a combined cert+key PEM file on disk.
739    ///
740    /// Returns an error if the certificate has already expired, so the caller
741    /// can fall through to regeneration rather than serving a stale cert.
742    fn load_from_disk(&self, path: &std::path::Path) -> crate::Result<rustls::sign::CertifiedKey> {
743        use rustls::pki_types::CertificateDer;
744        use rustls_pemfile::{certs, private_key};
745
746        let pem = std::fs::read_to_string(path)
747            .map_err(|e| miette::miette!("Failed to read disk cert {}: {e}", path.display()))?;
748
749        let cert_ders: Vec<CertificateDer<'static>> = certs(&mut pem.as_bytes())
750            .collect::<Result<Vec<_>, _>>()
751            .map_err(|e| miette::miette!("Failed to parse certs from {}: {e}", path.display()))?;
752
753        if cert_ders.is_empty() {
754            miette::bail!("No certificates found in {}", path.display());
755        }
756
757        // Check that the first certificate has not expired using x509-parser.
758        {
759            let (_, cert) = x509_parser::parse_x509_certificate(&cert_ders[0]).map_err(|e| {
760                miette::miette!("Failed to parse certificate from {}: {e}", path.display())
761            })?;
762            use chrono::Utc;
763            let now_ts = Utc::now().timestamp();
764            let not_after_ts = cert.validity().not_after.timestamp();
765            if not_after_ts < now_ts {
766                miette::bail!(
767                    "Cached certificate at {} has expired — will regenerate",
768                    path.display()
769                );
770            }
771        }
772
773        let key_der = private_key(&mut pem.as_bytes())
774            .map_err(|e| miette::miette!("Failed to parse key from {}: {e}", path.display()))?
775            .ok_or_else(|| miette::miette!("No private key found in {}", path.display()))?;
776
777        let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
778            .map_err(|e| miette::miette!("Failed to create signing key from disk: {e}"))?;
779
780        Ok(rustls::sign::CertifiedKey::new(cert_ders, signing_key))
781    }
782
783    /// Sign a leaf certificate for `domain` using the CA.
784    ///
785    /// SANs include:
786    /// - `DNS:<domain>` (exact match)
787    /// - `DNS:*.<parent>` (sibling wildcard, e.g. `*.pf.localhost` for `docs.pf.localhost`)
788    ///
789    /// Returns both the `CertifiedKey` and the combined PEM for disk caching.
790    fn sign_for_domain(&self, domain: &str) -> crate::Result<rustls::sign::CertifiedKey> {
791        use rcgen::date_time_ymd;
792        use rcgen::{CertificateParams, DistinguishedName, DnType, SanType};
793        use rustls::pki_types::CertificateDer;
794        use rustls_pemfile::private_key;
795
796        let mut params = CertificateParams::default();
797        let mut dn = DistinguishedName::new();
798        dn.push(DnType::CommonName, domain);
799        params.distinguished_name = dn;
800
801        // Set validity dynamically: from yesterday to 10 years from now.
802        {
803            use chrono::{Datelike, Duration, Utc};
804            let yesterday = Utc::now() - Duration::days(1);
805            // 397 days: stays within Chrome/Safari's 398-day maximum validity limit
806            // for TLS certificates (including locally-trusted CA leaf certs).
807            let expiry = Utc::now() + Duration::days(397);
808            params.not_before = date_time_ymd(
809                yesterday.year(),
810                yesterday.month() as u8,
811                yesterday.day() as u8,
812            );
813            params.not_after =
814                date_time_ymd(expiry.year(), expiry.month() as u8, expiry.day() as u8);
815        }
816
817        // Build SANs: exact domain + sibling wildcard (e.g. *.pf.localhost)
818        let mut sans =
819            vec![SanType::DnsName(domain.to_string().try_into().map_err(
820                |e| miette::miette!("Invalid domain name '{domain}': {e}"),
821            )?)];
822        // Add wildcard SAN for the parent domain (one level up)
823        if let Some(dot_pos) = domain.find('.') {
824            let parent = &domain[dot_pos + 1..];
825            // Only add wildcard if parent has at least one dot (not a bare TLD)
826            if parent.contains('.') {
827                let wildcard = format!("*.{parent}");
828                if let Ok(wc) = wildcard.try_into() {
829                    sans.push(SanType::DnsName(wc));
830                }
831            }
832        }
833        params.subject_alt_names = sans;
834
835        let leaf_key = rcgen::KeyPair::generate()
836            .map_err(|e| miette::miette!("Failed to generate leaf key: {e}"))?;
837        let leaf_cert = params
838            .signed_by(&leaf_key, &self.issuer)
839            .map_err(|e| miette::miette!("Failed to sign leaf cert for '{domain}': {e}"))?;
840
841        // Convert to rustls types
842        let cert_der = CertificateDer::from(leaf_cert.der().to_vec());
843        let key_pem = leaf_key.serialize_pem();
844        let key_der = private_key(&mut key_pem.as_bytes())
845            .map_err(|e| miette::miette!("Failed to parse leaf key PEM: {e}"))?
846            .ok_or_else(|| miette::miette!("No private key found in generated PEM"))?;
847
848        let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
849            .map_err(|e| miette::miette!("Failed to create signing key: {e}"))?;
850
851        // Persist cert + key to disk cache as combined PEM.
852        // Use 0600 so the private key is not world-readable.
853        let safe_name = domain.replace('.', "_").replace('*', "wildcard");
854        let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
855        let combined_pem = format!("{}{}", leaf_cert.pem(), key_pem);
856        {
857            #[cfg(unix)]
858            {
859                use std::io::Write;
860                use std::os::unix::fs::OpenOptionsExt;
861                if let Err(e) = std::fs::OpenOptions::new()
862                    .write(true)
863                    .create(true)
864                    .truncate(true)
865                    .mode(0o600)
866                    .open(&disk_path)
867                    .and_then(|mut f| f.write_all(combined_pem.as_bytes()))
868                {
869                    log::warn!(
870                        "Failed to persist cert for '{domain}' to {}: {e}",
871                        disk_path.display()
872                    );
873                }
874            }
875            #[cfg(not(unix))]
876            {
877                if let Err(e) = std::fs::write(&disk_path, combined_pem) {
878                    log::warn!(
879                        "Failed to persist cert for '{domain}' to {}: {e}",
880                        disk_path.display()
881                    );
882                } else {
883                    log::debug!(
884                        "Leaf cert for '{domain}' written to {} (file permissions are not \
885                         restricted on non-Unix platforms — consider restricting access manually)",
886                        disk_path.display()
887                    );
888                }
889            }
890        }
891
892        Ok(rustls::sign::CertifiedKey::new(vec![cert_der], signing_key))
893    }
894}
895
896#[cfg(feature = "proxy-tls")]
897impl rustls::server::ResolvesServerCert for SniCertResolver {
898    fn resolve(
899        &self,
900        client_hello: rustls::server::ClientHello<'_>,
901    ) -> Option<Arc<rustls::sign::CertifiedKey>> {
902        let domain = client_hello.server_name()?;
903        self.get_or_create(domain)
904    }
905}
906
907/// Get the effective host from a request.
908///
909/// HTTP/2 uses the `:authority` pseudo-header, which hyper exposes via
910/// `req.uri().authority()` rather than in the `HeaderMap`.
911/// HTTP/1.1 uses the `Host` header.
912fn get_request_host(req: &Request) -> Option<String> {
913    // HTTP/2: :authority is available via the request URI, not the HeaderMap.
914    let authority = req
915        .uri()
916        .authority()
917        .map(|a| a.as_str().to_string())
918        .filter(|s| !s.is_empty());
919
920    authority.or_else(|| {
921        req.headers()
922            .get(HOST)
923            .and_then(|h| h.to_str().ok())
924            .map(str::to_string)
925    })
926}
927
928/// Inject `X-Forwarded-*` headers into a proxied request.
929///
930/// Because the proxy is a **first-hop** dev tool (not a mid-tier forwarder),
931/// all four headers are **unconditionally overwritten** with values derived
932/// from the actual incoming connection.  Any values supplied by the connecting
933/// client are discarded.
934///
935/// Trusting client-supplied `x-forwarded-for` / `x-forwarded-proto` would
936/// allow a local process to spoof a remote IP or trick a backend's
937/// HTTPS-detection logic (CSRF checks, secure-cookie flags, redirect rules).
938fn inject_forwarded_headers(req: &mut Request, is_tls: bool, host_header: &str) {
939    let remote_addr = req
940        .extensions()
941        .get::<axum::extract::ConnectInfo<SocketAddr>>()
942        .map(|ci| ci.0.ip().to_string())
943        .unwrap_or_else(|| "127.0.0.1".to_string());
944
945    let proto = if is_tls { "https" } else { "http" };
946    let default_port = if is_tls { "443" } else { "80" };
947
948    // Always set fresh values — we are the edge, never a mid-tier forwarder.
949    // Discard any x-forwarded-* headers supplied by the connecting client.
950    let forwarded_for = remote_addr.clone();
951    let forwarded_proto = proto.to_string();
952    let forwarded_host = host_header.to_string();
953    let forwarded_port = host_header
954        .rsplit_once(':')
955        .map(|(_, port)| port.to_string())
956        .unwrap_or_else(|| default_port.to_string());
957
958    // Strip any client-supplied x-forwarded-* and RFC 7239 Forwarded headers
959    // before inserting ours, so that no trace of the original values reaches
960    // the backend.  The RFC 7239 `Forwarded` header is stripped alongside the
961    // legacy `x-forwarded-*` set because backends that read it (Django, Rails,
962    // Spring) would otherwise see client-injected spoofed IPs or protocols.
963    for name in [
964        "x-forwarded-for",
965        "x-forwarded-proto",
966        "x-forwarded-host",
967        "x-forwarded-port",
968        "forwarded",
969    ] {
970        if let Ok(header_name) = axum::http::HeaderName::from_bytes(name.as_bytes()) {
971            req.headers_mut().remove(&header_name);
972        }
973    }
974
975    let headers = [
976        ("x-forwarded-for", forwarded_for),
977        ("x-forwarded-proto", forwarded_proto),
978        ("x-forwarded-host", forwarded_host),
979        ("x-forwarded-port", forwarded_port),
980    ];
981
982    for (name, value) in headers {
983        if let Ok(v) = HeaderValue::from_str(&value) {
984            let header_name = axum::http::HeaderName::from_static(name);
985            req.headers_mut().insert(header_name, v);
986        }
987    }
988}
989
990/// Main proxy request handler.
991///
992/// Parses the `Host` header, resolves the target daemon, and forwards the request.
993/// WebSocket / HTTP upgrade requests are forwarded transparently via hyper's upgrade mechanism.
994async fn proxy_handler(State(state): State<ProxyState>, mut req: Request) -> Response {
995    // Extract the host (supports both HTTP/2 :authority and HTTP/1.1 Host)
996    let Some(raw_host) = get_request_host(&req) else {
997        return error_response(StatusCode::BAD_REQUEST, "Missing Host header");
998    };
999    // Strip port from host for routing.
1000    // IPv6 addresses in Host headers are bracketed per RFC 2732: `[::1]:port`.
1001    // Splitting naïvely on ':' would break on the colons inside the address.
1002    let host = if raw_host.starts_with('[') {
1003        // IPv6: "[::1]:port" or "[::1]"
1004        raw_host
1005            .split("]:")
1006            .next()
1007            .unwrap_or(&raw_host)
1008            .trim_start_matches('[')
1009            .trim_end_matches(']')
1010            .to_string()
1011    } else {
1012        // IPv4 / hostname: "host:port" or "host"
1013        raw_host.split(':').next().unwrap_or(&raw_host).to_string()
1014    };
1015
1016    // Loop detection: check hop count.
1017    //
1018    // Security: strip (zero out) the hop counter on the very first hop to
1019    // prevent external clients from forging a high value and triggering a
1020    // 508 Loop Detected response (denial-of-service).  A request is
1021    // considered "first hop" when it does not carry the `x-pitchfork-hops`
1022    // request header that pitchfork injects when forwarding — i.e. it did
1023    // not come from another pitchfork proxy instance.
1024    // Note: `x-pitchfork` is a *response* header added by pitchfork and is
1025    // never present on incoming requests, so it cannot be used here.
1026    let is_from_pitchfork = req.headers().contains_key(PROXY_HOPS_HEADER);
1027    let hops: u64 = if is_from_pitchfork {
1028        req.headers()
1029            .get(PROXY_HOPS_HEADER)
1030            .and_then(|v| v.to_str().ok())
1031            .and_then(|s| s.parse().ok())
1032            .unwrap_or(0)
1033    } else {
1034        // External request: ignore any forged hop counter.
1035        0
1036    };
1037    if hops >= MAX_PROXY_HOPS {
1038        return error_response(
1039            StatusCode::LOOP_DETECTED,
1040            &format!(
1041                "Loop detected for '{host}': request has passed through the proxy {hops} times.\n\
1042                 This usually means a backend is proxying back through pitchfork without rewriting \n\
1043                 the Host header. If you use Vite/webpack proxy, set changeOrigin: true."
1044            ),
1045        );
1046    }
1047
1048    // Resolve the target port from the host
1049    let target_port = match resolve_target(&host, &state.tld).await {
1050        ResolveResult::Ready(port) => port,
1051        ResolveResult::Starting { slug } => {
1052            return starting_html_response(&slug, &raw_host);
1053        }
1054        ResolveResult::NotFound => {
1055            return error_response(
1056                StatusCode::BAD_GATEWAY,
1057                &format!(
1058                    "No daemon found for host '{host}'.\n\
1059                     Make sure the daemon has a slug, is running, and has a port configured.\n\
1060                     Expected format: <slug>.{tld}",
1061                    tld = state.tld
1062                ),
1063            );
1064        }
1065        ResolveResult::Error(msg) => {
1066            return error_response(StatusCode::BAD_GATEWAY, &msg);
1067        }
1068    };
1069
1070    // Build the forwarding URI
1071    let path_and_query = req
1072        .uri()
1073        .path_and_query()
1074        .map(|pq| pq.as_str())
1075        .unwrap_or("/");
1076
1077    let forward_uri = match Uri::builder()
1078        .scheme("http")
1079        .authority(format!("localhost:{target_port}"))
1080        .path_and_query(path_and_query)
1081        .build()
1082    {
1083        Ok(uri) => uri,
1084        Err(e) => {
1085            return error_response(
1086                StatusCode::INTERNAL_SERVER_ERROR,
1087                &format!("Failed to build forward URI: {e}"),
1088            );
1089        }
1090    };
1091
1092    // Update the request URI and Host header
1093    *req.uri_mut() = forward_uri;
1094    req.headers_mut().insert(
1095        HOST,
1096        HeaderValue::from_str(&format!("localhost:{target_port}"))
1097            .unwrap_or_else(|_| HeaderValue::from_static("localhost")),
1098    );
1099
1100    // Inject X-Forwarded-* headers
1101    inject_forwarded_headers(&mut req, state.is_tls, &raw_host);
1102
1103    // Increment hop counter
1104    if let Ok(v) = HeaderValue::from_str(&(hops + 1).to_string()) {
1105        req.headers_mut()
1106            .insert(axum::http::HeaderName::from_static(PROXY_HOPS_HEADER), v);
1107    }
1108
1109    // Explicitly strip HTTP/2 pseudo-headers (":authority", ":method", etc.)
1110    // before forwarding to an HTTP/1.1 backend. Although hyper typically does
1111    // not store pseudo-headers in the HeaderMap, some middleware layers or
1112    // future hyper versions might; stripping them here is a defensive measure.
1113    let pseudo_headers: Vec<_> = req
1114        .headers()
1115        .keys()
1116        .filter(|k| k.as_str().starts_with(':'))
1117        .cloned()
1118        .collect();
1119    for key in pseudo_headers {
1120        req.headers_mut().remove(&key);
1121    }
1122
1123    // Extract the client-side OnUpgrade handle *before* consuming req
1124    let client_upgrade = hyper::upgrade::on(&mut req);
1125
1126    // Forward the request with a per-request timeout so that a backend that
1127    // accepts the TCP connection but then stalls (deadlock, blocking I/O, etc.)
1128    // cannot hold the proxy connection open forever and exhaust file descriptors.
1129    //
1130    // 120 s is intentionally generous for a local dev proxy — it covers slow
1131    // test suites, large file uploads, and SSE streams while still bounding
1132    // the worst-case resource leak.
1133    let result = match tokio::time::timeout(
1134        std::time::Duration::from_secs(120),
1135        state.client.request(req),
1136    )
1137    .await
1138    {
1139        Ok(r) => r,
1140        Err(_elapsed) => {
1141            let msg = format!(
1142                "Request to daemon on port {target_port} timed out after 120 s.\n\
1143                 The daemon accepted the connection but did not respond in time."
1144            );
1145            log::warn!("{msg}");
1146            if let Some(ref on_error) = state.on_error {
1147                on_error(&msg);
1148            }
1149            return error_response(StatusCode::GATEWAY_TIMEOUT, &msg);
1150        }
1151    };
1152    match result {
1153        Ok(mut resp) => {
1154            // Extract backend upgrade handle *before* consuming resp
1155            let backend_upgrade = hyper::upgrade::on(&mut resp);
1156            let (mut parts, body) = resp.into_parts();
1157
1158            // Add pitchfork identification header
1159            parts.headers.insert(
1160                axum::http::HeaderName::from_static(PITCHFORK_HEADER),
1161                HeaderValue::from_static("1"),
1162            );
1163
1164            // Strip the internal hop-counter so it is never leaked to external clients.
1165            parts.headers.remove(PROXY_HOPS_HEADER);
1166
1167            // Strip hop-by-hop headers when serving HTTPS (HTTP/2 forbids them).
1168            // Skip 101 Switching Protocols — that response is always HTTP/1.1 and
1169            // the client needs the `Upgrade` header to complete the WS handshake
1170            // (RFC 6455 §4.1 requires `Upgrade: websocket` in the 101 response).
1171            if state.is_tls && parts.status != StatusCode::SWITCHING_PROTOCOLS {
1172                for h in HOP_BY_HOP_HEADERS {
1173                    if let Ok(name) = axum::http::HeaderName::from_bytes(h.as_bytes()) {
1174                        parts.headers.remove(&name);
1175                    }
1176                }
1177            }
1178
1179            // If the backend returned 101 Switching Protocols, pipe the upgraded streams.
1180            if parts.status == StatusCode::SWITCHING_PROTOCOLS {
1181                // Note: loop detection for WebSocket upgrades is already handled at the
1182                // top of proxy_handler (hops >= MAX_PROXY_HOPS check) before the request
1183                // is forwarded.  A 101 response here means the backend accepted the
1184                // upgrade, so the hop count was already within limits.
1185                tokio::spawn(async move {
1186                    if let (Ok(client_upgraded), Ok(backend_upgraded)) =
1187                        (client_upgrade.await, backend_upgrade.await)
1188                    {
1189                        let mut client_io = hyper_util::rt::TokioIo::new(client_upgraded);
1190                        let mut backend_io = hyper_util::rt::TokioIo::new(backend_upgraded);
1191                        // No application-level timeout here: tokio::time::timeout would be a
1192                        // hard wall-clock deadline for the entire tunnel, not an idle timeout.
1193                        // Long-lived connections (Vite/webpack HMR, SSE-over-WS) would be
1194                        // silently terminated after the deadline even if data is actively
1195                        // flowing.  The OS TCP keepalive is sufficient to reap truly dead
1196                        // connections; a proper idle timeout would require a custom
1197                        // AsyncRead/AsyncWrite wrapper that resets the timer on each I/O op.
1198                        let _ =
1199                            tokio::io::copy_bidirectional(&mut client_io, &mut backend_io).await;
1200                    }
1201                });
1202                return Response::from_parts(parts, Body::empty());
1203            }
1204
1205            // Backend refused the upgrade (returned a non-101 response) — forward it as-is.
1206            // This can happen when the backend rejects a WebSocket handshake with e.g. 400.
1207            Response::from_parts(parts, Body::new(body))
1208        }
1209        Err(e) => {
1210            let msg = format!(
1211                "Failed to connect to daemon on port {target_port}: {e}\n\
1212                 The daemon may have stopped or is not yet ready."
1213            );
1214            if let Some(ref on_error) = state.on_error {
1215                on_error(&msg);
1216            } else {
1217                log::warn!("{msg}");
1218            }
1219            error_response(StatusCode::BAD_GATEWAY, &msg)
1220        }
1221    }
1222}
1223
1224/// Resolve the target for a given hostname.
1225///
1226/// Slug-based routing using the global config's `[slugs]` section:
1227/// 1. Strip TLD to get subdomain (the slug)
1228/// 2. Look up slug in global config → find project dir + daemon name
1229/// 3. Check state file for a running daemon with that name → get its port
1230/// 4. If `proxy.auto_start` is enabled and the daemon is not running,
1231///    trigger an automatic start and wait for it to become ready.
1232///
1233/// # Returns
1234/// - `ResolveResult::Ready(port)`       — daemon running (or just auto-started), forward to this port
1235/// - `ResolveResult::Starting { slug }` — daemon start in progress (show waiting page)
1236/// - `ResolveResult::NotFound`          — no daemon matched
1237/// - `ResolveResult::Error(msg)`        — routing refused with a descriptive reason
1238///
1239/// # Locking
1240/// The state file lock is held only for the duration of the snapshot copy,
1241/// then released immediately to avoid serialising all proxy requests.
1242async fn resolve_target(host: &str, tld: &str) -> ResolveResult {
1243    // Strip the TLD suffix to get the subdomain part.
1244    let Some(subdomain) = strip_tld(host, tld) else {
1245        return ResolveResult::NotFound;
1246    };
1247
1248    // Look up the slug via the in-memory cache (refreshed every SLUG_CACHE_TTL).
1249    let Some(cached) = cached_slug_lookup(&subdomain).await else {
1250        return ResolveResult::NotFound;
1251    };
1252
1253    let daemon_name = &cached.daemon_name;
1254    let expected_namespace = &cached.namespace;
1255
1256    // Find matching daemons in the state file.
1257    let daemons = {
1258        let state_file = SUPERVISOR.state_file.lock().await;
1259        state_file.daemons.clone()
1260    };
1261
1262    // Find running daemons whose short name matches the slug's daemon name,
1263    // scoped to the slug's registered project namespace when available.
1264    let running_matches: Vec<(&DaemonId, &crate::daemon::Daemon)> = daemons
1265        .iter()
1266        .filter(|(id, d)| {
1267            id.name() == daemon_name
1268                && d.status.is_running()
1269                && match expected_namespace {
1270                    Some(ns) => id.namespace() == ns,
1271                    None => true,
1272                }
1273        })
1274        .collect();
1275
1276    match running_matches.as_slice() {
1277        [] => {
1278            // Daemon not running — try auto-start if enabled.
1279            try_auto_start(&subdomain, &cached).await
1280        }
1281        [(_, d)] => {
1282            if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1283                ResolveResult::Ready(port)
1284            } else {
1285                ResolveResult::NotFound
1286            }
1287        }
1288        _ => {
1289            let d = running_matches[0].1;
1290            if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1291                ResolveResult::Ready(port)
1292            } else {
1293                ResolveResult::NotFound
1294            }
1295        }
1296    }
1297}
1298
1299/// RAII guard that removes a `DaemonId` from `AUTO_START_IN_PROGRESS` on drop.
1300///
1301/// This ensures the in-progress flag is cleared even if the auto-start future
1302/// panics (e.g. an unexpected `unwrap` inside a dependency).  Without this,
1303/// the daemon ID would stay in the set permanently and every subsequent proxy
1304/// request would return "Starting …" forever.
1305struct AutoStartGuard {
1306    daemon_id: DaemonId,
1307}
1308
1309impl Drop for AutoStartGuard {
1310    fn drop(&mut self) {
1311        let daemon_id = self.daemon_id.clone();
1312        // Spawn a cleanup task because `Drop` is synchronous and the mutex is
1313        // async.  If the runtime is shutting down this may not execute, but in
1314        // that case the entire set is being dropped anyway.
1315        tokio::spawn(async move {
1316            AUTO_START_IN_PROGRESS.lock().await.remove(&daemon_id);
1317        });
1318    }
1319}
1320
1321/// Attempt to auto-start a daemon for the given slug.
1322///
1323/// If `proxy.auto_start` is disabled, returns `NotFound`.
1324/// Uses a dedup set to prevent concurrent starts for the same daemon.
1325/// Calls `SUPERVISOR.run()` with `wait_ready = true` so the daemon goes
1326/// through the same readiness lifecycle as `pf start`, then polls for the
1327/// active port.
1328///
1329/// The entire operation — including `SUPERVISOR.run()` and the port-polling
1330/// loop — is bounded by `proxy_auto_start_timeout`.
1331async fn try_auto_start(slug: &str, cached: &CachedSlugEntry) -> ResolveResult {
1332    let s = settings();
1333    if !s.proxy.auto_start {
1334        return ResolveResult::NotFound;
1335    }
1336
1337    // Resolve the daemon ID from the slug's project directory.
1338    // Fall back to "global" when no namespace is resolved so that global
1339    // daemons can also benefit from auto-start (matching the name-only
1340    // routing fallback in `resolve_target`).
1341    let ns = cached
1342        .namespace
1343        .clone()
1344        .unwrap_or_else(|| "global".to_string());
1345    let daemon_id = match DaemonId::try_new(&ns, &cached.daemon_name) {
1346        Ok(id) => id,
1347        Err(_) => return ResolveResult::NotFound,
1348    };
1349
1350    // Atomically check-and-mark the daemon as in-progress so that concurrent
1351    // requests for the same stopped daemon don't trigger multiple starts.
1352    {
1353        let mut in_progress = AUTO_START_IN_PROGRESS.lock().await;
1354        if !in_progress.insert(daemon_id.clone()) {
1355            return ResolveResult::Starting {
1356                slug: slug.to_string(),
1357            };
1358        }
1359    }
1360
1361    // RAII guard: ensures the in-progress flag is cleared even on panic.
1362    let _guard = AutoStartGuard {
1363        daemon_id: daemon_id.clone(),
1364    };
1365
1366    // Apply proxy_auto_start_timeout to the *entire* auto-start operation,
1367    // including SUPERVISOR.run() (which waits for the daemon's readiness
1368    // signal) and the subsequent port-detection polling loop.
1369    let timeout = s.proxy_auto_start_timeout();
1370
1371    match tokio::time::timeout(timeout, try_auto_start_inner(slug, cached, &daemon_id)).await {
1372        Ok(result) => result,
1373        Err(_elapsed) => {
1374            log::warn!("Auto-start: total timeout ({timeout:?}) exceeded for daemon {daemon_id}");
1375            ResolveResult::Error(format!(
1376                "Auto-start for '{daemon_id}' timed out after {timeout:?}.\n\
1377                 The daemon did not become ready and bind a port within the configured \
1378                 proxy_auto_start_timeout.\n\
1379                 Increase the timeout or check the daemon's logs for slow startup."
1380            ))
1381        }
1382    }
1383}
1384
1385/// Inner implementation of [`try_auto_start`] extracted so that the caller can
1386/// wrap it with `tokio::time::timeout` and unconditionally clean up
1387/// `AUTO_START_IN_PROGRESS` regardless of the outcome.
1388async fn try_auto_start_inner(
1389    slug: &str,
1390    cached: &CachedSlugEntry,
1391    daemon_id: &DaemonId,
1392) -> ResolveResult {
1393    // Load config from the slug's project directory to find the daemon definition.
1394    let pt = match crate::pitchfork_toml::PitchforkToml::all_merged_from(&cached.dir) {
1395        Ok(pt) => pt,
1396        Err(e) => {
1397            log::warn!(
1398                "Auto-start: failed to load config from {}: {e}",
1399                cached.dir.display()
1400            );
1401            return ResolveResult::NotFound;
1402        }
1403    };
1404
1405    let daemon_config = match pt.daemons.get(daemon_id) {
1406        Some(cfg) => cfg,
1407        None => {
1408            log::debug!(
1409                "Auto-start: daemon {daemon_id} not found in config at {}",
1410                cached.dir.display()
1411            );
1412            return ResolveResult::NotFound;
1413        }
1414    };
1415
1416    // Build run options — keep wait_ready=true (set by build_run_options) so
1417    // SUPERVISOR.run() waits for the daemon's readiness signal before returning,
1418    // matching the same lifecycle as `pf start` via IPC.
1419    let opts = crate::ipc::batch::StartOptions::default();
1420    let mut run_opts = match crate::ipc::batch::build_run_options(daemon_id, daemon_config, &opts) {
1421        Ok(o) => o,
1422        Err(e) => {
1423            log::warn!("Auto-start: failed to build run options for {daemon_id}: {e}");
1424            return ResolveResult::Error(format!("Failed to build run options: {e}"));
1425        }
1426    };
1427
1428    if run_opts.dir.0.as_os_str().is_empty() {
1429        run_opts.dir = crate::config_types::Dir(cached.dir.clone());
1430    }
1431
1432    log::info!("Auto-start: starting daemon {daemon_id} for slug '{slug}'");
1433
1434    // Trigger the start and wait for daemon readiness.
1435    // This call is bounded by the tokio::time::timeout in try_auto_start().
1436    let run_result = SUPERVISOR.run(run_opts).await;
1437
1438    if let Err(e) = run_result {
1439        log::warn!("Auto-start: failed to start daemon {daemon_id}: {e}");
1440        return ResolveResult::Error(format!("Failed to start daemon: {e}"));
1441    }
1442
1443    // Daemon is ready. Poll briefly for the active_port to be detected
1444    // (detect_and_store_active_port runs asynchronously after readiness).
1445    // No per-loop timeout needed — the outer tokio::time::timeout covers this.
1446    let poll_interval = std::time::Duration::from_millis(250);
1447
1448    loop {
1449        let daemons = {
1450            let sf = SUPERVISOR.state_file.lock().await;
1451            sf.daemons.clone()
1452        };
1453
1454        if let Some(d) = daemons.get(daemon_id) {
1455            if d.status.is_running() {
1456                if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1457                    log::info!("Auto-start: daemon {daemon_id} is ready on port {port}");
1458                    return ResolveResult::Ready(port);
1459                }
1460            } else {
1461                log::warn!(
1462                    "Auto-start: daemon {daemon_id} is no longer running (status: {})",
1463                    d.status
1464                );
1465                return ResolveResult::Error(format!(
1466                    "Daemon '{daemon_id}' started but exited unexpectedly.\n\
1467                     Check its logs for errors."
1468                ));
1469            }
1470        } else {
1471            // Daemon not found in state file after a successful run() —
1472            // it was likely cleaned up immediately.  Don't spin until timeout.
1473            log::warn!("Auto-start: daemon {daemon_id} not found in state file after start");
1474            return ResolveResult::Error(format!(
1475                "Daemon '{daemon_id}' started but disappeared from the state file.\n\
1476                 Check its logs for errors."
1477            ));
1478        }
1479
1480        tokio::time::sleep(poll_interval).await;
1481    }
1482}
1483
1484/// Strip the TLD suffix from a hostname, returning the subdomain part.
1485///
1486/// Examples:
1487/// - `api.myproject.localhost` with tld `localhost` → `api.myproject`
1488/// - `api.localhost` with tld `localhost` → `api`
1489/// - `localhost` with tld `localhost` → `None` (no subdomain)
1490fn strip_tld(host: &str, tld: &str) -> Option<String> {
1491    host.strip_suffix(&format!(".{tld}"))
1492        .filter(|s| !s.is_empty())
1493        .map(str::to_string)
1494}
1495
1496/// Build a human-friendly error message for port binding failures.
1497fn bind_error_message(port: u16, err: &std::io::Error) -> String {
1498    if port < 1024 {
1499        format!(
1500            "Failed to bind proxy server to port {port}: {err}\n\
1501             Hint: ports below 1024 require elevated privileges. \
1502             Try: sudo pitchfork supervisor start"
1503        )
1504    } else {
1505        format!(
1506            "Failed to bind proxy server to port {port}: {err}\n\
1507             Hint: another process may already be using this port."
1508        )
1509    }
1510}
1511
1512/// Build an HTML "Starting…" response that auto-refreshes every 2 seconds.
1513///
1514/// Displayed when a proxy request triggers an auto-start for a stopped daemon.
1515/// Once the daemon is ready, the next refresh will proxy normally to the backend.
1516fn starting_html_response(slug: &str, raw_host: &str) -> Response {
1517    let escaped_slug = slug
1518        .replace('&', "&amp;")
1519        .replace('<', "&lt;")
1520        .replace('>', "&gt;")
1521        .replace('"', "&quot;")
1522        .replace('\'', "&#x27;");
1523    let escaped_host = raw_host
1524        .replace('&', "&amp;")
1525        .replace('<', "&lt;")
1526        .replace('>', "&gt;")
1527        .replace('"', "&quot;")
1528        .replace('\'', "&#x27;");
1529
1530    let html = format!(
1531        r##"<!DOCTYPE html>
1532<html lang="en">
1533<head>
1534    <meta charset="UTF-8">
1535    <meta name="viewport" content="width=device-width, initial-scale=1">
1536    <meta http-equiv="refresh" content="2">
1537    <title>Starting {escaped_slug}… — pitchfork</title>
1538    <style>
1539        * {{ margin: 0; padding: 0; box-sizing: border-box; }}
1540        body {{
1541            font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
1542            background: #0f1117;
1543            color: #e1e4e8;
1544            display: flex;
1545            align-items: center;
1546            justify-content: center;
1547            min-height: 100vh;
1548        }}
1549        .container {{
1550            text-align: center;
1551            max-width: 480px;
1552            padding: 2rem;
1553        }}
1554        .spinner {{
1555            width: 48px;
1556            height: 48px;
1557            border: 4px solid rgba(255, 255, 255, 0.1);
1558            border-top-color: #58a6ff;
1559            border-radius: 50%;
1560            animation: spin 0.8s linear infinite;
1561            margin: 0 auto 1.5rem;
1562        }}
1563        @keyframes spin {{
1564            to {{ transform: rotate(360deg); }}
1565        }}
1566        h1 {{
1567            font-size: 1.5rem;
1568            font-weight: 600;
1569            margin-bottom: 0.5rem;
1570        }}
1571        .slug {{
1572            color: #58a6ff;
1573            font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
1574        }}
1575        .host {{
1576            color: #8b949e;
1577            font-size: 0.875rem;
1578            margin-top: 0.25rem;
1579        }}
1580        .hint {{
1581            color: #8b949e;
1582            font-size: 0.8rem;
1583            margin-top: 1.5rem;
1584        }}
1585    </style>
1586</head>
1587<body>
1588    <div class="container">
1589        <div class="spinner"></div>
1590        <h1>Starting <span class="slug">{escaped_slug}</span>…</h1>
1591        <p class="host">{escaped_host}</p>
1592        <p class="hint">This page will refresh automatically when the daemon is ready.</p>
1593    </div>
1594</body>
1595</html>"##
1596    );
1597
1598    Response::builder()
1599        .status(StatusCode::SERVICE_UNAVAILABLE)
1600        .header("content-type", "text/html; charset=utf-8")
1601        .header("retry-after", "2")
1602        .body(Body::from(html))
1603        .unwrap_or_else(|_| (StatusCode::SERVICE_UNAVAILABLE, "Starting…").into_response())
1604}
1605
1606/// Handler that redirects plain-HTTP requests to HTTPS.
1607///
1608/// Used when the proxy is configured for HTTPS but receives a plain-HTTP
1609/// request on the same port (after the first-byte peek determines it is
1610/// not a TLS ClientHello).  Returns a 302 redirect to the HTTPS equivalent.
1611///
1612/// WebSocket upgrade attempts over plain HTTP are rejected with 400
1613/// because WS-over-plain-HTTP to a TLS port is inherently broken.
1614async fn redirect_to_https_handler(req: Request) -> Response {
1615    // Reject WebSocket upgrades over plain HTTP
1616    if req.headers().contains_key("upgrade") {
1617        log::warn!("Dropping plain-HTTP WebSocket upgrade attempt — use wss:// instead of ws://");
1618        return (
1619            StatusCode::BAD_REQUEST,
1620            "WebSocket over plain HTTP is not supported on the HTTPS port. Use wss:// instead.",
1621        )
1622            .into_response();
1623    }
1624
1625    let raw_host = get_request_host(&req);
1626    let Some(raw_host) = raw_host else {
1627        return (StatusCode::BAD_REQUEST, "Missing Host header").into_response();
1628    };
1629
1630    // Strip any incoming port from Host and use the configured HTTPS port.
1631    let hostname = if raw_host.starts_with('[') {
1632        // IPv6: "[::1]:port" or "[::1]"
1633        raw_host
1634            .split_once("]:")
1635            .map(|(host, _)| host)
1636            .unwrap_or(&raw_host)
1637            .trim_start_matches('[')
1638            .trim_end_matches(']')
1639    } else {
1640        // IPv4/hostname: "host:port" or "host"
1641        let mut parts = raw_host.rsplitn(2, ':');
1642        let last = parts.next().unwrap_or(&raw_host);
1643        parts.next().unwrap_or(last)
1644    };
1645
1646    let path = req
1647        .uri()
1648        .path_and_query()
1649        .map(|pq| pq.as_str())
1650        .unwrap_or("/");
1651
1652    let https_port = match u16::try_from(settings().proxy.port).ok().filter(|&p| p > 0) {
1653        Some(443) | None => String::new(),
1654        Some(port) => format!(":{port}"),
1655    };
1656
1657    let host_for_url = if raw_host.starts_with('[') {
1658        format!("[{hostname}]")
1659    } else {
1660        hostname.to_string()
1661    };
1662
1663    let location = format!("https://{host_for_url}{https_port}{path}");
1664    (
1665        StatusCode::FOUND,
1666        [(axum::http::header::LOCATION, location)],
1667    )
1668        .into_response()
1669}
1670
1671/// Build a plain-text error response.
1672fn error_response(status: StatusCode, message: &str) -> Response {
1673    (status, message.to_string()).into_response()
1674}
1675
1676#[cfg(test)]
1677mod tests {
1678    use super::*;
1679
1680    #[test]
1681    fn test_strip_tld() {
1682        assert_eq!(
1683            strip_tld("api.myproject.localhost", "localhost"),
1684            Some("api.myproject".to_string())
1685        );
1686        assert_eq!(
1687            strip_tld("api.localhost", "localhost"),
1688            Some("api".to_string())
1689        );
1690        assert_eq!(strip_tld("localhost", "localhost"), None);
1691        assert_eq!(
1692            strip_tld("api.myproject.test", "test"),
1693            Some("api.myproject".to_string())
1694        );
1695        assert_eq!(strip_tld("other.com", "localhost"), None);
1696    }
1697
1698    #[cfg(feature = "proxy-tls")]
1699    #[test]
1700    fn test_generate_ca() {
1701        let dir = tempfile::tempdir().unwrap();
1702        let cert_path = dir.path().join("ca.pem");
1703        let key_path = dir.path().join("ca-key.pem");
1704
1705        generate_ca(&cert_path, &key_path).unwrap();
1706
1707        assert!(cert_path.exists(), "ca.pem should be created");
1708        assert!(key_path.exists(), "ca-key.pem should be created");
1709
1710        let cert_pem = std::fs::read_to_string(&cert_path).unwrap();
1711        let key_pem = std::fs::read_to_string(&key_path).unwrap();
1712
1713        assert!(cert_pem.contains("BEGIN CERTIFICATE"), "should be PEM cert");
1714        assert!(
1715            key_pem.contains("BEGIN") && key_pem.contains("PRIVATE KEY"),
1716            "should be PEM key"
1717        );
1718    }
1719}