Skip to main content

pitchfork_cli/web/
server.rs

1use crate::Result;
2use axum::{
3    Router,
4    body::Body,
5    http::{Method, Request, StatusCode},
6    middleware::{self, Next},
7    response::Response,
8    routing::{get, post},
9};
10use std::net::SocketAddr;
11
12use super::routes;
13use super::static_files::static_handler;
14
15/// CSRF protection middleware - requires HX-Request header on POST requests.
16/// This prevents cross-origin form submissions since custom headers trigger CORS preflight.
17async fn csrf_protection(request: Request<Body>, next: Next) -> Result<Response, StatusCode> {
18    if request.method() == Method::POST {
19        // htmx automatically sends HX-Request header on all requests
20        // Cross-origin form submissions cannot set custom headers
21        if !request.headers().contains_key("hx-request") {
22            return Err(StatusCode::FORBIDDEN);
23        }
24    }
25    Ok(next.run(request).await)
26}
27
28/// Number of ports to try before giving up
29const PORT_ATTEMPTS: u16 = 10;
30
31pub async fn serve(port: u16) -> Result<()> {
32    let app = Router::new()
33        // Dashboard
34        .route("/", get(routes::index::index))
35        .route("/_stats", get(routes::index::stats_partial))
36        .route("/health", get(|| async { "OK" }))
37        // Daemons
38        .route("/daemons", get(routes::daemons::list))
39        .route("/daemons/_list", get(routes::daemons::list_partial))
40        .route("/daemons/{id}", get(routes::daemons::show))
41        .route("/daemons/{id}/start", post(routes::daemons::start))
42        .route("/daemons/{id}/stop", post(routes::daemons::stop))
43        .route("/daemons/{id}/restart", post(routes::daemons::restart))
44        .route("/daemons/{id}/enable", post(routes::daemons::enable))
45        .route("/daemons/{id}/disable", post(routes::daemons::disable))
46        // Logs
47        .route("/logs", get(routes::logs::index))
48        .route("/logs/{id}", get(routes::logs::show))
49        .route("/logs/{id}/_lines", get(routes::logs::lines_partial))
50        .route("/logs/{id}/stream", get(routes::logs::stream_sse))
51        .route("/logs/{id}/clear", post(routes::logs::clear))
52        // Config
53        .route("/config", get(routes::config::list))
54        .route("/config/edit", get(routes::config::edit))
55        .route("/config/validate", post(routes::config::validate))
56        .route("/config/save", post(routes::config::save))
57        // Static files
58        .route("/static/{*path}", get(static_handler))
59        // CSRF protection for all POST endpoints
60        .layer(middleware::from_fn(csrf_protection));
61
62    // Try up to PORT_ATTEMPTS ports starting from the given port
63    let mut last_error = None;
64    for offset in 0..PORT_ATTEMPTS {
65        let try_port = port.saturating_add(offset);
66        let addr = SocketAddr::from(([127, 0, 0, 1], try_port));
67
68        match tokio::net::TcpListener::bind(addr).await {
69            Ok(listener) => {
70                if offset > 0 {
71                    info!("Port {port} was in use, using port {try_port} instead");
72                }
73                info!("Web UI listening on http://{addr}");
74
75                return axum::serve(listener, app)
76                    .await
77                    .map_err(|e| miette::miette!("Web server error: {}", e));
78            }
79            Err(e) => {
80                debug!("Port {try_port} unavailable: {e}");
81                last_error = Some(e);
82            }
83        }
84    }
85
86    Err(miette::miette!(
87        "Failed to bind web server: tried ports {}-{}, all in use. Last error: {}",
88        port,
89        port.saturating_add(PORT_ATTEMPTS - 1),
90        last_error.map(|e| e.to_string()).unwrap_or_default()
91    ))
92}