Skip to main content

auth_framework/api/
middleware.rs

1//! API Middleware
2//!
3//! Authentication, authorization, rate limiting, and other middleware
4
5use crate::api::{ApiResponse, ApiState};
6use crate::distributed::rate_limiting::RateLimitResult;
7use axum::{
8    extract::Request,
9    http::StatusCode,
10    middleware::Next,
11    response::{IntoResponse, Response},
12};
13use std::time::{Duration, Instant};
14
15/// Sanitize a header value for safe inclusion in log output.
16/// Strips control characters (except space) and truncates to 200 chars
17/// to prevent log injection and log flooding.
18fn sanitize_header_for_log(value: &str) -> String {
19    value
20        .chars()
21        .filter(|c| !c.is_control() || *c == ' ')
22        .take(200)
23        .collect()
24}
25
26/// Rate limiting middleware backed by [`crate::distributed::rate_limiting::DistributedRateLimiter`].
27///
28/// Uses the client IP address (from `X-Forwarded-For`, then `X-Real-IP`, falling back to
29/// `"unknown"`) as the rate-limiting key.  Returns **429 Too Many Requests** when the limit
30/// is exceeded and adds standard `X-RateLimit-*` response headers on every response.
31pub async fn rate_limit_middleware_with_state(
32    state: ApiState,
33    request: Request,
34    next: Next,
35) -> Result<Response, Response> {
36    // Derive a stable key from the client IP.
37    // SECURITY (H-4): Validate the extracted value as a real IP address before using it
38    // as a rate-limit key.  An attacker who can set arbitrary X-Forwarded-For headers
39    // could bypass per-IP limiting by injecting a fabricated first address.  Requiring a
40    // valid parse means only syntactically correct IPs are accepted; anything else falls
41    // through to the next header or the fallback string "unknown".
42    let client_key = request
43        .headers()
44        .get("x-forwarded-for")
45        .and_then(|v| v.to_str().ok())
46        .and_then(|s| s.split(',').next())
47        .map(str::trim)
48        .and_then(|s| s.parse::<std::net::IpAddr>().ok().map(|ip| ip.to_string()))
49        .or_else(|| {
50            request
51                .headers()
52                .get("x-real-ip")
53                .and_then(|v| v.to_str().ok())
54                .map(str::trim)
55                .and_then(|s| s.parse::<std::net::IpAddr>().ok().map(|ip| ip.to_string()))
56        })
57        .unwrap_or_else(|| {
58            tracing::warn!(
59                "Rate limiter: no identifiable client IP from X-Forwarded-For or X-Real-IP; \
60                 falling back to shared 'unidentified' bucket"
61            );
62            "unidentified".to_string()
63        });
64
65    match state.rate_limiter.check_rate_limit(&client_key).await {
66        Ok(RateLimitResult::Allowed {
67            remaining,
68            reset_at,
69        }) => {
70            let mut response = next.run(request).await;
71            let headers = response.headers_mut();
72            let reset_secs = reset_at
73                .checked_duration_since(Instant::now())
74                .unwrap_or(Duration::ZERO)
75                .as_secs();
76            if let Ok(v) = remaining.to_string().parse() {
77                headers.insert("X-RateLimit-Remaining", v);
78            }
79            if let Ok(v) = reset_secs.to_string().parse() {
80                headers.insert("X-RateLimit-Reset", v);
81            }
82            Ok(response)
83        }
84        Ok(RateLimitResult::Denied { retry_after, .. }) => {
85            tracing::warn!(client = %client_key, "Rate limit exceeded");
86            let mut response = ApiResponse::<()>::error(
87                "RATE_LIMIT_EXCEEDED",
88                "Too many requests — please retry after the indicated delay",
89            )
90            .into_response();
91            *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
92            let headers = response.headers_mut();
93            if let Ok(v) = retry_after.as_secs().to_string().parse() {
94                headers.insert("Retry-After", v);
95            }
96            Err(response)
97        }
98        Ok(RateLimitResult::Blocked { unblock_at, reason }) => {
99            tracing::warn!(client = %client_key, reason = %reason, "Client is blocked");
100            let mut response = ApiResponse::<()>::error(
101                "CLIENT_BLOCKED",
102                "Access temporarily blocked due to repeated rate limit violations",
103            )
104            .into_response();
105            *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
106            let unblock_secs = unblock_at
107                .checked_duration_since(Instant::now())
108                .unwrap_or(Duration::ZERO)
109                .as_secs();
110            if let Ok(v) = unblock_secs.to_string().parse() {
111                let headers = response.headers_mut();
112                headers.insert("Retry-After", v);
113            }
114            Err(response)
115        }
116        Err(e) => {
117            tracing::error!(error = %e, "Rate limiter error — rejecting request");
118            let mut response = ApiResponse::<()>::error(
119                "RATE_LIMIT_UNAVAILABLE",
120                "Rate limiting is temporarily unavailable; request rejected to protect the service",
121            )
122            .into_response();
123            *response.status_mut() = StatusCode::SERVICE_UNAVAILABLE;
124            Err(response)
125        }
126    }
127}
128
129// cors_middleware was removed: it set `Access-Control-Allow-Origin: *` which is
130// inappropriate for an authentication service.  Use tower-http's `CorsLayer`
131// with a configured `AllowOrigin` list instead.
132
133/// Log every incoming HTTP request with method, URI, client IP, and user-agent.
134///
135/// Elapsed time is logged on the response path.
136///
137/// # Example
138/// ```rust,ignore
139/// let app = Router::new()
140///     .layer(axum::middleware::from_fn(logging_middleware));
141/// ```
142pub async fn logging_middleware(request: Request, next: Next) -> Response {
143    let start = Instant::now();
144    let method = request.method().clone();
145    let uri = request.uri().clone();
146    let headers = request.headers().clone();
147
148    // Extract user agent and IP for logging, sanitized to prevent log injection
149    let user_agent = headers
150        .get("user-agent")
151        .and_then(|v| v.to_str().ok())
152        .unwrap_or("unknown");
153    let user_agent = sanitize_header_for_log(user_agent);
154
155    let forwarded_for = headers
156        .get("x-forwarded-for")
157        .and_then(|v| v.to_str().ok())
158        .unwrap_or("unknown");
159    let forwarded_for = sanitize_header_for_log(forwarded_for);
160
161    tracing::info!(
162        "Request started: {} {} from {} ({})",
163        method,
164        uri,
165        forwarded_for,
166        user_agent
167    );
168
169    let response = next.run(request).await;
170    let duration = start.elapsed();
171    let status = response.status();
172
173    tracing::info!(
174        "Request completed: {} {} {} in {:?}",
175        method,
176        uri,
177        status,
178        duration
179    );
180
181    response
182}
183
184/// Security headers middleware
185pub async fn security_headers_middleware(request: Request, next: Next) -> Response {
186    let response = next.run(request).await;
187
188    let mut response = response;
189    let headers = response.headers_mut();
190
191    // Security headers — all values are well-known static strings so from_static is safe.
192    headers.insert(
193        "X-Content-Type-Options",
194        axum::http::HeaderValue::from_static("nosniff"),
195    );
196    headers.insert(
197        "X-Frame-Options",
198        axum::http::HeaderValue::from_static("DENY"),
199    );
200    headers.insert(
201        "X-XSS-Protection",
202        axum::http::HeaderValue::from_static("1; mode=block"),
203    );
204    headers.insert(
205        "Strict-Transport-Security",
206        axum::http::HeaderValue::from_static("max-age=31536000; includeSubDomains"),
207    );
208    headers.insert(
209        "Referrer-Policy",
210        axum::http::HeaderValue::from_static("strict-origin-when-cross-origin"),
211    );
212    headers.insert(
213        "Permissions-Policy",
214        axum::http::HeaderValue::from_static("camera=(), microphone=(), geolocation=()"),
215    );
216    headers.insert(
217        "Content-Security-Policy",
218        axum::http::HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
219    );
220
221    response
222}
223
224/// Request timeout middleware
225pub async fn timeout_middleware(request: Request, next: Next) -> Result<Response, Response> {
226    // Set a 30-second timeout for all requests
227    match tokio::time::timeout(Duration::from_secs(30), next.run(request)).await {
228        Ok(response) => Ok(response),
229        Err(_) => {
230            let error_response =
231                ApiResponse::<()>::error("REQUEST_TIMEOUT", "Request timed out after 30 seconds");
232            Err(error_response.into_response())
233        }
234    }
235}
236
237/// Check whether `auth_token` carries the given permission.
238///
239/// Supports exact matches, a wildcard `"*"`, and prefix wildcards such as
240/// `"read:*"` (matches `"read:users"`, `"read:settings"`, etc.).
241///
242/// # Example
243/// ```rust,ignore
244/// if check_permission(&token, "users:write") {
245///     // allow the operation
246/// }
247/// ```
248pub fn check_permission(auth_token: &crate::tokens::AuthToken, required_permission: &str) -> bool {
249    auth_token.permissions.iter().any(|perm| {
250        perm == required_permission
251            || perm == "*"
252            || (perm.ends_with("*") && required_permission.starts_with(&perm[..perm.len() - 1]))
253    })
254}
255
256/// Check whether `auth_token` carries the given role.
257///
258/// The `"admin"` role implicitly matches every role.
259///
260/// # Example
261/// ```rust,ignore
262/// if check_role(&token, "moderator") {
263///     // allow moderation actions
264/// }
265/// ```
266pub fn check_role(auth_token: &crate::tokens::AuthToken, required_role: &str) -> bool {
267    auth_token.roles.contains(&required_role.to_string())
268        || auth_token.roles.contains(&"admin".to_string()) // Admin has all roles
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use crate::tokens::{AuthToken, TokenMetadata};
275
276    fn make_token(permissions: Vec<&str>, roles: Vec<&str>) -> AuthToken {
277        AuthToken {
278            token_id: "tid".into(),
279            user_id: "uid".into(),
280            access_token: "at".into(),
281            token_type: Some("Bearer".into()),
282            subject: Some("uid".into()),
283            issuer: Some("iss".into()),
284            refresh_token: None,
285            issued_at: chrono::Utc::now(),
286            expires_at: chrono::Utc::now(),
287            scopes: vec![].into(),
288            auth_method: "jwt".into(),
289            client_id: None,
290            user_profile: None,
291            permissions: permissions
292                .into_iter()
293                .map(String::from)
294                .collect::<Vec<_>>()
295                .into(),
296            roles: roles
297                .into_iter()
298                .map(String::from)
299                .collect::<Vec<_>>()
300                .into(),
301            metadata: TokenMetadata::default(),
302        }
303    }
304
305    // ── check_permission ────────────────────────────────────────────────
306
307    #[test]
308    fn test_check_permission_exact_match() {
309        let token = make_token(vec!["users:read"], vec![]);
310        assert!(check_permission(&token, "users:read"));
311    }
312
313    #[test]
314    fn test_check_permission_no_match() {
315        let token = make_token(vec!["users:read"], vec![]);
316        assert!(!check_permission(&token, "users:write"));
317    }
318
319    #[test]
320    fn test_check_permission_wildcard_all() {
321        let token = make_token(vec!["*"], vec![]);
322        assert!(check_permission(&token, "anything:at:all"));
323    }
324
325    #[test]
326    fn test_check_permission_wildcard_prefix() {
327        let token = make_token(vec!["users:*"], vec![]);
328        assert!(check_permission(&token, "users:read"));
329        assert!(check_permission(&token, "users:write"));
330        assert!(!check_permission(&token, "admin:read"));
331    }
332
333    #[test]
334    fn test_check_permission_empty() {
335        let token = make_token(vec![], vec![]);
336        assert!(!check_permission(&token, "anything"));
337    }
338
339    // ── check_role ──────────────────────────────────────────────────────
340
341    #[test]
342    fn test_check_role_exact_match() {
343        let token = make_token(vec![], vec!["editor"]);
344        assert!(check_role(&token, "editor"));
345    }
346
347    #[test]
348    fn test_check_role_no_match() {
349        let token = make_token(vec![], vec!["editor"]);
350        assert!(!check_role(&token, "moderator"));
351    }
352
353    #[test]
354    fn test_check_role_admin_has_all_roles() {
355        let token = make_token(vec![], vec!["admin"]);
356        assert!(check_role(&token, "anything"));
357        assert!(check_role(&token, "editor"));
358    }
359
360    #[test]
361    fn test_check_role_empty() {
362        let token = make_token(vec![], vec![]);
363        assert!(!check_role(&token, "user"));
364    }
365}