mockforge_core/protocol_abstraction/
auth.rs

1//! Unified authentication middleware for all protocols
2
3use super::{Protocol, ProtocolMiddleware, ProtocolRequest, ProtocolResponse};
4use crate::config::AuthConfig;
5use crate::Result;
6use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11/// JWT Claims
12#[derive(Debug, Serialize, Deserialize, Clone)]
13pub struct Claims {
14    pub sub: String,
15    pub exp: Option<usize>,
16    pub iat: Option<usize>,
17    pub aud: Option<String>,
18    pub iss: Option<String>,
19    #[serde(flatten)]
20    pub extra: HashMap<String, serde_json::Value>,
21}
22
23/// Authentication result
24#[derive(Debug, Clone)]
25pub enum AuthResult {
26    /// Authentication successful
27    Success(Claims),
28    /// Authentication failed
29    Failure(String),
30    /// Network error during auth
31    NetworkError(String),
32}
33
34/// Unified authentication middleware
35pub struct AuthMiddleware {
36    /// Middleware name
37    name: String,
38    /// Authentication configuration
39    config: Arc<AuthConfig>,
40    /// Token introspection cache
41    introspection_cache: Arc<RwLock<HashMap<String, CachedToken>>>,
42}
43
44/// Cached token information
45#[derive(Debug, Clone)]
46struct CachedToken {
47    claims: Claims,
48    expires_at: std::time::Instant,
49}
50
51impl AuthMiddleware {
52    /// Create a new auth middleware
53    pub fn new(config: AuthConfig) -> Self {
54        Self {
55            name: "AuthMiddleware".to_string(),
56            config: Arc::new(config),
57            introspection_cache: Arc::new(RwLock::new(HashMap::new())),
58        }
59    }
60
61    /// Extract auth token from request metadata
62    fn extract_token(&self, request: &ProtocolRequest) -> Option<String> {
63        // Try Authorization header first (works for HTTP, GraphQL, WebSocket)
64        if let Some(auth_header) = request.metadata.get("authorization") {
65            // Handle "Bearer <token>" format
66            if let Some(token) = auth_header.strip_prefix("Bearer ") {
67                return Some(token.to_string());
68            }
69            return Some(auth_header.clone());
70        }
71
72        // Try API key header
73        if let Some(api_key_config) = &self.config.api_key {
74            if let Some(api_key) = request.metadata.get(&api_key_config.header_name) {
75                return Some(api_key.clone());
76            }
77        }
78
79        // For gRPC, try metadata
80        if request.protocol == Protocol::Grpc {
81            if let Some(token) = request.metadata.get("grpc-metadata-authorization") {
82                if let Some(stripped) = token.strip_prefix("Bearer ") {
83                    return Some(stripped.to_string());
84                }
85                return Some(token.clone());
86            }
87        }
88
89        None
90    }
91
92    /// Validate JWT token
93    async fn validate_jwt(&self, token: &str) -> AuthResult {
94        // Check cache first
95        if let Some(cached) = self.introspection_cache.read().await.get(token) {
96            if cached.expires_at > std::time::Instant::now() {
97                return AuthResult::Success(cached.claims.clone());
98            }
99        }
100
101        // Get JWT configuration
102        let jwt_config = match &self.config.jwt {
103            Some(config) => config,
104            None => return AuthResult::Failure("JWT not configured".to_string()),
105        };
106
107        // Decode header to get algorithm
108        let header = match decode_header(token) {
109            Ok(h) => h,
110            Err(e) => return AuthResult::Failure(format!("Invalid token header: {}", e)),
111        };
112
113        // Create validation
114        let mut validation = Validation::new(header.alg);
115        if let Some(audience) = &jwt_config.audience {
116            validation.set_audience(&[audience]);
117        }
118        if let Some(issuer) = &jwt_config.issuer {
119            validation.set_issuer(&[issuer]);
120        }
121
122        // Get secret
123        let secret = match &jwt_config.secret {
124            Some(s) => s,
125            None => return AuthResult::Failure("JWT secret not configured".to_string()),
126        };
127
128        // Decode token
129        let decoding_key = DecodingKey::from_secret(secret.as_bytes());
130        match decode::<Claims>(token, &decoding_key, &validation) {
131            Ok(token_data) => {
132                let claims = token_data.claims;
133
134                // Cache the token
135                let expires_at = if let Some(exp) = claims.exp {
136                    let exp_instant =
137                        std::time::UNIX_EPOCH + std::time::Duration::from_secs(exp as u64);
138                    std::time::Instant::now()
139                        + exp_instant.elapsed().unwrap_or(std::time::Duration::from_secs(300))
140                } else {
141                    std::time::Instant::now() + std::time::Duration::from_secs(300)
142                };
143
144                self.introspection_cache.write().await.insert(
145                    token.to_string(),
146                    CachedToken {
147                        claims: claims.clone(),
148                        expires_at,
149                    },
150                );
151
152                AuthResult::Success(claims)
153            }
154            Err(e) => AuthResult::Failure(format!("Token validation failed: {}", e)),
155        }
156    }
157
158    /// Validate API key
159    async fn validate_api_key(&self, key: &str) -> AuthResult {
160        let api_key_config = match &self.config.api_key {
161            Some(config) => config,
162            None => return AuthResult::Failure("API key not configured".to_string()),
163        };
164
165        // Check if the key is valid
166        if api_key_config.keys.contains(&key.to_string()) {
167            AuthResult::Success(Claims {
168                sub: "api_key_user".to_string(),
169                exp: None,
170                iat: None,
171                aud: None,
172                iss: Some("mockforge".to_string()),
173                extra: {
174                    let mut extra = HashMap::new();
175                    extra.insert("auth_type".to_string(), serde_json::json!("api_key"));
176                    extra
177                },
178            })
179        } else {
180            AuthResult::Failure("Invalid API key".to_string())
181        }
182    }
183
184    /// Perform authentication
185    async fn authenticate(&self, request: &ProtocolRequest) -> AuthResult {
186        // Extract token
187        let token = match self.extract_token(request) {
188            Some(t) => t,
189            None => {
190                // If no token and auth is not required, allow
191                if !self.config.require_auth {
192                    return AuthResult::Success(Claims {
193                        sub: "anonymous".to_string(),
194                        exp: None,
195                        iat: None,
196                        aud: None,
197                        iss: Some("mockforge".to_string()),
198                        extra: HashMap::new(),
199                    });
200                }
201                return AuthResult::Failure("No authentication token provided".to_string());
202            }
203        };
204
205        // Try JWT validation first
206        if self.config.jwt.is_some() {
207            let result = self.validate_jwt(&token).await;
208            if matches!(result, AuthResult::Success(_)) {
209                return result;
210            }
211        }
212
213        // Try API key validation
214        if self.config.api_key.is_some() {
215            let result = self.validate_api_key(&token).await;
216            if matches!(result, AuthResult::Success(_)) {
217                return result;
218            }
219        }
220
221        AuthResult::Failure("Authentication failed".to_string())
222    }
223}
224
225#[async_trait::async_trait]
226impl ProtocolMiddleware for AuthMiddleware {
227    fn name(&self) -> &str {
228        &self.name
229    }
230
231    async fn process_request(&self, request: &mut ProtocolRequest) -> Result<()> {
232        // Skip authentication for health checks and admin endpoints
233        if request.path.starts_with("/health") || request.path.starts_with("/__mockforge") {
234            return Ok(());
235        }
236
237        // Perform authentication
238        match self.authenticate(request).await {
239            AuthResult::Success(claims) => {
240                // Add claims to request metadata
241                request.metadata.insert("x-auth-sub".to_string(), claims.sub.clone());
242                if let Some(iss) = &claims.iss {
243                    request.metadata.insert("x-auth-iss".to_string(), iss.clone());
244                }
245                tracing::debug!(
246                    protocol = %request.protocol,
247                    user = %claims.sub,
248                    "Authentication successful"
249                );
250                Ok(())
251            }
252            AuthResult::Failure(reason) => {
253                tracing::warn!(
254                    protocol = %request.protocol,
255                    path = %request.path,
256                    reason = %reason,
257                    "Authentication failed"
258                );
259                Err(crate::Error::validation(format!("Authentication failed: {}", reason)))
260            }
261            AuthResult::NetworkError(reason) => {
262                tracing::error!(
263                    protocol = %request.protocol,
264                    reason = %reason,
265                    "Authentication network error"
266                );
267                Err(crate::Error::validation(format!("Authentication error: {}", reason)))
268            }
269        }
270    }
271
272    async fn process_response(
273        &self,
274        _request: &ProtocolRequest,
275        _response: &mut ProtocolResponse,
276    ) -> Result<()> {
277        // No post-processing needed for auth
278        Ok(())
279    }
280
281    fn supports_protocol(&self, _protocol: Protocol) -> bool {
282        // Auth middleware supports all protocols
283        true
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use crate::config::ApiKeyConfig;
291
292    #[test]
293    fn test_auth_middleware_creation() {
294        let config = AuthConfig {
295            require_auth: true,
296            jwt: None,
297            api_key: None,
298            oauth2: None,
299            basic_auth: None,
300        };
301
302        let middleware = AuthMiddleware::new(config);
303        assert_eq!(middleware.name(), "AuthMiddleware");
304        assert!(middleware.supports_protocol(Protocol::Http));
305        assert!(middleware.supports_protocol(Protocol::Grpc));
306        assert!(middleware.supports_protocol(Protocol::GraphQL));
307    }
308
309    #[test]
310    fn test_extract_token_bearer() {
311        let config = AuthConfig::default();
312        let middleware = AuthMiddleware::new(config);
313
314        let mut metadata = HashMap::new();
315        metadata.insert("authorization".to_string(), "Bearer test_token".to_string());
316
317        let request = ProtocolRequest {
318            protocol: Protocol::Http,
319            pattern: crate::MessagePattern::RequestResponse,
320            operation: "GET".to_string(),
321            path: "/test".to_string(),
322            topic: None,
323            routing_key: None,
324            partition: None,
325            qos: None,
326            metadata,
327            body: None,
328            client_ip: None,
329        };
330
331        let token = middleware.extract_token(&request);
332        assert_eq!(token, Some("test_token".to_string()));
333    }
334
335    #[test]
336    fn test_extract_token_api_key() {
337        let config = AuthConfig {
338            require_auth: true,
339            jwt: None,
340            api_key: Some(ApiKeyConfig {
341                header_name: "X-API-Key".to_string(),
342                query_name: None,
343                keys: vec!["test_key".to_string()],
344            }),
345            oauth2: None,
346            basic_auth: None,
347        };
348        let middleware = AuthMiddleware::new(config);
349
350        let mut metadata = HashMap::new();
351        metadata.insert("X-API-Key".to_string(), "test_key".to_string());
352
353        let request = ProtocolRequest {
354            protocol: Protocol::Http,
355            operation: "GET".to_string(),
356            path: "/test".to_string(),
357            metadata,
358            ..Default::default()
359        };
360
361        let token = middleware.extract_token(&request);
362        assert_eq!(token, Some("test_key".to_string()));
363    }
364
365    #[tokio::test]
366    async fn test_validate_api_key_success() {
367        let config = AuthConfig {
368            require_auth: true,
369            jwt: None,
370            api_key: Some(ApiKeyConfig {
371                header_name: "X-API-Key".to_string(),
372                query_name: None,
373                keys: vec!["valid_key".to_string()],
374            }),
375            oauth2: None,
376            basic_auth: None,
377        };
378        let middleware = AuthMiddleware::new(config);
379
380        let result = middleware.validate_api_key("valid_key").await;
381        assert!(matches!(result, AuthResult::Success(_)));
382    }
383
384    #[tokio::test]
385    async fn test_validate_api_key_failure() {
386        let config = AuthConfig {
387            require_auth: true,
388            jwt: None,
389            api_key: Some(ApiKeyConfig {
390                header_name: "X-API-Key".to_string(),
391                query_name: None,
392                keys: vec!["valid_key".to_string()],
393            }),
394            oauth2: None,
395            basic_auth: None,
396        };
397        let middleware = AuthMiddleware::new(config);
398
399        let result = middleware.validate_api_key("invalid_key").await;
400        assert!(matches!(result, AuthResult::Failure(_)));
401    }
402}