Skip to main content

jax_daemon/http_server/
mod.rs

1use axum::body::Body;
2use axum::extract::DefaultBodyLimit;
3use axum::http::{header, StatusCode, Uri};
4use axum::response::{IntoResponse, Response};
5use axum::routing::get;
6use axum::{Extension, Router};
7use http::header::{ACCEPT, ORIGIN};
8use http::Method;
9use rust_embed::RustEmbed;
10use tokio::sync::watch;
11use tower_http::cors::{Any, CorsLayer};
12use tower_http::trace::TraceLayer;
13use tower_http::trace::{DefaultOnFailure, DefaultOnResponse};
14use tower_http::LatencyUnit;
15
16pub mod api;
17mod config;
18mod gateway;
19mod handlers;
20
21pub use config::Config;
22
23use crate::ServiceState;
24
25const API_PREFIX: &str = "/api";
26const STATUS_PREFIX: &str = "/_status";
27
28/// Maximum upload size in bytes (500 MB)
29pub const MAX_UPLOAD_SIZE_BYTES: usize = 500 * 1024 * 1024;
30
31#[derive(RustEmbed)]
32#[folder = "static"]
33struct StaticAssets;
34
35async fn static_handler(uri: Uri) -> impl IntoResponse {
36    let path = uri
37        .path()
38        .trim_start_matches('/')
39        .trim_start_matches("static/");
40
41    match StaticAssets::get(path) {
42        Some(content) => {
43            let mime = mime_guess::from_path(path).first_or_octet_stream();
44            Response::builder()
45                .status(StatusCode::OK)
46                .header(header::CONTENT_TYPE, mime.as_ref())
47                .body(Body::from(content.data.to_vec()))
48                .unwrap()
49        }
50        None => {
51            // Serve 404.html if file not found
52            match StaticAssets::get("404.html") {
53                Some(content) => Response::builder()
54                    .status(StatusCode::NOT_FOUND)
55                    .header(header::CONTENT_TYPE, "text/html")
56                    .body(Body::from(content.data.to_vec()))
57                    .unwrap(),
58                None => Response::builder()
59                    .status(StatusCode::NOT_FOUND)
60                    .body(Body::from("Not Found"))
61                    .unwrap(),
62            }
63        }
64    }
65}
66
67/// Run the API HTTP server (private, serves /_status + /api routes).
68pub async fn run_api(
69    config: Config,
70    state: ServiceState,
71    mut shutdown_rx: watch::Receiver<()>,
72) -> Result<(), HttpServerError> {
73    let listen_addr = config.listen_addr;
74    let log_level = config.log_level;
75    let trace_layer = TraceLayer::new_for_http()
76        .on_response(
77            DefaultOnResponse::new()
78                .include_headers(false)
79                .level(log_level)
80                .latency_unit(LatencyUnit::Micros),
81        )
82        .on_failure(DefaultOnFailure::new().latency_unit(LatencyUnit::Micros));
83
84    let router = Router::new()
85        .nest(STATUS_PREFIX, health::router(state.clone()))
86        .nest(API_PREFIX, api::router(state.clone()))
87        .fallback(handlers::not_found_handler)
88        .layer(DefaultBodyLimit::max(MAX_UPLOAD_SIZE_BYTES))
89        .layer(Extension(config.clone()))
90        .with_state(state)
91        .layer(trace_layer);
92
93    tracing::info!(addr = ?listen_addr, "API server listening");
94    let listener = tokio::net::TcpListener::bind(listen_addr).await?;
95
96    axum::serve(listener, router)
97        .with_graceful_shutdown(async move {
98            let _ = shutdown_rx.changed().await;
99        })
100        .await?;
101
102    Ok(())
103}
104
105/// Run the gateway HTTP server (public, serves /_status + /gw + / + /static routes).
106pub async fn run_gateway(
107    config: Config,
108    state: ServiceState,
109    mut shutdown_rx: watch::Receiver<()>,
110) -> Result<(), HttpServerError> {
111    let listen_addr = config.listen_addr;
112    let log_level = config.log_level;
113    let trace_layer = TraceLayer::new_for_http()
114        .on_response(
115            DefaultOnResponse::new()
116                .include_headers(false)
117                .level(log_level)
118                .latency_unit(LatencyUnit::Micros),
119        )
120        .on_failure(DefaultOnFailure::new().latency_unit(LatencyUnit::Micros));
121
122    tracing::info!("Static files embedded in binary");
123
124    // Gateway CORS (GET only) for gateway routes
125    let gateway_cors = CorsLayer::new()
126        .allow_methods(vec![Method::GET])
127        .allow_headers(vec![ACCEPT, ORIGIN])
128        .allow_origin(Any)
129        .allow_credentials(false);
130
131    // Gateway routes with their own CORS layer
132    let gateway_routes = Router::new()
133        .route("/:bucket_id", get(gateway::root_handler))
134        .route("/:bucket_id/", get(gateway::root_handler))
135        .route("/:bucket_id/*file_path", get(gateway::handler))
136        .with_state(state.clone())
137        .layer(gateway_cors);
138
139    let router = Router::new()
140        .nest(STATUS_PREFIX, health::router(state.clone()))
141        .nest("/gw", gateway_routes)
142        .route("/", get(gateway::index::handler))
143        .route("/static/*path", get(static_handler))
144        .fallback(handlers::not_found_handler)
145        .layer(Extension(config.clone()))
146        .with_state(state)
147        .layer(trace_layer);
148
149    tracing::info!(addr = ?listen_addr, "Gateway server listening");
150    let listener = tokio::net::TcpListener::bind(listen_addr).await?;
151
152    axum::serve(listener, router)
153        .with_graceful_shutdown(async move {
154            let _ = shutdown_rx.changed().await;
155        })
156        .await?;
157
158    Ok(())
159}
160
161pub mod health;
162
163#[derive(Debug, thiserror::Error)]
164pub enum HttpServerError {
165    #[error("an error occurred running the HTTP server: {0}")]
166    ServingFailed(#[from] std::io::Error),
167}