1use axum::Router;
2use axum::body::Body;
3use axum::extract::DefaultBodyLimit;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::Duration;
11use tower_http::cors::CorsLayer;
12use tower_http::set_header::SetResponseHeaderLayer;
13use tower_http::timeout::TimeoutLayer;
14use tower_http::trace::TraceLayer;
15
16const MAX_BODY_SIZE: usize = 1024 * 1024;
18
19use crate::middleware::auth_middleware;
20use crate::routes;
21use crate::state::AppState;
22
23pub fn create_app(state: AppState) -> Router {
24 let api_keys = Arc::new(state.config.auth.api_keys.clone());
25 let timeout = Duration::from_secs(state.config.server.request_timeout_secs);
26 let rate_limit_rps = state.config.server.rate_limit_rps;
27
28 let api_routes = Router::new()
29 .route(
30 "/v1/scrape",
31 post(routes::scrape::scrape).fallback(method_not_allowed),
32 )
33 .route(
34 "/v1/crawl",
35 post(routes::crawl::start_crawl).fallback(method_not_allowed),
36 )
37 .route(
38 "/v1/crawl/{id}",
39 get(routes::crawl::get_crawl)
40 .delete(routes::crawl::cancel_crawl)
41 .fallback(method_not_allowed),
42 )
43 .route(
44 "/v1/map",
45 post(routes::map::map).fallback(method_not_allowed),
46 )
47 .route(
48 "/mcp",
49 post(routes::mcp::mcp_handler).fallback(method_not_allowed),
50 );
51
52 let api_routes = if api_keys.is_empty() {
53 api_routes.with_state(state.clone())
54 } else {
55 api_routes
56 .route_layer(axum::middleware::from_fn_with_state(
57 api_keys,
58 auth_middleware,
59 ))
60 .with_state(state.clone())
61 };
62
63 let rate_limiter = if rate_limit_rps > 0 {
64 Some(Arc::new(RateLimiter::new(rate_limit_rps)))
65 } else {
66 None
67 };
68
69 Router::new()
70 .route(
71 "/health",
72 get(routes::health::health).fallback(method_not_allowed),
73 )
74 .with_state(state)
75 .merge(api_routes)
76 .layer(axum::middleware::from_fn(move |req, next| {
77 let limiter = rate_limiter.clone();
78 rate_limit_middleware(limiter, req, next)
79 }))
80 .layer(DefaultBodyLimit::max(MAX_BODY_SIZE))
81 .layer(TimeoutLayer::with_status_code(
82 StatusCode::GATEWAY_TIMEOUT,
83 timeout,
84 ))
85 .layer(SetResponseHeaderLayer::overriding(
86 axum::http::header::X_CONTENT_TYPE_OPTIONS,
87 axum::http::HeaderValue::from_static("nosniff"),
88 ))
89 .layer(SetResponseHeaderLayer::overriding(
90 axum::http::header::X_FRAME_OPTIONS,
91 axum::http::HeaderValue::from_static("DENY"),
92 ))
93 .layer(CorsLayer::permissive())
94 .layer(TraceLayer::new_for_http())
95}
96
97async fn method_not_allowed() -> impl IntoResponse {
98 (
99 StatusCode::METHOD_NOT_ALLOWED,
100 axum::Json(crw_core::types::ApiResponse::<()>::err_with_code(
101 "Method not allowed",
102 "method_not_allowed",
103 )),
104 )
105}
106
107struct RateLimiter {
110 tokens: AtomicU64,
111 max_tokens: u64,
112 last_refill: std::sync::Mutex<std::time::Instant>,
113}
114
115impl RateLimiter {
116 fn new(rps: u64) -> Self {
117 Self {
118 tokens: AtomicU64::new(rps),
119 max_tokens: rps,
120 last_refill: std::sync::Mutex::new(std::time::Instant::now()),
121 }
122 }
123
124 fn try_acquire(&self) -> bool {
125 {
127 let mut last = self.last_refill.lock().unwrap();
128 let elapsed = last.elapsed();
129 if elapsed >= Duration::from_secs(1) {
130 let refill = (elapsed.as_secs_f64() * self.max_tokens as f64) as u64;
131 let current = self.tokens.load(Ordering::Relaxed);
132 let new_val = (current + refill).min(self.max_tokens);
133 self.tokens.store(new_val, Ordering::Relaxed);
134 *last = std::time::Instant::now();
135 }
136 }
137
138 loop {
140 let current = self.tokens.load(Ordering::Relaxed);
141 if current == 0 {
142 return false;
143 }
144 if self
145 .tokens
146 .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
147 .is_ok()
148 {
149 return true;
150 }
151 }
152 }
153}
154
155async fn rate_limit_middleware(
156 limiter: Option<Arc<RateLimiter>>,
157 req: Request<Body>,
158 next: Next,
159) -> Response {
160 if let Some(limiter) = limiter
161 && req.uri().path() != "/health"
162 && !limiter.try_acquire()
163 {
164 return (
165 StatusCode::TOO_MANY_REQUESTS,
166 axum::Json(crw_core::types::ApiResponse::<()>::err_with_code(
167 "Rate limited",
168 "rate_limited",
169 )),
170 )
171 .into_response();
172 }
173 next.run(req).await
174}