pitchfork-cli 2.6.0

Daemons with DX
Documentation
use crate::Result;
use crate::settings::settings;
use axum::{
    Router,
    body::Body,
    http::{Method, Request, StatusCode},
    middleware::{self, Next},
    response::{Redirect, Response},
    routing::{get, post},
};
use std::net::SocketAddr;

use super::routes;
use super::static_files::static_handler;

/// CSRF protection middleware - requires HX-Request header on POST requests.
/// This prevents cross-origin form submissions since custom headers trigger CORS preflight.
async fn csrf_protection(request: Request<Body>, next: Next) -> Result<Response, StatusCode> {
    if request.method() == Method::POST {
        // htmx automatically sends HX-Request header on all requests
        // Cross-origin form submissions cannot set custom headers
        if !request.headers().contains_key("hx-request") {
            return Err(StatusCode::FORBIDDEN);
        }
    }
    Ok(next.run(request).await)
}

pub async fn serve(port: u16, web_path: Option<String>) -> Result<()> {
    let base_path = super::normalize_base_path(web_path.as_deref())?;
    super::BASE_PATH
        .set(base_path.clone())
        .expect("BASE_PATH already set; serve() must only be called once per process");
    let s = settings();
    let bind_address = &s.web.bind_address;
    // port_attempts is stored as i64; clamp to a sane u16 range rather than
    // silently truncating negative/oversized values with `as u16`.
    let port_attempts: u16 = u16::try_from(s.web.port_attempts)
        .unwrap_or_else(|_| {
            warn!(
                "web.port_attempts value {} is out of range (1-65535), clamping to 10",
                s.web.port_attempts
            );
            10
        })
        .max(1);
    let inner = Router::new()
        // Dashboard
        .route("/", get(routes::index::index))
        .route("/_stats", get(routes::index::stats_partial))
        .route("/health", get(|| async { "OK" }))
        // Daemons
        .route("/daemons", get(routes::daemons::list))
        .route("/daemons/_list", get(routes::daemons::list_partial))
        .route("/daemons/{id}", get(routes::daemons::show))
        .route("/daemons/{id}/start", post(routes::daemons::start))
        .route("/daemons/{id}/stop", post(routes::daemons::stop))
        .route("/daemons/{id}/restart", post(routes::daemons::restart))
        .route("/daemons/{id}/enable", post(routes::daemons::enable))
        .route("/daemons/{id}/disable", post(routes::daemons::disable))
        // Logs
        .route("/logs", get(routes::logs::index))
        .route("/logs/{id}", get(routes::logs::show))
        .route("/logs/{id}/_lines", get(routes::logs::lines_partial))
        .route("/logs/{id}/stream", get(routes::logs::stream_sse))
        .route("/logs/{id}/clear", post(routes::logs::clear))
        // Config
        .route("/config", get(routes::config::list))
        .route("/config/edit", get(routes::config::edit))
        .route("/config/validate", post(routes::config::validate))
        .route("/config/save", post(routes::config::save))
        // Static files
        .route("/static/{*path}", get(static_handler))
        // CSRF protection for all POST endpoints
        .layer(middleware::from_fn(csrf_protection));

    let app = if base_path.is_empty() {
        inner
    } else {
        let redirect_target = format!("{base_path}/");
        Router::new()
            .route(
                "/",
                get(move || async move { Redirect::temporary(&redirect_target) }),
            )
            .nest(&base_path, inner)
    };

    // Parse bind address
    let ip_addr: std::net::IpAddr = bind_address
        .parse()
        .map_err(|e| miette::miette!("Invalid bind address '{}': {}", bind_address, e))?;

    // Try up to port_attempts ports starting from the given port
    let mut last_error = None;
    for offset in 0..port_attempts {
        let try_port = port.saturating_add(offset);
        let addr = SocketAddr::from((ip_addr, try_port));

        match tokio::net::TcpListener::bind(addr).await {
            Ok(listener) => {
                let actual_addr = listener
                    .local_addr()
                    .map_err(|e| miette::miette!("Failed to inspect bound web port: {}", e))?;
                if offset > 0 {
                    info!(
                        "Port {port} was in use, using port {} instead",
                        actual_addr.port()
                    );
                }
                if base_path.is_empty() {
                    info!("Web UI listening on http://{actual_addr}");
                } else {
                    info!("Web UI listening on http://{actual_addr}{base_path}/");
                }

                return axum::serve(listener, app)
                    .await
                    .map_err(|e| miette::miette!("Web server error: {}", e));
            }
            Err(e) => {
                debug!("Port {try_port} unavailable: {e}");
                last_error = Some(e);
            }
        }
    }

    Err(miette::miette!(
        "Failed to bind web server: tried ports {}-{}, all in use. Last error: {}",
        port,
        port.saturating_add(port_attempts - 1),
        last_error.map(|e| e.to_string()).unwrap_or_default()
    ))
}