1use axum::Router;
2use axum::body::Body;
3use axum::extract::DefaultBodyLimit;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::Duration;
11use tower_http::cors::CorsLayer;
12use tower_http::set_header::SetResponseHeaderLayer;
13use tower_http::timeout::TimeoutLayer;
14use tower_http::trace::TraceLayer;
15
16const MAX_BODY_SIZE: usize = 1024 * 1024;
18
19use crate::middleware::auth_middleware;
20use crate::routes;
21use crate::state::AppState;
22
23pub fn create_app(state: AppState) -> Router {
24 let api_keys = Arc::new(state.config.auth.api_keys.clone());
25 let timeout = Duration::from_secs(state.config.effective_request_timeout_secs());
30 let rate_limit_rps = state.config.server.rate_limit_rps;
31
32 let api_routes = Router::new()
33 .route(
34 "/v1/scrape",
35 post(routes::scrape::scrape).fallback(method_not_allowed),
36 )
37 .route(
38 "/v1/crawl",
39 post(routes::crawl::start_crawl).fallback(method_not_allowed),
40 )
41 .route(
42 "/v1/crawl/{id}",
43 get(routes::crawl::get_crawl)
44 .delete(routes::crawl::cancel_crawl)
45 .fallback(method_not_allowed),
46 )
47 .route(
48 "/v1/map",
49 post(routes::map::map).fallback(method_not_allowed),
50 )
51 .route(
52 "/v1/search",
53 post(routes::search::search).fallback(method_not_allowed),
54 )
55 .route(
56 "/v1/capabilities",
57 get(routes::capabilities::capabilities).fallback(method_not_allowed),
58 )
59 .route(
60 "/mcp",
61 post(routes::mcp::mcp_handler).fallback(method_not_allowed),
62 );
63
64 let api_routes = if api_keys.is_empty() {
65 api_routes.with_state(state.clone())
66 } else {
67 api_routes
68 .route_layer(axum::middleware::from_fn_with_state(
69 api_keys,
70 auth_middleware,
71 ))
72 .with_state(state.clone())
73 };
74
75 let rate_limiter = if rate_limit_rps > 0 {
76 Some(Arc::new(RateLimiter::new(rate_limit_rps)))
77 } else {
78 None
79 };
80
81 Router::new()
82 .route(
83 "/health",
84 get(routes::health::health).fallback(method_not_allowed),
85 )
86 .route(
87 "/ready",
88 get(routes::health::ready).fallback(method_not_allowed),
89 )
90 .route(
91 "/metrics",
92 get(routes::metrics::metrics).fallback(method_not_allowed),
93 )
94 .route(
95 "/metrics/renderer-breakers",
96 get(routes::breakers::renderer_breakers).fallback(method_not_allowed),
97 )
98 .route(
99 "/admin/breakers/reset",
100 post(routes::breakers::reset_breakers).fallback(method_not_allowed),
101 )
102 .with_state(state)
103 .merge(api_routes)
104 .layer(axum::middleware::from_fn(move |req, next| {
105 let limiter = rate_limiter.clone();
106 rate_limit_middleware(limiter, req, next)
107 }))
108 .layer(DefaultBodyLimit::max(MAX_BODY_SIZE))
109 .layer(TimeoutLayer::with_status_code(
110 StatusCode::GATEWAY_TIMEOUT,
111 timeout,
112 ))
113 .layer(SetResponseHeaderLayer::overriding(
114 axum::http::header::X_CONTENT_TYPE_OPTIONS,
115 axum::http::HeaderValue::from_static("nosniff"),
116 ))
117 .layer(SetResponseHeaderLayer::overriding(
118 axum::http::header::X_FRAME_OPTIONS,
119 axum::http::HeaderValue::from_static("DENY"),
120 ))
121 .layer(CorsLayer::permissive())
122 .layer(TraceLayer::new_for_http())
123}
124
125async fn method_not_allowed() -> impl IntoResponse {
126 (
127 StatusCode::METHOD_NOT_ALLOWED,
128 axum::Json(crw_core::types::ApiResponse::<()>::err_with_code(
129 "Method not allowed",
130 "method_not_allowed",
131 )),
132 )
133}
134
135struct RateLimiter {
138 tokens: AtomicU64,
139 max_tokens: u64,
140 last_refill: std::sync::Mutex<std::time::Instant>,
141}
142
143impl RateLimiter {
144 fn new(rps: u64) -> Self {
145 Self {
146 tokens: AtomicU64::new(rps),
147 max_tokens: rps,
148 last_refill: std::sync::Mutex::new(std::time::Instant::now()),
149 }
150 }
151
152 fn try_acquire(&self) -> bool {
153 {
155 let mut last = self.last_refill.lock().unwrap();
156 let elapsed = last.elapsed();
157 if elapsed >= Duration::from_secs(1) {
158 let refill = (elapsed.as_secs_f64() * self.max_tokens as f64) as u64;
159 let current = self.tokens.load(Ordering::Relaxed);
160 let new_val = (current + refill).min(self.max_tokens);
161 self.tokens.store(new_val, Ordering::Relaxed);
162 *last = std::time::Instant::now();
163 }
164 }
165
166 loop {
168 let current = self.tokens.load(Ordering::Relaxed);
169 if current == 0 {
170 return false;
171 }
172 if self
173 .tokens
174 .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
175 .is_ok()
176 {
177 return true;
178 }
179 }
180 }
181}
182
183async fn rate_limit_middleware(
184 limiter: Option<Arc<RateLimiter>>,
185 req: Request<Body>,
186 next: Next,
187) -> Response {
188 if let Some(limiter) = limiter
189 && req.uri().path() != "/health"
190 && req.uri().path() != "/ready"
191 && !limiter.try_acquire()
192 {
193 return (
194 StatusCode::TOO_MANY_REQUESTS,
195 axum::Json(crw_core::types::ApiResponse::<()>::err_with_code(
196 "Rate limited",
197 "rate_limited",
198 )),
199 )
200 .into_response();
201 }
202 next.run(req).await
203}