1use axum::body::Body;
7use axum::http::{Request, StatusCode};
8use axum::{extract::State, middleware::Next, response::Response};
9use tracing::{debug, error, warn};
10
11use super::authenticator::authenticate_request;
12use super::state::AuthState;
13use super::types::AuthResult;
14use mockforge_core::security::{
15 emit_security_event, EventActor, EventOutcome, EventTarget, SecurityEvent, SecurityEventType,
16};
17
18pub async fn auth_middleware(
20 State(state): State<AuthState>,
21 req: Request<Body>,
22 next: Next,
23) -> Response {
24 let path = req.uri().path().to_string();
25 let _method = req.method().clone();
26
27 if path.starts_with("/health") || path.starts_with("/__mockforge") {
29 return next.run(req).await;
30 }
31
32 let auth_header = req
34 .headers()
35 .get("authorization")
36 .and_then(|h| h.to_str().ok())
37 .map(|s| s.to_string());
38
39 let api_key_header = req
40 .headers()
41 .get(
42 state
43 .config
44 .api_key
45 .as_ref()
46 .map(|c| c.header_name.clone())
47 .unwrap_or_else(|| "X-API-Key".to_string()),
48 )
49 .and_then(|h| h.to_str().ok())
50 .map(|s| s.to_string());
51
52 let api_key_query = req.uri().query().and_then(|q| {
53 state
54 .config
55 .api_key
56 .as_ref()
57 .and_then(|c| c.query_name.as_ref())
58 .and_then(|param| {
59 url::form_urlencoded::parse(q.as_bytes())
60 .find(|(k, _)| k == param)
61 .map(|(_, v)| v.to_string())
62 })
63 });
64
65 let ip_address = req
67 .headers()
68 .get("x-forwarded-for")
69 .or_else(|| req.headers().get("x-real-ip"))
70 .and_then(|h| h.to_str().ok())
71 .map(|s| s.to_string())
72 .or_else(|| {
73 req.extensions()
74 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
75 .map(|addr| addr.ip().to_string())
76 });
77
78 let user_agent = req
79 .headers()
80 .get("user-agent")
81 .and_then(|h| h.to_str().ok())
82 .map(|s| s.to_string());
83
84 let auth_result =
86 authenticate_request(&state, &auth_header, &api_key_header, &api_key_query).await;
87
88 match auth_result {
89 AuthResult::Success(claims) => {
90 debug!("Authentication successful for user: {:?}", claims.sub);
91
92 let event = SecurityEvent::new(SecurityEventType::AuthSuccess, None, None)
94 .with_actor(EventActor {
95 user_id: claims.sub.clone(),
96 username: claims.sub.clone(),
97 ip_address: ip_address.clone(),
98 user_agent: user_agent.clone(),
99 })
100 .with_target(EventTarget {
101 resource_type: Some("api".to_string()),
102 resource_id: Some(path.clone()),
103 method: Some(req.method().to_string()),
104 })
105 .with_outcome(EventOutcome {
106 success: true,
107 reason: None,
108 })
109 .with_metadata("auth_method".to_string(), serde_json::json!("jwt"));
110 emit_security_event(event).await;
111
112 let mut req = req;
114 req.extensions_mut().insert(claims);
115 next.run(req).await
116 }
117 AuthResult::Failure(reason) => {
118 warn!("Authentication failed: {}", reason);
119
120 let event = SecurityEvent::new(SecurityEventType::AuthFailure, None, None)
122 .with_actor(EventActor {
123 user_id: None,
124 username: auth_header
125 .as_ref()
126 .and_then(|h| h.strip_prefix("Bearer "))
127 .or_else(|| auth_header.as_ref().and_then(|h| h.strip_prefix("Basic ")))
128 .map(|s| s.to_string()),
129 ip_address: ip_address.clone(),
130 user_agent: user_agent.clone(),
131 })
132 .with_target(EventTarget {
133 resource_type: Some("api".to_string()),
134 resource_id: Some(path.clone()),
135 method: Some(req.method().to_string()),
136 })
137 .with_outcome(EventOutcome {
138 success: false,
139 reason: Some(reason.clone()),
140 })
141 .with_metadata("failure_reason".to_string(), serde_json::json!(reason));
142 emit_security_event(event).await;
143 let mut res = Response::new(axum::body::Body::from(
144 serde_json::json!({
145 "error": "Authentication failed",
146 "message": reason
147 })
148 .to_string(),
149 ));
150 *res.status_mut() = StatusCode::UNAUTHORIZED;
151 res.headers_mut().insert("www-authenticate", "Bearer".parse().unwrap());
152 res
153 }
154 AuthResult::NetworkError(reason) => {
155 error!("Authentication network error: {}", reason);
156 let mut res = Response::new(axum::body::Body::from(
157 serde_json::json!({
158 "error": "Authentication service unavailable",
159 "message": "Unable to verify token due to network issues"
160 })
161 .to_string(),
162 ));
163 *res.status_mut() = StatusCode::SERVICE_UNAVAILABLE;
164 res
165 }
166 AuthResult::ServerError(reason) => {
167 error!("Authentication server error: {}", reason);
168 let mut res = Response::new(axum::body::Body::from(
169 serde_json::json!({
170 "error": "Authentication service error",
171 "message": "Unable to verify token due to server issues"
172 })
173 .to_string(),
174 ));
175 *res.status_mut() = StatusCode::BAD_GATEWAY;
176 res
177 }
178 AuthResult::TokenExpired => {
179 warn!("Token expired");
180
181 let event = SecurityEvent::new(SecurityEventType::AuthTokenExpired, None, None)
183 .with_actor(EventActor {
184 user_id: None,
185 username: None,
186 ip_address: ip_address.clone(),
187 user_agent: user_agent.clone(),
188 })
189 .with_target(EventTarget {
190 resource_type: Some("api".to_string()),
191 resource_id: Some(path.clone()),
192 method: Some(req.method().to_string()),
193 })
194 .with_outcome(EventOutcome {
195 success: false,
196 reason: Some("Token expired".to_string()),
197 });
198 emit_security_event(event).await;
199 let mut res = Response::new(axum::body::Body::from(
200 serde_json::json!({
201 "error": "Token expired",
202 "message": "The provided token has expired"
203 })
204 .to_string(),
205 ));
206 *res.status_mut() = StatusCode::UNAUTHORIZED;
207 res.headers_mut().insert(
208 "www-authenticate",
209 "Bearer error=\"invalid_token\", error_description=\"The token has expired\""
210 .parse()
211 .unwrap(),
212 );
213 res
214 }
215 AuthResult::TokenInvalid(reason) => {
216 warn!("Token invalid: {}", reason);
217
218 let event = SecurityEvent::new(SecurityEventType::AuthFailure, None, None)
220 .with_actor(EventActor {
221 user_id: None,
222 username: None,
223 ip_address: ip_address.clone(),
224 user_agent: user_agent.clone(),
225 })
226 .with_target(EventTarget {
227 resource_type: Some("api".to_string()),
228 resource_id: Some(path.clone()),
229 method: Some(req.method().to_string()),
230 })
231 .with_outcome(EventOutcome {
232 success: false,
233 reason: Some(format!("Invalid token: {}", reason)),
234 })
235 .with_metadata("token_invalid".to_string(), serde_json::json!(true));
236 emit_security_event(event).await;
237 let mut res = Response::new(axum::body::Body::from(
238 serde_json::json!({
239 "error": "Invalid token",
240 "message": reason
241 })
242 .to_string(),
243 ));
244 *res.status_mut() = StatusCode::UNAUTHORIZED;
245 res.headers_mut()
246 .insert("www-authenticate", "Bearer error=\"invalid_token\"".parse().unwrap());
247 res
248 }
249 AuthResult::None => {
250 if state.config.require_auth {
251 let event = SecurityEvent::new(SecurityEventType::AuthzAccessDenied, None, None)
253 .with_actor(EventActor {
254 user_id: None,
255 username: None,
256 ip_address: ip_address.clone(),
257 user_agent: user_agent.clone(),
258 })
259 .with_target(EventTarget {
260 resource_type: Some("api".to_string()),
261 resource_id: Some(path.clone()),
262 method: Some(req.method().to_string()),
263 })
264 .with_outcome(EventOutcome {
265 success: false,
266 reason: Some("Authentication required but not provided".to_string()),
267 });
268 emit_security_event(event).await;
269
270 let mut res = Response::new(axum::body::Body::from(
271 serde_json::json!({
272 "error": "Authentication required"
273 })
274 .to_string(),
275 ));
276 *res.status_mut() = StatusCode::UNAUTHORIZED;
277 res.headers_mut().insert("www-authenticate", "Bearer".parse().unwrap());
278 res
279 } else {
280 debug!("No authentication provided, proceeding without auth");
281 next.run(req).await
282 }
283 }
284 }
285}