airsprotocols_mcp/oauth2/
context.rs1use std::collections::HashMap;
8
9use chrono::{DateTime, Duration, Utc};
11use serde::{Deserialize, Serialize};
12
13use crate::oauth2::types::JwtClaims;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct AuthContext {
19 pub claims: JwtClaims,
21
22 pub scopes: Vec<String>,
24
25 pub created_at: DateTime<Utc>,
27
28 pub expires_at: Option<DateTime<Utc>>,
30
31 pub request_id: Option<String>,
33
34 pub metadata: AuthMetadata,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, Default)]
40pub struct AuthMetadata {
41 pub client_ip: Option<String>,
43
44 pub user_agent: Option<String>,
46
47 pub custom_attributes: HashMap<String, String>,
49}
50
51impl AuthContext {
52 pub fn new(claims: JwtClaims, scopes: Vec<String>) -> Self {
54 let expires_at = claims
55 .exp
56 .map(|exp| DateTime::from_timestamp(exp, 0).unwrap_or_else(Utc::now));
57
58 Self {
59 claims,
60 scopes,
61 created_at: Utc::now(),
62 expires_at,
63 request_id: None,
64 metadata: AuthMetadata::default(),
65 }
66 }
67
68 pub fn with_request_id(mut self, request_id: String) -> Self {
70 self.request_id = Some(request_id);
71 self
72 }
73
74 pub fn with_client_ip(mut self, client_ip: String) -> Self {
76 self.metadata.client_ip = Some(client_ip);
77 self
78 }
79
80 pub fn with_user_agent(mut self, user_agent: String) -> Self {
82 self.metadata.user_agent = Some(user_agent);
83 self
84 }
85
86 pub fn with_custom_attribute(mut self, key: String, value: String) -> Self {
88 self.metadata.custom_attributes.insert(key, value);
89 self
90 }
91
92 pub fn user_id(&self) -> &str {
94 &self.claims.sub
95 }
96
97 pub fn audience(&self) -> Option<&str> {
99 self.claims.aud.as_deref()
100 }
101
102 pub fn issuer(&self) -> Option<&str> {
104 self.claims.iss.as_deref()
105 }
106
107 pub fn jwt_id(&self) -> Option<&str> {
109 self.claims.jti.as_deref()
110 }
111
112 pub fn is_expired(&self) -> bool {
114 match self.expires_at {
115 Some(expires_at) => Utc::now() > expires_at,
116 None => false, }
118 }
119
120 pub fn is_valid(&self) -> bool {
122 !self.is_expired()
123 }
124
125 pub fn time_until_expiration(&self) -> Option<Duration> {
127 self.expires_at.and_then(|expires_at| {
128 let duration = expires_at - Utc::now();
129 if duration.num_seconds() > 0 {
130 Some(duration)
131 } else {
132 None }
134 })
135 }
136
137 pub fn has_scope(&self, scope: &str) -> bool {
139 self.scopes.contains(&scope.to_string())
140 }
141
142 pub fn has_any_scope(&self, scopes: &[String]) -> bool {
144 scopes.iter().any(|scope| self.has_scope(scope))
145 }
146
147 pub fn has_all_scopes(&self, scopes: &[String]) -> bool {
149 scopes.iter().all(|scope| self.has_scope(scope))
150 }
151
152 pub fn get_scopes_matching(&self, pattern: &str) -> Vec<&String> {
154 if let Some(prefix) = pattern.strip_suffix('*') {
155 self.scopes
156 .iter()
157 .filter(|scope| scope.starts_with(prefix))
158 .collect()
159 } else {
160 self.scopes
161 .iter()
162 .filter(|scope| *scope == pattern)
163 .collect()
164 }
165 }
166
167 pub fn create_audit_entry(&self, action: &str, resource: &str) -> AuditLogEntry {
169 AuditLogEntry {
170 timestamp: Utc::now(),
171 user_id: self.user_id().to_string(),
172 action: action.to_string(),
173 resource: resource.to_string(),
174 scopes: self.scopes.clone(),
175 client_ip: self.metadata.client_ip.clone(),
176 user_agent: self.metadata.user_agent.clone(),
177 request_id: self.request_id.clone(),
178 jwt_id: self.jwt_id().map(|s| s.to_string()),
179 success: true, }
181 }
182
183 pub fn to_log_summary(&self) -> AuthContextSummary {
185 AuthContextSummary {
186 user_id: self.user_id().to_string(),
187 scopes: self.scopes.clone(),
188 expires_at: self.expires_at,
189 client_ip: self.metadata.client_ip.clone(),
190 request_id: self.request_id.clone(),
191 }
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct AuditLogEntry {
198 pub timestamp: DateTime<Utc>,
199 pub user_id: String,
200 pub action: String,
201 pub resource: String,
202 pub scopes: Vec<String>,
203 pub client_ip: Option<String>,
204 pub user_agent: Option<String>,
205 pub request_id: Option<String>,
206 pub jwt_id: Option<String>,
207 pub success: bool,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct AuthContextSummary {
213 pub user_id: String,
214 pub scopes: Vec<String>,
215 pub expires_at: Option<DateTime<Utc>>,
216 pub client_ip: Option<String>,
217 pub request_id: Option<String>,
218}
219
220pub trait AuthContextExt {
222 fn auth_context(&self) -> Option<&AuthContext>;
224
225 fn auth_context_mut(&mut self) -> Option<&mut AuthContext>;
227}
228
229impl AuthContextExt for axum::http::Extensions {
231 fn auth_context(&self) -> Option<&AuthContext> {
232 self.get::<AuthContext>()
233 }
234
235 fn auth_context_mut(&mut self) -> Option<&mut AuthContext> {
236 self.get_mut::<AuthContext>()
237 }
238}
239
240#[macro_export]
242macro_rules! require_auth {
243 ($extensions:expr) => {
244 $extensions
245 .auth_context()
246 .ok_or_else(|| $crate::oauth2::error::OAuth2Error::MissingAuthorization)?
247 };
248}
249
250#[macro_export]
251macro_rules! require_scope {
252 ($context:expr, $scope:expr) => {
253 if !$context.has_scope($scope) {
254 return Err($crate::oauth2::error::OAuth2Error::InsufficientScope {
255 required: $scope.to_string(),
256 provided: $context.scopes.join(" "),
257 });
258 }
259 };
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::oauth2::types::JwtClaims;
266
267 fn create_test_claims() -> JwtClaims {
268 JwtClaims {
269 sub: "user123".to_string(),
270 aud: Some("mcp-server".to_string()),
271 iss: Some("https://auth.example.com".to_string()),
272 exp: Some(
273 Utc::now().timestamp() + 3600, ),
275 nbf: None,
276 iat: None,
277 jti: Some("jwt-123".to_string()),
278 scope: Some("mcp:tools:execute mcp:resources:read".to_string()),
279 scopes: None,
280 }
281 }
282
283 #[test]
284 fn test_auth_context_creation() {
285 let claims = create_test_claims();
286 let scopes = vec![
287 "mcp:tools:execute".to_string(),
288 "mcp:resources:read".to_string(),
289 ];
290
291 let context = AuthContext::new(claims.clone(), scopes.clone());
292
293 assert_eq!(context.user_id(), "user123");
294 assert_eq!(context.audience(), Some("mcp-server"));
295 assert_eq!(context.issuer(), Some("https://auth.example.com"));
296 assert_eq!(context.jwt_id(), Some("jwt-123"));
297 assert_eq!(context.scopes, scopes);
298 assert!(!context.is_expired());
299 }
300
301 #[test]
302 fn test_auth_context_builders() {
303 let claims = create_test_claims();
304 let scopes = vec!["mcp:tools:execute".to_string()];
305
306 let context = AuthContext::new(claims, scopes)
307 .with_request_id("req-123".to_string())
308 .with_client_ip("192.168.1.1".to_string())
309 .with_user_agent("TestAgent/1.0".to_string())
310 .with_custom_attribute("tenant".to_string(), "example-org".to_string());
311
312 assert_eq!(context.request_id, Some("req-123".to_string()));
313 assert_eq!(context.metadata.client_ip, Some("192.168.1.1".to_string()));
314 assert_eq!(
315 context.metadata.user_agent,
316 Some("TestAgent/1.0".to_string())
317 );
318 assert_eq!(
319 context.metadata.custom_attributes.get("tenant"),
320 Some(&"example-org".to_string())
321 );
322 }
323
324 #[test]
325 fn test_scope_checking() {
326 let claims = create_test_claims();
327 let scopes = vec![
328 "mcp:tools:execute".to_string(),
329 "mcp:resources:read".to_string(),
330 "mcp:admin:all".to_string(),
331 ];
332
333 let context = AuthContext::new(claims, scopes);
334
335 assert!(context.has_scope("mcp:tools:execute"));
337 assert!(context.has_scope("mcp:resources:read"));
338 assert!(!context.has_scope("mcp:tools:admin"));
339
340 assert!(context.has_any_scope(&["mcp:tools:execute".to_string(),
342 "mcp:unknown:scope".to_string()]));
343 assert!(!context.has_any_scope(&["mcp:unknown:scope".to_string()]));
344
345 assert!(context.has_all_scopes(&["mcp:tools:execute".to_string(),
347 "mcp:resources:read".to_string()]));
348 assert!(!context.has_all_scopes(&["mcp:tools:execute".to_string(),
349 "mcp:unknown:scope".to_string()]));
350 }
351
352 #[test]
353 fn test_scope_pattern_matching() {
354 let claims = create_test_claims();
355 let scopes = vec![
356 "mcp:tools:execute".to_string(),
357 "mcp:tools:read".to_string(),
358 "mcp:resources:read".to_string(),
359 ];
360
361 let context = AuthContext::new(claims, scopes);
362
363 let tools_scopes = context.get_scopes_matching("mcp:tools:*");
365 assert_eq!(tools_scopes.len(), 2);
366 assert!(tools_scopes.contains(&&"mcp:tools:execute".to_string()));
367 assert!(tools_scopes.contains(&&"mcp:tools:read".to_string()));
368
369 let exact_scope = context.get_scopes_matching("mcp:resources:read");
370 assert_eq!(exact_scope.len(), 1);
371 assert!(exact_scope.contains(&&"mcp:resources:read".to_string()));
372 }
373
374 #[test]
375 fn test_expiration_checking() {
376 let mut claims = create_test_claims();
377
378 claims.exp = Some(
380 Utc::now().timestamp() - 3600, );
382
383 let context = AuthContext::new(claims, vec![]);
384 assert!(context.is_expired());
385 assert!(!context.is_valid());
386 assert!(context.time_until_expiration().is_none());
387 }
388
389 #[test]
390 fn test_audit_log_entry() {
391 let claims = create_test_claims();
392 let scopes = vec!["mcp:tools:execute".to_string()];
393
394 let context = AuthContext::new(claims, scopes)
395 .with_request_id("req-123".to_string())
396 .with_client_ip("192.168.1.1".to_string());
397
398 let audit_entry = context.create_audit_entry("tools/call", "calculator");
399
400 assert_eq!(audit_entry.user_id, "user123");
401 assert_eq!(audit_entry.action, "tools/call");
402 assert_eq!(audit_entry.resource, "calculator");
403 assert_eq!(audit_entry.request_id, Some("req-123".to_string()));
404 assert_eq!(audit_entry.client_ip, Some("192.168.1.1".to_string()));
405 assert_eq!(audit_entry.jwt_id, Some("jwt-123".to_string()));
406 assert!(audit_entry.success);
407 }
408
409 #[test]
410 fn test_log_summary() {
411 let claims = create_test_claims();
412 let scopes = vec!["mcp:tools:execute".to_string()];
413
414 let context = AuthContext::new(claims, scopes.clone())
415 .with_request_id("req-123".to_string())
416 .with_client_ip("192.168.1.1".to_string());
417
418 let summary = context.to_log_summary();
419
420 assert_eq!(summary.user_id, "user123");
421 assert_eq!(summary.scopes, scopes);
422 assert_eq!(summary.request_id, Some("req-123".to_string()));
423 assert_eq!(summary.client_ip, Some("192.168.1.1".to_string()));
424 }
426}