mockforge_core/protocol_abstraction/
auth.rs1use super::{Protocol, ProtocolMiddleware, ProtocolRequest, ProtocolResponse};
4use crate::config::AuthConfig;
5use crate::Result;
6use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11#[derive(Debug, Serialize, Deserialize, Clone)]
13pub struct Claims {
14 pub sub: String,
16 pub exp: Option<usize>,
18 pub iat: Option<usize>,
20 pub aud: Option<String>,
22 pub iss: Option<String>,
24 #[serde(flatten)]
26 pub extra: HashMap<String, serde_json::Value>,
27}
28
29#[derive(Debug, Clone)]
31pub enum AuthResult {
32 Success(Claims),
34 Failure(String),
36 NetworkError(String),
38}
39
40pub struct AuthMiddleware {
42 name: String,
44 config: Arc<AuthConfig>,
46 introspection_cache: Arc<RwLock<HashMap<String, CachedToken>>>,
48}
49
50#[derive(Debug, Clone)]
52struct CachedToken {
53 claims: Claims,
54 expires_at: std::time::Instant,
55}
56
57impl AuthMiddleware {
58 pub fn new(config: AuthConfig) -> Self {
60 Self {
61 name: "AuthMiddleware".to_string(),
62 config: Arc::new(config),
63 introspection_cache: Arc::new(RwLock::new(HashMap::new())),
64 }
65 }
66
67 fn extract_token(&self, request: &ProtocolRequest) -> Option<String> {
69 if let Some(auth_header) = request.metadata.get("authorization") {
71 if let Some(token) = auth_header.strip_prefix("Bearer ") {
73 return Some(token.to_string());
74 }
75 return Some(auth_header.clone());
76 }
77
78 if let Some(api_key_config) = &self.config.api_key {
80 if let Some(api_key) = request.metadata.get(&api_key_config.header_name) {
81 return Some(api_key.clone());
82 }
83 }
84
85 if request.protocol == Protocol::Grpc {
87 if let Some(token) = request.metadata.get("grpc-metadata-authorization") {
88 if let Some(stripped) = token.strip_prefix("Bearer ") {
89 return Some(stripped.to_string());
90 }
91 return Some(token.clone());
92 }
93 }
94
95 None
96 }
97
98 async fn validate_jwt(&self, token: &str) -> AuthResult {
100 if let Some(cached) = self.introspection_cache.read().await.get(token) {
102 if cached.expires_at > std::time::Instant::now() {
103 return AuthResult::Success(cached.claims.clone());
104 }
105 }
106
107 let jwt_config = match &self.config.jwt {
109 Some(config) => config,
110 None => return AuthResult::Failure("JWT not configured".to_string()),
111 };
112
113 let header = match decode_header(token) {
115 Ok(h) => h,
116 Err(e) => return AuthResult::Failure(format!("Invalid token header: {}", e)),
117 };
118
119 let mut validation = Validation::new(header.alg);
121 if let Some(audience) = &jwt_config.audience {
122 validation.set_audience(&[audience]);
123 }
124 if let Some(issuer) = &jwt_config.issuer {
125 validation.set_issuer(&[issuer]);
126 }
127
128 let secret = match &jwt_config.secret {
130 Some(s) => s,
131 None => return AuthResult::Failure("JWT secret not configured".to_string()),
132 };
133
134 let decoding_key = DecodingKey::from_secret(secret.as_bytes());
136 match decode::<Claims>(token, &decoding_key, &validation) {
137 Ok(token_data) => {
138 let claims = token_data.claims;
139
140 let expires_at = if let Some(exp) = claims.exp {
142 let exp_instant =
143 std::time::UNIX_EPOCH + std::time::Duration::from_secs(exp as u64);
144 std::time::Instant::now()
145 + exp_instant.elapsed().unwrap_or(std::time::Duration::from_secs(300))
146 } else {
147 std::time::Instant::now() + std::time::Duration::from_secs(300)
148 };
149
150 self.introspection_cache.write().await.insert(
151 token.to_string(),
152 CachedToken {
153 claims: claims.clone(),
154 expires_at,
155 },
156 );
157
158 AuthResult::Success(claims)
159 }
160 Err(e) => AuthResult::Failure(format!("Token validation failed: {}", e)),
161 }
162 }
163
164 async fn validate_api_key(&self, key: &str) -> AuthResult {
166 let api_key_config = match &self.config.api_key {
167 Some(config) => config,
168 None => return AuthResult::Failure("API key not configured".to_string()),
169 };
170
171 if api_key_config.keys.contains(&key.to_string()) {
173 AuthResult::Success(Claims {
174 sub: "api_key_user".to_string(),
175 exp: None,
176 iat: None,
177 aud: None,
178 iss: Some("mockforge".to_string()),
179 extra: {
180 let mut extra = HashMap::new();
181 extra.insert("auth_type".to_string(), serde_json::json!("api_key"));
182 extra
183 },
184 })
185 } else {
186 AuthResult::Failure("Invalid API key".to_string())
187 }
188 }
189
190 async fn authenticate(&self, request: &ProtocolRequest) -> AuthResult {
192 let token = match self.extract_token(request) {
194 Some(t) => t,
195 None => {
196 if !self.config.require_auth {
198 return AuthResult::Success(Claims {
199 sub: "anonymous".to_string(),
200 exp: None,
201 iat: None,
202 aud: None,
203 iss: Some("mockforge".to_string()),
204 extra: HashMap::new(),
205 });
206 }
207 return AuthResult::Failure("No authentication token provided".to_string());
208 }
209 };
210
211 if self.config.jwt.is_some() {
213 let result = self.validate_jwt(&token).await;
214 if matches!(result, AuthResult::Success(_)) {
215 return result;
216 }
217 }
218
219 if self.config.api_key.is_some() {
221 let result = self.validate_api_key(&token).await;
222 if matches!(result, AuthResult::Success(_)) {
223 return result;
224 }
225 }
226
227 AuthResult::Failure("Authentication failed".to_string())
228 }
229}
230
231#[async_trait::async_trait]
232impl ProtocolMiddleware for AuthMiddleware {
233 fn name(&self) -> &str {
234 &self.name
235 }
236
237 async fn process_request(&self, request: &mut ProtocolRequest) -> Result<()> {
238 if request.path.starts_with("/health") || request.path.starts_with("/__mockforge") {
240 return Ok(());
241 }
242
243 match self.authenticate(request).await {
245 AuthResult::Success(claims) => {
246 request.metadata.insert("x-auth-sub".to_string(), claims.sub.clone());
248 if let Some(iss) = &claims.iss {
249 request.metadata.insert("x-auth-iss".to_string(), iss.clone());
250 }
251 tracing::debug!(
252 protocol = %request.protocol,
253 user = %claims.sub,
254 "Authentication successful"
255 );
256 Ok(())
257 }
258 AuthResult::Failure(reason) => {
259 tracing::warn!(
260 protocol = %request.protocol,
261 path = %request.path,
262 reason = %reason,
263 "Authentication failed"
264 );
265 Err(crate::Error::validation(format!("Authentication failed: {}", reason)))
266 }
267 AuthResult::NetworkError(reason) => {
268 tracing::error!(
269 protocol = %request.protocol,
270 reason = %reason,
271 "Authentication network error"
272 );
273 Err(crate::Error::validation(format!("Authentication error: {}", reason)))
274 }
275 }
276 }
277
278 async fn process_response(
279 &self,
280 _request: &ProtocolRequest,
281 _response: &mut ProtocolResponse,
282 ) -> Result<()> {
283 Ok(())
285 }
286
287 fn supports_protocol(&self, _protocol: Protocol) -> bool {
288 true
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use crate::config::ApiKeyConfig;
297
298 #[test]
299 fn test_auth_middleware_creation() {
300 let config = AuthConfig {
301 require_auth: true,
302 jwt: None,
303 api_key: None,
304 oauth2: None,
305 basic_auth: None,
306 };
307
308 let middleware = AuthMiddleware::new(config);
309 assert_eq!(middleware.name(), "AuthMiddleware");
310 assert!(middleware.supports_protocol(Protocol::Http));
311 assert!(middleware.supports_protocol(Protocol::Grpc));
312 assert!(middleware.supports_protocol(Protocol::GraphQL));
313 }
314
315 #[test]
316 fn test_extract_token_bearer() {
317 let config = AuthConfig::default();
318 let middleware = AuthMiddleware::new(config);
319
320 let mut metadata = HashMap::new();
321 metadata.insert("authorization".to_string(), "Bearer test_token".to_string());
322
323 let request = ProtocolRequest {
324 protocol: Protocol::Http,
325 pattern: crate::MessagePattern::RequestResponse,
326 operation: "GET".to_string(),
327 path: "/test".to_string(),
328 topic: None,
329 routing_key: None,
330 partition: None,
331 qos: None,
332 metadata,
333 body: None,
334 client_ip: None,
335 };
336
337 let token = middleware.extract_token(&request);
338 assert_eq!(token, Some("test_token".to_string()));
339 }
340
341 #[test]
342 fn test_extract_token_api_key() {
343 let config = AuthConfig {
344 require_auth: true,
345 jwt: None,
346 api_key: Some(ApiKeyConfig {
347 header_name: "X-API-Key".to_string(),
348 query_name: None,
349 keys: vec!["test_key".to_string()],
350 }),
351 oauth2: None,
352 basic_auth: None,
353 };
354 let middleware = AuthMiddleware::new(config);
355
356 let mut metadata = HashMap::new();
357 metadata.insert("X-API-Key".to_string(), "test_key".to_string());
358
359 let request = ProtocolRequest {
360 protocol: Protocol::Http,
361 operation: "GET".to_string(),
362 path: "/test".to_string(),
363 metadata,
364 ..Default::default()
365 };
366
367 let token = middleware.extract_token(&request);
368 assert_eq!(token, Some("test_key".to_string()));
369 }
370
371 #[tokio::test]
372 async fn test_validate_api_key_success() {
373 let config = AuthConfig {
374 require_auth: true,
375 jwt: None,
376 api_key: Some(ApiKeyConfig {
377 header_name: "X-API-Key".to_string(),
378 query_name: None,
379 keys: vec!["valid_key".to_string()],
380 }),
381 oauth2: None,
382 basic_auth: None,
383 };
384 let middleware = AuthMiddleware::new(config);
385
386 let result = middleware.validate_api_key("valid_key").await;
387 assert!(matches!(result, AuthResult::Success(_)));
388 }
389
390 #[tokio::test]
391 async fn test_validate_api_key_failure() {
392 let config = AuthConfig {
393 require_auth: true,
394 jwt: None,
395 api_key: Some(ApiKeyConfig {
396 header_name: "X-API-Key".to_string(),
397 query_name: None,
398 keys: vec!["valid_key".to_string()],
399 }),
400 oauth2: None,
401 basic_auth: None,
402 };
403 let middleware = AuthMiddleware::new(config);
404
405 let result = middleware.validate_api_key("invalid_key").await;
406 assert!(matches!(result, AuthResult::Failure(_)));
407 }
408}