mockforge_core/protocol_abstraction/
auth.rs1use super::{
4 MiddlewareAction, Protocol, ProtocolMiddleware, ProtocolRequest, ProtocolResponse,
5 ResponseStatus,
6};
7use crate::config::AuthConfig;
8use crate::Result;
9use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14#[derive(Debug, Serialize, Deserialize, Clone)]
16pub struct Claims {
17 pub sub: String,
19 pub exp: Option<usize>,
21 pub iat: Option<usize>,
23 pub aud: Option<String>,
25 pub iss: Option<String>,
27 #[serde(flatten)]
29 pub extra: HashMap<String, serde_json::Value>,
30}
31
32#[derive(Debug, Clone)]
34pub enum AuthResult {
35 Success(Claims),
37 Failure(String),
39 NetworkError(String),
41}
42
43pub struct AuthMiddleware {
45 name: String,
47 config: Arc<AuthConfig>,
49 introspection_cache: Arc<RwLock<HashMap<String, CachedToken>>>,
51}
52
53#[derive(Debug, Clone)]
55struct CachedToken {
56 claims: Claims,
57 expires_at: std::time::Instant,
58}
59
60impl AuthMiddleware {
61 pub fn new(config: AuthConfig) -> Self {
63 Self {
64 name: "AuthMiddleware".to_string(),
65 config: Arc::new(config),
66 introspection_cache: Arc::new(RwLock::new(HashMap::new())),
67 }
68 }
69
70 fn extract_token(&self, request: &ProtocolRequest) -> Option<String> {
72 if let Some(auth_header) = request.metadata.get("authorization") {
74 if let Some(token) = auth_header.strip_prefix("Bearer ") {
76 return Some(token.to_string());
77 }
78 return Some(auth_header.clone());
79 }
80
81 if let Some(api_key_config) = &self.config.api_key {
83 if let Some(api_key) = request.metadata.get(&api_key_config.header_name) {
84 return Some(api_key.clone());
85 }
86 }
87
88 if request.protocol == Protocol::Grpc {
90 if let Some(token) = request.metadata.get("grpc-metadata-authorization") {
91 if let Some(stripped) = token.strip_prefix("Bearer ") {
92 return Some(stripped.to_string());
93 }
94 return Some(token.clone());
95 }
96 }
97
98 None
99 }
100
101 async fn validate_jwt(&self, token: &str) -> AuthResult {
103 if let Some(cached) = self.introspection_cache.read().await.get(token) {
105 if cached.expires_at > std::time::Instant::now() {
106 return AuthResult::Success(cached.claims.clone());
107 }
108 }
109
110 let jwt_config = match &self.config.jwt {
112 Some(config) => config,
113 None => return AuthResult::Failure("JWT not configured".to_string()),
114 };
115
116 let header = match decode_header(token) {
118 Ok(h) => h,
119 Err(e) => return AuthResult::Failure(format!("Invalid token header: {}", e)),
120 };
121
122 let mut validation = Validation::new(header.alg);
124 if let Some(audience) = &jwt_config.audience {
125 validation.set_audience(&[audience]);
126 }
127 if let Some(issuer) = &jwt_config.issuer {
128 validation.set_issuer(&[issuer]);
129 }
130
131 let secret = match &jwt_config.secret {
133 Some(s) => s,
134 None => return AuthResult::Failure("JWT secret not configured".to_string()),
135 };
136
137 let decoding_key = DecodingKey::from_secret(secret.as_bytes());
139 match decode::<Claims>(token, &decoding_key, &validation) {
140 Ok(token_data) => {
141 let claims = token_data.claims;
142
143 let expires_at = if let Some(exp) = claims.exp {
145 let exp_instant =
146 std::time::UNIX_EPOCH + std::time::Duration::from_secs(exp as u64);
147 std::time::Instant::now()
148 + exp_instant.elapsed().unwrap_or(std::time::Duration::from_secs(300))
149 } else {
150 std::time::Instant::now() + std::time::Duration::from_secs(300)
151 };
152
153 self.introspection_cache.write().await.insert(
154 token.to_string(),
155 CachedToken {
156 claims: claims.clone(),
157 expires_at,
158 },
159 );
160
161 AuthResult::Success(claims)
162 }
163 Err(e) => AuthResult::Failure(format!("Token validation failed: {}", e)),
164 }
165 }
166
167 async fn validate_api_key(&self, key: &str) -> AuthResult {
169 let api_key_config = match &self.config.api_key {
170 Some(config) => config,
171 None => return AuthResult::Failure("API key not configured".to_string()),
172 };
173
174 if api_key_config.keys.contains(&key.to_string()) {
176 AuthResult::Success(Claims {
177 sub: "api_key_user".to_string(),
178 exp: None,
179 iat: None,
180 aud: None,
181 iss: Some("mockforge".to_string()),
182 extra: {
183 let mut extra = HashMap::new();
184 extra.insert("auth_type".to_string(), serde_json::json!("api_key"));
185 extra
186 },
187 })
188 } else {
189 AuthResult::Failure("Invalid API key".to_string())
190 }
191 }
192
193 async fn authenticate(&self, request: &ProtocolRequest) -> AuthResult {
195 let token = match self.extract_token(request) {
197 Some(t) => t,
198 None => {
199 if !self.config.require_auth {
201 return AuthResult::Success(Claims {
202 sub: "anonymous".to_string(),
203 exp: None,
204 iat: None,
205 aud: None,
206 iss: Some("mockforge".to_string()),
207 extra: HashMap::new(),
208 });
209 }
210 return AuthResult::Failure("No authentication token provided".to_string());
211 }
212 };
213
214 if self.config.jwt.is_some() {
216 let result = self.validate_jwt(&token).await;
217 if matches!(result, AuthResult::Success(_)) {
218 return result;
219 }
220 }
221
222 if self.config.api_key.is_some() {
224 let result = self.validate_api_key(&token).await;
225 if matches!(result, AuthResult::Success(_)) {
226 return result;
227 }
228 }
229
230 AuthResult::Failure("Authentication failed".to_string())
231 }
232}
233
234#[async_trait::async_trait]
235impl ProtocolMiddleware for AuthMiddleware {
236 fn name(&self) -> &str {
237 &self.name
238 }
239
240 async fn process_request(&self, request: &mut ProtocolRequest) -> Result<MiddlewareAction> {
241 if request.path.starts_with("/health") || request.path.starts_with("/__mockforge") {
243 return Ok(MiddlewareAction::Continue);
244 }
245
246 match self.authenticate(request).await {
248 AuthResult::Success(claims) => {
249 request.metadata.insert("x-auth-sub".to_string(), claims.sub.clone());
251 if let Some(iss) = &claims.iss {
252 request.metadata.insert("x-auth-iss".to_string(), iss.clone());
253 }
254 tracing::debug!(
255 protocol = %request.protocol,
256 user = %claims.sub,
257 "Authentication successful"
258 );
259 Ok(MiddlewareAction::Continue)
260 }
261 AuthResult::Failure(reason) => {
262 tracing::warn!(
263 protocol = %request.protocol,
264 path = %request.path,
265 reason = %reason,
266 "Authentication failed"
267 );
268 Ok(MiddlewareAction::ShortCircuit(ProtocolResponse {
269 status: ResponseStatus::HttpStatus(401),
270 metadata: std::collections::HashMap::new(),
271 body: format!(r#"{{"error":"Authentication failed","reason":"{}"}}"#, reason)
272 .into_bytes(),
273 content_type: "application/json".to_string(),
274 }))
275 }
276 AuthResult::NetworkError(reason) => {
277 tracing::error!(
278 protocol = %request.protocol,
279 reason = %reason,
280 "Authentication network error"
281 );
282 Ok(MiddlewareAction::ShortCircuit(ProtocolResponse {
283 status: ResponseStatus::HttpStatus(503),
284 metadata: std::collections::HashMap::new(),
285 body: format!(
286 r#"{{"error":"Authentication service unavailable","reason":"{}"}}"#,
287 reason
288 )
289 .into_bytes(),
290 content_type: "application/json".to_string(),
291 }))
292 }
293 }
294 }
295
296 async fn process_response(
297 &self,
298 _request: &ProtocolRequest,
299 _response: &mut ProtocolResponse,
300 ) -> Result<()> {
301 Ok(())
303 }
304
305 fn supports_protocol(&self, _protocol: Protocol) -> bool {
306 true
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::config::ApiKeyConfig;
315
316 #[test]
317 fn test_auth_middleware_creation() {
318 let config = AuthConfig {
319 require_auth: true,
320 jwt: None,
321 api_key: None,
322 oauth2: None,
323 basic_auth: None,
324 };
325
326 let middleware = AuthMiddleware::new(config);
327 assert_eq!(middleware.name(), "AuthMiddleware");
328 assert!(middleware.supports_protocol(Protocol::Http));
329 assert!(middleware.supports_protocol(Protocol::Grpc));
330 assert!(middleware.supports_protocol(Protocol::GraphQL));
331 }
332
333 #[test]
334 fn test_extract_token_bearer() {
335 let config = AuthConfig::default();
336 let middleware = AuthMiddleware::new(config);
337
338 let mut metadata = HashMap::new();
339 metadata.insert("authorization".to_string(), "Bearer test_token".to_string());
340
341 let request = ProtocolRequest {
342 protocol: Protocol::Http,
343 pattern: crate::MessagePattern::RequestResponse,
344 operation: "GET".to_string(),
345 path: "/test".to_string(),
346 topic: None,
347 routing_key: None,
348 partition: None,
349 qos: None,
350 metadata,
351 body: None,
352 client_ip: None,
353 };
354
355 let token = middleware.extract_token(&request);
356 assert_eq!(token, Some("test_token".to_string()));
357 }
358
359 #[test]
360 fn test_extract_token_api_key() {
361 let config = AuthConfig {
362 require_auth: true,
363 jwt: None,
364 api_key: Some(ApiKeyConfig {
365 header_name: "X-API-Key".to_string(),
366 query_name: None,
367 keys: vec!["test_key".to_string()],
368 }),
369 oauth2: None,
370 basic_auth: None,
371 };
372 let middleware = AuthMiddleware::new(config);
373
374 let mut metadata = HashMap::new();
375 metadata.insert("X-API-Key".to_string(), "test_key".to_string());
376
377 let request = ProtocolRequest {
378 protocol: Protocol::Http,
379 operation: "GET".to_string(),
380 path: "/test".to_string(),
381 metadata,
382 ..Default::default()
383 };
384
385 let token = middleware.extract_token(&request);
386 assert_eq!(token, Some("test_key".to_string()));
387 }
388
389 #[tokio::test]
390 async fn test_validate_api_key_success() {
391 let config = AuthConfig {
392 require_auth: true,
393 jwt: None,
394 api_key: Some(ApiKeyConfig {
395 header_name: "X-API-Key".to_string(),
396 query_name: None,
397 keys: vec!["valid_key".to_string()],
398 }),
399 oauth2: None,
400 basic_auth: None,
401 };
402 let middleware = AuthMiddleware::new(config);
403
404 let result = middleware.validate_api_key("valid_key").await;
405 assert!(matches!(result, AuthResult::Success(_)));
406 }
407
408 #[tokio::test]
409 async fn test_validate_api_key_failure() {
410 let config = AuthConfig {
411 require_auth: true,
412 jwt: None,
413 api_key: Some(ApiKeyConfig {
414 header_name: "X-API-Key".to_string(),
415 query_name: None,
416 keys: vec!["valid_key".to_string()],
417 }),
418 oauth2: None,
419 basic_auth: None,
420 };
421 let middleware = AuthMiddleware::new(config);
422
423 let result = middleware.validate_api_key("invalid_key").await;
424 assert!(matches!(result, AuthResult::Failure(_)));
425 }
426}