Skip to main content

mockforge_core/protocol_abstraction/
auth.rs

1//! Unified authentication middleware for all protocols
2
3use super::{
4    MiddlewareAction, Protocol, ProtocolMiddleware, ProtocolRequest, ProtocolResponse,
5    ResponseStatus,
6};
7use crate::config::AuthConfig;
8use crate::Result;
9use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14/// JWT token claims extracted from authentication tokens
15#[derive(Debug, Serialize, Deserialize, Clone)]
16pub struct Claims {
17    /// Subject (user ID or identifier)
18    pub sub: String,
19    /// Expiration timestamp
20    pub exp: Option<usize>,
21    /// Issued at timestamp
22    pub iat: Option<usize>,
23    /// Audience (intended recipient)
24    pub aud: Option<String>,
25    /// Issuer (token issuer identifier)
26    pub iss: Option<String>,
27    /// Additional custom claims (flattened into the struct)
28    #[serde(flatten)]
29    pub extra: HashMap<String, serde_json::Value>,
30}
31
32/// Result of authentication attempt
33#[derive(Debug, Clone)]
34pub enum AuthResult {
35    /// Authentication successful with extracted claims
36    Success(Claims),
37    /// Authentication failed with error message
38    Failure(String),
39    /// Network error occurred during authentication
40    NetworkError(String),
41}
42
43/// Unified authentication middleware
44pub struct AuthMiddleware {
45    /// Middleware name
46    name: String,
47    /// Authentication configuration
48    config: Arc<AuthConfig>,
49    /// Token introspection cache
50    introspection_cache: Arc<RwLock<HashMap<String, CachedToken>>>,
51}
52
53/// Cached token information
54#[derive(Debug, Clone)]
55struct CachedToken {
56    claims: Claims,
57    expires_at: std::time::Instant,
58}
59
60impl AuthMiddleware {
61    /// Create a new auth middleware
62    pub fn new(config: AuthConfig) -> Self {
63        Self {
64            name: "AuthMiddleware".to_string(),
65            config: Arc::new(config),
66            introspection_cache: Arc::new(RwLock::new(HashMap::new())),
67        }
68    }
69
70    /// Extract auth token from request metadata
71    fn extract_token(&self, request: &ProtocolRequest) -> Option<String> {
72        // Try Authorization header first (works for HTTP, GraphQL, WebSocket)
73        if let Some(auth_header) = request.metadata.get("authorization") {
74            // Handle "Bearer <token>" format
75            if let Some(token) = auth_header.strip_prefix("Bearer ") {
76                return Some(token.to_string());
77            }
78            return Some(auth_header.clone());
79        }
80
81        // Try API key header
82        if let Some(api_key_config) = &self.config.api_key {
83            if let Some(api_key) = request.metadata.get(&api_key_config.header_name) {
84                return Some(api_key.clone());
85            }
86        }
87
88        // For gRPC, try metadata
89        if request.protocol == Protocol::Grpc {
90            if let Some(token) = request.metadata.get("grpc-metadata-authorization") {
91                if let Some(stripped) = token.strip_prefix("Bearer ") {
92                    return Some(stripped.to_string());
93                }
94                return Some(token.clone());
95            }
96        }
97
98        None
99    }
100
101    /// Validate JWT token
102    async fn validate_jwt(&self, token: &str) -> AuthResult {
103        // Check cache first
104        if let Some(cached) = self.introspection_cache.read().await.get(token) {
105            if cached.expires_at > std::time::Instant::now() {
106                return AuthResult::Success(cached.claims.clone());
107            }
108        }
109
110        // Get JWT configuration
111        let jwt_config = match &self.config.jwt {
112            Some(config) => config,
113            None => return AuthResult::Failure("JWT not configured".to_string()),
114        };
115
116        // Decode header to get algorithm
117        let header = match decode_header(token) {
118            Ok(h) => h,
119            Err(e) => return AuthResult::Failure(format!("Invalid token header: {}", e)),
120        };
121
122        // Create validation
123        let mut validation = Validation::new(header.alg);
124        if let Some(audience) = &jwt_config.audience {
125            validation.set_audience(&[audience]);
126        }
127        if let Some(issuer) = &jwt_config.issuer {
128            validation.set_issuer(&[issuer]);
129        }
130
131        // Get secret
132        let secret = match &jwt_config.secret {
133            Some(s) => s,
134            None => return AuthResult::Failure("JWT secret not configured".to_string()),
135        };
136
137        // Decode token
138        let decoding_key = DecodingKey::from_secret(secret.as_bytes());
139        match decode::<Claims>(token, &decoding_key, &validation) {
140            Ok(token_data) => {
141                let claims = token_data.claims;
142
143                // Cache the token
144                let expires_at = if let Some(exp) = claims.exp {
145                    let exp_instant =
146                        std::time::UNIX_EPOCH + std::time::Duration::from_secs(exp as u64);
147                    std::time::Instant::now()
148                        + exp_instant.elapsed().unwrap_or(std::time::Duration::from_secs(300))
149                } else {
150                    std::time::Instant::now() + std::time::Duration::from_secs(300)
151                };
152
153                self.introspection_cache.write().await.insert(
154                    token.to_string(),
155                    CachedToken {
156                        claims: claims.clone(),
157                        expires_at,
158                    },
159                );
160
161                AuthResult::Success(claims)
162            }
163            Err(e) => AuthResult::Failure(format!("Token validation failed: {}", e)),
164        }
165    }
166
167    /// Validate API key
168    async fn validate_api_key(&self, key: &str) -> AuthResult {
169        let api_key_config = match &self.config.api_key {
170            Some(config) => config,
171            None => return AuthResult::Failure("API key not configured".to_string()),
172        };
173
174        // Check if the key is valid
175        if api_key_config.keys.contains(&key.to_string()) {
176            AuthResult::Success(Claims {
177                sub: "api_key_user".to_string(),
178                exp: None,
179                iat: None,
180                aud: None,
181                iss: Some("mockforge".to_string()),
182                extra: {
183                    let mut extra = HashMap::new();
184                    extra.insert("auth_type".to_string(), serde_json::json!("api_key"));
185                    extra
186                },
187            })
188        } else {
189            AuthResult::Failure("Invalid API key".to_string())
190        }
191    }
192
193    /// Perform authentication
194    async fn authenticate(&self, request: &ProtocolRequest) -> AuthResult {
195        // Extract token
196        let token = match self.extract_token(request) {
197            Some(t) => t,
198            None => {
199                // If no token and auth is not required, allow
200                if !self.config.require_auth {
201                    return AuthResult::Success(Claims {
202                        sub: "anonymous".to_string(),
203                        exp: None,
204                        iat: None,
205                        aud: None,
206                        iss: Some("mockforge".to_string()),
207                        extra: HashMap::new(),
208                    });
209                }
210                return AuthResult::Failure("No authentication token provided".to_string());
211            }
212        };
213
214        // Try JWT validation first
215        if self.config.jwt.is_some() {
216            let result = self.validate_jwt(&token).await;
217            if matches!(result, AuthResult::Success(_)) {
218                return result;
219            }
220        }
221
222        // Try API key validation
223        if self.config.api_key.is_some() {
224            let result = self.validate_api_key(&token).await;
225            if matches!(result, AuthResult::Success(_)) {
226                return result;
227            }
228        }
229
230        AuthResult::Failure("Authentication failed".to_string())
231    }
232}
233
234#[async_trait::async_trait]
235impl ProtocolMiddleware for AuthMiddleware {
236    fn name(&self) -> &str {
237        &self.name
238    }
239
240    async fn process_request(&self, request: &mut ProtocolRequest) -> Result<MiddlewareAction> {
241        // Skip authentication for health checks and admin endpoints
242        if request.path.starts_with("/health") || request.path.starts_with("/__mockforge") {
243            return Ok(MiddlewareAction::Continue);
244        }
245
246        // Perform authentication
247        match self.authenticate(request).await {
248            AuthResult::Success(claims) => {
249                // Add claims to request metadata
250                request.metadata.insert("x-auth-sub".to_string(), claims.sub.clone());
251                if let Some(iss) = &claims.iss {
252                    request.metadata.insert("x-auth-iss".to_string(), iss.clone());
253                }
254                tracing::debug!(
255                    protocol = %request.protocol,
256                    user = %claims.sub,
257                    "Authentication successful"
258                );
259                Ok(MiddlewareAction::Continue)
260            }
261            AuthResult::Failure(reason) => {
262                tracing::warn!(
263                    protocol = %request.protocol,
264                    path = %request.path,
265                    reason = %reason,
266                    "Authentication failed"
267                );
268                Ok(MiddlewareAction::ShortCircuit(ProtocolResponse {
269                    status: ResponseStatus::HttpStatus(401),
270                    metadata: std::collections::HashMap::new(),
271                    body: format!(r#"{{"error":"Authentication failed","reason":"{}"}}"#, reason)
272                        .into_bytes(),
273                    content_type: "application/json".to_string(),
274                }))
275            }
276            AuthResult::NetworkError(reason) => {
277                tracing::error!(
278                    protocol = %request.protocol,
279                    reason = %reason,
280                    "Authentication network error"
281                );
282                Ok(MiddlewareAction::ShortCircuit(ProtocolResponse {
283                    status: ResponseStatus::HttpStatus(503),
284                    metadata: std::collections::HashMap::new(),
285                    body: format!(
286                        r#"{{"error":"Authentication service unavailable","reason":"{}"}}"#,
287                        reason
288                    )
289                    .into_bytes(),
290                    content_type: "application/json".to_string(),
291                }))
292            }
293        }
294    }
295
296    async fn process_response(
297        &self,
298        _request: &ProtocolRequest,
299        _response: &mut ProtocolResponse,
300    ) -> Result<()> {
301        // No post-processing needed for auth
302        Ok(())
303    }
304
305    fn supports_protocol(&self, _protocol: Protocol) -> bool {
306        // Auth middleware supports all protocols
307        true
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::config::ApiKeyConfig;
315
316    #[test]
317    fn test_auth_middleware_creation() {
318        let config = AuthConfig {
319            require_auth: true,
320            jwt: None,
321            api_key: None,
322            oauth2: None,
323            basic_auth: None,
324        };
325
326        let middleware = AuthMiddleware::new(config);
327        assert_eq!(middleware.name(), "AuthMiddleware");
328        assert!(middleware.supports_protocol(Protocol::Http));
329        assert!(middleware.supports_protocol(Protocol::Grpc));
330        assert!(middleware.supports_protocol(Protocol::GraphQL));
331    }
332
333    #[test]
334    fn test_extract_token_bearer() {
335        let config = AuthConfig::default();
336        let middleware = AuthMiddleware::new(config);
337
338        let mut metadata = HashMap::new();
339        metadata.insert("authorization".to_string(), "Bearer test_token".to_string());
340
341        let request = ProtocolRequest {
342            protocol: Protocol::Http,
343            pattern: crate::MessagePattern::RequestResponse,
344            operation: "GET".to_string(),
345            path: "/test".to_string(),
346            topic: None,
347            routing_key: None,
348            partition: None,
349            qos: None,
350            metadata,
351            body: None,
352            client_ip: None,
353        };
354
355        let token = middleware.extract_token(&request);
356        assert_eq!(token, Some("test_token".to_string()));
357    }
358
359    #[test]
360    fn test_extract_token_api_key() {
361        let config = AuthConfig {
362            require_auth: true,
363            jwt: None,
364            api_key: Some(ApiKeyConfig {
365                header_name: "X-API-Key".to_string(),
366                query_name: None,
367                keys: vec!["test_key".to_string()],
368            }),
369            oauth2: None,
370            basic_auth: None,
371        };
372        let middleware = AuthMiddleware::new(config);
373
374        let mut metadata = HashMap::new();
375        metadata.insert("X-API-Key".to_string(), "test_key".to_string());
376
377        let request = ProtocolRequest {
378            protocol: Protocol::Http,
379            operation: "GET".to_string(),
380            path: "/test".to_string(),
381            metadata,
382            ..Default::default()
383        };
384
385        let token = middleware.extract_token(&request);
386        assert_eq!(token, Some("test_key".to_string()));
387    }
388
389    #[tokio::test]
390    async fn test_validate_api_key_success() {
391        let config = AuthConfig {
392            require_auth: true,
393            jwt: None,
394            api_key: Some(ApiKeyConfig {
395                header_name: "X-API-Key".to_string(),
396                query_name: None,
397                keys: vec!["valid_key".to_string()],
398            }),
399            oauth2: None,
400            basic_auth: None,
401        };
402        let middleware = AuthMiddleware::new(config);
403
404        let result = middleware.validate_api_key("valid_key").await;
405        assert!(matches!(result, AuthResult::Success(_)));
406    }
407
408    #[tokio::test]
409    async fn test_validate_api_key_failure() {
410        let config = AuthConfig {
411            require_auth: true,
412            jwt: None,
413            api_key: Some(ApiKeyConfig {
414                header_name: "X-API-Key".to_string(),
415                query_name: None,
416                keys: vec!["valid_key".to_string()],
417            }),
418            oauth2: None,
419            basic_auth: None,
420        };
421        let middleware = AuthMiddleware::new(config);
422
423        let result = middleware.validate_api_key("invalid_key").await;
424        assert!(matches!(result, AuthResult::Failure(_)));
425    }
426}