Skip to main content

crw_server/
app.rs

1use axum::Router;
2use axum::body::Body;
3use axum::extract::DefaultBodyLimit;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::Duration;
11use tower_http::cors::CorsLayer;
12use tower_http::set_header::SetResponseHeaderLayer;
13use tower_http::timeout::TimeoutLayer;
14use tower_http::trace::TraceLayer;
15
16/// Maximum request body size (1 MB) to prevent memory exhaustion from large payloads.
17const MAX_BODY_SIZE: usize = 1024 * 1024;
18
19use crate::middleware::auth_middleware;
20use crate::routes::{self, method_not_allowed};
21use crate::state::AppState;
22
23pub fn create_app(state: AppState) -> Router {
24    let api_keys = Arc::new(state.config.auth.api_keys.clone());
25    // Tower outer timeout. `effective_request_timeout_secs()` widens the
26    // operator baseline so the longest legitimate handler runtime (auto-extended
27    // scrape, search enrichment fan-out, map's 300s ceiling) isn't cancelled
28    // by the outer layer before the inner deadline fires. See issue #35.
29    let timeout = Duration::from_secs(state.config.effective_request_timeout_secs());
30    let rate_limit_rps = state.config.server.rate_limit_rps;
31
32    // v1 + v2 routers merged before the shared auth + rate-limit layers, so the
33    // /v2/* surface (issue #62) inherits auth, rate-limiting, body-limit and the
34    // timeout layer identically to v1. `/mcp` is version-less.
35    let api_routes = routes::v1::router().merge(routes::v2::router()).route(
36        "/mcp",
37        post(routes::mcp::mcp_handler).fallback(method_not_allowed),
38    );
39
40    let api_routes = if api_keys.is_empty() {
41        api_routes.with_state(state.clone())
42    } else {
43        api_routes
44            .route_layer(axum::middleware::from_fn_with_state(
45                api_keys,
46                auth_middleware,
47            ))
48            .with_state(state.clone())
49    };
50
51    let rate_limiter = if rate_limit_rps > 0 {
52        Some(Arc::new(RateLimiter::new(rate_limit_rps)))
53    } else {
54        None
55    };
56
57    Router::new()
58        .route(
59            "/health",
60            get(routes::health::health).fallback(method_not_allowed),
61        )
62        .route(
63            "/openapi.json",
64            get(routes::openapi::serve_openapi_3_1).fallback(method_not_allowed),
65        )
66        .route(
67            "/openapi-3.0.json",
68            get(routes::openapi::serve_openapi_3_0).fallback(method_not_allowed),
69        )
70        .route(
71            "/ready",
72            get(routes::health::ready).fallback(method_not_allowed),
73        )
74        .route(
75            "/metrics",
76            get(routes::metrics::metrics).fallback(method_not_allowed),
77        )
78        .route(
79            "/metrics/renderer-breakers",
80            get(routes::breakers::renderer_breakers).fallback(method_not_allowed),
81        )
82        .route(
83            "/admin/breakers/reset",
84            post(routes::breakers::reset_breakers).fallback(method_not_allowed),
85        )
86        .with_state(state)
87        .merge(api_routes)
88        .layer(axum::middleware::from_fn(move |req, next| {
89            let limiter = rate_limiter.clone();
90            rate_limit_middleware(limiter, req, next)
91        }))
92        .layer(DefaultBodyLimit::max(MAX_BODY_SIZE))
93        .layer(TimeoutLayer::with_status_code(
94            StatusCode::GATEWAY_TIMEOUT,
95            timeout,
96        ))
97        .layer(SetResponseHeaderLayer::overriding(
98            axum::http::header::X_CONTENT_TYPE_OPTIONS,
99            axum::http::HeaderValue::from_static("nosniff"),
100        ))
101        .layer(SetResponseHeaderLayer::overriding(
102            axum::http::header::X_FRAME_OPTIONS,
103            axum::http::HeaderValue::from_static("DENY"),
104        ))
105        .layer(CorsLayer::permissive())
106        .layer(TraceLayer::new_for_http())
107}
108
109/// Simple token-bucket rate limiter using atomic counters.
110/// Refills `rps` tokens every second.
111struct RateLimiter {
112    tokens: AtomicU64,
113    max_tokens: u64,
114    last_refill: std::sync::Mutex<std::time::Instant>,
115}
116
117impl RateLimiter {
118    fn new(rps: u64) -> Self {
119        Self {
120            tokens: AtomicU64::new(rps),
121            max_tokens: rps,
122            last_refill: std::sync::Mutex::new(std::time::Instant::now()),
123        }
124    }
125
126    fn try_acquire(&self) -> bool {
127        // Refill tokens based on elapsed time.
128        {
129            let mut last = self.last_refill.lock().unwrap();
130            let elapsed = last.elapsed();
131            if elapsed >= Duration::from_secs(1) {
132                let refill = (elapsed.as_secs_f64() * self.max_tokens as f64) as u64;
133                let current = self.tokens.load(Ordering::Relaxed);
134                let new_val = (current + refill).min(self.max_tokens);
135                self.tokens.store(new_val, Ordering::Relaxed);
136                *last = std::time::Instant::now();
137            }
138        }
139
140        // Try to consume one token.
141        loop {
142            let current = self.tokens.load(Ordering::Relaxed);
143            if current == 0 {
144                return false;
145            }
146            if self
147                .tokens
148                .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
149                .is_ok()
150            {
151                return true;
152            }
153        }
154    }
155}
156
157async fn rate_limit_middleware(
158    limiter: Option<Arc<RateLimiter>>,
159    req: Request<Body>,
160    next: Next,
161) -> Response {
162    if let Some(limiter) = limiter
163        && req.uri().path() != "/health"
164        && req.uri().path() != "/ready"
165        && !limiter.try_acquire()
166    {
167        return (
168            StatusCode::TOO_MANY_REQUESTS,
169            axum::Json(crw_core::types::ApiResponse::<()>::err_with_code(
170                "Rate limited",
171                "rate_limited",
172            )),
173        )
174            .into_response();
175    }
176    next.run(req).await
177}