1use axum::body::Body;
7use axum::http::header::HeaderValue;
8use axum::http::{Request, StatusCode};
9use axum::{extract::State, middleware::Next, response::Response};
10use tracing::{debug, error, warn};
11
12use super::authenticator::authenticate_request;
13use super::state::AuthState;
14use super::types::AuthResult;
15use mockforge_core::security::{
16 emit_security_event, EventActor, EventOutcome, EventTarget, SecurityEvent, SecurityEventType,
17};
18
19pub async fn auth_middleware(
21 State(state): State<AuthState>,
22 req: Request<Body>,
23 next: Next,
24) -> Response {
25 let path = req.uri().path().to_string();
26 let _method = req.method().clone();
27
28 if path.starts_with("/health") || path.starts_with("/__mockforge") {
30 return next.run(req).await;
31 }
32
33 let auth_header = req
35 .headers()
36 .get("authorization")
37 .and_then(|h| h.to_str().ok())
38 .map(|s| s.to_string());
39
40 let api_key_header = req
41 .headers()
42 .get(
43 state
44 .config
45 .api_key
46 .as_ref()
47 .map(|c| c.header_name.clone())
48 .unwrap_or_else(|| "X-API-Key".to_string()),
49 )
50 .and_then(|h| h.to_str().ok())
51 .map(|s| s.to_string());
52
53 let api_key_query = req.uri().query().and_then(|q| {
54 state
55 .config
56 .api_key
57 .as_ref()
58 .and_then(|c| c.query_name.as_ref())
59 .and_then(|param| {
60 url::form_urlencoded::parse(q.as_bytes())
61 .find(|(k, _)| k == param)
62 .map(|(_, v)| v.to_string())
63 })
64 });
65
66 let ip_address = req
68 .headers()
69 .get("x-forwarded-for")
70 .or_else(|| req.headers().get("x-real-ip"))
71 .and_then(|h| h.to_str().ok())
72 .map(|s| s.to_string())
73 .or_else(|| {
74 req.extensions()
75 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
76 .map(|addr| addr.ip().to_string())
77 });
78
79 let user_agent = req
80 .headers()
81 .get("user-agent")
82 .and_then(|h| h.to_str().ok())
83 .map(|s| s.to_string());
84
85 let auth_result =
87 authenticate_request(&state, &auth_header, &api_key_header, &api_key_query).await;
88
89 match auth_result {
90 AuthResult::Success(claims) => {
91 debug!("Authentication successful for user: {:?}", claims.sub);
92
93 let event = SecurityEvent::new(SecurityEventType::AuthSuccess, None, None)
95 .with_actor(EventActor {
96 user_id: claims.sub.clone(),
97 username: claims.sub.clone(),
98 ip_address: ip_address.clone(),
99 user_agent: user_agent.clone(),
100 })
101 .with_target(EventTarget {
102 resource_type: Some("api".to_string()),
103 resource_id: Some(path.clone()),
104 method: Some(req.method().to_string()),
105 })
106 .with_outcome(EventOutcome {
107 success: true,
108 reason: None,
109 })
110 .with_metadata("auth_method".to_string(), serde_json::json!("jwt"));
111 emit_security_event(event).await;
112
113 use axum::extract::Extension;
116 let mut req = req;
117 req.extensions_mut().insert(Extension(claims));
118 next.run(req).await
119 }
120 AuthResult::Failure(reason) => {
121 warn!("Authentication failed: {}", reason);
122
123 let event = SecurityEvent::new(SecurityEventType::AuthFailure, None, None)
125 .with_actor(EventActor {
126 user_id: None,
127 username: auth_header
128 .as_ref()
129 .and_then(|h| h.strip_prefix("Bearer "))
130 .or_else(|| auth_header.as_ref().and_then(|h| h.strip_prefix("Basic ")))
131 .map(|s| s.to_string()),
132 ip_address: ip_address.clone(),
133 user_agent: user_agent.clone(),
134 })
135 .with_target(EventTarget {
136 resource_type: Some("api".to_string()),
137 resource_id: Some(path.clone()),
138 method: Some(req.method().to_string()),
139 })
140 .with_outcome(EventOutcome {
141 success: false,
142 reason: Some(reason.clone()),
143 })
144 .with_metadata("failure_reason".to_string(), serde_json::json!(reason));
145 emit_security_event(event).await;
146 let mut res = Response::new(Body::from(
147 serde_json::json!({
148 "error": "Authentication failed",
149 "message": reason
150 })
151 .to_string(),
152 ));
153 *res.status_mut() = StatusCode::UNAUTHORIZED;
154 res.headers_mut().insert("www-authenticate", HeaderValue::from_static("Bearer"));
155 res
156 }
157 AuthResult::NetworkError(reason) => {
158 error!("Authentication network error: {}", reason);
159 let mut res = Response::new(Body::from(
160 serde_json::json!({
161 "error": "Authentication service unavailable",
162 "message": "Unable to verify token due to network issues"
163 })
164 .to_string(),
165 ));
166 *res.status_mut() = StatusCode::SERVICE_UNAVAILABLE;
167 res
168 }
169 AuthResult::ServerError(reason) => {
170 error!("Authentication server error: {}", reason);
171 let mut res = Response::new(Body::from(
172 serde_json::json!({
173 "error": "Authentication service error",
174 "message": "Unable to verify token due to server issues"
175 })
176 .to_string(),
177 ));
178 *res.status_mut() = StatusCode::BAD_GATEWAY;
179 res
180 }
181 AuthResult::TokenExpired => {
182 warn!("Token expired");
183
184 let event = SecurityEvent::new(SecurityEventType::AuthTokenExpired, None, None)
186 .with_actor(EventActor {
187 user_id: None,
188 username: None,
189 ip_address: ip_address.clone(),
190 user_agent: user_agent.clone(),
191 })
192 .with_target(EventTarget {
193 resource_type: Some("api".to_string()),
194 resource_id: Some(path.clone()),
195 method: Some(req.method().to_string()),
196 })
197 .with_outcome(EventOutcome {
198 success: false,
199 reason: Some("Token expired".to_string()),
200 });
201 emit_security_event(event).await;
202 let mut res = Response::new(Body::from(
203 serde_json::json!({
204 "error": "Token expired",
205 "message": "The provided token has expired"
206 })
207 .to_string(),
208 ));
209 *res.status_mut() = StatusCode::UNAUTHORIZED;
210 res.headers_mut().insert(
211 "www-authenticate",
212 HeaderValue::from_static(
213 "Bearer error=\"invalid_token\", error_description=\"The token has expired\"",
214 ),
215 );
216 res
217 }
218 AuthResult::TokenInvalid(reason) => {
219 warn!("Token invalid: {}", reason);
220
221 let event = SecurityEvent::new(SecurityEventType::AuthFailure, None, None)
223 .with_actor(EventActor {
224 user_id: None,
225 username: None,
226 ip_address: ip_address.clone(),
227 user_agent: user_agent.clone(),
228 })
229 .with_target(EventTarget {
230 resource_type: Some("api".to_string()),
231 resource_id: Some(path.clone()),
232 method: Some(req.method().to_string()),
233 })
234 .with_outcome(EventOutcome {
235 success: false,
236 reason: Some(format!("Invalid token: {}", reason)),
237 })
238 .with_metadata("token_invalid".to_string(), serde_json::json!(true));
239 emit_security_event(event).await;
240 let mut res = Response::new(Body::from(
241 serde_json::json!({
242 "error": "Invalid token",
243 "message": reason
244 })
245 .to_string(),
246 ));
247 *res.status_mut() = StatusCode::UNAUTHORIZED;
248 res.headers_mut().insert(
249 "www-authenticate",
250 HeaderValue::from_static("Bearer error=\"invalid_token\""),
251 );
252 res
253 }
254 AuthResult::None => {
255 if state.config.require_auth {
256 let event = SecurityEvent::new(SecurityEventType::AuthzAccessDenied, None, None)
258 .with_actor(EventActor {
259 user_id: None,
260 username: None,
261 ip_address: ip_address.clone(),
262 user_agent: user_agent.clone(),
263 })
264 .with_target(EventTarget {
265 resource_type: Some("api".to_string()),
266 resource_id: Some(path.clone()),
267 method: Some(req.method().to_string()),
268 })
269 .with_outcome(EventOutcome {
270 success: false,
271 reason: Some("Authentication required but not provided".to_string()),
272 });
273 emit_security_event(event).await;
274
275 let mut res = Response::new(Body::from(
276 serde_json::json!({
277 "error": "Authentication required"
278 })
279 .to_string(),
280 ));
281 *res.status_mut() = StatusCode::UNAUTHORIZED;
282 res.headers_mut().insert("www-authenticate", HeaderValue::from_static("Bearer"));
283 res
284 } else {
285 debug!("No authentication provided, proceeding without auth");
286 next.run(req).await
287 }
288 }
289 }
290}