Skip to main content

crw_server/
app.rs

1use axum::Router;
2use axum::body::Body;
3use axum::extract::DefaultBodyLimit;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::Duration;
11use tower_http::cors::CorsLayer;
12use tower_http::set_header::SetResponseHeaderLayer;
13use tower_http::timeout::TimeoutLayer;
14use tower_http::trace::TraceLayer;
15
16/// Maximum request body size (1 MB) to prevent memory exhaustion from large payloads.
17const MAX_BODY_SIZE: usize = 1024 * 1024;
18
19use crate::middleware::auth_middleware;
20use crate::routes;
21use crate::state::AppState;
22
23pub fn create_app(state: AppState) -> Router {
24    let api_keys = Arc::new(state.config.auth.api_keys.clone());
25    // Tower outer timeout. `effective_request_timeout_secs()` widens the
26    // operator baseline so the longest legitimate handler runtime (auto-extended
27    // scrape, search enrichment fan-out, map's 300s ceiling) isn't cancelled
28    // by the outer layer before the inner deadline fires. See issue #35.
29    let timeout = Duration::from_secs(state.config.effective_request_timeout_secs());
30    let rate_limit_rps = state.config.server.rate_limit_rps;
31
32    let api_routes = Router::new()
33        .route(
34            "/v1/scrape",
35            post(routes::scrape::scrape).fallback(method_not_allowed),
36        )
37        .route(
38            "/v1/crawl",
39            post(routes::crawl::start_crawl).fallback(method_not_allowed),
40        )
41        .route(
42            "/v1/crawl/{id}",
43            get(routes::crawl::get_crawl)
44                .delete(routes::crawl::cancel_crawl)
45                .fallback(method_not_allowed),
46        )
47        .route(
48            "/v1/map",
49            post(routes::map::map).fallback(method_not_allowed),
50        )
51        .route(
52            "/v1/search",
53            post(routes::search::search).fallback(method_not_allowed),
54        )
55        .route(
56            "/v1/capabilities",
57            get(routes::capabilities::capabilities).fallback(method_not_allowed),
58        )
59        .route(
60            "/mcp",
61            post(routes::mcp::mcp_handler).fallback(method_not_allowed),
62        );
63
64    let api_routes = if api_keys.is_empty() {
65        api_routes.with_state(state.clone())
66    } else {
67        api_routes
68            .route_layer(axum::middleware::from_fn_with_state(
69                api_keys,
70                auth_middleware,
71            ))
72            .with_state(state.clone())
73    };
74
75    let rate_limiter = if rate_limit_rps > 0 {
76        Some(Arc::new(RateLimiter::new(rate_limit_rps)))
77    } else {
78        None
79    };
80
81    Router::new()
82        .route(
83            "/health",
84            get(routes::health::health).fallback(method_not_allowed),
85        )
86        .route(
87            "/ready",
88            get(routes::health::ready).fallback(method_not_allowed),
89        )
90        .route(
91            "/metrics",
92            get(routes::metrics::metrics).fallback(method_not_allowed),
93        )
94        .route(
95            "/metrics/renderer-breakers",
96            get(routes::breakers::renderer_breakers).fallback(method_not_allowed),
97        )
98        .route(
99            "/admin/breakers/reset",
100            post(routes::breakers::reset_breakers).fallback(method_not_allowed),
101        )
102        .with_state(state)
103        .merge(api_routes)
104        .layer(axum::middleware::from_fn(move |req, next| {
105            let limiter = rate_limiter.clone();
106            rate_limit_middleware(limiter, req, next)
107        }))
108        .layer(DefaultBodyLimit::max(MAX_BODY_SIZE))
109        .layer(TimeoutLayer::with_status_code(
110            StatusCode::GATEWAY_TIMEOUT,
111            timeout,
112        ))
113        .layer(SetResponseHeaderLayer::overriding(
114            axum::http::header::X_CONTENT_TYPE_OPTIONS,
115            axum::http::HeaderValue::from_static("nosniff"),
116        ))
117        .layer(SetResponseHeaderLayer::overriding(
118            axum::http::header::X_FRAME_OPTIONS,
119            axum::http::HeaderValue::from_static("DENY"),
120        ))
121        .layer(CorsLayer::permissive())
122        .layer(TraceLayer::new_for_http())
123}
124
125async fn method_not_allowed() -> impl IntoResponse {
126    (
127        StatusCode::METHOD_NOT_ALLOWED,
128        axum::Json(crw_core::types::ApiResponse::<()>::err_with_code(
129            "Method not allowed",
130            "method_not_allowed",
131        )),
132    )
133}
134
135/// Simple token-bucket rate limiter using atomic counters.
136/// Refills `rps` tokens every second.
137struct RateLimiter {
138    tokens: AtomicU64,
139    max_tokens: u64,
140    last_refill: std::sync::Mutex<std::time::Instant>,
141}
142
143impl RateLimiter {
144    fn new(rps: u64) -> Self {
145        Self {
146            tokens: AtomicU64::new(rps),
147            max_tokens: rps,
148            last_refill: std::sync::Mutex::new(std::time::Instant::now()),
149        }
150    }
151
152    fn try_acquire(&self) -> bool {
153        // Refill tokens based on elapsed time.
154        {
155            let mut last = self.last_refill.lock().unwrap();
156            let elapsed = last.elapsed();
157            if elapsed >= Duration::from_secs(1) {
158                let refill = (elapsed.as_secs_f64() * self.max_tokens as f64) as u64;
159                let current = self.tokens.load(Ordering::Relaxed);
160                let new_val = (current + refill).min(self.max_tokens);
161                self.tokens.store(new_val, Ordering::Relaxed);
162                *last = std::time::Instant::now();
163            }
164        }
165
166        // Try to consume one token.
167        loop {
168            let current = self.tokens.load(Ordering::Relaxed);
169            if current == 0 {
170                return false;
171            }
172            if self
173                .tokens
174                .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
175                .is_ok()
176            {
177                return true;
178            }
179        }
180    }
181}
182
183async fn rate_limit_middleware(
184    limiter: Option<Arc<RateLimiter>>,
185    req: Request<Body>,
186    next: Next,
187) -> Response {
188    if let Some(limiter) = limiter
189        && req.uri().path() != "/health"
190        && req.uri().path() != "/ready"
191        && !limiter.try_acquire()
192    {
193        return (
194            StatusCode::TOO_MANY_REQUESTS,
195            axum::Json(crw_core::types::ApiResponse::<()>::err_with_code(
196                "Rate limited",
197                "rate_limited",
198            )),
199        )
200            .into_response();
201    }
202    next.run(req).await
203}