mockforge_core/protocol_abstraction/
auth.rs

1//! Unified authentication middleware for all protocols
2
3use super::{Protocol, ProtocolMiddleware, ProtocolRequest, ProtocolResponse};
4use crate::config::AuthConfig;
5use crate::Result;
6use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11/// JWT token claims extracted from authentication tokens
12#[derive(Debug, Serialize, Deserialize, Clone)]
13pub struct Claims {
14    /// Subject (user ID or identifier)
15    pub sub: String,
16    /// Expiration timestamp
17    pub exp: Option<usize>,
18    /// Issued at timestamp
19    pub iat: Option<usize>,
20    /// Audience (intended recipient)
21    pub aud: Option<String>,
22    /// Issuer (token issuer identifier)
23    pub iss: Option<String>,
24    /// Additional custom claims (flattened into the struct)
25    #[serde(flatten)]
26    pub extra: HashMap<String, serde_json::Value>,
27}
28
29/// Result of authentication attempt
30#[derive(Debug, Clone)]
31pub enum AuthResult {
32    /// Authentication successful with extracted claims
33    Success(Claims),
34    /// Authentication failed with error message
35    Failure(String),
36    /// Network error occurred during authentication
37    NetworkError(String),
38}
39
40/// Unified authentication middleware
41pub struct AuthMiddleware {
42    /// Middleware name
43    name: String,
44    /// Authentication configuration
45    config: Arc<AuthConfig>,
46    /// Token introspection cache
47    introspection_cache: Arc<RwLock<HashMap<String, CachedToken>>>,
48}
49
50/// Cached token information
51#[derive(Debug, Clone)]
52struct CachedToken {
53    claims: Claims,
54    expires_at: std::time::Instant,
55}
56
57impl AuthMiddleware {
58    /// Create a new auth middleware
59    pub fn new(config: AuthConfig) -> Self {
60        Self {
61            name: "AuthMiddleware".to_string(),
62            config: Arc::new(config),
63            introspection_cache: Arc::new(RwLock::new(HashMap::new())),
64        }
65    }
66
67    /// Extract auth token from request metadata
68    fn extract_token(&self, request: &ProtocolRequest) -> Option<String> {
69        // Try Authorization header first (works for HTTP, GraphQL, WebSocket)
70        if let Some(auth_header) = request.metadata.get("authorization") {
71            // Handle "Bearer <token>" format
72            if let Some(token) = auth_header.strip_prefix("Bearer ") {
73                return Some(token.to_string());
74            }
75            return Some(auth_header.clone());
76        }
77
78        // Try API key header
79        if let Some(api_key_config) = &self.config.api_key {
80            if let Some(api_key) = request.metadata.get(&api_key_config.header_name) {
81                return Some(api_key.clone());
82            }
83        }
84
85        // For gRPC, try metadata
86        if request.protocol == Protocol::Grpc {
87            if let Some(token) = request.metadata.get("grpc-metadata-authorization") {
88                if let Some(stripped) = token.strip_prefix("Bearer ") {
89                    return Some(stripped.to_string());
90                }
91                return Some(token.clone());
92            }
93        }
94
95        None
96    }
97
98    /// Validate JWT token
99    async fn validate_jwt(&self, token: &str) -> AuthResult {
100        // Check cache first
101        if let Some(cached) = self.introspection_cache.read().await.get(token) {
102            if cached.expires_at > std::time::Instant::now() {
103                return AuthResult::Success(cached.claims.clone());
104            }
105        }
106
107        // Get JWT configuration
108        let jwt_config = match &self.config.jwt {
109            Some(config) => config,
110            None => return AuthResult::Failure("JWT not configured".to_string()),
111        };
112
113        // Decode header to get algorithm
114        let header = match decode_header(token) {
115            Ok(h) => h,
116            Err(e) => return AuthResult::Failure(format!("Invalid token header: {}", e)),
117        };
118
119        // Create validation
120        let mut validation = Validation::new(header.alg);
121        if let Some(audience) = &jwt_config.audience {
122            validation.set_audience(&[audience]);
123        }
124        if let Some(issuer) = &jwt_config.issuer {
125            validation.set_issuer(&[issuer]);
126        }
127
128        // Get secret
129        let secret = match &jwt_config.secret {
130            Some(s) => s,
131            None => return AuthResult::Failure("JWT secret not configured".to_string()),
132        };
133
134        // Decode token
135        let decoding_key = DecodingKey::from_secret(secret.as_bytes());
136        match decode::<Claims>(token, &decoding_key, &validation) {
137            Ok(token_data) => {
138                let claims = token_data.claims;
139
140                // Cache the token
141                let expires_at = if let Some(exp) = claims.exp {
142                    let exp_instant =
143                        std::time::UNIX_EPOCH + std::time::Duration::from_secs(exp as u64);
144                    std::time::Instant::now()
145                        + exp_instant.elapsed().unwrap_or(std::time::Duration::from_secs(300))
146                } else {
147                    std::time::Instant::now() + std::time::Duration::from_secs(300)
148                };
149
150                self.introspection_cache.write().await.insert(
151                    token.to_string(),
152                    CachedToken {
153                        claims: claims.clone(),
154                        expires_at,
155                    },
156                );
157
158                AuthResult::Success(claims)
159            }
160            Err(e) => AuthResult::Failure(format!("Token validation failed: {}", e)),
161        }
162    }
163
164    /// Validate API key
165    async fn validate_api_key(&self, key: &str) -> AuthResult {
166        let api_key_config = match &self.config.api_key {
167            Some(config) => config,
168            None => return AuthResult::Failure("API key not configured".to_string()),
169        };
170
171        // Check if the key is valid
172        if api_key_config.keys.contains(&key.to_string()) {
173            AuthResult::Success(Claims {
174                sub: "api_key_user".to_string(),
175                exp: None,
176                iat: None,
177                aud: None,
178                iss: Some("mockforge".to_string()),
179                extra: {
180                    let mut extra = HashMap::new();
181                    extra.insert("auth_type".to_string(), serde_json::json!("api_key"));
182                    extra
183                },
184            })
185        } else {
186            AuthResult::Failure("Invalid API key".to_string())
187        }
188    }
189
190    /// Perform authentication
191    async fn authenticate(&self, request: &ProtocolRequest) -> AuthResult {
192        // Extract token
193        let token = match self.extract_token(request) {
194            Some(t) => t,
195            None => {
196                // If no token and auth is not required, allow
197                if !self.config.require_auth {
198                    return AuthResult::Success(Claims {
199                        sub: "anonymous".to_string(),
200                        exp: None,
201                        iat: None,
202                        aud: None,
203                        iss: Some("mockforge".to_string()),
204                        extra: HashMap::new(),
205                    });
206                }
207                return AuthResult::Failure("No authentication token provided".to_string());
208            }
209        };
210
211        // Try JWT validation first
212        if self.config.jwt.is_some() {
213            let result = self.validate_jwt(&token).await;
214            if matches!(result, AuthResult::Success(_)) {
215                return result;
216            }
217        }
218
219        // Try API key validation
220        if self.config.api_key.is_some() {
221            let result = self.validate_api_key(&token).await;
222            if matches!(result, AuthResult::Success(_)) {
223                return result;
224            }
225        }
226
227        AuthResult::Failure("Authentication failed".to_string())
228    }
229}
230
231#[async_trait::async_trait]
232impl ProtocolMiddleware for AuthMiddleware {
233    fn name(&self) -> &str {
234        &self.name
235    }
236
237    async fn process_request(&self, request: &mut ProtocolRequest) -> Result<()> {
238        // Skip authentication for health checks and admin endpoints
239        if request.path.starts_with("/health") || request.path.starts_with("/__mockforge") {
240            return Ok(());
241        }
242
243        // Perform authentication
244        match self.authenticate(request).await {
245            AuthResult::Success(claims) => {
246                // Add claims to request metadata
247                request.metadata.insert("x-auth-sub".to_string(), claims.sub.clone());
248                if let Some(iss) = &claims.iss {
249                    request.metadata.insert("x-auth-iss".to_string(), iss.clone());
250                }
251                tracing::debug!(
252                    protocol = %request.protocol,
253                    user = %claims.sub,
254                    "Authentication successful"
255                );
256                Ok(())
257            }
258            AuthResult::Failure(reason) => {
259                tracing::warn!(
260                    protocol = %request.protocol,
261                    path = %request.path,
262                    reason = %reason,
263                    "Authentication failed"
264                );
265                Err(crate::Error::validation(format!("Authentication failed: {}", reason)))
266            }
267            AuthResult::NetworkError(reason) => {
268                tracing::error!(
269                    protocol = %request.protocol,
270                    reason = %reason,
271                    "Authentication network error"
272                );
273                Err(crate::Error::validation(format!("Authentication error: {}", reason)))
274            }
275        }
276    }
277
278    async fn process_response(
279        &self,
280        _request: &ProtocolRequest,
281        _response: &mut ProtocolResponse,
282    ) -> Result<()> {
283        // No post-processing needed for auth
284        Ok(())
285    }
286
287    fn supports_protocol(&self, _protocol: Protocol) -> bool {
288        // Auth middleware supports all protocols
289        true
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use crate::config::ApiKeyConfig;
297
298    #[test]
299    fn test_auth_middleware_creation() {
300        let config = AuthConfig {
301            require_auth: true,
302            jwt: None,
303            api_key: None,
304            oauth2: None,
305            basic_auth: None,
306        };
307
308        let middleware = AuthMiddleware::new(config);
309        assert_eq!(middleware.name(), "AuthMiddleware");
310        assert!(middleware.supports_protocol(Protocol::Http));
311        assert!(middleware.supports_protocol(Protocol::Grpc));
312        assert!(middleware.supports_protocol(Protocol::GraphQL));
313    }
314
315    #[test]
316    fn test_extract_token_bearer() {
317        let config = AuthConfig::default();
318        let middleware = AuthMiddleware::new(config);
319
320        let mut metadata = HashMap::new();
321        metadata.insert("authorization".to_string(), "Bearer test_token".to_string());
322
323        let request = ProtocolRequest {
324            protocol: Protocol::Http,
325            pattern: crate::MessagePattern::RequestResponse,
326            operation: "GET".to_string(),
327            path: "/test".to_string(),
328            topic: None,
329            routing_key: None,
330            partition: None,
331            qos: None,
332            metadata,
333            body: None,
334            client_ip: None,
335        };
336
337        let token = middleware.extract_token(&request);
338        assert_eq!(token, Some("test_token".to_string()));
339    }
340
341    #[test]
342    fn test_extract_token_api_key() {
343        let config = AuthConfig {
344            require_auth: true,
345            jwt: None,
346            api_key: Some(ApiKeyConfig {
347                header_name: "X-API-Key".to_string(),
348                query_name: None,
349                keys: vec!["test_key".to_string()],
350            }),
351            oauth2: None,
352            basic_auth: None,
353        };
354        let middleware = AuthMiddleware::new(config);
355
356        let mut metadata = HashMap::new();
357        metadata.insert("X-API-Key".to_string(), "test_key".to_string());
358
359        let request = ProtocolRequest {
360            protocol: Protocol::Http,
361            operation: "GET".to_string(),
362            path: "/test".to_string(),
363            metadata,
364            ..Default::default()
365        };
366
367        let token = middleware.extract_token(&request);
368        assert_eq!(token, Some("test_key".to_string()));
369    }
370
371    #[tokio::test]
372    async fn test_validate_api_key_success() {
373        let config = AuthConfig {
374            require_auth: true,
375            jwt: None,
376            api_key: Some(ApiKeyConfig {
377                header_name: "X-API-Key".to_string(),
378                query_name: None,
379                keys: vec!["valid_key".to_string()],
380            }),
381            oauth2: None,
382            basic_auth: None,
383        };
384        let middleware = AuthMiddleware::new(config);
385
386        let result = middleware.validate_api_key("valid_key").await;
387        assert!(matches!(result, AuthResult::Success(_)));
388    }
389
390    #[tokio::test]
391    async fn test_validate_api_key_failure() {
392        let config = AuthConfig {
393            require_auth: true,
394            jwt: None,
395            api_key: Some(ApiKeyConfig {
396                header_name: "X-API-Key".to_string(),
397                query_name: None,
398                keys: vec!["valid_key".to_string()],
399            }),
400            oauth2: None,
401            basic_auth: None,
402        };
403        let middleware = AuthMiddleware::new(config);
404
405        let result = middleware.validate_api_key("invalid_key").await;
406        assert!(matches!(result, AuthResult::Failure(_)));
407    }
408}