pmcp/server/auth/traits.rs
1//! Core authentication traits for flexible OAuth/auth integration.
2//!
3//! This module provides a provider-agnostic authentication abstraction for MCP servers.
4//! The core design principle is that **your MCP server code should never know about OAuth
5//! providers, tokens, or authentication flows. It only sees `AuthContext`.**
6//!
7//! # Provider Agnosticism
8//!
9//! The authentication system supports multiple OAuth providers (Cognito, Entra, Google,
10//! Okta, Auth0, etc.) through configuration, not code changes. See [`ClaimMappings`] for
11//! how provider-specific claim names are translated to standard names.
12//!
13//! # Example
14//!
15//! ```rust
16//! use pmcp::server::auth::AuthContext;
17//!
18//! fn handle_request(auth: &AuthContext) -> Result<String, &'static str> {
19//! // Require authentication
20//! auth.require_auth()?;
21//!
22//! // Check scopes
23//! auth.require_scope("read:data")?;
24//!
25//! // Access user info (provider-agnostic)
26//! let user_id = auth.user_id();
27//! let email = auth.email().unwrap_or("unknown");
28//!
29//! Ok(format!("Hello, {} ({})", email, user_id))
30//! }
31//! ```
32
33use crate::error::Result;
34use async_trait::async_trait;
35use serde::{de::DeserializeOwned, Deserialize, Serialize};
36use std::collections::HashMap;
37
38/// Authentication context containing validated user information.
39///
40/// This is the **only** auth type your MCP code should interact with.
41/// It provides a provider-agnostic view of the authenticated user, regardless
42/// of whether the token came from Cognito, Entra, Google, Okta, or any other
43/// OIDC provider.
44///
45/// # Provider-Agnostic Access
46///
47/// Use the helper methods like [`email()`](Self::email), [`tenant_id()`](Self::tenant_id),
48/// and [`user_id()`](Self::user_id) instead of directly accessing claims. These methods
49/// handle the different claim names used by various OAuth providers.
50///
51/// # Example
52///
53/// ```rust
54/// use pmcp::server::auth::AuthContext;
55///
56/// fn get_user_greeting(auth: &AuthContext) -> String {
57/// let name = auth.email().unwrap_or(auth.user_id());
58/// format!("Welcome, {}!", name)
59/// }
60/// ```
61#[derive(Debug, Clone, Default, Serialize, Deserialize)]
62pub struct AuthContext {
63 /// Subject identifier (user ID from the `sub` claim).
64 pub subject: String,
65
66 /// Granted scopes/permissions.
67 pub scopes: Vec<String>,
68
69 /// Additional claims from the token.
70 /// Use the helper methods like [`email()`](Self::email) for common claims.
71 pub claims: HashMap<String, serde_json::Value>,
72
73 /// Original token if available (for forwarding to downstream services).
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub token: Option<String>,
76
77 /// Client ID that authenticated.
78 #[serde(skip_serializing_if = "Option::is_none")]
79 pub client_id: Option<String>,
80
81 /// Token expiration timestamp (Unix epoch seconds).
82 #[serde(skip_serializing_if = "Option::is_none")]
83 pub expires_at: Option<u64>,
84
85 /// Whether this context represents an authenticated user.
86 #[serde(default)]
87 pub authenticated: bool,
88}
89
90impl AuthContext {
91 /// Create a new authenticated context.
92 pub fn new(subject: impl Into<String>) -> Self {
93 Self {
94 subject: subject.into(),
95 authenticated: true,
96 ..Default::default()
97 }
98 }
99
100 /// Create an anonymous (unauthenticated) context.
101 pub fn anonymous() -> Self {
102 Self {
103 subject: "anonymous".to_string(),
104 authenticated: false,
105 ..Default::default()
106 }
107 }
108
109 /// Get the user ID (alias for subject).
110 ///
111 /// This is the standard user identifier, typically from the `sub` claim.
112 #[inline]
113 pub fn user_id(&self) -> &str {
114 &self.subject
115 }
116
117 /// Get a typed claim value.
118 ///
119 /// # Example
120 ///
121 /// ```rust
122 /// use pmcp::server::auth::AuthContext;
123 ///
124 /// let auth = AuthContext::new("user-123");
125 /// let roles: Option<Vec<String>> = auth.claim("roles");
126 /// ```
127 pub fn claim<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
128 self.claims
129 .get(key)
130 .and_then(|v| serde_json::from_value(v.clone()).ok())
131 }
132
133 /// Get the email address (handles different claim names across providers).
134 ///
135 /// This method checks common claim names used by different OAuth providers:
136 /// - `email` (Cognito, Google, Okta, Auth0)
137 /// - `preferred_username` (Entra ID)
138 /// - `upn` (Entra ID UPN)
139 pub fn email(&self) -> Option<&str> {
140 self.claims
141 .get("email")
142 .or_else(|| self.claims.get("preferred_username"))
143 .or_else(|| self.claims.get("upn"))
144 .and_then(|v| v.as_str())
145 }
146
147 /// Get the display name.
148 ///
149 /// Checks common claim names for user's name:
150 /// - `name` (most providers)
151 /// - `given_name` + `family_name` fallback
152 pub fn name(&self) -> Option<&str> {
153 self.claims.get("name").and_then(|v| v.as_str())
154 }
155
156 /// Get the tenant ID (handles different claim names across providers).
157 ///
158 /// This method checks common claim names used by different OAuth providers:
159 /// - `tenant_id` (normalized)
160 /// - `tid` (Entra ID)
161 /// - `custom:tenant_id` (Cognito custom attribute)
162 /// - `custom:tenant` (Cognito custom attribute)
163 /// - `org_id` (Auth0, Okta)
164 pub fn tenant_id(&self) -> Option<&str> {
165 self.claims
166 .get("tenant_id")
167 .or_else(|| self.claims.get("tid")) // Entra ID
168 .or_else(|| self.claims.get("custom:tenant_id")) // Cognito
169 .or_else(|| self.claims.get("custom:tenant")) // Cognito
170 .or_else(|| self.claims.get("org_id")) // Auth0, Okta
171 .and_then(|v| v.as_str())
172 }
173
174 /// Get groups/roles the user belongs to.
175 ///
176 /// Checks common claim names for group membership:
177 /// - `groups` (Entra ID, Okta)
178 /// - `cognito:groups` (Cognito)
179 /// - `roles` (Auth0)
180 pub fn groups(&self) -> Vec<String> {
181 self.claims
182 .get("groups")
183 .or_else(|| self.claims.get("cognito:groups"))
184 .or_else(|| self.claims.get("roles"))
185 .and_then(|v| serde_json::from_value(v.clone()).ok())
186 .unwrap_or_default()
187 }
188
189 /// Check if the context has a specific scope.
190 pub fn has_scope(&self, scope: &str) -> bool {
191 self.scopes.iter().any(|s| s == scope)
192 }
193
194 /// Check if the context has all specified scopes.
195 pub fn has_all_scopes(&self, scopes: &[&str]) -> bool {
196 scopes.iter().all(|scope| self.has_scope(scope))
197 }
198
199 /// Check if the context has any of the specified scopes.
200 pub fn has_any_scope(&self, scopes: &[&str]) -> bool {
201 scopes.iter().any(|scope| self.has_scope(scope))
202 }
203
204 /// Require a scope, returning an error message if missing.
205 ///
206 /// # Example
207 ///
208 /// ```rust
209 /// use pmcp::server::auth::AuthContext;
210 ///
211 /// fn protected_operation(auth: &AuthContext) -> Result<(), &'static str> {
212 /// auth.require_scope("write:data")?;
213 /// // ... perform operation
214 /// Ok(())
215 /// }
216 /// ```
217 pub fn require_scope(&self, scope: &str) -> std::result::Result<(), &'static str> {
218 if self.has_scope(scope) {
219 Ok(())
220 } else {
221 Err("Insufficient scope")
222 }
223 }
224
225 /// Require authentication, returning an error message if not authenticated.
226 ///
227 /// # Example
228 ///
229 /// ```rust
230 /// use pmcp::server::auth::AuthContext;
231 ///
232 /// fn protected_operation(auth: &AuthContext) -> Result<&str, &'static str> {
233 /// auth.require_auth()?;
234 /// Ok(auth.user_id())
235 /// }
236 /// ```
237 pub fn require_auth(&self) -> std::result::Result<(), &'static str> {
238 if self.authenticated {
239 Ok(())
240 } else {
241 Err("Authentication required")
242 }
243 }
244
245 /// Check if the token is expired.
246 pub fn is_expired(&self) -> bool {
247 if let Some(expires_at) = self.expires_at {
248 let now = std::time::SystemTime::now()
249 .duration_since(std::time::UNIX_EPOCH)
250 .unwrap()
251 .as_secs();
252 expires_at < now
253 } else {
254 false
255 }
256 }
257
258 /// Check if the user is in a specific group.
259 pub fn in_group(&self, group: &str) -> bool {
260 self.groups().iter().any(|g| g == group)
261 }
262}
263
264/// Claim mappings for translating provider-specific claims to standard names.
265///
266/// Different OAuth providers use different claim names for the same information.
267/// This struct allows configuring the mapping from provider-specific names to
268/// standard names used by `AuthContext`.
269///
270/// # Provider-Specific Claim Names
271///
272/// | Standard | Cognito | Entra ID | Google | Okta | Auth0 |
273/// |----------|---------|----------|--------|------|-------|
274/// | `user_id` | sub | oid | sub | uid | sub |
275/// | `tenant_id` | `custom:tenant` | tid | N/A | `org_id` | `org_id` |
276/// | email | email | `preferred_username` | email | email | email |
277/// | groups | `cognito:groups` | groups | N/A | groups | roles |
278///
279/// # Example
280///
281/// ```rust
282/// use pmcp::server::auth::ClaimMappings;
283///
284/// // Configure for Entra ID
285/// let mappings = ClaimMappings::entra();
286/// assert_eq!(mappings.user_id, "oid");
287/// assert_eq!(mappings.tenant_id, Some("tid".to_string()));
288/// ```
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct ClaimMappings {
291 /// Claim name for user ID (default: "sub").
292 #[serde(default = "default_user_id_claim")]
293 pub user_id: String,
294
295 /// Claim name for tenant ID.
296 #[serde(skip_serializing_if = "Option::is_none")]
297 pub tenant_id: Option<String>,
298
299 /// Claim name for email.
300 #[serde(skip_serializing_if = "Option::is_none")]
301 pub email: Option<String>,
302
303 /// Claim name for display name.
304 #[serde(skip_serializing_if = "Option::is_none")]
305 pub name: Option<String>,
306
307 /// Claim name for groups/roles.
308 #[serde(skip_serializing_if = "Option::is_none")]
309 pub groups: Option<String>,
310
311 /// Additional custom mappings.
312 #[serde(flatten)]
313 pub custom: HashMap<String, String>,
314}
315
316fn default_user_id_claim() -> String {
317 "sub".to_string()
318}
319
320impl Default for ClaimMappings {
321 fn default() -> Self {
322 Self {
323 user_id: default_user_id_claim(),
324 tenant_id: None,
325 email: Some("email".to_string()),
326 name: Some("name".to_string()),
327 groups: None,
328 custom: HashMap::new(),
329 }
330 }
331}
332
333impl ClaimMappings {
334 /// Create claim mappings for AWS Cognito.
335 pub fn cognito() -> Self {
336 Self {
337 user_id: "sub".to_string(),
338 tenant_id: Some("custom:tenant_id".to_string()),
339 email: Some("email".to_string()),
340 name: Some("name".to_string()),
341 groups: Some("cognito:groups".to_string()),
342 custom: HashMap::new(),
343 }
344 }
345
346 /// Create claim mappings for Microsoft Entra ID (Azure AD).
347 pub fn entra() -> Self {
348 Self {
349 user_id: "oid".to_string(),
350 tenant_id: Some("tid".to_string()),
351 email: Some("preferred_username".to_string()),
352 name: Some("name".to_string()),
353 groups: Some("groups".to_string()),
354 custom: HashMap::new(),
355 }
356 }
357
358 /// Create claim mappings for Google Identity.
359 pub fn google() -> Self {
360 Self {
361 user_id: "sub".to_string(),
362 tenant_id: None, // Google doesn't have tenant concept
363 email: Some("email".to_string()),
364 name: Some("name".to_string()),
365 groups: None,
366 custom: HashMap::new(),
367 }
368 }
369
370 /// Create claim mappings for Okta.
371 pub fn okta() -> Self {
372 Self {
373 user_id: "uid".to_string(),
374 tenant_id: Some("org_id".to_string()),
375 email: Some("email".to_string()),
376 name: Some("name".to_string()),
377 groups: Some("groups".to_string()),
378 custom: HashMap::new(),
379 }
380 }
381
382 /// Create claim mappings for Auth0.
383 pub fn auth0() -> Self {
384 Self {
385 user_id: "sub".to_string(),
386 tenant_id: Some("org_id".to_string()),
387 email: Some("email".to_string()),
388 name: Some("name".to_string()),
389 groups: Some("roles".to_string()),
390 custom: HashMap::new(),
391 }
392 }
393
394 /// Apply these mappings to normalize claims from a token.
395 ///
396 /// This transforms provider-specific claims into standard names that
397 /// `AuthContext` helper methods can find.
398 pub fn normalize_claims(
399 &self,
400 claims: &serde_json::Value,
401 ) -> HashMap<String, serde_json::Value> {
402 let mut normalized = HashMap::new();
403
404 if let Some(obj) = claims.as_object() {
405 // Copy all original claims
406 for (key, value) in obj {
407 normalized.insert(key.clone(), value.clone());
408 }
409
410 // Add normalized mappings
411 if let Some(value) = obj.get(&self.user_id) {
412 normalized.insert("sub".to_string(), value.clone());
413 }
414 if let Some(ref tenant_claim) = self.tenant_id {
415 if let Some(value) = obj.get(tenant_claim) {
416 normalized.insert("tenant_id".to_string(), value.clone());
417 }
418 }
419 if let Some(ref email_claim) = self.email {
420 if let Some(value) = obj.get(email_claim) {
421 normalized.insert("email".to_string(), value.clone());
422 }
423 }
424 if let Some(ref name_claim) = self.name {
425 if let Some(value) = obj.get(name_claim) {
426 normalized.insert("name".to_string(), value.clone());
427 }
428 }
429 if let Some(ref groups_claim) = self.groups {
430 if let Some(value) = obj.get(groups_claim) {
431 normalized.insert("groups".to_string(), value.clone());
432 }
433 }
434
435 // Apply custom mappings
436 for (standard_name, provider_name) in &self.custom {
437 if let Some(value) = obj.get(provider_name) {
438 normalized.insert(standard_name.clone(), value.clone());
439 }
440 }
441 }
442
443 normalized
444 }
445}
446
447/// Core authentication provider trait.
448/// This is the main abstraction that MCP servers use for authentication.
449#[async_trait]
450pub trait AuthProvider: Send + Sync {
451 /// Validate an incoming request and extract authentication context.
452 ///
453 /// This method receives the authorization header value and should:
454 /// 1. Parse the authentication token (e.g., Bearer token)
455 /// 2. Validate the token
456 /// 3. Return the authentication context if valid
457 ///
458 /// The `authorization_header` parameter contains the value of the Authorization header,
459 /// if present (e.g., "Bearer eyJhbGci...")
460 async fn validate_request(
461 &self,
462 authorization_header: Option<&str>,
463 ) -> Result<Option<AuthContext>>;
464
465 /// Get the authentication scheme this provider uses (e.g., "Bearer", "Basic").
466 fn auth_scheme(&self) -> &'static str {
467 "Bearer"
468 }
469
470 /// Check if this provider requires authentication for all requests.
471 fn is_required(&self) -> bool {
472 true
473 }
474}
475
476/// Token validator trait for validating access tokens.
477#[async_trait]
478pub trait TokenValidator: Send + Sync {
479 /// Validate an access token and return token information.
480 async fn validate(&self, token: &str) -> Result<AuthContext>;
481
482 /// Optionally validate token with additional context (e.g., required scopes).
483 async fn validate_with_context(
484 &self,
485 token: &str,
486 required_scopes: Option<&[&str]>,
487 ) -> Result<AuthContext> {
488 let auth_context = self.validate(token).await?;
489
490 // Check required scopes if specified
491 if let Some(scopes) = required_scopes {
492 if !auth_context.has_all_scopes(scopes) {
493 return Err(crate::error::Error::protocol(
494 crate::error::ErrorCode::INVALID_REQUEST,
495 "Insufficient scopes",
496 ));
497 }
498 }
499
500 Ok(auth_context)
501 }
502}
503
504/// Session management trait for stateful authentication.
505#[async_trait]
506pub trait SessionManager: Send + Sync {
507 /// Create a new session and return the session ID.
508 async fn create_session(&self, auth: AuthContext) -> Result<String>;
509
510 /// Get session by ID.
511 async fn get_session(&self, session_id: &str) -> Result<Option<AuthContext>>;
512
513 /// Update an existing session.
514 async fn update_session(&self, session_id: &str, auth: AuthContext) -> Result<()>;
515
516 /// Invalidate a session.
517 async fn invalidate_session(&self, session_id: &str) -> Result<()>;
518
519 /// Clean up expired sessions (optional background task).
520 async fn cleanup_expired(&self) -> Result<usize> {
521 Ok(0) // Default no-op implementation
522 }
523}
524
525/// Tool authorization trait for fine-grained access control.
526#[async_trait]
527pub trait ToolAuthorizer: Send + Sync {
528 /// Check if the authenticated context can access a specific tool.
529 async fn can_access_tool(&self, auth: &AuthContext, tool_name: &str) -> Result<bool>;
530
531 /// Get required scopes for a tool.
532 async fn required_scopes_for_tool(&self, tool_name: &str) -> Result<Vec<String>>;
533}
534
535/// Simple scope-based tool authorizer.
536#[derive(Debug, Clone)]
537pub struct ScopeBasedAuthorizer {
538 tool_scopes: HashMap<String, Vec<String>>,
539 default_scopes: Vec<String>,
540}
541
542impl ScopeBasedAuthorizer {
543 /// Create a new scope-based authorizer.
544 pub fn new() -> Self {
545 Self {
546 tool_scopes: HashMap::new(),
547 default_scopes: vec!["mcp:tools:use".to_string()],
548 }
549 }
550
551 /// Add required scopes for a tool.
552 pub fn require_scopes<S, I>(mut self, tool_name: impl Into<String>, scopes: I) -> Self
553 where
554 I: IntoIterator<Item = S>,
555 S: AsRef<str>,
556 {
557 let scopes_vec = scopes.into_iter().map(|s| s.as_ref().to_string()).collect();
558 self.tool_scopes.insert(tool_name.into(), scopes_vec);
559 self
560 }
561
562 /// Set default required scopes for all tools.
563 pub fn default_scopes(mut self, scopes: Vec<String>) -> Self {
564 self.default_scopes = scopes;
565 self
566 }
567}
568
569#[async_trait]
570impl ToolAuthorizer for ScopeBasedAuthorizer {
571 async fn can_access_tool(&self, auth: &AuthContext, tool_name: &str) -> Result<bool> {
572 let required_scopes = self
573 .tool_scopes
574 .get(tool_name)
575 .unwrap_or(&self.default_scopes);
576
577 let scope_refs: Vec<&str> = required_scopes.iter().map(|s| s.as_str()).collect();
578 Ok(auth.has_all_scopes(&scope_refs))
579 }
580
581 async fn required_scopes_for_tool(&self, tool_name: &str) -> Result<Vec<String>> {
582 Ok(self
583 .tool_scopes
584 .get(tool_name)
585 .unwrap_or(&self.default_scopes)
586 .clone())
587 }
588}
589
590impl Default for ScopeBasedAuthorizer {
591 fn default() -> Self {
592 Self::new()
593 }
594}