use std::net::SocketAddr;
use std::sync::Arc;
use axum::Router;
use axum::body::Body;
use axum::extract::{Request, State};
use axum::http::{HeaderValue, StatusCode, Uri};
use axum::response::{IntoResponse, Response};
use hyper::header::HOST;
const PITCHFORK_HEADER: &str = "x-pitchfork";
const PROXY_HOPS_HEADER: &str = "x-pitchfork-hops";
const MAX_PROXY_HOPS: u64 = 5;
const HOP_BY_HOP_HEADERS: &[&str] = &[
"connection",
"keep-alive",
"proxy-connection",
"transfer-encoding",
"upgrade",
];
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::rt::TokioExecutor;
use tokio::net::TcpListener;
use crate::daemon_id::DaemonId;
use crate::settings::settings;
use crate::supervisor::SUPERVISOR;
const SLUG_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(2);
#[derive(Clone, Debug)]
struct CachedSlugEntry {
namespace: Option<String>,
daemon_name: String,
dir: std::path::PathBuf,
}
struct SlugCache {
entries: Arc<std::collections::HashMap<String, CachedSlugEntry>>,
expires_at: std::time::Instant,
}
static SLUG_CACHE: once_cell::sync::Lazy<tokio::sync::Mutex<SlugCache>> =
once_cell::sync::Lazy::new(|| {
tokio::sync::Mutex::new(SlugCache {
entries: Arc::new(std::collections::HashMap::new()),
expires_at: std::time::Instant::now(), })
});
fn build_slug_entries() -> std::collections::HashMap<String, CachedSlugEntry> {
let global_slugs = crate::pitchfork_toml::PitchforkToml::read_global_slugs();
let mut entries = std::collections::HashMap::with_capacity(global_slugs.len());
for (slug, entry) in &global_slugs {
let ns = crate::pitchfork_toml::PitchforkToml::namespace_for_dir(&entry.dir).ok();
let daemon_name = entry.daemon.as_deref().unwrap_or(slug).to_string();
entries.insert(
slug.clone(),
CachedSlugEntry {
namespace: ns,
daemon_name,
dir: entry.dir.clone(),
},
);
}
entries
}
async fn get_cached_slugs() -> Arc<std::collections::HashMap<String, CachedSlugEntry>> {
{
let cache = SLUG_CACHE.lock().await;
if std::time::Instant::now() < cache.expires_at {
return Arc::clone(&cache.entries);
}
}
let new_entries = Arc::new(build_slug_entries());
{
let mut cache = SLUG_CACHE.lock().await;
cache.entries = Arc::clone(&new_entries);
cache.expires_at = std::time::Instant::now() + SLUG_CACHE_TTL;
}
new_entries
}
async fn cached_slug_lookup(subdomain: &str) -> Option<CachedSlugEntry> {
let entries = get_cached_slugs().await;
entries.get(subdomain).cloned()
}
static AUTO_START_IN_PROGRESS: once_cell::sync::Lazy<
tokio::sync::Mutex<std::collections::HashSet<DaemonId>>,
> = once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new(std::collections::HashSet::new()));
enum ResolveResult {
Ready(u16),
Starting { slug: String },
NotFound,
Error(String),
}
type OnErrorFn = Arc<dyn Fn(&str) + Send + Sync>;
#[derive(Clone)]
struct ProxyState {
client: Arc<Client<HttpConnector, Body>>,
tld: String,
is_tls: bool,
on_error: Option<OnErrorFn>,
}
pub async fn serve(
bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
) -> crate::Result<()> {
let s = settings();
let Some(effective_port) = u16::try_from(s.proxy.port).ok().filter(|&p| p > 0) else {
let msg = format!(
"proxy.port {} is out of valid port range (1-65535), proxy server cannot start",
s.proxy.port
);
let _ = bind_tx.send(Err(msg.clone()));
miette::bail!("{msg}");
};
let mut connector = HttpConnector::new();
connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
let client = Client::builder(TokioExecutor::new())
.pool_idle_timeout(std::time::Duration::from_secs(30))
.build(connector);
let state = ProxyState {
client: Arc::new(client),
tld: s.proxy.tld.clone(),
is_tls: s.proxy.https,
on_error: None,
};
let app = Router::new().fallback(proxy_handler).with_state(state);
let bind_ip: std::net::IpAddr = match s.proxy.host.parse() {
Ok(ip) => ip,
Err(_) => {
log::warn!(
"proxy.host {:?} is not a valid IP address — falling back to 127.0.0.1. \
The proxy will only be reachable on the loopback interface.",
s.proxy.host
);
std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
}
};
let addr = SocketAddr::from((bind_ip, effective_port));
if s.proxy.https {
serve_https_with_http_fallback(app, addr, s, effective_port, bind_tx).await
} else {
serve_http(app, addr, effective_port, bind_tx).await
}
}
async fn serve_http(
app: Router,
addr: SocketAddr,
effective_port: u16,
bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
) -> crate::Result<()> {
let listener = match TcpListener::bind(addr).await {
Ok(l) => {
let _ = bind_tx.send(Ok(()));
l
}
Err(e) => {
let msg = bind_error_message(effective_port, &e);
let _ = bind_tx.send(Err(msg.clone()));
return Err(miette::miette!("{msg}"));
}
};
log::info!("Proxy server listening on http://{addr}");
if effective_port < 1024 {
log::info!(
"Note: port {effective_port} is a privileged port. \
The supervisor must be started with sudo to bind to this port."
);
}
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.map_err(|e| miette::miette!("Proxy server error: {e}"))?;
Ok(())
}
#[cfg(feature = "proxy-tls")]
async fn serve_https_with_http_fallback(
app: Router,
addr: SocketAddr,
s: &crate::settings::Settings,
effective_port: u16,
bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
) -> crate::Result<()> {
use rustls::ServerConfig;
use tokio_rustls::TlsAcceptor;
let (ca_cert_path, ca_key_path) = resolve_tls_paths(s);
if !ca_cert_path.exists() || !ca_key_path.exists() {
generate_ca(&ca_cert_path, &ca_key_path)?;
log::info!(
"Generated local CA certificate at {}",
ca_cert_path.display()
);
log::info!("To trust the CA in your browser, run: pitchfork proxy trust");
}
let _ = rustls::crypto::ring::default_provider().install_default();
let resolver = SniCertResolver::new(&ca_cert_path, &ca_key_path)?;
let tls_config = ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver));
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
let listener = match TcpListener::bind(addr).await {
Ok(l) => {
let _ = bind_tx.send(Ok(()));
l
}
Err(e) => {
let msg = bind_error_message(effective_port, &e);
let _ = bind_tx.send(Err(msg.clone()));
return Err(miette::miette!("{msg}"));
}
};
log::info!("Proxy server listening on https://{addr} (HTTP also accepted)");
if effective_port < 1024 {
log::info!(
"Note: port {effective_port} is a privileged port. \
The supervisor must be started with sudo to bind to this port."
);
}
loop {
let (stream, _peer_addr) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
log::warn!("Accept error (will retry): {e}");
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
continue;
}
};
let acceptor = acceptor.clone();
let app = app.clone();
tokio::spawn(async move {
let mut peek_buf = [0u8; 1];
match stream.peek(&mut peek_buf).await {
Ok(0) | Err(_) => return, _ => {}
}
if peek_buf[0] == 0x16 {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
let io = hyper_util::rt::TokioIo::new(tls_stream);
let svc = hyper_util::service::TowerToHyperService::new(app);
let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(io, svc)
.await;
}
Err(e) => {
log::debug!("TLS handshake error: {e}");
}
}
} else {
let io = hyper_util::rt::TokioIo::new(stream);
let svc = hyper_util::service::TowerToHyperService::new(app);
let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(io, svc)
.await;
}
});
}
}
#[cfg(not(feature = "proxy-tls"))]
async fn serve_https_with_http_fallback(
_app: Router,
_addr: SocketAddr,
_s: &crate::settings::Settings,
_effective_port: u16,
bind_tx: tokio::sync::oneshot::Sender<std::result::Result<(), String>>,
) -> crate::Result<()> {
let msg = "HTTPS proxy support requires the `proxy-tls` feature.\n\
Rebuild pitchfork with: cargo build --features proxy-tls"
.to_string();
let _ = bind_tx.send(Err(msg.clone()));
miette::bail!("{msg}")
}
#[cfg(feature = "proxy-tls")]
fn resolve_tls_paths(s: &crate::settings::Settings) -> (std::path::PathBuf, std::path::PathBuf) {
let proxy_dir = crate::env::PITCHFORK_STATE_DIR.join("proxy");
let resolve = |configured: &str, default: &str| {
if configured.is_empty() {
proxy_dir.join(default)
} else {
std::path::PathBuf::from(configured)
}
};
(
resolve(&s.proxy.tls_cert, "ca.pem"),
resolve(&s.proxy.tls_key, "ca-key.pem"),
)
}
#[cfg(feature = "proxy-tls")]
pub fn generate_ca(cert_path: &std::path::Path, key_path: &std::path::Path) -> crate::Result<()> {
use rcgen::{
BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyUsagePurpose,
};
if let Some(parent) = cert_path.parent() {
std::fs::create_dir_all(parent)
.map_err(|e| miette::miette!("Failed to create proxy cert directory: {e}"))?;
}
let mut params = CertificateParams::default();
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, "Pitchfork Local CA");
dn.push(DnType::OrganizationName, "Pitchfork");
params.distinguished_name = dn;
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
let key_pair = rcgen::KeyPair::generate()
.map_err(|e| miette::miette!("Failed to generate CA key pair: {e}"))?;
let ca_cert = params
.self_signed(&key_pair)
.map_err(|e| miette::miette!("Failed to self-sign CA certificate: {e}"))?;
std::fs::write(cert_path, ca_cert.pem()).map_err(|e| {
miette::miette!(
"Failed to write CA certificate to {}: {e}",
cert_path.display()
)
})?;
{
use std::io::Write;
#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;
std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(key_path)
.and_then(|mut f| f.write_all(key_pair.serialize_pem().as_bytes()))
.map_err(|e| {
miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
})?;
}
#[cfg(not(unix))]
{
std::fs::write(key_path, key_pair.serialize_pem()).map_err(|e| {
miette::miette!("Failed to write CA key to {}: {e}", key_path.display())
})?;
log::debug!(
"CA private key written to {} (file permissions are not restricted \
on non-Unix platforms — consider restricting access manually)",
key_path.display()
);
}
}
Ok(())
}
#[cfg(feature = "proxy-tls")]
struct SniCertResolver {
issuer: rcgen::Issuer<'static, rcgen::KeyPair>,
host_certs_dir: std::path::PathBuf,
cache: std::sync::Mutex<std::collections::HashMap<String, Arc<rustls::sign::CertifiedKey>>>,
pending: std::sync::Mutex<std::collections::HashSet<String>>,
pending_cv: std::sync::Condvar,
}
#[cfg(feature = "proxy-tls")]
impl std::fmt::Debug for SniCertResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SniCertResolver").finish_non_exhaustive()
}
}
#[cfg(feature = "proxy-tls")]
impl SniCertResolver {
fn new(ca_cert_path: &std::path::Path, ca_key_path: &std::path::Path) -> crate::Result<Self> {
let ca_key_pem = std::fs::read_to_string(ca_key_path)
.map_err(|e| miette::miette!("Failed to read CA key {}: {e}", ca_key_path.display()))?;
let ca_cert_pem = std::fs::read_to_string(ca_cert_path).map_err(|e| {
miette::miette!("Failed to read CA cert {}: {e}", ca_cert_path.display())
})?;
if !ca_cert_pem.contains("BEGIN CERTIFICATE") {
miette::bail!("CA cert file does not contain a valid PEM certificate");
}
let ca_key = rcgen::KeyPair::from_pem(&ca_key_pem)
.map_err(|e| miette::miette!("Failed to parse CA key: {e}"))?;
let issuer = rcgen::Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key)
.map_err(|e| miette::miette!("Failed to parse CA cert: {e}"))?;
let host_certs_dir = ca_cert_path
.parent()
.unwrap_or(std::path::Path::new("."))
.join("host-certs");
std::fs::create_dir_all(&host_certs_dir)
.map_err(|e| miette::miette!("Failed to create host-certs dir: {e}"))?;
Ok(Self {
issuer,
host_certs_dir,
cache: std::sync::Mutex::new(std::collections::HashMap::new()),
pending: std::sync::Mutex::new(std::collections::HashSet::new()),
pending_cv: std::sync::Condvar::new(),
})
}
fn get_or_create(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
{
let cache = self.cache.lock().ok()?;
if let Some(ck) = cache.get(domain) {
return Some(Arc::clone(ck));
}
}
loop {
{
let mut pending = self.pending.lock().ok()?;
if pending.contains(domain) {
pending = self.pending_cv.wait(pending).ok()?;
drop(pending);
} else {
pending.insert(domain.to_string());
break;
}
}
{
let cache = self.cache.lock().ok()?;
if let Some(ck) = cache.get(domain) {
return Some(Arc::clone(ck));
}
} }
let result = self.get_or_create_inner(domain);
{
let mut pending = match self.pending.lock() {
Ok(g) => g,
Err(e) => e.into_inner(),
};
pending.remove(domain);
self.pending_cv.notify_all();
}
result
}
fn get_or_create_inner(&self, domain: &str) -> Option<Arc<rustls::sign::CertifiedKey>> {
let safe_name = domain.replace('.', "_").replace('*', "wildcard");
let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
if disk_path.exists() {
if let Ok(ck) = self.load_from_disk(&disk_path) {
let ck = Arc::new(ck);
if let Ok(mut cache) = self.cache.lock() {
cache.insert(domain.to_string(), Arc::clone(&ck));
}
return Some(ck);
}
let _ = std::fs::remove_file(&disk_path);
}
let ck = self.sign_for_domain(domain).ok()?;
let ck = Arc::new(ck);
if let Ok(mut cache) = self.cache.lock() {
cache.insert(domain.to_string(), Arc::clone(&ck));
}
Some(ck)
}
fn load_from_disk(&self, path: &std::path::Path) -> crate::Result<rustls::sign::CertifiedKey> {
use rustls::pki_types::CertificateDer;
use rustls_pemfile::{certs, private_key};
let pem = std::fs::read_to_string(path)
.map_err(|e| miette::miette!("Failed to read disk cert {}: {e}", path.display()))?;
let cert_ders: Vec<CertificateDer<'static>> = certs(&mut pem.as_bytes())
.collect::<Result<Vec<_>, _>>()
.map_err(|e| miette::miette!("Failed to parse certs from {}: {e}", path.display()))?;
if cert_ders.is_empty() {
miette::bail!("No certificates found in {}", path.display());
}
{
let (_, cert) = x509_parser::parse_x509_certificate(&cert_ders[0]).map_err(|e| {
miette::miette!("Failed to parse certificate from {}: {e}", path.display())
})?;
use chrono::Utc;
let now_ts = Utc::now().timestamp();
let not_after_ts = cert.validity().not_after.timestamp();
if not_after_ts < now_ts {
miette::bail!(
"Cached certificate at {} has expired — will regenerate",
path.display()
);
}
}
let key_der = private_key(&mut pem.as_bytes())
.map_err(|e| miette::miette!("Failed to parse key from {}: {e}", path.display()))?
.ok_or_else(|| miette::miette!("No private key found in {}", path.display()))?;
let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
.map_err(|e| miette::miette!("Failed to create signing key from disk: {e}"))?;
Ok(rustls::sign::CertifiedKey::new(cert_ders, signing_key))
}
fn sign_for_domain(&self, domain: &str) -> crate::Result<rustls::sign::CertifiedKey> {
use rcgen::date_time_ymd;
use rcgen::{CertificateParams, DistinguishedName, DnType, SanType};
use rustls::pki_types::CertificateDer;
use rustls_pemfile::private_key;
let mut params = CertificateParams::default();
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, domain);
params.distinguished_name = dn;
{
use chrono::{Datelike, Duration, Utc};
let yesterday = Utc::now() - Duration::days(1);
let expiry = Utc::now() + Duration::days(397);
params.not_before = date_time_ymd(
yesterday.year(),
yesterday.month() as u8,
yesterday.day() as u8,
);
params.not_after =
date_time_ymd(expiry.year(), expiry.month() as u8, expiry.day() as u8);
}
let mut sans =
vec![SanType::DnsName(domain.to_string().try_into().map_err(
|e| miette::miette!("Invalid domain name '{domain}': {e}"),
)?)];
if let Some(dot_pos) = domain.find('.') {
let parent = &domain[dot_pos + 1..];
if parent.contains('.') {
let wildcard = format!("*.{parent}");
if let Ok(wc) = wildcard.try_into() {
sans.push(SanType::DnsName(wc));
}
}
}
params.subject_alt_names = sans;
let leaf_key = rcgen::KeyPair::generate()
.map_err(|e| miette::miette!("Failed to generate leaf key: {e}"))?;
let leaf_cert = params
.signed_by(&leaf_key, &self.issuer)
.map_err(|e| miette::miette!("Failed to sign leaf cert for '{domain}': {e}"))?;
let cert_der = CertificateDer::from(leaf_cert.der().to_vec());
let key_pem = leaf_key.serialize_pem();
let key_der = private_key(&mut key_pem.as_bytes())
.map_err(|e| miette::miette!("Failed to parse leaf key PEM: {e}"))?
.ok_or_else(|| miette::miette!("No private key found in generated PEM"))?;
let signing_key = rustls::crypto::ring::sign::any_supported_type(&key_der)
.map_err(|e| miette::miette!("Failed to create signing key: {e}"))?;
let safe_name = domain.replace('.', "_").replace('*', "wildcard");
let disk_path = self.host_certs_dir.join(format!("{safe_name}.pem"));
let combined_pem = format!("{}{}", leaf_cert.pem(), key_pem);
{
use std::io::Write;
#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;
if let Err(e) = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&disk_path)
.and_then(|mut f| f.write_all(combined_pem.as_bytes()))
{
log::warn!(
"Failed to persist cert for '{domain}' to {}: {e}",
disk_path.display()
);
}
}
#[cfg(not(unix))]
{
if let Err(e) = std::fs::write(&disk_path, combined_pem) {
log::warn!(
"Failed to persist cert for '{domain}' to {}: {e}",
disk_path.display()
);
} else {
log::debug!(
"Leaf cert for '{domain}' written to {} (file permissions are not \
restricted on non-Unix platforms — consider restricting access manually)",
disk_path.display()
);
}
}
}
Ok(rustls::sign::CertifiedKey::new(vec![cert_der], signing_key))
}
}
#[cfg(feature = "proxy-tls")]
impl rustls::server::ResolvesServerCert for SniCertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
let domain = client_hello.server_name()?;
self.get_or_create(domain)
}
}
fn get_request_host(req: &Request) -> Option<String> {
let authority = req
.uri()
.authority()
.map(|a| a.as_str().to_string())
.filter(|s| !s.is_empty());
authority.or_else(|| {
req.headers()
.get(HOST)
.and_then(|h| h.to_str().ok())
.map(str::to_string)
})
}
fn inject_forwarded_headers(req: &mut Request, is_tls: bool, host_header: &str) {
let remote_addr = req
.extensions()
.get::<axum::extract::ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip().to_string())
.unwrap_or_else(|| "127.0.0.1".to_string());
let proto = if is_tls { "https" } else { "http" };
let default_port = if is_tls { "443" } else { "80" };
let forwarded_for = remote_addr.clone();
let forwarded_proto = proto.to_string();
let forwarded_host = host_header.to_string();
let forwarded_port = host_header
.rsplit_once(':')
.map(|(_, port)| port.to_string())
.unwrap_or_else(|| default_port.to_string());
for name in [
"x-forwarded-for",
"x-forwarded-proto",
"x-forwarded-host",
"x-forwarded-port",
"forwarded",
] {
if let Ok(header_name) = axum::http::HeaderName::from_bytes(name.as_bytes()) {
req.headers_mut().remove(&header_name);
}
}
let headers = [
("x-forwarded-for", forwarded_for),
("x-forwarded-proto", forwarded_proto),
("x-forwarded-host", forwarded_host),
("x-forwarded-port", forwarded_port),
];
for (name, value) in headers {
if let Ok(v) = HeaderValue::from_str(&value) {
let header_name = axum::http::HeaderName::from_static(name);
req.headers_mut().insert(header_name, v);
}
}
}
async fn proxy_handler(State(state): State<ProxyState>, mut req: Request) -> Response {
let Some(raw_host) = get_request_host(&req) else {
return error_response(StatusCode::BAD_REQUEST, "Missing Host header");
};
let host = if raw_host.starts_with('[') {
raw_host
.split("]:")
.next()
.unwrap_or(&raw_host)
.trim_start_matches('[')
.trim_end_matches(']')
.to_string()
} else {
raw_host.split(':').next().unwrap_or(&raw_host).to_string()
};
let is_from_pitchfork = req.headers().contains_key(PROXY_HOPS_HEADER);
let hops: u64 = if is_from_pitchfork {
req.headers()
.get(PROXY_HOPS_HEADER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse().ok())
.unwrap_or(0)
} else {
0
};
if hops >= MAX_PROXY_HOPS {
return error_response(
StatusCode::LOOP_DETECTED,
&format!(
"Loop detected for '{host}': request has passed through the proxy {hops} times.\n\
This usually means a backend is proxying back through pitchfork without rewriting \n\
the Host header. If you use Vite/webpack proxy, set changeOrigin: true."
),
);
}
let target_port = match resolve_target(&host, &state.tld).await {
ResolveResult::Ready(port) => port,
ResolveResult::Starting { slug } => {
return starting_html_response(&slug, &raw_host);
}
ResolveResult::NotFound => {
return error_response(
StatusCode::BAD_GATEWAY,
&format!(
"No daemon found for host '{host}'.\n\
Make sure the daemon has a slug, is running, and has a port configured.\n\
Expected format: <slug>.{tld}",
tld = state.tld
),
);
}
ResolveResult::Error(msg) => {
return error_response(StatusCode::BAD_GATEWAY, &msg);
}
};
let path_and_query = req
.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/");
let forward_uri = match Uri::builder()
.scheme("http")
.authority(format!("localhost:{target_port}"))
.path_and_query(path_and_query)
.build()
{
Ok(uri) => uri,
Err(e) => {
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("Failed to build forward URI: {e}"),
);
}
};
*req.uri_mut() = forward_uri;
req.headers_mut().insert(
HOST,
HeaderValue::from_str(&format!("localhost:{target_port}"))
.unwrap_or_else(|_| HeaderValue::from_static("localhost")),
);
inject_forwarded_headers(&mut req, state.is_tls, &raw_host);
if let Ok(v) = HeaderValue::from_str(&(hops + 1).to_string()) {
req.headers_mut()
.insert(axum::http::HeaderName::from_static(PROXY_HOPS_HEADER), v);
}
let pseudo_headers: Vec<_> = req
.headers()
.keys()
.filter(|k| k.as_str().starts_with(':'))
.cloned()
.collect();
for key in pseudo_headers {
req.headers_mut().remove(&key);
}
let client_upgrade = hyper::upgrade::on(&mut req);
let result = match tokio::time::timeout(
std::time::Duration::from_secs(120),
state.client.request(req),
)
.await
{
Ok(r) => r,
Err(_elapsed) => {
let msg = format!(
"Request to daemon on port {target_port} timed out after 120 s.\n\
The daemon accepted the connection but did not respond in time."
);
log::warn!("{msg}");
if let Some(ref on_error) = state.on_error {
on_error(&msg);
}
return error_response(StatusCode::GATEWAY_TIMEOUT, &msg);
}
};
match result {
Ok(mut resp) => {
let backend_upgrade = hyper::upgrade::on(&mut resp);
let (mut parts, body) = resp.into_parts();
parts.headers.insert(
axum::http::HeaderName::from_static(PITCHFORK_HEADER),
HeaderValue::from_static("1"),
);
parts.headers.remove(PROXY_HOPS_HEADER);
if state.is_tls && parts.status != StatusCode::SWITCHING_PROTOCOLS {
for h in HOP_BY_HOP_HEADERS {
if let Ok(name) = axum::http::HeaderName::from_bytes(h.as_bytes()) {
parts.headers.remove(&name);
}
}
}
if parts.status == StatusCode::SWITCHING_PROTOCOLS {
tokio::spawn(async move {
if let (Ok(client_upgraded), Ok(backend_upgraded)) =
(client_upgrade.await, backend_upgrade.await)
{
let mut client_io = hyper_util::rt::TokioIo::new(client_upgraded);
let mut backend_io = hyper_util::rt::TokioIo::new(backend_upgraded);
let _ =
tokio::io::copy_bidirectional(&mut client_io, &mut backend_io).await;
}
});
return Response::from_parts(parts, Body::empty());
}
Response::from_parts(parts, Body::new(body))
}
Err(e) => {
let msg = format!(
"Failed to connect to daemon on port {target_port}: {e}\n\
The daemon may have stopped or is not yet ready."
);
if let Some(ref on_error) = state.on_error {
on_error(&msg);
} else {
log::warn!("{msg}");
}
error_response(StatusCode::BAD_GATEWAY, &msg)
}
}
}
async fn resolve_target(host: &str, tld: &str) -> ResolveResult {
let Some(subdomain) = strip_tld(host, tld) else {
return ResolveResult::NotFound;
};
let Some(cached) = cached_slug_lookup(&subdomain).await else {
return ResolveResult::NotFound;
};
let daemon_name = &cached.daemon_name;
let expected_namespace = &cached.namespace;
let daemons = {
let state_file = SUPERVISOR.state_file.lock().await;
state_file.daemons.clone()
};
let running_matches: Vec<(&DaemonId, &crate::daemon::Daemon)> = daemons
.iter()
.filter(|(id, d)| {
id.name() == daemon_name
&& d.status.is_running()
&& match expected_namespace {
Some(ns) => id.namespace() == ns,
None => true,
}
})
.collect();
match running_matches.as_slice() {
[] => {
try_auto_start(&subdomain, &cached).await
}
[(_, d)] => {
if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
ResolveResult::Ready(port)
} else {
ResolveResult::NotFound
}
}
_ => {
let d = running_matches[0].1;
if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
ResolveResult::Ready(port)
} else {
ResolveResult::NotFound
}
}
}
}
struct AutoStartGuard {
daemon_id: DaemonId,
}
impl Drop for AutoStartGuard {
fn drop(&mut self) {
let daemon_id = self.daemon_id.clone();
tokio::spawn(async move {
AUTO_START_IN_PROGRESS.lock().await.remove(&daemon_id);
});
}
}
async fn try_auto_start(slug: &str, cached: &CachedSlugEntry) -> ResolveResult {
let s = settings();
if !s.proxy.auto_start {
return ResolveResult::NotFound;
}
let ns = cached
.namespace
.clone()
.unwrap_or_else(|| "global".to_string());
let daemon_id = match DaemonId::try_new(&ns, &cached.daemon_name) {
Ok(id) => id,
Err(_) => return ResolveResult::NotFound,
};
{
let mut in_progress = AUTO_START_IN_PROGRESS.lock().await;
if !in_progress.insert(daemon_id.clone()) {
return ResolveResult::Starting {
slug: slug.to_string(),
};
}
}
let _guard = AutoStartGuard {
daemon_id: daemon_id.clone(),
};
let timeout = s.proxy_auto_start_timeout();
match tokio::time::timeout(timeout, try_auto_start_inner(slug, cached, &daemon_id)).await {
Ok(result) => result,
Err(_elapsed) => {
log::warn!("Auto-start: total timeout ({timeout:?}) exceeded for daemon {daemon_id}");
ResolveResult::Error(format!(
"Auto-start for '{daemon_id}' timed out after {timeout:?}.\n\
The daemon did not become ready and bind a port within the configured \
proxy_auto_start_timeout.\n\
Increase the timeout or check the daemon's logs for slow startup."
))
}
}
}
async fn try_auto_start_inner(
slug: &str,
cached: &CachedSlugEntry,
daemon_id: &DaemonId,
) -> ResolveResult {
let pt = match crate::pitchfork_toml::PitchforkToml::all_merged_from(&cached.dir) {
Ok(pt) => pt,
Err(e) => {
log::warn!(
"Auto-start: failed to load config from {}: {e}",
cached.dir.display()
);
return ResolveResult::NotFound;
}
};
let daemon_config = match pt.daemons.get(daemon_id) {
Some(cfg) => cfg,
None => {
log::debug!(
"Auto-start: daemon {daemon_id} not found in config at {}",
cached.dir.display()
);
return ResolveResult::NotFound;
}
};
let opts = crate::ipc::batch::StartOptions::default();
let mut run_opts = match crate::ipc::batch::build_run_options(daemon_id, daemon_config, &opts) {
Ok(o) => o,
Err(e) => {
log::warn!("Auto-start: failed to build run options for {daemon_id}: {e}");
return ResolveResult::Error(format!("Failed to build run options: {e}"));
}
};
if run_opts.dir.as_os_str().is_empty() {
run_opts.dir = cached.dir.clone();
}
log::info!("Auto-start: starting daemon {daemon_id} for slug '{slug}'");
let run_result = SUPERVISOR.run(run_opts).await;
if let Err(e) = run_result {
log::warn!("Auto-start: failed to start daemon {daemon_id}: {e}");
return ResolveResult::Error(format!("Failed to start daemon: {e}"));
}
let poll_interval = std::time::Duration::from_millis(250);
loop {
let daemons = {
let sf = SUPERVISOR.state_file.lock().await;
sf.daemons.clone()
};
if let Some(d) = daemons.get(daemon_id) {
if d.status.is_running() {
if let Some(port) = d.active_port.or_else(|| d.resolved_port.first().copied()) {
log::info!("Auto-start: daemon {daemon_id} is ready on port {port}");
return ResolveResult::Ready(port);
}
} else {
log::warn!(
"Auto-start: daemon {daemon_id} is no longer running (status: {})",
d.status
);
return ResolveResult::Error(format!(
"Daemon '{daemon_id}' started but exited unexpectedly.\n\
Check its logs for errors."
));
}
} else {
log::warn!("Auto-start: daemon {daemon_id} not found in state file after start");
return ResolveResult::Error(format!(
"Daemon '{daemon_id}' started but disappeared from the state file.\n\
Check its logs for errors."
));
}
tokio::time::sleep(poll_interval).await;
}
}
fn strip_tld(host: &str, tld: &str) -> Option<String> {
host.strip_suffix(&format!(".{tld}"))
.filter(|s| !s.is_empty())
.map(str::to_string)
}
fn bind_error_message(port: u16, err: &std::io::Error) -> String {
if port < 1024 {
format!(
"Failed to bind proxy server to port {port}: {err}\n\
Hint: ports below 1024 require elevated privileges. \
Try: sudo pitchfork supervisor start"
)
} else {
format!(
"Failed to bind proxy server to port {port}: {err}\n\
Hint: another process may already be using this port."
)
}
}
fn starting_html_response(slug: &str, raw_host: &str) -> Response {
let escaped_slug = slug
.replace('&', "&")
.replace('<', "<")
.replace('>', ">")
.replace('"', """)
.replace('\'', "'");
let escaped_host = raw_host
.replace('&', "&")
.replace('<', "<")
.replace('>', ">")
.replace('"', """)
.replace('\'', "'");
let html = format!(
r##"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<meta http-equiv="refresh" content="2">
<title>Starting {escaped_slug}… — pitchfork</title>
<style>
* {{ margin: 0; padding: 0; box-sizing: border-box; }}
body {{
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
background: #0f1117;
color: #e1e4e8;
display: flex;
align-items: center;
justify-content: center;
min-height: 100vh;
}}
.container {{
text-align: center;
max-width: 480px;
padding: 2rem;
}}
.spinner {{
width: 48px;
height: 48px;
border: 4px solid rgba(255, 255, 255, 0.1);
border-top-color: #58a6ff;
border-radius: 50%;
animation: spin 0.8s linear infinite;
margin: 0 auto 1.5rem;
}}
@keyframes spin {{
to {{ transform: rotate(360deg); }}
}}
h1 {{
font-size: 1.5rem;
font-weight: 600;
margin-bottom: 0.5rem;
}}
.slug {{
color: #58a6ff;
font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
}}
.host {{
color: #8b949e;
font-size: 0.875rem;
margin-top: 0.25rem;
}}
.hint {{
color: #8b949e;
font-size: 0.8rem;
margin-top: 1.5rem;
}}
</style>
</head>
<body>
<div class="container">
<div class="spinner"></div>
<h1>Starting <span class="slug">{escaped_slug}</span>…</h1>
<p class="host">{escaped_host}</p>
<p class="hint">This page will refresh automatically when the daemon is ready.</p>
</div>
</body>
</html>"##
);
Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.header("content-type", "text/html; charset=utf-8")
.header("retry-after", "2")
.body(Body::from(html))
.unwrap_or_else(|_| (StatusCode::SERVICE_UNAVAILABLE, "Starting…").into_response())
}
fn error_response(status: StatusCode, message: &str) -> Response {
(status, message.to_string()).into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_strip_tld() {
assert_eq!(
strip_tld("api.myproject.localhost", "localhost"),
Some("api.myproject".to_string())
);
assert_eq!(
strip_tld("api.localhost", "localhost"),
Some("api".to_string())
);
assert_eq!(strip_tld("localhost", "localhost"), None);
assert_eq!(
strip_tld("api.myproject.test", "test"),
Some("api.myproject".to_string())
);
assert_eq!(strip_tld("other.com", "localhost"), None);
}
#[cfg(feature = "proxy-tls")]
#[test]
fn test_generate_ca() {
let dir = tempfile::tempdir().unwrap();
let cert_path = dir.path().join("ca.pem");
let key_path = dir.path().join("ca-key.pem");
generate_ca(&cert_path, &key_path).unwrap();
assert!(cert_path.exists(), "ca.pem should be created");
assert!(key_path.exists(), "ca-key.pem should be created");
let cert_pem = std::fs::read_to_string(&cert_path).unwrap();
let key_pem = std::fs::read_to_string(&key_path).unwrap();
assert!(cert_pem.contains("BEGIN CERTIFICATE"), "should be PEM cert");
assert!(
key_pem.contains("BEGIN") && key_pem.contains("PRIVATE KEY"),
"should be PEM key"
);
}
}