mockforge_core/protocol_abstraction/
auth.rs1use super::{Protocol, ProtocolMiddleware, ProtocolRequest, ProtocolResponse};
4use crate::config::AuthConfig;
5use crate::Result;
6use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11#[derive(Debug, Serialize, Deserialize, Clone)]
13pub struct Claims {
14 pub sub: String,
15 pub exp: Option<usize>,
16 pub iat: Option<usize>,
17 pub aud: Option<String>,
18 pub iss: Option<String>,
19 #[serde(flatten)]
20 pub extra: HashMap<String, serde_json::Value>,
21}
22
23#[derive(Debug, Clone)]
25pub enum AuthResult {
26 Success(Claims),
28 Failure(String),
30 NetworkError(String),
32}
33
34pub struct AuthMiddleware {
36 name: String,
38 config: Arc<AuthConfig>,
40 introspection_cache: Arc<RwLock<HashMap<String, CachedToken>>>,
42}
43
44#[derive(Debug, Clone)]
46struct CachedToken {
47 claims: Claims,
48 expires_at: std::time::Instant,
49}
50
51impl AuthMiddleware {
52 pub fn new(config: AuthConfig) -> Self {
54 Self {
55 name: "AuthMiddleware".to_string(),
56 config: Arc::new(config),
57 introspection_cache: Arc::new(RwLock::new(HashMap::new())),
58 }
59 }
60
61 fn extract_token(&self, request: &ProtocolRequest) -> Option<String> {
63 if let Some(auth_header) = request.metadata.get("authorization") {
65 if let Some(token) = auth_header.strip_prefix("Bearer ") {
67 return Some(token.to_string());
68 }
69 return Some(auth_header.clone());
70 }
71
72 if let Some(api_key_config) = &self.config.api_key {
74 if let Some(api_key) = request.metadata.get(&api_key_config.header_name) {
75 return Some(api_key.clone());
76 }
77 }
78
79 if request.protocol == Protocol::Grpc {
81 if let Some(token) = request.metadata.get("grpc-metadata-authorization") {
82 if let Some(stripped) = token.strip_prefix("Bearer ") {
83 return Some(stripped.to_string());
84 }
85 return Some(token.clone());
86 }
87 }
88
89 None
90 }
91
92 async fn validate_jwt(&self, token: &str) -> AuthResult {
94 if let Some(cached) = self.introspection_cache.read().await.get(token) {
96 if cached.expires_at > std::time::Instant::now() {
97 return AuthResult::Success(cached.claims.clone());
98 }
99 }
100
101 let jwt_config = match &self.config.jwt {
103 Some(config) => config,
104 None => return AuthResult::Failure("JWT not configured".to_string()),
105 };
106
107 let header = match decode_header(token) {
109 Ok(h) => h,
110 Err(e) => return AuthResult::Failure(format!("Invalid token header: {}", e)),
111 };
112
113 let mut validation = Validation::new(header.alg);
115 if let Some(audience) = &jwt_config.audience {
116 validation.set_audience(&[audience]);
117 }
118 if let Some(issuer) = &jwt_config.issuer {
119 validation.set_issuer(&[issuer]);
120 }
121
122 let secret = match &jwt_config.secret {
124 Some(s) => s,
125 None => return AuthResult::Failure("JWT secret not configured".to_string()),
126 };
127
128 let decoding_key = DecodingKey::from_secret(secret.as_bytes());
130 match decode::<Claims>(token, &decoding_key, &validation) {
131 Ok(token_data) => {
132 let claims = token_data.claims;
133
134 let expires_at = if let Some(exp) = claims.exp {
136 let exp_instant =
137 std::time::UNIX_EPOCH + std::time::Duration::from_secs(exp as u64);
138 std::time::Instant::now()
139 + exp_instant.elapsed().unwrap_or(std::time::Duration::from_secs(300))
140 } else {
141 std::time::Instant::now() + std::time::Duration::from_secs(300)
142 };
143
144 self.introspection_cache.write().await.insert(
145 token.to_string(),
146 CachedToken {
147 claims: claims.clone(),
148 expires_at,
149 },
150 );
151
152 AuthResult::Success(claims)
153 }
154 Err(e) => AuthResult::Failure(format!("Token validation failed: {}", e)),
155 }
156 }
157
158 async fn validate_api_key(&self, key: &str) -> AuthResult {
160 let api_key_config = match &self.config.api_key {
161 Some(config) => config,
162 None => return AuthResult::Failure("API key not configured".to_string()),
163 };
164
165 if api_key_config.keys.contains(&key.to_string()) {
167 AuthResult::Success(Claims {
168 sub: "api_key_user".to_string(),
169 exp: None,
170 iat: None,
171 aud: None,
172 iss: Some("mockforge".to_string()),
173 extra: {
174 let mut extra = HashMap::new();
175 extra.insert("auth_type".to_string(), serde_json::json!("api_key"));
176 extra
177 },
178 })
179 } else {
180 AuthResult::Failure("Invalid API key".to_string())
181 }
182 }
183
184 async fn authenticate(&self, request: &ProtocolRequest) -> AuthResult {
186 let token = match self.extract_token(request) {
188 Some(t) => t,
189 None => {
190 if !self.config.require_auth {
192 return AuthResult::Success(Claims {
193 sub: "anonymous".to_string(),
194 exp: None,
195 iat: None,
196 aud: None,
197 iss: Some("mockforge".to_string()),
198 extra: HashMap::new(),
199 });
200 }
201 return AuthResult::Failure("No authentication token provided".to_string());
202 }
203 };
204
205 if self.config.jwt.is_some() {
207 let result = self.validate_jwt(&token).await;
208 if matches!(result, AuthResult::Success(_)) {
209 return result;
210 }
211 }
212
213 if self.config.api_key.is_some() {
215 let result = self.validate_api_key(&token).await;
216 if matches!(result, AuthResult::Success(_)) {
217 return result;
218 }
219 }
220
221 AuthResult::Failure("Authentication failed".to_string())
222 }
223}
224
225#[async_trait::async_trait]
226impl ProtocolMiddleware for AuthMiddleware {
227 fn name(&self) -> &str {
228 &self.name
229 }
230
231 async fn process_request(&self, request: &mut ProtocolRequest) -> Result<()> {
232 if request.path.starts_with("/health") || request.path.starts_with("/__mockforge") {
234 return Ok(());
235 }
236
237 match self.authenticate(request).await {
239 AuthResult::Success(claims) => {
240 request.metadata.insert("x-auth-sub".to_string(), claims.sub.clone());
242 if let Some(iss) = &claims.iss {
243 request.metadata.insert("x-auth-iss".to_string(), iss.clone());
244 }
245 tracing::debug!(
246 protocol = %request.protocol,
247 user = %claims.sub,
248 "Authentication successful"
249 );
250 Ok(())
251 }
252 AuthResult::Failure(reason) => {
253 tracing::warn!(
254 protocol = %request.protocol,
255 path = %request.path,
256 reason = %reason,
257 "Authentication failed"
258 );
259 Err(crate::Error::validation(format!("Authentication failed: {}", reason)))
260 }
261 AuthResult::NetworkError(reason) => {
262 tracing::error!(
263 protocol = %request.protocol,
264 reason = %reason,
265 "Authentication network error"
266 );
267 Err(crate::Error::validation(format!("Authentication error: {}", reason)))
268 }
269 }
270 }
271
272 async fn process_response(
273 &self,
274 _request: &ProtocolRequest,
275 _response: &mut ProtocolResponse,
276 ) -> Result<()> {
277 Ok(())
279 }
280
281 fn supports_protocol(&self, _protocol: Protocol) -> bool {
282 true
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use crate::config::ApiKeyConfig;
291
292 #[test]
293 fn test_auth_middleware_creation() {
294 let config = AuthConfig {
295 require_auth: true,
296 jwt: None,
297 api_key: None,
298 oauth2: None,
299 basic_auth: None,
300 };
301
302 let middleware = AuthMiddleware::new(config);
303 assert_eq!(middleware.name(), "AuthMiddleware");
304 assert!(middleware.supports_protocol(Protocol::Http));
305 assert!(middleware.supports_protocol(Protocol::Grpc));
306 assert!(middleware.supports_protocol(Protocol::GraphQL));
307 }
308
309 #[test]
310 fn test_extract_token_bearer() {
311 let config = AuthConfig::default();
312 let middleware = AuthMiddleware::new(config);
313
314 let mut metadata = HashMap::new();
315 metadata.insert("authorization".to_string(), "Bearer test_token".to_string());
316
317 let request = ProtocolRequest {
318 protocol: Protocol::Http,
319 pattern: crate::MessagePattern::RequestResponse,
320 operation: "GET".to_string(),
321 path: "/test".to_string(),
322 topic: None,
323 routing_key: None,
324 partition: None,
325 qos: None,
326 metadata,
327 body: None,
328 client_ip: None,
329 };
330
331 let token = middleware.extract_token(&request);
332 assert_eq!(token, Some("test_token".to_string()));
333 }
334
335 #[test]
336 fn test_extract_token_api_key() {
337 let config = AuthConfig {
338 require_auth: true,
339 jwt: None,
340 api_key: Some(ApiKeyConfig {
341 header_name: "X-API-Key".to_string(),
342 query_name: None,
343 keys: vec!["test_key".to_string()],
344 }),
345 oauth2: None,
346 basic_auth: None,
347 };
348 let middleware = AuthMiddleware::new(config);
349
350 let mut metadata = HashMap::new();
351 metadata.insert("X-API-Key".to_string(), "test_key".to_string());
352
353 let request = ProtocolRequest {
354 protocol: Protocol::Http,
355 operation: "GET".to_string(),
356 path: "/test".to_string(),
357 metadata,
358 ..Default::default()
359 };
360
361 let token = middleware.extract_token(&request);
362 assert_eq!(token, Some("test_key".to_string()));
363 }
364
365 #[tokio::test]
366 async fn test_validate_api_key_success() {
367 let config = AuthConfig {
368 require_auth: true,
369 jwt: None,
370 api_key: Some(ApiKeyConfig {
371 header_name: "X-API-Key".to_string(),
372 query_name: None,
373 keys: vec!["valid_key".to_string()],
374 }),
375 oauth2: None,
376 basic_auth: None,
377 };
378 let middleware = AuthMiddleware::new(config);
379
380 let result = middleware.validate_api_key("valid_key").await;
381 assert!(matches!(result, AuthResult::Success(_)));
382 }
383
384 #[tokio::test]
385 async fn test_validate_api_key_failure() {
386 let config = AuthConfig {
387 require_auth: true,
388 jwt: None,
389 api_key: Some(ApiKeyConfig {
390 header_name: "X-API-Key".to_string(),
391 query_name: None,
392 keys: vec!["valid_key".to_string()],
393 }),
394 oauth2: None,
395 basic_auth: None,
396 };
397 let middleware = AuthMiddleware::new(config);
398
399 let result = middleware.validate_api_key("invalid_key").await;
400 assert!(matches!(result, AuthResult::Failure(_)));
401 }
402}