folk-plugin-http 0.2.3

HTTP plugin for Folk — accepts connections via hyper and dispatches to PHP workers
Documentation
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;

use anyhow::Result;
use axum::Router;
use axum::body::Body;
use axum::extract::{ConnectInfo, State};
use axum::http::{self, Request, Response};
use axum::routing::any;
use ipnet::IpNet;
use tokio::net::TcpListener;
use tokio::sync::watch;
use tower::ServiceBuilder;
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::timeout::TimeoutLayer;
use tracing::{error, info};

use crate::config::HttpConfig;
use crate::payload::{decode_response, encode_request};

#[derive(Clone)]
struct AppState {
    executor: Arc<dyn folk_api::Executor>,
    config: Arc<HttpConfig>,
    active_connections: Arc<AtomicU64>,
}

pub struct HttpServer {
    config: HttpConfig,
    executor: Arc<dyn folk_api::Executor>,
    active_connections: Arc<AtomicU64>,
}

impl HttpServer {
    pub fn new(
        config: HttpConfig,
        executor: Arc<dyn folk_api::Executor>,
        active_connections: Arc<AtomicU64>,
    ) -> Self {
        Self {
            config,
            executor,
            active_connections,
        }
    }

    pub async fn run(self, shutdown: watch::Receiver<bool>) -> Result<()> {
        let state = AppState {
            executor: self.executor.clone(),
            config: Arc::new(self.config.clone()),
            active_connections: self.active_connections.clone(),
        };

        let mut app = Router::new()
            .route("/{*path}", any(handle))
            .route("/", any(handle))
            .with_state(state)
            .layer(
                ServiceBuilder::new()
                    .layer(RequestBodyLimitLayer::new(self.config.max_request_size))
                    .layer(TimeoutLayer::with_status_code(
                        http::StatusCode::GATEWAY_TIMEOUT,
                        self.config.write_timeout,
                    )),
            );

        if self.config.compression.enabled {
            app = app.layer(build_compression_layer(&self.config.compression));
        }

        #[cfg(feature = "tls")]
        if let Some(ref tls) = self.config.tls {
            return self.run_tls(app, tls, shutdown).await;
        }

        #[cfg(feature = "h2c")]
        if self.config.h2c {
            return self.run_h2c(app, shutdown).await;
        }

        self.run_plain(app, shutdown).await
    }

    async fn run_plain(&self, app: Router, shutdown: watch::Receiver<bool>) -> Result<()> {
        let listener = TcpListener::bind(self.config.listen).await?;

        axum::serve(
            listener,
            app.into_make_service_with_connect_info::<SocketAddr>(),
        )
        .with_graceful_shutdown(shutdown_signal(shutdown))
        .await?;

        Ok(())
    }

    #[cfg(feature = "tls")]
    async fn run_tls(
        &self,
        app: Router,
        tls: &crate::config::TlsConfig,
        shutdown: watch::Receiver<bool>,
    ) -> Result<()> {
        use axum_server::Handle;
        use axum_server::tls_rustls::RustlsConfig;

        let rustls_config = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;

        info!(cert = %tls.cert.display(), "TLS enabled");

        let handle = Handle::new();
        let shutdown_handle = handle.clone();
        tokio::spawn(async move {
            shutdown_signal(shutdown).await;
            shutdown_handle.graceful_shutdown(None);
        });

        axum_server::bind_rustls(self.config.listen, rustls_config)
            .handle(handle)
            .serve(app.into_make_service_with_connect_info::<SocketAddr>())
            .await?;

        Ok(())
    }

    #[cfg(feature = "h2c")]
    async fn run_h2c(&self, app: Router, mut shutdown: watch::Receiver<bool>) -> Result<()> {
        use hyper_util::rt::{TokioExecutor, TokioIo};
        use hyper_util::server::conn::auto::Builder as AutoBuilder;

        info!("h2c (HTTP/2 cleartext) enabled");

        let listener = TcpListener::bind(self.config.listen).await?;
        let builder = Arc::new(AutoBuilder::new(TokioExecutor::new()));
        let mut tasks = tokio::task::JoinSet::new();

        loop {
            tokio::select! {
                result = listener.accept() => {
                    let (stream, remote_addr) = result?;
                    let app = app.clone();
                    let builder = builder.clone();
                    tasks.spawn(async move {
                        let svc = hyper::service::service_fn(move |mut req: Request<hyper::body::Incoming>| {
                            // Inject ConnectInfo manually since we bypass axum::serve.
                            req.extensions_mut().insert(ConnectInfo(remote_addr));
                            let app = app.clone();
                            async move {
                                let resp = tower::Service::call(&mut app.clone(), req).await;
                                resp.map_err(|e| match e {})
                            }
                        });
                        let _ = builder.serve_connection_with_upgrades(TokioIo::new(stream), svc).await;
                    });
                }
                _ = async {
                    loop {
                        if shutdown.changed().await.is_err() || *shutdown.borrow() {
                            break;
                        }
                    }
                } => {
                    break;
                }
            }
        }

        // Wait for active connections to finish
        while tasks.join_next().await.is_some() {}

        Ok(())
    }
}

async fn shutdown_signal(mut shutdown: watch::Receiver<bool>) {
    loop {
        if shutdown.changed().await.is_err() || *shutdown.borrow() {
            break;
        }
    }
}

struct ConnectionGuard(Arc<AtomicU64>);

impl Drop for ConnectionGuard {
    fn drop(&mut self) {
        self.0.fetch_sub(1, Ordering::Relaxed);
    }
}

async fn handle(
    State(state): State<AppState>,
    connect_info: ConnectInfo<SocketAddr>,
    req: Request<Body>,
) -> Response<Body> {
    state.active_connections.fetch_add(1, Ordering::Relaxed);
    let _conn_guard = ConnectionGuard(state.active_connections.clone());
    let start = Instant::now();
    let method = req.method().clone();
    let uri = req.uri().clone();
    let peer_addr = connect_info.0;

    let client_ip = resolve_client_ip(
        peer_addr.ip(),
        req.headers()
            .get("x-forwarded-for")
            .and_then(|v| v.to_str().ok()),
        &state.config.trusted_proxies,
    );

    let response = handle_inner(&state, req).await;

    if state.config.access_log {
        let duration = start.elapsed();
        let status = response.status().as_u16();
        let response_bytes = response
            .headers()
            .get(http::header::CONTENT_LENGTH)
            .and_then(|v| v.to_str().ok())
            .and_then(|v| v.parse::<u64>().ok())
            .unwrap_or(0);
        info!(
            client_ip = %client_ip,
            method = %method,
            uri = %uri,
            status = status,
            duration_ms = duration.as_millis() as u64,
            response_bytes = response_bytes,
            "http request",
        );
    }

    // _conn_guard drop handles fetch_sub
    response
}

async fn handle_inner(state: &AppState, req: Request<Body>) -> Response<Body> {
    let max_body = state.config.max_request_size;
    let read_timeout = state.config.read_timeout;
    let payload = match tokio::time::timeout(read_timeout, encode_request(req, max_body)).await {
        Ok(Ok(p)) => p,
        Ok(Err(e)) => {
            error!(error = ?e, "encode request");
            return Response::builder()
                .status(500)
                .body(Body::from("encode error"))
                .unwrap();
        }
        Err(_) => {
            return Response::builder()
                .status(408)
                .body(Body::from("request body read timeout"))
                .unwrap();
        }
    };

    let response_value = match state.executor.execute_value("http.handle", payload).await {
        Ok(v) => v,
        Err(e) => {
            error!(error = ?e, "dispatch to worker");
            return Response::builder()
                .status(502)
                .body(Body::from("worker error"))
                .unwrap();
        }
    };

    match decode_response(response_value) {
        Ok(r) => r,
        Err(e) => {
            error!(error = ?e, "decode response");
            Response::builder()
                .status(500)
                .body(Body::from("decode error"))
                .unwrap()
        }
    }
}

/// Resolve the real client IP from X-Forwarded-For if the peer is a trusted proxy.
///
/// Walks the X-Forwarded-For chain from right to left, stopping at the first
/// IP that is NOT in a trusted subnet. This is the standard secure algorithm
/// (rightmost non-trusted).
pub fn resolve_client_ip(peer_ip: IpAddr, xff: Option<&str>, trusted: &[IpNet]) -> IpAddr {
    if trusted.is_empty() {
        return peer_ip;
    }

    if !is_trusted(peer_ip, trusted) {
        return peer_ip;
    }

    let Some(xff) = xff else {
        return peer_ip;
    };

    let addrs: Vec<&str> = xff.split(',').map(|s| s.trim()).collect();

    // Walk from right to left — the rightmost non-trusted IP is the client.
    for addr_str in addrs.iter().rev() {
        if let Ok(ip) = addr_str.parse::<IpAddr>() {
            if !is_trusted(ip, trusted) {
                return ip;
            }
        }
    }

    // All IPs in the chain are trusted — use peer IP.
    peer_ip
}

fn is_trusted(ip: IpAddr, trusted: &[IpNet]) -> bool {
    trusted.iter().any(|net| net.contains(&ip))
}

fn build_compression_layer(
    config: &crate::config::CompressionConfig,
) -> tower_http::compression::CompressionLayer<tower_http::compression::predicate::SizeAbove> {
    use crate::config::CompressionAlgorithm;
    use tower_http::compression::CompressionLayer;

    let mut layer = CompressionLayer::new()
        .no_gzip()
        .no_br()
        .no_zstd()
        .no_deflate();

    for algo in &config.algorithms {
        layer = match algo {
            CompressionAlgorithm::Gzip => layer.gzip(true),
            CompressionAlgorithm::Br => layer.br(true),
            CompressionAlgorithm::Zstd => layer.zstd(true),
            CompressionAlgorithm::Deflate => layer.deflate(true),
        };
    }

    #[allow(clippy::cast_possible_truncation)]
    let min_size = config.min_size as u16;
    layer.compress_when(tower_http::compression::predicate::SizeAbove::new(min_size))
}