auth_framework/authorization_enhanced/
middleware.rs

1//! Enhanced authorization middleware using role-system v1.0
2//!
3//! This module provides comprehensive authorization middleware for Axum,
4//! replacing the basic role checking with enterprise-grade RBAC.
5
6use crate::api::{ApiResponse, ApiState};
7use crate::tokens::AuthToken;
8use axum::{
9    extract::{Request, State},
10    middleware::Next,
11    response::{IntoResponse, Response},
12};
13use std::collections::HashMap;
14use tracing::{debug, info, warn};
15
16/// Enhanced RBAC middleware using role-system v1.0
17pub async fn rbac_middleware(
18    State(state): State<ApiState>,
19    request: Request,
20    next: Next,
21) -> Result<Response, Response> {
22    // Skip auth for public endpoints
23    let path = request.uri().path();
24    if is_public_endpoint(path) {
25        return Ok(next.run(request).await);
26    }
27
28    // Get auth token from extensions (should be set by auth_middleware)
29    let auth_token = match request.extensions().get::<AuthToken>() {
30        Some(token) => token.clone(),
31        None => {
32            let error_response = ApiResponse::<()>::unauthorized();
33            return Err(error_response.into_response());
34        }
35    };
36
37    // Build request context for conditional permissions
38    let context = build_request_context(&request, &auth_token);
39
40    // Check authorization using role-system
41    let authorized = match check_authorization(&state, &auth_token, &request, &context).await {
42        Ok(granted) => granted,
43        Err(e) => {
44            warn!("Authorization check failed: {}", e);
45            let error_response = ApiResponse::<()>::forbidden();
46            return Err(error_response.into_response());
47        }
48    };
49
50    if authorized {
51        debug!(
52            "Authorization granted for user '{}' on {}",
53            auth_token.user_id, path
54        );
55        Ok(next.run(request).await)
56    } else {
57        info!(
58            "Authorization denied for user '{}' on {}",
59            auth_token.user_id, path
60        );
61        let error_response = ApiResponse::<()>::forbidden();
62        Err(error_response.into_response())
63    }
64}
65
66/// Conditional permission middleware for time/location-based access
67pub async fn conditional_permission_middleware(
68    State(state): State<ApiState>,
69    request: Request,
70    next: Next,
71) -> Result<Response, Response> {
72    let path = request.uri().path();
73
74    // Apply conditional permissions for sensitive endpoints
75    if is_sensitive_endpoint(path) {
76        let auth_token = match request.extensions().get::<AuthToken>() {
77            Some(token) => token,
78            None => {
79                let error_response = ApiResponse::<()>::unauthorized();
80                return Err(error_response.into_response());
81            }
82        };
83
84        let context = build_conditional_context(&request);
85
86        // Check conditional permissions
87        let has_conditional_access: bool = state
88            .authorization_service
89            .check_permission(&auth_token.user_id, "access", path, Some(&context))
90            .await
91            .unwrap_or_default();
92
93        if !has_conditional_access {
94            info!(
95                "Conditional access denied for user '{}' on {}",
96                auth_token.user_id, path
97            );
98            let error_response = ApiResponse::<()>::error(
99                "CONDITIONAL_ACCESS_DENIED",
100                "Access denied due to conditional permissions (time, location, etc.)",
101            );
102            return Err(error_response.into_response());
103        }
104    }
105
106    Ok(next.run(request).await)
107}
108
109/// Role elevation middleware for administrative actions
110pub async fn role_elevation_middleware(
111    State(state): State<ApiState>,
112    request: Request,
113    next: Next,
114) -> Result<Response, Response> {
115    let path = request.uri().path();
116
117    // Check if this endpoint requires elevated permissions
118    if requires_role_elevation(path) {
119        let auth_token = match request.extensions().get::<AuthToken>() {
120            Some(token) => token,
121            None => {
122                let error_response = ApiResponse::<()>::unauthorized();
123                return Err(error_response.into_response());
124            }
125        };
126
127        // Check if user has elevated permissions
128        let has_elevated_access: bool = state
129            .authorization_service
130            .check_permission(&auth_token.user_id, "elevated", "admin", None)
131            .await
132            .unwrap_or_default();
133
134        if !has_elevated_access {
135            info!(
136                "Elevated access required for user '{}' on {}",
137                auth_token.user_id, path
138            );
139            let error_response = ApiResponse::<()>::error(
140                "ELEVATION_REQUIRED",
141                "This action requires elevated permissions. Please request temporary role elevation.",
142            );
143            return Err(error_response.into_response());
144        }
145    }
146
147    Ok(next.run(request).await)
148}
149
150/// Check authorization using the enhanced authorization service
151async fn check_authorization(
152    state: &ApiState,
153    auth_token: &AuthToken,
154    request: &Request,
155    context: &HashMap<String, String>,
156) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
157    let method = request.method().as_str();
158    let path = request.uri().path();
159
160    // Use the enhanced authorization service
161    state
162        .authorization_service
163        .check_api_permission(&auth_token.user_id, method, path, context)
164        .await
165        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
166}
167
168/// Build request context for conditional permissions
169fn build_request_context(request: &Request, auth_token: &AuthToken) -> HashMap<String, String> {
170    let mut context = HashMap::new();
171
172    // Add user context
173    context.insert("user_id".to_string(), auth_token.user_id.clone());
174    context.insert("roles".to_string(), auth_token.roles.join(","));
175
176    // Add request metadata
177    if let Some(user_agent) = request.headers().get("user-agent")
178        && let Ok(ua_str) = user_agent.to_str()
179    {
180        context.insert("user_agent".to_string(), ua_str.to_string());
181    }
182
183    // Add IP address
184    if let Some(forwarded_for) = request.headers().get("x-forwarded-for")
185        && let Ok(ip_str) = forwarded_for.to_str()
186    {
187        context.insert("ip_address".to_string(), ip_str.to_string());
188    }
189
190    // Add time-based context
191    let current_hour = chrono::Utc::now().format("%H").to_string();
192    let hour: u32 = current_hour.parse().unwrap_or(0);
193
194    if (9..=17).contains(&hour) {
195        context.insert("time".to_string(), "business_hours".to_string());
196    } else {
197        context.insert("time".to_string(), "after_hours".to_string());
198    }
199
200    // Add day of week
201    let day_of_week = chrono::Utc::now().format("%u").to_string(); // 1-7, Monday = 1
202    let weekday: u32 = day_of_week.parse().unwrap_or(1);
203
204    if (1..=5).contains(&weekday) {
205        context.insert("day_type".to_string(), "weekday".to_string());
206    } else {
207        context.insert("day_type".to_string(), "weekend".to_string());
208    }
209
210    context
211}
212
213/// Build conditional context for sensitive operations
214fn build_conditional_context(request: &Request) -> HashMap<String, String> {
215    let mut context = HashMap::new();
216
217    // Check for VPN indicators
218    if let Some(via) = request.headers().get("via")
219        && let Ok(via_str) = via.to_str()
220        && (via_str.contains("vpn") || via_str.contains("proxy"))
221    {
222        context.insert("connection_type".to_string(), "vpn".to_string());
223    }
224
225    // Check for mobile device
226    if let Some(user_agent) = request.headers().get("user-agent")
227        && let Ok(ua_str) = user_agent.to_str()
228    {
229        if ua_str.contains("Mobile") || ua_str.contains("Android") || ua_str.contains("iPhone") {
230            context.insert("device_type".to_string(), "mobile".to_string());
231        } else {
232            context.insert("device_type".to_string(), "desktop".to_string());
233        }
234    }
235
236    // Add security level based on endpoint sensitivity
237    let path = request.uri().path();
238    if path.contains("/admin/") {
239        context.insert("security_level".to_string(), "high".to_string());
240    } else if path.contains("/api/") {
241        context.insert("security_level".to_string(), "medium".to_string());
242    } else {
243        context.insert("security_level".to_string(), "low".to_string());
244    }
245
246    context
247}
248
249/// Check if endpoint is public (doesn't require authentication)
250fn is_public_endpoint(path: &str) -> bool {
251    match path {
252        "/health" | "/health/detailed" | "/metrics" | "/readiness" | "/liveness" => true,
253        "/auth/login" | "/auth/refresh" | "/auth/providers" => true,
254        "/oauth/authorize" | "/oauth/token" | "/oauth/.well-known/openid_configuration" => true,
255        _ if path.starts_with("/oauth/.well-known/") => true,
256        _ => false,
257    }
258}
259
260/// Check if endpoint contains sensitive data
261fn is_sensitive_endpoint(path: &str) -> bool {
262    match path {
263        _ if path.starts_with("/admin/") => true,
264        _ if path.contains("/secrets/") => true,
265        _ if path.contains("/config/") => true,
266        _ if path.contains("/keys/") => true,
267        "/auth/logout" => true, // Logout should have conditional access
268        _ => false,
269    }
270}
271
272/// Check if endpoint requires role elevation
273fn requires_role_elevation(path: &str) -> bool {
274    match path {
275        _ if path.starts_with("/admin/users/delete") => true,
276        _ if path.starts_with("/admin/system/") => true,
277        _ if path.contains("/sudo/") => true,
278        _ if path.contains("/elevate/") => true,
279        _ => false,
280    }
281}
282
283/// Permission requirement middleware factory
284pub fn require_permission(
285    action: &str,
286    resource: &str,
287) -> impl Fn(
288    Request,
289    Next,
290) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, Response>> + Send>>
291+ Clone {
292    let action = action.to_string();
293    let resource = resource.to_string();
294
295    move |request: Request, next: Next| {
296        let action = action.clone();
297        let resource = resource.clone();
298        Box::pin(async move {
299            let auth_token = match request.extensions().get::<AuthToken>() {
300                Some(token) => token,
301                None => {
302                    let error_response = ApiResponse::<()>::unauthorized();
303                    return Err(error_response.into_response());
304                }
305            };
306
307            // For this implementation, we'd need access to the authorization service
308            // This would typically be passed through the request state
309            // For now, we'll do a basic permission check using the auth token
310            if check_token_permission(auth_token, &action, &resource) {
311                Ok(next.run(request).await)
312            } else {
313                let error_response = ApiResponse::<()>::forbidden();
314                Err(error_response.into_response())
315            }
316        })
317    }
318}
319
320/// Basic permission check using auth token (fallback)
321fn check_token_permission(auth_token: &AuthToken, action: &str, resource: &str) -> bool {
322    // Check for admin role (has all permissions)
323    if auth_token.roles.contains(&"admin".to_string()) {
324        return true;
325    }
326
327    // Check explicit permissions
328    let required_permission = format!("{}:{}", action, resource);
329    auth_token.permissions.iter().any(|perm| {
330        perm == &required_permission
331            || perm == "*"
332            || (perm.ends_with("*") && required_permission.starts_with(&perm[..perm.len() - 1]))
333    })
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_public_endpoint_detection() {
342        assert!(is_public_endpoint("/health"));
343        assert!(is_public_endpoint("/auth/login"));
344        assert!(is_public_endpoint(
345            "/oauth/.well-known/openid_configuration"
346        ));
347        assert!(!is_public_endpoint("/api/users"));
348        assert!(!is_public_endpoint("/admin/roles"));
349    }
350
351    #[test]
352    fn test_sensitive_endpoint_detection() {
353        assert!(is_sensitive_endpoint("/admin/users"));
354        assert!(is_sensitive_endpoint("/api/secrets/vault"));
355        assert!(is_sensitive_endpoint("/auth/logout"));
356        assert!(!is_sensitive_endpoint("/api/health"));
357        assert!(!is_sensitive_endpoint("/public/info"));
358    }
359
360    #[test]
361    fn test_elevation_requirement() {
362        assert!(requires_role_elevation("/admin/users/delete/123"));
363        assert!(requires_role_elevation("/admin/system/shutdown"));
364        assert!(requires_role_elevation("/api/sudo/execute"));
365        assert!(!requires_role_elevation("/admin/users"));
366        assert!(!requires_role_elevation("/api/profile"));
367    }
368
369    #[test]
370    fn test_context_building() {
371        // This would require setting up a mock request, which is complex
372        // In a real test, we'd create a proper test request and verify context
373        let context: HashMap<String, String> = HashMap::new();
374        assert!(context.is_empty()); // Placeholder test
375    }
376}
377
378