Skip to main content

crw_server/
app.rs

1use axum::Router;
2use axum::body::Body;
3use axum::extract::DefaultBodyLimit;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::Duration;
11use tower_http::cors::CorsLayer;
12use tower_http::set_header::SetResponseHeaderLayer;
13use tower_http::timeout::TimeoutLayer;
14use tower_http::trace::TraceLayer;
15
16/// Maximum request body size (1 MB) to prevent memory exhaustion from large payloads.
17const MAX_BODY_SIZE: usize = 1024 * 1024;
18
19use crate::middleware::auth_middleware;
20use crate::routes;
21use crate::state::AppState;
22
23pub fn create_app(state: AppState) -> Router {
24    let api_keys = Arc::new(state.config.auth.api_keys.clone());
25    let timeout = Duration::from_secs(state.config.server.request_timeout_secs);
26    let rate_limit_rps = state.config.server.rate_limit_rps;
27
28    let api_routes = Router::new()
29        .route(
30            "/v1/scrape",
31            post(routes::scrape::scrape).fallback(method_not_allowed),
32        )
33        .route(
34            "/v1/crawl",
35            post(routes::crawl::start_crawl).fallback(method_not_allowed),
36        )
37        .route(
38            "/v1/crawl/{id}",
39            get(routes::crawl::get_crawl)
40                .delete(routes::crawl::cancel_crawl)
41                .fallback(method_not_allowed),
42        )
43        .route(
44            "/v1/map",
45            post(routes::map::map).fallback(method_not_allowed),
46        )
47        .route(
48            "/mcp",
49            post(routes::mcp::mcp_handler).fallback(method_not_allowed),
50        );
51
52    let api_routes = if api_keys.is_empty() {
53        api_routes.with_state(state.clone())
54    } else {
55        api_routes
56            .route_layer(axum::middleware::from_fn_with_state(
57                api_keys,
58                auth_middleware,
59            ))
60            .with_state(state.clone())
61    };
62
63    let rate_limiter = if rate_limit_rps > 0 {
64        Some(Arc::new(RateLimiter::new(rate_limit_rps)))
65    } else {
66        None
67    };
68
69    Router::new()
70        .route(
71            "/health",
72            get(routes::health::health).fallback(method_not_allowed),
73        )
74        .with_state(state)
75        .merge(api_routes)
76        .layer(axum::middleware::from_fn(move |req, next| {
77            let limiter = rate_limiter.clone();
78            rate_limit_middleware(limiter, req, next)
79        }))
80        .layer(DefaultBodyLimit::max(MAX_BODY_SIZE))
81        .layer(TimeoutLayer::with_status_code(
82            StatusCode::GATEWAY_TIMEOUT,
83            timeout,
84        ))
85        .layer(SetResponseHeaderLayer::overriding(
86            axum::http::header::X_CONTENT_TYPE_OPTIONS,
87            axum::http::HeaderValue::from_static("nosniff"),
88        ))
89        .layer(SetResponseHeaderLayer::overriding(
90            axum::http::header::X_FRAME_OPTIONS,
91            axum::http::HeaderValue::from_static("DENY"),
92        ))
93        .layer(CorsLayer::permissive())
94        .layer(TraceLayer::new_for_http())
95}
96
97async fn method_not_allowed() -> impl IntoResponse {
98    (
99        StatusCode::METHOD_NOT_ALLOWED,
100        axum::Json(crw_core::types::ApiResponse::<()>::err_with_code(
101            "Method not allowed",
102            "method_not_allowed",
103        )),
104    )
105}
106
107/// Simple token-bucket rate limiter using atomic counters.
108/// Refills `rps` tokens every second.
109struct RateLimiter {
110    tokens: AtomicU64,
111    max_tokens: u64,
112    last_refill: std::sync::Mutex<std::time::Instant>,
113}
114
115impl RateLimiter {
116    fn new(rps: u64) -> Self {
117        Self {
118            tokens: AtomicU64::new(rps),
119            max_tokens: rps,
120            last_refill: std::sync::Mutex::new(std::time::Instant::now()),
121        }
122    }
123
124    fn try_acquire(&self) -> bool {
125        // Refill tokens based on elapsed time.
126        {
127            let mut last = self.last_refill.lock().unwrap();
128            let elapsed = last.elapsed();
129            if elapsed >= Duration::from_secs(1) {
130                let refill = (elapsed.as_secs_f64() * self.max_tokens as f64) as u64;
131                let current = self.tokens.load(Ordering::Relaxed);
132                let new_val = (current + refill).min(self.max_tokens);
133                self.tokens.store(new_val, Ordering::Relaxed);
134                *last = std::time::Instant::now();
135            }
136        }
137
138        // Try to consume one token.
139        loop {
140            let current = self.tokens.load(Ordering::Relaxed);
141            if current == 0 {
142                return false;
143            }
144            if self
145                .tokens
146                .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
147                .is_ok()
148            {
149                return true;
150            }
151        }
152    }
153}
154
155async fn rate_limit_middleware(
156    limiter: Option<Arc<RateLimiter>>,
157    req: Request<Body>,
158    next: Next,
159) -> Response {
160    if let Some(limiter) = limiter
161        && req.uri().path() != "/health"
162        && !limiter.try_acquire()
163    {
164        return (
165            StatusCode::TOO_MANY_REQUESTS,
166            axum::Json(crw_core::types::ApiResponse::<()>::err_with_code(
167                "Rate limited",
168                "rate_limited",
169            )),
170        )
171            .into_response();
172    }
173    next.run(req).await
174}