Skip to main content

pitchfork_cli/proxy/
server.rs

1//! Reverse proxy server implementation.
2//!
3//! Listens on a configured port and routes requests to daemon processes based
4//! on the `Host` header subdomain pattern.
5//!
6//! When `proxy.https = true`, a local CA is auto-generated (via `rcgen`) and
7//! each incoming TLS connection is served with a per-domain certificate signed
8//! by that CA (SNI-based dynamic certificate issuance).
9
10use std::net::SocketAddr;
11use std::sync::Arc;
12
13use axum::Router;
14use axum::body::Body;
15use axum::extract::{Request, State};
16use axum::http::{HeaderValue, StatusCode, Uri};
17use axum::response::{IntoResponse, Response};
18use hyper::header::HOST;
19
20/// Response header used to identify a pitchfork proxy (for health checks and debugging).
21const PITCHFORK_HEADER: &str = "x-pitchfork";
22
23/// Request header tracking how many times a request has passed through the proxy.
24/// Used to detect forwarding loops.
25const PROXY_HOPS_HEADER: &str = "x-pitchfork-hops";
26
27/// Maximum number of proxy hops before rejecting as a loop.
28const MAX_PROXY_HOPS: u64 = 5;
29
30/// HTTP/1.1 hop-by-hop headers that are forbidden in HTTP/2 responses.
31/// These must be stripped when proxying an HTTP/1.1 backend response back to an HTTP/2 client.
32const HOP_BY_HOP_HEADERS: &[&str] = &[
33    "connection",
34    "keep-alive",
35    "proxy-connection",
36    "transfer-encoding",
37    "upgrade",
38];
39
40use hyper_util::client::legacy::Client;
41use hyper_util::client::legacy::connect::HttpConnector;
42use hyper_util::rt::TokioExecutor;
43use tokio::net::TcpListener;
44
45use crate::daemon_id::DaemonId;
46use crate::settings::settings;
47use crate::supervisor::SUPERVISOR;
48
49// ─── Slug resolution cache ──────────────────────────────────────────────────
50//
51// `read_global_slugs()` reads ~/.config/pitchfork/config.toml from disk on every
52// call, and `namespace_for_dir()` traverses the filesystem upward to find the
53// nearest pitchfork.toml.  Both are called from `resolve_target_port()` which
54// sits in the hot path of every proxied HTTP request.
55//
56// This cache stores the resolved slug → (namespace, daemon_name) mapping
57// in memory with a short TTL so that the proxy does zero disk I/O for the vast
58// majority of requests while still picking up config changes within seconds.
59
60/// How long to cache the slug resolution table before re-reading from disk.
61const SLUG_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(2);
62
63/// Cached slug entry: pre-resolved namespace + daemon name for a slug.
64#[derive(Clone, Debug)]
65pub struct CachedSlugEntry {
66    /// The slug key as registered in config (needed for display in auto-start pages).
67    pub slug: String,
68    /// Expected namespace derived from `entry.resolve_dir()` (None if derivation failed).
69    pub namespace: Option<String>,
70    /// Daemon short name (defaults to slug name when not explicitly set).
71    pub daemon_name: String,
72    /// Project directory for this slug (needed for auto-start).
73    pub dir: std::path::PathBuf,
74    /// Worktrees (git) / workspaces (jj) discovered under this slug's project directory.
75    pub worktrees: Vec<crate::proxy::worktree::WorktreeEntry>,
76}
77
78/// In-memory cache for the global slug registry + derived namespaces.
79struct SlugCache {
80    entries: Arc<std::collections::HashMap<String, CachedSlugEntry>>,
81    expires_at: std::time::Instant,
82}
83
84static SLUG_CACHE: once_cell::sync::Lazy<tokio::sync::Mutex<SlugCache>> =
85    once_cell::sync::Lazy::new(|| {
86        tokio::sync::Mutex::new(SlugCache {
87            entries: Arc::new(std::collections::HashMap::new()),
88            expires_at: std::time::Instant::now(), // expired → will be populated on first access
89        })
90    });
91
92/// Build the slug lookup table from disk (expensive — involves file I/O + subprocesses).
93/// Called outside the cache lock via `spawn_blocking` to avoid blocking the Tokio runtime.
94fn build_slug_entries() -> std::collections::HashMap<String, CachedSlugEntry> {
95    let global_slugs = crate::pitchfork_toml::PitchforkToml::read_global_slugs();
96    let mut entries = std::collections::HashMap::with_capacity(global_slugs.len());
97    let worktree_enabled = crate::settings::settings().proxy.worktree;
98    for (slug, entry) in &global_slugs {
99        let ns = entry.resolve_namespace();
100        let daemon_name = entry.daemon.as_deref().unwrap_or(slug).to_string();
101        let worktrees = if worktree_enabled {
102            let wts = match entry.resolve_dir() {
103                Some(dir) => crate::proxy::worktree::discover_worktrees(&dir),
104                None => vec![],
105            };
106            // Warn about sanitized-branch collisions and drop duplicates so that
107            // unreachable entries don't waste memory in the cache.
108            let mut seen = std::collections::HashMap::with_capacity(wts.len());
109            let mut deduped = Vec::with_capacity(wts.len());
110            for mut wt in wts {
111                let wt_ns = crate::pitchfork_toml::PitchforkToml::namespace_for_dir(&wt.path).ok();
112                wt.namespace = wt_ns;
113                match seen.entry(wt.sanitized_branch.clone()) {
114                    std::collections::hash_map::Entry::Occupied(e) => {
115                        log::warn!(
116                            "Worktree slug collision: '{}' and '{}' both sanitize to '{}'. \
117                             Only the first (in discovery order) will be routed.",
118                            e.get(),
119                            wt.branch,
120                            wt.sanitized_branch,
121                        );
122                    }
123                    std::collections::hash_map::Entry::Vacant(e) => {
124                        e.insert(wt.branch.clone());
125                        deduped.push(wt);
126                    }
127                }
128            }
129            deduped
130        } else {
131            vec![]
132        };
133        entries.insert(
134            slug.clone(),
135            CachedSlugEntry {
136                slug: slug.clone(),
137                namespace: ns,
138                daemon_name,
139                dir: entry.resolve_dir().unwrap_or_default(),
140                worktrees,
141            },
142        );
143    }
144    entries
145}
146
147/// Return a snapshot of the cached slug table, refreshing from disk if expired.
148///
149/// The disk I/O happens *outside* the mutex to avoid blocking concurrent requests
150/// during the refresh.  A short race window exists where two threads may both
151/// refresh, but that is harmless (last writer wins with identical data).
152pub async fn get_cached_slugs() -> Arc<std::collections::HashMap<String, CachedSlugEntry>> {
153    // Fast path: cache still valid — just clone the Arc.
154    {
155        let cache = SLUG_CACHE.lock().await;
156        if std::time::Instant::now() < cache.expires_at {
157            return Arc::clone(&cache.entries);
158        }
159    } // lock released before disk I/O
160
161    // Slow path: refresh from disk on a blocking thread (involves subprocess calls).
162    let new_entries = Arc::new(
163        tokio::task::spawn_blocking(build_slug_entries)
164            .await
165            .unwrap_or_else(|e| {
166                log::warn!("Failed to refresh slug cache: {e}");
167                std::collections::HashMap::new()
168            }),
169    );
170
171    // Store the refreshed entries.
172    {
173        let mut cache = SLUG_CACHE.lock().await;
174        cache.entries = Arc::clone(&new_entries);
175        cache.expires_at = std::time::Instant::now() + SLUG_CACHE_TTL;
176    }
177
178    new_entries
179}
180
181/// Try to match a subdomain against a slug table, with optional wildcard fallback.
182///
183/// When `wildcard` is true and no exact match is found, progressively strips
184/// subdomain prefixes from the left until a match is found or no dots remain.
185/// For example, with slug "myapp" registered, `tenant.myapp` matches "myapp".
186fn wildcard_slug_lookup<'a>(
187    subdomain: &str,
188    entries: &'a std::collections::HashMap<String, CachedSlugEntry>,
189    wildcard: bool,
190) -> Option<&'a CachedSlugEntry> {
191    entries.get(subdomain).or_else(|| {
192        if !wildcard {
193            return None;
194        }
195        // "a.b.myapp" has dots at 1,3 → "b.myapp", "myapp"
196        subdomain
197            .match_indices('.')
198            .map(|(i, _)| &subdomain[i + 1..])
199            .find_map(|candidate| entries.get(candidate))
200    })
201}
202
203/// Look up a slug in the cached table.
204///
205/// With wildcard enabled (default), falls back to progressively shorter
206/// subdomain suffixes when an exact match is not found.  For example,
207/// `tenant.myapp` will match slug `myapp` if no slug named `tenant.myapp`
208/// exists.
209async fn cached_slug_lookup(subdomain: &str) -> Option<CachedSlugEntry> {
210    let entries = get_cached_slugs().await;
211    wildcard_slug_lookup(subdomain, &entries, settings().proxy.wildcard).cloned()
212}
213
214// ─── Auto-start deduplication ───────────────────────────────────────────────
215//
216// When auto_start is enabled, concurrent proxy requests for the same stopped
217// daemon must not trigger multiple start operations.  This set tracks daemon
218// IDs that are currently being auto-started.
219
220static AUTO_START_IN_PROGRESS: once_cell::sync::Lazy<
221    tokio::sync::Mutex<std::collections::HashSet<DaemonId>>,
222> = once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new(std::collections::HashSet::new()));
223
224/// Result of resolving a proxy target for a given host.
225enum ResolveResult {
226    /// Daemon is running and ready — forward to this port.
227    /// Covers both already-running daemons and freshly auto-started ones.
228    Ready(u16),
229    /// Daemon is currently starting (auto-start in progress or just triggered).
230    Starting { slug: String },
231    /// No matching slug or daemon found.
232    NotFound,
233    /// Routing refused with a descriptive reason.
234    Error(String),
235}
236
237/// Shared proxy state passed to each request handler.
238/// Callback type invoked on proxy errors (e.g. for logging/alerting).
239type OnErrorFn = Arc<dyn Fn(&str) + Send + Sync>;
240
241#[derive(Clone)]
242struct ProxyState {
243    /// HTTP client used to forward requests to daemon backends.
244    client: Arc<Client<HttpConnector, Body>>,
245    /// The configured TLD (e.g. "localhost").
246    tld: String,
247    /// Whether the proxy is serving HTTPS.
248    is_tls: bool,
249    /// Optional error callback invoked on proxy errors (e.g. for logging/alerting).
250    on_error: Option<OnErrorFn>,
251}
252
253/// Start the reverse proxy server.
254///
255/// Binds to the configured port and serves until the process exits.
256/// When `proxy.https = true`, TLS is terminated here using a self-signed
257/// certificate (auto-generated if not present).
258///
259/// This function is intended to be spawned as a background task.
260pub async fn serve(
261    bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
262    cancel: tokio_util::sync::CancellationToken,
263) -> crate::Result<()> {
264    let s = settings();
265    let lan_enabled = s.proxy.lan || !s.proxy.lan_ip.is_empty();
266
267    let effective_tld = if lan_enabled {
268        "local".to_string()
269    } else {
270        s.proxy.tld.clone()
271    };
272
273    let Some(effective_port) = u16::try_from(s.proxy.port).ok().filter(|&p| p > 0) else {
274        let msg = format!(
275            "proxy.port {} is out of valid port range (1-65535), proxy server cannot start",
276            s.proxy.port
277        );
278        let _ = bind_tx.send(Err(msg.clone()));
279        miette::bail!("{msg}");
280    };
281
282    let mut connector = HttpConnector::new();
283    // Limit how long the proxy waits to establish a TCP connection to a backend.
284    // Without this, a daemon that accepts the SYN but never completes the handshake
285    // would stall the proxy indefinitely.
286    connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
287
288    let client = Client::builder(TokioExecutor::new())
289        // Reclaim idle keep-alive connections after 30 s so that file descriptors
290        // are not held open forever when a backend goes quiet.
291        .pool_idle_timeout(std::time::Duration::from_secs(30))
292        .build(connector);
293
294    let state = ProxyState {
295        client: Arc::new(client),
296        tld: effective_tld.clone(),
297        is_tls: s.proxy.https,
298        on_error: None,
299    };
300
301    let app = Router::new().fallback(proxy_handler).with_state(state);
302
303    // Resolve bind address from settings.
304    // In LAN mode, default to 0.0.0.0 so the proxy is reachable from other
305    // devices on the network.  Users can still override with proxy.host.
306    let bind_ip: std::net::IpAddr = if lan_enabled && s.proxy.host == "127.0.0.1" {
307        std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)
308    } else {
309        match s.proxy.host.parse() {
310            Ok(ip) => ip,
311            Err(_) => {
312                log::warn!(
313                    "proxy.host {:?} is not a valid IP address — falling back to 127.0.0.1. \
314                     The proxy will only be reachable on the loopback interface.",
315                    s.proxy.host
316                );
317                std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
318            }
319        }
320    };
321    let addr = SocketAddr::from((bind_ip, effective_port));
322
323    if s.proxy.https {
324        serve_https_with_http_fallback(app, addr, &s, effective_port, bind_tx, cancel).await
325    } else {
326        serve_http(app, addr, effective_port, bind_tx, cancel).await
327    }
328}
329
330/// Serve plain HTTP.
331async fn serve_http(
332    app: Router,
333    addr: SocketAddr,
334    effective_port: u16,
335    bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
336    cancel: tokio_util::sync::CancellationToken,
337) -> crate::Result<()> {
338    let listener = match TcpListener::bind(addr).await {
339        Ok(l) => {
340            if settings().proxy.sync_hosts {
341                crate::proxy::hosts::sync_hosts_from_settings();
342            }
343            let _ = bind_tx.send(Ok(()));
344            l
345        }
346        Err(e) => {
347            let msg = bind_error_message(effective_port, &e);
348            let _ = bind_tx.send(Err(msg.clone()));
349            return Err(miette::miette!("{msg}"));
350        }
351    };
352
353    log::info!("Proxy server listening on http://{addr}");
354    if effective_port < 1024 {
355        log::info!(
356            "Note: port {effective_port} is a privileged port. \
357             The supervisor must be started with sudo to bind to this port."
358        );
359    }
360    let shutdown_signal = cancel.clone().cancelled_owned();
361    axum::serve(
362        listener,
363        app.into_make_service_with_connect_info::<SocketAddr>(),
364    )
365    .with_graceful_shutdown(shutdown_signal)
366    .await
367    .map_err(|e| miette::miette!("Proxy server error: {e}"))?;
368    Ok(())
369}
370
371/// Serve HTTPS with automatic HTTP detection on the same port.
372///
373/// Peeks at the first byte of each incoming TCP connection:
374/// - `0x16` (TLS ClientHello) → hand off to the TLS acceptor (HTTP/2 + HTTP/1.1 via ALPN)
375/// - anything else → 302 redirect to HTTPS
376#[cfg(feature = "proxy-tls")]
377async fn serve_https_with_http_fallback(
378    app: Router,
379    addr: SocketAddr,
380    s: &crate::settings::Settings,
381    effective_port: u16,
382    bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
383    cancel: tokio_util::sync::CancellationToken,
384) -> crate::Result<()> {
385    use rustls::ServerConfig;
386    use tokio_rustls::TlsAcceptor;
387
388    let (ca_cert_path, ca_key_path) = resolve_tls_paths(s);
389
390    // Generate CA if not present
391    if !ca_cert_path.exists() || !ca_key_path.exists() {
392        generate_ca(&ca_cert_path, &ca_key_path)?;
393        log::info!(
394            "Generated local CA certificate at {}",
395            ca_cert_path.display()
396        );
397        log::info!("To trust the CA in your browser, run: pitchfork proxy trust");
398    }
399
400    // Install ring as the default CryptoProvider if none has been set yet.
401    let _ = rustls::crypto::ring::default_provider().install_default();
402
403    // Build the SNI resolver (loads CA, caches per-domain certs)
404    let resolver = SniCertResolver::new(&ca_cert_path, &ca_key_path)?;
405
406    let mut tls_config = ServerConfig::builder()
407        .with_no_client_auth()
408        .with_cert_resolver(Arc::new(resolver));
409    // Advertise HTTP/2 and HTTP/1.1 via ALPN so browsers negotiate HTTP/2
410    // for multiplexed requests (eliminates the 6-connection-per-host limit).
411    tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
412
413    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
414
415    let listener = match TcpListener::bind(addr).await {
416        Ok(l) => {
417            if settings().proxy.sync_hosts {
418                crate::proxy::hosts::sync_hosts_from_settings();
419            }
420            let _ = bind_tx.send(Ok(()));
421            l
422        }
423        Err(e) => {
424            let msg = bind_error_message(effective_port, &e);
425            let _ = bind_tx.send(Err(msg.clone()));
426            return Err(miette::miette!("{msg}"));
427        }
428    };
429
430    log::info!("Proxy server listening on https://{addr} (HTTP also accepted)");
431    if effective_port < 1024 {
432        log::info!(
433            "Note: port {effective_port} is a privileged port. \
434             The supervisor must be started with sudo to bind to this port."
435        );
436    }
437
438    // Build a lightweight redirect app for plain-HTTP requests.
439    let redirect_app = Router::new().fallback(redirect_to_https_handler);
440
441    // Accept connections and sniff the first byte to decide TLS vs plain HTTP.
442    let mut conn_tasks: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
443    loop {
444        // Reap finished connection tasks during normal operation so the JoinSet
445        // does not retain one entry per historical connection.
446        while conn_tasks.try_join_next().is_some() {}
447
448        tokio::select! {
449            accept_result = listener.accept() => {
450                let (stream, _peer_addr) = match accept_result {
451                    Ok(conn) => conn,
452                    Err(e) => {
453                        log::warn!("Accept error (will retry): {e}");
454                        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
455                        continue;
456                    }
457                };
458
459                let acceptor = acceptor.clone();
460                let app = app.clone();
461                let redirect_app = redirect_app.clone();
462
463                conn_tasks.spawn(async move {
464                    // Peek at the first byte without consuming it.
465                    // TLS ClientHello always starts with 0x16 (content type "handshake").
466                    let mut peek_buf = [0u8; 1];
467                    match stream.peek(&mut peek_buf).await {
468                        Ok(0) | Err(_) => return,
469                        _ => {}
470                    }
471
472                    if peek_buf[0] == 0x16 {
473                        // TLS handshake → HTTP/2 or HTTP/1.1 (negotiated via ALPN)
474                        match acceptor.accept(stream).await {
475                            Ok(tls_stream) => {
476                                let io = hyper_util::rt::TokioIo::new(tls_stream);
477                                let svc = hyper_util::service::TowerToHyperService::new(app);
478                                if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
479                                    .serve_connection_with_upgrades(io, svc)
480                                    .await
481                                {
482                                    // HTTP/2 RST_STREAM errors from cancelled browser requests
483                                    // (navigation, HMR) are normal — log at debug to avoid noise.
484                                    log::debug!("Connection error: {e}");
485                                }
486                            }
487                            Err(e) => {
488                                log::debug!("TLS handshake error: {e}");
489                            }
490                        }
491                    } else {
492                        // Plain HTTP on the TLS port → 302 redirect to HTTPS
493                        let io = hyper_util::rt::TokioIo::new(stream);
494                        let svc = hyper_util::service::TowerToHyperService::new(redirect_app);
495                        let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
496                            .serve_connection_with_upgrades(io, svc)
497                            .await;
498                    }
499                });
500
501                while conn_tasks.try_join_next().is_some() {}
502            }
503            _ = cancel.cancelled() => {
504                log::info!("Proxy server shutting down (cancel signal received)");
505                break;
506            }
507        }
508    }
509
510    // Drain in-flight connections with a timeout.
511    let drain_timeout = std::time::Duration::from_secs(10);
512    let _ = tokio::time::timeout(drain_timeout, async {
513        while conn_tasks.join_next().await.is_some() {}
514    })
515    .await;
516
517    Ok(())
518}
519
520/// Fallback when proxy-tls feature is not enabled.
521#[cfg(not(feature = "proxy-tls"))]
522async fn serve_https_with_http_fallback(
523    _app: Router,
524    _addr: SocketAddr,
525    _s: &crate::settings::Settings,
526    _effective_port: u16,
527    bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
528    _cancel: tokio_util::sync::CancellationToken,
529) -> crate::Result<()> {
530    let msg = "HTTPS proxy support requires the `proxy-tls` feature.\n\
531         Rebuild pitchfork with: cargo build --features proxy-tls"
532        .to_string();
533    let _ = bind_tx.send(Err(msg.clone()));
534    miette::bail!("{msg}")
535}
536
537/// Resolve the CA certificate and key paths from settings.
538///
539/// If `tls_cert` / `tls_key` are empty, falls back to the auto-generated
540/// CA paths in `$PITCHFORK_STATE_DIR/proxy/`.
541#[cfg(feature = "proxy-tls")]
542fn resolve_tls_paths(s: &crate::settings::Settings) -> (std::path::PathBuf, std::path::PathBuf) {
543    let proxy_dir = crate::env::PITCHFORK_STATE_DIR.join("proxy");
544    let resolve = |configured: &str, default: &str| {
545        if configured.is_empty() {
546            proxy_dir.join(default)
547        } else {
548            std::path::PathBuf::from(configured)
549        }
550    };
551    (
552        resolve(&s.proxy.tls_cert, "ca.pem"),
553        resolve(&s.proxy.tls_key, "ca-key.pem"),
554    )
555}
556
557/// Generate a local root CA certificate and private key using `rcgen`.
558///
559/// The CA is used to sign per-domain certificates on demand (SNI).
560/// Files are written in PEM format to `cert_path` and `key_path`.
561#[cfg(feature = "proxy-tls")]
562pub fn generate_ca(cert_path: &std::path::Path, key_path: &std::path::Path) -> crate::Result<()> {
563    use rcgen::{
564        BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyUsagePurpose,
565    };
566
567    // Create parent directory if needed
568    if let Some(parent) = cert_path.parent() {
569        std::fs::create_dir_all(parent)
570            .map_err(|e| miette::miette!("Failed to create proxy cert directory: {e}"))?;
571    }
572
573    let mut params = CertificateParams::default();
574    let mut dn = DistinguishedName::new();
575    dn.push(DnType::CommonName, "Pitchfork Local CA");
576    dn.push(DnType::OrganizationName, "Pitchfork");
577    params.distinguished_name = dn;
578    params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
579    params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
580
581    let key_pair = rcgen::KeyPair::generate()
582        .map_err(|e| miette::miette!("Failed to generate CA key pair: {e}"))?;
583    let ca_cert = params
584        .self_signed(&key_pair)
585        .map_err(|e| miette::miette!("Failed to self-sign CA certificate: {e}"))?;
586
587    // Write the CA certificate (public — 0644 is fine)
588    std::fs::write(cert_path, ca_cert.pem()).map_err(|e| {
589        miette::miette!(
590            "Failed to write CA certificate to {}: {e}",
591            cert_path.display()
592        )
593    })?;
594
595    // Write the CA private key with restrictive permissions (0600).
596    // Using OpenOptions + mode() so the file is never world-readable,
597    // even briefly before a chmod call.
598    {
599        #[cfg(unix)]
600        {
601            use std::io::Write;
602            use std::os::unix::fs::OpenOptionsExt;
603            std::fs::OpenOptions::new()
604                .write(true)
605                .create(true)
606                .truncate(true)
607                .mode(0o600)
608                .open(key_path)
609                .and_then(|mut f| f.write_all(key_pair.serialize_pem().as_bytes()))
610                .map_err(|e| {
611                    miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
612                })?;
613        }
614        #[cfg(not(unix))]
615        {
616            std::fs::write(key_path, key_pair.serialize_pem()).map_err(|e| {
617                miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
618            })?;
619            log::debug!(
620                "CA private key written to {} (file permissions are not restricted \
621                 on non-Unix platforms — consider restricting access manually)",
622                key_path.display()
623            );
624        }
625    }
626
627    Ok(())
628}
629
630/// SNI-based certificate resolver.
631///
632/// Holds the local CA and a two-level cache of per-domain certificates:
633/// - L1: in-memory `HashMap` (fastest, process-lifetime)
634/// - L2: on-disk `host-certs/<safe_name>.pem` (survives restarts)
635///
636/// A `pending` set prevents concurrent requests for the same domain from
637/// triggering multiple simultaneous cert-generation operations.
638///
639/// On each new TLS connection, `resolve()` is called with the SNI hostname;
640/// if no cached cert exists for that domain, one is signed by the CA on the fly.
641///
642/// # Locking strategy
643/// Both `cache` and `pending` use `std::sync::Mutex` paired with a
644/// `std::sync::Condvar`.  The critical sections are intentionally short
645/// (hash-map lookups / inserts), so the blocking time is negligible.
646/// `get_or_create` is only called from the synchronous `ResolvesServerCert`
647/// trait method (not from an async context), so blocking a thread here is
648/// acceptable.
649#[cfg(feature = "proxy-tls")]
650struct SniCertResolver {
651    /// The CA issuer (key + parsed cert params, used to sign leaf certs).
652    issuer: rcgen::Issuer<'static, rcgen::KeyPair>,
653    /// Directory where per-domain PEM files are cached on disk.
654    host_certs_dir: std::path::PathBuf,
655    /// L1 cache: domain → certified key (in-memory).
656    cache: std::sync::Mutex<std::collections::HashMap<String, Arc<rustls::sign::CertifiedKey>>>,
657    /// Pending set: domains currently being generated (dedup concurrent requests).
658    /// Using a `Condvar` so waiting threads are parked instead of spin-sleeping,
659    /// which avoids blocking tokio worker threads.
660    pending: std::sync::Mutex<std::collections::HashSet<String>>,
661    /// Condvar paired with `pending` — notified when a domain is removed from the set.
662    pending_cv: std::sync::Condvar,
663}
664
665#[cfg(feature = "proxy-tls")]
666impl std::fmt::Debug for SniCertResolver {
667    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
668        f.debug_struct("SniCertResolver").finish_non_exhaustive()
669    }
670}
671
672#[cfg(feature = "proxy-tls")]
673impl SniCertResolver {
674    /// Load the CA from disk and prepare the resolver.
675    fn new(ca_cert_path: &std::path::Path, ca_key_path: &std::path::Path) -> crate::Result<Self> {
676        let ca_key_pem = std::fs::read_to_string(ca_key_path)
677            .map_err(|e| miette::miette!("Failed to read CA key {}: {e}", ca_key_path.display()))?;
678        let ca_cert_pem = std::fs::read_to_string(ca_cert_path).map_err(|e| {
679            miette::miette!("Failed to read CA cert {}: {e}", ca_cert_path.display())
680        })?;
681
682        // Verify the PEM is readable (sanity check)
683        if !ca_cert_pem.contains("BEGIN CERTIFICATE") {
684            miette::bail!("CA cert file does not contain a valid PEM certificate");
685        }
686
687        let ca_key = rcgen::KeyPair::from_pem(&ca_key_pem)
688            .map_err(|e| miette::miette!("Failed to parse CA key: {e}"))?;
689
690        // Parse the CA cert + key into an Issuer for signing leaf certs.
691        let issuer = rcgen::Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key)
692            .map_err(|e| miette::miette!("Failed to parse CA cert: {e}"))?;
693
694        // Ensure the host-certs directory exists
695        let host_certs_dir = ca_cert_path
696            .parent()
697            .unwrap_or(std::path::Path::new("."))
698            .join("host-certs");
699        std::fs::create_dir_all(&host_certs_dir)
700            .map_err(|e| miette::miette!("Failed to create host-certs dir: {e}"))?;
701
702        Ok(Self {
703            issuer,
704            host_certs_dir,
705            cache: std::sync::Mutex::new(std::collections::HashMap::new()),
706            pending: std::sync::Mutex::new(std::collections::HashSet::new()),
707            pending_cv: std::sync::Condvar::new(),
708        })
709    }
710
711    /// Get or create a `CertifiedKey` for the given domain.
712    ///
713    /// Resolution order:
714    /// 1. L1 in-memory cache
715    /// 2. L2 on-disk cache (`host-certs/<safe_name>.pem`)
716    /// 3. Generate fresh cert, persist to disk, populate both caches
717    ///
718    /// Concurrent requests for the same domain are deduplicated: the second
719    /// thread waits on a `Condvar` until the first thread finishes, then reads
720    /// from the cache.  This avoids both duplicate cert generation and the
721    /// spin-sleep anti-pattern that would block tokio worker threads.
722    ///
723    /// # Locking discipline
724    /// `cache` and `pending` are **never held simultaneously**.  The protocol is:
725    /// 1. Check `cache` (lock, read, unlock).
726    /// 2. Acquire `pending`; wait if domain is in-progress; re-check `cache`
727    ///    after waking (unlock `cache` before re-acquiring `pending` is not
728    ///    needed because we release `cache` before entering the `pending` block).
729    /// 3. Insert domain into `pending`; release `pending` lock.
730    /// 4. Generate cert (no locks held).
731    /// 5. Insert into `cache` (lock, write, unlock).
732    /// 6. Remove from `pending` and notify (lock, write, unlock).
733    fn get_or_create(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
734        // L1: memory cache (fast path — no pending lock needed)
735        {
736            let cache = self.cache.lock().ok()?;
737            if let Some(ck) = cache.get(domain) {
738                return Some(Arc::clone(ck));
739            }
740        } // cache lock released here
741
742        // Dedup: acquire the pending lock, wait if another thread is generating
743        // this domain, then re-check the cache (without holding pending) before
744        // deciding to generate.
745        //
746        // We deliberately release the pending lock before re-checking the cache
747        // to avoid holding both locks simultaneously.  The re-check is safe
748        // because: if the generating thread inserted into the cache and then
749        // removed from pending, we will see the cert in the cache.  If we miss
750        // the window (extremely unlikely), we will generate a duplicate cert,
751        // which is harmless — the last writer wins in the cache.
752        loop {
753            {
754                let mut pending = self.pending.lock().ok()?;
755                if pending.contains(domain) {
756                    // Another thread is generating; wait until it finishes.
757                    pending = self.pending_cv.wait(pending).ok()?;
758                    // pending lock re-acquired; loop to re-check cache below.
759                    drop(pending);
760                } else {
761                    // No one else is generating; claim the slot and proceed.
762                    pending.insert(domain.to_string());
763                    break;
764                }
765            } // pending lock released
766
767            // Re-check cache after being woken (the generating thread may have
768            // already populated it).  Cache lock is acquired independently of
769            // pending lock here — no nesting.
770            {
771                let cache = self.cache.lock().ok()?;
772                if let Some(ck) = cache.get(domain) {
773                    return Some(Arc::clone(ck));
774                }
775            } // cache lock released
776        } // pending lock released at break
777
778        let result = self.get_or_create_inner(domain);
779
780        // Always clear the pending flag and wake waiting threads.
781        // notify_all() is called *inside* the lock scope so that the domain is
782        // guaranteed to be removed before any waiting thread is woken up.
783        // If the lock is poisoned we recover it (the data is still valid) so
784        // that the domain is always removed and waiters are always notified.
785        {
786            let mut pending = match self.pending.lock() {
787                Ok(g) => g,
788                Err(e) => e.into_inner(),
789            };
790            pending.remove(domain);
791            self.pending_cv.notify_all();
792        }
793
794        result
795    }
796
797    /// Inner implementation: check disk cache, then generate.
798    fn get_or_create_inner(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
799        let safe_name = domain.replace('.', "_").replace('*', "wildcard");
800        let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
801
802        // L2: disk cache — try to load existing cert+key PEM
803        if disk_path.exists() {
804            if let Ok(ck) = self.load_from_disk(&disk_path) {
805                let ck = Arc::new(ck);
806                if let Ok(mut cache) = self.cache.lock() {
807                    cache.insert(domain.to_string(), Arc::clone(&ck));
808                }
809                return Some(ck);
810            }
811            // Disk cache corrupt/expired — fall through to regenerate
812            let _ = std::fs::remove_file(&disk_path);
813        }
814
815        // L3: generate fresh cert
816        let ck = self.sign_for_domain(domain).ok()?;
817
818        let ck = Arc::new(ck);
819        if let Ok(mut cache) = self.cache.lock() {
820            cache.insert(domain.to_string(), Arc::clone(&ck));
821        }
822        Some(ck)
823    }
824
825    /// Load a `CertifiedKey` from a combined cert+key PEM file on disk.
826    ///
827    /// Returns an error if the certificate has already expired, so the caller
828    /// can fall through to regeneration rather than serving a stale cert.
829    fn load_from_disk(&self, path: &std::path::Path) -> crate::Result<rustls::sign::CertifiedKey> {
830        use rustls::pki_types::CertificateDer;
831        use rustls_pemfile::{certs, private_key};
832
833        let pem = std::fs::read_to_string(path)
834            .map_err(|e| miette::miette!("Failed to read disk cert {}: {e}", path.display()))?;
835
836        let cert_ders: Vec<CertificateDer<'static>> = certs(&mut pem.as_bytes())
837            .collect::<Result<Vec<_>, _>>()
838            .map_err(|e| miette::miette!("Failed to parse certs from {}: {e}", path.display()))?;
839
840        if cert_ders.is_empty() {
841            miette::bail!("No certificates found in {}", path.display());
842        }
843
844        // Check that the first certificate has not expired using x509-parser.
845        {
846            let (_, cert) = x509_parser::parse_x509_certificate(&cert_ders[0]).map_err(|e| {
847                miette::miette!("Failed to parse certificate from {}: {e}", path.display())
848            })?;
849            use chrono::Utc;
850            let now_ts = Utc::now().timestamp();
851            let not_after_ts = cert.validity().not_after.timestamp();
852            if not_after_ts < now_ts {
853                miette::bail!(
854                    "Cached certificate at {} has expired — will regenerate",
855                    path.display()
856                );
857            }
858        }
859
860        let key_der = private_key(&mut pem.as_bytes())
861            .map_err(|e| miette::miette!("Failed to parse key from {}: {e}", path.display()))?
862            .ok_or_else(|| miette::miette!("No private key found in {}", path.display()))?;
863
864        let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
865            .map_err(|e| miette::miette!("Failed to create signing key from disk: {e}"))?;
866
867        Ok(rustls::sign::CertifiedKey::new(cert_ders, signing_key))
868    }
869
870    /// Sign a leaf certificate for `domain` using the CA.
871    ///
872    /// SANs include:
873    /// - `DNS:<domain>` (exact match)
874    /// - `DNS:*.<parent>` (sibling wildcard, e.g. `*.pf.localhost` for `docs.pf.localhost`)
875    ///
876    /// Returns both the `CertifiedKey` and the combined PEM for disk caching.
877    fn sign_for_domain(&self, domain: &str) -> crate::Result<rustls::sign::CertifiedKey> {
878        use rcgen::date_time_ymd;
879        use rcgen::{CertificateParams, DistinguishedName, DnType, SanType};
880        use rustls::pki_types::CertificateDer;
881        use rustls_pemfile::private_key;
882
883        let mut params = CertificateParams::default();
884        let mut dn = DistinguishedName::new();
885        dn.push(DnType::CommonName, domain);
886        params.distinguished_name = dn;
887
888        // Set validity dynamically: from yesterday to 10 years from now.
889        {
890            use chrono::{Datelike, Duration, Utc};
891            let yesterday = Utc::now() - Duration::days(1);
892            // 397 days: stays within Chrome/Safari's 398-day maximum validity limit
893            // for TLS certificates (including locally-trusted CA leaf certs).
894            let expiry = Utc::now() + Duration::days(397);
895            params.not_before = date_time_ymd(
896                yesterday.year(),
897                yesterday.month() as u8,
898                yesterday.day() as u8,
899            );
900            params.not_after =
901                date_time_ymd(expiry.year(), expiry.month() as u8, expiry.day() as u8);
902        }
903
904        // Build SANs: exact domain + sibling wildcard (e.g. *.pf.localhost)
905        let mut sans =
906            vec![SanType::DnsName(domain.to_string().try_into().map_err(
907                |e| miette::miette!("Invalid domain name '{domain}': {e}"),
908            )?)];
909        // Add wildcard SAN for the parent domain (one level up)
910        if let Some(dot_pos) = domain.find('.') {
911            let parent = &domain[dot_pos + 1..];
912            // Only add wildcard if parent has at least one dot (not a bare TLD)
913            if parent.contains('.') {
914                let wildcard = format!("*.{parent}");
915                if let Ok(wc) = wildcard.try_into() {
916                    sans.push(SanType::DnsName(wc));
917                }
918            }
919        }
920        params.subject_alt_names = sans;
921
922        let leaf_key = rcgen::KeyPair::generate()
923            .map_err(|e| miette::miette!("Failed to generate leaf key: {e}"))?;
924        let leaf_cert = params
925            .signed_by(&leaf_key, &self.issuer)
926            .map_err(|e| miette::miette!("Failed to sign leaf cert for '{domain}': {e}"))?;
927
928        // Convert to rustls types
929        let cert_der = CertificateDer::from(leaf_cert.der().to_vec());
930        let key_pem = leaf_key.serialize_pem();
931        let key_der = private_key(&mut key_pem.as_bytes())
932            .map_err(|e| miette::miette!("Failed to parse leaf key PEM: {e}"))?
933            .ok_or_else(|| miette::miette!("No private key found in generated PEM"))?;
934
935        let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
936            .map_err(|e| miette::miette!("Failed to create signing key: {e}"))?;
937
938        // Persist cert + key to disk cache as combined PEM.
939        // Use 0600 so the private key is not world-readable.
940        let safe_name = domain.replace('.', "_").replace('*', "wildcard");
941        let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
942        let combined_pem = format!("{}{}", leaf_cert.pem(), key_pem);
943        {
944            #[cfg(unix)]
945            {
946                use std::io::Write;
947                use std::os::unix::fs::OpenOptionsExt;
948                if let Err(e) = std::fs::OpenOptions::new()
949                    .write(true)
950                    .create(true)
951                    .truncate(true)
952                    .mode(0o600)
953                    .open(&disk_path)
954                    .and_then(|mut f| f.write_all(combined_pem.as_bytes()))
955                {
956                    log::warn!(
957                        "Failed to persist cert for '{domain}' to {}: {e}",
958                        disk_path.display()
959                    );
960                }
961            }
962            #[cfg(not(unix))]
963            {
964                if let Err(e) = std::fs::write(&disk_path, combined_pem) {
965                    log::warn!(
966                        "Failed to persist cert for '{domain}' to {}: {e}",
967                        disk_path.display()
968                    );
969                } else {
970                    log::debug!(
971                        "Leaf cert for '{domain}' written to {} (file permissions are not \
972                         restricted on non-Unix platforms — consider restricting access manually)",
973                        disk_path.display()
974                    );
975                }
976            }
977        }
978
979        Ok(rustls::sign::CertifiedKey::new(vec![cert_der], signing_key))
980    }
981}
982
983#[cfg(feature = "proxy-tls")]
984impl rustls::server::ResolvesServerCert for SniCertResolver {
985    fn resolve(
986        &self,
987        client_hello: rustls::server::ClientHello<'_>,
988    ) -> Option<Arc<rustls::sign::CertifiedKey>> {
989        let domain = client_hello.server_name()?;
990        self.get_or_create(domain)
991    }
992}
993
994/// Get the effective host from a request.
995///
996/// HTTP/2 uses the `:authority` pseudo-header, which hyper exposes via
997/// `req.uri().authority()` rather than in the `HeaderMap`.
998/// HTTP/1.1 uses the `Host` header.
999fn get_request_host(req: &Request) -> Option<String> {
1000    // HTTP/2: :authority is available via the request URI, not the HeaderMap.
1001    let authority = req
1002        .uri()
1003        .authority()
1004        .map(|a| a.as_str().to_string())
1005        .filter(|s| !s.is_empty());
1006
1007    authority.or_else(|| {
1008        req.headers()
1009            .get(HOST)
1010            .and_then(|h| h.to_str().ok())
1011            .map(str::to_string)
1012    })
1013}
1014
1015/// Inject `X-Forwarded-*` headers into a proxied request.
1016///
1017/// Because the proxy is a **first-hop** dev tool (not a mid-tier forwarder),
1018/// all four headers are **unconditionally overwritten** with values derived
1019/// from the actual incoming connection.  Any values supplied by the connecting
1020/// client are discarded.
1021///
1022/// Trusting client-supplied `x-forwarded-for` / `x-forwarded-proto` would
1023/// allow a local process to spoof a remote IP or trick a backend's
1024/// HTTPS-detection logic (CSRF checks, secure-cookie flags, redirect rules).
1025fn inject_forwarded_headers(req: &mut Request, is_tls: bool, host_header: &str) {
1026    let remote_addr = req
1027        .extensions()
1028        .get::<axum::extract::ConnectInfo<SocketAddr>>()
1029        .map(|ci| ci.0.ip().to_string())
1030        .unwrap_or_else(|| "127.0.0.1".to_string());
1031
1032    let proto = if is_tls { "https" } else { "http" };
1033    let default_port = if is_tls { "443" } else { "80" };
1034
1035    // Always set fresh values — we are the edge, never a mid-tier forwarder.
1036    // Discard any x-forwarded-* headers supplied by the connecting client.
1037    let forwarded_for = remote_addr.clone();
1038    let forwarded_proto = proto.to_string();
1039    let forwarded_host = host_header.to_string();
1040    let forwarded_port = host_header
1041        .rsplit_once(':')
1042        .map(|(_, port)| port.to_string())
1043        .unwrap_or_else(|| default_port.to_string());
1044
1045    // Strip any client-supplied x-forwarded-* and RFC 7239 Forwarded headers
1046    // before inserting ours, so that no trace of the original values reaches
1047    // the backend.  The RFC 7239 `Forwarded` header is stripped alongside the
1048    // legacy `x-forwarded-*` set because backends that read it (Django, Rails,
1049    // Spring) would otherwise see client-injected spoofed IPs or protocols.
1050    for name in [
1051        "x-forwarded-for",
1052        "x-forwarded-proto",
1053        "x-forwarded-host",
1054        "x-forwarded-port",
1055        "forwarded",
1056    ] {
1057        if let Ok(header_name) = axum::http::HeaderName::from_bytes(name.as_bytes()) {
1058            req.headers_mut().remove(&header_name);
1059        }
1060    }
1061
1062    let headers = [
1063        ("x-forwarded-for", forwarded_for),
1064        ("x-forwarded-proto", forwarded_proto),
1065        ("x-forwarded-host", forwarded_host),
1066        ("x-forwarded-port", forwarded_port),
1067    ];
1068
1069    for (name, value) in headers {
1070        if let Ok(v) = HeaderValue::from_str(&value) {
1071            let header_name = axum::http::HeaderName::from_static(name);
1072            req.headers_mut().insert(header_name, v);
1073        }
1074    }
1075}
1076
1077/// Main proxy request handler.
1078///
1079/// Parses the `Host` header, resolves the target daemon, and forwards the request.
1080/// WebSocket / HTTP upgrade requests are forwarded transparently via hyper's upgrade mechanism.
1081async fn proxy_handler(State(state): State<ProxyState>, mut req: Request) -> Response {
1082    // Extract the host (supports both HTTP/2 :authority and HTTP/1.1 Host)
1083    let Some(raw_host) = get_request_host(&req) else {
1084        return error_response(StatusCode::BAD_REQUEST, "Missing Host header");
1085    };
1086    // Strip port from host for routing.
1087    // IPv6 addresses in Host headers are bracketed per RFC 2732: `[::1]:port`.
1088    // Splitting naïvely on ':' would break on the colons inside the address.
1089    let host = if raw_host.starts_with('[') {
1090        // IPv6: "[::1]:port" or "[::1]"
1091        raw_host
1092            .split("]:")
1093            .next()
1094            .unwrap_or(&raw_host)
1095            .trim_start_matches('[')
1096            .trim_end_matches(']')
1097            .to_string()
1098    } else {
1099        // IPv4 / hostname: "host:port" or "host"
1100        raw_host.split(':').next().unwrap_or(&raw_host).to_string()
1101    };
1102
1103    // Loop detection: check hop count.
1104    //
1105    // Security: strip (zero out) the hop counter on the very first hop to
1106    // prevent external clients from forging a high value and triggering a
1107    // 508 Loop Detected response (denial-of-service).  A request is
1108    // considered "first hop" when it does not carry the `x-pitchfork-hops`
1109    // request header that pitchfork injects when forwarding — i.e. it did
1110    // not come from another pitchfork proxy instance.
1111    // Note: `x-pitchfork` is a *response* header added by pitchfork and is
1112    // never present on incoming requests, so it cannot be used here.
1113    let is_from_pitchfork = req.headers().contains_key(PROXY_HOPS_HEADER);
1114    let hops: u64 = if is_from_pitchfork {
1115        req.headers()
1116            .get(PROXY_HOPS_HEADER)
1117            .and_then(|v| v.to_str().ok())
1118            .and_then(|s| s.parse().ok())
1119            .unwrap_or(0)
1120    } else {
1121        // External request: ignore any forged hop counter.
1122        0
1123    };
1124    if hops >= MAX_PROXY_HOPS {
1125        return error_response(
1126            StatusCode::LOOP_DETECTED,
1127            &format!(
1128                "Loop detected for '{host}': request has passed through the proxy {hops} times.\n\
1129                 This usually means a backend is proxying back through pitchfork without rewriting \n\
1130                 the Host header. If you use Vite/webpack proxy, set changeOrigin: true."
1131            ),
1132        );
1133    }
1134
1135    // Intercept "pitchfork.<tld>" — route to the built-in web UI
1136    let target_port = if let Some(subdomain) = strip_tld(&host, &state.tld) {
1137        if subdomain == "pitchfork" {
1138            crate::web::port()
1139        } else {
1140            None
1141        }
1142    } else {
1143        None
1144    };
1145
1146    let target_port = if let Some(port) = target_port {
1147        port
1148    } else {
1149        match resolve_target(&host, &state.tld).await {
1150            ResolveResult::Ready(port) => port,
1151            ResolveResult::Starting { slug } => {
1152                return starting_html_response(&slug, &raw_host);
1153            }
1154            ResolveResult::NotFound => {
1155                return error_response(
1156                    StatusCode::BAD_GATEWAY,
1157                    &format!(
1158                        "No daemon found for host '{host}'.\n\
1159                         Make sure the daemon has a slug, is running, and has a port configured.\n\
1160                         Expected format: <slug>.{tld}",
1161                        tld = state.tld
1162                    ),
1163                );
1164            }
1165            ResolveResult::Error(msg) => {
1166                return error_response(StatusCode::BAD_GATEWAY, &msg);
1167            }
1168        }
1169    };
1170    // Build the forwarding URI
1171    let path_and_query = req
1172        .uri()
1173        .path_and_query()
1174        .map(|pq| pq.as_str())
1175        .unwrap_or("/");
1176
1177    let forward_uri = match Uri::builder()
1178        .scheme("http")
1179        .authority(format!("localhost:{target_port}"))
1180        .path_and_query(path_and_query)
1181        .build()
1182    {
1183        Ok(uri) => uri,
1184        Err(e) => {
1185            return error_response(
1186                StatusCode::INTERNAL_SERVER_ERROR,
1187                &format!("Failed to build forward URI: {e}"),
1188            );
1189        }
1190    };
1191
1192    // Update the request URI and Host header
1193    *req.uri_mut() = forward_uri;
1194    req.headers_mut().insert(
1195        HOST,
1196        HeaderValue::from_str(&format!("localhost:{target_port}"))
1197            .unwrap_or_else(|_| HeaderValue::from_static("localhost")),
1198    );
1199
1200    // Inject X-Forwarded-* headers
1201    inject_forwarded_headers(&mut req, state.is_tls, &raw_host);
1202
1203    // Increment hop counter
1204    if let Ok(v) = HeaderValue::from_str(&(hops + 1).to_string()) {
1205        req.headers_mut()
1206            .insert(axum::http::HeaderName::from_static(PROXY_HOPS_HEADER), v);
1207    }
1208
1209    // Explicitly strip HTTP/2 pseudo-headers (":authority", ":method", etc.)
1210    // before forwarding to an HTTP/1.1 backend. Although hyper typically does
1211    // not store pseudo-headers in the HeaderMap, some middleware layers or
1212    // future hyper versions might; stripping them here is a defensive measure.
1213    let pseudo_headers: Vec<_> = req
1214        .headers()
1215        .keys()
1216        .filter(|k| k.as_str().starts_with(':'))
1217        .cloned()
1218        .collect();
1219    for key in pseudo_headers {
1220        req.headers_mut().remove(&key);
1221    }
1222
1223    // Downgrade the forwarded request to HTTP/1.1. TLS connections negotiate
1224    // HTTP/2 inbound via ALPN, but the upstream forward client speaks HTTP/1 to
1225    // the daemon. Without this, the still-h2-tagged request is rejected by the
1226    // client with `UserUnsupportedVersion`, surfacing as a 502 to the browser.
1227    *req.version_mut() = axum::http::Version::HTTP_11;
1228
1229    // Extract the client-side OnUpgrade handle *before* consuming req
1230    let client_upgrade = hyper::upgrade::on(&mut req);
1231
1232    // Forward the request with a per-request timeout so that a backend that
1233    // accepts the TCP connection but then stalls (deadlock, blocking I/O, etc.)
1234    // cannot hold the proxy connection open forever and exhaust file descriptors.
1235    //
1236    // 120 s is intentionally generous for a local dev proxy — it covers slow
1237    // test suites, large file uploads, and SSE streams while still bounding
1238    // the worst-case resource leak.
1239    let result = match tokio::time::timeout(
1240        std::time::Duration::from_secs(120),
1241        state.client.request(req),
1242    )
1243    .await
1244    {
1245        Ok(r) => r,
1246        Err(_elapsed) => {
1247            let msg = format!(
1248                "Request to daemon on port {target_port} timed out after 120 s.\n\
1249                 The daemon accepted the connection but did not respond in time."
1250            );
1251            log::warn!("{msg}");
1252            if let Some(ref on_error) = state.on_error {
1253                on_error(&msg);
1254            }
1255            return error_response(StatusCode::GATEWAY_TIMEOUT, &msg);
1256        }
1257    };
1258    match result {
1259        Ok(mut resp) => {
1260            // Extract backend upgrade handle *before* consuming resp
1261            let backend_upgrade = hyper::upgrade::on(&mut resp);
1262            let (mut parts, body) = resp.into_parts();
1263
1264            // Add pitchfork identification header
1265            parts.headers.insert(
1266                axum::http::HeaderName::from_static(PITCHFORK_HEADER),
1267                HeaderValue::from_static("1"),
1268            );
1269
1270            // Strip the internal hop-counter so it is never leaked to external clients.
1271            parts.headers.remove(PROXY_HOPS_HEADER);
1272
1273            // Strip hop-by-hop headers when serving HTTPS (HTTP/2 forbids them).
1274            // Skip 101 Switching Protocols — that response is always HTTP/1.1 and
1275            // the client needs the `Upgrade` header to complete the WS handshake
1276            // (RFC 6455 §4.1 requires `Upgrade: websocket` in the 101 response).
1277            if state.is_tls && parts.status != StatusCode::SWITCHING_PROTOCOLS {
1278                for h in HOP_BY_HOP_HEADERS {
1279                    if let Ok(name) = axum::http::HeaderName::from_bytes(h.as_bytes()) {
1280                        parts.headers.remove(&name);
1281                    }
1282                }
1283            }
1284
1285            // If the backend returned 101 Switching Protocols, pipe the upgraded streams.
1286            if parts.status == StatusCode::SWITCHING_PROTOCOLS {
1287                // Note: loop detection for WebSocket upgrades is already handled at the
1288                // top of proxy_handler (hops >= MAX_PROXY_HOPS check) before the request
1289                // is forwarded.  A 101 response here means the backend accepted the
1290                // upgrade, so the hop count was already within limits.
1291                tokio::spawn(async move {
1292                    if let (Ok(client_upgraded), Ok(backend_upgraded)) =
1293                        (client_upgrade.await, backend_upgrade.await)
1294                    {
1295                        let mut client_io = hyper_util::rt::TokioIo::new(client_upgraded);
1296                        let mut backend_io = hyper_util::rt::TokioIo::new(backend_upgraded);
1297                        // No application-level timeout here: tokio::time::timeout would be a
1298                        // hard wall-clock deadline for the entire tunnel, not an idle timeout.
1299                        // Long-lived connections (Vite/webpack HMR, SSE-over-WS) would be
1300                        // silently terminated after the deadline even if data is actively
1301                        // flowing.  The OS TCP keepalive is sufficient to reap truly dead
1302                        // connections; a proper idle timeout would require a custom
1303                        // AsyncRead/AsyncWrite wrapper that resets the timer on each I/O op.
1304                        let _ =
1305                            tokio::io::copy_bidirectional(&mut client_io, &mut backend_io).await;
1306                    }
1307                });
1308                return Response::from_parts(parts, Body::empty());
1309            }
1310
1311            // Backend refused the upgrade (returned a non-101 response) — forward it as-is.
1312            // This can happen when the backend rejects a WebSocket handshake with e.g. 400.
1313            Response::from_parts(parts, Body::new(body))
1314        }
1315        Err(e) => {
1316            let msg = format!(
1317                "Failed to connect to daemon on port {target_port}: {e}\n\
1318                 The daemon may have stopped or is not yet ready."
1319            );
1320            if let Some(ref on_error) = state.on_error {
1321                on_error(&msg);
1322            } else {
1323                log::warn!("{msg}");
1324            }
1325            error_response(StatusCode::BAD_GATEWAY, &msg)
1326        }
1327    }
1328}
1329
1330/// Resolve the target for a given hostname.
1331///
1332/// Slug-based routing using the global config's `[slugs]` section:
1333/// 1. Strip TLD to get subdomain (the slug)
1334/// 2. Look up slug in global config → find project dir + daemon name
1335/// 3. Check state file for a running daemon with that name → get its port
1336/// 4. If `proxy.auto_start` is enabled and the daemon is not running,
1337///    trigger an automatic start and wait for it to become ready.
1338///
1339/// # Returns
1340/// - `ResolveResult::Ready(port)`       — daemon running (or just auto-started), forward to this port
1341/// - `ResolveResult::Starting { slug }` — daemon start in progress (show waiting page)
1342/// - `ResolveResult::NotFound`          — no daemon matched
1343/// - `ResolveResult::Error(msg)`        — routing refused with a descriptive reason
1344///
1345/// # Locking
1346/// The state file lock is held only for the duration of the snapshot copy,
1347/// then released immediately to avoid serialising all proxy requests.
1348async fn resolve_target(host: &str, tld: &str) -> ResolveResult {
1349    let Some(subdomain) = strip_tld(host, tld) else {
1350        return ResolveResult::NotFound;
1351    };
1352
1353    let Some(cached) = cached_slug_lookup(&subdomain).await else {
1354        return ResolveResult::NotFound;
1355    };
1356
1357    // ─── Worktree prefix extraction ──────────────────────────────────────
1358    // When a wildcard subdomain like "feature-a.myapp" matched slug "myapp",
1359    // the prefix "feature-a" may correspond to a git worktree or jj workspace.
1360    let (expected_namespace, worktree_dir) = if subdomain != cached.slug {
1361        let prefix = subdomain
1362            .strip_suffix(&format!(".{}", cached.slug))
1363            .map(|s| s.to_string());
1364        match prefix {
1365            Some(ref p) => match cached.worktrees.iter().find(|w| w.sanitized_branch == *p) {
1366                Some(wt) => {
1367                    let ns = wt.namespace.clone().or_else(|| {
1368                        log::warn!(
1369                            "Worktree '{}' has no cached namespace; \
1370                             falling back to parent slug namespace.",
1371                            wt.path.display()
1372                        );
1373                        cached.namespace.clone()
1374                    });
1375                    (ns, Some(wt.path.clone()))
1376                }
1377                None => (cached.namespace.clone(), None),
1378            },
1379            None => (cached.namespace.clone(), None),
1380        }
1381    } else {
1382        (cached.namespace.clone(), None)
1383    };
1384
1385    let daemon_name = &cached.daemon_name;
1386
1387    let daemons = {
1388        let state_file = SUPERVISOR.state_file.lock().await;
1389        state_file.daemons.clone()
1390    };
1391
1392    let running_matches: Vec<(&DaemonId, &crate::daemon::Daemon)> = daemons
1393        .iter()
1394        .filter(|(id, d)| {
1395            id.name() == daemon_name
1396                && d.status.is_running()
1397                && match &expected_namespace {
1398                    Some(ns) => id.namespace() == ns,
1399                    None => true,
1400                }
1401        })
1402        .collect();
1403
1404    match running_matches.as_slice() {
1405        [] => {
1406            try_auto_start(
1407                &cached.slug,
1408                &cached,
1409                worktree_dir.as_deref(),
1410                expected_namespace.as_deref(),
1411            )
1412            .await
1413        }
1414        [(_, d)] => {
1415            if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1416                ResolveResult::Ready(port)
1417            } else {
1418                ResolveResult::NotFound
1419            }
1420        }
1421        _ => {
1422            let d = running_matches[0].1;
1423            if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1424                ResolveResult::Ready(port)
1425            } else {
1426                ResolveResult::NotFound
1427            }
1428        }
1429    }
1430}
1431
1432/// RAII guard that removes a `DaemonId` from `AUTO_START_IN_PROGRESS` on drop.
1433///
1434/// This ensures the in-progress flag is cleared even if the auto-start future
1435/// panics (e.g. an unexpected `unwrap` inside a dependency).  Without this,
1436/// the daemon ID would stay in the set permanently and every subsequent proxy
1437/// request would return "Starting …" forever.
1438struct AutoStartGuard {
1439    daemon_id: DaemonId,
1440}
1441
1442impl Drop for AutoStartGuard {
1443    fn drop(&mut self) {
1444        let daemon_id = self.daemon_id.clone();
1445        // Spawn a cleanup task because `Drop` is synchronous and the mutex is
1446        // async.  If the runtime is shutting down this may not execute, but in
1447        // that case the entire set is being dropped anyway.
1448        tokio::spawn(async move {
1449            AUTO_START_IN_PROGRESS.lock().await.remove(&daemon_id);
1450        });
1451    }
1452}
1453
1454/// Attempt to auto-start a daemon for the given slug.
1455///
1456/// If `proxy.auto_start` is disabled, returns `NotFound`.
1457/// Uses a dedup set to prevent concurrent starts for the same daemon.
1458/// Calls `SUPERVISOR.run()` with `wait_ready = true` so the daemon goes
1459/// through the same readiness lifecycle as `pf start`, then polls for the
1460/// active port.
1461///
1462/// The entire operation — including `SUPERVISOR.run()` and the port-polling
1463/// loop — is bounded by `proxy_auto_start_timeout`.
1464async fn try_auto_start(
1465    slug: &str,
1466    cached: &CachedSlugEntry,
1467    worktree_dir: Option<&std::path::Path>,
1468    expected_namespace: Option<&str>,
1469) -> ResolveResult {
1470    let s = settings();
1471    if !s.proxy.auto_start {
1472        return ResolveResult::NotFound;
1473    }
1474
1475    let ns = expected_namespace
1476        .map(|s| s.to_string())
1477        .or_else(|| cached.namespace.clone())
1478        .unwrap_or_else(|| "global".to_string());
1479    let daemon_id = match DaemonId::try_new(&ns, &cached.daemon_name) {
1480        Ok(id) => id,
1481        Err(_) => return ResolveResult::NotFound,
1482    };
1483
1484    {
1485        let mut in_progress = AUTO_START_IN_PROGRESS.lock().await;
1486        if !in_progress.insert(daemon_id.clone()) {
1487            return ResolveResult::Starting {
1488                slug: slug.to_string(),
1489            };
1490        }
1491    }
1492
1493    let _guard = AutoStartGuard {
1494        daemon_id: daemon_id.clone(),
1495    };
1496
1497    let timeout = s.proxy_auto_start_timeout();
1498
1499    match tokio::time::timeout(
1500        timeout,
1501        try_auto_start_inner(slug, cached, &daemon_id, worktree_dir),
1502    )
1503    .await
1504    {
1505        Ok(result) => result,
1506        Err(_elapsed) => {
1507            log::warn!("Auto-start: total timeout ({timeout:?}) exceeded for daemon {daemon_id}");
1508            ResolveResult::Error(format!(
1509                "Auto-start for '{daemon_id}' timed out after {timeout:?}.\n\
1510                 The daemon did not become ready and bind a port within the configured \
1511                 proxy_auto_start_timeout.\n\
1512                 Increase the timeout or check the daemon's logs for slow startup."
1513            ))
1514        }
1515    }
1516}
1517
1518/// Inner implementation of [`try_auto_start`] extracted so that the caller can
1519/// wrap it with `tokio::time::timeout` and unconditionally clean up
1520/// `AUTO_START_IN_PROGRESS` regardless of the outcome.
1521async fn try_auto_start_inner(
1522    slug: &str,
1523    cached: &CachedSlugEntry,
1524    daemon_id: &DaemonId,
1525    worktree_dir: Option<&std::path::Path>,
1526) -> ResolveResult {
1527    let config_dir = worktree_dir.unwrap_or(&cached.dir);
1528
1529    let pt = match crate::pitchfork_toml::PitchforkToml::all_merged_from(config_dir) {
1530        Ok(pt) => pt,
1531        Err(e) => {
1532            log::warn!(
1533                "Auto-start: failed to load config from {}: {e}",
1534                config_dir.display()
1535            );
1536            return ResolveResult::NotFound;
1537        }
1538    };
1539
1540    let daemon_config = match pt.daemons.get(daemon_id) {
1541        Some(cfg) => cfg,
1542        None => {
1543            log::debug!(
1544                "Auto-start: daemon {daemon_id} not found in config at {}",
1545                config_dir.display()
1546            );
1547            return ResolveResult::NotFound;
1548        }
1549    };
1550
1551    let opts = crate::ipc::batch::StartOptions {
1552        quiet: true,
1553        ..crate::ipc::batch::StartOptions::default()
1554    };
1555    let mut run_opts =
1556        match crate::ipc::batch::build_run_options(daemon_id, daemon_config, Some(&opts)) {
1557            Ok(o) => o,
1558            Err(e) => {
1559                log::warn!("Auto-start: failed to build run options for {daemon_id}: {e}");
1560                return ResolveResult::Error(format!("Failed to build run options: {e}"));
1561            }
1562        };
1563
1564    // Only set the working directory when the daemon config didn't specify one.
1565    // If the config has an explicit `dir`, respect it even in a worktree context.
1566    if run_opts.dir.0.as_os_str().is_empty() {
1567        run_opts.dir = crate::config_types::Dir(config_dir.to_path_buf());
1568    }
1569
1570    log::info!("Auto-start: starting daemon {daemon_id} for slug '{slug}'");
1571
1572    let run_result = SUPERVISOR.run(run_opts).await;
1573
1574    if let Err(e) = run_result {
1575        log::warn!("Auto-start: failed to start daemon {daemon_id}: {e}");
1576        return ResolveResult::Error(format!("Failed to start daemon: {e}"));
1577    }
1578
1579    let poll_interval = std::time::Duration::from_millis(250);
1580
1581    loop {
1582        let daemons = {
1583            let sf = SUPERVISOR.state_file.lock().await;
1584            sf.daemons.clone()
1585        };
1586
1587        if let Some(d) = daemons.get(daemon_id) {
1588            if d.status.is_running() {
1589                if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
1590                    log::info!("Auto-start: daemon {daemon_id} is ready on port {port}");
1591                    return ResolveResult::Ready(port);
1592                }
1593            } else {
1594                log::warn!(
1595                    "Auto-start: daemon {daemon_id} is no longer running (status: {})",
1596                    d.status
1597                );
1598                return ResolveResult::Error(format!(
1599                    "Daemon '{daemon_id}' started but exited unexpectedly.\n\
1600                     Check its logs for errors."
1601                ));
1602            }
1603        } else {
1604            log::warn!("Auto-start: daemon {daemon_id} not found in state file after start");
1605            return ResolveResult::Error(format!(
1606                "Daemon '{daemon_id}' started but disappeared from the state file.\n\
1607                 Check its logs for errors."
1608            ));
1609        }
1610
1611        tokio::time::sleep(poll_interval).await;
1612    }
1613}
1614
1615/// Strip the TLD suffix from a hostname, returning the subdomain part.
1616///
1617/// Examples:
1618/// - `api.myproject.localhost` with tld `localhost` → `api.myproject`
1619/// - `api.localhost` with tld `localhost` → `api`
1620/// - `localhost` with tld `localhost` → `None` (no subdomain)
1621fn strip_tld(host: &str, tld: &str) -> Option<String> {
1622    host.strip_suffix(&format!(".{tld}"))
1623        .filter(|s| !s.is_empty())
1624        .map(str::to_string)
1625}
1626
1627/// Build a human-friendly error message for port binding failures.
1628fn bind_error_message(port: u16, err: &std::io::Error) -> String {
1629    if port < 1024 {
1630        format!(
1631            "Failed to bind proxy server to port {port}: {err}\n\
1632             Hint: ports below 1024 require elevated privileges. \
1633             Try: sudo pitchfork supervisor start"
1634        )
1635    } else {
1636        format!(
1637            "Failed to bind proxy server to port {port}: {err}\n\
1638             Hint: another process may already be using this port."
1639        )
1640    }
1641}
1642
1643/// Build an HTML "Starting…" response that auto-refreshes every 2 seconds.
1644///
1645/// Displayed when a proxy request triggers an auto-start for a stopped daemon.
1646/// Once the daemon is ready, the next refresh will proxy normally to the backend.
1647fn starting_html_response(slug: &str, raw_host: &str) -> Response {
1648    let escaped_slug = slug
1649        .replace('&', "&amp;")
1650        .replace('<', "&lt;")
1651        .replace('>', "&gt;")
1652        .replace('"', "&quot;")
1653        .replace('\'', "&#x27;");
1654    let escaped_host = raw_host
1655        .replace('&', "&amp;")
1656        .replace('<', "&lt;")
1657        .replace('>', "&gt;")
1658        .replace('"', "&quot;")
1659        .replace('\'', "&#x27;");
1660
1661    let html = format!(
1662        r##"<!DOCTYPE html>
1663<html lang="en">
1664<head>
1665    <meta charset="UTF-8">
1666    <meta name="viewport" content="width=device-width, initial-scale=1">
1667    <meta http-equiv="refresh" content="2">
1668    <title>Starting {escaped_slug}… — pitchfork</title>
1669    <style>
1670        * {{ margin: 0; padding: 0; box-sizing: border-box; }}
1671        body {{
1672            font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
1673            background: #0f1117;
1674            color: #e1e4e8;
1675            display: flex;
1676            align-items: center;
1677            justify-content: center;
1678            min-height: 100vh;
1679        }}
1680        .container {{
1681            text-align: center;
1682            max-width: 480px;
1683            padding: 2rem;
1684        }}
1685        .spinner {{
1686            width: 48px;
1687            height: 48px;
1688            border: 4px solid rgba(255, 255, 255, 0.1);
1689            border-top-color: #58a6ff;
1690            border-radius: 50%;
1691            animation: spin 0.8s linear infinite;
1692            margin: 0 auto 1.5rem;
1693        }}
1694        @keyframes spin {{
1695            to {{ transform: rotate(360deg); }}
1696        }}
1697        h1 {{
1698            font-size: 1.5rem;
1699            font-weight: 600;
1700            margin-bottom: 0.5rem;
1701        }}
1702        .slug {{
1703            color: #58a6ff;
1704            font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
1705        }}
1706        .host {{
1707            color: #8b949e;
1708            font-size: 0.875rem;
1709            margin-top: 0.25rem;
1710        }}
1711        .hint {{
1712            color: #8b949e;
1713            font-size: 0.8rem;
1714            margin-top: 1.5rem;
1715        }}
1716    </style>
1717</head>
1718<body>
1719    <div class="container">
1720        <div class="spinner"></div>
1721        <h1>Starting <span class="slug">{escaped_slug}</span>…</h1>
1722        <p class="host">{escaped_host}</p>
1723        <p class="hint">This page will refresh automatically when the daemon is ready.</p>
1724    </div>
1725</body>
1726</html>"##
1727    );
1728
1729    Response::builder()
1730        .status(StatusCode::SERVICE_UNAVAILABLE)
1731        .header("content-type", "text/html; charset=utf-8")
1732        .header("retry-after", "2")
1733        .body(Body::from(html))
1734        .unwrap_or_else(|_| (StatusCode::SERVICE_UNAVAILABLE, "Starting…").into_response())
1735}
1736
1737/// Handler that redirects plain-HTTP requests to HTTPS.
1738///
1739/// Used when the proxy is configured for HTTPS but receives a plain-HTTP
1740/// request on the same port (after the first-byte peek determines it is
1741/// not a TLS ClientHello).  Returns a 302 redirect to the HTTPS equivalent.
1742///
1743/// WebSocket upgrade attempts over plain HTTP are rejected with 400
1744/// because WS-over-plain-HTTP to a TLS port is inherently broken.
1745async fn redirect_to_https_handler(req: Request) -> Response {
1746    // Reject WebSocket upgrades over plain HTTP
1747    if req.headers().contains_key("upgrade") {
1748        log::warn!("Dropping plain-HTTP WebSocket upgrade attempt — use wss:// instead of ws://");
1749        return (
1750            StatusCode::BAD_REQUEST,
1751            "WebSocket over plain HTTP is not supported on the HTTPS port. Use wss:// instead.",
1752        )
1753            .into_response();
1754    }
1755
1756    let raw_host = get_request_host(&req);
1757    let Some(raw_host) = raw_host else {
1758        return (StatusCode::BAD_REQUEST, "Missing Host header").into_response();
1759    };
1760
1761    // Strip any incoming port from Host and use the configured HTTPS port.
1762    let hostname = if raw_host.starts_with('[') {
1763        // IPv6: "[::1]:port" or "[::1]"
1764        raw_host
1765            .split_once("]:")
1766            .map(|(host, _)| host)
1767            .unwrap_or(&raw_host)
1768            .trim_start_matches('[')
1769            .trim_end_matches(']')
1770    } else {
1771        // IPv4/hostname: "host:port" or "host"
1772        let mut parts = raw_host.rsplitn(2, ':');
1773        let last = parts.next().unwrap_or(&raw_host);
1774        parts.next().unwrap_or(last)
1775    };
1776
1777    let path = req
1778        .uri()
1779        .path_and_query()
1780        .map(|pq| pq.as_str())
1781        .unwrap_or("/");
1782
1783    let https_port = match u16::try_from(settings().proxy.port).ok().filter(|&p| p > 0) {
1784        Some(443) | None => String::new(),
1785        Some(port) => format!(":{port}"),
1786    };
1787
1788    let host_for_url = if raw_host.starts_with('[') {
1789        format!("[{hostname}]")
1790    } else {
1791        hostname.to_string()
1792    };
1793
1794    let location = format!("https://{host_for_url}{https_port}{path}");
1795    (
1796        StatusCode::FOUND,
1797        [(axum::http::header::LOCATION, location)],
1798    )
1799        .into_response()
1800}
1801
1802/// Build a plain-text error response.
1803fn error_response(status: StatusCode, message: &str) -> Response {
1804    (status, message.to_string()).into_response()
1805}
1806
1807#[cfg(test)]
1808mod tests {
1809    use super::*;
1810
1811    #[test]
1812    fn test_strip_tld() {
1813        assert_eq!(
1814            strip_tld("api.myproject.localhost", "localhost"),
1815            Some("api.myproject".to_string())
1816        );
1817        assert_eq!(
1818            strip_tld("api.localhost", "localhost"),
1819            Some("api".to_string())
1820        );
1821        assert_eq!(strip_tld("localhost", "localhost"), None);
1822        assert_eq!(
1823            strip_tld("api.myproject.test", "test"),
1824            Some("api.myproject".to_string())
1825        );
1826        assert_eq!(strip_tld("other.com", "localhost"), None);
1827    }
1828
1829    fn make_entry(name: &str) -> CachedSlugEntry {
1830        CachedSlugEntry {
1831            slug: name.to_string(),
1832            namespace: None,
1833            daemon_name: name.to_string(),
1834            dir: std::path::PathBuf::from(format!("/tmp/{name}")),
1835            worktrees: vec![],
1836        }
1837    }
1838
1839    #[test]
1840    fn test_wildcard_slug_lookup_exact_match() {
1841        let mut entries = std::collections::HashMap::new();
1842        entries.insert("myapp".to_string(), make_entry("myapp"));
1843        // Exact match takes priority.
1844        let result = wildcard_slug_lookup("myapp", &entries, true);
1845        assert!(result.is_some());
1846        assert_eq!(result.unwrap().daemon_name, "myapp");
1847    }
1848
1849    #[test]
1850    fn test_wildcard_slug_lookup_subdomain_fallback() {
1851        let mut entries = std::collections::HashMap::new();
1852        entries.insert("myapp".to_string(), make_entry("myapp"));
1853        // "tenant.myapp" falls back to "myapp".
1854        let result = wildcard_slug_lookup("tenant.myapp", &entries, true);
1855        assert!(result.is_some());
1856        assert_eq!(result.unwrap().daemon_name, "myapp");
1857    }
1858
1859    #[test]
1860    fn test_wildcard_slug_lookup_nested_fallback() {
1861        let mut entries = std::collections::HashMap::new();
1862        entries.insert("myapp".to_string(), make_entry("myapp"));
1863        // "a.b.myapp" falls back to "myapp" through "b.myapp" → "myapp".
1864        let result = wildcard_slug_lookup("a.b.myapp", &entries, true);
1865        assert!(result.is_some());
1866        assert_eq!(result.unwrap().daemon_name, "myapp");
1867    }
1868
1869    #[test]
1870    fn test_wildcard_slug_lookup_no_match() {
1871        let entries = std::collections::HashMap::new();
1872        // Empty entries → no match.
1873        let result = wildcard_slug_lookup("tenant.myapp", &entries, true);
1874        assert!(result.is_none());
1875    }
1876
1877    #[test]
1878    fn test_wildcard_slug_lookup_disabled() {
1879        let mut entries = std::collections::HashMap::new();
1880        entries.insert("myapp".to_string(), make_entry("myapp"));
1881        // With wildcard disabled, "tenant.myapp" does NOT match "myapp".
1882        let result = wildcard_slug_lookup("tenant.myapp", &entries, false);
1883        assert!(result.is_none());
1884        // But exact match still works.
1885        let result = wildcard_slug_lookup("myapp", &entries, false);
1886        assert!(result.is_some());
1887    }
1888
1889    #[test]
1890    fn test_wildcard_slug_lookup_exact_beats_wildcard() {
1891        let mut entries = std::collections::HashMap::new();
1892        entries.insert("myapp".to_string(), make_entry("myapp"));
1893        let mut tenant_entry = make_entry("tenant-daemon");
1894        tenant_entry.slug = "tenant.myapp".to_string();
1895        entries.insert("tenant.myapp".to_string(), tenant_entry);
1896        // "tenant.myapp" should match the exact slug, not fall back to "myapp".
1897        let result = wildcard_slug_lookup("tenant.myapp", &entries, true);
1898        assert!(result.is_some());
1899        assert_eq!(result.unwrap().daemon_name, "tenant-daemon");
1900    }
1901
1902    #[cfg(feature = "proxy-tls")]
1903    #[test]
1904    fn test_generate_ca() {
1905        let dir = tempfile::tempdir().unwrap();
1906        let cert_path = dir.path().join("ca.pem");
1907        let key_path = dir.path().join("ca-key.pem");
1908
1909        generate_ca(&cert_path, &key_path).unwrap();
1910
1911        assert!(cert_path.exists(), "ca.pem should be created");
1912        assert!(key_path.exists(), "ca-key.pem should be created");
1913
1914        let cert_pem = std::fs::read_to_string(&cert_path).unwrap();
1915        let key_pem = std::fs::read_to_string(&key_path).unwrap();
1916
1917        assert!(cert_pem.contains("BEGIN CERTIFICATE"), "should be PEM cert");
1918        assert!(
1919            key_pem.contains("BEGIN") && key_pem.contains("PRIVATE KEY"),
1920            "should be PEM key"
1921        );
1922    }
1923}