1use axum::Router;
2use axum::body::Body;
3use axum::extract::DefaultBodyLimit;
4use axum::http::{Request, StatusCode};
5use axum::middleware::Next;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::Duration;
11use tower_http::cors::CorsLayer;
12use tower_http::set_header::SetResponseHeaderLayer;
13use tower_http::timeout::TimeoutLayer;
14use tower_http::trace::TraceLayer;
15
16const MAX_BODY_SIZE: usize = 1024 * 1024;
18
19use crate::middleware::auth_middleware;
20use crate::routes::{self, method_not_allowed};
21use crate::state::AppState;
22
23pub fn create_app(state: AppState) -> Router {
24 let api_keys = Arc::new(state.config.auth.api_keys.clone());
25 let timeout = Duration::from_secs(state.config.effective_request_timeout_secs());
30 let rate_limit_rps = state.config.server.rate_limit_rps;
31
32 let api_routes = routes::v1::router().merge(routes::v2::router()).route(
36 "/mcp",
37 post(routes::mcp::mcp_handler).fallback(method_not_allowed),
38 );
39
40 let api_routes = if api_keys.is_empty() {
41 api_routes.with_state(state.clone())
42 } else {
43 api_routes
44 .route_layer(axum::middleware::from_fn_with_state(
45 api_keys,
46 auth_middleware,
47 ))
48 .with_state(state.clone())
49 };
50
51 let rate_limiter = if rate_limit_rps > 0 {
52 Some(Arc::new(RateLimiter::new(rate_limit_rps)))
53 } else {
54 None
55 };
56
57 Router::new()
58 .route(
59 "/health",
60 get(routes::health::health).fallback(method_not_allowed),
61 )
62 .route(
63 "/openapi.json",
64 get(routes::openapi::serve_openapi_3_1).fallback(method_not_allowed),
65 )
66 .route(
67 "/openapi-3.0.json",
68 get(routes::openapi::serve_openapi_3_0).fallback(method_not_allowed),
69 )
70 .route(
71 "/ready",
72 get(routes::health::ready).fallback(method_not_allowed),
73 )
74 .route(
75 "/metrics",
76 get(routes::metrics::metrics).fallback(method_not_allowed),
77 )
78 .route(
79 "/metrics/renderer-breakers",
80 get(routes::breakers::renderer_breakers).fallback(method_not_allowed),
81 )
82 .route(
83 "/admin/breakers/reset",
84 post(routes::breakers::reset_breakers).fallback(method_not_allowed),
85 )
86 .with_state(state)
87 .merge(api_routes)
88 .layer(axum::middleware::from_fn(move |req, next| {
89 let limiter = rate_limiter.clone();
90 rate_limit_middleware(limiter, req, next)
91 }))
92 .layer(DefaultBodyLimit::max(MAX_BODY_SIZE))
93 .layer(TimeoutLayer::with_status_code(
94 StatusCode::GATEWAY_TIMEOUT,
95 timeout,
96 ))
97 .layer(SetResponseHeaderLayer::overriding(
98 axum::http::header::X_CONTENT_TYPE_OPTIONS,
99 axum::http::HeaderValue::from_static("nosniff"),
100 ))
101 .layer(SetResponseHeaderLayer::overriding(
102 axum::http::header::X_FRAME_OPTIONS,
103 axum::http::HeaderValue::from_static("DENY"),
104 ))
105 .layer(CorsLayer::permissive())
106 .layer(TraceLayer::new_for_http())
107}
108
109struct RateLimiter {
112 tokens: AtomicU64,
113 max_tokens: u64,
114 last_refill: std::sync::Mutex<std::time::Instant>,
115}
116
117impl RateLimiter {
118 fn new(rps: u64) -> Self {
119 Self {
120 tokens: AtomicU64::new(rps),
121 max_tokens: rps,
122 last_refill: std::sync::Mutex::new(std::time::Instant::now()),
123 }
124 }
125
126 fn try_acquire(&self) -> bool {
127 {
129 let mut last = self.last_refill.lock().unwrap();
130 let elapsed = last.elapsed();
131 if elapsed >= Duration::from_secs(1) {
132 let refill = (elapsed.as_secs_f64() * self.max_tokens as f64) as u64;
133 let current = self.tokens.load(Ordering::Relaxed);
134 let new_val = (current + refill).min(self.max_tokens);
135 self.tokens.store(new_val, Ordering::Relaxed);
136 *last = std::time::Instant::now();
137 }
138 }
139
140 loop {
142 let current = self.tokens.load(Ordering::Relaxed);
143 if current == 0 {
144 return false;
145 }
146 if self
147 .tokens
148 .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
149 .is_ok()
150 {
151 return true;
152 }
153 }
154 }
155}
156
157async fn rate_limit_middleware(
158 limiter: Option<Arc<RateLimiter>>,
159 req: Request<Body>,
160 next: Next,
161) -> Response {
162 if let Some(limiter) = limiter
163 && req.uri().path() != "/health"
164 && req.uri().path() != "/ready"
165 && !limiter.try_acquire()
166 {
167 return (
168 StatusCode::TOO_MANY_REQUESTS,
169 axum::Json(crw_core::types::ApiResponse::<()>::err_with_code(
170 "Rate limited",
171 "rate_limited",
172 )),
173 )
174 .into_response();
175 }
176 next.run(req).await
177}