1use axum::Router;
2use axum::body::Body;
3use axum::extract::DefaultBodyLimit;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::Duration;
11use tower_http::cors::CorsLayer;
12use tower_http::set_header::SetResponseHeaderLayer;
13use tower_http::timeout::TimeoutLayer;
14use tower_http::trace::TraceLayer;
15
16const MAX_BODY_SIZE: usize = 1024 * 1024;
18
19use crate::middleware::auth_middleware;
20use crate::routes;
21use crate::state::AppState;
22
23pub fn create_app(state: AppState) -> Router {
24 let api_keys = Arc::new(state.config.auth.api_keys.clone());
25 let timeout = Duration::from_secs(state.config.effective_request_timeout_secs());
30 let rate_limit_rps = state.config.server.rate_limit_rps;
31
32 let api_routes = Router::new()
33 .route(
34 "/v1/scrape",
35 post(routes::scrape::scrape).fallback(method_not_allowed),
36 )
37 .route(
38 "/v1/crawl",
39 post(routes::crawl::start_crawl).fallback(method_not_allowed),
40 )
41 .route(
42 "/v1/crawl/{id}",
43 get(routes::crawl::get_crawl)
44 .delete(routes::crawl::cancel_crawl)
45 .fallback(method_not_allowed),
46 )
47 .route(
48 "/v1/map",
49 post(routes::map::map).fallback(method_not_allowed),
50 )
51 .route(
52 "/v1/search",
53 post(routes::search::search).fallback(method_not_allowed),
54 )
55 .route(
56 "/mcp",
57 post(routes::mcp::mcp_handler).fallback(method_not_allowed),
58 );
59
60 let api_routes = if api_keys.is_empty() {
61 api_routes.with_state(state.clone())
62 } else {
63 api_routes
64 .route_layer(axum::middleware::from_fn_with_state(
65 api_keys,
66 auth_middleware,
67 ))
68 .with_state(state.clone())
69 };
70
71 let rate_limiter = if rate_limit_rps > 0 {
72 Some(Arc::new(RateLimiter::new(rate_limit_rps)))
73 } else {
74 None
75 };
76
77 Router::new()
78 .route(
79 "/health",
80 get(routes::health::health).fallback(method_not_allowed),
81 )
82 .route(
83 "/ready",
84 get(routes::health::ready).fallback(method_not_allowed),
85 )
86 .route(
87 "/metrics",
88 get(routes::metrics::metrics).fallback(method_not_allowed),
89 )
90 .route(
91 "/metrics/renderer-breakers",
92 get(routes::breakers::renderer_breakers).fallback(method_not_allowed),
93 )
94 .route(
95 "/admin/breakers/reset",
96 post(routes::breakers::reset_breakers).fallback(method_not_allowed),
97 )
98 .with_state(state)
99 .merge(api_routes)
100 .layer(axum::middleware::from_fn(move |req, next| {
101 let limiter = rate_limiter.clone();
102 rate_limit_middleware(limiter, req, next)
103 }))
104 .layer(DefaultBodyLimit::max(MAX_BODY_SIZE))
105 .layer(TimeoutLayer::with_status_code(
106 StatusCode::GATEWAY_TIMEOUT,
107 timeout,
108 ))
109 .layer(SetResponseHeaderLayer::overriding(
110 axum::http::header::X_CONTENT_TYPE_OPTIONS,
111 axum::http::HeaderValue::from_static("nosniff"),
112 ))
113 .layer(SetResponseHeaderLayer::overriding(
114 axum::http::header::X_FRAME_OPTIONS,
115 axum::http::HeaderValue::from_static("DENY"),
116 ))
117 .layer(CorsLayer::permissive())
118 .layer(TraceLayer::new_for_http())
119}
120
121async fn method_not_allowed() -> impl IntoResponse {
122 (
123 StatusCode::METHOD_NOT_ALLOWED,
124 axum::Json(crw_core::types::ApiResponse::<()>::err_with_code(
125 "Method not allowed",
126 "method_not_allowed",
127 )),
128 )
129}
130
131struct RateLimiter {
134 tokens: AtomicU64,
135 max_tokens: u64,
136 last_refill: std::sync::Mutex<std::time::Instant>,
137}
138
139impl RateLimiter {
140 fn new(rps: u64) -> Self {
141 Self {
142 tokens: AtomicU64::new(rps),
143 max_tokens: rps,
144 last_refill: std::sync::Mutex::new(std::time::Instant::now()),
145 }
146 }
147
148 fn try_acquire(&self) -> bool {
149 {
151 let mut last = self.last_refill.lock().unwrap();
152 let elapsed = last.elapsed();
153 if elapsed >= Duration::from_secs(1) {
154 let refill = (elapsed.as_secs_f64() * self.max_tokens as f64) as u64;
155 let current = self.tokens.load(Ordering::Relaxed);
156 let new_val = (current + refill).min(self.max_tokens);
157 self.tokens.store(new_val, Ordering::Relaxed);
158 *last = std::time::Instant::now();
159 }
160 }
161
162 loop {
164 let current = self.tokens.load(Ordering::Relaxed);
165 if current == 0 {
166 return false;
167 }
168 if self
169 .tokens
170 .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
171 .is_ok()
172 {
173 return true;
174 }
175 }
176 }
177}
178
179async fn rate_limit_middleware(
180 limiter: Option<Arc<RateLimiter>>,
181 req: Request<Body>,
182 next: Next,
183) -> Response {
184 if let Some(limiter) = limiter
185 && req.uri().path() != "/health"
186 && req.uri().path() != "/ready"
187 && !limiter.try_acquire()
188 {
189 return (
190 StatusCode::TOO_MANY_REQUESTS,
191 axum::Json(crw_core::types::ApiResponse::<()>::err_with_code(
192 "Rate limited",
193 "rate_limited",
194 )),
195 )
196 .into_response();
197 }
198 next.run(req).await
199}