airsprotocols_mcp/oauth2/
context.rs

1//! OAuth Authentication Context
2//!
3//! This module provides authentication context structures that carry
4//! validated OAuth 2.1 token information through the request pipeline.
5
6// Layer 1: Standard library imports
7use std::collections::HashMap;
8
9// Layer 2: Third-party crate imports
10use chrono::{DateTime, Duration, Utc};
11use serde::{Deserialize, Serialize};
12
13// Layer 3: Internal module imports
14use crate::oauth2::types::JwtClaims;
15
16/// Authentication context for OAuth 2.1 authenticated requests
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct AuthContext {
19    /// JWT claims from the validated token
20    pub claims: JwtClaims,
21
22    /// User's granted scopes
23    pub scopes: Vec<String>,
24
25    /// Timestamp when this context was created
26    pub created_at: DateTime<Utc>,
27
28    /// Token expiration time (if available)
29    pub expires_at: Option<DateTime<Utc>>,
30
31    /// Request ID for audit logging
32    pub request_id: Option<String>,
33
34    /// Additional context metadata
35    pub metadata: AuthMetadata,
36}
37
38/// Additional authentication metadata
39#[derive(Debug, Clone, Serialize, Deserialize, Default)]
40pub struct AuthMetadata {
41    /// Client IP address
42    pub client_ip: Option<String>,
43
44    /// User agent string
45    pub user_agent: Option<String>,
46
47    /// Custom attributes for extensibility
48    pub custom_attributes: HashMap<String, String>,
49}
50
51impl AuthContext {
52    /// Create a new authentication context from validated JWT claims
53    pub fn new(claims: JwtClaims, scopes: Vec<String>) -> Self {
54        let expires_at = claims
55            .exp
56            .map(|exp| DateTime::from_timestamp(exp, 0).unwrap_or_else(Utc::now));
57
58        Self {
59            claims,
60            scopes,
61            created_at: Utc::now(),
62            expires_at,
63            request_id: None,
64            metadata: AuthMetadata::default(),
65        }
66    }
67
68    /// Create a new authentication context with request ID for audit logging
69    pub fn with_request_id(mut self, request_id: String) -> Self {
70        self.request_id = Some(request_id);
71        self
72    }
73
74    /// Add client IP to the context
75    pub fn with_client_ip(mut self, client_ip: String) -> Self {
76        self.metadata.client_ip = Some(client_ip);
77        self
78    }
79
80    /// Add User-Agent to the context
81    pub fn with_user_agent(mut self, user_agent: String) -> Self {
82        self.metadata.user_agent = Some(user_agent);
83        self
84    }
85
86    /// Add custom attribute to the context
87    pub fn with_custom_attribute(mut self, key: String, value: String) -> Self {
88        self.metadata.custom_attributes.insert(key, value);
89        self
90    }
91
92    /// Get the user ID from the token subject
93    pub fn user_id(&self) -> &str {
94        &self.claims.sub
95    }
96
97    /// Get the token audience
98    pub fn audience(&self) -> Option<&str> {
99        self.claims.aud.as_deref()
100    }
101
102    /// Get the token issuer
103    pub fn issuer(&self) -> Option<&str> {
104        self.claims.iss.as_deref()
105    }
106
107    /// Get the JWT ID (jti)
108    pub fn jwt_id(&self) -> Option<&str> {
109        self.claims.jti.as_deref()
110    }
111
112    /// Check if the token is expired
113    pub fn is_expired(&self) -> bool {
114        match self.expires_at {
115            Some(expires_at) => Utc::now() > expires_at,
116            None => false, // No expiration time means never expires
117        }
118    }
119
120    /// Check if the context is still valid (not expired)
121    pub fn is_valid(&self) -> bool {
122        !self.is_expired()
123    }
124
125    /// Get time until expiration
126    pub fn time_until_expiration(&self) -> Option<Duration> {
127        self.expires_at.and_then(|expires_at| {
128            let duration = expires_at - Utc::now();
129            if duration.num_seconds() > 0 {
130                Some(duration)
131            } else {
132                None // Token is already expired
133            }
134        })
135    }
136
137    /// Check if user has a specific scope
138    pub fn has_scope(&self, scope: &str) -> bool {
139        self.scopes.contains(&scope.to_string())
140    }
141
142    /// Check if user has any of the specified scopes
143    pub fn has_any_scope(&self, scopes: &[String]) -> bool {
144        scopes.iter().any(|scope| self.has_scope(scope))
145    }
146
147    /// Check if user has all of the specified scopes
148    pub fn has_all_scopes(&self, scopes: &[String]) -> bool {
149        scopes.iter().all(|scope| self.has_scope(scope))
150    }
151
152    /// Get scopes that match a pattern (e.g., "mcp:tools:*")
153    pub fn get_scopes_matching(&self, pattern: &str) -> Vec<&String> {
154        if let Some(prefix) = pattern.strip_suffix('*') {
155            self.scopes
156                .iter()
157                .filter(|scope| scope.starts_with(prefix))
158                .collect()
159        } else {
160            self.scopes
161                .iter()
162                .filter(|scope| *scope == pattern)
163                .collect()
164        }
165    }
166
167    /// Create an audit log entry for this authentication context
168    pub fn create_audit_entry(&self, action: &str, resource: &str) -> AuditLogEntry {
169        AuditLogEntry {
170            timestamp: Utc::now(),
171            user_id: self.user_id().to_string(),
172            action: action.to_string(),
173            resource: resource.to_string(),
174            scopes: self.scopes.clone(),
175            client_ip: self.metadata.client_ip.clone(),
176            user_agent: self.metadata.user_agent.clone(),
177            request_id: self.request_id.clone(),
178            jwt_id: self.jwt_id().map(|s| s.to_string()),
179            success: true, // Will be updated based on operation result
180        }
181    }
182
183    /// Convert to a summary for logging (without sensitive data)
184    pub fn to_log_summary(&self) -> AuthContextSummary {
185        AuthContextSummary {
186            user_id: self.user_id().to_string(),
187            scopes: self.scopes.clone(),
188            expires_at: self.expires_at,
189            client_ip: self.metadata.client_ip.clone(),
190            request_id: self.request_id.clone(),
191        }
192    }
193}
194
195/// Audit log entry for authentication events
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct AuditLogEntry {
198    pub timestamp: DateTime<Utc>,
199    pub user_id: String,
200    pub action: String,
201    pub resource: String,
202    pub scopes: Vec<String>,
203    pub client_ip: Option<String>,
204    pub user_agent: Option<String>,
205    pub request_id: Option<String>,
206    pub jwt_id: Option<String>,
207    pub success: bool,
208}
209
210/// Authentication context summary for logging (without sensitive data)
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct AuthContextSummary {
213    pub user_id: String,
214    pub scopes: Vec<String>,
215    pub expires_at: Option<DateTime<Utc>>,
216    pub client_ip: Option<String>,
217    pub request_id: Option<String>,
218}
219
220/// Trait for extracting authentication context from request extensions
221pub trait AuthContextExt {
222    /// Extract the authentication context from the request
223    fn auth_context(&self) -> Option<&AuthContext>;
224
225    /// Extract the authentication context mutably
226    fn auth_context_mut(&mut self) -> Option<&mut AuthContext>;
227}
228
229// Implementation for Axum's request extensions
230impl AuthContextExt for axum::http::Extensions {
231    fn auth_context(&self) -> Option<&AuthContext> {
232        self.get::<AuthContext>()
233    }
234
235    fn auth_context_mut(&mut self) -> Option<&mut AuthContext> {
236        self.get_mut::<AuthContext>()
237    }
238}
239
240/// Helper macros for working with authentication context
241#[macro_export]
242macro_rules! require_auth {
243    ($extensions:expr) => {
244        $extensions
245            .auth_context()
246            .ok_or_else(|| $crate::oauth2::error::OAuth2Error::MissingAuthorization)?
247    };
248}
249
250#[macro_export]
251macro_rules! require_scope {
252    ($context:expr, $scope:expr) => {
253        if !$context.has_scope($scope) {
254            return Err($crate::oauth2::error::OAuth2Error::InsufficientScope {
255                required: $scope.to_string(),
256                provided: $context.scopes.join(" "),
257            });
258        }
259    };
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::oauth2::types::JwtClaims;
266
267    fn create_test_claims() -> JwtClaims {
268        JwtClaims {
269            sub: "user123".to_string(),
270            aud: Some("mcp-server".to_string()),
271            iss: Some("https://auth.example.com".to_string()),
272            exp: Some(
273                Utc::now().timestamp() + 3600, // Expires in 1 hour
274            ),
275            nbf: None,
276            iat: None,
277            jti: Some("jwt-123".to_string()),
278            scope: Some("mcp:tools:execute mcp:resources:read".to_string()),
279            scopes: None,
280        }
281    }
282
283    #[test]
284    fn test_auth_context_creation() {
285        let claims = create_test_claims();
286        let scopes = vec![
287            "mcp:tools:execute".to_string(),
288            "mcp:resources:read".to_string(),
289        ];
290
291        let context = AuthContext::new(claims.clone(), scopes.clone());
292
293        assert_eq!(context.user_id(), "user123");
294        assert_eq!(context.audience(), Some("mcp-server"));
295        assert_eq!(context.issuer(), Some("https://auth.example.com"));
296        assert_eq!(context.jwt_id(), Some("jwt-123"));
297        assert_eq!(context.scopes, scopes);
298        assert!(!context.is_expired());
299    }
300
301    #[test]
302    fn test_auth_context_builders() {
303        let claims = create_test_claims();
304        let scopes = vec!["mcp:tools:execute".to_string()];
305
306        let context = AuthContext::new(claims, scopes)
307            .with_request_id("req-123".to_string())
308            .with_client_ip("192.168.1.1".to_string())
309            .with_user_agent("TestAgent/1.0".to_string())
310            .with_custom_attribute("tenant".to_string(), "example-org".to_string());
311
312        assert_eq!(context.request_id, Some("req-123".to_string()));
313        assert_eq!(context.metadata.client_ip, Some("192.168.1.1".to_string()));
314        assert_eq!(
315            context.metadata.user_agent,
316            Some("TestAgent/1.0".to_string())
317        );
318        assert_eq!(
319            context.metadata.custom_attributes.get("tenant"),
320            Some(&"example-org".to_string())
321        );
322    }
323
324    #[test]
325    fn test_scope_checking() {
326        let claims = create_test_claims();
327        let scopes = vec![
328            "mcp:tools:execute".to_string(),
329            "mcp:resources:read".to_string(),
330            "mcp:admin:all".to_string(),
331        ];
332
333        let context = AuthContext::new(claims, scopes);
334
335        // Test individual scope checking
336        assert!(context.has_scope("mcp:tools:execute"));
337        assert!(context.has_scope("mcp:resources:read"));
338        assert!(!context.has_scope("mcp:tools:admin"));
339
340        // Test any scope checking
341        assert!(context.has_any_scope(&["mcp:tools:execute".to_string(),
342            "mcp:unknown:scope".to_string()]));
343        assert!(!context.has_any_scope(&["mcp:unknown:scope".to_string()]));
344
345        // Test all scopes checking
346        assert!(context.has_all_scopes(&["mcp:tools:execute".to_string(),
347            "mcp:resources:read".to_string()]));
348        assert!(!context.has_all_scopes(&["mcp:tools:execute".to_string(),
349            "mcp:unknown:scope".to_string()]));
350    }
351
352    #[test]
353    fn test_scope_pattern_matching() {
354        let claims = create_test_claims();
355        let scopes = vec![
356            "mcp:tools:execute".to_string(),
357            "mcp:tools:read".to_string(),
358            "mcp:resources:read".to_string(),
359        ];
360
361        let context = AuthContext::new(claims, scopes);
362
363        // Test pattern matching
364        let tools_scopes = context.get_scopes_matching("mcp:tools:*");
365        assert_eq!(tools_scopes.len(), 2);
366        assert!(tools_scopes.contains(&&"mcp:tools:execute".to_string()));
367        assert!(tools_scopes.contains(&&"mcp:tools:read".to_string()));
368
369        let exact_scope = context.get_scopes_matching("mcp:resources:read");
370        assert_eq!(exact_scope.len(), 1);
371        assert!(exact_scope.contains(&&"mcp:resources:read".to_string()));
372    }
373
374    #[test]
375    fn test_expiration_checking() {
376        let mut claims = create_test_claims();
377
378        // Test with expired token
379        claims.exp = Some(
380            Utc::now().timestamp() - 3600, // Expired 1 hour ago
381        );
382
383        let context = AuthContext::new(claims, vec![]);
384        assert!(context.is_expired());
385        assert!(!context.is_valid());
386        assert!(context.time_until_expiration().is_none());
387    }
388
389    #[test]
390    fn test_audit_log_entry() {
391        let claims = create_test_claims();
392        let scopes = vec!["mcp:tools:execute".to_string()];
393
394        let context = AuthContext::new(claims, scopes)
395            .with_request_id("req-123".to_string())
396            .with_client_ip("192.168.1.1".to_string());
397
398        let audit_entry = context.create_audit_entry("tools/call", "calculator");
399
400        assert_eq!(audit_entry.user_id, "user123");
401        assert_eq!(audit_entry.action, "tools/call");
402        assert_eq!(audit_entry.resource, "calculator");
403        assert_eq!(audit_entry.request_id, Some("req-123".to_string()));
404        assert_eq!(audit_entry.client_ip, Some("192.168.1.1".to_string()));
405        assert_eq!(audit_entry.jwt_id, Some("jwt-123".to_string()));
406        assert!(audit_entry.success);
407    }
408
409    #[test]
410    fn test_log_summary() {
411        let claims = create_test_claims();
412        let scopes = vec!["mcp:tools:execute".to_string()];
413
414        let context = AuthContext::new(claims, scopes.clone())
415            .with_request_id("req-123".to_string())
416            .with_client_ip("192.168.1.1".to_string());
417
418        let summary = context.to_log_summary();
419
420        assert_eq!(summary.user_id, "user123");
421        assert_eq!(summary.scopes, scopes);
422        assert_eq!(summary.request_id, Some("req-123".to_string()));
423        assert_eq!(summary.client_ip, Some("192.168.1.1".to_string()));
424        // JWT claims and other sensitive data should not be in summary
425    }
426}