Skip to main content

jax_daemon/http_server/
mod.rs

1use axum::body::Body;
2use axum::extract::DefaultBodyLimit;
3use axum::http::{header, StatusCode, Uri};
4use axum::response::{IntoResponse, Response};
5use axum::routing::get;
6use axum::{Extension, Router};
7use http::header::{ACCEPT, ORIGIN};
8use http::Method;
9use rust_embed::RustEmbed;
10use tokio::sync::watch;
11use tower_http::cors::{Any, CorsLayer};
12use tower_http::trace::TraceLayer;
13use tower_http::trace::{DefaultOnFailure, DefaultOnResponse};
14use tower_http::LatencyUnit;
15
16pub mod api;
17mod config;
18mod gateway_index;
19mod handlers;
20mod html;
21
22pub use config::Config;
23
24use crate::ServiceState;
25
26const API_PREFIX: &str = "/api";
27const STATUS_PREFIX: &str = "/_status";
28
29/// Maximum upload size in bytes (500 MB)
30pub const MAX_UPLOAD_SIZE_BYTES: usize = 500 * 1024 * 1024;
31
32#[derive(RustEmbed)]
33#[folder = "static"]
34struct StaticAssets;
35
36async fn static_handler(uri: Uri) -> impl IntoResponse {
37    let path = uri
38        .path()
39        .trim_start_matches('/')
40        .trim_start_matches("static/");
41
42    match StaticAssets::get(path) {
43        Some(content) => {
44            let mime = mime_guess::from_path(path).first_or_octet_stream();
45            Response::builder()
46                .status(StatusCode::OK)
47                .header(header::CONTENT_TYPE, mime.as_ref())
48                .body(Body::from(content.data.to_vec()))
49                .unwrap()
50        }
51        None => {
52            // Serve 404.html if file not found
53            match StaticAssets::get("404.html") {
54                Some(content) => Response::builder()
55                    .status(StatusCode::NOT_FOUND)
56                    .header(header::CONTENT_TYPE, "text/html")
57                    .body(Body::from(content.data.to_vec()))
58                    .unwrap(),
59                None => Response::builder()
60                    .status(StatusCode::NOT_FOUND)
61                    .body(Body::from("Not Found"))
62                    .unwrap(),
63            }
64        }
65    }
66}
67
68/// Run the API HTTP server (private, serves /_status + /api routes).
69pub async fn run_api(
70    config: Config,
71    state: ServiceState,
72    mut shutdown_rx: watch::Receiver<()>,
73) -> Result<(), HttpServerError> {
74    let listen_addr = config.listen_addr;
75    let log_level = config.log_level;
76    let trace_layer = TraceLayer::new_for_http()
77        .on_response(
78            DefaultOnResponse::new()
79                .include_headers(false)
80                .level(log_level)
81                .latency_unit(LatencyUnit::Micros),
82        )
83        .on_failure(DefaultOnFailure::new().latency_unit(LatencyUnit::Micros));
84
85    let router = Router::new()
86        .nest(STATUS_PREFIX, health::router(state.clone()))
87        .nest(API_PREFIX, api::router(state.clone()))
88        .fallback(handlers::not_found_handler)
89        .layer(DefaultBodyLimit::max(MAX_UPLOAD_SIZE_BYTES))
90        .layer(Extension(config.clone()))
91        .with_state(state)
92        .layer(trace_layer);
93
94    tracing::info!(addr = ?listen_addr, "API server listening");
95    let listener = tokio::net::TcpListener::bind(listen_addr).await?;
96
97    axum::serve(listener, router)
98        .with_graceful_shutdown(async move {
99            let _ = shutdown_rx.changed().await;
100        })
101        .await?;
102
103    Ok(())
104}
105
106/// Run the gateway HTTP server (public, serves /_status + /gw + / + /static routes).
107pub async fn run_gateway(
108    config: Config,
109    state: ServiceState,
110    mut shutdown_rx: watch::Receiver<()>,
111) -> Result<(), HttpServerError> {
112    let listen_addr = config.listen_addr;
113    let log_level = config.log_level;
114    let trace_layer = TraceLayer::new_for_http()
115        .on_response(
116            DefaultOnResponse::new()
117                .include_headers(false)
118                .level(log_level)
119                .latency_unit(LatencyUnit::Micros),
120        )
121        .on_failure(DefaultOnFailure::new().latency_unit(LatencyUnit::Micros));
122
123    tracing::info!("Static files embedded in binary");
124
125    // Gateway CORS (GET only) for gateway routes
126    let gateway_cors = CorsLayer::new()
127        .allow_methods(vec![Method::GET])
128        .allow_headers(vec![ACCEPT, ORIGIN])
129        .allow_origin(Any)
130        .allow_credentials(false);
131
132    // Gateway routes with their own CORS layer
133    let gateway_routes = Router::new()
134        .route("/:bucket_id", get(html::gateway::root_handler))
135        .route("/:bucket_id/", get(html::gateway::root_handler))
136        .route("/:bucket_id/*file_path", get(html::gateway::handler))
137        .with_state(state.clone())
138        .layer(gateway_cors);
139
140    let router = Router::new()
141        .nest(STATUS_PREFIX, health::router(state.clone()))
142        .nest("/gw", gateway_routes)
143        .route("/", get(gateway_index::handler))
144        .route("/static/*path", get(static_handler))
145        .fallback(handlers::not_found_handler)
146        .layer(Extension(config.clone()))
147        .with_state(state)
148        .layer(trace_layer);
149
150    tracing::info!(addr = ?listen_addr, "Gateway server listening");
151    let listener = tokio::net::TcpListener::bind(listen_addr).await?;
152
153    axum::serve(listener, router)
154        .with_graceful_shutdown(async move {
155            let _ = shutdown_rx.changed().await;
156        })
157        .await?;
158
159    Ok(())
160}
161
162mod health;
163
164#[derive(Debug, thiserror::Error)]
165pub enum HttpServerError {
166    #[error("an error occurred running the HTTP server: {0}")]
167    ServingFailed(#[from] std::io::Error),
168}