Skip to main content

pmcp/server/auth/
traits.rs

1//! Core authentication traits for flexible OAuth/auth integration.
2//!
3//! This module provides a provider-agnostic authentication abstraction for MCP servers.
4//! The core design principle is that **your MCP server code should never know about OAuth
5//! providers, tokens, or authentication flows. It only sees `AuthContext`.**
6//!
7//! # Provider Agnosticism
8//!
9//! The authentication system supports multiple OAuth providers (Cognito, Entra, Google,
10//! Okta, Auth0, etc.) through configuration, not code changes. See [`ClaimMappings`] for
11//! how provider-specific claim names are translated to standard names.
12//!
13//! # Example
14//!
15//! ```rust
16//! use pmcp::server::auth::AuthContext;
17//!
18//! fn handle_request(auth: &AuthContext) -> Result<String, &'static str> {
19//!     // Require authentication
20//!     auth.require_auth()?;
21//!
22//!     // Check scopes
23//!     auth.require_scope("read:data")?;
24//!
25//!     // Access user info (provider-agnostic)
26//!     let user_id = auth.user_id();
27//!     let email = auth.email().unwrap_or("unknown");
28//!
29//!     Ok(format!("Hello, {} ({})", email, user_id))
30//! }
31//! ```
32
33use crate::error::Result;
34use async_trait::async_trait;
35use serde::{de::DeserializeOwned, Deserialize, Serialize};
36use std::collections::HashMap;
37
38/// Authentication context containing validated user information.
39///
40/// This is the **only** auth type your MCP code should interact with.
41/// It provides a provider-agnostic view of the authenticated user, regardless
42/// of whether the token came from Cognito, Entra, Google, Okta, or any other
43/// OIDC provider.
44///
45/// # Provider-Agnostic Access
46///
47/// Use the helper methods like [`email()`](Self::email), [`tenant_id()`](Self::tenant_id),
48/// and [`user_id()`](Self::user_id) instead of directly accessing claims. These methods
49/// handle the different claim names used by various OAuth providers.
50///
51/// # Example
52///
53/// ```rust
54/// use pmcp::server::auth::AuthContext;
55///
56/// fn get_user_greeting(auth: &AuthContext) -> String {
57///     let name = auth.email().unwrap_or(auth.user_id());
58///     format!("Welcome, {}!", name)
59/// }
60/// ```
61#[derive(Debug, Clone, Default, Serialize, Deserialize)]
62pub struct AuthContext {
63    /// Subject identifier (user ID from the `sub` claim).
64    pub subject: String,
65
66    /// Granted scopes/permissions.
67    pub scopes: Vec<String>,
68
69    /// Additional claims from the token.
70    /// Use the helper methods like [`email()`](Self::email) for common claims.
71    pub claims: HashMap<String, serde_json::Value>,
72
73    /// Original token if available (for forwarding to downstream services).
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub token: Option<String>,
76
77    /// Client ID that authenticated.
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub client_id: Option<String>,
80
81    /// Token expiration timestamp (Unix epoch seconds).
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub expires_at: Option<u64>,
84
85    /// Whether this context represents an authenticated user.
86    #[serde(default)]
87    pub authenticated: bool,
88}
89
90impl AuthContext {
91    /// Create a new authenticated context.
92    pub fn new(subject: impl Into<String>) -> Self {
93        Self {
94            subject: subject.into(),
95            authenticated: true,
96            ..Default::default()
97        }
98    }
99
100    /// Create an anonymous (unauthenticated) context.
101    pub fn anonymous() -> Self {
102        Self {
103            subject: "anonymous".to_string(),
104            authenticated: false,
105            ..Default::default()
106        }
107    }
108
109    /// Get the user ID (alias for subject).
110    ///
111    /// This is the standard user identifier, typically from the `sub` claim.
112    #[inline]
113    pub fn user_id(&self) -> &str {
114        &self.subject
115    }
116
117    /// Get a typed claim value.
118    ///
119    /// # Example
120    ///
121    /// ```rust
122    /// use pmcp::server::auth::AuthContext;
123    ///
124    /// let auth = AuthContext::new("user-123");
125    /// let roles: Option<Vec<String>> = auth.claim("roles");
126    /// ```
127    pub fn claim<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
128        self.claims
129            .get(key)
130            .and_then(|v| serde_json::from_value(v.clone()).ok())
131    }
132
133    /// Get the email address (handles different claim names across providers).
134    ///
135    /// This method checks common claim names used by different OAuth providers:
136    /// - `email` (Cognito, Google, Okta, Auth0)
137    /// - `preferred_username` (Entra ID)
138    /// - `upn` (Entra ID UPN)
139    pub fn email(&self) -> Option<&str> {
140        self.claims
141            .get("email")
142            .or_else(|| self.claims.get("preferred_username"))
143            .or_else(|| self.claims.get("upn"))
144            .and_then(|v| v.as_str())
145    }
146
147    /// Get the display name.
148    ///
149    /// Checks common claim names for user's name:
150    /// - `name` (most providers)
151    /// - `given_name` + `family_name` fallback
152    pub fn name(&self) -> Option<&str> {
153        self.claims.get("name").and_then(|v| v.as_str())
154    }
155
156    /// Get the tenant ID (handles different claim names across providers).
157    ///
158    /// This method checks common claim names used by different OAuth providers:
159    /// - `tenant_id` (normalized)
160    /// - `tid` (Entra ID)
161    /// - `custom:tenant_id` (Cognito custom attribute)
162    /// - `custom:tenant` (Cognito custom attribute)
163    /// - `org_id` (Auth0, Okta)
164    pub fn tenant_id(&self) -> Option<&str> {
165        self.claims
166            .get("tenant_id")
167            .or_else(|| self.claims.get("tid")) // Entra ID
168            .or_else(|| self.claims.get("custom:tenant_id")) // Cognito
169            .or_else(|| self.claims.get("custom:tenant")) // Cognito
170            .or_else(|| self.claims.get("org_id")) // Auth0, Okta
171            .and_then(|v| v.as_str())
172    }
173
174    /// Get groups/roles the user belongs to.
175    ///
176    /// Checks common claim names for group membership:
177    /// - `groups` (Entra ID, Okta)
178    /// - `cognito:groups` (Cognito)
179    /// - `roles` (Auth0)
180    pub fn groups(&self) -> Vec<String> {
181        self.claims
182            .get("groups")
183            .or_else(|| self.claims.get("cognito:groups"))
184            .or_else(|| self.claims.get("roles"))
185            .and_then(|v| serde_json::from_value(v.clone()).ok())
186            .unwrap_or_default()
187    }
188
189    /// Check if the context has a specific scope.
190    pub fn has_scope(&self, scope: &str) -> bool {
191        self.scopes.iter().any(|s| s == scope)
192    }
193
194    /// Check if the context has all specified scopes.
195    pub fn has_all_scopes(&self, scopes: &[&str]) -> bool {
196        scopes.iter().all(|scope| self.has_scope(scope))
197    }
198
199    /// Check if the context has any of the specified scopes.
200    pub fn has_any_scope(&self, scopes: &[&str]) -> bool {
201        scopes.iter().any(|scope| self.has_scope(scope))
202    }
203
204    /// Require a scope, returning an error message if missing.
205    ///
206    /// # Example
207    ///
208    /// ```rust
209    /// use pmcp::server::auth::AuthContext;
210    ///
211    /// fn protected_operation(auth: &AuthContext) -> Result<(), &'static str> {
212    ///     auth.require_scope("write:data")?;
213    ///     // ... perform operation
214    ///     Ok(())
215    /// }
216    /// ```
217    pub fn require_scope(&self, scope: &str) -> std::result::Result<(), &'static str> {
218        if self.has_scope(scope) {
219            Ok(())
220        } else {
221            Err("Insufficient scope")
222        }
223    }
224
225    /// Require authentication, returning an error message if not authenticated.
226    ///
227    /// # Example
228    ///
229    /// ```rust
230    /// use pmcp::server::auth::AuthContext;
231    ///
232    /// fn protected_operation(auth: &AuthContext) -> Result<&str, &'static str> {
233    ///     auth.require_auth()?;
234    ///     Ok(auth.user_id())
235    /// }
236    /// ```
237    pub fn require_auth(&self) -> std::result::Result<(), &'static str> {
238        if self.authenticated {
239            Ok(())
240        } else {
241            Err("Authentication required")
242        }
243    }
244
245    /// Check if the token is expired.
246    pub fn is_expired(&self) -> bool {
247        if let Some(expires_at) = self.expires_at {
248            let now = std::time::SystemTime::now()
249                .duration_since(std::time::UNIX_EPOCH)
250                .unwrap()
251                .as_secs();
252            expires_at < now
253        } else {
254            false
255        }
256    }
257
258    /// Check if the user is in a specific group.
259    pub fn in_group(&self, group: &str) -> bool {
260        self.groups().iter().any(|g| g == group)
261    }
262}
263
264/// Claim mappings for translating provider-specific claims to standard names.
265///
266/// Different OAuth providers use different claim names for the same information.
267/// This struct allows configuring the mapping from provider-specific names to
268/// standard names used by `AuthContext`.
269///
270/// # Provider-Specific Claim Names
271///
272/// | Standard | Cognito | Entra ID | Google | Okta | Auth0 |
273/// |----------|---------|----------|--------|------|-------|
274/// | `user_id` | sub | oid | sub | uid | sub |
275/// | `tenant_id` | `custom:tenant` | tid | N/A | `org_id` | `org_id` |
276/// | email | email | `preferred_username` | email | email | email |
277/// | groups | `cognito:groups` | groups | N/A | groups | roles |
278///
279/// # Example
280///
281/// ```rust
282/// use pmcp::server::auth::ClaimMappings;
283///
284/// // Configure for Entra ID
285/// let mappings = ClaimMappings::entra();
286/// assert_eq!(mappings.user_id, "oid");
287/// assert_eq!(mappings.tenant_id, Some("tid".to_string()));
288/// ```
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct ClaimMappings {
291    /// Claim name for user ID (default: "sub").
292    #[serde(default = "default_user_id_claim")]
293    pub user_id: String,
294
295    /// Claim name for tenant ID.
296    #[serde(skip_serializing_if = "Option::is_none")]
297    pub tenant_id: Option<String>,
298
299    /// Claim name for email.
300    #[serde(skip_serializing_if = "Option::is_none")]
301    pub email: Option<String>,
302
303    /// Claim name for display name.
304    #[serde(skip_serializing_if = "Option::is_none")]
305    pub name: Option<String>,
306
307    /// Claim name for groups/roles.
308    #[serde(skip_serializing_if = "Option::is_none")]
309    pub groups: Option<String>,
310
311    /// Additional custom mappings.
312    #[serde(flatten)]
313    pub custom: HashMap<String, String>,
314}
315
316fn default_user_id_claim() -> String {
317    "sub".to_string()
318}
319
320impl Default for ClaimMappings {
321    fn default() -> Self {
322        Self {
323            user_id: default_user_id_claim(),
324            tenant_id: None,
325            email: Some("email".to_string()),
326            name: Some("name".to_string()),
327            groups: None,
328            custom: HashMap::new(),
329        }
330    }
331}
332
333impl ClaimMappings {
334    /// Create claim mappings for AWS Cognito.
335    pub fn cognito() -> Self {
336        Self {
337            user_id: "sub".to_string(),
338            tenant_id: Some("custom:tenant_id".to_string()),
339            email: Some("email".to_string()),
340            name: Some("name".to_string()),
341            groups: Some("cognito:groups".to_string()),
342            custom: HashMap::new(),
343        }
344    }
345
346    /// Create claim mappings for Microsoft Entra ID (Azure AD).
347    pub fn entra() -> Self {
348        Self {
349            user_id: "oid".to_string(),
350            tenant_id: Some("tid".to_string()),
351            email: Some("preferred_username".to_string()),
352            name: Some("name".to_string()),
353            groups: Some("groups".to_string()),
354            custom: HashMap::new(),
355        }
356    }
357
358    /// Create claim mappings for Google Identity.
359    pub fn google() -> Self {
360        Self {
361            user_id: "sub".to_string(),
362            tenant_id: None, // Google doesn't have tenant concept
363            email: Some("email".to_string()),
364            name: Some("name".to_string()),
365            groups: None,
366            custom: HashMap::new(),
367        }
368    }
369
370    /// Create claim mappings for Okta.
371    pub fn okta() -> Self {
372        Self {
373            user_id: "uid".to_string(),
374            tenant_id: Some("org_id".to_string()),
375            email: Some("email".to_string()),
376            name: Some("name".to_string()),
377            groups: Some("groups".to_string()),
378            custom: HashMap::new(),
379        }
380    }
381
382    /// Create claim mappings for Auth0.
383    pub fn auth0() -> Self {
384        Self {
385            user_id: "sub".to_string(),
386            tenant_id: Some("org_id".to_string()),
387            email: Some("email".to_string()),
388            name: Some("name".to_string()),
389            groups: Some("roles".to_string()),
390            custom: HashMap::new(),
391        }
392    }
393
394    /// Apply these mappings to normalize claims from a token.
395    ///
396    /// This transforms provider-specific claims into standard names that
397    /// `AuthContext` helper methods can find.
398    pub fn normalize_claims(
399        &self,
400        claims: &serde_json::Value,
401    ) -> HashMap<String, serde_json::Value> {
402        let mut normalized = HashMap::new();
403
404        if let Some(obj) = claims.as_object() {
405            // Copy all original claims
406            for (key, value) in obj {
407                normalized.insert(key.clone(), value.clone());
408            }
409
410            // Add normalized mappings
411            if let Some(value) = obj.get(&self.user_id) {
412                normalized.insert("sub".to_string(), value.clone());
413            }
414            if let Some(ref tenant_claim) = self.tenant_id {
415                if let Some(value) = obj.get(tenant_claim) {
416                    normalized.insert("tenant_id".to_string(), value.clone());
417                }
418            }
419            if let Some(ref email_claim) = self.email {
420                if let Some(value) = obj.get(email_claim) {
421                    normalized.insert("email".to_string(), value.clone());
422                }
423            }
424            if let Some(ref name_claim) = self.name {
425                if let Some(value) = obj.get(name_claim) {
426                    normalized.insert("name".to_string(), value.clone());
427                }
428            }
429            if let Some(ref groups_claim) = self.groups {
430                if let Some(value) = obj.get(groups_claim) {
431                    normalized.insert("groups".to_string(), value.clone());
432                }
433            }
434
435            // Apply custom mappings
436            for (standard_name, provider_name) in &self.custom {
437                if let Some(value) = obj.get(provider_name) {
438                    normalized.insert(standard_name.clone(), value.clone());
439                }
440            }
441        }
442
443        normalized
444    }
445}
446
447/// Core authentication provider trait.
448/// This is the main abstraction that MCP servers use for authentication.
449#[async_trait]
450pub trait AuthProvider: Send + Sync {
451    /// Validate an incoming request and extract authentication context.
452    ///
453    /// This method receives the authorization header value and should:
454    /// 1. Parse the authentication token (e.g., Bearer token)
455    /// 2. Validate the token
456    /// 3. Return the authentication context if valid
457    ///
458    /// The `authorization_header` parameter contains the value of the Authorization header,
459    /// if present (e.g., "Bearer eyJhbGci...")
460    async fn validate_request(
461        &self,
462        authorization_header: Option<&str>,
463    ) -> Result<Option<AuthContext>>;
464
465    /// Get the authentication scheme this provider uses (e.g., "Bearer", "Basic").
466    fn auth_scheme(&self) -> &'static str {
467        "Bearer"
468    }
469
470    /// Check if this provider requires authentication for all requests.
471    fn is_required(&self) -> bool {
472        true
473    }
474}
475
476/// Token validator trait for validating access tokens.
477#[async_trait]
478pub trait TokenValidator: Send + Sync {
479    /// Validate an access token and return token information.
480    async fn validate(&self, token: &str) -> Result<AuthContext>;
481
482    /// Optionally validate token with additional context (e.g., required scopes).
483    async fn validate_with_context(
484        &self,
485        token: &str,
486        required_scopes: Option<&[&str]>,
487    ) -> Result<AuthContext> {
488        let auth_context = self.validate(token).await?;
489
490        // Check required scopes if specified
491        if let Some(scopes) = required_scopes {
492            if !auth_context.has_all_scopes(scopes) {
493                return Err(crate::error::Error::protocol(
494                    crate::error::ErrorCode::INVALID_REQUEST,
495                    "Insufficient scopes",
496                ));
497            }
498        }
499
500        Ok(auth_context)
501    }
502}
503
504/// Session management trait for stateful authentication.
505#[async_trait]
506pub trait SessionManager: Send + Sync {
507    /// Create a new session and return the session ID.
508    async fn create_session(&self, auth: AuthContext) -> Result<String>;
509
510    /// Get session by ID.
511    async fn get_session(&self, session_id: &str) -> Result<Option<AuthContext>>;
512
513    /// Update an existing session.
514    async fn update_session(&self, session_id: &str, auth: AuthContext) -> Result<()>;
515
516    /// Invalidate a session.
517    async fn invalidate_session(&self, session_id: &str) -> Result<()>;
518
519    /// Clean up expired sessions (optional background task).
520    async fn cleanup_expired(&self) -> Result<usize> {
521        Ok(0) // Default no-op implementation
522    }
523}
524
525/// Tool authorization trait for fine-grained access control.
526#[async_trait]
527pub trait ToolAuthorizer: Send + Sync {
528    /// Check if the authenticated context can access a specific tool.
529    async fn can_access_tool(&self, auth: &AuthContext, tool_name: &str) -> Result<bool>;
530
531    /// Get required scopes for a tool.
532    async fn required_scopes_for_tool(&self, tool_name: &str) -> Result<Vec<String>>;
533}
534
535/// Simple scope-based tool authorizer.
536#[derive(Debug, Clone)]
537pub struct ScopeBasedAuthorizer {
538    tool_scopes: HashMap<String, Vec<String>>,
539    default_scopes: Vec<String>,
540}
541
542impl ScopeBasedAuthorizer {
543    /// Create a new scope-based authorizer.
544    pub fn new() -> Self {
545        Self {
546            tool_scopes: HashMap::new(),
547            default_scopes: vec!["mcp:tools:use".to_string()],
548        }
549    }
550
551    /// Add required scopes for a tool.
552    pub fn require_scopes<S, I>(mut self, tool_name: impl Into<String>, scopes: I) -> Self
553    where
554        I: IntoIterator<Item = S>,
555        S: AsRef<str>,
556    {
557        let scopes_vec = scopes.into_iter().map(|s| s.as_ref().to_string()).collect();
558        self.tool_scopes.insert(tool_name.into(), scopes_vec);
559        self
560    }
561
562    /// Set default required scopes for all tools.
563    pub fn default_scopes(mut self, scopes: Vec<String>) -> Self {
564        self.default_scopes = scopes;
565        self
566    }
567}
568
569#[async_trait]
570impl ToolAuthorizer for ScopeBasedAuthorizer {
571    async fn can_access_tool(&self, auth: &AuthContext, tool_name: &str) -> Result<bool> {
572        let required_scopes = self
573            .tool_scopes
574            .get(tool_name)
575            .unwrap_or(&self.default_scopes);
576
577        let scope_refs: Vec<&str> = required_scopes.iter().map(|s| s.as_str()).collect();
578        Ok(auth.has_all_scopes(&scope_refs))
579    }
580
581    async fn required_scopes_for_tool(&self, tool_name: &str) -> Result<Vec<String>> {
582        Ok(self
583            .tool_scopes
584            .get(tool_name)
585            .unwrap_or(&self.default_scopes)
586            .clone())
587    }
588}
589
590impl Default for ScopeBasedAuthorizer {
591    fn default() -> Self {
592        Self::new()
593    }
594}