1use crate::api::{ApiResponse, ApiState};
6use crate::distributed::rate_limiting::RateLimitResult;
7use axum::{
8 extract::Request,
9 http::StatusCode,
10 middleware::Next,
11 response::{IntoResponse, Response},
12};
13use std::time::{Duration, Instant};
14
15fn sanitize_header_for_log(value: &str) -> String {
19 value
20 .chars()
21 .filter(|c| !c.is_control() || *c == ' ')
22 .take(200)
23 .collect()
24}
25
26pub async fn rate_limit_middleware_with_state(
32 state: ApiState,
33 request: Request,
34 next: Next,
35) -> Result<Response, Response> {
36 let client_key = request
43 .headers()
44 .get("x-forwarded-for")
45 .and_then(|v| v.to_str().ok())
46 .and_then(|s| s.split(',').next())
47 .map(str::trim)
48 .and_then(|s| s.parse::<std::net::IpAddr>().ok().map(|ip| ip.to_string()))
49 .or_else(|| {
50 request
51 .headers()
52 .get("x-real-ip")
53 .and_then(|v| v.to_str().ok())
54 .map(str::trim)
55 .and_then(|s| s.parse::<std::net::IpAddr>().ok().map(|ip| ip.to_string()))
56 })
57 .unwrap_or_else(|| {
58 tracing::warn!(
59 "Rate limiter: no identifiable client IP from X-Forwarded-For or X-Real-IP; \
60 falling back to shared 'unidentified' bucket"
61 );
62 "unidentified".to_string()
63 });
64
65 match state.rate_limiter.check_rate_limit(&client_key).await {
66 Ok(RateLimitResult::Allowed {
67 remaining,
68 reset_at,
69 }) => {
70 let mut response = next.run(request).await;
71 let headers = response.headers_mut();
72 let reset_secs = reset_at
73 .checked_duration_since(Instant::now())
74 .unwrap_or(Duration::ZERO)
75 .as_secs();
76 if let Ok(v) = remaining.to_string().parse() {
77 headers.insert("X-RateLimit-Remaining", v);
78 }
79 if let Ok(v) = reset_secs.to_string().parse() {
80 headers.insert("X-RateLimit-Reset", v);
81 }
82 Ok(response)
83 }
84 Ok(RateLimitResult::Denied { retry_after, .. }) => {
85 tracing::warn!(client = %client_key, "Rate limit exceeded");
86 let mut response = ApiResponse::<()>::error(
87 "RATE_LIMIT_EXCEEDED",
88 "Too many requests — please retry after the indicated delay",
89 )
90 .into_response();
91 *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
92 let headers = response.headers_mut();
93 if let Ok(v) = retry_after.as_secs().to_string().parse() {
94 headers.insert("Retry-After", v);
95 }
96 Err(response)
97 }
98 Ok(RateLimitResult::Blocked { unblock_at, reason }) => {
99 tracing::warn!(client = %client_key, reason = %reason, "Client is blocked");
100 let mut response = ApiResponse::<()>::error(
101 "CLIENT_BLOCKED",
102 "Access temporarily blocked due to repeated rate limit violations",
103 )
104 .into_response();
105 *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
106 let unblock_secs = unblock_at
107 .checked_duration_since(Instant::now())
108 .unwrap_or(Duration::ZERO)
109 .as_secs();
110 if let Ok(v) = unblock_secs.to_string().parse() {
111 let headers = response.headers_mut();
112 headers.insert("Retry-After", v);
113 }
114 Err(response)
115 }
116 Err(e) => {
117 tracing::error!(error = %e, "Rate limiter error — rejecting request");
118 let mut response = ApiResponse::<()>::error(
119 "RATE_LIMIT_UNAVAILABLE",
120 "Rate limiting is temporarily unavailable; request rejected to protect the service",
121 )
122 .into_response();
123 *response.status_mut() = StatusCode::SERVICE_UNAVAILABLE;
124 Err(response)
125 }
126 }
127}
128
129pub async fn logging_middleware(request: Request, next: Next) -> Response {
143 let start = Instant::now();
144 let method = request.method().clone();
145 let uri = request.uri().clone();
146 let headers = request.headers().clone();
147
148 let user_agent = headers
150 .get("user-agent")
151 .and_then(|v| v.to_str().ok())
152 .unwrap_or("unknown");
153 let user_agent = sanitize_header_for_log(user_agent);
154
155 let forwarded_for = headers
156 .get("x-forwarded-for")
157 .and_then(|v| v.to_str().ok())
158 .unwrap_or("unknown");
159 let forwarded_for = sanitize_header_for_log(forwarded_for);
160
161 tracing::info!(
162 "Request started: {} {} from {} ({})",
163 method,
164 uri,
165 forwarded_for,
166 user_agent
167 );
168
169 let response = next.run(request).await;
170 let duration = start.elapsed();
171 let status = response.status();
172
173 tracing::info!(
174 "Request completed: {} {} {} in {:?}",
175 method,
176 uri,
177 status,
178 duration
179 );
180
181 response
182}
183
184pub async fn security_headers_middleware(request: Request, next: Next) -> Response {
186 let response = next.run(request).await;
187
188 let mut response = response;
189 let headers = response.headers_mut();
190
191 headers.insert(
193 "X-Content-Type-Options",
194 axum::http::HeaderValue::from_static("nosniff"),
195 );
196 headers.insert(
197 "X-Frame-Options",
198 axum::http::HeaderValue::from_static("DENY"),
199 );
200 headers.insert(
201 "X-XSS-Protection",
202 axum::http::HeaderValue::from_static("1; mode=block"),
203 );
204 headers.insert(
205 "Strict-Transport-Security",
206 axum::http::HeaderValue::from_static("max-age=31536000; includeSubDomains"),
207 );
208 headers.insert(
209 "Referrer-Policy",
210 axum::http::HeaderValue::from_static("strict-origin-when-cross-origin"),
211 );
212 headers.insert(
213 "Permissions-Policy",
214 axum::http::HeaderValue::from_static("camera=(), microphone=(), geolocation=()"),
215 );
216 headers.insert(
217 "Content-Security-Policy",
218 axum::http::HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
219 );
220
221 response
222}
223
224pub async fn timeout_middleware(request: Request, next: Next) -> Result<Response, Response> {
226 match tokio::time::timeout(Duration::from_secs(30), next.run(request)).await {
228 Ok(response) => Ok(response),
229 Err(_) => {
230 let error_response =
231 ApiResponse::<()>::error("REQUEST_TIMEOUT", "Request timed out after 30 seconds");
232 Err(error_response.into_response())
233 }
234 }
235}
236
237pub fn check_permission(auth_token: &crate::tokens::AuthToken, required_permission: &str) -> bool {
249 auth_token.permissions.iter().any(|perm| {
250 perm == required_permission
251 || perm == "*"
252 || (perm.ends_with("*") && required_permission.starts_with(&perm[..perm.len() - 1]))
253 })
254}
255
256pub fn check_role(auth_token: &crate::tokens::AuthToken, required_role: &str) -> bool {
267 auth_token.roles.contains(&required_role.to_string())
268 || auth_token.roles.contains(&"admin".to_string()) }
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::tokens::{AuthToken, TokenMetadata};
275
276 fn make_token(permissions: Vec<&str>, roles: Vec<&str>) -> AuthToken {
277 AuthToken {
278 token_id: "tid".into(),
279 user_id: "uid".into(),
280 access_token: "at".into(),
281 token_type: Some("Bearer".into()),
282 subject: Some("uid".into()),
283 issuer: Some("iss".into()),
284 refresh_token: None,
285 issued_at: chrono::Utc::now(),
286 expires_at: chrono::Utc::now(),
287 scopes: vec![].into(),
288 auth_method: "jwt".into(),
289 client_id: None,
290 user_profile: None,
291 permissions: permissions
292 .into_iter()
293 .map(String::from)
294 .collect::<Vec<_>>()
295 .into(),
296 roles: roles
297 .into_iter()
298 .map(String::from)
299 .collect::<Vec<_>>()
300 .into(),
301 metadata: TokenMetadata::default(),
302 }
303 }
304
305 #[test]
308 fn test_check_permission_exact_match() {
309 let token = make_token(vec!["users:read"], vec![]);
310 assert!(check_permission(&token, "users:read"));
311 }
312
313 #[test]
314 fn test_check_permission_no_match() {
315 let token = make_token(vec!["users:read"], vec![]);
316 assert!(!check_permission(&token, "users:write"));
317 }
318
319 #[test]
320 fn test_check_permission_wildcard_all() {
321 let token = make_token(vec!["*"], vec![]);
322 assert!(check_permission(&token, "anything:at:all"));
323 }
324
325 #[test]
326 fn test_check_permission_wildcard_prefix() {
327 let token = make_token(vec!["users:*"], vec![]);
328 assert!(check_permission(&token, "users:read"));
329 assert!(check_permission(&token, "users:write"));
330 assert!(!check_permission(&token, "admin:read"));
331 }
332
333 #[test]
334 fn test_check_permission_empty() {
335 let token = make_token(vec![], vec![]);
336 assert!(!check_permission(&token, "anything"));
337 }
338
339 #[test]
342 fn test_check_role_exact_match() {
343 let token = make_token(vec![], vec!["editor"]);
344 assert!(check_role(&token, "editor"));
345 }
346
347 #[test]
348 fn test_check_role_no_match() {
349 let token = make_token(vec![], vec!["editor"]);
350 assert!(!check_role(&token, "moderator"));
351 }
352
353 #[test]
354 fn test_check_role_admin_has_all_roles() {
355 let token = make_token(vec![], vec!["admin"]);
356 assert!(check_role(&token, "anything"));
357 assert!(check_role(&token, "editor"));
358 }
359
360 #[test]
361 fn test_check_role_empty() {
362 let token = make_token(vec![], vec![]);
363 assert!(!check_role(&token, "user"));
364 }
365}