1use axum::body::Body;
7use axum::http::{Request, StatusCode};
8use axum::{extract::State, middleware::Next, response::Response};
9use tracing::{debug, error, warn};
10
11use super::authenticator::authenticate_request;
12use super::state::AuthState;
13use super::types::AuthResult;
14use mockforge_core::security::{
15 emit_security_event, EventActor, EventOutcome, EventTarget, SecurityEvent, SecurityEventType,
16};
17
18pub async fn auth_middleware(
20 State(state): State<AuthState>,
21 req: Request<Body>,
22 next: Next,
23) -> Response {
24 let path = req.uri().path().to_string();
25 let _method = req.method().clone();
26
27 if path.starts_with("/health") || path.starts_with("/__mockforge") {
29 return next.run(req).await;
30 }
31
32 let auth_header = req
34 .headers()
35 .get("authorization")
36 .and_then(|h| h.to_str().ok())
37 .map(|s| s.to_string());
38
39 let api_key_header = req
40 .headers()
41 .get(
42 state
43 .config
44 .api_key
45 .as_ref()
46 .map(|c| c.header_name.clone())
47 .unwrap_or_else(|| "X-API-Key".to_string()),
48 )
49 .and_then(|h| h.to_str().ok())
50 .map(|s| s.to_string());
51
52 let api_key_query = req.uri().query().and_then(|q| {
53 state
54 .config
55 .api_key
56 .as_ref()
57 .and_then(|c| c.query_name.as_ref())
58 .and_then(|param| {
59 url::form_urlencoded::parse(q.as_bytes())
60 .find(|(k, _)| k == param)
61 .map(|(_, v)| v.to_string())
62 })
63 });
64
65 let ip_address = req
67 .headers()
68 .get("x-forwarded-for")
69 .or_else(|| req.headers().get("x-real-ip"))
70 .and_then(|h| h.to_str().ok())
71 .map(|s| s.to_string())
72 .or_else(|| {
73 req.extensions()
74 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
75 .map(|addr| addr.ip().to_string())
76 });
77
78 let user_agent = req
79 .headers()
80 .get("user-agent")
81 .and_then(|h| h.to_str().ok())
82 .map(|s| s.to_string());
83
84 let auth_result =
86 authenticate_request(&state, &auth_header, &api_key_header, &api_key_query).await;
87
88 match auth_result {
89 AuthResult::Success(claims) => {
90 debug!("Authentication successful for user: {:?}", claims.sub);
91
92 let event = SecurityEvent::new(SecurityEventType::AuthSuccess, None, None)
94 .with_actor(EventActor {
95 user_id: claims.sub.clone(),
96 username: claims.sub.clone(),
97 ip_address: ip_address.clone(),
98 user_agent: user_agent.clone(),
99 })
100 .with_target(EventTarget {
101 resource_type: Some("api".to_string()),
102 resource_id: Some(path.clone()),
103 method: Some(req.method().to_string()),
104 })
105 .with_outcome(EventOutcome {
106 success: true,
107 reason: None,
108 })
109 .with_metadata("auth_method".to_string(), serde_json::json!("jwt"));
110 emit_security_event(event).await;
111
112 use axum::extract::Extension;
115 let mut req = req;
116 req.extensions_mut().insert(Extension(claims));
117 next.run(req).await
118 }
119 AuthResult::Failure(reason) => {
120 warn!("Authentication failed: {}", reason);
121
122 let event = SecurityEvent::new(SecurityEventType::AuthFailure, None, None)
124 .with_actor(EventActor {
125 user_id: None,
126 username: auth_header
127 .as_ref()
128 .and_then(|h| h.strip_prefix("Bearer "))
129 .or_else(|| auth_header.as_ref().and_then(|h| h.strip_prefix("Basic ")))
130 .map(|s| s.to_string()),
131 ip_address: ip_address.clone(),
132 user_agent: user_agent.clone(),
133 })
134 .with_target(EventTarget {
135 resource_type: Some("api".to_string()),
136 resource_id: Some(path.clone()),
137 method: Some(req.method().to_string()),
138 })
139 .with_outcome(EventOutcome {
140 success: false,
141 reason: Some(reason.clone()),
142 })
143 .with_metadata("failure_reason".to_string(), serde_json::json!(reason));
144 emit_security_event(event).await;
145 let mut res = Response::new(axum::body::Body::from(
146 serde_json::json!({
147 "error": "Authentication failed",
148 "message": reason
149 })
150 .to_string(),
151 ));
152 *res.status_mut() = StatusCode::UNAUTHORIZED;
153 res.headers_mut().insert("www-authenticate", "Bearer".parse().unwrap());
154 res
155 }
156 AuthResult::NetworkError(reason) => {
157 error!("Authentication network error: {}", reason);
158 let mut res = Response::new(axum::body::Body::from(
159 serde_json::json!({
160 "error": "Authentication service unavailable",
161 "message": "Unable to verify token due to network issues"
162 })
163 .to_string(),
164 ));
165 *res.status_mut() = StatusCode::SERVICE_UNAVAILABLE;
166 res
167 }
168 AuthResult::ServerError(reason) => {
169 error!("Authentication server error: {}", reason);
170 let mut res = Response::new(axum::body::Body::from(
171 serde_json::json!({
172 "error": "Authentication service error",
173 "message": "Unable to verify token due to server issues"
174 })
175 .to_string(),
176 ));
177 *res.status_mut() = StatusCode::BAD_GATEWAY;
178 res
179 }
180 AuthResult::TokenExpired => {
181 warn!("Token expired");
182
183 let event = SecurityEvent::new(SecurityEventType::AuthTokenExpired, None, None)
185 .with_actor(EventActor {
186 user_id: None,
187 username: None,
188 ip_address: ip_address.clone(),
189 user_agent: user_agent.clone(),
190 })
191 .with_target(EventTarget {
192 resource_type: Some("api".to_string()),
193 resource_id: Some(path.clone()),
194 method: Some(req.method().to_string()),
195 })
196 .with_outcome(EventOutcome {
197 success: false,
198 reason: Some("Token expired".to_string()),
199 });
200 emit_security_event(event).await;
201 let mut res = Response::new(axum::body::Body::from(
202 serde_json::json!({
203 "error": "Token expired",
204 "message": "The provided token has expired"
205 })
206 .to_string(),
207 ));
208 *res.status_mut() = StatusCode::UNAUTHORIZED;
209 res.headers_mut().insert(
210 "www-authenticate",
211 "Bearer error=\"invalid_token\", error_description=\"The token has expired\""
212 .parse()
213 .unwrap(),
214 );
215 res
216 }
217 AuthResult::TokenInvalid(reason) => {
218 warn!("Token invalid: {}", reason);
219
220 let event = SecurityEvent::new(SecurityEventType::AuthFailure, None, None)
222 .with_actor(EventActor {
223 user_id: None,
224 username: None,
225 ip_address: ip_address.clone(),
226 user_agent: user_agent.clone(),
227 })
228 .with_target(EventTarget {
229 resource_type: Some("api".to_string()),
230 resource_id: Some(path.clone()),
231 method: Some(req.method().to_string()),
232 })
233 .with_outcome(EventOutcome {
234 success: false,
235 reason: Some(format!("Invalid token: {}", reason)),
236 })
237 .with_metadata("token_invalid".to_string(), serde_json::json!(true));
238 emit_security_event(event).await;
239 let mut res = Response::new(axum::body::Body::from(
240 serde_json::json!({
241 "error": "Invalid token",
242 "message": reason
243 })
244 .to_string(),
245 ));
246 *res.status_mut() = StatusCode::UNAUTHORIZED;
247 res.headers_mut()
248 .insert("www-authenticate", "Bearer error=\"invalid_token\"".parse().unwrap());
249 res
250 }
251 AuthResult::None => {
252 if state.config.require_auth {
253 let event = SecurityEvent::new(SecurityEventType::AuthzAccessDenied, None, None)
255 .with_actor(EventActor {
256 user_id: None,
257 username: None,
258 ip_address: ip_address.clone(),
259 user_agent: user_agent.clone(),
260 })
261 .with_target(EventTarget {
262 resource_type: Some("api".to_string()),
263 resource_id: Some(path.clone()),
264 method: Some(req.method().to_string()),
265 })
266 .with_outcome(EventOutcome {
267 success: false,
268 reason: Some("Authentication required but not provided".to_string()),
269 });
270 emit_security_event(event).await;
271
272 let mut res = Response::new(axum::body::Body::from(
273 serde_json::json!({
274 "error": "Authentication required"
275 })
276 .to_string(),
277 ));
278 *res.status_mut() = StatusCode::UNAUTHORIZED;
279 res.headers_mut().insert("www-authenticate", "Bearer".parse().unwrap());
280 res
281 } else {
282 debug!("No authentication provided, proceeding without auth");
283 next.run(req).await
284 }
285 }
286 }
287}