auth_framework/api/
middleware.rs1use crate::api::{ApiResponse, ApiState, extract_bearer_token, validate_api_token};
6use axum::{
7 extract::{Request, State},
8 middleware::Next,
9 response::{IntoResponse, Response},
10};
11use std::time::{Duration, Instant};
12
13pub async fn auth_middleware(
15 State(state): State<ApiState>,
16 mut request: Request,
17 next: Next,
18) -> Result<Response, Response> {
19 let path = request.uri().path();
21 if is_public_endpoint(path) {
22 return Ok(next.run(request).await);
23 }
24
25 let headers = request.headers();
27 match extract_bearer_token(headers) {
28 Some(token) => {
29 match validate_api_token(&state.auth_framework, &token).await {
31 Ok(auth_token) => {
32 request.extensions_mut().insert(auth_token);
34 Ok(next.run(request).await)
35 }
36 Err(_) => {
37 let error_response = ApiResponse::<()>::unauthorized();
38 Err(error_response.into_response())
39 }
40 }
41 }
42 None => {
43 let error_response = ApiResponse::<()>::unauthorized();
44 Err(error_response.into_response())
45 }
46 }
47}
48
49pub async fn admin_middleware(
51 State(_state): State<ApiState>,
52 request: Request,
53 next: Next,
54) -> Result<Response, Response> {
55 match request.extensions().get::<crate::tokens::AuthToken>() {
57 Some(auth_token) => {
58 if auth_token.roles.contains(&"admin".to_string()) {
59 Ok(next.run(request).await)
60 } else {
61 let error_response = ApiResponse::<()>::forbidden();
62 Err(error_response.into_response())
63 }
64 }
65 None => {
66 let error_response = ApiResponse::<()>::unauthorized();
68 Err(error_response.into_response())
69 }
70 }
71}
72
73pub async fn rate_limit_middleware(request: Request, next: Next) -> Result<Response, Response> {
75 let mut response = next.run(request).await;
79
80 let headers = response.headers_mut();
82 headers.insert("X-RateLimit-Limit", "100".parse().unwrap());
83 headers.insert("X-RateLimit-Remaining", "95".parse().unwrap());
84 headers.insert("X-RateLimit-Reset", "1692278400".parse().unwrap()); Ok(response)
87}
88
89pub async fn cors_middleware(request: Request, next: Next) -> Response {
91 let response = next.run(request).await;
92
93 let mut response = response;
94 let headers = response.headers_mut();
95
96 headers.insert("Access-Control-Allow-Origin", "*".parse().unwrap());
97 headers.insert(
98 "Access-Control-Allow-Methods",
99 "GET, POST, PUT, DELETE, OPTIONS".parse().unwrap(),
100 );
101 headers.insert(
102 "Access-Control-Allow-Headers",
103 "Content-Type, Authorization".parse().unwrap(),
104 );
105 headers.insert("Access-Control-Max-Age", "3600".parse().unwrap());
106
107 response
108}
109
110pub async fn logging_middleware(request: Request, next: Next) -> Response {
112 let start = Instant::now();
113 let method = request.method().clone();
114 let uri = request.uri().clone();
115 let headers = request.headers().clone();
116
117 let user_agent = headers
119 .get("user-agent")
120 .and_then(|v| v.to_str().ok())
121 .unwrap_or("unknown");
122
123 let forwarded_for = headers
124 .get("x-forwarded-for")
125 .and_then(|v| v.to_str().ok())
126 .unwrap_or("unknown");
127
128 tracing::info!(
129 "Request started: {} {} from {} ({})",
130 method,
131 uri,
132 forwarded_for,
133 user_agent
134 );
135
136 let response = next.run(request).await;
137 let duration = start.elapsed();
138 let status = response.status();
139
140 tracing::info!(
141 "Request completed: {} {} {} in {:?}",
142 method,
143 uri,
144 status,
145 duration
146 );
147
148 response
149}
150
151pub async fn security_headers_middleware(request: Request, next: Next) -> Response {
153 let response = next.run(request).await;
154
155 let mut response = response;
156 let headers = response.headers_mut();
157
158 headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
160 headers.insert("X-Frame-Options", "DENY".parse().unwrap());
161 headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
162 headers.insert(
163 "Strict-Transport-Security",
164 "max-age=31536000; includeSubDomains".parse().unwrap(),
165 );
166 headers.insert(
167 "Referrer-Policy",
168 "strict-origin-when-cross-origin".parse().unwrap(),
169 );
170 headers.insert(
171 "Permissions-Policy",
172 "camera=(), microphone=(), geolocation=()".parse().unwrap(),
173 );
174
175 response
176}
177
178pub async fn timeout_middleware(request: Request, next: Next) -> Result<Response, Response> {
180 match tokio::time::timeout(Duration::from_secs(30), next.run(request)).await {
182 Ok(response) => Ok(response),
183 Err(_) => {
184 let error_response =
185 ApiResponse::<()>::error("REQUEST_TIMEOUT", "Request timed out after 30 seconds");
186 Err(error_response.into_response())
187 }
188 }
189}
190
191fn is_public_endpoint(path: &str) -> bool {
193 match path {
194 "/health" | "/health/detailed" | "/metrics" | "/readiness" | "/liveness" => true,
195 "/auth/login" | "/auth/refresh" | "/auth/providers" => true,
196 "/oauth/authorize" | "/oauth/token" | "/oauth/.well-known/openid_configuration" => true,
197 _ if path.starts_with("/oauth/.well-known/") => true,
198 _ => false,
199 }
200}
201
202pub fn check_permission(auth_token: &crate::tokens::AuthToken, required_permission: &str) -> bool {
204 auth_token.permissions.iter().any(|perm| {
205 perm == required_permission
206 || perm == "*"
207 || (perm.ends_with("*") && required_permission.starts_with(&perm[..perm.len() - 1]))
208 })
209}
210
211pub fn check_role(auth_token: &crate::tokens::AuthToken, required_role: &str) -> bool {
213 auth_token.roles.contains(&required_role.to_string())
214 || auth_token.roles.contains(&"admin".to_string()) }
216
217