Skip to main content

pitchfork_cli/web/
server.rs

1use crate::Result;
2use crate::settings::settings;
3use axum::{
4    Router,
5    body::Body,
6    http::{Method, Request, StatusCode},
7    middleware::{self, Next},
8    response::{Redirect, Response},
9    routing::{get, post},
10};
11use std::net::SocketAddr;
12
13use super::routes;
14use super::static_files::static_handler;
15
16/// CSRF protection middleware - requires HX-Request header on POST requests.
17/// This prevents cross-origin form submissions since custom headers trigger CORS preflight.
18async fn csrf_protection(request: Request<Body>, next: Next) -> Result<Response, StatusCode> {
19    if request.method() == Method::POST {
20        // htmx automatically sends HX-Request header on all requests
21        // Cross-origin form submissions cannot set custom headers
22        if !request.headers().contains_key("hx-request") {
23            return Err(StatusCode::FORBIDDEN);
24        }
25    }
26    Ok(next.run(request).await)
27}
28
29pub async fn serve(port: u16, web_path: Option<String>) -> Result<()> {
30    let base_path = super::normalize_base_path(web_path.as_deref())?;
31    super::BASE_PATH
32        .set(base_path.clone())
33        .expect("BASE_PATH already set; serve() must only be called once per process");
34    let s = settings();
35    let bind_address = &s.web.bind_address;
36    // port_attempts is stored as i64; clamp to a sane u16 range rather than
37    // silently truncating negative/oversized values with `as u16`.
38    let port_attempts: u16 = u16::try_from(s.web.port_attempts)
39        .unwrap_or_else(|_| {
40            warn!(
41                "web.port_attempts value {} is out of range (1-65535), clamping to 10",
42                s.web.port_attempts
43            );
44            10
45        })
46        .max(1);
47    let inner = Router::new()
48        // Dashboard
49        .route("/", get(routes::index::index))
50        .route("/_stats", get(routes::index::stats_partial))
51        .route("/health", get(|| async { "OK" }))
52        // Daemons
53        .route("/daemons", get(routes::daemons::list))
54        .route("/daemons/_list", get(routes::daemons::list_partial))
55        .route("/daemons/{id}", get(routes::daemons::show))
56        .route("/daemons/{id}/start", post(routes::daemons::start))
57        .route("/daemons/{id}/stop", post(routes::daemons::stop))
58        .route("/daemons/{id}/restart", post(routes::daemons::restart))
59        .route("/daemons/{id}/enable", post(routes::daemons::enable))
60        .route("/daemons/{id}/disable", post(routes::daemons::disable))
61        // Logs
62        .route("/logs", get(routes::logs::index))
63        .route("/logs/{id}", get(routes::logs::show))
64        .route("/logs/{id}/_lines", get(routes::logs::lines_partial))
65        .route("/logs/{id}/stream", get(routes::logs::stream_sse))
66        .route("/logs/{id}/clear", post(routes::logs::clear))
67        // Config
68        .route("/config", get(routes::config::list))
69        .route("/config/edit", get(routes::config::edit))
70        .route("/config/validate", post(routes::config::validate))
71        .route("/config/save", post(routes::config::save))
72        // Static files
73        .route("/static/{*path}", get(static_handler))
74        // CSRF protection for all POST endpoints
75        .layer(middleware::from_fn(csrf_protection));
76
77    let app = if base_path.is_empty() {
78        inner
79    } else {
80        let redirect_target = format!("{base_path}/");
81        Router::new()
82            .route(
83                "/",
84                get(move || async move { Redirect::temporary(&redirect_target) }),
85            )
86            .nest(&base_path, inner)
87    };
88
89    // Parse bind address
90    let ip_addr: std::net::IpAddr = bind_address
91        .parse()
92        .map_err(|e| miette::miette!("Invalid bind address '{}': {}", bind_address, e))?;
93
94    // Try up to port_attempts ports starting from the given port
95    let mut last_error = None;
96    for offset in 0..port_attempts {
97        let try_port = port.saturating_add(offset);
98        let addr = SocketAddr::from((ip_addr, try_port));
99
100        match tokio::net::TcpListener::bind(addr).await {
101            Ok(listener) => {
102                let actual_addr = listener
103                    .local_addr()
104                    .map_err(|e| miette::miette!("Failed to inspect bound web port: {}", e))?;
105                if offset > 0 {
106                    info!(
107                        "Port {port} was in use, using port {} instead",
108                        actual_addr.port()
109                    );
110                }
111                if base_path.is_empty() {
112                    info!("Web UI listening on http://{actual_addr}");
113                } else {
114                    info!("Web UI listening on http://{actual_addr}{base_path}/");
115                }
116
117                return axum::serve(listener, app)
118                    .await
119                    .map_err(|e| miette::miette!("Web server error: {}", e));
120            }
121            Err(e) => {
122                debug!("Port {try_port} unavailable: {e}");
123                last_error = Some(e);
124            }
125        }
126    }
127
128    Err(miette::miette!(
129        "Failed to bind web server: tried ports {}-{}, all in use. Last error: {}",
130        port,
131        port.saturating_add(port_attempts - 1),
132        last_error.map(|e| e.to_string()).unwrap_or_default()
133    ))
134}