auth_framework/api/
middleware.rs

1//! API Middleware
2//!
3//! Authentication, authorization, rate limiting, and other middleware
4
5use crate::api::{ApiResponse, ApiState, extract_bearer_token, validate_api_token};
6use axum::{
7    extract::{Request, State},
8    middleware::Next,
9    response::{IntoResponse, Response},
10};
11use std::time::{Duration, Instant};
12
13/// Authentication middleware
14pub async fn auth_middleware(
15    State(state): State<ApiState>,
16    mut request: Request,
17    next: Next,
18) -> Result<Response, Response> {
19    // Skip auth for public endpoints
20    let path = request.uri().path();
21    if is_public_endpoint(path) {
22        return Ok(next.run(request).await);
23    }
24
25    // Extract token from headers
26    let headers = request.headers();
27    match extract_bearer_token(headers) {
28        Some(token) => {
29            // Validate token
30            match validate_api_token(&state.auth_framework, &token).await {
31                Ok(auth_token) => {
32                    // Add auth token to request extensions for use in handlers
33                    request.extensions_mut().insert(auth_token);
34                    Ok(next.run(request).await)
35                }
36                Err(_) => {
37                    let error_response = ApiResponse::<()>::unauthorized();
38                    Err(error_response.into_response())
39                }
40            }
41        }
42        None => {
43            let error_response = ApiResponse::<()>::unauthorized();
44            Err(error_response.into_response())
45        }
46    }
47}
48
49/// Admin authorization middleware
50pub async fn admin_middleware(
51    State(_state): State<ApiState>,
52    request: Request,
53    next: Next,
54) -> Result<Response, Response> {
55    // Get auth token from extensions (should be set by auth_middleware)
56    match request.extensions().get::<crate::tokens::AuthToken>() {
57        Some(auth_token) => {
58            if auth_token.roles.contains(&"admin".to_string()) {
59                Ok(next.run(request).await)
60            } else {
61                let error_response = ApiResponse::<()>::forbidden();
62                Err(error_response.into_response())
63            }
64        }
65        None => {
66            // If no auth token, user is not authenticated
67            let error_response = ApiResponse::<()>::unauthorized();
68            Err(error_response.into_response())
69        }
70    }
71}
72
73/// Rate limiting middleware
74pub async fn rate_limit_middleware(request: Request, next: Next) -> Result<Response, Response> {
75    // In a real implementation, use a distributed rate limiter like Redis
76    // For now, just add rate limit headers
77
78    let mut response = next.run(request).await;
79
80    // Add rate limit headers
81    let headers = response.headers_mut();
82    headers.insert("X-RateLimit-Limit", "100".parse().unwrap());
83    headers.insert("X-RateLimit-Remaining", "95".parse().unwrap());
84    headers.insert("X-RateLimit-Reset", "1692278400".parse().unwrap()); // Unix timestamp
85
86    Ok(response)
87}
88
89/// CORS middleware
90pub async fn cors_middleware(request: Request, next: Next) -> Response {
91    let response = next.run(request).await;
92
93    let mut response = response;
94    let headers = response.headers_mut();
95
96    headers.insert("Access-Control-Allow-Origin", "*".parse().unwrap());
97    headers.insert(
98        "Access-Control-Allow-Methods",
99        "GET, POST, PUT, DELETE, OPTIONS".parse().unwrap(),
100    );
101    headers.insert(
102        "Access-Control-Allow-Headers",
103        "Content-Type, Authorization".parse().unwrap(),
104    );
105    headers.insert("Access-Control-Max-Age", "3600".parse().unwrap());
106
107    response
108}
109
110/// Logging middleware
111pub async fn logging_middleware(request: Request, next: Next) -> Response {
112    let start = Instant::now();
113    let method = request.method().clone();
114    let uri = request.uri().clone();
115    let headers = request.headers().clone();
116
117    // Extract user agent and IP for logging
118    let user_agent = headers
119        .get("user-agent")
120        .and_then(|v| v.to_str().ok())
121        .unwrap_or("unknown");
122
123    let forwarded_for = headers
124        .get("x-forwarded-for")
125        .and_then(|v| v.to_str().ok())
126        .unwrap_or("unknown");
127
128    tracing::info!(
129        "Request started: {} {} from {} ({})",
130        method,
131        uri,
132        forwarded_for,
133        user_agent
134    );
135
136    let response = next.run(request).await;
137    let duration = start.elapsed();
138    let status = response.status();
139
140    tracing::info!(
141        "Request completed: {} {} {} in {:?}",
142        method,
143        uri,
144        status,
145        duration
146    );
147
148    response
149}
150
151/// Security headers middleware
152pub async fn security_headers_middleware(request: Request, next: Next) -> Response {
153    let response = next.run(request).await;
154
155    let mut response = response;
156    let headers = response.headers_mut();
157
158    // Security headers
159    headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
160    headers.insert("X-Frame-Options", "DENY".parse().unwrap());
161    headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
162    headers.insert(
163        "Strict-Transport-Security",
164        "max-age=31536000; includeSubDomains".parse().unwrap(),
165    );
166    headers.insert(
167        "Referrer-Policy",
168        "strict-origin-when-cross-origin".parse().unwrap(),
169    );
170    headers.insert(
171        "Permissions-Policy",
172        "camera=(), microphone=(), geolocation=()".parse().unwrap(),
173    );
174
175    response
176}
177
178/// Request timeout middleware
179pub async fn timeout_middleware(request: Request, next: Next) -> Result<Response, Response> {
180    // Set a 30-second timeout for all requests
181    match tokio::time::timeout(Duration::from_secs(30), next.run(request)).await {
182        Ok(response) => Ok(response),
183        Err(_) => {
184            let error_response =
185                ApiResponse::<()>::error("REQUEST_TIMEOUT", "Request timed out after 30 seconds");
186            Err(error_response.into_response())
187        }
188    }
189}
190
191/// Check if endpoint is public (doesn't require authentication)
192fn is_public_endpoint(path: &str) -> bool {
193    match path {
194        "/health" | "/health/detailed" | "/metrics" | "/readiness" | "/liveness" => true,
195        "/auth/login" | "/auth/refresh" | "/auth/providers" => true,
196        "/oauth/authorize" | "/oauth/token" | "/oauth/.well-known/openid_configuration" => true,
197        _ if path.starts_with("/oauth/.well-known/") => true,
198        _ => false,
199    }
200}
201
202/// Permission check helper
203pub fn check_permission(auth_token: &crate::tokens::AuthToken, required_permission: &str) -> bool {
204    auth_token.permissions.iter().any(|perm| {
205        perm == required_permission
206            || perm == "*"
207            || (perm.ends_with("*") && required_permission.starts_with(&perm[..perm.len() - 1]))
208    })
209}
210
211/// Role check helper
212pub fn check_role(auth_token: &crate::tokens::AuthToken, required_role: &str) -> bool {
213    auth_token.roles.contains(&required_role.to_string())
214        || auth_token.roles.contains(&"admin".to_string()) // Admin has all roles
215}
216
217